Merge pull request #12 from 8baby8/main

add data_process
This commit is contained in:
xzwang 2024-01-23 21:29:22 +08:00 committed by GitHub
commit f47a360720
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1676 additions and 0 deletions

45
generate_data/check.py Normal file
View File

@ -0,0 +1,45 @@
import os
import json
def get_all_file_paths(folder_path, suffix=''):
files = os.listdir(folder_path)
path = []
for file in files:
file_path = os.path.join(folder_path, file)
if os.path.isdir(file_path):
path.extend(get_all_file_paths(file_path))
else:
if file_path.endswith(suffix):
path.append(file_path)
return path
def check(filepath):
with open(path, 'rt', encoding='utf-8') as file:
data = json.load(file)
for idx, item in enumerate(data):
dict_item = dict(item)
for conversation in dict_item:
if conversation != 'conversation':
return 'found error in file: ' + filepath + ' at conversation index: ' + str(idx)
try:
if len(dict_item[conversation]) == 0:
return 'found error in file: ' + filepath + ' at conversation index: ' + str(idx)
except:
return 'found error in file: ' + filepath + ' at conversation index: ' + str(idx)
for in_out in dict_item[conversation]:
for key in in_out:
if key != 'system' and key != 'input' and key != 'output':
return 'found error in file: ' + filepath + ' at conversation index: ' + str(idx)
try :
if len(in_out[key]) == 0:
return 'found error in file: ' + filepath + ' at conversation index: ' + str(idx)
except:
return 'found error in file: ' + filepath + ' at conversation index: ' + str(idx)
return 'no error in file: ' + filepath
if __name__ == '__main__':
dir_path = '.'
paths = get_all_file_paths(dir_path, suffix='.json')
for path in paths:
print(check(filepath=path))

59
generate_data/config.yml Normal file
View File

@ -0,0 +1,59 @@
aistudio _token : '{your_token}' # 文心一言的token
dashscope_api_key : '{your_api_key}' #通义千问的api_key
zhiouai_api_key : '{your_api_key}' # 智浦AI的密钥
# 星火大模型配置
appid : "{}" # 填写控制台中获取的 APPID 信息
api_secret : "{}" # 填写控制台中获取的 APISecret 信息
api_key : "{}" # 填写控制台中获取的 APIKey 信息
system : '现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决'
emotions_list : [
"钦佩",
"崇拜",
"欣赏",
"娱乐",
"焦虑",
"敬畏",
"尴尬",
"厌倦",
"冷静",
"困惑",
"渴望",
"厌恶",
"同情",
"痛苦",
"着迷",
"嫉妒",
"兴奋",
"恐惧",
"痛恨",
"有趣",
"快乐",
"怀旧",
"浪漫",
"悲伤",
"满意",
"性欲",
"同情",
"满足"
]
areas_of_life : [
"工作",
"学业",
"生活",
"身体",
"家人",
"朋友",
"社交",
"恋爱",
"就业",
"责任",
"爱好",
"环境",
"隐私",
"安全",
"梦想",
"自由"
]

View File

@ -0,0 +1,60 @@
import json
import random
import yaml
import erniebot
with open('config.yml', 'r', encoding='utf-8') as f:
configs = yaml.load(f.read(), Loader=yaml.FullLoader)
erniebot.api_type = 'aistudio'
#此处需要将你的token也就是AIstudio主页的访问令牌放到下方
erniebot.access_token = configs['aistudio _token']
system = configs['system']
areas_of_life = configs['areas_of_life']
emotions_list = configs['emotions_list']
words = ''
# prompt = '''
# 你是一个研究过无数具有心理健康问题的病人与心理健康医生对话案例的专家,请你构造一些符合实际情况的具有心理健康问题的病人和心理健康医生的多轮对话案例。要求医生的回复尽可能包含心理辅导知识,并且能够一步步诱导病人说出自己的问题进而提供解决问题的可行方案。注意,构造的数据必须以医生的陈述为结束语。请以如下格式返回生成的数据:
# 病人:病人的咨询或陈述
# 医生:医生的安抚和建议
# '''
for data in areas_of_life:
for emo in emotions_list:
res = []
print(f'正在为{data}_{emo}场景生成对应数据集')
prompt = f'''你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的专家,请你构造一些符合实际情况的具有心理健康问题的病人和心理健康医生的连续的多轮对话记录。
要求病人的问题属于{data}场景具有{emo}情感医生的回复尽可能包含心理辅导知识并且能够一步步诱导病人说出自己的问题进而提供解决问题的可行方案
注意构造的数据必须以医生的陈述为结束语每次只需要构造一个案例并且不需要写案例一二等等请返回完整的对话内容
请以如下格式返回生成的数据
病人病人的咨询或陈述
医生医生的安抚和建议
'''
for i in range(15):
response = erniebot.ChatCompletion.create(
model='ernie-3.5',
messages=[{'role': 'user', 'content': f"{prompt}"}],
# top_p=random.uniform(0.5, 0.99),
# penalty_score = random.uniform(1.0, 2.0)
)
tmp = response.result
print(tmp)
ls = tmp.split('\n')
conversation = {'conversation':[]}
for j in range(0, len(ls)-1, 2):
# print(j)
q_a = {}
if j == 0:
q_a = {'system':system, 'input':ls[j].split("")[-1], 'output':ls[j+1].split("")[-1]}
else:
q_a = {'input':ls[j].split("")[-1], 'output':ls[j+1].split("")[-1]}
# print(q_a)
conversation['conversation'].append(q_a)
res.append(conversation)
print(f'{i}条数据生成完成!!')
print('================================')
print(f'{data}_{emo}场景对应数据集生成完毕')
# 将数据写入JSON文件
with open('./data/Ernie_{data}_{emo}.json', 'w', encoding='utf-8') as file:
json.dump(res, file, ensure_ascii=False, indent=4)

785
generate_data/main.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,104 @@
import json
import random
import argparse
import yaml
import re
from tqdm import tqdm
with open('config.yml', 'r', encoding='utf-8') as f:
configs = yaml.load(f.read(), Loader=yaml.FullLoader)
def qwen_api(data, emo):
import dashscope
from http import HTTPStatus
dashscope.api_key = configs['dashscope_api_key']
prompt = f'''你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的专家,请你构造一些符合实际情况的具有心理健
康问题的病人和心理健康医生的连续的多轮对话记录要求病人的问题属于{data}场景具有{emo}情感医生的回复尽可能包含心理辅导知识并且能够一步步诱导病人说出自己的问题进而提供解决问题的可行方案注意构造的数据必须以医生的陈述为结束语请只返回完整的对话内容请以如下格式返回生成的数据
病人病人的咨询或陈述
医生医生的安抚和建议
'''
response = dashscope.Generation.call(
model='qwen-max',
prompt=prompt,
history=[],
)
if response.status_code == HTTPStatus.OK:
result = response.output.text
print(result)
else:
result = 'ERROR'
return result
def save_jsonl(data_lis, file_path):
import json
# 将字典列表写入文件,每一行一个字典
with open(file_path, 'at', encoding='utf-8') as file:
for item in data_lis:
json_string = json.dumps(item, ensure_ascii=False) + '\n'
file.write(json_string)
if __name__ == '__main__':
idx = 0
parser = argparse.ArgumentParser(description='数据生成参数')
parser.add_argument('--data', type=str, help='生活场景')
# 解析命令行参数
args = parser.parse_args()
emotions_lis = configs['emotions_list']
areas_of_life = configs['areas_of_life']
conversation_lis = []
for i in tqdm(range(100)):
one_conversation = {
"conversation": []
}
dia_tuple = []
emo = random.choice(emotions_lis)
res = qwen_api(data=args.data, emo=emo)
print(res)
# 一次会话
doctor_pattern = r'医生:(.*?)(病人:|$)'
doctor_matches = re.findall(doctor_pattern, res, re.DOTALL)
doctor_conversations = [match[0] for match in doctor_matches]
patient_pattern = r'病人:(.*?)医生:'
patient_matches = re.findall(patient_pattern, res, re.DOTALL)
patient_conversations = [match for match in patient_matches]
for doc, pat in zip(doctor_conversations, patient_conversations):
if len(one_conversation['conversation']) == 0:
one_conversation['conversation'].append(
{
"system": "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。",
"input": pat,
"output": doc
},
)
else:
one_conversation['conversation'].append(
{
"input": pat,
"output": doc
},
)
conversation_lis.append(one_conversation)
idx += 1
# 每生成10条数据存储一次
if (idx % 10 == 0):
path = f'./{args.data}.jsonl'
save_jsonl(data_lis=conversation_lis, file_path=path)
conversation_lis = [] # 清空

View File

@ -0,0 +1,4 @@
erniebot #文心一言
dashscope # 通义千问
zhipuai # 智浦
websocket #调用星火大模型的时候会使用

View File

@ -0,0 +1,27 @@
#!/bin/bash
# 定义生活领域的列表
areas_of_life=(
"工作"
"学业"
"生活"
"身体"
"家人"
"朋友"
"社交"
"恋爱"
"就业"
"责任"
"爱好"
"环境"
"隐私"
"安全"
"梦想"
"自由"
)
# 使用for循环遍历数组
for area in "${areas_of_life[@]}"; do
echo "当前生活领域: $area"
python qwen_gen_data.py --data $area
done

94
generate_data/tutorial.md Normal file
View File

@ -0,0 +1,94 @@
# EMO 心理大模型 微调数据生成教程
**一、目标与背景**
为了使我们的心理大模型有更好的表达效果我们必须要有高质量的数据集。为了达到这一目标我们决定利用四种强大的人工智能大模型文心一言、通义千问、讯飞星火和智浦AI来生成对话数据。此外我们还将增强数据集的认知深度通过加入少量自我认知数据集来提高模型的泛化能力。
**二、数据集生成方法**
1. **模型选择与数据准备**
选择文心一言、通义千问、讯飞星火和智浦这四种大语言模型获取调用相应接口的API并准备用于生成对话数据。
2. **单轮与多轮对话数据生成**
利用这四种模型我们生成了10000条单轮和多轮对话数据。在这一过程中我们确保了数据的多样性、复杂性和有效性。
因为心理活动往往是复杂的为了保证数据的多样性。我们选择了16 * 28 共`448`个场景进行数据集生成具体场景名称请参考config.yml中的 `emotions_list 和 areas_of_life`两个参数的配置。
3. **自我认知数据集的加入**
为了增强模型的认知能力,我们特意加入了一部分自我认知数据集。这些数据集有助于模型更好地理解上下文,提高对话的自然度和连贯性。
**三、实践步骤**
1. **初始化**
* 安装所需的软件和库。
```bash
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
* 准备输入数据和配置参数。
可参见`config.yml`均有注释
2. **模型选择与配置**
* 根据需求选择适合的模型。
为了使大家都能够玩上大模型我们选用InterLLM2-7B作为我们的基线模型消费级显卡也可部署微调的哦
* 对模型进行必要的配置和调整。
根据我们的数据集以及配置策略使用XTuner进行微调
3. **数据生成**
* 使用通义千问大模型进行数据生成。
```bash
# 终端运行
bash run_qwen.bash
```
* 使用百度文心大模型进行数据生成。
```bash
# 终端运行
python ernie_gen_data.py
```
* 使用智浦AI大模型进行数据生成。
```bash
# 终端运行
python zhipuai_gen_data.py
```
* 使用讯飞星火大模型进行数据生成。
```bash
# 终端运行
python ./xinghuo/gen_data.py
```
4. **自我认知数据集的整合**
* 自我认知数据集这个就需要按照格式手动生成的哈~,如下格式即可。
```json
[
{
"conversation": [
{
"input": "请介绍一下你自己",
"output": "我是大佬的emo小助手可以帮助你解决心理上的问题哦"
}
]
},
{
"conversation": [
{
"input": "请做一下自我介绍",
"output": "我是大佬的emo小助手可以帮助你解决心理上的问题哦"
}
]
}
]
```
5. **数据集整合。**
在进行数据集整合之前我们要检查生成的数据是否存在格式错误类型不符合等情况。我们需要check.py进行检查数据。最后再使用merge_json.py将所有的json整合为一个总的json文件。
6. **评估与优化**
* 使用适当的评估指标对生成的数据集进行评估。
* 根据评估结果进行必要的优化和调整。
7. **测试与部署**
* 使用独立测试集对训练好的模型进行评估。
* 根据测试结果进行必要的调整和优化。
* 将最终的模型部署到实际应用中。

View File

@ -0,0 +1,136 @@
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
answer = ""
class Ws_Param(object):
# 初始化
def __init__(self, APPID, APIKey, APISecret, Spark_url):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
self.host = urlparse(Spark_url).netloc
self.path = urlparse(Spark_url).path
self.Spark_url = Spark_url
# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
# 拼接鉴权参数生成url
url = self.Spark_url + '?' + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
# 收到websocket错误的处理
def on_error(ws, error):
print("### error:", error)
# 收到websocket关闭的处理
def on_close(ws,one,two):
print(" ")
# 收到websocket连接建立的处理
def on_open(ws):
thread.start_new_thread(run, (ws,))
def run(ws, *args):
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
ws.send(data)
# 收到websocket消息的处理
def on_message(ws, message):
# print(message)
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
print(content,end ="")
global answer
answer += content
# print(1)
if status == 2:
ws.close()
def gen_params(appid, domain,question):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"temperature": 0.5,
"max_tokens": 2048
}
},
"payload": {
"message": {
"text": question
}
}
}
return data
def main(appid, api_key, api_secret, Spark_url,domain, question):
# print("星火:")
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
ws.appid = appid
ws.question = question
ws.domain = domain
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})

View File

@ -0,0 +1,60 @@
import SparkApi
from prompt import *
from tqdm import tqdm
# 以下密钥信息从控制台获取
appid = "" # 填写控制台中获取的 APPID 信息
api_secret = "" # 填写控制台中获取的 APISecret 信息
api_key = "" # 填写控制台中获取的 APIKey 信息
# 用于配置大模型版本默认“general/generalv2”
domain = "general" # v1.5版本
# domain = "generalv2" # v2.0版本
# 云端环境的服务地址
Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
# Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
text = []
# length = 0
def getText(role, content):
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon)
return text
def getlength(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def checklen(text):
while (getlength(text) > 8000):
del text[0]
return text
if __name__ == '__main__':
text.clear
file_name = 'train3.jsonl'
conversations = []
for i in tqdm(range(200)):
Input = prompt(random.randint(0, 16))
question = checklen(getText("user", Input))
SparkApi.answer = ""
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question)
getText("assistant", SparkApi.answer)
conversations.append(ChatGLM3_6B(SparkApi.answer))
for item in conversations:
save_jsonl(item, file_name)
conversations.clear()

View File

@ -0,0 +1,64 @@
import SparkApi
from prompt import *
from tqdm import tqdm
with open('config.yml', 'r', encoding='utf-8') as f:
configs = yaml.load(f.read(), Loader=yaml.FullLoader)
# 以下密钥信息从控制台获取
appid = configs['appid'] # 填写控制台中获取的 APPID 信息
api_secret = configs['api_secret'] # 填写控制台中获取的 APISecret 信息
api_key = configs['api_key'] # 填写控制台中获取的 APIKey 信息
#用于配置大模型版本默认“general/generalv2”
domain = "general" # v1.5版本
# domain = "generalv2" # v2.0版本
#云端环境的服务地址
Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
# Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
text =[]
# length = 0
def getText(role,content):
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon)
return text
def getlength(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def checklen(text):
while (getlength(text) > 8000):
del text[0]
return text
if __name__ == '__main__':
text.clear
file_name = 'a2.jsonl'
conversations = []
for i in range(1):
for j in tqdm(range(10)):
Input = prompt(i)
question = checklen(getText("user",Input))
SparkApi.answer =""
SparkApi.main(appid,api_key, api_secret, Spark_url, domain, question)
getText("assistant", SparkApi.answer)
conversations.append(xinghuo_api(SparkApi.answer))
if i % 2 == 0 :
save_jsonl(conversations, file_name)
conversations.clear()

View File

@ -0,0 +1,151 @@
import json
import random
import re
import copy
# new_data = [{'role': 'user', 'content': '你好'}, {'role': 'assistant', 'content': '你好!有什么我可以帮助您的吗?'}, {'role': 'user', 'content': '生成一段心理医生和病人的对话'}, {'role': 'assistant', 'content': '心理医生:你好,我是你的心理医生。请问你的名字是?\n\n病人我叫李明。\n\n心理医生很高兴见到你李明先生。你来咨询是因为什么问题呢\n\n病人最近我总是感觉很焦虑睡眠也不好。\n\n心理医生好的可以跟我具体说说你的情况吗你有什么压力或者担忧的事情吗\n\n病人其实我一直在担心工作上的表现觉得自己做得不够好还有家庭的问题。\n\n心理医生这些都是很常见的问题。你可以告诉我你在工作中遇到了什么困难吗我们可以一起探讨一下如何解决。\n\n病人我觉得自己的工作能力不够强经常被领导批评。而且我家里的情况也不是很好父母经常吵架让我很难受。\n\n心理医生我理解你的感受。这些问题确实会让人感到压力和焦虑。不过我们可以通过一些方法来缓解这种情况。比如说你可以尝试一些放松的活动比如瑜伽或者冥想来减轻压力和焦虑。同时你也可以考虑寻求家人或者朋友的帮助让他们给你提供一些支持和鼓励。\n\n病人好的我会试试的。谢谢你的建议。\n\n心理医生不用客气如果你有任何问题或者需要进一步的帮助随时可以联系我。'}]
# text2 = []
# data = {'system':'现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。', 'input':'', 'output':''}
# for val in new_data:
# if val['role'] == 'user':
# continue
#
# print(text2)
def save_jsonl(conversations, path_file):
# 把对话写入文件
with open(path_file, 'a+', encoding='utf-8') as f:
for conversation in conversations:
Json_String = json.dumps(conversation, ensure_ascii=False) + '\n'
f.write(Json_String)
# 生成输入提示词
def prompt(life_type=0):
emotions_lis = [
"钦佩",
"崇拜",
"欣赏",
"娱乐",
"焦虑",
"敬畏",
"尴尬",
"厌倦",
"冷静",
"困惑",
"渴望",
"厌恶",
"同情",
"痛苦",
"着迷",
"嫉妒",
"兴奋",
"恐惧",
"痛恨",
"有趣",
"快乐",
"怀旧",
"浪漫",
"悲伤",
"满意",
"性欲",
"同情",
"满足"
]
areas_of_life = [
"工作",
"学业(小学,初中,高中,大学,研究生,博士)",
"生活(衣,食,住,行等等)",
"身体",
"家人",
"朋友",
"社交",
"恋爱",
"就业",
"责任",
"爱好",
"环境",
"隐私",
"安全",
"梦想",
"自由"
]
# 输入数据处理
if life_type < 0:
raise ValueError('life_type must > 0')
emo = random.choice(emotions_lis)
life_type %= 16
Input = f'''你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的专家,请你构造一些符合实际情况的具有心理健
康问题的病人和心理健康医生的连续的一段多轮对话记录要求病人的问题属于{areas_of_life[life_type]}场景具有{emo}情感医生的回复尽可能包含心理辅导知识并且能够一步步诱导病人说出自己的问题进而提供解决问题的可行方案注意构造的数据必须以医生的陈述为结束语请只返回完整的对话内容请以如下格式返回生成的数据
病人病人的咨询或陈述
医生医生的安抚和建议
'''
return Input
def xinghuo_api(content):
# 对话格式
conversation1 = {'system':'现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。', 'input':'', 'output':''}
conversation = {'input':'', 'output':''}
conversations = {'conversation':[]}
# temp = {'system':'现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。', 'input':'', 'output':''}
# 划分对话形式
dialogue = re.split('医生:|病人:', content)
# 对话前的数据处理
if dialogue[0] == '':
dialogue.pop(0)
# 一次对话
flag = False
for ind, item in enumerate(dialogue):
if flag == False:
if (ind + 1) % 2 == 1:
conversation1['input'] = dialogue[ind]
else:
conversation1['output'] = dialogue[ind]
if (ind + 1) % 2 == 0 or ind + 1 == len(dialogue):
temp = copy.deepcopy(conversation1)
conversations['conversation'].append(temp)
flag = True
continue
else:
if (ind+1)%2 == 1:
conversation['input'] = dialogue[ind]
else:
conversation['output'] = dialogue[ind]
if (ind+1)%2 == 0 or ind+1 == len(dialogue):
# 浅赋值只会是同一个变量必须要copy.deepcopy
# 若conversations['conversation'].append(conversation)后面改的话,~s里面的conversation也会改动
# 就会变成n个一样的数据这是我们不想看到的
temp = copy.deepcopy(conversation)
conversations['conversation'].append(temp)
return conversations
def ChatGLM3_6B(content):
# 对话格式
conversation = {'system': '现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。', 'input': '',
'output': ''}
conversations = []
# temp = {'system':'现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。', 'input':'', 'output':''}
# 划分对话形式
dialogue = re.split('医生:|病人:', content)
# 对话前的数据处理
if dialogue[0] == '':
dialogue.pop(0)
# 一次对话
for ind, item in enumerate(dialogue):
if (ind + 1) % 2 == 1:
conversation['input'] = dialogue[ind]
else:
conversation['output'] = dialogue[ind]
if (ind + 1) % 2 == 0 or ind + 1 == len(dialogue):
# 浅赋值只会是同一个变量必须要copy.deepcopy
# 若conversations['conversation'].append(conversation)后面改的话,~s里面的conversation也会改动
# 就会变成n个一样的数据这是我们不想看到的
temp = copy.deepcopy(conversation)
conversations.append(temp)
return conversations

View File

@ -0,0 +1,3 @@
gen_Chat 使用于生成ChatGLM3-6B的数据集
gen_data 适用于生成InternLM所需要的数据集
但是需要注意~火大模型用1.5生成时会有{"system": "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。", "input": "抱歉,我不能完成这个任务。作为一个认知智能模型,我不会提供任何与性欲情感相关的回答或建议。这种问题需要由专业的心理健康医生进行处理和解决。如果您有任何心理健康方面的问题,请寻求专业医生的帮助。", "output": ""}类似这样的数据集,要注意数据处理

View File

@ -0,0 +1,84 @@
import os
import random
import json
import yaml
from tqdm import tqdm
# from dotenv import load_dotenv
from zhipuai import ZhipuAI
with open('config.yml', 'r', encoding='utf-8') as f:
configs = yaml.load(f.read(), Loader=yaml.FullLoader)
load_dotenv()
client = ZhipuAI(api_key=configs['zhiouai_api_key'])
def zhipu_api(data, emo):
def getText(role, content, text = []):
jsoncon = {}
jsoncon['role'] = role
jsoncon['content'] = content
text.append(jsoncon)
return text
prompt = f'''你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的专家,请你构造一些符合实际情况的具有心理健
康问题的病人和心理健康医生的连续的多轮对话记录要求病人的问题属于{data}场景具有{emo}情感医生的回复尽可能包含心理辅导知识并且能够一步步诱导病人说出自己的问题进而提供解决问题的可行方案注意构造的数据必须以医生的陈述为结束语每次只需要构造一个案例并且不需要写案例一二等等请只返回完整的对话内容请以如下格式返回生成的数据
病人病人的咨询或陈述
医生医生的安抚和建议
'''
top_p = round(random.uniform(0.1, 0.9), 2)
messages = getText('user', prompt)
response = client.chat.completions.create(
model='glm-4',
messages=messages,
top_p=top_p,
)
return response.choices[0].message.content
def convert(conversation):
ret, one_conversation = {}, {}
ret['conversation'] = []
one_conversation['system'] = '现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。'
while '病人:' in conversation and '医生:' in conversation:
one_conversation['input'] = conversation.split('病人:')[1].split('医生:')[0]
one_conversation['output'] = conversation.split('病人:')[1].split('医生:')[1].split('病人:')[0]
conversation = '病人:' + '病人:'.join(conversation.split('病人:')[2:])
ret['conversation'].append(one_conversation)
one_conversation = {}
return ret
def save_jsonl(data_lis, file_path):
if not os.path.exists(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, 'w', encoding='utf-8') as f:
for item in data_lis:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
if __name__ == '__main__':
emotions_lis = configs['emotions_list']
areas_of_life = configs['areas_of_life']
conversation_lis = []
for emo in emotions_lis:
for area in areas_of_life:
if os.path.exists(f'./zhipuai/{area}/{emo}.jsonl'):
print(f'./zhipuai/{area}/{emo}.jsonl exists')
continue
for i in tqdm(range(5), desc='{emo}, {area}'.format(emo=emo, area=area)):
res = zhipu_api(area, emo)
print(res)
if res == 'null':
print(area, emo, 'error')
continue
conversation_lis.append(convert(res))
save_jsonl(conversation_lis, f'./zhipuai/{area}/{emo}.jsonl')
print(f'generate ./zhipuai/{area}/{emo}.jsonl')
conversation_lis = []