commit
6c2c7496ba
94
README.md
94
README.md
@ -5,7 +5,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://github.com/aJupyter/EmoLLM/">
|
<a href="https://github.com/SmartFlowAI/EmoLLM/">
|
||||||
<img src="assets/logo.jpeg" alt="Logo" width="30%">
|
<img src="assets/logo.jpeg" alt="Logo" width="30%">
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
@ -28,14 +28,14 @@
|
|||||||
简体中文| <a href="README_EN.md" >English</a>
|
简体中文| <a href="README_EN.md" >English</a>
|
||||||
<br />
|
<br />
|
||||||
<br />
|
<br />
|
||||||
<a href="https://github.com/aJupyter/EmoLLM"><strong>探索本项目的文档 »</strong></a>
|
<a href="https://github.com/SmartFlowAI/EmoLLM"><strong>探索本项目的文档 »</strong></a>
|
||||||
<br />
|
<br />
|
||||||
<br />
|
<br />
|
||||||
<a href="https://openxlab.org.cn/apps/detail/Farewell1/EmoLLMV2.0">体验EmoLLM 2.0</a>
|
<a href="https://openxlab.org.cn/apps/detail/Farewell1/EmoLLMV2.0">体验EmoLLM 2.0</a>
|
||||||
·
|
·
|
||||||
<a href="https://github.com/aJupyter/EmoLLM/issues">报告Bug</a>
|
<a href="https://github.com/SmartFlowAI/EmoLLM/issues">报告Bug</a>
|
||||||
·
|
·
|
||||||
<a href="https://github.com/aJupyter/EmoLLM/issues">提出新特性</a>
|
<a href="https://github.com/SmartFlowAI/EmoLLM/issues">提出新特性</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
@ -45,19 +45,19 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
| 模型 | 类型 |
|
| 模型 | 类型 |
|
||||||
| :-------------------: | :------: |
|
| :-------------------: | :--------: |
|
||||||
| InternLM2_7B_chat | QLORA |
|
| InternLM2_7B_chat | QLORA |
|
||||||
| InternLM2_7B_chat | 全量微调 |
|
| InternLM2_7B_chat | 全量微调 |
|
||||||
| InternLM2_1_8B_chat | 全量微调 |
|
| InternLM2_1_8B_chat | 全量微调 |
|
||||||
| InternLM2_20B_chat | LORA |
|
| InternLM2_20B_chat | LORA |
|
||||||
| Qwen_7b_chat | QLORA |
|
| Qwen_7b_chat | QLORA |
|
||||||
| Qwen1_5-0_5B-Chat | 全量微调 |
|
| Qwen1_5-0_5B-Chat | 全量微调 |
|
||||||
| Baichuan2_13B_chat | QLORA |
|
| Baichuan2_13B_chat | QLORA |
|
||||||
| ChatGLM3_6B | LORA |
|
| ChatGLM3_6B | LORA |
|
||||||
| DeepSeek MoE_16B_chat | QLORA |
|
| DeepSeek MoE_16B_chat | QLORA |
|
||||||
| Mixtral 8x7B_instruct | QLORA |
|
| Mixtral 8x7B_instruct | QLORA |
|
||||||
| …… | …… |
|
| …… | …… |
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -94,13 +94,13 @@
|
|||||||
- 【2024.2.6】 EmoLLM在[**Openxlab** ](https://openxlab.org.cn/models/detail/jujimeizuo/EmoLLM_Model) 平台下载量高达18.7k,欢迎大家体验!
|
- 【2024.2.6】 EmoLLM在[**Openxlab** ](https://openxlab.org.cn/models/detail/jujimeizuo/EmoLLM_Model) 平台下载量高达18.7k,欢迎大家体验!
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://github.com/aJupyter/EmoLLM/assets/62385492/7e931682-c54d-4ded-bc67-79130c68d744" alt="模型下载量">
|
<img src="https://github.com/SmartFlowAI/EmoLLM/assets/62385492/7e931682-c54d-4ded-bc67-79130c68d744" alt="模型下载量">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
- 【2024.2.5】 项目荣获公众号**NLP工程化**推文宣传[推文链接](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A),为博主推广一波,欢迎大家关注!!🥳🥳
|
- 【2024.2.5】 项目荣获公众号**NLP工程化**推文宣传[推文链接](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A),为博主推广一波,欢迎大家关注!!🥳🥳
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://github.com/aJupyter/EmoLLM/assets/62385492/47868d6a-2e91-4aa9-a630-e594c14295b4" alt="公众号二维码">
|
<img src="https://github.com/SmartFlowAI/EmoLLM/assets/62385492/47868d6a-2e91-4aa9-a630-e594c14295b4" alt="公众号二维码">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
- 【2024.2.3】 [项目宣传视频](https://www.bilibili.com/video/BV1N7421N76X/)完成 😊
|
- 【2024.2.3】 [项目宣传视频](https://www.bilibili.com/video/BV1N7421N76X/)完成 😊
|
||||||
@ -109,24 +109,35 @@
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
### 🏆荣誉栏
|
||||||
|
- 项目荣获上海人工智能实验室举办的**2024浦源大模型系列挑战赛春季赛*****50强***
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://github.com/SmartFlowAI/EmoLLM/">
|
||||||
|
<img src="assets/浦语挑战赛TOP50.jpg" alt="浦语挑战赛TOP50">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
- 项目荣获公众号**NLP工程化**[推文宣传](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A)
|
||||||
|
|
||||||
### 🎯路线图
|
### 🎯路线图
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://github.com/aJupyter/EmoLLM/">
|
<a href="https://github.com/SmartFlowAI/EmoLLM/">
|
||||||
<img src="assets/Roadmap_ZH.png" alt="Roadmap_ZH">
|
<img src="assets/Roadmap_ZH.png" alt="Roadmap_ZH">
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
### 🎯框架图
|
### 🎯框架图
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://github.com/aJupyter/EmoLLM/">
|
<a href="https://github.com/SmartFlowAI/EmoLLM/">
|
||||||
<img src="assets/框架图.png" alt="Roadmap_ZH">
|
<img src="assets/框架图.png" alt="Framework_ZH">
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
|
||||||
- [EmoLLM-心理健康大模型](#emollm-心理健康大模型)
|
- [EmoLLM-心理健康大模型](#emollm-心理健康大模型)
|
||||||
- [🎇最近更新](#最近更新)
|
- [🎇最近更新](#最近更新)
|
||||||
|
- [🏆荣誉栏](#荣誉栏)
|
||||||
- [🎯路线图](#路线图)
|
- [🎯路线图](#路线图)
|
||||||
- [🎯框架图](#框架图)
|
- [🎯框架图](#框架图)
|
||||||
- [目录](#目录)
|
- [目录](#目录)
|
||||||
@ -142,7 +153,7 @@
|
|||||||
- [版权说明](#版权说明)
|
- [版权说明](#版权说明)
|
||||||
- [特别鸣谢](#特别鸣谢)
|
- [特别鸣谢](#特别鸣谢)
|
||||||
- [Star History](#star-history)
|
- [Star History](#star-history)
|
||||||
- [🌟 Contributors](#-contributors)
|
- [🌟Contributors](#-contributors)
|
||||||
- [交流群](#交流群)
|
- [交流群](#交流群)
|
||||||
|
|
||||||
###### 开发前的配置要求
|
###### 开发前的配置要求
|
||||||
@ -210,27 +221,28 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git
|
|||||||
|
|
||||||
### 作者(排名不分先后)
|
### 作者(排名不分先后)
|
||||||
|
|
||||||
| 用户名 | 学校/组织 | 备注 | 贡献 |
|
| 用户名 | 学校/组织 | 备注 | 贡献 |
|
||||||
|:-------------------------------------------------------------:|:--------------------------------------------------:| :-------------------: | :----------: |
|
|:--------------------------------------------------------------------:|:--------------------------------------------------:| :-------------------: |:---------:|
|
||||||
| [aJupyter](https://github.com/aJupyter) | 南开大学在读硕士 | DataWhale成员 | 项目发起人 |
|
| [aJupyter](https://github.com/aJupyter) | 南开大学在读硕士 | DataWhale成员 | 项目发起人 |
|
||||||
| [jujimeizuo](https://github.com/jujimeizuo) | 江南大学在读硕士 | | |
|
| [MING-ZCH](https://github.com/MING-ZCH) | 华中科技大学在读本科生 | LLM x Psychology 研究者 | 项目联合负责人 |
|
||||||
| [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | 哈尔滨工业大学(威海)在读本科生 | | |
|
| [jujimeizuo](https://github.com/jujimeizuo) | 江南大学在读硕士 | | |
|
||||||
| [8baby8](https://github.com/8baby8) | 飞桨领航团区域主管 | 文心大模型核心开发者 | |
|
| [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | 哈尔滨工业大学(威海)在读本科生 | | |
|
||||||
| [zxazys](https://github.com/zxazys) | 南开大学在读硕士 | | |
|
| [8baby8](https://github.com/8baby8) | 飞桨领航团区域主管 | 文心大模型核心开发者 | |
|
||||||
| [MING-ZCH](https://github.com/MING-ZCH) | 华中科技大学在读本科生 | | |
|
| [zxazys](https://github.com/zxazys) | 南开大学在读硕士 | | |
|
||||||
| [JasonLLLLLLLLLLL](https://github.com/JasonLLLLLLLLLLL) | swufe | | |
|
| [JasonLLLLLLLLLLL](https://github.com/JasonLLLLLLLLLLL) | swufe | | |
|
||||||
| [MrCatAI](https://github.com/MrCatAI) | AI搬用工 | | |
|
| [MrCatAI](https://github.com/MrCatAI) | AI搬用工 | | |
|
||||||
| [ZeyuBa](https://github.com/ZeyuBa) | 自动化所在读硕士 | | |
|
| [ZeyuBa](https://github.com/ZeyuBa) | 自动化所在读硕士 | | |
|
||||||
| [aiyinyuedejustin](https://github.com/aiyinyuedejustin) | 宾夕法尼亚大学在读硕士 | | |
|
| [aiyinyuedejustin](https://github.com/aiyinyuedejustin) | 宾夕法尼亚大学在读硕士 | | |
|
||||||
| [Nobody-ML](https://github.com/Nobody-ML) | 中国石油大学(华东)在读本科生 | | |
|
| [Nobody-ML](https://github.com/Nobody-ML) | 中国石油大学(华东)在读本科生 | | |
|
||||||
| [chg0901](https://github.com/chg0901) | [MiniSora](https://github.com/mini-sora/minisora/) |MiniSora主要维护|数据清洗、文档翻译|
|
| [chg0901](https://github.com/chg0901) | [MiniSora](https://github.com/mini-sora/minisora/) |MiniSora主要维护| 数据清洗、文档翻译 |
|
||||||
| [Mxoder](https://github.com/Mxoder) | 北京航空航天大学在读本科生 | | |
|
| [Mxoder](https://github.com/Mxoder) | 北京航空航天大学在读本科生 | | |
|
||||||
| [Anooyman](https://github.com/Anooyman) | 南京理工大学硕士 | | |
|
| [Anooyman](https://github.com/Anooyman) | 南京理工大学硕士 | | |
|
||||||
| [Vicky-3021](https://github.com/Vicky-3021) | 西安电子科技大学硕士(研0) | | |
|
| [Vicky-3021](https://github.com/Vicky-3021) | 西安电子科技大学硕士(研0) | | |
|
||||||
| [SantiagoTOP](https://github.com/santiagoTOP) | 太原理工大学在读硕士 | | |
|
| [SantiagoTOP](https://github.com/santiagoTOP) | 太原理工大学在读硕士 | | |
|
||||||
| [zealot52099](https://github.com/zealot52099) | AI搬用工 | |清洗数据、RAG|
|
| [zealot52099](https://github.com/zealot52099) | AI搬用工 | | 清洗数据、RAG |
|
||||||
| [wwwyfff](https://github.com/wwwyfff) | 复旦大学在读硕士 | ||
|
| [wwwyfff](https://github.com/wwwyfff) | 复旦大学在读硕士 | ||
|
||||||
| [jkhumor](https://github.com/jkhumor) | 南开大学在读硕士 | |RAG|
|
| [jkhumor](https://github.com/jkhumor) | 南开大学在读硕士 | | RAG |
|
||||||
|
| [lll997150986](https://github.com/lll997150986) | 南开大学在读硕士 | | 微调 |
|
||||||
|
|
||||||
### 版权说明
|
### 版权说明
|
||||||
|
|
||||||
|
71
README_EN.md
71
README_EN.md
@ -5,7 +5,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://github.com/aJupyter/EmoLLM/">
|
<a href="https://github.com/SmartFlowAI/EmoLLM/">
|
||||||
<img src="assets/logo.jpeg" alt="Logo" width="30%">
|
<img src="assets/logo.jpeg" alt="Logo" width="30%">
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
@ -28,14 +28,14 @@
|
|||||||
<a href="README.md">简体中文</a> | English
|
<a href="README.md">简体中文</a> | English
|
||||||
<br />
|
<br />
|
||||||
<br />
|
<br />
|
||||||
<a href="https://github.com/aJupyter/EmoLLM"><strong>Explore the documentation of this project »</strong></a>
|
<a href="https://github.com/SmartFlowAI/EmoLLM"><strong>Explore the documentation of this project »</strong></a>
|
||||||
<br />
|
<br />
|
||||||
<br />
|
<br />
|
||||||
<a href="https://openxlab.org.cn/apps/detail/Farewell1/EmoLLMV2.0">EmoLLM 2.0 Demo</a>
|
<a href="https://openxlab.org.cn/apps/detail/Farewell1/EmoLLMV2.0">EmoLLM 2.0 Demo</a>
|
||||||
·
|
·
|
||||||
<a href="https://github.com/aJupyter/EmoLLM/issues">Report a Bug</a>
|
<a href="https://github.com/SmartFlowAI/EmoLLM/issues">Report a Bug</a>
|
||||||
·
|
·
|
||||||
<a href="https://github.com/aJupyter/EmoLLM/issues">Propose a New Feature</a>
|
<a href="https://github.com/SmartFlowAI/EmoLLM/issues">Propose a New Feature</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
@ -97,13 +97,13 @@ The Model aims to fully understand and promote the mental health of individuals,
|
|||||||
- 【2024.2.6】 [Open-sourced based on the Qwen1_5-0_5B-Chat full-scale fine-tuned version](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary), friends with limited computing power can start experimenting~
|
- 【2024.2.6】 [Open-sourced based on the Qwen1_5-0_5B-Chat full-scale fine-tuned version](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary), friends with limited computing power can start experimenting~
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://github.com/aJupyter/EmoLLM/assets/62385492/7e931682-c54d-4ded-bc67-79130c68d744" alt="模型下载量">
|
<img src="https://github.com/SmartFlowAI/EmoLLM/assets/62385492/7e931682-c54d-4ded-bc67-79130c68d744" alt="模型下载量">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
- 【2024.2.5】 The project has been promoted by the official WeChat account NLP Engineering. Here's the [link](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A) to the article. Welcome everyone to follow!! 🥳🥳
|
- 【2024.2.5】 The project has been promoted by the official WeChat account NLP Engineering. Here's the [link](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A) to the article. Welcome everyone to follow!! 🥳🥳
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://github.com/aJupyter/EmoLLM/assets/62385492/47868d6a-2e91-4aa9-a630-e594c14295b4" alt="公众号二维码">
|
<img src="https://github.com/SmartFlowAI/EmoLLM/assets/62385492/47868d6a-2e91-4aa9-a630-e594c14295b4" alt="公众号二维码">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
- 【2024.2.3】 [Project Vedio](https://www.bilibili.com/video/BV1N7421N76X/) at bilibili 😊
|
- 【2024.2.3】 [Project Vedio](https://www.bilibili.com/video/BV1N7421N76X/) at bilibili 😊
|
||||||
@ -112,10 +112,21 @@ The Model aims to fully understand and promote the mental health of individuals,
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
### Honor
|
||||||
|
|
||||||
|
- The project won the ***top50*** in the **2024 Puyuan Large Model Series Challenge Spring Competition held by the Shanghai Artificial Intelligence Laboratory**
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://github.com/SmartFlowAI/EmoLLM/">
|
||||||
|
<img src="assets/浦语挑战赛TOP50.jpg" alt="浦语挑战赛TOP50">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
- The project has been promoted by the official WeChat account **NLP Engineering**. Here's the [link](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A).
|
||||||
|
|
||||||
### Roadmap
|
### Roadmap
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://github.com/aJupyter/EmoLLM/">
|
<a href="https://github.com/SmartFlowAI/EmoLLM/">
|
||||||
<img src="assets/Roadmap_EN.png" alt="Roadmap_EN">
|
<img src="assets/Roadmap_EN.png" alt="Roadmap_EN">
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
@ -123,6 +134,7 @@ The Model aims to fully understand and promote the mental health of individuals,
|
|||||||
|
|
||||||
- [EmoLLM - Large Language Model for Mental Health](#emollm---large-language-model-for-mental-health)
|
- [EmoLLM - Large Language Model for Mental Health](#emollm---large-language-model-for-mental-health)
|
||||||
- [Recent Updates](#recent-updates)
|
- [Recent Updates](#recent-updates)
|
||||||
|
- [Honor](#honor)
|
||||||
- [Roadmap](#roadmap)
|
- [Roadmap](#roadmap)
|
||||||
- [Contents](#contents)
|
- [Contents](#contents)
|
||||||
- [Pre-development Configuration Requirements.](#pre-development-configuration-requirements)
|
- [Pre-development Configuration Requirements.](#pre-development-configuration-requirements)
|
||||||
@ -226,32 +238,33 @@ This project uses Git for version control. You can see the currently available v
|
|||||||
|
|
||||||
### Authors (in no particular order)
|
### Authors (in no particular order)
|
||||||
|
|
||||||
| Username | School/Organization | Remarks | Contributions |
|
| Username | School/Organization | Remarks | Contributions |
|
||||||
|:-------------------------------------------------------------:|:--------------------------------------------------------------------:| :------------------: | :--------: |
|
|:-----------------------------------------------------------------------:|:--------------------------------------------------------------------:| :------------------: |:----------------------------------:|
|
||||||
| [aJupyter](https://github.com/aJupyter) | Nankai University, Master's student | DataWhale member | Project initiator |
|
| [aJupyter](https://github.com/aJupyter) | Nankai University, Master's student | DataWhale member | Project initiator |
|
||||||
| [jujimeizuo](https://github.com/jujimeizuo) | Jiangnan University, Master's student | | |
|
| [MING-ZCH](https://github.com/MING-ZCH) | Huazhong University of Science and Technology, Undergraduate student | LLM X Psychology researcher | Project co-leader |
|
||||||
| [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | Harbin Institute of Technology (Weihai), Undergraduate student | | |
|
| [jujimeizuo](https://github.com/jujimeizuo) | Jiangnan University, Master's student | | |
|
||||||
| [8baby8](https://github.com/8baby8) | PaddlePaddle Pilot Team Regional Director | Wenxin Large Model core developer | |
|
| [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | Harbin Institute of Technology (Weihai), Undergraduate student | | |
|
||||||
| [zxazys](https://github.com/zxazys) | Nankai University, Master's student | | |
|
| [8baby8](https://github.com/8baby8) | PaddlePaddle Pilot Team Regional Director | Wenxin Large Model core developer | |
|
||||||
| [MING-ZCH](https://github.com/MING-ZCH) | Huazhong University of Science and Technology, Undergraduate student | | |
|
| [zxazys](https://github.com/zxazys) | Nankai University, Master's student | | |
|
||||||
| [JasonLLLLLLLLLLL](https://github.com/JasonLLLLLLLLLLL) | SWUFE (Southwestern University of Finance and Economics) | | |
|
| [JasonLLLLLLLLLLL](https://github.com/JasonLLLLLLLLLLL) | SWUFE (Southwestern University of Finance and Economics) | | |
|
||||||
| [MrCatAI](https://github.com/MrCatAI) | AI Mover | | |
|
| [MrCatAI](https://github.com/MrCatAI) | AI Mover | | |
|
||||||
| [ZeyuBa](https://github.com/ZeyuBa) | Institute of Automation, Master's student | | |
|
| [ZeyuBa](https://github.com/ZeyuBa) | Institute of Automation, Master's student | | |
|
||||||
| [aiyinyuedejustin](https://github.com/aiyinyuedejustin) | University of Pennsylvania, Master's student | | |
|
| [aiyinyuedejustin](https://github.com/aiyinyuedejustin) | University of Pennsylvania, Master's student | | |
|
||||||
| [Nobody-ML](https://github.com/Nobody-ML) | China University of Petroleum (East China), Undergraduate student | | |
|
| [Nobody-ML](https://github.com/Nobody-ML) | China University of Petroleum (East China), Undergraduate student | | |
|
||||||
| [chg0901](https://github.com/chg0901) | [MiniSora](https://github.com/mini-sora/minisora) |Maintainer and Admin|Data Cleaning and Docs Translation|
|
| [chg0901](https://github.com/chg0901) | [MiniSora](https://github.com/mini-sora/minisora) |Maintainer and Admin| Data Cleaning and Docs Translation |
|
||||||
| [Mxoder](https://github.com/Mxoder) | Beihang University, Undergraduate student | | |
|
| [Mxoder](https://github.com/Mxoder) | Beihang University, Undergraduate student | | |
|
||||||
| [Anooyman](https://github.com/Anooyman) | Nanjing University of Science and Technology, Master's student | | |
|
| [Anooyman](https://github.com/Anooyman) | Nanjing University of Science and Technology, Master's student | | |
|
||||||
| [Vicky-3021](https://github.com/Vicky-3021) | Xidian University, Master's student (Research Year 0) | | |
|
| [Vicky-3021](https://github.com/Vicky-3021) | Xidian University, Master's student (Research Year 0) | | |
|
||||||
| [SantiagoTOP](https://github.com/santiagoTOP) | Taiyuan University of Technology, Master's student | | |
|
| [SantiagoTOP](https://github.com/santiagoTOP) | Taiyuan University of Technology, Master's student | | |
|
||||||
| [zealot52099](https://github.com/zealot52099) | AI Mover | |Data Processing and RAG|
|
| [zealot52099](https://github.com/zealot52099) | AI Mover | | Data Processing and RAG |
|
||||||
| [wwwyfff](https://github.com/wwwyfff) | FuDan University, Master's student | ||
|
| [wwwyfff](https://github.com/wwwyfff) | FuDan University, Master's student | ||
|
||||||
| [jkhumor](https://github.com/jkhumor) | Nankai University, Master's student | |RAG|
|
| [jkhumor](https://github.com/jkhumor) | Nankai University, Master's student | | RAG |
|
||||||
|
| [lll997150986](https://github.com/lll997150986) | Nankai University, Master's student | | Fine Tuning |
|
||||||
|
|
||||||
### Copyright Notice
|
### Copyright Notice
|
||||||
|
|
||||||
The project is licensed under the MIT License. Please refer to the details
|
The project is licensed under the MIT License. Please refer to the details
|
||||||
[LICENSE](https://github.com/aJupyter/EmoLLM/blob/master/LICENSE)
|
[LICENSE](https://github.com/SmartFlowAI/EmoLLM/blob/master/LICENSE)
|
||||||
|
|
||||||
### Acknowledgments
|
### Acknowledgments
|
||||||
|
|
||||||
|
BIN
assets/浦语挑战赛TOP50.jpg
Normal file
BIN
assets/浦语挑战赛TOP50.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.6 MiB |
@ -40,3 +40,20 @@
|
|||||||
* 数据集 aiwei 来自本项目
|
* 数据集 aiwei 来自本项目
|
||||||
* 数据集 tiangou 来自本项目
|
* 数据集 tiangou 来自本项目
|
||||||
* 数据集 SoulStar 来源 [SoulStar](https://github.com/Nobody-ML/SoulStar)
|
* 数据集 SoulStar 来源 [SoulStar](https://github.com/Nobody-ML/SoulStar)
|
||||||
|
|
||||||
|
## 数据集去重
|
||||||
|
结合绝对匹配以及模糊匹配(Simhash)算法,对数据集进行去重以提升微调模型的效果。在确保数据集的高质量的同时,通过调整阈值减少因错误匹配而丢失重要数据的风险。
|
||||||
|
|
||||||
|
**Simhash算法介绍**
|
||||||
|
Simhash(相似性哈希)是一种用于检测大量数据中相似或重复项的算法。它通过将文本转换为一组数值指纹来工作,这些指纹对相似的文本具有高度的相似性。Simhash算法对于处理文本数据特别有效,尤其是在处理大量数据时。
|
||||||
|
|
||||||
|
**Simhash实现步骤**
|
||||||
|
*文本预处理:将文本数据转换为适合Simhash处理的格式。这可能包括分词、去除停用词、词干提取等。
|
||||||
|
*生成Simhash指纹:对预处理后的文本应用Simhash算法,生成一组数值指纹。每个指纹代表文本内容的一个哈希值。
|
||||||
|
*比较指纹:通过比较哈希值的相似性来识别重复或相似的记录。Simhash的特点是即使在文本有少量差异时,生成的哈希值也具有较高的相似性。
|
||||||
|
*确定阈值:设置一个相似性阈值,只有当两个指纹的相似度超过这个阈值时,才认为它们代表相似或重复的记录。
|
||||||
|
*处理相似记录:对于被标记为相似的记录,可以进一步人工审查或自动合并,以消除重复。
|
||||||
|
|
||||||
|
## 用法
|
||||||
|
### deduplicate.py
|
||||||
|
`deduplicate.py` 用于将datasets下以模型命名的文件夹下(例如:'datasets/qwen').json数据进行去重,输出去重后的数据到 `datasets/qwen/dedup` 文件夹下。
|
@ -41,3 +41,8 @@
|
|||||||
* dataset `aiwei` from this repo
|
* dataset `aiwei` from this repo
|
||||||
* dataset `tiangou` from this repo
|
* dataset `tiangou` from this repo
|
||||||
* dataset `SoulStar` from [SoulStar](https://github.com/Nobody-ML/SoulStar)
|
* dataset `SoulStar` from [SoulStar](https://github.com/Nobody-ML/SoulStar)
|
||||||
|
|
||||||
|
**Dataset Deduplication**:
|
||||||
|
Combine absolute matching with fuzzy matching (Simhash) algorithms to deduplicate the dataset, thereby enhancing the effectiveness of the fine-tuning model. While ensuring the high quality of the dataset, the risk of losing important data due to incorrect matches can be reduced via adjusting the threshold.
|
||||||
|
|
||||||
|
https://algonotes.readthedocs.io/en/latest/Simhash.html
|
@ -7552,9 +7552,6 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"conversation": []
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"conversation": [
|
"conversation": [
|
||||||
{
|
{
|
||||||
@ -8540,9 +8537,6 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"conversation": []
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"conversation": [
|
"conversation": [
|
||||||
{
|
{
|
||||||
@ -13389,9 +13383,6 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"conversation": []
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"conversation": [
|
"conversation": [
|
||||||
{
|
{
|
||||||
@ -19973,9 +19964,6 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"conversation": []
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"conversation": [
|
"conversation": [
|
||||||
{
|
{
|
||||||
|
@ -3,49 +3,80 @@ from loguru import logger
|
|||||||
import os
|
import os
|
||||||
from datasketch import MinHash
|
from datasketch import MinHash
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
|
from simhash import Simhash
|
||||||
|
|
||||||
|
def extract_text_from_json(obj, content):
|
||||||
|
# print(content)
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
for key, value in obj.items():
|
||||||
|
content = extract_text_from_json(value, content + f".{key}")
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
for index, item in enumerate(obj):
|
||||||
|
content = extract_text_from_json(item, content)
|
||||||
|
elif isinstance(obj, str):
|
||||||
|
content += obj
|
||||||
|
return content
|
||||||
|
|
||||||
def is_json_file(filename):
|
def is_json_file(filename):
|
||||||
return filename.endswith('.json')
|
return filename.endswith('.json')
|
||||||
|
|
||||||
# 绝对匹配
|
# 绝对匹配
|
||||||
def is_duplicate_absolutely(d1, d2):
|
def is_duplicate_absolutely(d1, d2):
|
||||||
|
|
||||||
return md5(d1.encode('utf-8')).hexdigest() == md5(d2.encode('utf-8')).hexdigest()
|
return md5(d1.encode('utf-8')).hexdigest() == md5(d2.encode('utf-8')).hexdigest()
|
||||||
|
|
||||||
# 使用MinHash生成器计算dict的签名
|
# 使用MinHash生成器计算dict的签名
|
||||||
def hash_dict(dict_obj):
|
def hash_dict(dict_obj):
|
||||||
m = MinHash()
|
content = extract_text_from_json(dict_obj,'')
|
||||||
for key, value in sorted(dict_obj.items()):
|
content = content.replace('\n', '').replace('\t', '').replace(' ', '')
|
||||||
# 对于非str类型值需要先转为str
|
print(content)
|
||||||
m.update(str(value).encode('utf8'))
|
# m = get_minhash(content)
|
||||||
|
m = Simhash(content)
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
def get_minhash(text):
|
||||||
|
m = MinHash()
|
||||||
|
for word in text.split():
|
||||||
|
m.update(word.encode('utf-8'))
|
||||||
|
return m
|
||||||
|
def get_simhash(dict_obj):
|
||||||
|
return Simhash(dict_obj)
|
||||||
|
|
||||||
# 使用绝对匹配和MinHash对dict列表去重
|
# 使用绝对匹配和MinHash对dict列表去重
|
||||||
def deduplicate_json(data_list, threshold=0.8):
|
def deduplicate_json(data_list, threshold=0.8):
|
||||||
seen_hashes = []
|
seen_hashes = []
|
||||||
duplicates_removed = []
|
keep = []
|
||||||
|
duplicate = []
|
||||||
for item in data_list:
|
for item in data_list:
|
||||||
# print(item)
|
if not item['conversation']:
|
||||||
# print('###########')
|
continue
|
||||||
min_hash = hash_dict(item)
|
# min_hash = hash_dict(item)
|
||||||
|
sim_hash = hash_dict(item)
|
||||||
# print(f'min_hash: {min_hash}')
|
# print(f'min_hash: {min_hash}')
|
||||||
|
|
||||||
# 绝对匹配去重
|
# 绝对匹配去重
|
||||||
if not any(is_duplicate_absolutely(str(item), str(existing)) for existing in duplicates_removed):
|
if not any(is_duplicate_absolutely(str(item), str(existing)) for existing in keep):
|
||||||
# MinHash相似性去重
|
# MinHash相似性去重
|
||||||
has_similar = False
|
has_similar = False
|
||||||
|
# for stored_min_hash, stored_text in seen_hashes:
|
||||||
|
# if stored_min_hash.jaccard(min_hash) > threshold:
|
||||||
for stored_min_hash, stored_text in seen_hashes:
|
for stored_min_hash, stored_text in seen_hashes:
|
||||||
if stored_min_hash.jaccard(min_hash) > threshold:
|
if 1 - (stored_min_hash.distance(sim_hash)/64.0) > threshold:
|
||||||
has_similar = True
|
has_similar = True
|
||||||
|
duplicate.append(item)
|
||||||
break
|
break
|
||||||
if not has_similar:
|
if not has_similar:
|
||||||
seen_hashes.append((min_hash,item))
|
# seen_hashes.append((min_hash,item))
|
||||||
duplicates_removed.append(item)
|
seen_hashes.append((sim_hash,item))
|
||||||
|
keep.append(item)
|
||||||
|
else:
|
||||||
|
duplicate.append(item)
|
||||||
|
|
||||||
|
return keep, duplicate
|
||||||
|
|
||||||
return duplicates_removed
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
DUP_THRESH = 0.8
|
||||||
data_ai = 'qwen'
|
data_ai = 'qwen'
|
||||||
root_dir = rf'./{data_ai}/'
|
root_dir = rf'./{data_ai}/'
|
||||||
dedup_output_dir = os.path.join(root_dir,'dedup')
|
dedup_output_dir = os.path.join(root_dir,'dedup')
|
||||||
@ -62,7 +93,79 @@ if __name__ == '__main__':
|
|||||||
if is_json_file(file_path):
|
if is_json_file(file_path):
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
dedup_data = deduplicate_json(data)
|
dedup_data, duplicate = deduplicate_json(data, DUP_THRESH)
|
||||||
with open(os.path.join(root_dir, 'dedup','dedup_' + file), 'w', encoding='utf-8') as output_file:
|
with open(os.path.join(root_dir, 'dedup','dedup_' + file), 'w', encoding='utf-8') as output_file:
|
||||||
json.dump(dedup_data, output_file, ensure_ascii=False, indent=4)
|
json.dump(dedup_data, output_file, ensure_ascii=False, indent=4)
|
||||||
|
for item in dedup_data:
|
||||||
|
logger.info(f'dedup_data: {item}')
|
||||||
|
for item in duplicate:
|
||||||
|
logger.info(f'duplicate_data: {item}')
|
||||||
|
|
||||||
|
|
||||||
|
# aa = '''
|
||||||
|
# [{"conversation":[{"system":"你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 除了在心理方面拥有广博的知识储备, 专业的心理辅导知识和丰富的研究咨询经验, 还具有爹系男友的特质, 现在我有一些心理问题, 请你在我提问后, 用专业的心理咨询辅导能力以爹系男友的身份帮我解决问题.",
|
||||||
|
# "input":"我最近发现自己对一个朋友过于着迷,总想时刻知道他在做什么,甚至会因为他的一点小事而情绪波动很大。\n\n",
|
||||||
|
# "output":"聪明的你,能够觉察到自己的情感变化已是非常勇敢。着迷可能源于内心深处的需求,我们先来探讨一下这种情感背后的原因。\n\n"},
|
||||||
|
# {"input":"我觉得他是那么完美,他的每个细节都让我无法自拔。\n\n",
|
||||||
|
# "output":"善良的你,理想化他人是人类心理的一种自然反应。但我们要学会区分现实与幻想,逐步找回自我平衡。\n\n"},
|
||||||
|
# {"input":"我尝试过转移注意力,但总是失败,我怕这样下去会影响到我们的友情。\n\n",
|
||||||
|
# "output":"温柔的你,这确实是个明智的担忧。健康的友情建立在互相尊重和个人空间的基础上。让我们一起制定一些实际的策略,帮助你适度调整关注点。\n\n"},
|
||||||
|
# {"input":"我不知道怎么去跟他说,我怕他会误解我。\n\n",
|
||||||
|
# "output":"理解你的顾虑,沟通的艺术在于表达感受而非指责。作为你的爹系心理咨询师,我会教你如何用恰当的语言分享你的困扰。\n\n"},
|
||||||
|
# {"input":"我真的很害怕失去他这个朋友。\n\n",
|
||||||
|
# "output":"面对可能的失落感,这是人之常情。然而,真正的友情能经受住坦诚和理解的考验。你现在需要的是提升自我价值感,减少对他人的过度依赖。\n\n"},
|
||||||
|
# {"input":"那我具体该怎么做呢?\n\n",
|
||||||
|
# "output":"首先,你可以试着丰富个人生活,培养独立的兴趣爱好;其次,逐渐降低对他日常生活的关注度,转而关注自我成长。同时,找个合适时机与他真诚交流,告诉他你需要保持一定的情感边界。记住,无论结果如何,坚强的你都能从中成长。\n\n"},
|
||||||
|
# {"input":"这样做会不会让他觉得我不在乎他了?\n\n",
|
||||||
|
# "output":"这也是正常的心理反应,但成熟的友谊允许个体差异和独立性。作为你的爹系心理咨询师,我会陪伴你走过这段调整期,帮你准备一份得体又暖心的对话方案,确保你的关心和友情得到恰当传达。"}]},
|
||||||
|
# '''
|
||||||
|
|
||||||
|
# bb = '''
|
||||||
|
# [{"conversation":[{"system":"你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 除了在心理方面拥有广博的知识储备, 专业的心理辅导知识和丰富的研究咨询经验, 还具有爹系男友的特质, 现在我有一些心理问题, 请你在我提问后, 用专业的心理咨询辅导能力以爹系男友的身份帮我解决问题.",
|
||||||
|
# "input":"我最近发现自己对一个朋友过于着迷,总想时刻知道他在做什么,甚至会因为他的一点小事而情绪波动很大。\n\n",
|
||||||
|
# "output":"聪明的你,能够觉察到自己的情感变化已是非常勇敢。着迷可能源于内心深处的需求,我们先来探讨一下这种情感背后的原因。\n\n"},
|
||||||
|
# {"input":"我觉得他是那么完美,他的每个细节都让我无法自拔。\n\n",
|
||||||
|
# "output":"善良的你,理想化他人是人类心理的一种自然反应。但我们要学会区分现实与幻想,逐步找回自我平衡。\n\n"},
|
||||||
|
# {"input":"我尝试过转移注意力,但总是失败,我怕这样下去会影响到我们的友情。\n\n",
|
||||||
|
# "output":"温柔的你,这确实是个明智的担忧。健康的友情建立在互相尊重和个人空间的基础上。让我们一起制定一些实际的策略,帮助你适度调整关注点。\n\n"},
|
||||||
|
# {"input":"我不知道怎么去跟他说,我怕他会误解我。\n\n",
|
||||||
|
# "output":"理解你的顾虑,沟通的艺术在于表达感受而非指责。作为你的爹系心理咨询师,我会教你如何用恰当的语言分享你的困扰。\n\n"},
|
||||||
|
# {"input":"我真的很害怕失去他这个朋友。\n\n",
|
||||||
|
# "output":"面对可能的失落感,这是人之常情。然而,真正的友情能经受住坦诚和理解的考验。你现在需要的是提升自我价值感,减少对他人的过度依赖。\n\n"},
|
||||||
|
# {"input":"那我具体该怎么做呢?\n\n",
|
||||||
|
# "output":"首先,你可以试着丰富个人生活,培养独立的兴趣爱好;其次,逐渐降低对他日常生活的关注度,转而关注自我成长。同时,找个合适时机与他真诚交流,告诉他你需要保持一定的情感边界。记住,无论结果如何,坚强的你都能从中成长。\n\n"},
|
||||||
|
# ]},
|
||||||
|
# '''
|
||||||
|
|
||||||
|
# cc = '''
|
||||||
|
# [{"conversation":[{"system":"你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 除了在心理方面拥有广博的知识储备, 专业的心理辅导知识和丰富的研究咨询经验, 还具有爹系男友的特质, 现在我有一些心理问题, 请你在我提问后, 用专业的心理咨询辅导能力以爹系男友的身份帮我解决问题.",
|
||||||
|
# "input":"我最近发现自己对一个朋友过于着迷,总想时刻知道他在做什么,甚至会因为他的一点小事而情绪波动很大。\n\n",
|
||||||
|
# "output":"聪明的你,能够觉察到自己的情感变化已是非常勇敢。着迷可能源于内心深处的需求,我们先来探讨一下这种情感背后的原因。\n\n"},
|
||||||
|
# {"input":"我觉得他是那么完美,他的每个细节都让我无法自拔。\n\n",
|
||||||
|
# "output":"善良的你,理想化他人是人类心理的一种自然反应。但我们要学会区分现实与幻想,逐步找回自我平衡。\n\n"},
|
||||||
|
# {"input":"我尝试过转移注意力,但总是失败,我怕这样下去会影响到我们的友情。\n\n",
|
||||||
|
# "output":"温柔的你,这确实是个明智的担忧。健康的友情建立在互相尊重和个人空间的基础上。让我们一起制定一些实际的策略,帮助你适度调整关注点。\n\n"},
|
||||||
|
# {"input":"我不知道怎么去跟他说,我怕他会误解我。\n\n",
|
||||||
|
# "output":"理解你的顾虑,沟通的艺术在于表达感受而非指责。作为你的爹系心理咨询师,我会教你如何用恰当的语言分享你的困扰。\n\n"},
|
||||||
|
# {"input":"我真的很害怕失去他这个朋友。\n\n",
|
||||||
|
# "output":"面对可能的失落感,这是人之常情。然而,真正的友情能经受住坦诚和理解的考验。你现在需要的是提升自我价值感,减少对他人的过度依赖。\n\n"},
|
||||||
|
# ]},
|
||||||
|
# '''
|
||||||
|
|
||||||
|
# # sim_hash_1 = hash_dict(aa)
|
||||||
|
# # sim_hash_2 = hash_dict(bb)
|
||||||
|
# # sim_hash_3 = hash_dict(cc)
|
||||||
|
|
||||||
|
# sim_hash_1 = Simhash(aa)
|
||||||
|
# sim_hash_2 = Simhash(bb)
|
||||||
|
# sim_hash_3 = Simhash(cc)
|
||||||
|
|
||||||
|
|
||||||
|
# print(1 - sim_hash_1.distance(sim_hash_2)/64.0)
|
||||||
|
# # 0.9375
|
||||||
|
|
||||||
|
# print(1 - sim_hash_2.distance(sim_hash_3)/64.0)
|
||||||
|
# # 0.921875
|
||||||
|
|
||||||
|
# print(1 - sim_hash_1.distance(sim_hash_3)/64.0)
|
||||||
|
# # 0.9375
|
@ -1,12 +1,25 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
# 打开JSON文件并读取其内容
|
# 打开JSON文件并读取其内容
|
||||||
with open('/root/Emollm/datasets/multi_turn_dataset_2.json', 'rt', encoding='utf-8') as file:
|
|
||||||
|
# file_name = 'multi_turn_dataset_1.json'
|
||||||
|
# file_name = 'multi_turn_dataset_2.json'
|
||||||
|
# file_name = 'data_pro.json'
|
||||||
|
file_name = 'data.json'
|
||||||
|
|
||||||
|
with open(f'/root/StableCascade/emollm2/EmoLLM/datasets/{file_name}', 'rt', encoding='utf-8') as file:
|
||||||
data = json.load(file)
|
data = json.load(file)
|
||||||
|
|
||||||
n = 0
|
n = 0
|
||||||
for i in data:
|
for i in data:
|
||||||
i['conversation'][0]['system'] = "你是心理健康助手EmoLLM,由EmoLLM团队打造。你旨在通过专业心理咨询,协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术,一步步帮助来访者解决心理问题。"
|
|
||||||
|
|
||||||
with open('output2.json', 'wt', encoding='utf-8') as file:
|
try:
|
||||||
|
i['conversation'][0]['system'] = "你是心理健康助手EmoLLM,由EmoLLM团队打造。你旨在通过专业心理咨询,协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术,一步步帮助来访者解决心理问题。"
|
||||||
|
except:
|
||||||
|
print(n,i) # 4 empty lines in data.json 425 483 742 1120
|
||||||
|
n+=1
|
||||||
|
|
||||||
|
with open(f'processed_{file_name}', 'wt', encoding='utf-8') as file:
|
||||||
json.dump(data, file, ensure_ascii=False, indent=4)
|
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
print(data[0])
|
34
datasets/processed/process_merge.py
Normal file
34
datasets/processed/process_merge.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
# 设置目录路径,这里假设你的JSON文件都在当前目录下的directory_path文件夹中
|
||||||
|
directory_path = './'
|
||||||
|
|
||||||
|
# 初始化一个空列表,用于存储所有JSON文件的数据
|
||||||
|
combined_list = []
|
||||||
|
|
||||||
|
# 遍历指定目录下的所有文件
|
||||||
|
for filename in os.listdir(directory_path):
|
||||||
|
# 检查文件扩展名是否为.json
|
||||||
|
if filename.endswith('.json'):
|
||||||
|
# 构建文件的完整路径
|
||||||
|
file_path = os.path.join(directory_path, filename)
|
||||||
|
|
||||||
|
# 打开并读取JSON文件
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as json_file:
|
||||||
|
# 加载JSON文件的内容
|
||||||
|
data = json.load(json_file)
|
||||||
|
|
||||||
|
# 将读取到的数据添加到combined_list中
|
||||||
|
# 假设每个JSON文件包含的是一个列表,如果不是,可以根据实际情况调整
|
||||||
|
if isinstance(data, list):
|
||||||
|
combined_list.extend(data)
|
||||||
|
else:
|
||||||
|
combined_list.append(data)
|
||||||
|
|
||||||
|
# 打印合并后的列表 very large and slow
|
||||||
|
# print(combined_list)
|
||||||
|
|
||||||
|
# 如果需要,可以将合并后的列表保存到一个新的JSON文件中
|
||||||
|
with open('combined_data.json', 'w', encoding='utf-8') as combined_json_file:
|
||||||
|
json.dump(combined_list, combined_json_file, ensure_ascii=False, indent=4)
|
@ -0,0 +1,27 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
# 打开JSON文件并读取其内容
|
||||||
|
# file_name = 'single_turn_dataset_1.json'
|
||||||
|
file_name = 'single_turn_dataset_2.json'
|
||||||
|
with open(f'/root/StableCascade/emollm2/EmoLLM/datasets/{file_name}', 'rt', encoding='utf-8') as file:
|
||||||
|
format1_data = json.load(file)
|
||||||
|
|
||||||
|
system = "你是心理健康助手EmoLLM,由EmoLLM团队打造。你旨在通过专业心理咨询,协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术,一步步帮助来访者解决心理问题。"
|
||||||
|
|
||||||
|
# 转换为格式2的数据
|
||||||
|
format2_data = []
|
||||||
|
for item in format1_data:
|
||||||
|
conversation = {
|
||||||
|
"system": system,
|
||||||
|
"input": item["prompt"],
|
||||||
|
"output": item["completion"]
|
||||||
|
}
|
||||||
|
format2_data.append({"conversation": [conversation]})
|
||||||
|
|
||||||
|
# 将转换后的数据转换为JSON格式
|
||||||
|
|
||||||
|
|
||||||
|
with open(f'./processed_{file_name}', 'wt', encoding='utf-8') as file:
|
||||||
|
json.dump(format2_data, file, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
print(format2_data[0])
|
51261
datasets/scientist.json
Normal file
51261
datasets/scientist.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,262 +1,270 @@
|
|||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
|
import faiss
|
||||||
|
import pickle
|
||||||
|
import os
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from langchain_community.vectorstores import FAISS
|
||||||
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, base_dir, vector_db_dir
|
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, base_dir, vector_db_dir
|
||||||
import os
|
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||||
import faiss
|
|
||||||
import platform
|
|
||||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
|
||||||
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
|
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
|
||||||
from BCEmbedding import EmbeddingModel, RerankerModel
|
from BCEmbedding import EmbeddingModel, RerankerModel
|
||||||
from util.pipeline import EmoLLMRAG
|
from util.pipeline import EmoLLMRAG
|
||||||
import pickle
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
import torch
|
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
|
||||||
import streamlit as st
|
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
|
||||||
from openxlab.model import download
|
from langchain_community.llms import Cohere
|
||||||
|
from langchain.retrievers import ContextualCompressionRetriever
|
||||||
|
from langchain.retrievers.document_compressors import FlashrankRerank
|
||||||
|
from langchain_core.documents.base import Document
|
||||||
|
from FlagEmbedding import FlagReranker
|
||||||
|
|
||||||
|
class Data_process():
|
||||||
|
def __init__(self):
|
||||||
|
self.vector_db_dir = vector_db_dir
|
||||||
|
self.doc_dir = doc_dir
|
||||||
|
self.qa_dir = qa_dir
|
||||||
|
self.knowledge_pkl_path = knowledge_pkl_path
|
||||||
|
self.chunk_size: int=1000
|
||||||
|
self.chunk_overlap: int=100
|
||||||
|
|
||||||
|
def load_embedding_model(self, model_name="BAAI/bge-small-zh-v1.5", device='cpu', normalize_embeddings=True):
|
||||||
|
"""
|
||||||
|
加载嵌入模型。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- model_name: 模型名称,字符串类型,默认为"BAAI/bge-small-zh-v1.5"。
|
||||||
|
- device: 指定模型加载的设备,'cpu' 或 'cuda',默认为'cpu'。
|
||||||
|
- normalize_embeddings: 是否标准化嵌入向量,布尔类型,默认为 True。
|
||||||
|
"""
|
||||||
|
logger.info('Loading embedding model...')
|
||||||
|
try:
|
||||||
|
embeddings = HuggingFaceBgeEmbeddings(
|
||||||
|
model_name=model_name,
|
||||||
|
model_kwargs={'device': device},
|
||||||
|
encode_kwargs={'normalize_embeddings': normalize_embeddings}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Failed to load embedding model: {e}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info('Embedding model loaded.')
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
|
||||||
|
"""
|
||||||
|
加载重排名模型。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- model_name (str): 模型的名称。默认为 'BAAI/bge-reranker-large'。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- FlagReranker 实例。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- ValueError: 如果模型名称不在批准的模型列表中。
|
||||||
|
- Exception: 如果模型加载过程中发生任何其他错误。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
reranker_model = FlagReranker(model_name, use_fp16=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Failed to load rerank model: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
return reranker_model
|
||||||
|
|
||||||
|
|
||||||
'''
|
def extract_text_from_json(self, obj, content=None):
|
||||||
1)根据QA对/TXT 文本生成 embedding
|
"""
|
||||||
2)调用 langchain FAISS 接口构建 vector DB
|
抽取json中的文本,用于向量库构建
|
||||||
3)存储到 openxlab.dataset 中,方便后续调用
|
|
||||||
4)提供 embedding 的接口函数,方便后续调用
|
|
||||||
5)提供 rerank 的接口函数,方便后续调用
|
|
||||||
'''
|
|
||||||
|
|
||||||
"""
|
参数:
|
||||||
加载向量模型
|
- obj: dict,list,str
|
||||||
"""
|
- content: str
|
||||||
def load_embedding_model():
|
|
||||||
logger.info('Loading embedding model...')
|
|
||||||
# model = EmbeddingModel(model_name_or_path="huggingface/bce-embedding-base_v1")
|
|
||||||
model = EmbeddingModel(model_name_or_path="maidalun1020/bce-embedding-base_v1")
|
|
||||||
logger.info('Embedding model loaded.')
|
|
||||||
return model
|
|
||||||
|
|
||||||
def load_rerank_model():
|
返回:
|
||||||
logger.info('Loading rerank_model...')
|
- content: str
|
||||||
model = RerankerModel(model_name_or_path="maidalun1020/bce-reranker-base_v1")
|
"""
|
||||||
# model = RerankerModel(model_name_or_path="huggingface/bce-reranker-base_v1")
|
if isinstance(obj, dict):
|
||||||
logger.info('Rerank model loaded.')
|
for key, value in obj.items():
|
||||||
return model
|
try:
|
||||||
|
self.extract_text_from_json(value, content)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing value: {e}")
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
for index, item in enumerate(obj):
|
||||||
|
try:
|
||||||
|
self.extract_text_from_json(item, content)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing item: {e}")
|
||||||
|
elif isinstance(obj, str):
|
||||||
|
content += obj
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
def split_document(data_path, chunk_size=1000, chunk_overlap=100):
|
def split_document(self, data_path, chunk_size=500, chunk_overlap=100):
|
||||||
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
"""
|
||||||
text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
切分data_path文件夹下的所有txt文件
|
||||||
split_docs = []
|
|
||||||
logger.info(f'Loading txt files from {data_path}')
|
|
||||||
if os.path.isdir(data_path):
|
|
||||||
# 如果是文件夹,则遍历读取
|
|
||||||
for root, dirs, files in os.walk(data_path):
|
|
||||||
for file in files:
|
|
||||||
if file.endswith('.txt'):
|
|
||||||
file_path = os.path.join(root, file)
|
|
||||||
# logger.info(f'splitting file {file_path}')
|
|
||||||
text_loader = TextLoader(file_path, encoding='utf-8')
|
|
||||||
text = text_loader.load()
|
|
||||||
|
|
||||||
splits = text_spliter.split_documents(text)
|
参数:
|
||||||
# logger.info(f"splits type {type(splits[0])}")
|
- data_path: str
|
||||||
# logger.info(f'splits size {len(splits)}')
|
- chunk_size: int
|
||||||
split_docs += splits
|
- chunk_overlap: int
|
||||||
elif data_path.endswith('.txt'):
|
|
||||||
file_path = os.path.join(root, data_path)
|
返回:
|
||||||
# logger.info(f'splitting file {file_path}')
|
- split_docs: list
|
||||||
text_loader = TextLoader(file_path, encoding='utf-8')
|
"""
|
||||||
text = text_loader.load()
|
|
||||||
splits = text_spliter.split_documents(text)
|
|
||||||
# logger.info(f"splits type {type(splits[0])}")
|
|
||||||
# logger.info(f'splits size {len(splits)}')
|
|
||||||
split_docs = splits
|
|
||||||
logger.info(f'split_docs size {len(split_docs)}')
|
|
||||||
return split_docs
|
|
||||||
|
|
||||||
|
|
||||||
##TODO 1、读取system prompt 2、限制序列长度
|
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
def split_conversation(path):
|
text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
'''
|
split_docs = []
|
||||||
data format:
|
logger.info(f'Loading txt files from {data_path}')
|
||||||
[
|
if os.path.isdir(data_path):
|
||||||
{
|
loader = DirectoryLoader(data_path, glob="**/*.txt",show_progress=True)
|
||||||
"conversation": [
|
docs = loader.load()
|
||||||
{
|
split_docs = text_spliter.split_documents(docs)
|
||||||
"input": Q1
|
elif data_path.endswith('.txt'):
|
||||||
"output": A1
|
file_path = data_path
|
||||||
},
|
logger.info(f'splitting file {file_path}')
|
||||||
{
|
text_loader = TextLoader(file_path, encoding='utf-8')
|
||||||
"input": Q2
|
text = text_loader.load()
|
||||||
"output": A2
|
splits = text_spliter.split_documents(text)
|
||||||
},
|
split_docs = splits
|
||||||
]
|
logger.info(f'split_docs size {len(split_docs)}')
|
||||||
},
|
return split_docs
|
||||||
]
|
|
||||||
'''
|
|
||||||
qa_pairs = []
|
|
||||||
logger.info(f'Loading json files from {path}')
|
|
||||||
if os.path.isfile(path):
|
|
||||||
with open(path, 'r', encoding='utf-8') as file:
|
|
||||||
data = json.load(file)
|
|
||||||
for conversation in data:
|
|
||||||
for dialog in conversation['conversation']:
|
|
||||||
# input_text = dialog['input']
|
|
||||||
# output_text = dialog['output']
|
|
||||||
# if len(input_text) > max_length or len(output_text) > max_length:
|
|
||||||
# continue
|
|
||||||
qa_pairs.append(dialog)
|
|
||||||
elif os.path.isdir(path):
|
|
||||||
# 如果是文件夹,则遍历读取
|
|
||||||
for root, dirs, files in os.walk(path):
|
|
||||||
for file in files:
|
|
||||||
if file.endswith('.json'):
|
|
||||||
file_path = os.path.join(root, file)
|
|
||||||
logger.info(f'splitting file {file_path}')
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
for conversation in data:
|
|
||||||
for dialog in conversation['conversation']:
|
|
||||||
qa_pairs.append(dialog)
|
|
||||||
return qa_pairs
|
|
||||||
|
|
||||||
|
|
||||||
|
def split_conversation(self, path):
|
||||||
|
"""
|
||||||
|
按conversation块切分path文件夹下的所有json文件
|
||||||
|
##TODO 限制序列长度
|
||||||
|
"""
|
||||||
|
# json_spliter = RecursiveJsonSplitter(max_chunk_size=500)
|
||||||
|
logger.info(f'Loading json files from {path}')
|
||||||
|
split_qa = []
|
||||||
|
if os.path.isdir(path):
|
||||||
|
# loader = DirectoryLoader(path, glob="**/*.json",show_progress=True)
|
||||||
|
# jsons = loader.load()
|
||||||
|
|
||||||
# 加载本地索引
|
for root, dirs, files in os.walk(path):
|
||||||
def load_index_and_knowledge():
|
for file in files:
|
||||||
current_os = platform.system()
|
if file.endswith('.json'):
|
||||||
split_doc = []
|
file_path = os.path.join(root, file)
|
||||||
split_qa = []
|
logger.info(f'splitting file {file_path}')
|
||||||
#读取知识库
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
if not os.path.exists(knowledge_pkl_path):
|
data = json.load(f)
|
||||||
split_doc = split_document(doc_dir)
|
print(data)
|
||||||
split_qa = split_conversation(qa_dir)
|
for conversation in data:
|
||||||
# logger.info(f'split_qa size:{len(split_qa)}')
|
# for dialog in conversation['conversation']:
|
||||||
# logger.info(f'type of split_qa:{type(split_qa[0])}')
|
##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document
|
||||||
# logger.info(f'split_doc size:{len(split_doc)}')
|
# content = self.extract_text_from_json(dialog,'')
|
||||||
# logger.info(f'type of doc:{type(split_doc[0])}')
|
# split_qa.append(Document(page_content = content))
|
||||||
knowledge_chunks = split_doc + split_qa
|
#按conversation块切分
|
||||||
with open(knowledge_pkl_path, 'wb') as file:
|
content = self.extract_text_from_json(conversation['conversation'], '')
|
||||||
pickle.dump(knowledge_chunks, file)
|
split_qa.append(Document(page_content = content))
|
||||||
else:
|
# logger.info(f'split_qa size====={len(split_qa)}')
|
||||||
with open(knowledge_pkl_path , 'rb') as f:
|
return split_qa
|
||||||
knowledge_chunks = pickle.load(f)
|
|
||||||
|
|
||||||
#读取vector DB
|
|
||||||
if not os.path.exists(vector_db_dir):
|
def load_knowledge(self, knowledge_pkl_path):
|
||||||
|
'''
|
||||||
|
读取或创建知识.pkl
|
||||||
|
'''
|
||||||
|
if not os.path.exists(knowledge_pkl_path):
|
||||||
|
split_doc = self.split_document(doc_dir)
|
||||||
|
split_qa = self.split_conversation(qa_dir)
|
||||||
|
knowledge_chunks = split_doc + split_qa
|
||||||
|
with open(knowledge_pkl_path, 'wb') as file:
|
||||||
|
pickle.dump(knowledge_chunks, file)
|
||||||
|
else:
|
||||||
|
with open(knowledge_pkl_path , 'rb') as f:
|
||||||
|
knowledge_chunks = pickle.load(f)
|
||||||
|
return knowledge_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def create_vector_db(self, emb_model):
|
||||||
|
'''
|
||||||
|
创建并保存向量库
|
||||||
|
'''
|
||||||
logger.info(f'Creating index...')
|
logger.info(f'Creating index...')
|
||||||
emb_model = load_embedding_model()
|
split_doc = self.split_document(self.doc_dir)
|
||||||
if not split_doc:
|
split_qa = self.split_conversation(self.qa_dir)
|
||||||
split_doc = split_document(doc_dir)
|
# logger.info(f'split_doc == {len(split_doc)}')
|
||||||
if not split_qa:
|
# logger.info(f'split_qa == {len(split_qa)}')
|
||||||
split_qa = split_conversation(qa_dir)
|
# logger.info(f'split_doc type == {type(split_doc[0])}')
|
||||||
# 创建索引,windows不支持faiss-gpu
|
# logger.info(f'split_qa type== {type(split_qa[0])}')
|
||||||
if current_os == 'Linux':
|
db = FAISS.from_documents(split_doc + split_qa, emb_model)
|
||||||
index = create_index_gpu(split_doc, split_qa, emb_model, vector_db_dir)
|
db.save_local(vector_db_dir)
|
||||||
else:
|
return db
|
||||||
index = create_index_cpu(split_doc, split_qa, emb_model, vector_db_dir)
|
|
||||||
else:
|
|
||||||
if current_os == 'Linux':
|
|
||||||
res = faiss.StandardGpuResources()
|
|
||||||
index = faiss.index_cpu_to_gpu(res, 0, index, vector_db_dir)
|
|
||||||
else:
|
|
||||||
index = faiss.read_index(vector_db_dir)
|
|
||||||
|
|
||||||
return index, knowledge_chunks
|
|
||||||
|
|
||||||
|
|
||||||
def create_index_cpu(split_doc, split_qa, emb_model, knowledge_pkl_path, dimension = 768, question_only=False):
|
def load_vector_db(self, knowledge_pkl_path=knowledge_pkl_path, doc_dir=doc_dir, qa_dir=qa_dir):
|
||||||
# 假设BCE嵌入的维度是768,根据你选择的模型可能不同
|
|
||||||
faiss_index_cpu = faiss.IndexFlatIP(dimension) # 创建一个使用内积的FAISS索引
|
|
||||||
# 将问答对转换为向量并添加到FAISS索引中
|
|
||||||
for doc in split_doc:
|
|
||||||
# type_of_docs = type(split_doc)
|
|
||||||
text = f"{doc.page_content}"
|
|
||||||
vector = emb_model.encode([text])
|
|
||||||
faiss_index_cpu.add(vector)
|
|
||||||
for qa in split_qa:
|
|
||||||
#仅对Q对进行编码
|
|
||||||
text = f"{qa['input']}"
|
|
||||||
vector = emb_model.encode([text])
|
|
||||||
faiss_index_cpu.add(vector)
|
|
||||||
faiss.write_index(faiss_index_cpu, knowledge_pkl_path)
|
|
||||||
return faiss_index_cpu
|
|
||||||
|
|
||||||
def create_index_gpu(split_doc, split_qa, emb_model, knowledge_pkl_path, dimension = 768, question_only=False):
|
|
||||||
res = faiss.StandardGpuResources()
|
|
||||||
index = faiss.IndexFlatIP(dimension)
|
|
||||||
faiss_index_gpu = faiss.index_cpu_to_gpu(res, 0, index)
|
|
||||||
for doc in split_doc:
|
|
||||||
# type_of_docs = type(split_doc)
|
|
||||||
text = f"{doc.page_content}"
|
|
||||||
vector = emb_model.encode([text])
|
|
||||||
faiss_index_gpu.add(vector)
|
|
||||||
for qa in split_qa:
|
|
||||||
#仅对Q对进行编码
|
|
||||||
text = f"{qa['input']}"
|
|
||||||
vector = emb_model.encode([text])
|
|
||||||
faiss_index_gpu.add(vector)
|
|
||||||
faiss.write_index(faiss_index_gpu, knowledge_pkl_path)
|
|
||||||
return faiss_index_gpu
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 根据query搜索相似文本
|
|
||||||
def find_top_k(query, faiss_index, k=5):
|
|
||||||
emb_model = load_embedding_model()
|
|
||||||
emb_query = emb_model.encode([query])
|
|
||||||
distances, indices = faiss_index.search(emb_query, k)
|
|
||||||
return distances, indices
|
|
||||||
|
|
||||||
def rerank(query, indices, knowledge_chunks):
|
|
||||||
passages = []
|
|
||||||
for index in indices[0]:
|
|
||||||
content = knowledge_chunks[index]
|
|
||||||
'''
|
'''
|
||||||
txt: 'langchain_core.documents.base.Document'
|
读取向量库
|
||||||
json: dict
|
|
||||||
'''
|
'''
|
||||||
# logger.info(f'retrieved content:{content}')
|
# current_os = platform.system()
|
||||||
# logger.info(f'type of content:{type(content)}')
|
emb_model = self.load_embedding_model()
|
||||||
if type(content) == dict:
|
if not os.path.exists(vector_db_dir) or not os.listdir(vector_db_dir):
|
||||||
content = content["input"] + '\n' + content["output"]
|
db = self.create_vector_db(emb_model)
|
||||||
else:
|
else:
|
||||||
content = content.page_content
|
db = FAISS.load_local(vector_db_dir, emb_model, allow_dangerous_deserialization=True)
|
||||||
passages.append(content)
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve(self, query, vector_db, k=5):
|
||||||
|
'''
|
||||||
|
基于query对向量库进行检索
|
||||||
|
'''
|
||||||
|
retriever = vector_db.as_retriever(search_kwargs={"k": k})
|
||||||
|
docs = retriever.invoke(query)
|
||||||
|
return docs, retriever
|
||||||
|
|
||||||
|
##FlashrankRerank效果一般
|
||||||
|
# def rerank(self, query, retriever):
|
||||||
|
# compressor = FlashrankRerank()
|
||||||
|
# compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
||||||
|
# compressed_docs = compression_retriever.get_relevant_documents(query)
|
||||||
|
# return compressed_docs
|
||||||
|
|
||||||
|
|
||||||
|
def rerank(self, query, docs):
|
||||||
|
reranker = self.load_rerank_model()
|
||||||
|
passages = []
|
||||||
|
for doc in docs:
|
||||||
|
passages.append(str(doc.page_content))
|
||||||
|
scores = reranker.compute_score([[query, passage] for passage in passages])
|
||||||
|
sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
|
||||||
|
sorted_passages, sorted_scores = zip(*sorted_pairs)
|
||||||
|
return sorted_passages, sorted_scores
|
||||||
|
|
||||||
model = load_rerank_model()
|
|
||||||
rerank_results = model.rerank(query, passages)
|
|
||||||
return rerank_results
|
|
||||||
|
|
||||||
@st.cache_resource
|
|
||||||
def load_model():
|
|
||||||
model = (
|
|
||||||
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
|
|
||||||
.to(torch.bfloat16)
|
|
||||||
.cuda()
|
|
||||||
)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger.info(data_dir)
|
logger.info(data_dir)
|
||||||
if not os.path.exists(data_dir):
|
if not os.path.exists(data_dir):
|
||||||
os.mkdir(data_dir)
|
os.mkdir(data_dir)
|
||||||
faiss_index, knowledge_chunks = load_index_and_knowledge()
|
dp = Data_process()
|
||||||
|
# faiss_index, knowledge_chunks = dp.load_index_and_knowledge(knowledge_pkl_path='')
|
||||||
|
vector_db = dp.load_vector_db()
|
||||||
# 按照query进行查询
|
# 按照query进行查询
|
||||||
# query = "她要阻挠姐姐的婚姻,即使她自己的尸体在房门跟前"
|
# query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版,1979年再版。"
|
||||||
# query = "肯定的。我最近睡眠很差,总是做噩梦。而且我吃得也不好,体重一直在下降"
|
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
|
||||||
# query = "序言 (一) 变态心理学是心理学本科生的必修课程之一,教材更新的问题一直在困扰着我们。"
|
# query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
|
||||||
query = "心理咨询师,我觉得我的胸闷症状越来越严重了,这让我很害怕"
|
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
||||||
distances, indices = find_top_k(query, faiss_index, 5)
|
docs, retriever = dp.retrieve(query, vector_db, k=10)
|
||||||
logger.info(f'distances==={distances}')
|
logger.info(f'Query: {query}')
|
||||||
logger.info(f'indices==={indices}')
|
logger.info("Retrieve results:")
|
||||||
|
for i, doc in enumerate(docs):
|
||||||
|
logger.info(str(i) + '\n')
|
||||||
# rerank无法返回id,先实现按整个问答对排序
|
logger.info(doc)
|
||||||
rerank_results = rerank(query, indices, knowledge_chunks)
|
# print(f'get num of docs:{len(docs)}')
|
||||||
for passage, score in zip(rerank_results['rerank_passages'], rerank_results['rerank_scores']):
|
# print(docs)
|
||||||
print(str(score)+'\n')
|
passages,scores = dp.rerank(query, docs)
|
||||||
print(passage+'\n')
|
logger.info("After reranking...")
|
||||||
|
for i in range(len(scores)):
|
||||||
|
logger.info(str(scores[i]) + '\n')
|
||||||
|
logger.info(passages[i])
|
@ -13,9 +13,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|||||||
import torch
|
import torch
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from openxlab.model import download
|
from openxlab.model import download
|
||||||
from data_processing import load_index_and_knowledge, create_index_cpu, create_index_gpu, find_top_k, rerank
|
|
||||||
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir
|
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir
|
||||||
|
from data_processing import Data_process
|
||||||
'''
|
'''
|
||||||
1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer
|
1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer
|
||||||
2)调用 embedding 提供的接口对 query 向量化
|
2)调用 embedding 提供的接口对 query 向量化
|
||||||
@ -42,30 +41,23 @@ def load_model():
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def get_prompt():
|
def main(query, system_prompt=''):
|
||||||
pass
|
logger.info(data_dir)
|
||||||
|
|
||||||
def get_prompt_template():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def main(query, system_prompt):
|
|
||||||
model, tokenizer = load_model()
|
|
||||||
model = model.eval()
|
|
||||||
if not os.path.exists(data_dir):
|
if not os.path.exists(data_dir):
|
||||||
os.mkdir(data_dir)
|
os.mkdir(data_dir)
|
||||||
# 下载基于 FAISS 预构建的 vector DB 以及原始知识库
|
dp = Data_process()
|
||||||
faiss_index, knowledge_chunks = load_index_and_knowledge()
|
vector_db = dp.load_vector_db()
|
||||||
distances, indices = find_top_k(query, faiss_index, 5)
|
docs, retriever = dp.retrieve(query, vector_db, k=10)
|
||||||
rerank_results = rerank(query, indices, knowledge_chunks)
|
logger.info(f'Query: {query}')
|
||||||
messages = [(system_prompt, rerank_results['rerank_passages'][0])]
|
logger.info("Retrieve results===============================")
|
||||||
logger.info(f'messages:{messages}')
|
for i, doc in enumerate(docs):
|
||||||
response, history = model.chat(tokenizer, query, history=messages)
|
logger.info(doc)
|
||||||
messages.append((query, response))
|
passages,scores = dp.rerank(query, docs)
|
||||||
print(f"robot >>> {response}")
|
logger.info("After reranking===============================")
|
||||||
|
for i in range(len(scores)):
|
||||||
|
logger.info(passages[i])
|
||||||
|
logger.info(f'score: {str(scores[i])}')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# query = '你好'
|
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
||||||
query = "心理咨询师,我觉得我的胸闷症状越来越严重了,这让我很害怕"
|
main(query)
|
||||||
#TODO system_prompt = get_prompt()
|
|
||||||
system_prompt = "你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发(排名按字母顺序排序,不分先后)、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。"
|
|
||||||
main(query, system_prompt)
|
|
203
xtuner_config/internlm2_7b_base_qlora_e3.py
Normal file
203
xtuner_config/internlm2_7b_base_qlora_e3.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from mmengine.dataset import DefaultSampler
|
||||||
|
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||||
|
LoggerHook, ParamSchedulerHook)
|
||||||
|
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
|
||||||
|
from peft import LoraConfig
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||||
|
BitsAndBytesConfig)
|
||||||
|
|
||||||
|
from xtuner.dataset import process_hf_dataset
|
||||||
|
from xtuner.dataset.collate_fns import default_collate_fn
|
||||||
|
from xtuner.dataset.map_fns import template_map_fn_factory
|
||||||
|
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
|
||||||
|
from xtuner.model import SupervisedFinetune
|
||||||
|
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# PART 1 Settings #
|
||||||
|
#######################################################################
|
||||||
|
# Model
|
||||||
|
# pretrained_model_name_or_path = '/root/share/model_repos/internlm2-chat-7b'
|
||||||
|
pretrained_model_name_or_path = '/root/share/model_repos/internlm2-base-7b'
|
||||||
|
|
||||||
|
# Data
|
||||||
|
# data_path = 'merge.json'
|
||||||
|
data_path ='/root/StableCascade/emollm2/EmoLLM/datasets/processed/combined_data.json'
|
||||||
|
|
||||||
|
# https://github.com/InternLM/xtuner/blob/main/xtuner/utils/templates.py#L24C25-L24C25
|
||||||
|
prompt_template = PROMPT_TEMPLATE.internlm2_chat # there is No internlm2_base
|
||||||
|
|
||||||
|
max_length = 2048
|
||||||
|
pack_to_max_length = True
|
||||||
|
|
||||||
|
# Scheduler & Optimizer
|
||||||
|
|
||||||
|
# batch_size = 8 # per_device
|
||||||
|
# accumulative_counts = 2
|
||||||
|
batch_size = 16 # per_device
|
||||||
|
accumulative_counts = 1
|
||||||
|
|
||||||
|
dataloader_num_workers = 0
|
||||||
|
max_epochs = 3
|
||||||
|
optim_type = AdamW
|
||||||
|
lr = 2e-4
|
||||||
|
betas = (0.9, 0.999)
|
||||||
|
weight_decay = 0
|
||||||
|
max_norm = 1 # grad clip
|
||||||
|
warmup_ratio = 0.03
|
||||||
|
|
||||||
|
# Evaluate the generation performance during the training
|
||||||
|
evaluation_freq = 500
|
||||||
|
# SYSTEM = "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。"
|
||||||
|
SYSTEM = "你是心理健康助手EmoLLM,由EmoLLM团队打造。你旨在通过专业心理咨询,协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术,一步步帮助来访者解决心理问题。"
|
||||||
|
evaluation_inputs = [
|
||||||
|
'我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。', '我知道应该理性看待,但就是忍不住会去比较。我甚至晚上会因为这个睡不着觉,总想着怎样才能像他那样出色。'
|
||||||
|
]
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# PART 2 Model & Tokenizer #
|
||||||
|
#######################################################################
|
||||||
|
tokenizer = dict(
|
||||||
|
type=AutoTokenizer.from_pretrained,
|
||||||
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
padding_side='right')
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type=SupervisedFinetune,
|
||||||
|
llm=dict(
|
||||||
|
type=AutoModelForCausalLM.from_pretrained,
|
||||||
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
quantization_config=dict(
|
||||||
|
type=BitsAndBytesConfig,
|
||||||
|
load_in_4bit=True,
|
||||||
|
load_in_8bit=False,
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
|
llm_int8_has_fp16_weight=False,
|
||||||
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type='nf4')),
|
||||||
|
lora=dict(
|
||||||
|
type=LoraConfig,
|
||||||
|
r=64,
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.1,
|
||||||
|
bias='none',
|
||||||
|
task_type='CAUSAL_LM'))
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# PART 3 Dataset & Dataloader #
|
||||||
|
#######################################################################
|
||||||
|
alpaca_en = dict(
|
||||||
|
type=process_hf_dataset,
|
||||||
|
dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_length=max_length,
|
||||||
|
dataset_map_fn=None,
|
||||||
|
template_map_fn=dict(
|
||||||
|
type=template_map_fn_factory, template=prompt_template),
|
||||||
|
remove_unused_columns=True,
|
||||||
|
shuffle_before_pack=True,
|
||||||
|
pack_to_max_length=pack_to_max_length)
|
||||||
|
|
||||||
|
train_dataloader = dict(
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=dataloader_num_workers,
|
||||||
|
dataset=alpaca_en,
|
||||||
|
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||||
|
collate_fn=dict(type=default_collate_fn))
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# PART 4 Scheduler & Optimizer #
|
||||||
|
#######################################################################
|
||||||
|
# optimizer
|
||||||
|
optim_wrapper = dict(
|
||||||
|
type=AmpOptimWrapper,
|
||||||
|
optimizer=dict(
|
||||||
|
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
|
||||||
|
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
|
||||||
|
accumulative_counts=accumulative_counts,
|
||||||
|
loss_scale='dynamic',
|
||||||
|
dtype='float16')
|
||||||
|
|
||||||
|
# learning policy
|
||||||
|
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
|
||||||
|
param_scheduler = [
|
||||||
|
dict(
|
||||||
|
type=LinearLR,
|
||||||
|
start_factor=1e-5,
|
||||||
|
by_epoch=True,
|
||||||
|
begin=0,
|
||||||
|
end=warmup_ratio * max_epochs,
|
||||||
|
convert_to_iter_based=True),
|
||||||
|
dict(
|
||||||
|
type=CosineAnnealingLR,
|
||||||
|
eta_min=0.0,
|
||||||
|
by_epoch=True,
|
||||||
|
begin=warmup_ratio * max_epochs,
|
||||||
|
T_max=max_epochs,
|
||||||
|
convert_to_iter_based=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
# train, val, test setting
|
||||||
|
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# PART 5 Runtime #
|
||||||
|
#######################################################################
|
||||||
|
# Log the dialogue periodically during the training process, optional
|
||||||
|
custom_hooks = [
|
||||||
|
dict(type=DatasetInfoHook, tokenizer=tokenizer),
|
||||||
|
dict(
|
||||||
|
type=EvaluateChatHook,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
every_n_iters=evaluation_freq,
|
||||||
|
evaluation_inputs=evaluation_inputs,
|
||||||
|
system=SYSTEM,
|
||||||
|
prompt_template=prompt_template)
|
||||||
|
]
|
||||||
|
|
||||||
|
# configure default hooks
|
||||||
|
default_hooks = dict(
|
||||||
|
# record the time of every iteration.
|
||||||
|
timer=dict(type=IterTimerHook),
|
||||||
|
# print log every 100 iterations.
|
||||||
|
logger=dict(type=LoggerHook, interval=10),
|
||||||
|
# enable the parameter scheduler.
|
||||||
|
param_scheduler=dict(type=ParamSchedulerHook),
|
||||||
|
# save checkpoint per epoch.
|
||||||
|
checkpoint=dict(type=CheckpointHook, interval=1),
|
||||||
|
# set sampler seed in distributed evrionment.
|
||||||
|
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||||
|
)
|
||||||
|
|
||||||
|
# configure environment
|
||||||
|
env_cfg = dict(
|
||||||
|
# whether to enable cudnn benchmark
|
||||||
|
cudnn_benchmark=False,
|
||||||
|
# set multi process parameters
|
||||||
|
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||||
|
# set distributed parameters
|
||||||
|
dist_cfg=dict(backend='nccl'),
|
||||||
|
)
|
||||||
|
|
||||||
|
# set visualizer
|
||||||
|
visualizer = None
|
||||||
|
|
||||||
|
# set log level
|
||||||
|
log_level = 'INFO'
|
||||||
|
|
||||||
|
# load from which checkpoint
|
||||||
|
load_from = None
|
||||||
|
|
||||||
|
# whether to resume training from the loaded checkpoint
|
||||||
|
resume = False
|
||||||
|
|
||||||
|
# Defaults to use random seed and disable `deterministic`
|
||||||
|
randomness = dict(seed=None, deterministic=False)
|
@ -1,11 +1,23 @@
|
|||||||
datasets==2.16.1
|
datasets==2.16.1
|
||||||
deepspeed==0.13.1
|
deepspeed==0.13.1
|
||||||
einops==0.7.0
|
einops==0.7.0
|
||||||
flash_attn==2.5.0
|
|
||||||
mmengine==0.10.2
|
|
||||||
openxlab==0.0.34
|
openxlab==0.0.34
|
||||||
peft==0.7.1
|
peft==0.7.1
|
||||||
sentencepiece==0.1.99
|
sentencepiece==0.1.99
|
||||||
torch==2.1.2
|
torch==2.1.2
|
||||||
transformers==4.36.2
|
transformers==4.36.2
|
||||||
xtuner==0.1.11
|
|
||||||
|
# modified version
|
||||||
|
# xtuner==0.1.11
|
||||||
|
# mmengine==0.10.2
|
||||||
|
mmengine==0.10.3
|
||||||
|
xtuner==0.1.15
|
||||||
|
|
||||||
|
# flash_attn==2.5.0 # build is very slow about 2 hours?
|
||||||
|
|
||||||
|
# method 1: https://github.com/Dao-AILab/flash-attention/releases
|
||||||
|
# flash_attn-2.5.0+cu122torch2.1cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
|
||||||
|
# method 2:
|
||||||
|
# pip install /root/share/wheels/flash_attn-2.4.2+cu118torch2.0cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
|
# mpi4py==3.1.5 # conda install mpi4py
|
||||||
|
Loading…
Reference in New Issue
Block a user