diff --git a/generate_data/merge_json.py b/generate_data/merge_json.py deleted file mode 100644 index 48ef5a1..0000000 --- a/generate_data/merge_json.py +++ /dev/null @@ -1,60 +0,0 @@ -import json -import os - - -def save_merge_json(data_lis, file_path): - with open(file_path, 'wt', encoding='utf-8') as file: - json.dump(data_lis, file, ensure_ascii=False, separators=(',\n',':')) - - -def get_all_file_paths(folder_path, file_type='.jsonl'): - # 确保传入的是一个目录 - if not os.path.isdir(folder_path): - raise ValueError(f"{folder_path} is not a valid directory") - - # 获取文件夹下所有文件的路径 - file_paths = [os.path.join(folder_path, file) for file in os.listdir( - folder_path) if os.path.isfile(os.path.join(folder_path, file)) and (file_type in file)] - return file_paths - - -if __name__ == '__main__': - conversion_lis = [] - - folder_path = r'./' - # D:\github_repos\EmoLLM\generate_data - - merge_path = folder_path.split('/')[-1] - try: - merge_last_path = folder_path.split('/')[-2] if folder_path.split('/')[-2]!='.' else '' - except: - merge_last_path = '' - print(f'merge_path={merge_path},merge_last_path={merge_last_path}') - - - for path in get_all_file_paths(folder_path): - print(path) - - with open(path, 'rt', encoding='utf-8') as file: - for line in file: - # # 移除行尾的换行符 - # if line == '\n': - # line = line.rstrip('\n') - line = line.rstrip('\n') - # 解析JSON - try: - data = json.loads(line) - conversion_lis.append(data) - # conversion_lis.append('\n') - except json.JSONDecodeError as e: - print(f"Error decoding JSON: {e}") - - if merge_last_path!='': - save_merge_json_path = rf'./{merge_last_path}/{merge_path}_merge.jsonl' - elif merge_path!='': - save_merge_json_path = rf'./{merge_path}_merge.json' - else: - save_merge_json_path = rf'./curr_merge.json' - - save_merge_json(data_lis=conversion_lis, - file_path=save_merge_json_path) diff --git a/scripts/.env b/scripts/.env new file mode 100644 index 0000000..a2f1f2f --- /dev/null +++ b/scripts/.env @@ -0,0 +1 @@ +ZHIPUAI_API_KEY = '' \ No newline at end of file diff --git a/scripts/check.py b/scripts/check.py new file mode 100644 index 0000000..2557be4 --- /dev/null +++ b/scripts/check.py @@ -0,0 +1,45 @@ +import os +import json + +def get_all_file_paths(folder_path, suffix=''): + files = os.listdir(folder_path) + path = [] + for file in files: + file_path = os.path.join(folder_path, file) + if os.path.isdir(file_path): + path.extend(get_all_file_paths(file_path)) + else: + if file_path.endswith(suffix): + path.append(file_path) + return path + +def check(filepath): + with open(path, 'rt', encoding='utf-8') as file: + data = json.load(file) + for idx, item in enumerate(data): + dict_item = dict(item) + for conversation in dict_item: + if conversation != 'conversation': + return 'found error in file: ' + filepath + ' at conversation index: ' + str(idx) + try: + if len(dict_item[conversation]) == 0: + return 'found error in file: ' + filepath + ' at conversation index: ' + str(idx) + except: + return 'found error in file: ' + filepath + ' at conversation index: ' + str(idx) + for in_out in dict_item[conversation]: + for key in in_out: + if key != 'system' and key != 'input' and key != 'output': + return 'found error in file: ' + filepath + ' at conversation index: ' + str(idx) + try : + if len(in_out[key]) == 0: + return 'found error in file: ' + filepath + ' at conversation index: ' + str(idx) + except: + return 'found error in file: ' + filepath + ' at conversation index: ' + str(idx) + return 'no error in file: ' + filepath + + +if __name__ == '__main__': + dir_path = '.' + paths = get_all_file_paths(dir_path, suffix='.json') + for path in paths: + print(check(filepath=path)) \ No newline at end of file diff --git a/generate_data/gen_metafile.py b/scripts/gen_metafile.py similarity index 100% rename from generate_data/gen_metafile.py rename to scripts/gen_metafile.py diff --git a/scripts/merge_json.py b/scripts/merge_json.py new file mode 100644 index 0000000..714befb --- /dev/null +++ b/scripts/merge_json.py @@ -0,0 +1,40 @@ +import json +import os + + +def save_merge_json(data_lis, file_path): + import json + + with open(file_path, 'wt', encoding='utf-8') as file: + json.dump(data_lis, file, ensure_ascii=False) + + +def get_all_file_paths(folder_path): + # 确保传入的是一个目录 + if not os.path.isdir(folder_path): + raise ValueError(f"{folder_path} is not a valid directory") + + # 获取文件夹下所有文件的路径 + file_paths = [os.path.join(folder_path, file) for file in os.listdir( + folder_path) if os.path.isfile(os.path.join(folder_path, file))] + return file_paths + + +if __name__ == '__main__': + conversion_lis = [] + + for path in get_all_file_paths(r'data\res-aiwei'): + print(path) + + with open(path, 'rt', encoding='utf-8') as file: + for line in file: + # 移除行尾的换行符 + line = line.rstrip('\n') + # 解析JSON + try: + data = json.loads(line) + conversion_lis.append(data) + except json.JSONDecodeError as e: + print(f"Error decoding JSON: {e}") + save_merge_json(data_lis=conversion_lis, + file_path=r'.\merge.json') diff --git a/generate_data/pdf2txt.py b/scripts/pdf2txt.py similarity index 100% rename from generate_data/pdf2txt.py rename to scripts/pdf2txt.py diff --git a/generate_data/process.py b/scripts/process.py similarity index 100% rename from generate_data/process.py rename to scripts/process.py diff --git a/generate_data/qa_generation/README.md b/scripts/qa_generation/README.md similarity index 100% rename from generate_data/qa_generation/README.md rename to scripts/qa_generation/README.md diff --git a/generate_data/qa_generation/README_EN.md b/scripts/qa_generation/README_EN.md similarity index 100% rename from generate_data/qa_generation/README_EN.md rename to scripts/qa_generation/README_EN.md diff --git a/generate_data/qa_generation/config/__init__.py b/scripts/qa_generation/config/__init__.py similarity index 100% rename from generate_data/qa_generation/config/__init__.py rename to scripts/qa_generation/config/__init__.py diff --git a/generate_data/qa_generation/config/config.py b/scripts/qa_generation/config/config.py similarity index 100% rename from generate_data/qa_generation/config/config.py rename to scripts/qa_generation/config/config.py diff --git a/generate_data/qa_generation/main.py b/scripts/qa_generation/main.py similarity index 100% rename from generate_data/qa_generation/main.py rename to scripts/qa_generation/main.py diff --git a/generate_data/qa_generation/model/__init__.py b/scripts/qa_generation/model/__init__.py similarity index 100% rename from generate_data/qa_generation/model/__init__.py rename to scripts/qa_generation/model/__init__.py diff --git a/generate_data/qa_generation/model/gemini.py b/scripts/qa_generation/model/gemini.py similarity index 100% rename from generate_data/qa_generation/model/gemini.py rename to scripts/qa_generation/model/gemini.py diff --git a/generate_data/qa_generation/model/glm.py b/scripts/qa_generation/model/glm.py similarity index 100% rename from generate_data/qa_generation/model/glm.py rename to scripts/qa_generation/model/glm.py diff --git a/generate_data/qa_generation/model/gpt.py b/scripts/qa_generation/model/gpt.py similarity index 100% rename from generate_data/qa_generation/model/gpt.py rename to scripts/qa_generation/model/gpt.py diff --git a/generate_data/qa_generation/model/qwen.py b/scripts/qa_generation/model/qwen.py similarity index 100% rename from generate_data/qa_generation/model/qwen.py rename to scripts/qa_generation/model/qwen.py diff --git a/generate_data/qa_generation/requirements.txt b/scripts/qa_generation/requirements.txt similarity index 100% rename from generate_data/qa_generation/requirements.txt rename to scripts/qa_generation/requirements.txt diff --git a/generate_data/qa_generation/system_prompt_v1.md b/scripts/qa_generation/system_prompt_v1.md similarity index 100% rename from generate_data/qa_generation/system_prompt_v1.md rename to scripts/qa_generation/system_prompt_v1.md diff --git a/generate_data/qa_generation/system_prompt_v1_EN.md b/scripts/qa_generation/system_prompt_v1_EN.md similarity index 100% rename from generate_data/qa_generation/system_prompt_v1_EN.md rename to scripts/qa_generation/system_prompt_v1_EN.md diff --git a/generate_data/qa_generation/system_prompt_v2.md b/scripts/qa_generation/system_prompt_v2.md similarity index 100% rename from generate_data/qa_generation/system_prompt_v2.md rename to scripts/qa_generation/system_prompt_v2.md diff --git a/generate_data/qa_generation/system_prompt_v2_EN.md b/scripts/qa_generation/system_prompt_v2_EN.md similarity index 100% rename from generate_data/qa_generation/system_prompt_v2_EN.md rename to scripts/qa_generation/system_prompt_v2_EN.md diff --git a/generate_data/qa_generation/util/__init__.py b/scripts/qa_generation/util/__init__.py similarity index 100% rename from generate_data/qa_generation/util/__init__.py rename to scripts/qa_generation/util/__init__.py diff --git a/generate_data/qa_generation/util/data_loader.py b/scripts/qa_generation/util/data_loader.py similarity index 100% rename from generate_data/qa_generation/util/data_loader.py rename to scripts/qa_generation/util/data_loader.py diff --git a/generate_data/qa_generation/util/logger.py b/scripts/qa_generation/util/logger.py similarity index 100% rename from generate_data/qa_generation/util/logger.py rename to scripts/qa_generation/util/logger.py diff --git a/generate_data/qa_generation/util/prompt_loader.py b/scripts/qa_generation/util/prompt_loader.py similarity index 100% rename from generate_data/qa_generation/util/prompt_loader.py rename to scripts/qa_generation/util/prompt_loader.py diff --git a/scripts/trans_process.py b/scripts/trans_process.py new file mode 100644 index 0000000..3999114 --- /dev/null +++ b/scripts/trans_process.py @@ -0,0 +1,78 @@ +import json +from tqdm import tqdm + + +def qwen_api(prompt): + import dashscope + from http import HTTPStatus + + dashscope.api_key = "your key" + prompt = "你是一位非常擅长将英文翻译成中文的专家。请你将下面的英文翻译成正确地道的中文,要求只返回翻译的中文句子:\n" + prompt + response = dashscope.Generation.call( + model='qwen-max', + prompt=prompt, + history=[], + ) + + if response.status_code == HTTPStatus.OK: + result = response.output.text + # print(result) + else: + result = 'ERROR' + return result + + +def get_conversation_list(): + with open('./ESConv.json', 'rt', encoding='utf-8') as file: + data = json.load(file) + + idx = 0 + conversation_list = [] + for itm in tqdm(data): + one_conversation = { + "conversation": [] + } + dia_tuple = [] + for dia in tqdm(itm['dialog']): + # print(dia['speaker'], dia['content']) + if dia['speaker'] == 'seeker': + dia_tuple.append(qwen_api(dia['content'])) + elif dia['speaker'] == 'supporter': + dia_tuple.append(qwen_api(dia['content'])) + else: + exit("不存在角色!") + + if len(dia_tuple) == 2 and len(one_conversation['conversation']) == 0: + one_conversation['conversation'].append( + { + "system": "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。", + "input": dia_tuple[0], + "output": dia_tuple[1] + }, + ) + dia_tuple = [] + + elif len(dia_tuple) == 2: + one_conversation['conversation'].append( + { + "input": dia_tuple[0], + "output": dia_tuple[1] + }, + ) + dia_tuple = [] + + conversation_list.append(one_conversation) + idx += 1 + + # if (idx == 1): + # print(conversation_list) + # break + print(idx) + return conversation_list + + +if __name__ == '__main__': + conversation_list = get_conversation_list() + # 将conversation_list保存为一个json文件 + with open('conversation_list.json', 'wt', encoding='utf-8') as f: + json.dump(conversation_list, f, ensure_ascii=False) diff --git a/scripts/upload_openxlab.py b/scripts/upload_openxlab.py new file mode 100644 index 0000000..252fd3b --- /dev/null +++ b/scripts/upload_openxlab.py @@ -0,0 +1,3 @@ +import os + +os.system("openxlab model create --model-repo='jujimeizuo/EmoLLM_Model' -s ./metafile.yml") \ No newline at end of file