optimize deduplicate.py

Add time print information
save duplicate dataset as well
remove print(content)
This commit is contained in:
HongCheng 2024-03-23 15:24:45 +09:00
parent 66fa15da5d
commit 950cab0262

View File

@ -5,6 +5,9 @@ from datasketch import MinHash
from hashlib import md5 from hashlib import md5
from simhash import Simhash from simhash import Simhash
import time
import numpy as np
def extract_text_from_json(obj, content): def extract_text_from_json(obj, content):
# print(content) # print(content)
if isinstance(obj, dict): if isinstance(obj, dict):
@ -29,7 +32,7 @@ def is_duplicate_absolutely(d1, d2):
def hash_dict(dict_obj): def hash_dict(dict_obj):
content = extract_text_from_json(dict_obj,'') content = extract_text_from_json(dict_obj,'')
content = content.replace('\n', '').replace('\t', '').replace(' ', '') content = content.replace('\n', '').replace('\t', '').replace(' ', '')
print(content) # print(content)
# m = get_minhash(content) # m = get_minhash(content)
m = Simhash(content) m = Simhash(content)
return m return m
@ -43,10 +46,19 @@ def get_simhash(dict_obj):
return Simhash(dict_obj) return Simhash(dict_obj)
# 使用绝对匹配和MinHash对dict列表去重 # 使用绝对匹配和MinHash对dict列表去重
def deduplicate_json(data_list, threshold=0.8): def deduplicate_json(data_list, threshold=0.8, time_print=True):
seen_hashes = [] seen_hashes = []
keep = [] keep = []
duplicate = [] duplicate = []
# global start
start = time.time()
last_start_seen_hashes = start
last_start_duplicate = start
stop1 = 0
stop2 = 0
print_interval = 500
for item in data_list: for item in data_list:
if not item['conversation']: if not item['conversation']:
continue continue
@ -60,15 +72,36 @@ def deduplicate_json(data_list, threshold=0.8):
has_similar = False has_similar = False
# for stored_min_hash, stored_text in seen_hashes: # for stored_min_hash, stored_text in seen_hashes:
# if stored_min_hash.jaccard(min_hash) > threshold: # if stored_min_hash.jaccard(min_hash) > threshold:
for stored_min_hash, stored_text in seen_hashes: for stored_min_hash, stored_text in seen_hashes:
if 1 - (stored_min_hash.distance(sim_hash)/64.0) > threshold: if 1 - (stored_min_hash.distance(sim_hash)/64.0) > threshold:
has_similar = True has_similar = True
duplicate.append(item) duplicate.append(item)
print_len_duplicate = len(duplicate)+1
if print_len_duplicate%print_interval == 0:
if time_print:
stop1 = time.time()
print(f'print_len_duplicate={print_len_duplicate} Time: ', np.round(stop1 - last_start_duplicate, 5), np.round(stop1 - start , 5))
last_start_duplicate = stop1
else:
print(f'print_len_duplicate={print_len_duplicate}')
break break
if not has_similar: if not has_similar:
# seen_hashes.append((min_hash,item))
seen_hashes.append((sim_hash,item)) seen_hashes.append((sim_hash,item))
keep.append(item) keep.append(item)
print_len_seen_hashes = len(seen_hashes)+1
if print_len_seen_hashes%print_interval == 0:
if time_print:
stop2 = time.time()
print(f'print_len_seen_hashes={print_len_seen_hashes} Time: ', str(np.round(stop2 - last_start_seen_hashes,5)), str(np.round(stop2 - start, 5)))
last_start_seen_hashes = stop2
else:
print(f'print_len_seen_hashes={print_len_seen_hashes}')
else: else:
duplicate.append(item) duplicate.append(item)
@ -77,7 +110,8 @@ def deduplicate_json(data_list, threshold=0.8):
if __name__ == '__main__': if __name__ == '__main__':
DUP_THRESH = 0.8 DUP_THRESH = 0.8
data_ai = 'qwen' data_ai = 'FatherLikeBF'
# root_dir = rf'./datasets/{data_ai}/'
root_dir = rf'./{data_ai}/' root_dir = rf'./{data_ai}/'
dedup_output_dir = os.path.join(root_dir,'dedup') dedup_output_dir = os.path.join(root_dir,'dedup')
if not os.path.exists(dedup_output_dir): if not os.path.exists(dedup_output_dir):
@ -94,8 +128,13 @@ if __name__ == '__main__':
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f) data = json.load(f)
dedup_data, duplicate = deduplicate_json(data, DUP_THRESH) dedup_data, duplicate = deduplicate_json(data, DUP_THRESH)
with open(os.path.join(root_dir, 'dedup','dedup_' + file), 'w', encoding='utf-8') as output_file: with open(os.path.join(root_dir, 'dedup','dedup_' + file), 'w', encoding='utf-8') as output_file:
json.dump(dedup_data, output_file, ensure_ascii=False, indent=4) json.dump(dedup_data, output_file, ensure_ascii=False, indent=4)
with open(os.path.join(root_dir, 'dedup','dup_' + file), 'w', encoding='utf-8') as output_file:
json.dump(duplicate, output_file, ensure_ascii=False, indent=4)
for item in dedup_data: for item in dedup_data:
logger.info(f'dedup_data: {item}') logger.info(f'dedup_data: {item}')
for item in duplicate: for item in duplicate: