chore: update script

This commit is contained in:
jujimeizuo 2024-01-19 15:49:10 +08:00
parent 051bc08dc1
commit 92f1272ed3
3 changed files with 27 additions and 21 deletions

3
.gitignore vendored
View File

@ -1,4 +1,5 @@
ESConv.json
.DS_Store
__pycache__/
tmp/
tmp/
data/zhipuai/

View File

@ -1 +1,6 @@
# EmoLLM
## 🌟 Contributors
[![EmoLLM contributors](https://contrib.rocks/image?repo=aJupyter/EmoLLM&max=2000)](https://github.com/aJupyter/EmoLLM/graphs/contributors)

View File

@ -1,4 +1,5 @@
import os
import random
import json
from tqdm import tqdm
from dotenv import load_dotenv
@ -22,10 +23,12 @@ def zhipu_api(data, emo):
医生医生的安抚和建议
'''
top_p = round(random.uniform(0.1, 0.9), 2)
messages = getText('user', prompt)
response = client.chat.completions.create(
model='glm-4',
messages=messages,
top_p=top_p,
)
return response.choices[0].message.content
@ -47,6 +50,8 @@ def convert(conversation):
def save_jsonl(data_lis, file_path):
if not os.path.exists(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, 'w', encoding='utf-8') as f:
for item in data_lis:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
@ -67,7 +72,7 @@ if __name__ == '__main__':
"渴望",
"厌恶",
"同情",
"痛苦"
"痛苦",
"着迷",
"嫉妒",
"兴奋",
@ -80,7 +85,6 @@ if __name__ == '__main__':
"悲伤",
"满意",
"性欲",
"同情",
"满足"
]
areas_of_life = [
@ -103,22 +107,18 @@ if __name__ == '__main__':
]
conversation_lis = []
idx = 0
for area in areas_of_life:
j = 0
for idx in tqdm(range(len(emotions_lis)), desc=f'data:{area}, emo:{emotions_lis[j]}'):
emo = emotions_lis[j]
res = zhipu_api(area, emo)
print(res)
if res == 'null':
print(area, emo, 'error')
for emo in emotions_lis:
for area in areas_of_life:
if os.path.exists(f'./zhipuai/{area}/{emo}.jsonl'):
print(f'./zhipuai/{area}/{emo}.jsonl exists')
continue
conversation_lis.append(convert(res))
if idx % 2 == 1:
save_jsonl(conversation_lis, f'./zhipuai_{idx}.jsonl')
conversation_lis = []
idx += 1
j += 1
if len(conversation_lis) > 0:
save_jsonl(conversation_lis, f'./zhipuai.jsonl')
conversation_lis = []
for i in tqdm(range(5), desc='{emo}, {area}'.format(emo=emo, area=area)):
res = zhipu_api(area, emo)
print(res)
if res == 'null':
print(area, emo, 'error')
continue
conversation_lis.append(convert(res))
save_jsonl(conversation_lis, f'./zhipuai/{area}/{emo}.jsonl')
print(f'generate ./zhipuai/{area}/{emo}.jsonl')
conversation_lis = []