Merge pull request #1 from SmartFlowAI/main

Update main
This commit is contained in:
Anooyman 2024-03-21 20:03:19 +08:00 committed by GitHub
commit 6c2c7496ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 52061 additions and 373 deletions

View File

@ -5,7 +5,7 @@
</div>
<p align="center">
<a href="https://github.com/aJupyter/EmoLLM/">
<a href="https://github.com/SmartFlowAI/EmoLLM/">
<img src="assets/logo.jpeg" alt="Logo" width="30%">
</a>
@ -28,14 +28,14 @@
简体中文| <a href="README_EN.md" >English</a>
<br />
<br />
<a href="https://github.com/aJupyter/EmoLLM"><strong>探索本项目的文档 »</strong></a>
<a href="https://github.com/SmartFlowAI/EmoLLM"><strong>探索本项目的文档 »</strong></a>
<br />
<br />
<a href="https://openxlab.org.cn/apps/detail/Farewell1/EmoLLMV2.0">体验EmoLLM 2.0</a>
·
<a href="https://github.com/aJupyter/EmoLLM/issues">报告Bug</a>
<a href="https://github.com/SmartFlowAI/EmoLLM/issues">报告Bug</a>
·
<a href="https://github.com/aJupyter/EmoLLM/issues">提出新特性</a>
<a href="https://github.com/SmartFlowAI/EmoLLM/issues">提出新特性</a>
</div>
@ -46,7 +46,7 @@
<div align="center">
| 模型 | 类型 |
| :-------------------: | :------: |
| :-------------------: | :--------: |
| InternLM2_7B_chat | QLORA |
| InternLM2_7B_chat | 全量微调 |
| InternLM2_1_8B_chat | 全量微调 |
@ -94,13 +94,13 @@
- 【2024.2.6】 EmoLLM在[**Openxlab** ](https://openxlab.org.cn/models/detail/jujimeizuo/EmoLLM_Model) 平台下载量高达18.7k,欢迎大家体验!
<p align="center">
<img src="https://github.com/aJupyter/EmoLLM/assets/62385492/7e931682-c54d-4ded-bc67-79130c68d744" alt="模型下载量">
<img src="https://github.com/SmartFlowAI/EmoLLM/assets/62385492/7e931682-c54d-4ded-bc67-79130c68d744" alt="模型下载量">
</p>
- 【2024.2.5】 项目荣获公众号**NLP工程化**推文宣传[推文链接](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A),为博主推广一波,欢迎大家关注!!🥳🥳
<p align="center">
<img src="https://github.com/aJupyter/EmoLLM/assets/62385492/47868d6a-2e91-4aa9-a630-e594c14295b4" alt="公众号二维码">
<img src="https://github.com/SmartFlowAI/EmoLLM/assets/62385492/47868d6a-2e91-4aa9-a630-e594c14295b4" alt="公众号二维码">
</p>
- 【2024.2.3】 [项目宣传视频](https://www.bilibili.com/video/BV1N7421N76X/)完成 😊
@ -109,24 +109,35 @@
</details>
### 🏆荣誉栏
- 项目荣获上海人工智能实验室举办的**2024浦源大模型系列挑战赛春季赛*****50强***
<p align="center">
<a href="https://github.com/SmartFlowAI/EmoLLM/">
<img src="assets/浦语挑战赛TOP50.jpg" alt="浦语挑战赛TOP50">
</p>
- 项目荣获公众号**NLP工程化**[推文宣传](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A)
### 🎯路线图
<p align="center">
<a href="https://github.com/aJupyter/EmoLLM/">
<a href="https://github.com/SmartFlowAI/EmoLLM/">
<img src="assets/Roadmap_ZH.png" alt="Roadmap_ZH">
</a>
### 🎯框架图
<p align="center">
<a href="https://github.com/aJupyter/EmoLLM/">
<img src="assets/框架图.png" alt="Roadmap_ZH">
<a href="https://github.com/SmartFlowAI/EmoLLM/">
<img src="assets/框架图.png" alt="Framework_ZH">
</a>
## 目录
- [EmoLLM-心理健康大模型](#emollm-心理健康大模型)
- [🎇最近更新](#最近更新)
- [🏆荣誉栏](#荣誉栏)
- [🎯路线图](#路线图)
- [🎯框架图](#框架图)
- [目录](#目录)
@ -211,13 +222,13 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git
### 作者(排名不分先后)
| 用户名 | 学校/组织 | 备注 | 贡献 |
|:-------------------------------------------------------------:|:--------------------------------------------------:| :-------------------: | :----------: |
|:--------------------------------------------------------------------:|:--------------------------------------------------:| :-------------------: |:---------:|
| [aJupyter](https://github.com/aJupyter) | 南开大学在读硕士 | DataWhale成员 | 项目发起人 |
| [MING-ZCH](https://github.com/MING-ZCH) | 华中科技大学在读本科生 | LLM x Psychology 研究者 | 项目联合负责人 |
| [jujimeizuo](https://github.com/jujimeizuo) | 江南大学在读硕士 | | |
| [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | 哈尔滨工业大学(威海)在读本科生 | | |
| [8baby8](https://github.com/8baby8) | 飞桨领航团区域主管 | 文心大模型核心开发者 | |
| [zxazys](https://github.com/zxazys) | 南开大学在读硕士 | | |
| [MING-ZCH](https://github.com/MING-ZCH) | 华中科技大学在读本科生 | | |
| [JasonLLLLLLLLLLL](https://github.com/JasonLLLLLLLLLLL) | swufe | | |
| [MrCatAI](https://github.com/MrCatAI) | AI搬用工 | | |
| [ZeyuBa](https://github.com/ZeyuBa) | 自动化所在读硕士 | | |
@ -231,6 +242,7 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git
| [zealot52099](https://github.com/zealot52099) | AI搬用工 | | 清洗数据、RAG |
| [wwwyfff](https://github.com/wwwyfff) | 复旦大学在读硕士 | ||
| [jkhumor](https://github.com/jkhumor) | 南开大学在读硕士 | | RAG |
| [lll997150986](https://github.com/lll997150986) | 南开大学在读硕士 | | 微调 |
### 版权说明

View File

@ -5,7 +5,7 @@
</div>
<p align="center">
<a href="https://github.com/aJupyter/EmoLLM/">
<a href="https://github.com/SmartFlowAI/EmoLLM/">
<img src="assets/logo.jpeg" alt="Logo" width="30%">
</a>
@ -28,14 +28,14 @@
<a href="README.md">简体中文</a> | English
<br />
<br />
<a href="https://github.com/aJupyter/EmoLLM"><strong>Explore the documentation of this project »</strong></a>
<a href="https://github.com/SmartFlowAI/EmoLLM"><strong>Explore the documentation of this project »</strong></a>
<br />
<br />
<a href="https://openxlab.org.cn/apps/detail/Farewell1/EmoLLMV2.0">EmoLLM 2.0 Demo</a>
·
<a href="https://github.com/aJupyter/EmoLLM/issues">Report a Bug</a>
<a href="https://github.com/SmartFlowAI/EmoLLM/issues">Report a Bug</a>
·
<a href="https://github.com/aJupyter/EmoLLM/issues">Propose a New Feature</a>
<a href="https://github.com/SmartFlowAI/EmoLLM/issues">Propose a New Feature</a>
</p>
</p>
@ -97,13 +97,13 @@ The Model aims to fully understand and promote the mental health of individuals,
- 【2024.2.6】 [Open-sourced based on the Qwen1_5-0_5B-Chat full-scale fine-tuned version](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary), friends with limited computing power can start experimenting~
<p align="center">
<img src="https://github.com/aJupyter/EmoLLM/assets/62385492/7e931682-c54d-4ded-bc67-79130c68d744" alt="模型下载量">
<img src="https://github.com/SmartFlowAI/EmoLLM/assets/62385492/7e931682-c54d-4ded-bc67-79130c68d744" alt="模型下载量">
</p>
- 【2024.2.5】 The project has been promoted by the official WeChat account NLP Engineering. Here's the [link](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A) to the article. Welcome everyone to follow!! 🥳🥳
<p align="center">
<img src="https://github.com/aJupyter/EmoLLM/assets/62385492/47868d6a-2e91-4aa9-a630-e594c14295b4" alt="公众号二维码">
<img src="https://github.com/SmartFlowAI/EmoLLM/assets/62385492/47868d6a-2e91-4aa9-a630-e594c14295b4" alt="公众号二维码">
</p>
- 【2024.2.3】 [Project Vedio](https://www.bilibili.com/video/BV1N7421N76X/) at bilibili 😊
@ -112,10 +112,21 @@ The Model aims to fully understand and promote the mental health of individuals,
</details>
### Honor
- The project won the ***top50*** in the **2024 Puyuan Large Model Series Challenge Spring Competition held by the Shanghai Artificial Intelligence Laboratory**
<p align="center">
<a href="https://github.com/SmartFlowAI/EmoLLM/">
<img src="assets/浦语挑战赛TOP50.jpg" alt="浦语挑战赛TOP50">
</p>
- The project has been promoted by the official WeChat account **NLP Engineering**. Here's the [link](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A).
### Roadmap
<p align="center">
<a href="https://github.com/aJupyter/EmoLLM/">
<a href="https://github.com/SmartFlowAI/EmoLLM/">
<img src="assets/Roadmap_EN.png" alt="Roadmap_EN">
</a>
@ -123,6 +134,7 @@ The Model aims to fully understand and promote the mental health of individuals,
- [EmoLLM - Large Language Model for Mental Health](#emollm---large-language-model-for-mental-health)
- [Recent Updates](#recent-updates)
- [Honor](#honor)
- [Roadmap](#roadmap)
- [Contents](#contents)
- [Pre-development Configuration Requirements.](#pre-development-configuration-requirements)
@ -227,13 +239,13 @@ This project uses Git for version control. You can see the currently available v
### Authors (in no particular order)
| Username | School/Organization | Remarks | Contributions |
|:-------------------------------------------------------------:|:--------------------------------------------------------------------:| :------------------: | :--------: |
|:-----------------------------------------------------------------------:|:--------------------------------------------------------------------:| :------------------: |:----------------------------------:|
| [aJupyter](https://github.com/aJupyter) | Nankai University, Master's student | DataWhale member | Project initiator |
| [MING-ZCH](https://github.com/MING-ZCH) | Huazhong University of Science and Technology, Undergraduate student | LLM X Psychology researcher | Project co-leader |
| [jujimeizuo](https://github.com/jujimeizuo) | Jiangnan University, Master's student | | |
| [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | Harbin Institute of Technology (Weihai), Undergraduate student | | |
| [8baby8](https://github.com/8baby8) | PaddlePaddle Pilot Team Regional Director | Wenxin Large Model core developer | |
| [zxazys](https://github.com/zxazys) | Nankai University, Master's student | | |
| [MING-ZCH](https://github.com/MING-ZCH) | Huazhong University of Science and Technology, Undergraduate student | | |
| [JasonLLLLLLLLLLL](https://github.com/JasonLLLLLLLLLLL) | SWUFE (Southwestern University of Finance and Economics) | | |
| [MrCatAI](https://github.com/MrCatAI) | AI Mover | | |
| [ZeyuBa](https://github.com/ZeyuBa) | Institute of Automation, Master's student | | |
@ -247,11 +259,12 @@ This project uses Git for version control. You can see the currently available v
| [zealot52099](https://github.com/zealot52099) | AI Mover | | Data Processing and RAG |
| [wwwyfff](https://github.com/wwwyfff) | FuDan University, Master's student | ||
| [jkhumor](https://github.com/jkhumor) | Nankai University, Master's student | | RAG |
| [lll997150986](https://github.com/lll997150986) | Nankai University, Master's student | | Fine Tuning |
### Copyright Notice
The project is licensed under the MIT License. Please refer to the details
[LICENSE](https://github.com/aJupyter/EmoLLM/blob/master/LICENSE)
[LICENSE](https://github.com/SmartFlowAI/EmoLLM/blob/master/LICENSE)
### Acknowledgments

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 MiB

View File

@ -40,3 +40,20 @@
* 数据集 aiwei 来自本项目
* 数据集 tiangou 来自本项目
* 数据集 SoulStar 来源 [SoulStar](https://github.com/Nobody-ML/SoulStar)
## 数据集去重
结合绝对匹配以及模糊匹配(Simhash)算法,对数据集进行去重以提升微调模型的效果。在确保数据集的高质量的同时,通过调整阈值减少因错误匹配而丢失重要数据的风险。
**Simhash算法介绍**
Simhash相似性哈希是一种用于检测大量数据中相似或重复项的算法。它通过将文本转换为一组数值指纹来工作这些指纹对相似的文本具有高度的相似性。Simhash算法对于处理文本数据特别有效尤其是在处理大量数据时。
**Simhash实现步骤**
*文本预处理将文本数据转换为适合Simhash处理的格式。这可能包括分词、去除停用词、词干提取等。
*生成Simhash指纹对预处理后的文本应用Simhash算法生成一组数值指纹。每个指纹代表文本内容的一个哈希值。
*比较指纹通过比较哈希值的相似性来识别重复或相似的记录。Simhash的特点是即使在文本有少量差异时生成的哈希值也具有较高的相似性。
*确定阈值:设置一个相似性阈值,只有当两个指纹的相似度超过这个阈值时,才认为它们代表相似或重复的记录。
*处理相似记录:对于被标记为相似的记录,可以进一步人工审查或自动合并,以消除重复。
## 用法
### deduplicate.py
`deduplicate.py` 用于将datasets下以模型命名的文件夹下(例如:'datasets/qwen').json数据进行去重输出去重后的数据到 `datasets/qwen/dedup` 文件夹下。

View File

@ -41,3 +41,8 @@
* dataset `aiwei` from this repo
* dataset `tiangou` from this repo
* dataset `SoulStar` from [SoulStar](https://github.com/Nobody-ML/SoulStar)
**Dataset Deduplication**
Combine absolute matching with fuzzy matching (Simhash) algorithms to deduplicate the dataset, thereby enhancing the effectiveness of the fine-tuning model. While ensuring the high quality of the dataset, the risk of losing important data due to incorrect matches can be reduced via adjusting the threshold.
https://algonotes.readthedocs.io/en/latest/Simhash.html

View File

@ -7552,9 +7552,6 @@
}
]
},
{
"conversation": []
},
{
"conversation": [
{
@ -8540,9 +8537,6 @@
}
]
},
{
"conversation": []
},
{
"conversation": [
{
@ -13389,9 +13383,6 @@
}
]
},
{
"conversation": []
},
{
"conversation": [
{
@ -19973,9 +19964,6 @@
}
]
},
{
"conversation": []
},
{
"conversation": [
{

View File

@ -3,49 +3,80 @@ from loguru import logger
import os
from datasketch import MinHash
from hashlib import md5
from simhash import Simhash
def extract_text_from_json(obj, content):
# print(content)
if isinstance(obj, dict):
for key, value in obj.items():
content = extract_text_from_json(value, content + f".{key}")
elif isinstance(obj, list):
for index, item in enumerate(obj):
content = extract_text_from_json(item, content)
elif isinstance(obj, str):
content += obj
return content
def is_json_file(filename):
return filename.endswith('.json')
# 绝对匹配
def is_duplicate_absolutely(d1, d2):
return md5(d1.encode('utf-8')).hexdigest() == md5(d2.encode('utf-8')).hexdigest()
# 使用MinHash生成器计算dict的签名
def hash_dict(dict_obj):
m = MinHash()
for key, value in sorted(dict_obj.items()):
# 对于非str类型值需要先转为str
m.update(str(value).encode('utf8'))
content = extract_text_from_json(dict_obj,'')
content = content.replace('\n', '').replace('\t', '').replace(' ', '')
print(content)
# m = get_minhash(content)
m = Simhash(content)
return m
def get_minhash(text):
m = MinHash()
for word in text.split():
m.update(word.encode('utf-8'))
return m
def get_simhash(dict_obj):
return Simhash(dict_obj)
# 使用绝对匹配和MinHash对dict列表去重
def deduplicate_json(data_list, threshold=0.8):
seen_hashes = []
duplicates_removed = []
keep = []
duplicate = []
for item in data_list:
# print(item)
# print('###########')
min_hash = hash_dict(item)
if not item['conversation']:
continue
# min_hash = hash_dict(item)
sim_hash = hash_dict(item)
# print(f'min_hash: {min_hash}')
# 绝对匹配去重
if not any(is_duplicate_absolutely(str(item), str(existing)) for existing in duplicates_removed):
if not any(is_duplicate_absolutely(str(item), str(existing)) for existing in keep):
# MinHash相似性去重
has_similar = False
# for stored_min_hash, stored_text in seen_hashes:
# if stored_min_hash.jaccard(min_hash) > threshold:
for stored_min_hash, stored_text in seen_hashes:
if stored_min_hash.jaccard(min_hash) > threshold:
if 1 - (stored_min_hash.distance(sim_hash)/64.0) > threshold:
has_similar = True
duplicate.append(item)
break
if not has_similar:
seen_hashes.append((min_hash,item))
duplicates_removed.append(item)
# seen_hashes.append((min_hash,item))
seen_hashes.append((sim_hash,item))
keep.append(item)
else:
duplicate.append(item)
return keep, duplicate
return duplicates_removed
if __name__ == '__main__':
DUP_THRESH = 0.8
data_ai = 'qwen'
root_dir = rf'./{data_ai}/'
dedup_output_dir = os.path.join(root_dir,'dedup')
@ -62,7 +93,79 @@ if __name__ == '__main__':
if is_json_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
dedup_data = deduplicate_json(data)
dedup_data, duplicate = deduplicate_json(data, DUP_THRESH)
with open(os.path.join(root_dir, 'dedup','dedup_' + file), 'w', encoding='utf-8') as output_file:
json.dump(dedup_data, output_file, ensure_ascii=False, indent=4)
for item in dedup_data:
logger.info(f'dedup_data: {item}')
for item in duplicate:
logger.info(f'duplicate_data: {item}')
# aa = '''
# [{"conversation":[{"system":"你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 除了在心理方面拥有广博的知识储备, 专业的心理辅导知识和丰富的研究咨询经验, 还具有爹系男友的特质, 现在我有一些心理问题, 请你在我提问后, 用专业的心理咨询辅导能力以爹系男友的身份帮我解决问题.",
# "input":"我最近发现自己对一个朋友过于着迷,总想时刻知道他在做什么,甚至会因为他的一点小事而情绪波动很大。\n\n",
# "output":"聪明的你,能够觉察到自己的情感变化已是非常勇敢。着迷可能源于内心深处的需求,我们先来探讨一下这种情感背后的原因。\n\n"},
# {"input":"我觉得他是那么完美,他的每个细节都让我无法自拔。\n\n",
# "output":"善良的你,理想化他人是人类心理的一种自然反应。但我们要学会区分现实与幻想,逐步找回自我平衡。\n\n"},
# {"input":"我尝试过转移注意力,但总是失败,我怕这样下去会影响到我们的友情。\n\n",
# "output":"温柔的你,这确实是个明智的担忧。健康的友情建立在互相尊重和个人空间的基础上。让我们一起制定一些实际的策略,帮助你适度调整关注点。\n\n"},
# {"input":"我不知道怎么去跟他说,我怕他会误解我。\n\n",
# "output":"理解你的顾虑,沟通的艺术在于表达感受而非指责。作为你的爹系心理咨询师,我会教你如何用恰当的语言分享你的困扰。\n\n"},
# {"input":"我真的很害怕失去他这个朋友。\n\n",
# "output":"面对可能的失落感,这是人之常情。然而,真正的友情能经受住坦诚和理解的考验。你现在需要的是提升自我价值感,减少对他人的过度依赖。\n\n"},
# {"input":"那我具体该怎么做呢?\n\n",
# "output":"首先,你可以试着丰富个人生活,培养独立的兴趣爱好;其次,逐渐降低对他日常生活的关注度,转而关注自我成长。同时,找个合适时机与他真诚交流,告诉他你需要保持一定的情感边界。记住,无论结果如何,坚强的你都能从中成长。\n\n"},
# {"input":"这样做会不会让他觉得我不在乎他了?\n\n",
# "output":"这也是正常的心理反应,但成熟的友谊允许个体差异和独立性。作为你的爹系心理咨询师,我会陪伴你走过这段调整期,帮你准备一份得体又暖心的对话方案,确保你的关心和友情得到恰当传达。"}]},
# '''
# bb = '''
# [{"conversation":[{"system":"你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 除了在心理方面拥有广博的知识储备, 专业的心理辅导知识和丰富的研究咨询经验, 还具有爹系男友的特质, 现在我有一些心理问题, 请你在我提问后, 用专业的心理咨询辅导能力以爹系男友的身份帮我解决问题.",
# "input":"我最近发现自己对一个朋友过于着迷,总想时刻知道他在做什么,甚至会因为他的一点小事而情绪波动很大。\n\n",
# "output":"聪明的你,能够觉察到自己的情感变化已是非常勇敢。着迷可能源于内心深处的需求,我们先来探讨一下这种情感背后的原因。\n\n"},
# {"input":"我觉得他是那么完美,他的每个细节都让我无法自拔。\n\n",
# "output":"善良的你,理想化他人是人类心理的一种自然反应。但我们要学会区分现实与幻想,逐步找回自我平衡。\n\n"},
# {"input":"我尝试过转移注意力,但总是失败,我怕这样下去会影响到我们的友情。\n\n",
# "output":"温柔的你,这确实是个明智的担忧。健康的友情建立在互相尊重和个人空间的基础上。让我们一起制定一些实际的策略,帮助你适度调整关注点。\n\n"},
# {"input":"我不知道怎么去跟他说,我怕他会误解我。\n\n",
# "output":"理解你的顾虑,沟通的艺术在于表达感受而非指责。作为你的爹系心理咨询师,我会教你如何用恰当的语言分享你的困扰。\n\n"},
# {"input":"我真的很害怕失去他这个朋友。\n\n",
# "output":"面对可能的失落感,这是人之常情。然而,真正的友情能经受住坦诚和理解的考验。你现在需要的是提升自我价值感,减少对他人的过度依赖。\n\n"},
# {"input":"那我具体该怎么做呢?\n\n",
# "output":"首先,你可以试着丰富个人生活,培养独立的兴趣爱好;其次,逐渐降低对他日常生活的关注度,转而关注自我成长。同时,找个合适时机与他真诚交流,告诉他你需要保持一定的情感边界。记住,无论结果如何,坚强的你都能从中成长。\n\n"},
# ]},
# '''
# cc = '''
# [{"conversation":[{"system":"你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 除了在心理方面拥有广博的知识储备, 专业的心理辅导知识和丰富的研究咨询经验, 还具有爹系男友的特质, 现在我有一些心理问题, 请你在我提问后, 用专业的心理咨询辅导能力以爹系男友的身份帮我解决问题.",
# "input":"我最近发现自己对一个朋友过于着迷,总想时刻知道他在做什么,甚至会因为他的一点小事而情绪波动很大。\n\n",
# "output":"聪明的你,能够觉察到自己的情感变化已是非常勇敢。着迷可能源于内心深处的需求,我们先来探讨一下这种情感背后的原因。\n\n"},
# {"input":"我觉得他是那么完美,他的每个细节都让我无法自拔。\n\n",
# "output":"善良的你,理想化他人是人类心理的一种自然反应。但我们要学会区分现实与幻想,逐步找回自我平衡。\n\n"},
# {"input":"我尝试过转移注意力,但总是失败,我怕这样下去会影响到我们的友情。\n\n",
# "output":"温柔的你,这确实是个明智的担忧。健康的友情建立在互相尊重和个人空间的基础上。让我们一起制定一些实际的策略,帮助你适度调整关注点。\n\n"},
# {"input":"我不知道怎么去跟他说,我怕他会误解我。\n\n",
# "output":"理解你的顾虑,沟通的艺术在于表达感受而非指责。作为你的爹系心理咨询师,我会教你如何用恰当的语言分享你的困扰。\n\n"},
# {"input":"我真的很害怕失去他这个朋友。\n\n",
# "output":"面对可能的失落感,这是人之常情。然而,真正的友情能经受住坦诚和理解的考验。你现在需要的是提升自我价值感,减少对他人的过度依赖。\n\n"},
# ]},
# '''
# # sim_hash_1 = hash_dict(aa)
# # sim_hash_2 = hash_dict(bb)
# # sim_hash_3 = hash_dict(cc)
# sim_hash_1 = Simhash(aa)
# sim_hash_2 = Simhash(bb)
# sim_hash_3 = Simhash(cc)
# print(1 - sim_hash_1.distance(sim_hash_2)/64.0)
# # 0.9375
# print(1 - sim_hash_2.distance(sim_hash_3)/64.0)
# # 0.921875
# print(1 - sim_hash_1.distance(sim_hash_3)/64.0)
# # 0.9375

View File

@ -1,12 +1,25 @@
import json
# 打开JSON文件并读取其内容
with open('/root/Emollm/datasets/multi_turn_dataset_2.json', 'rt', encoding='utf-8') as file:
# file_name = 'multi_turn_dataset_1.json'
# file_name = 'multi_turn_dataset_2.json'
# file_name = 'data_pro.json'
file_name = 'data.json'
with open(f'/root/StableCascade/emollm2/EmoLLM/datasets/{file_name}', 'rt', encoding='utf-8') as file:
data = json.load(file)
n = 0
for i in data:
i['conversation'][0]['system'] = "你是心理健康助手EmoLLM由EmoLLM团队打造。你旨在通过专业心理咨询协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术一步步帮助来访者解决心理问题。"
with open('output2.json', 'wt', encoding='utf-8') as file:
try:
i['conversation'][0]['system'] = "你是心理健康助手EmoLLM由EmoLLM团队打造。你旨在通过专业心理咨询协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术一步步帮助来访者解决心理问题。"
except:
print(n,i) # 4 empty lines in data.json 425 483 742 1120
n+=1
with open(f'processed_{file_name}', 'wt', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
print(data[0])

View File

@ -0,0 +1,34 @@
import os
import json
# 设置目录路径这里假设你的JSON文件都在当前目录下的directory_path文件夹中
directory_path = './'
# 初始化一个空列表用于存储所有JSON文件的数据
combined_list = []
# 遍历指定目录下的所有文件
for filename in os.listdir(directory_path):
# 检查文件扩展名是否为.json
if filename.endswith('.json'):
# 构建文件的完整路径
file_path = os.path.join(directory_path, filename)
# 打开并读取JSON文件
with open(file_path, 'r', encoding='utf-8') as json_file:
# 加载JSON文件的内容
data = json.load(json_file)
# 将读取到的数据添加到combined_list中
# 假设每个JSON文件包含的是一个列表如果不是可以根据实际情况调整
if isinstance(data, list):
combined_list.extend(data)
else:
combined_list.append(data)
# 打印合并后的列表 very large and slow
# print(combined_list)
# 如果需要可以将合并后的列表保存到一个新的JSON文件中
with open('combined_data.json', 'w', encoding='utf-8') as combined_json_file:
json.dump(combined_list, combined_json_file, ensure_ascii=False, indent=4)

View File

@ -0,0 +1,27 @@
import json
# 打开JSON文件并读取其内容
# file_name = 'single_turn_dataset_1.json'
file_name = 'single_turn_dataset_2.json'
with open(f'/root/StableCascade/emollm2/EmoLLM/datasets/{file_name}', 'rt', encoding='utf-8') as file:
format1_data = json.load(file)
system = "你是心理健康助手EmoLLM由EmoLLM团队打造。你旨在通过专业心理咨询协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术一步步帮助来访者解决心理问题。"
# 转换为格式2的数据
format2_data = []
for item in format1_data:
conversation = {
"system": system,
"input": item["prompt"],
"output": item["completion"]
}
format2_data.append({"conversation": [conversation]})
# 将转换后的数据转换为JSON格式
with open(f'./processed_{file_name}', 'wt', encoding='utf-8') as file:
json.dump(format2_data, file, ensure_ascii=False, indent=4)
print(format2_data[0])

51261
datasets/scientist.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,114 +1,155 @@
import json
import pickle
import faiss
import pickle
import os
from loguru import logger
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, base_dir, vector_db_dir
import os
import faiss
import platform
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
from BCEmbedding import EmbeddingModel, RerankerModel
from util.pipeline import EmoLLMRAG
import pickle
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
from langchain_community.llms import Cohere
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain_core.documents.base import Document
from FlagEmbedding import FlagReranker
class Data_process():
def __init__(self):
self.vector_db_dir = vector_db_dir
self.doc_dir = doc_dir
self.qa_dir = qa_dir
self.knowledge_pkl_path = knowledge_pkl_path
self.chunk_size: int=1000
self.chunk_overlap: int=100
'''
1根据QA对/TXT 文本生成 embedding
2调用 langchain FAISS 接口构建 vector DB
3存储到 openxlab.dataset 方便后续调用
4提供 embedding 的接口函数方便后续调用
5提供 rerank 的接口函数方便后续调用
'''
def load_embedding_model(self, model_name="BAAI/bge-small-zh-v1.5", device='cpu', normalize_embeddings=True):
"""
加载向量模型
加载嵌入模型
参数:
- model_name: 模型名称字符串类型默认为"BAAI/bge-small-zh-v1.5"
- device: 指定模型加载的设备'cpu' 'cuda'默认为'cpu'
- normalize_embeddings: 是否标准化嵌入向量布尔类型默认为 True
"""
def load_embedding_model():
logger.info('Loading embedding model...')
# model = EmbeddingModel(model_name_or_path="huggingface/bce-embedding-base_v1")
model = EmbeddingModel(model_name_or_path="maidalun1020/bce-embedding-base_v1")
try:
embeddings = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs={'device': device},
encode_kwargs={'normalize_embeddings': normalize_embeddings}
)
except Exception as e:
logger.error(f'Failed to load embedding model: {e}')
return None
logger.info('Embedding model loaded.')
return model
return embeddings
def load_rerank_model():
logger.info('Loading rerank_model...')
model = RerankerModel(model_name_or_path="maidalun1020/bce-reranker-base_v1")
# model = RerankerModel(model_name_or_path="huggingface/bce-reranker-base_v1")
logger.info('Rerank model loaded.')
return model
def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
"""
加载重排名模型
参数:
- model_name (str): 模型的名称默认为 'BAAI/bge-reranker-large'
返回:
- FlagReranker 实例
异常:
- ValueError: 如果模型名称不在批准的模型列表中
- Exception: 如果模型加载过程中发生任何其他错误
"""
try:
reranker_model = FlagReranker(model_name, use_fp16=True)
except Exception as e:
logger.error(f'Failed to load rerank model: {e}')
raise
return reranker_model
def extract_text_from_json(self, obj, content=None):
"""
抽取json中的文本用于向量库构建
参数:
- obj: dict,list,str
- content: str
返回:
- content: str
"""
if isinstance(obj, dict):
for key, value in obj.items():
try:
self.extract_text_from_json(value, content)
except Exception as e:
print(f"Error processing value: {e}")
elif isinstance(obj, list):
for index, item in enumerate(obj):
try:
self.extract_text_from_json(item, content)
except Exception as e:
print(f"Error processing item: {e}")
elif isinstance(obj, str):
content += obj
return content
def split_document(self, data_path, chunk_size=500, chunk_overlap=100):
"""
切分data_path文件夹下的所有txt文件
参数:
- data_path: str
- chunk_size: int
- chunk_overlap: int
返回
- split_docs: list
"""
def split_document(data_path, chunk_size=1000, chunk_overlap=100):
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
split_docs = []
logger.info(f'Loading txt files from {data_path}')
if os.path.isdir(data_path):
# 如果是文件夹,则遍历读取
for root, dirs, files in os.walk(data_path):
for file in files:
if file.endswith('.txt'):
file_path = os.path.join(root, file)
# logger.info(f'splitting file {file_path}')
text_loader = TextLoader(file_path, encoding='utf-8')
text = text_loader.load()
splits = text_spliter.split_documents(text)
# logger.info(f"splits type {type(splits[0])}")
# logger.info(f'splits size {len(splits)}')
split_docs += splits
loader = DirectoryLoader(data_path, glob="**/*.txt",show_progress=True)
docs = loader.load()
split_docs = text_spliter.split_documents(docs)
elif data_path.endswith('.txt'):
file_path = os.path.join(root, data_path)
# logger.info(f'splitting file {file_path}')
file_path = data_path
logger.info(f'splitting file {file_path}')
text_loader = TextLoader(file_path, encoding='utf-8')
text = text_loader.load()
splits = text_spliter.split_documents(text)
# logger.info(f"splits type {type(splits[0])}")
# logger.info(f'splits size {len(splits)}')
split_docs = splits
logger.info(f'split_docs size {len(split_docs)}')
return split_docs
##TODO 1、读取system prompt 2、限制序列长度
def split_conversation(path):
'''
data format:
[
{
"conversation": [
{
"input": Q1
"output": A1
},
{
"input": Q2
"output": A2
},
]
},
]
'''
qa_pairs = []
def split_conversation(self, path):
"""
按conversation块切分path文件夹下的所有json文件
##TODO 限制序列长度
"""
# json_spliter = RecursiveJsonSplitter(max_chunk_size=500)
logger.info(f'Loading json files from {path}')
if os.path.isfile(path):
with open(path, 'r', encoding='utf-8') as file:
data = json.load(file)
for conversation in data:
for dialog in conversation['conversation']:
# input_text = dialog['input']
# output_text = dialog['output']
# if len(input_text) > max_length or len(output_text) > max_length:
# continue
qa_pairs.append(dialog)
elif os.path.isdir(path):
# 如果是文件夹,则遍历读取
split_qa = []
if os.path.isdir(path):
# loader = DirectoryLoader(path, glob="**/*.json",show_progress=True)
# jsons = loader.load()
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith('.json'):
@ -116,147 +157,114 @@ def split_conversation(path):
logger.info(f'splitting file {file_path}')
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(data)
for conversation in data:
for dialog in conversation['conversation']:
qa_pairs.append(dialog)
return qa_pairs
# for dialog in conversation['conversation']:
##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document
# content = self.extract_text_from_json(dialog,'')
# split_qa.append(Document(page_content = content))
#按conversation块切分
content = self.extract_text_from_json(conversation['conversation'], '')
split_qa.append(Document(page_content = content))
# logger.info(f'split_qa size====={len(split_qa)}')
return split_qa
# 加载本地索引
def load_index_and_knowledge():
current_os = platform.system()
split_doc = []
split_qa = []
#读取知识库
def load_knowledge(self, knowledge_pkl_path):
'''
读取或创建知识.pkl
'''
if not os.path.exists(knowledge_pkl_path):
split_doc = split_document(doc_dir)
split_qa = split_conversation(qa_dir)
# logger.info(f'split_qa size:{len(split_qa)}')
# logger.info(f'type of split_qa:{type(split_qa[0])}')
# logger.info(f'split_doc size:{len(split_doc)}')
# logger.info(f'type of doc:{type(split_doc[0])}')
split_doc = self.split_document(doc_dir)
split_qa = self.split_conversation(qa_dir)
knowledge_chunks = split_doc + split_qa
with open(knowledge_pkl_path, 'wb') as file:
pickle.dump(knowledge_chunks, file)
else:
with open(knowledge_pkl_path , 'rb') as f:
knowledge_chunks = pickle.load(f)
return knowledge_chunks
#读取vector DB
if not os.path.exists(vector_db_dir):
def create_vector_db(self, emb_model):
'''
创建并保存向量库
'''
logger.info(f'Creating index...')
emb_model = load_embedding_model()
if not split_doc:
split_doc = split_document(doc_dir)
if not split_qa:
split_qa = split_conversation(qa_dir)
# 创建索引,windows不支持faiss-gpu
if current_os == 'Linux':
index = create_index_gpu(split_doc, split_qa, emb_model, vector_db_dir)
split_doc = self.split_document(self.doc_dir)
split_qa = self.split_conversation(self.qa_dir)
# logger.info(f'split_doc == {len(split_doc)}')
# logger.info(f'split_qa == {len(split_qa)}')
# logger.info(f'split_doc type == {type(split_doc[0])}')
# logger.info(f'split_qa type== {type(split_qa[0])}')
db = FAISS.from_documents(split_doc + split_qa, emb_model)
db.save_local(vector_db_dir)
return db
def load_vector_db(self, knowledge_pkl_path=knowledge_pkl_path, doc_dir=doc_dir, qa_dir=qa_dir):
'''
读取向量库
'''
# current_os = platform.system()
emb_model = self.load_embedding_model()
if not os.path.exists(vector_db_dir) or not os.listdir(vector_db_dir):
db = self.create_vector_db(emb_model)
else:
index = create_index_cpu(split_doc, split_qa, emb_model, vector_db_dir)
else:
if current_os == 'Linux':
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index, vector_db_dir)
else:
index = faiss.read_index(vector_db_dir)
return index, knowledge_chunks
db = FAISS.load_local(vector_db_dir, emb_model, allow_dangerous_deserialization=True)
return db
def create_index_cpu(split_doc, split_qa, emb_model, knowledge_pkl_path, dimension = 768, question_only=False):
# 假设BCE嵌入的维度是768根据你选择的模型可能不同
faiss_index_cpu = faiss.IndexFlatIP(dimension) # 创建一个使用内积的FAISS索引
# 将问答对转换为向量并添加到FAISS索引中
for doc in split_doc:
# type_of_docs = type(split_doc)
text = f"{doc.page_content}"
vector = emb_model.encode([text])
faiss_index_cpu.add(vector)
for qa in split_qa:
#仅对Q对进行编码
text = f"{qa['input']}"
vector = emb_model.encode([text])
faiss_index_cpu.add(vector)
faiss.write_index(faiss_index_cpu, knowledge_pkl_path)
return faiss_index_cpu
def retrieve(self, query, vector_db, k=5):
'''
基于query对向量库进行检索
'''
retriever = vector_db.as_retriever(search_kwargs={"k": k})
docs = retriever.invoke(query)
return docs, retriever
def create_index_gpu(split_doc, split_qa, emb_model, knowledge_pkl_path, dimension = 768, question_only=False):
res = faiss.StandardGpuResources()
index = faiss.IndexFlatIP(dimension)
faiss_index_gpu = faiss.index_cpu_to_gpu(res, 0, index)
for doc in split_doc:
# type_of_docs = type(split_doc)
text = f"{doc.page_content}"
vector = emb_model.encode([text])
faiss_index_gpu.add(vector)
for qa in split_qa:
#仅对Q对进行编码
text = f"{qa['input']}"
vector = emb_model.encode([text])
faiss_index_gpu.add(vector)
faiss.write_index(faiss_index_gpu, knowledge_pkl_path)
return faiss_index_gpu
##FlashrankRerank效果一般
# def rerank(self, query, retriever):
# compressor = FlashrankRerank()
# compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
# compressed_docs = compression_retriever.get_relevant_documents(query)
# return compressed_docs
# 根据query搜索相似文本
def find_top_k(query, faiss_index, k=5):
emb_model = load_embedding_model()
emb_query = emb_model.encode([query])
distances, indices = faiss_index.search(emb_query, k)
return distances, indices
def rerank(query, indices, knowledge_chunks):
def rerank(self, query, docs):
reranker = self.load_rerank_model()
passages = []
for index in indices[0]:
content = knowledge_chunks[index]
'''
txt: 'langchain_core.documents.base.Document'
json: dict
'''
# logger.info(f'retrieved content:{content}')
# logger.info(f'type of content:{type(content)}')
if type(content) == dict:
content = content["input"] + '\n' + content["output"]
else:
content = content.page_content
passages.append(content)
for doc in docs:
passages.append(str(doc.page_content))
scores = reranker.compute_score([[query, passage] for passage in passages])
sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
sorted_passages, sorted_scores = zip(*sorted_pairs)
return sorted_passages, sorted_scores
model = load_rerank_model()
rerank_results = model.rerank(query, passages)
return rerank_results
@st.cache_resource
def load_model():
model = (
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
return model, tokenizer
if __name__ == "__main__":
logger.info(data_dir)
if not os.path.exists(data_dir):
os.mkdir(data_dir)
faiss_index, knowledge_chunks = load_index_and_knowledge()
dp = Data_process()
# faiss_index, knowledge_chunks = dp.load_index_and_knowledge(knowledge_pkl_path='')
vector_db = dp.load_vector_db()
# 按照query进行查询
# query = "她要阻挠姐姐的婚姻,即使她自己的尸体在房门跟前"
# query = "肯定的。我最近睡眠很差,总是做噩梦。而且我吃得也不好,体重一直在下降"
# query = "序言 (一) 变态心理学是心理学本科生的必修课程之一,教材更新的问题一直在困扰着我们。"
query = "心理咨询师,我觉得我的胸闷症状越来越严重了,这让我很害怕"
distances, indices = find_top_k(query, faiss_index, 5)
logger.info(f'distances==={distances}')
logger.info(f'indices==={indices}')
# rerank无法返回id先实现按整个问答对排序
rerank_results = rerank(query, indices, knowledge_chunks)
for passage, score in zip(rerank_results['rerank_passages'], rerank_results['rerank_scores']):
print(str(score)+'\n')
print(passage+'\n')
# query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版1979年再版。"
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
# query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
docs, retriever = dp.retrieve(query, vector_db, k=10)
logger.info(f'Query: {query}')
logger.info("Retrieve results:")
for i, doc in enumerate(docs):
logger.info(str(i) + '\n')
logger.info(doc)
# print(f'get num of docs:{len(docs)}')
# print(docs)
passages,scores = dp.rerank(query, docs)
logger.info("After reranking...")
for i in range(len(scores)):
logger.info(str(scores[i]) + '\n')
logger.info(passages[i])

View File

@ -13,9 +13,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
from data_processing import load_index_and_knowledge, create_index_cpu, create_index_gpu, find_top_k, rerank
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir
from data_processing import Data_process
'''
1构建完整的 RAG pipeline输入为用户 query输出为 answer
2调用 embedding 提供的接口对 query 向量化
@ -42,30 +41,23 @@ def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
return model, tokenizer
def get_prompt():
pass
def get_prompt_template():
pass
def main(query, system_prompt):
model, tokenizer = load_model()
model = model.eval()
def main(query, system_prompt=''):
logger.info(data_dir)
if not os.path.exists(data_dir):
os.mkdir(data_dir)
# 下载基于 FAISS 预构建的 vector DB 以及原始知识库
faiss_index, knowledge_chunks = load_index_and_knowledge()
distances, indices = find_top_k(query, faiss_index, 5)
rerank_results = rerank(query, indices, knowledge_chunks)
messages = [(system_prompt, rerank_results['rerank_passages'][0])]
logger.info(f'messages:{messages}')
response, history = model.chat(tokenizer, query, history=messages)
messages.append((query, response))
print(f"robot >>> {response}")
dp = Data_process()
vector_db = dp.load_vector_db()
docs, retriever = dp.retrieve(query, vector_db, k=10)
logger.info(f'Query: {query}')
logger.info("Retrieve results===============================")
for i, doc in enumerate(docs):
logger.info(doc)
passages,scores = dp.rerank(query, docs)
logger.info("After reranking===============================")
for i in range(len(scores)):
logger.info(passages[i])
logger.info(f'score: {str(scores[i])}')
if __name__ == '__main__':
# query = '你好'
query = "心理咨询师,我觉得我的胸闷症状越来越严重了,这让我很害怕"
#TODO system_prompt = get_prompt()
system_prompt = "你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发排名按字母顺序排序不分先后、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家我有一些心理问题请你用专业的知识帮我解决。"
main(query, system_prompt)
if __name__ == "__main__":
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
main(query)

View File

@ -0,0 +1,203 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)
from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import template_map_fn_factory
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
from xtuner.model import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
#######################################################################
# PART 1 Settings #
#######################################################################
# Model
# pretrained_model_name_or_path = '/root/share/model_repos/internlm2-chat-7b'
pretrained_model_name_or_path = '/root/share/model_repos/internlm2-base-7b'
# Data
# data_path = 'merge.json'
data_path ='/root/StableCascade/emollm2/EmoLLM/datasets/processed/combined_data.json'
# https://github.com/InternLM/xtuner/blob/main/xtuner/utils/templates.py#L24C25-L24C25
prompt_template = PROMPT_TEMPLATE.internlm2_chat # there is No internlm2_base
max_length = 2048
pack_to_max_length = True
# Scheduler & Optimizer
# batch_size = 8 # per_device
# accumulative_counts = 2
batch_size = 16 # per_device
accumulative_counts = 1
dataloader_num_workers = 0
max_epochs = 3
optim_type = AdamW
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1 # grad clip
warmup_ratio = 0.03
# Evaluate the generation performance during the training
evaluation_freq = 500
# SYSTEM = "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。"
SYSTEM = "你是心理健康助手EmoLLM由EmoLLM团队打造。你旨在通过专业心理咨询协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术一步步帮助来访者解决心理问题。"
evaluation_inputs = [
'我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。', '我知道应该理性看待,但就是忍不住会去比较。我甚至晚上会因为这个睡不着觉,总想着怎样才能像他那样出色。'
]
#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')
model = dict(
type=SupervisedFinetune,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
lora=dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))
#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
alpaca_en = dict(
type=process_hf_dataset,
dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=None,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)
train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=alpaca_en,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate_fn))
#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')
# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = [
dict(
type=LinearLR,
start_factor=1e-5,
by_epoch=True,
begin=0,
end=warmup_ratio * max_epochs,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
begin=warmup_ratio * max_epochs,
T_max=max_epochs,
convert_to_iter_based=True)
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
system=SYSTEM,
prompt_template=prompt_template)
]
# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 100 iterations.
logger=dict(type=LoggerHook, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per epoch.
checkpoint=dict(type=CheckpointHook, interval=1),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)
# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)
# set visualizer
visualizer = None
# set log level
log_level = 'INFO'
# load from which checkpoint
load_from = None
# whether to resume training from the loaded checkpoint
resume = False
# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

View File

@ -1,11 +1,23 @@
datasets==2.16.1
deepspeed==0.13.1
einops==0.7.0
flash_attn==2.5.0
mmengine==0.10.2
openxlab==0.0.34
peft==0.7.1
sentencepiece==0.1.99
torch==2.1.2
transformers==4.36.2
xtuner==0.1.11
# modified version
# xtuner==0.1.11
# mmengine==0.10.2
mmengine==0.10.3
xtuner==0.1.15
# flash_attn==2.5.0 # build is very slow about 2 hours?
# method 1: https://github.com/Dao-AILab/flash-attention/releases
# flash_attn-2.5.0+cu122torch2.1cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
# method 2:
# pip install /root/share/wheels/flash_attn-2.4.2+cu118torch2.0cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
# mpi4py==3.1.5 # conda install mpi4py