optimize deduplicate.py

Add time print information
save duplicate dataset as well
remove print(content)
This commit is contained in:
HongCheng 2024-03-23 15:24:45 +09:00
parent 66fa15da5d
commit 950cab0262

View File

@ -5,6 +5,9 @@ from datasketch import MinHash
from hashlib import md5
from simhash import Simhash
import time
import numpy as np
def extract_text_from_json(obj, content):
# print(content)
if isinstance(obj, dict):
@ -29,7 +32,7 @@ def is_duplicate_absolutely(d1, d2):
def hash_dict(dict_obj):
content = extract_text_from_json(dict_obj,'')
content = content.replace('\n', '').replace('\t', '').replace(' ', '')
print(content)
# print(content)
# m = get_minhash(content)
m = Simhash(content)
return m
@ -43,10 +46,19 @@ def get_simhash(dict_obj):
return Simhash(dict_obj)
# 使用绝对匹配和MinHash对dict列表去重
def deduplicate_json(data_list, threshold=0.8):
def deduplicate_json(data_list, threshold=0.8, time_print=True):
seen_hashes = []
keep = []
duplicate = []
# global start
start = time.time()
last_start_seen_hashes = start
last_start_duplicate = start
stop1 = 0
stop2 = 0
print_interval = 500
for item in data_list:
if not item['conversation']:
continue
@ -60,15 +72,36 @@ def deduplicate_json(data_list, threshold=0.8):
has_similar = False
# for stored_min_hash, stored_text in seen_hashes:
# if stored_min_hash.jaccard(min_hash) > threshold:
for stored_min_hash, stored_text in seen_hashes:
if 1 - (stored_min_hash.distance(sim_hash)/64.0) > threshold:
has_similar = True
duplicate.append(item)
print_len_duplicate = len(duplicate)+1
if print_len_duplicate%print_interval == 0:
if time_print:
stop1 = time.time()
print(f'print_len_duplicate={print_len_duplicate} Time: ', np.round(stop1 - last_start_duplicate, 5), np.round(stop1 - start , 5))
last_start_duplicate = stop1
else:
print(f'print_len_duplicate={print_len_duplicate}')
break
if not has_similar:
# seen_hashes.append((min_hash,item))
seen_hashes.append((sim_hash,item))
keep.append(item)
print_len_seen_hashes = len(seen_hashes)+1
if print_len_seen_hashes%print_interval == 0:
if time_print:
stop2 = time.time()
print(f'print_len_seen_hashes={print_len_seen_hashes} Time: ', str(np.round(stop2 - last_start_seen_hashes,5)), str(np.round(stop2 - start, 5)))
last_start_seen_hashes = stop2
else:
print(f'print_len_seen_hashes={print_len_seen_hashes}')
else:
duplicate.append(item)
@ -77,7 +110,8 @@ def deduplicate_json(data_list, threshold=0.8):
if __name__ == '__main__':
DUP_THRESH = 0.8
data_ai = 'qwen'
data_ai = 'FatherLikeBF'
# root_dir = rf'./datasets/{data_ai}/'
root_dir = rf'./{data_ai}/'
dedup_output_dir = os.path.join(root_dir,'dedup')
if not os.path.exists(dedup_output_dir):
@ -93,9 +127,14 @@ if __name__ == '__main__':
if is_json_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
dedup_data, duplicate = deduplicate_json(data, DUP_THRESH)
dedup_data, duplicate = deduplicate_json(data, DUP_THRESH)
with open(os.path.join(root_dir, 'dedup','dedup_' + file), 'w', encoding='utf-8') as output_file:
json.dump(dedup_data, output_file, ensure_ascii=False, indent=4)
with open(os.path.join(root_dir, 'dedup','dup_' + file), 'w', encoding='utf-8') as output_file:
json.dump(duplicate, output_file, ensure_ascii=False, indent=4)
for item in dedup_data:
logger.info(f'dedup_data: {item}')
for item in duplicate: