ADD RE @a
This commit is contained in:
parent
021f0f5638
commit
b3c2607677
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import argparse
|
import argparse
|
||||||
|
import re
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@ -109,36 +110,38 @@ if __name__ == '__main__':
|
|||||||
print(res)
|
print(res)
|
||||||
|
|
||||||
# 一次会话
|
# 一次会话
|
||||||
for itm in res.split('\n'):
|
doctor_pattern = r'医生:(.*?)(病人:|$)'
|
||||||
if itm.startswith("病人:"):
|
|
||||||
dia_tuple.append(itm.split(":")[1])
|
|
||||||
elif itm.startswith("医生:"):
|
|
||||||
dia_tuple.append(itm.split(":")[1])
|
|
||||||
|
|
||||||
if len(dia_tuple) == 2 and len(one_conversation['conversation']) == 0:
|
doctor_matches = re.findall(doctor_pattern, res, re.DOTALL)
|
||||||
|
doctor_conversations = [match[0] for match in doctor_matches]
|
||||||
|
|
||||||
|
patient_pattern = r'病人:(.*?)医生:'
|
||||||
|
patient_matches = re.findall(patient_pattern, res, re.DOTALL)
|
||||||
|
patient_conversations = [match for match in patient_matches]
|
||||||
|
|
||||||
|
for doc, pat in zip(doctor_conversations, patient_conversations):
|
||||||
|
if len(one_conversation['conversation']) == 0:
|
||||||
one_conversation['conversation'].append(
|
one_conversation['conversation'].append(
|
||||||
{
|
{
|
||||||
"system": "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。",
|
"system": "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。",
|
||||||
"input": dia_tuple[0],
|
"input": pat,
|
||||||
"output": dia_tuple[1]
|
"output": doc
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
dia_tuple = []
|
|
||||||
|
|
||||||
elif len(dia_tuple) == 2:
|
else:
|
||||||
one_conversation['conversation'].append(
|
one_conversation['conversation'].append(
|
||||||
{
|
{
|
||||||
"input": dia_tuple[0],
|
"input": pat,
|
||||||
"output": dia_tuple[1]
|
"output": doc
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
dia_tuple = []
|
|
||||||
conversation_lis.append(one_conversation)
|
conversation_lis.append(one_conversation)
|
||||||
|
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
# 每生成2条数据存储一次
|
# 每生成10条数据存储一次
|
||||||
if (idx % 2 == 0):
|
if (idx % 10 == 0):
|
||||||
path = f'./{args.data}.jsonl'
|
path = f'./{args.data}.jsonl'
|
||||||
save_jsonl(data_lis=conversation_lis, file_path=path)
|
save_jsonl(data_lis=conversation_lis, file_path=path)
|
||||||
conversation_lis = [] # 清空
|
conversation_lis = [] # 清空
|
||||||
|
Loading…
Reference in New Issue
Block a user