Compare commits
No commits in common. "1125b67f50e037e3147b4fb59c5449be2af8e170" and "a64a3c3cbe49005d4f18b75a42ec2d0e9a707538" have entirely different histories.
1125b67f50
...
a64a3c3cbe
2
.gitignore
vendored
2
.gitignore
vendored
@ -6,8 +6,6 @@ data/
|
||||
pdf/
|
||||
.idea/
|
||||
logs/
|
||||
.vscode/
|
||||
work_dirs/
|
||||
|
||||
# *.jsonl
|
||||
# *.json
|
||||
|
5
.vscode/settings.json
vendored
Normal file
5
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": null
|
||||
}
|
||||
}
|
41
README.md
41
README.md
@ -6,7 +6,7 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/SmartFlowAI/EmoLLM/">
|
||||
<img src="assets/EmoLLM_logo_L.png" alt="Logo" width="50%">
|
||||
<img src="assets/EmoLLM_transparent.png" alt="Logo" width="50%">
|
||||
</a>
|
||||
|
||||
<div align="center">
|
||||
@ -25,7 +25,7 @@
|
||||
<h3 align="center">EmoLLM</h3>
|
||||
|
||||
<div align="center">
|
||||
简体中文| <a href="README_EN.md" >English</a> | <a href="README_JP.md" >日本語</a>
|
||||
简体中文| <a href="README_EN.md" >English</a>
|
||||
<br />
|
||||
<br />
|
||||
<a href="https://github.com/SmartFlowAI/EmoLLM"><strong>探索本项目的文档 »</strong></a>
|
||||
@ -59,9 +59,8 @@
|
||||
| ChatGLM3_6B | LoRA | [chatglm3_6b_lora_alpaca_e3.py](./xtuner_config/chatglm3_6b_lora_alpaca_e3.py) | |
|
||||
| DeepSeek MoE_16B_chat | QLoRA | [deepseek_moe_16b_chat_qlora_oasst1_e3.py](./xtuner_config/deepseek_moe_16b_chat_qlora_oasst1_e3.py) | |
|
||||
| Mixtral 8x7B_instruct | QLoRA | [mixtral_8x7b_instruct_qlora_oasst1_e3.py](./xtuner_config/mixtral_8x7b_instruct_qlora_oasst1_e3.py) | |
|
||||
| LLaMA3_8B_instruct | QLoRA | [aiwei_llama3_8b_instruct_qlora_e3.py](./xtuner_config/aiwei_llama3_8b_instruct_qlora_e3.py) | [OpenXLab](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM-LLaMA3_8b_instruct_aiwei/tree/main), [ModelScope](https://modelscope.cn/models/aJupyter/EmoLLM-LLaMA3_8b_instruct_aiwei/files) |
|
||||
| LLaMA3_8B_instruct | QLoRA | [llama3_8b_instruct_qlora_alpaca_e3_M_ruozhi_scM.py](./xtuner_config/llama3_8b_instruct_qlora_alpaca_e3_M_ruozhi_scM.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct3.0), [ModelScope](https://modelscope.cn/models/chg0901/EmoLLM-Llama3-8B-Instruct3.0/summary) |
|
||||
| Qwen2-7B-Instruct | LoRA | [Qwen2-7B-Instruct_lora.py](./xtuner_config/Qwen2-7B-Instruct_lora.py) |[ModelScope](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen2-7B-Instruct_lora/) |
|
||||
| LLaMA3_8b_instruct | QLoRA | [aiwei_llama3_8b_instruct_qlora_e3.py](./xtuner_config/aiwei_llama3_8b_instruct_qlora_e3.py) | [OpenXLab](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM-LLaMA3_8b_instruct_aiwei/tree/main), [ModelScope](https://modelscope.cn/models/aJupyter/EmoLLM-LLaMA3_8b_instruct_aiwei/files) |
|
||||
| LLaMA3_8b_instruct | QLoRA | [llama3_8b_instruct_qlora_alpaca_e3_M_ruozhi_scM.py](./xtuner_config/llama3_8b_instruct_qlora_alpaca_e3_M_ruozhi_scM.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct3.0), [ModelScope](https://modelscope.cn/models/chg0901/EmoLLM-Llama3-8B-Instruct3.0/summary) |
|
||||
| …… | …… | …… | …… |
|
||||
|
||||
</div>
|
||||
@ -101,8 +100,6 @@
|
||||
</table>
|
||||
|
||||
## 🎇最近更新
|
||||
- 【2024.09.14】基于Qwen2-7B-Instruct模型的Lora微调模型开源,微调配置文件地址:[Qwen2-7B-Instruct_lora.py](./xtuner_config/Qwen2-7B-Instruct_lora.py) ,模型权重链接:[ModelScope](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen2-7B-Instruct_lora/)
|
||||
- 【2024.08】基于GLM4-9B-chat微调Lora模型开源(基于LLaMA-Factory),详情见[微调教程](./doc/GLM-4-9B-chat%20Lora%20微调(llama-factory).md) ,模型权重链接:[ModelScope](https://www.modelscope.cn/models/wwewwt/EmoLLM-glm-4-9b-chat/summary)
|
||||
- 【2024.07.16】欢迎大家体验 EmoLLM V3.0 ,该模型是基于InternLM2.5-7B-Chat模型的全量微调,微调配置文件地址:[internlm2_5_chat_7b_full.py](./xtuner_config/internlm2_5_chat_7b_full.py) ,模型权重链接:[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM_V3.0), [ModelScope](https://modelscope.cn/models/chg0901/EmoLLMV3.0) ,WebDemo地址: [OpenXLab apps](https://openxlab.org.cn/apps/detail/chg0901/EmoLLMV3.0), [配套全量微调知乎教程](https://zhuanlan.zhihu.com/p/708931911)。
|
||||
- 【2024.07】欢迎大家使用稳定版 EmoLLM V2.0 进行日常使用和学术研究,模型权重链接:[OpenXLab](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_internlm2_7b_full/tree/main)。
|
||||
- 【2024.07】新增基于InternLM2_5_7B_chat[微调配置](./xtuner_config/internlm2_5_chat_7b_qlora_oasst1_e3.py)、模型文件发布在 [ModelScope](https://www.modelscope.cn/models/z342994309/emollm_interlm2_5/)。
|
||||
@ -120,15 +117,15 @@
|
||||
- 【2024.03.11】 **EmoLLM V2.0 相比 EmoLLM V1.0 全面提升,已超越 Role-playing ChatGPT 在心理咨询任务上的能力!**[点击体验EmoLLM V2.0](https://openxlab.org.cn/apps/detail/Farewell1/EmoLLMV2.0),更新[数据集统计及详细信息](./datasets/)、[路线图](./assets/Roadmap_ZH.png)
|
||||
- 【2024.03.09】 新增并发功能加速 [QA 对生成](./scripts/qa_generation/)、[RAG pipeline](./rag/)
|
||||
- 【2024.03.03】 [基于InternLM2-7B-chat全量微调版本EmoLLM V2.0开源](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_internlm2_7b_full),需要两块A100*80G,更新专业评估,详见[evaluate](./evaluate/),更新基于PaddleOCR的PDF转txt工具脚本,详见[scripts](./scripts/)
|
||||
|
||||
<details>
|
||||
<summary>查看更多</summary>
|
||||
|
||||
- 【2024.02.29】更新客观评估计算,详见[evaluate](./evaluate/),更新一系列数据集,详见[datasets](./datasets/)
|
||||
- 【2024.02.27】更新英文readme和一系列数据集(舔狗和单轮对话)
|
||||
- 【2024.02.23】推出基于InternLM2_7B_chat_qlora的 `温柔御姐心理医生艾薇`,[点击获取模型权重](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_aiwei),[配置文件](xtuner_config/aiwei-internlm2_chat_7b_qlora.py),[在线体验链接](https://openxlab.org.cn/apps/detail/ajupyter/EmoLLM-aiwei)
|
||||
- 【2024.02.23】更新[若干微调配置](/xtuner_config/),新增 [data_pro.json](/datasets/data_pro.json)(数量更多、场景更全、更丰富)和 [aiwei.json](/datasets/aiwei.json)(温柔御姐角色扮演专用,带有Emoji表情),即将推出 `温柔御姐心理医生艾薇`
|
||||
- 【2024.02.18】 [基于Qwen1_5-0_5B-Chat全量微调版本开源](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary),算力有限的道友可以玩起来~
|
||||
|
||||
<details>
|
||||
<summary>查看更多</summary>
|
||||
|
||||
- 【2024.02.06】 EmoLLM在[**Openxlab** ](https://openxlab.org.cn/models/detail/jujimeizuo/EmoLLM_Model) 平台下载量高达18.7k,欢迎大家体验!
|
||||
|
||||
<p align="center">
|
||||
@ -188,7 +185,7 @@
|
||||
- [使用指南](#使用指南)
|
||||
- [🍪快速体验](#快速体验)
|
||||
- [📌数据构建](#数据构建)
|
||||
- [🎨增量预训练、微调指南](#增量预训练微调指南)
|
||||
- [🎨微调指南](#微调指南)
|
||||
- [🔧部署指南](#部署指南)
|
||||
- [⚙RAG(检索增强生成)](#rag检索增强生成)
|
||||
- [🎓评测指南](#评测指南)
|
||||
@ -207,7 +204,6 @@
|
||||
###### 开发前的配置要求
|
||||
|
||||
- 硬件:A100 40G(仅针对InternLM2_7B_chat+qlora微调+deepspeed zero2优化)
|
||||
- todo:发布更多硬件消耗细节
|
||||
|
||||
###### 使用指南
|
||||
|
||||
@ -220,7 +216,7 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git
|
||||
2. 依次阅读或者选择感兴趣的部分阅读:
|
||||
- [快速体验](#快速体验)
|
||||
- [数据构建](#数据构建)
|
||||
- [增量预训练、微调指南](#增量预训练微调指南)
|
||||
- [微调指南](#微调指南)
|
||||
- [部署指南](#部署指南)
|
||||
- [RAG](#rag检索增强生成)
|
||||
- [评测指南](#评测指南)
|
||||
@ -234,21 +230,19 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git
|
||||
|
||||
|
||||
### 📌数据构建
|
||||
|
||||
- 请阅读[数据构建指南](generate_data/tutorial.md)查阅
|
||||
|
||||
- 微调用到的数据集见[datasets](datasets/data.json)
|
||||
|
||||
### 🎨增量预训练、微调指南
|
||||
- 增量预训练详见[增量预训练指南](./xtuner_config/pt/README.md)
|
||||
- 【基于xtuner】全量、LoRA、QLoRA微调详见[微调指南](./xtuner_config/README.md)
|
||||
- 【基于ms-swift】全量、LoRA、QLoRA微调详见[微调指南](./swift/README.md)
|
||||
- 【基于LLaMA-Factory】全量、LoRA、QLoRA微调详见[微调指南](./doc/GLM-4-9B-chat%20Lora%20微调(llama-factory).md)
|
||||
- todo:待更新DPO训练
|
||||
### 🎨微调指南
|
||||
|
||||
详见[微调指南](xtuner_config/README.md)
|
||||
|
||||
### 🔧部署指南
|
||||
|
||||
- Demo部署:详见[部署指南](demo/README.md)
|
||||
- 基于[LMDeploy](https://github.com/InternLM/lmdeploy/)的量化部署:详见[deploy](./deploy/lmdeploy.md)
|
||||
- todo: 基于VLLM部署指南
|
||||
|
||||
### ⚙RAG(检索增强生成)
|
||||
|
||||
@ -263,14 +257,13 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git
|
||||
|
||||
### 使用到的框架
|
||||
|
||||
- [xtuner](https://github.com/InternLM/xtuner):用于微调
|
||||
- [Xtuner](https://github.com/InternLM/xtuner):用于微调
|
||||
- [Transformers](https://github.com/huggingface/transformers)
|
||||
- [Pytorch](https://pytorch.org/)
|
||||
- [LMDeploy](https://github.com/InternLM/lmdeploy/):用于量化部署
|
||||
- [Stremlit](https://streamlit.io/):用于构建Demo
|
||||
- [DeepSpeed](https://github.com/microsoft/DeepSpeed):并行训练
|
||||
- [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory/blob/main):训练框架
|
||||
- [ms-swift](https://github.com/modelscope/ms-swift):训练框架
|
||||
- …
|
||||
|
||||
#### 如何参与本项目
|
||||
|
||||
|
47
README_EN.md
47
README_EN.md
@ -25,7 +25,7 @@
|
||||
<h3 align="center">EmoLLM</h3>
|
||||
|
||||
<p align="center">
|
||||
<a href="README.md">简体中文</a> | English | <a href="README_JP.md">日本語</a>
|
||||
<a href="README.md">简体中文</a> | English
|
||||
<br />
|
||||
<br />
|
||||
<a href="https://github.com/SmartFlowAI/EmoLLM"><strong>Explore the documentation of this project »</strong></a>
|
||||
@ -63,7 +63,6 @@
|
||||
| Mixtral 8x7B_instruct | QLoRA | [mixtral_8x7b_instruct_qlora_oasst1_e3.py](./xtuner_config/mixtral_8x7b_instruct_qlora_oasst1_e3.py) | |
|
||||
| LLaMA3_8b_instruct | QLoRA | [aiwei_llama3_8b_instruct_qlora_e3.py](./xtuner_config/aiwei_llama3_8b_instruct_qlora_e3.py) | [OpenXLab](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM-LLaMA3_8b_instruct_aiwei/tree/main), [ModelScope](https://modelscope.cn/models/aJupyter/EmoLLM-LLaMA3_8b_instruct_aiwei/files) |
|
||||
| LLaMA3_8b_instruct | QLoRA | [llama3_8b_instruct_qlora_alpaca_e3_M_ruozhi_scM.py](./xtuner_config/llama3_8b_instruct_qlora_alpaca_e3_M_ruozhi_scM.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct3.0), [ModelScope](https://modelscope.cn/models/chg0901/EmoLLM-Llama3-8B-Instruct3.0/summary) |
|
||||
| Qwen2-7B-Instruct | LoRA | [Qwen2-7B-Instruct_lora.py](./xtuner_config/Qwen2-7B-Instruct_lora.py) |[ModelScope](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen2-7B-Instruct_lora/) |
|
||||
| …… | …… | …… | …… |
|
||||
|
||||
|
||||
@ -105,13 +104,11 @@ The Model aims to fully understand and promote the mental health of individuals,
|
||||
</table>
|
||||
|
||||
## Recent Updates
|
||||
- [2024.09.14] The Lora fine-tuned model based on the Qwen2-7B-Instruct model is open-sourced. Fine-tuning configuration file address: [Qwen2-7B-Instruct_lora.py](./xtuner_config/Qwen2-7B-Instruct_lora.py), model weight link: [ModelScope](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen2-7B-Instruct_lora/)
|
||||
- [2024.08] The Lora fine-tuned model based on GLM4-9B-chat is open-sourced (based on Llama-factory). For details, see [Fine-tuning Tutorial](./doc/GLM-4-9B-chat%20Lora%20微调(llama-factory).md), model weight link: [ModelScope](https://www.modelscope.cn/models/wwewwt/EmoLLM-glm-4-9b-chat/summary)
|
||||
- [2024.07.16] Welcome everyone to experience EmoLLM V3.0. This model is a fully fine-tuned version based on the InternLM2.5-7B-Chat model. The fine-tuning configuration file can be found at: [internlm2_5_chat_7b_full.py](./xtuner_config/internlm2_5_chat_7b_full.py). Model weights are available at: [OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM_V3.0), [ModelScope](https://modelscope.cn/models/chg0901/EmoLLMV3.0). WebDemo is available at: [OpenXLab apps](https://openxlab.org.cn/apps/detail/chg0901/EmoLLMV3.0), [Full fine-tuning tutorial on Zhihu](https://zhuanlan.zhihu.com/p/708931911).
|
||||
- [2024.07] Welcome to use the stable version of EmoLLM V2.0 for daily use and academic research. Model weight link: [OpenXLab](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_internlm2_7b_full/tree/main).
|
||||
- [2024.07] Added InternLM2_5_7B_chat[fine-tuning configuration](./xtuner_config/internlm2_5_chat_7b_qlora_oasst1_e3.py)、model file [ModelScope](https://www.modelscope.cn/models/z342994309/emollm_interlm2_5/)。
|
||||
- [2024.06] Added [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)[GLM4-9B-chat fine-tuning guide](./doc/GLM-4-9B-chat%20Lora%20微调(llama-factory).md), added [swift-based fine-tuning guide](./swift/), the paper [ESC-Eval: Evaluating Emotion Support Conversations in Large Language Models](https://arxiv.org/abs/2406.14952) cited EmoLLM and EmoLLM achieved good results.
|
||||
- [2024.05.28] The multi-turn dialogue dataset **CPsyCunD** and **professional evaluation method** used by EmoLLM have been released. For details, please see the 2024 ACL findings[《CPsyCoun: A Report-based Multi-turn Dialogue Reconstruction and Evaluation Framework for Chinese Psychological Counseling》](https://arxiv.org/abs/2405.16433)!
|
||||
- 【2024.07.16】 Welcome everyone to experience EmoLLM V3.0. This model is a fully fine-tuned version based on the InternLM2.5-7B-Chat model. The fine-tuning configuration file can be found at: [internlm2_5_chat_7b_full.py](./xtuner_config/internlm2_5_chat_7b_full.py). Model weights are available at: [OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM_V3.0), [ModelScope](https://modelscope.cn/models/chg0901/EmoLLMV3.0). WebDemo is available at: [OpenXLab apps](https://openxlab.org.cn/apps/detail/chg0901/EmoLLMV3.0), [Full fine-tuning tutorial on Zhihu](https://zhuanlan.zhihu.com/p/708931911).
|
||||
- 【2024.07】Welcome to use the stable version of EmoLLM V2.0 for daily use and academic research. Model weight link: [OpenXLab](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_internlm2_7b_full/tree/main).
|
||||
- 【2024.07】Added InternLM2_5_7B_chat[fine-tuning configuration](./xtuner_config/internlm2_5_chat_7b_qlora_oasst1_e3.py)、model file [ModelScope](https://www.modelscope.cn/models/z342994309/emollm_interlm2_5/)。
|
||||
- 【2024.06】 Added [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)[GLM4-9B-chat fine-tuning guide](./doc/GLM-4-9B-chat%20Lora%20微调(llama-factory).md), added [swift-based fine-tuning guide](./swift/), the paper [ESC-Eval: Evaluating Emotion Support Conversations in Large Language Models](https://arxiv.org/abs/2406.14952) cited EmoLLM and EmoLLM achieved good results.
|
||||
- 【2024.05.28】The multi-turn dialogue dataset **CPsyCunD** and **professional evaluation method** used by EmoLLM have been released. For details, please see the 2024 ACL findings[《CPsyCoun: A Report-based Multi-turn Dialogue Reconstruction and Evaluation Framework for Chinese Psychological Counseling》](https://arxiv.org/abs/2405.16433)!
|
||||
- [2024.05.08] EmoLLM**Daddy-like BF V0.1** is public now in [1. **Baidu AppBuilder**](https://appbuilder.baidu.com/s/4cLyw) and [2. **OpenXLab**](https://openxlab.org.cn/apps/detail/chg0901/EmoLLM3.0_Gradio_Llama3-8B-Instruct3.0), welcome to like and add it to your collections!
|
||||
- [2024.05.07] [Incremental Pre-training Guide](xtuner_config/pt/README.md)
|
||||
- [2024.05.04] [EmoLLM3.0 OpenXLab Demo](https://st-app-center-006861-9746-jlroxvg.openxlab.space/) based on LLaMA3_8b_instruct is available now ([restart link]((https://openxlab.org.cn/apps/detail/chg0901/EmoLLM-Llama3-8B-Instruct3.0))), [LLAMA3 fine-tuning guide](xtuner_config/README_llama3_8b_instruct_qlora_alpaca_e3_M.md) is updated, LLaMA3_8b_instruct-8B QLoRA fine-tuning model EmoLLM3.0 weights are released on [**OpenXLab**](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct3.0) and [**ModelScope**](https://modelscope.cn/models/chg0901/EmoLLM-Llama3-8B-Instruct3.0/summary) platforms
|
||||
@ -125,11 +122,6 @@ The Model aims to fully understand and promote the mental health of individuals,
|
||||
- [2024.03.11] **EmoLLM V2.0 is greatly improved in all scores compared to EmoLLM V1.0. Surpasses the performance of Role-playing ChatGPT on counseling tasks!** [Click to experience EmoLLM V2.0](https://openxlab.org.cn/apps/detail/Farewell1/EmoLLMV2.0), update [dataset statistics and details](./datasets/), [Roadmap](./assets/Roadmap_ZH.png)
|
||||
- [2024.03.09] Add concurrency acceleration [QA pair generation](./scripts/qa_generation/), [RAG pipeline](./rag/)
|
||||
- [2024.03.03] [Based on InternLM2-7B-chat full fine-tuned version EmoLLM V2.0 open sourced](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_internlm2_7b_full), need two A100*80G, update professional evaluation, see [evaluate](./evaluate/), update PaddleOCR-based PDF to txt tool scripts, see [scripts](./scripts/).
|
||||
|
||||
|
||||
<details>
|
||||
<summary>View More</summary>
|
||||
|
||||
- [2024.02.29] Updated objective assessment calculations, see [evaluate](./evaluate/) for details. A series of datasets have also been updated, see [datasets](./datasets/) for details.
|
||||
- [2024.02.27] Updated English README and a series of datasets (licking dogs and one-round dialogue)
|
||||
- [2024.02.23]The "Gentle Lady Psychologist Ai Wei" based on InternLM2_7B_chat_qlora was launched. [Click here to obtain the model weights](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_aiwei), [configuration file](xtuner_config/aiwei-internlm2_chat_7b_qlora.py), [online experience link](https://openxlab.org.cn/apps/detail/ajupyter/EmoLLM-aiwei)
|
||||
@ -137,7 +129,11 @@ The Model aims to fully understand and promote the mental health of individuals,
|
||||
- [2024.02.23]Updated [several fine-tuning configurations](/xtuner_config/), added [data_pro.json](/datasets/data_pro.json) (more quantity, more comprehensive scenarios, richer content) and [aiwei.json](/datasets/aiwei.json) (dedicated to the gentle lady role-play, featuring Emoji expressions), the "Gentle Lady Psychologist Ai Wei" is coming soon.
|
||||
|
||||
- [2024.02.18] The full fine-tuned version based on Qwen1_5-0_5B-Chat has been [open-sourced](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary). Friends with limited computational resources can now dive in and explore it.
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>View More</summary>
|
||||
|
||||
- [2024.02.06] [Open-sourced based on the Qwen1_5-0_5B-Chat full-scale fine-tuned version](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary), friends with limited computing power can start experimenting~
|
||||
|
||||
<p align="center">
|
||||
@ -191,7 +187,7 @@ The Model aims to fully understand and promote the mental health of individuals,
|
||||
- [User Guide](#user-guide)
|
||||
- [🍪Quick start](#quick-start)
|
||||
- [📌Data Construction](#data-construction)
|
||||
- [🎨Incremental Pre-training and Fine-tuning Guide](#incremental-pre-training-and-fine-tuning-guide)
|
||||
- [🎨Fine-tuning Guide](#fine-tuning-guide)
|
||||
- [🔧Deployment Guide](#deployment-guide)
|
||||
- [⚙RAG (Retrieval Augmented Generation)](#rag-retrieval-augmented-generation)
|
||||
- [🎓Evaluation Guide](#evaluation-guide)
|
||||
@ -210,7 +206,6 @@ The Model aims to fully understand and promote the mental health of individuals,
|
||||
###### Pre-development Configuration Requirements.
|
||||
|
||||
- A100 40G (specifically for InternLM2_7B_chat + qlora fine-tuning + deepspeed zero2 optimization)
|
||||
- **[TODO]**: Publish more details about hardware consumption.
|
||||
|
||||
###### User Guide
|
||||
|
||||
@ -223,7 +218,7 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git
|
||||
1. Read in sequence or read sections you're interested in:
|
||||
- [Quick Start](#quick-start)
|
||||
- [Data Construction](#data-construction)
|
||||
- [Fine-tuning Guide](#incremental-pre-training-and-fine-tuning-guide)
|
||||
- [Fine-tuning Guide](#fine-tuning-guide)
|
||||
- [Deployment Guide](#deployment-guide)
|
||||
- [RAG](#rag-retrieval-augmented-generation)
|
||||
- [Evaluation Guide](#evaluation-guide)
|
||||
@ -235,22 +230,19 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git
|
||||
- Quick coding: [Baby EmoLLM](quick_start/Baby_EmoLLM.ipynb)
|
||||
|
||||
### 📌Data Construction
|
||||
|
||||
- Please read the [Data Construction Guide ](generate_data/tutorial_EN.md) for reference.
|
||||
|
||||
- The dataset used for this fine-tuning can be found at [datasets](datasets/data.json)
|
||||
|
||||
### 🎨Incremental Pre-training and Fine-tuning Guide
|
||||
- For details on incremental pre-training, see [Incremental Pre-training Guide](./xtuner_config/pt/README.md).
|
||||
- For full-scale, LoRA, and QLoRA fine-tuning based on **xtuner**, see [Fine-tuning Guide](./xtuner_config/README_EN.md).
|
||||
- For full-scale, LoRA, and QLoRA fine-tuning based on **ms-swift**, see [Fine-tuning Guide](./swift/README_EN.md).
|
||||
- For full-scale, LoRA, and QLoRA fine-tuning based on **LLaMA-Factory**, see [Fine-tuning Guide](./doc/GLM-4-9B-chat%20Lora%20微调(llama-factory).md).
|
||||
- **[TODO]**: Update DPO training.
|
||||
### 🎨Fine-tuning Guide
|
||||
|
||||
For details, see the [fine-tuning guide](xtuner_config/README_EN.md)
|
||||
|
||||
### 🔧Deployment Guide
|
||||
|
||||
- Demo deployment: see [deployment guide](./demo/README_EN.md) for details.
|
||||
- Quantitative deployment based on [LMDeploy](https://github.com/InternLM/lmdeploy/): see [deploy](./deploy/lmdeploy_EN.md)
|
||||
- **[TODO]**: Deployment Guide for VLLM
|
||||
|
||||
### ⚙RAG (Retrieval Augmented Generation)
|
||||
|
||||
@ -271,8 +263,7 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git
|
||||
- [LMDeploy](https://github.com/InternLM/lmdeploy/): for quantitative deployment
|
||||
- [Stremlit](https://streamlit.io/): for building demos
|
||||
- [DeepSpeed](https://github.com/microsoft/DeepSpeed): for parallel training
|
||||
- [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory/blob/main)
|
||||
- [ms-swift](https://github.com/modelscope/ms-swift)
|
||||
- …
|
||||
|
||||
#### How to participate in this project
|
||||
|
||||
@ -364,7 +355,7 @@ The project is licensed under the MIT License. Please refer to the details
|
||||
[issues-shield]: https://img.shields.io/github/issues/SmartflowAI/EmoLLM.svg?style=flat-square
|
||||
[issues-url]: https://img.shields.io/github/issues/SmartflowAI/EmoLLM.svg
|
||||
[license-shield]: https://img.shields.io/github/license/SmartflowAI/EmoLLM.svg?style=flat-square
|
||||
[license-url]: https://github.com/SmartFlowAI/EmoLLM/blob/main/LICENSE
|
||||
[license-url]: https://github.com/SmartflowAI/EmoLLM/blob/main/LICENSE
|
||||
|
||||
[OpenXLab_App-image]: https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg
|
||||
[OpenXLab_Model-image]: https://cdn-static.openxlab.org.cn/header/openxlab_models.svg
|
||||
|
364
README_JP.md
364
README_JP.md
@ -1,364 +0,0 @@
|
||||
<div align="center">
|
||||
|
||||
# EmoLLM - メンタルヘルスのための大規模言語モデル
|
||||
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/SmartFlowAI/EmoLLM/">
|
||||
<img src="assets/EmoLLM_transparent.png" alt="Logo" width="50%">
|
||||
</a>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<!-- PROJECT SHIELDS -->
|
||||
[![Contributors][contributors-shield]][contributors-url]
|
||||
[![Forks][forks-shield]][forks-url]
|
||||
[![Issues][issues-shield]][issues-url]
|
||||
[![OpenXLab_App][OpenXLab_App-image]][OpenXLab_App-url]
|
||||
[![OpenXLab_Model][OpenXLab_Model-image]][OpenXLab_Model-url]
|
||||
[![MIT License][license-shield]][license-url]
|
||||
[![Stargazers][stars-shield]][stars-url]
|
||||
|
||||
</div>
|
||||
|
||||
<h3 align="center">EmoLLM</h3>
|
||||
|
||||
<p align="center">
|
||||
<a href="README.md">简体中文</a> | English | 日本語
|
||||
<br />
|
||||
<br />
|
||||
<a href="https://github.com/SmartFlowAI/EmoLLM"><strong>このプロジェクトのドキュメントを探索する »</strong></a>
|
||||
<br />
|
||||
<br />
|
||||
<a href="https://openxlab.org.cn/apps/detail/Farewell1/EmoLLMV2.0">EmoLLM 2.0 デモ</a>
|
||||
·
|
||||
<a href="https://github.com/SmartFlowAI/EmoLLM/issues">バグを報告する</a>
|
||||
·
|
||||
<a href="https://github.com/SmartFlowAI/EmoLLM/issues">新機能を提案する</a>
|
||||
</p>
|
||||
|
||||
</p>
|
||||
|
||||
<!-- 本篇README.md面向开发者 -->
|
||||
|
||||
**EmoLLM** は、メンタルヘルスカウンセリングにおいて顧客を理解し、サポートし、助けるために設計された大規模言語モデルのシリーズです。LLMの指示から微調整されています。スターをいただけると嬉しいです~⭐⭐。オープンソースの構成は以下の通りです:
|
||||
|
||||
<div align="center">
|
||||
|
||||
| モデル | タイプ | ファイルリンク | モデルリンク |
|
||||
| :-------------------: | :------: | :------------------------------------------------------------------------------------------------------: |:------: |
|
||||
| InternLM2_5_7B_chat | QLORA | [internlm2_5_chat_7b_qlora_oasst1_e3.py](./xtuner_config/internlm2_5_chat_7b_qlora_oasst1_e3.py) |[ModelScope](https://www.modelscope.cn/models/z342994309/emollm_interlm2_5/) |
|
||||
| InternLM2_7B_chat | QLORA | [internlm2_7b_chat_qlora_e3.py](./xtuner_config/internlm2_7b_chat_qlora_e3.py) | [ModelScope](https://modelscope.cn/models/aJupyter/EmoLLM/files) |
|
||||
| InternLM2_7B_chat | 全量微調整 | [internlm2_chat_7b_full.py](./xtuner_config/internlm2_chat_7b_full.py) | [OpenXLab](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_internlm2_7b_full) |
|
||||
| InternLM2_7B_base | QLORA | [internlm2_7b_base_qlora_e10_M_1e4_32_64.py](./xtuner_config/internlm2_7b_base_qlora_e10_M_1e4_32_64.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-InternLM7B-base-10e), [ModelScope](https://www.modelscope.cn/models/chg0901/EmoLLM-InternLM7B-base-10e/summary) |
|
||||
| InternLM2_1_8B_chat | 全量微調整 | [internlm2_1_8b_full_alpaca_e3.py](./xtuner_config/internlm2_1_8b_full_alpaca_e3.py) | [OpenXLab](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_internlm2_1_8b_full/tree/main), [ModelScope](https://modelscope.cn/models/aJupyter/EmoLLM_PT_InternLM1.8B-chat/files) |
|
||||
| InternLM2_20B_chat | LORA |[internlm2_20b_chat_lora_alpaca_e3.py](./xtuner_config/internlm2_20b_chat_lora_alpaca_e3.py)| |
|
||||
| Qwen_7b_chat | QLORA | [qwen_7b_chat_qlora_e3.py](./xtuner_config/qwen_7b_chat_qlora_e3.py) | |
|
||||
| Qwen1_5-0_5B-Chat | 全量微調整 | [qwen1_5_0_5_B_full.py](./xtuner_config/qwen1_5_0_5_B_full.py) | [ModelScope](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary) |
|
||||
| Baichuan2_13B_chat | QLORA | [baichuan2_13b_chat_qlora_alpaca_e3.py](./xtuner_config/baichuan2_13b_chat_qlora_alpaca_e3.py) | |
|
||||
| ChatGLM3_6B | LORA | [chatglm3_6b_lora_alpaca_e3.py](./xtuner_config/chatglm3_6b_lora_alpaca_e3.py) | |
|
||||
| DeepSeek MoE_16B_chat | QLORA | [deepseek_moe_16b_chat_qlora_oasst1_e3.py](./xtuner_config/deepseek_moe_16b_chat_qlora_oasst1_e3.py) | |
|
||||
| Mixtral 8x7B_instruct | QLORA | [mixtral_8x7b_instruct_qlora_oasst1_e3.py](./xtuner_config/mixtral_8x7b_instruct_qlora_oasst1_e3.py) | |
|
||||
| LLaMA3_8b_instruct | QLORA | [aiwei_llama3_8b_instruct_qlora_e3.py](./xtuner_config/aiwei_llama3_8b_instruct_qlora_e3.py) | [OpenXLab](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM-LLaMA3_8b_instruct_aiwei/tree/main), [ModelScope](https://modelscope.cn/models/aJupyter/EmoLLM-LLaMA3_8b_instruct_aiwei/files) |
|
||||
| LLaMA3_8b_instruct | QLORA | [llama3_8b_instruct_qlora_alpaca_e3_M_ruozhi_scM.py](./xtuner_config/llama3_8b_instruct_qlora_alpaca_e3_M_ruozhi_scM.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct3.0), [ModelScope](https://modelscope.cn/models/chg0901/EmoLLM-Llama3-8B-Instruct3.0/summary) |
|
||||
| …… | …… | …… | …… |
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
皆さんのこのプロジェクトへの貢献をお待ちしています~
|
||||
|
||||
---
|
||||
|
||||
このモデルは、個人、グループ、社会のメンタルヘルスを完全に理解し、促進することを目的としています。このモデルには通常、以下の主要なコンポーネントが含まれます:
|
||||
|
||||
- 認知要因:個人の思考パターン、信念システム、認知バイアス、問題解決能力に関するもの。認知要因は、個人が人生の出来事をどのように解釈し、対応するかに影響を与えるため、メンタルヘルスに大きな影響を与えます。
|
||||
- 感情要因:感情の調整、感情の表現、感情の経験を含む。感情の健康はメンタルヘルスの重要な部分であり、個人が感情をどのように管理し、表現し、負の感情からどのように回復するかに関与します。
|
||||
- 行動要因:個人の行動パターン、習慣、対処戦略に関するもの。これには、ストレス管理スキル、社交スキル、自己効力感(自分の能力に対する自信)が含まれます。
|
||||
- 社会環境:家族、仕事、コミュニティ、文化的背景などの外部要因であり、これらは個人のメンタルヘルスに直接的および間接的な影響を与えます。
|
||||
- 身体の健康:身体の健康とメンタルヘルスは密接に関連しています。良好な身体の健康はメンタルヘルスを促進し、その逆もまた然りです。
|
||||
- 心理的レジリエンス:逆境から回復し、適応する個人の能力を指します。心理的レジリエンスが強い人は、挑戦から回復し、それから学び、成長することができます。
|
||||
- 予防および介入措置:メンタルヘルスの大規模モデルには、心理的問題を予防し、メンタルヘルスを促進するための戦略も含まれます。これには、心理教育、カウンセリング、治療、社会的支援システムが含まれます。
|
||||
- 評価および診断ツール:メンタルヘルスを効果的に促進するためには、個人の心理状態を評価し、潜在的な心理的問題を診断するための科学的なツールが必要です。
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td align="center" style="background-color: transparent">
|
||||
<img src="assets\aiwei_demo.gif" alt="占位图">
|
||||
</td>
|
||||
<td align="center" style="background-color: transparent">
|
||||
<img src="assets\aiwei_demo2.gif" alt="占位图">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" style="background-color: transparent">
|
||||
<img src="assets\aiwei_demo3.gif" alt="占位图">
|
||||
</td>
|
||||
<td align="center" style="background-color: transparent">
|
||||
<img src="assets\aiwei_demo4.gif" alt="占位图">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 最近の更新
|
||||
- 【2024.7】EmoLLM V2.0の安定版を日常使用および学術研究にご利用ください。モデルの重みリンク:[OpenXLab](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_internlm2_7b_full/tree/main)。
|
||||
- 【2024.7】InternLM2_5_7B_chatの微調整構成を追加しました。[ModelScope](https://www.modelscope.cn/models/z342994309/emollm_interlm2_5/)。
|
||||
- 【2024.6】[LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)の[GLM4-9B-chat微調整ガイド](./doc/GLM-4-9B-chat%20Lora%20微调(llama-factory).md)を追加しました。[swiftベースの微調整ガイド](./swift/)を追加しました。論文[ESC-Eval: Evaluating Emotion Support Conversations in Large Language Models](https://arxiv.org/abs/2406.14952)がEmoLLMを引用し、EmoLLMが良好な結果を達成しました。
|
||||
- 【2024.05.28】EmoLLMが使用するマルチターン対話データセット**CPsyCunD**と**専門評価方法**が公開されました。詳細は2024 ACL findings[《CPsyCoun: A Report-based Multi-turn Dialogue Reconstruction and Evaluation Framework for Chinese Psychological Counseling》](https://arxiv.org/abs/2405.16433)をご覧ください!
|
||||
- [2024.05.08] EmoLLM**Daddy-like BF V0.1**が[1. **Baidu AppBuilder**](https://appbuilder.baidu.com/s/4cLyw)と[2. **OpenXLab**](https://openxlab.org.cn/apps/detail/chg0901/EmoLLM3.0_Gradio_Llama3-8B-Instruct3.0)で公開されました。ぜひ「いいね」と「コレクション」に追加してください!
|
||||
- [2024.05.07] [インクリメンタルプレトレーニングガイド](xtuner_config/pt/README.md)
|
||||
- [2024.05.04] [LLaMA3_8b_instructベースのEmoLLM3.0 OpenXLabデモ](https://st-app-center-006861-9746-jlroxvg.openxlab.space/)が公開されました([再起動リンク](https://openxlab.org.cn/apps/detail/chg0901/EmoLLM-Llama3-8B-Instruct3.0))、[LLAMA3微調整ガイド](xtuner_config/README_llama3_8b_instruct_qlora_alpaca_e3_M.md)が更新されました。LLaMA3_8b_instruct-8B QLoRA微調整モデルEmoLLM3.0の重みが[**OpenXLab**](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct3.0)と[**ModelScope**](https://modelscope.cn/models/chg0901/EmoLLM-Llama3-8B-Instruct3.0/summary)プラットフォームで公開されました。
|
||||
- [2024.04.20] [LLAMA3微調整ガイド](xtuner_config/README_llama3_8b_instruct_qlora_alpaca_e3_M.md)と[LLaMA3_8b_instructのaiwei](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM-LLaMA3_8b_instruct_aiwei)がオープンソース化されました。
|
||||
- [2023.04.14] [クイックスタート](docs/quick_start_EN.md)とナニー級チュートリアル[BabyEmoLLM](Baby_EmoLLM.ipynb)を追加しました。
|
||||
- [2024.04.02] Huggingfaceに[Old Mother Counsellor](https://huggingface.co/brycewang2018/EmoLLM-mother/tree/main)をアップロードしました。
|
||||
- [2024.03.25] [Mother-like Therapist]がHuggingfaceで公開されました(https://huggingface.co/brycewang2018/EmoLLM-mother/tree/main)。
|
||||
- [2024.03.25] [Daddy-like Boy-Friend]がBaidu Paddle-Paddle AI Studioプラットフォームで公開されました(https://aistudio.baidu.com/community/app/68787)。
|
||||
- [2024.03.24] **InternLM2-Base-7B QLoRA微調整モデル**が**OpenXLab**と**ModelScope**プラットフォームで公開されました。詳細は[**InternLM2-Base-7B QLoRA**](./xtuner_config/README_internlm2_7b_base_qlora.md)をご覧ください。
|
||||
- [2024.03.12] [aiwei]がBaidu Paddle-Paddle AI Studioプラットフォームで公開されました(https://aistudio.baidu.com/community/app/63335)。
|
||||
- [2024.03.11] **EmoLLM V2.0はEmoLLM V1.0と比較して全体的に向上し、心理カウンセリングタスクにおいてRole-playing ChatGPTを上回る能力を持っています!** [EmoLLM V2.0を体験する](https://openxlab.org.cn/apps/detail/Farewell1/EmoLLMV2.0)、[データセットの統計と詳細情報](./datasets/)、[ロードマップ](./assets/Roadmap_ZH.png)を更新しました。
|
||||
- [2024.03.09] 同時実行機能を追加して[QAペア生成](./scripts/qa_generation/)、[RAGパイプライン](./rag/)を加速しました。
|
||||
- [2024.03.03] [InternLM2-7B-chat全量微調整バージョンEmoLLM V2.0がオープンソース化されました](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_internlm2_7b_full)、2つのA100*80Gが必要です。専門評価を更新しました。詳細は[evaluate](./evaluate/)をご覧ください。PaddleOCRベースのPDFからtxtへの変換ツールスクリプトを更新しました。詳細は[scripts](./scripts/)をご覧ください。
|
||||
- [2024.02.29] 客観的評価計算を更新しました。詳細は[evaluate](./evaluate/)をご覧ください。一連のデータセットを更新しました。詳細は[datasets](./datasets/)をご覧ください。
|
||||
- [2024.02.27] 英語のREADMEと一連のデータセット(リッキングドッグとワンターン対話)を更新しました。
|
||||
- [2024.02.23] InternLM2_7B_chat_qloraベースの「優しいお姉さん心理カウンセラーAi Wei」をリリースしました。[モデルの重みを取得する](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_aiwei)、[構成ファイル](xtuner_config/aiwei-internlm2_chat_7b_qlora.py)、[オンライン体験リンク](https://openxlab.org.cn/apps/detail/ajupyter/EmoLLM-aiwei)。
|
||||
- [2024.02.23] [いくつかの微調整構成](/xtuner_config/)を更新しました。[data_pro.json](/datasets/data_pro.json)(より多くの量、より包括的なシナリオ、より豊富な内容)と[aiwei.json](/datasets/aiwei.json)(優しいお姉さんのロールプレイ専用、Emoji表現を含む)を追加しました。「優しいお姉さん心理カウンセラーAi Wei」が近日公開予定です。
|
||||
- [2024.02.18] [Qwen1_5-0_5B-Chat全量微調整バージョンがオープンソース化されました](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary)。計算リソースが限られている方もぜひお試しください。
|
||||
|
||||
<details>
|
||||
<summary>もっと見る</summary>
|
||||
|
||||
- [2024.02.06] [Qwen1_5-0_5B-Chat全量微調整バージョンがオープンソース化されました](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary)。計算リソースが限られている方もぜひお試しください。
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/SmartFlowAI/EmoLLM/assets/62385492/7e931682-c54d-4ded-bc67-79130c68d744" alt="モデルダウンロード数">
|
||||
</p>
|
||||
|
||||
- [2024.02.05] プロジェクトが公式WeChatアカウントNLP Engineeringで紹介されました。記事の[リンク](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A)はこちらです。皆さんのフォローをお待ちしています!! 🥳🥳
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/SmartFlowAI/EmoLLM/assets/62385492/47868d6a-2e91-4aa9-a630-e594c14295b4" alt="公式WeChatアカウントのQRコード">
|
||||
</p>
|
||||
|
||||
- [2024.02.03] [プロジェクトビデオ](https://www.bilibili.com/video/BV1N7421N76X/)がbilibiliで公開されました 😊
|
||||
- [2024.01.27] データ構築ドキュメント、微調整ガイド、デプロイメントガイド、Readmeなどの関連ドキュメントを完成させました 👏
|
||||
- [2024.01.25] EmoLLM V1.0がオンラインでデプロイされました https://openxlab.org.cn/apps/detail/jujimeizuo/EmoLLM 😀
|
||||
|
||||
</details>
|
||||
|
||||
## 栄誉
|
||||
|
||||
- プロジェクトは、**2024浦源大模型シリーズチャレンジ春季大会**で**イノベーションとクリエイティビティ賞**を受賞しました。
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/SmartFlowAI/EmoLLM/">
|
||||
<img src="assets/Shusheng.png" alt="チャレンジイノベーションとクリエイティビティ賞">
|
||||
</p>
|
||||
|
||||
|
||||
- [AI-enabled university programme "National College Tour"](https://mp.weixin.qq.com/s/yyaulQ1wBzKq5cXaGl2Wag)で一等賞を受賞しました。
|
||||
- プロジェクトは公式WeChatアカウント**NLP Engineering**で紹介されました。記事の[リンク](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A)はこちらです。
|
||||
|
||||
## ロードマップ
|
||||
|
||||
- 🎉以下のメディアおよび友人の皆様に、このプロジェクトの報道とサポートに感謝します(以下、順不同!省略があれば申し訳ありませんが、感謝しています!追加を歓迎します!):[NLP工程化](https://mp.weixin.qq.com/s/78lrRl2tlXEKUfElnkVx4A)、[机智流](https://mp.weixin.qq.com/s/_wMCmssRMGd0Oz5OVVkjAA)、[爱可可爱生活](https://mp.weixin.qq.com/s/4WaCg4OpkCWXEuWHuV4r3w)、[阿郎小哥](https://mp.weixin.qq.com/s/_MSMeL1XHP0v5lDi3YaPVw)、[大模型日知路](https://mp.weixin.qq.com/s/FYYibsCXtfU6FFM9TuKILA)、[AI Code](https://mp.weixin.qq.com/s/yDWGY3S4CwCi6U_irsFmqA)など!
|
||||
|
||||
- プロジェクトビデオ[EmoLLM](https://www.bilibili.com/video/BV1N7421N76X/)が公開されました。ぜひご覧ください! 😀
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/SmartFlowAI/EmoLLM/">
|
||||
<img src="assets/Roadmap_EN.png" alt="ロードマップ_EN">
|
||||
</a>
|
||||
|
||||
## コンテンツ
|
||||
|
||||
- [EmoLLM - メンタルヘルスのための大規模言語モデル](#emollm---メンタルヘルスのための大規模言語モデル)
|
||||
- [最近の更新](#最近の更新)
|
||||
- [栄誉](#栄誉)
|
||||
- [ロードマップ](#ロードマップ)
|
||||
- [コンテンツ](#コンテンツ)
|
||||
- [開発前の構成要件](#開発前の構成要件)
|
||||
- [ユーザーガイド](#ユーザーガイド)
|
||||
- [🍪クイックスタート](#クイックスタート)
|
||||
- [📌データ構築](#データ構築)
|
||||
- [🎨微調整ガイド](#微調整ガイド)
|
||||
- [🔧デプロイメントガイド](#デプロイメントガイド)
|
||||
- [⚙RAG(検索強化生成)](#rag検索強化生成)
|
||||
- [🎓評価ガイド](#評価ガイド)
|
||||
- [使用されたフレームワーク](#使用されたフレームワーク)
|
||||
- [このプロジェクトに参加する方法](#このプロジェクトに参加する方法)
|
||||
- [バージョン管理](#バージョン管理)
|
||||
- [著者(順不同)](#著者順不同)
|
||||
- [著作権表示](#著作権表示)
|
||||
- [謝辞](#謝辞)
|
||||
- [関連プロジェクト](#関連プロジェクト)
|
||||
- [人々](#人々)
|
||||
- [スター履歴](#スター履歴)
|
||||
- [🌟 貢献者](#-貢献者)
|
||||
- [コミュニケーショングループ](#コミュニケーショングループ)
|
||||
|
||||
###### 開発前の構成要件
|
||||
|
||||
- A100 40G(特にInternLM2_7B_chat + qlora微調整 + deepspeed zero2最適化用)
|
||||
|
||||
###### ユーザーガイド
|
||||
|
||||
1. リポジトリをクローンする
|
||||
|
||||
```sh
|
||||
git clone https://github.com/SmartFlowAI/EmoLLM.git
|
||||
```
|
||||
|
||||
1. 順番に読むか、興味のあるセクションを読む:
|
||||
- [クイックスタート](#クイックスタート)
|
||||
- [データ構築](#データ構築)
|
||||
- [微調整ガイド](#微調整ガイド)
|
||||
- [デプロイメントガイド](#デプロイメントガイド)
|
||||
- [RAG](#rag検索強化生成)
|
||||
- [評価ガイド](#評価ガイド)
|
||||
- 詳細を表示
|
||||
|
||||
|
||||
### 🍪クイックスタート
|
||||
- [クイックスタート](quick_start/quick_start_EN.md)を参照してください。
|
||||
- クイックコーディング:[Baby EmoLLM](quick_start/Baby_EmoLLM.ipynb)
|
||||
|
||||
### 📌データ構築
|
||||
|
||||
- [データ構築ガイド](generate_data/tutorial_EN.md)を参照してください。
|
||||
|
||||
- この微調整に使用されたデータセットは[datasets](datasets/data.json)にあります。
|
||||
|
||||
### 🎨微調整ガイド
|
||||
|
||||
詳細は[微調整ガイド](xtuner_config/README_EN.md)を参照してください。
|
||||
|
||||
### 🔧デプロイメントガイド
|
||||
|
||||
- デモデプロイメント:詳細は[デプロイメントガイド](./demo/README_EN.md)を参照してください。
|
||||
- [LMDeploy](https://github.com/InternLM/lmdeploy/)に基づく定量デプロイメント:詳細は[deploy](./deploy/lmdeploy_EN.md)を参照してください。
|
||||
|
||||
### ⚙RAG(検索強化生成)
|
||||
|
||||
- 詳細は[RAG](rag/README_EN.md)を参照してください。
|
||||
|
||||
### 🎓評価ガイド
|
||||
|
||||
- モデル評価は**一般的な指標評価**と**専門的な指標評価**に分かれています。詳細は[評価ガイド](evaluate/README.md)を参照してください。
|
||||
|
||||
<details>
|
||||
<summary>追加の詳細</summary>
|
||||
|
||||
### 使用されたフレームワーク
|
||||
|
||||
- [Xtuner](https://github.com/InternLM/xtuner)
|
||||
- [Transformers](https://github.com/huggingface/transformers)
|
||||
- [Pytorch](https://pytorch.org/)
|
||||
- [LMDeploy](https://github.com/InternLM/lmdeploy/): 定量デプロイメント用
|
||||
- [Stremlit](https://streamlit.io/): デモ構築用
|
||||
- [DeepSpeed](https://github.com/microsoft/DeepSpeed): 並列トレーニング用
|
||||
- …
|
||||
|
||||
#### このプロジェクトに参加する方法
|
||||
|
||||
貢献は、オープンソースコミュニティを学習、インスピレーション、創造の素晴らしい場所にします。あなたの貢献は非常に感謝されます。
|
||||
|
||||
1. プロジェクトをフォークする
|
||||
2. フィーチャーブランチを作成する(`git checkout -b feature/AmazingFeature`)
|
||||
3. 変更をコミットする(`git commit -m 'Add some AmazingFeature'`)
|
||||
4. ブランチにプッシュする(`git push origin feature/AmazingFeature`)
|
||||
5. プルリクエストを開く
|
||||
|
||||
### バージョン管理
|
||||
|
||||
このプロジェクトはバージョン管理にGitを使用しています。現在利用可能なバージョンはリポジトリで確認できます。
|
||||
|
||||
</details>
|
||||
|
||||
### 著者(順不同)
|
||||
|
||||
| ユーザー名 | 学校/組織 | 備考 | 貢献 |
|
||||
| :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
|
||||
| [aJupyter](https://github.com/aJupyter) | 南開大学、修士課程在籍 | DataWhaleメンバー | プロジェクト発起人 |
|
||||
| [MING-ZCH](https://github.com/MING-ZCH) | 華中科技大学、学部生 | LLM X メンタルヘルス研究者 | プロジェクト共同リーダー |
|
||||
| [jujimeizuo](https://github.com/jujimeizuo) | 江南大学、修士課程在籍 | | |
|
||||
| [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | ハルビン工業大学(威海)、学部生 | | |
|
||||
| [8baby8](https://github.com/8baby8) | PaddlePaddleパイロットチーム地域ディレクター | 文心大モデルのコア開発者 | |
|
||||
| [zxazys](https://github.com/zxazys) | 南開大学、修士課程在籍 | | |
|
||||
| [JasonLLLLLLLLLLL](https://github.com/JasonLLLLLLLLLLL) | SWUFE(西南財経大学) | | |
|
||||
| [MrCatAI](https://github.com/MrCatAI) | AIムーバー | | |
|
||||
| [ZeyuBa](https://github.com/ZeyuBa) | 自動化研究所、修士課程在籍 | | |
|
||||
| [aiyinyuedejustin](https://github.com/aiyinyuedejustin) | ペンシルベニア大学、修士課程在籍 | | |
|
||||
| [Nobody-ML](https://github.com/Nobody-ML) | 中国石油大学(華東)、学部生 | | |
|
||||
| [chg0901](https://github.com/chg0901) | [MiniSora](https://github.com/mini-sora/minisora) | [MiniSora](https://github.com/mini-sora/minisora)のメンテナーおよび管理者 | LLMの事前トレーニングと微調整、モデルのアップロード、データのクリーニング、ドキュメントの翻訳 |
|
||||
| [Mxoder](https://github.com/Mxoder) | 北京航空航天大学、学部生 | | |
|
||||
| [Anooyman](https://github.com/Anooyman) | 南京理工大学、修士課程在籍 | | |
|
||||
| [Vicky-3021](https://github.com/Vicky-3021) | 西安電子科技大学、修士課程在籍(研究年0) | | |
|
||||
| [SantiagoTOP](https://github.com/santiagoTOP) | 太原理工大学、修士課程在籍 | | データのクリーニング、ドキュメント管理、Baby EmoLLMのメンテナンス |
|
||||
| [zealot52099](https://github.com/zealot52099) | 個人開発者 | | データ処理、LLMの微調整とRAG |
|
||||
| [wwwyfff](https://github.com/wwwyfff) | 復旦大学、修士課程在籍 | | |
|
||||
| [jkhumor](https://github.com/jkhumor) | 南開大学、修士課程在籍 | | RAG |
|
||||
| [lll997150986](https://github.com/lll997150986) | 南開大学、修士課程在籍 | | 微調整 |
|
||||
| [nln-maker](https://github.com/nln-maker) | 南開大学、修士課程在籍 | | フロントエンドとバックエンドの開発 |
|
||||
| [dream00001](https://github.com/dream00001) | 南開大学、修士課程在籍 | | フロントエンドとバックエンドの開発 |
|
||||
| [王几行XING](zhihu.com/people/brycewang1898) | 北京大学、修士課程卒業 | | データ処理、LLMの微調整、フロントエンドとバックエンドの開発 |
|
||||
| [思在] | 北京大学、修士課程卒業(マイクロソフト) | | LLMの微調整、フロントエンドとバックエンドの開発 |
|
||||
| [TingWei](https://github.com/wwewwt) | 電子科技大学、修士課程卒業 | | LLMの微調整 |
|
||||
| [PengYu](https://github.com/hi-pengyu) | 石河子大学、修士課程在籍 | | LLMの微調整 |
|
||||
### 著作権表示
|
||||
|
||||
このプロジェクトはMITライセンスの下でライセンスされています。詳細については、[LICENSE](https://github.com/SmartFlowAI/EmoLLM/blob/master/LICENSE)を参照してください。
|
||||
|
||||
### 謝辞
|
||||
#### 関連プロジェクト
|
||||
- [CPsyCoun](https://github.com/CAS-SIAT-XinHai/CPsyCoun)
|
||||
- [Smile](https://github.com/qiuhuachuan/smile)
|
||||
- [SoulChat](https://github.com/scutcyr/SoulChat)
|
||||
|
||||
#### 人々
|
||||
- [上海人工知能研究所](https://www.shlab.org.cn/)
|
||||
- [Vansin](https://github.com/vansin)
|
||||
- A.bu(心理学修士、北京大学)
|
||||
- [Sanbuphy](https://github.com/sanbuphy)
|
||||
- [HatBoy](https://github.com/hatboy)
|
||||
|
||||
<!-- links -->
|
||||
|
||||
<!-- [linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=flat-square&logo=linkedin&colorB=555 -->
|
||||
|
||||
<!-- [linkedin-url]: https://linkedin.com/in/aJupyter -->
|
||||
|
||||
<!-- 太少了,没必要放 -->
|
||||
|
||||
## スター履歴
|
||||
|
||||
[](https://star-history.com/#SmartFlowAI/EmoLLM&Date)
|
||||
|
||||
## 🌟 貢献者
|
||||
|
||||
[](https://github.com/SmartFlowAI/EmoLLM/graphs/contributors)
|
||||
|
||||
[your-project-path]: SmartflowAI/EmoLLM
|
||||
[contributors-shield]: https://img.shields.io/github/contributors/SmartflowAI/EmoLLM.svg?style=flat-square
|
||||
[contributors-url]: https://github.com/SmartflowAI/EmoLLM/graphs/contributors
|
||||
[forks-shield]: https://img.shields.io/github/forks/SmartflowAI/EmoLLM.svg?style=flat-square
|
||||
[forks-url]: https://github.com/SmartflowAI/EmoLLM/network/members
|
||||
[stars-shield]: https://img.shields.io/github/stars/SmartflowAI/EmoLLM.svg?style=flat-square
|
||||
[stars-url]: https://github.com/SmartflowAI/EmoLLM/stargazers
|
||||
[issues-shield]: https://img.shields.io/github/issues/SmartflowAI/EmoLLM.svg?style=flat-square
|
||||
[issues-url]: https://img.shields.io/github/issues/SmartflowAI/EmoLLM.svg
|
||||
[license-shield]: https://img.shields.io/github/license/SmartflowAI/EmoLLM.svg?style=flat-square
|
||||
[license-url]: https://github.com/SmartFlowAI/EmoLLM/blob/main/LICENSE
|
||||
|
||||
[OpenXLab_App-image]: https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg
|
||||
[OpenXLab_Model-image]: https://cdn-static.openxlab.org.cn/header/openxlab_models.svg
|
||||
[OpenXLab_App-url]: https://openxlab.org.cn/apps/detail/Farewell1/EmoLLMV2.0
|
||||
[OpenXLab_Model-url]: https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_internlm2_7b_full
|
||||
|
||||
## コミュニケーショングループ
|
||||
|
||||
- 失敗した場合は、Issueセクションに移動してください。
|
||||
|
||||
<p align="center">
|
||||
<img width="30%" src="https://github.com/SmartFlowAI/EmoLLM/assets/62385492/55ecd0aa-4832-4269-ad57-4c26f9aa286b" alt="EmoLLM公式コミュニケーショングループ">
|
||||
</p>
|
@ -1,217 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import os
|
||||
|
||||
|
||||
class AssembleHeaderException(Exception):
|
||||
def __init__(this, msg):
|
||||
this.message = msg
|
||||
|
||||
|
||||
class Url:
|
||||
def __init__(this, host, path, schema):
|
||||
this.host = host
|
||||
this.path = path
|
||||
this.schema = schema
|
||||
pass
|
||||
|
||||
|
||||
# calculate sha256 and encode to base64
|
||||
def sha256base64(data):
|
||||
sha256 = hashlib.sha256()
|
||||
sha256.update(data)
|
||||
digest = base64.b64encode(sha256.digest()).decode(encoding='utf-8')
|
||||
return digest
|
||||
|
||||
|
||||
def parse_url(requset_url):
|
||||
stidx = requset_url.index("://")
|
||||
host = requset_url[stidx + 3:]
|
||||
schema = requset_url[:stidx + 3]
|
||||
edidx = host.index("/")
|
||||
if edidx <= 0:
|
||||
raise AssembleHeaderException("invalid request url:" + requset_url)
|
||||
path = host[edidx:]
|
||||
host = host[:edidx]
|
||||
u = Url(host, path, schema)
|
||||
return u
|
||||
|
||||
|
||||
# 生成鉴权url
|
||||
def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
|
||||
u = parse_url(requset_url)
|
||||
host = u.host
|
||||
path = u.path
|
||||
now = datetime.now()
|
||||
date = format_date_time(time.mktime(now.timetuple()))
|
||||
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path)
|
||||
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
api_key, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
values = {
|
||||
"host": host,
|
||||
"date": date,
|
||||
"authorization": authorization
|
||||
}
|
||||
|
||||
return requset_url + "?" + urlencode(values)
|
||||
|
||||
|
||||
def get_Body(appid, text, style):
|
||||
org_content = json.dumps(text).encode('utf-8')
|
||||
body = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "39769795890",
|
||||
"status": 3
|
||||
},
|
||||
"parameter": {
|
||||
"emb": {
|
||||
"domain": style,
|
||||
"feature": {
|
||||
"encoding": "utf8"
|
||||
}
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"messages": {
|
||||
"text": base64.b64encode(json.dumps(text).encode('utf-8')).decode()
|
||||
}
|
||||
}
|
||||
}
|
||||
return body
|
||||
|
||||
|
||||
# 发起请求并返回结果
|
||||
def get_embp_embedding(text, appid, apikey, apisecret):
|
||||
host = 'https://emb-cn-huabei-1.xf-yun.com/'
|
||||
url = assemble_ws_auth_url(host, method='POST', api_key=apikey, api_secret=apisecret)
|
||||
content = get_Body(appid, text, "para")
|
||||
response = requests.post(url, json=content, headers={'content-type': "application/json"}).text
|
||||
return response
|
||||
|
||||
|
||||
# 解析结果并输出
|
||||
def parser_Message(message):
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
print(f'请求错误: {code}, {data}')
|
||||
return None
|
||||
else:
|
||||
text_base = data["payload"]["feature"]["text"]
|
||||
text_data = base64.b64decode(text_base)
|
||||
dt = np.dtype(np.float32).newbyteorder("<")
|
||||
text = np.frombuffer(text_data, dtype=dt)
|
||||
return text
|
||||
|
||||
|
||||
# 加载问答对数据
|
||||
def load_qa_data(file_path):
|
||||
qa_pairs = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
qa_pairs.append(json.loads(line.strip()))
|
||||
return qa_pairs
|
||||
|
||||
|
||||
# 保存embedding到文件
|
||||
def save_embeddings(embeddings, file_path):
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(embeddings, f, ensure_ascii=False)
|
||||
|
||||
|
||||
# 获取文本的embedding
|
||||
def get_embedding_for_text(text, appid, apikey, apisecret):
|
||||
desc = {"messages": [{"content": text, "role": "user"}]}
|
||||
res = get_embp_embedding(desc, appid=appid, apikey=apikey, apisecret=apisecret)
|
||||
return parser_Message(res)
|
||||
|
||||
|
||||
# 逐行加载已存在的embedding
|
||||
def load_embeddings(file_path):
|
||||
embeddings = {}
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip(): # 忽略空行
|
||||
embedding_data = json.loads(line.strip())
|
||||
embeddings.update(embedding_data)
|
||||
except FileNotFoundError:
|
||||
print(f"文件 {file_path} 不存在,将创建新文件")
|
||||
return embeddings
|
||||
|
||||
|
||||
# 逐行保存embedding到文件
|
||||
def save_embedding_line_by_line(qa, embedding, file_path):
|
||||
if embedding is not None:
|
||||
embedding_as_list = embedding.tolist() # 将numpy array转换为列表
|
||||
with open(file_path, 'a', encoding='utf-8') as f:
|
||||
json.dump({qa: embedding_as_list}, f, ensure_ascii=False)
|
||||
f.write("\n") # 每行一个embedding
|
||||
|
||||
|
||||
# 获取单个问题的embedding,并处理请求错误
|
||||
def get_embedding_with_retry(question, appid, apikey, apisecret, max_retries=5):
|
||||
retries = 0
|
||||
while retries < max_retries:
|
||||
try:
|
||||
embedding = get_embedding_for_text(question, appid, apikey, apisecret)
|
||||
if embedding is not None:
|
||||
return embedding
|
||||
except Exception as e:
|
||||
print(f"请求错误: {e}")
|
||||
retries += 1
|
||||
print(f"重试第 {retries} 次...")
|
||||
time.sleep(5) # 每次重试前等待 5 秒
|
||||
print(f"获取'{question}' 的embedding失败")
|
||||
return None
|
||||
|
||||
|
||||
# 获取所有问答对的embedding并逐行保存
|
||||
def get_and_save_embeddings(qa_pairs, appid, apikey, apisecret, file_path, qps_limit=2):
|
||||
all_embeddings = load_embeddings(file_path) # 尝试加载已存在的embedding
|
||||
interval = 1 / qps_limit # 根据QPS限制设置间隔时间
|
||||
for qa in qa_pairs:
|
||||
question = qa['input']
|
||||
if question in all_embeddings:
|
||||
print(f"'{question}' 的embedding已存在,跳过计算")
|
||||
continue
|
||||
print(f"计算'{question}' 的embedding...")
|
||||
embedding = get_embedding_with_retry(question, appid, apikey, apisecret) # 带重试机制的请求
|
||||
if embedding is not None:
|
||||
save_embedding_line_by_line(question, embedding, file_path) # 逐行保存
|
||||
all_embeddings[question] = embedding # 更新已计算的embedding
|
||||
time.sleep(interval) # 确保符合QPS限制
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 设置路径
|
||||
qa_file = "output/train_optimized_multiple.jsonl" # 原问答对文件
|
||||
embedding_file = "output/qa_embeddings.json" # embedding存储文件
|
||||
|
||||
appid = "f0f73de5"
|
||||
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
|
||||
api_key = "5773f6f95563708de994d17b7ea5d414"
|
||||
|
||||
# 加载数据
|
||||
qa_pairs = load_qa_data(qa_file)
|
||||
|
||||
# 获取并保存embedding
|
||||
get_and_save_embeddings(qa_pairs, appid, api_key, api_secret, embedding_file)
|
||||
|
||||
print(f"已保存所有问答对的embedding到 {embedding_file}")
|
@ -1,122 +0,0 @@
|
||||
import json
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
import jsonlines
|
||||
|
||||
# 加载问答对嵌入
|
||||
qa_embeddings = {}
|
||||
with jsonlines.open('output/qa_embeddings.json', 'r') as reader:
|
||||
for obj in reader:
|
||||
qa_embeddings.update(obj) # 将每行的json对象加入到qa_embeddings
|
||||
|
||||
# 加载问答对
|
||||
qa_pairs = []
|
||||
with open('output/train_optimized_multiple.jsonl', 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
qa_pairs.append(json.loads(line))
|
||||
|
||||
# 提取嵌入和问题
|
||||
questions = list(qa_embeddings.keys())
|
||||
embeddings = np.array(list(qa_embeddings.values()))
|
||||
|
||||
# 关键词及其类别
|
||||
categories = {
|
||||
"栽培油橄榄的意义": ["栽培油橄榄", "经济价值", "引种"],
|
||||
"油橄榄属植物分类": ["油橄榄属", "植物分类", "植物种", "原产地"],
|
||||
"油橄榄生物学特性": ["根系类型", "土壤关系", "花芽分化", "花序", "授粉特性", "果实发育", "油脂形成"],
|
||||
"油橄榄的生态环境条件": ["气候条件", "温度", "光照", "水分", "土壤生态", "海拔高度", "坡度"],
|
||||
"油橄榄品种": ["佛奥", "莱星", "皮削利", "阿斯", "配多灵", "果大尔", "皮瓜尔", "科拉蒂", "克里", "爱桑", "贝拉", "实生种"],
|
||||
"油橄榄育苗技术": ["育苗场地", "种子繁殖", "实生苗", "嫁接繁殖", "砧木", "接穗", "扦插繁殖", "组织培养"],
|
||||
"油橄榄种植": ["园地选择", "种植密度", "栽植方式", "栽后管理"],
|
||||
"土壤、肥料、水管理": ["土壤管理", "矿质营养", "果园灌溉", "果实采收"],
|
||||
"整形修剪": ["整形修剪", "生物学原理", "结果习性", "树形", "幼树修剪", "复壮修剪"],
|
||||
"病虫害防治": ["孔雀斑病", "炭疽病", "黄萎病", "肿瘤病", "根腐病", "云斑天牛", "油橄榄片盾", "大粒横沟象"]
|
||||
}
|
||||
|
||||
# 初始化类别关键词的嵌入字典
|
||||
category_embeddings = {category: [] for category in categories}
|
||||
|
||||
|
||||
# 假设我们有一个方法来计算关键词的嵌入,例如从qa_embeddings中获取
|
||||
def get_keyword_embedding(keyword):
|
||||
return qa_embeddings.get(keyword, None)
|
||||
|
||||
|
||||
# 为每个类别生成关键词的嵌入
|
||||
for category, keywords in categories.items():
|
||||
for keyword in keywords:
|
||||
keyword_embedding = get_keyword_embedding(keyword)
|
||||
if keyword_embedding is not None:
|
||||
category_embeddings[category].append(keyword_embedding)
|
||||
|
||||
# 将类别关键词的嵌入转化为平均向量
|
||||
for category in category_embeddings:
|
||||
if category_embeddings[category]:
|
||||
category_embeddings[category] = np.mean(category_embeddings[category], axis=0)
|
||||
else:
|
||||
category_embeddings[category] = np.zeros(embeddings.shape[1]) # 默认空向量
|
||||
|
||||
# 计算每个问题与类别之间的相似度
|
||||
category_similarities = {}
|
||||
for idx, question in enumerate(questions):
|
||||
question_embedding = embeddings[idx]
|
||||
category_similarities[question] = {}
|
||||
|
||||
for category, category_embedding in category_embeddings.items():
|
||||
similarity = cosine_similarity([question_embedding], [category_embedding])[0][0]
|
||||
category_similarities[question][category] = similarity
|
||||
|
||||
# 将每个问题分配到相似度最高的类别
|
||||
category_assignments = {category: [] for category in categories}
|
||||
for question in questions:
|
||||
best_category = max(category_similarities[question], key=category_similarities[question].get)
|
||||
category_assignments[best_category].append(question)
|
||||
|
||||
# 整合并生成新的jsonl格式,确保每个问答对都被包括
|
||||
fine_tune_data = []
|
||||
for category, assigned_questions in category_assignments.items():
|
||||
for idx, question in enumerate(assigned_questions):
|
||||
history = []
|
||||
output = ""
|
||||
instruction = ""
|
||||
|
||||
# 查找当前问题及其回答
|
||||
qa_pair = next((qa for qa in qa_pairs if qa['input'] == question), None)
|
||||
|
||||
if qa_pair:
|
||||
instruction = qa_pair['input'] # 当前问题作为instruction
|
||||
output = qa_pair['output'] # 当前问题的回答作为output
|
||||
|
||||
# 从同一类别的其他问题构建history,保证每个history与当前问题在同一类别
|
||||
history_similarities = []
|
||||
for related_question in assigned_questions:
|
||||
if related_question != question:
|
||||
related_embedding = qa_embeddings[related_question]
|
||||
similarity = cosine_similarity([qa_embeddings[question]], [related_embedding])[0][0]
|
||||
history_similarities.append((related_question, similarity))
|
||||
|
||||
# 按相似度排序,并选择前1~3个问题作为history
|
||||
history_similarities = sorted(history_similarities, key=lambda x: x[1], reverse=True)
|
||||
for related_question, _ in history_similarities[:3]:
|
||||
related_qa_pair = next((qa for qa in qa_pairs if qa['input'] == related_question), None)
|
||||
if related_qa_pair:
|
||||
history.append([related_qa_pair['input'], related_qa_pair['output']])
|
||||
|
||||
# 构建最终格式
|
||||
if instruction and output:
|
||||
fine_tune_entry = {
|
||||
"instruction": instruction,
|
||||
"input": "", # input为空
|
||||
"output": output, # 当前问题的回答
|
||||
"history": history, # 最多包含3条相关问题
|
||||
"system": "你是一位油橄榄栽培专家,熟知油橄榄的品种分类、栽培技术、生态环境要求以及病虫害防治。"
|
||||
}
|
||||
fine_tune_data.append(fine_tune_entry)
|
||||
|
||||
# 保存新的jsonl格式
|
||||
with open('output/fine_tune_data.jsonl', 'w', encoding='utf-8') as f:
|
||||
for entry in fine_tune_data:
|
||||
json.dump(entry, f, ensure_ascii=False)
|
||||
f.write('\n')
|
||||
|
||||
print("对话数据整理完成")
|
@ -1,71 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/10/24 11:10
|
||||
# @Author : 黄子寒
|
||||
# @Email : 1064071566@qq.com
|
||||
# @File : LDArec.py
|
||||
# @Project : EmoLLM
|
||||
import json
|
||||
import jieba
|
||||
from gensim import corpora
|
||||
from gensim.models.ldamodel import LdaModel
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
# 加载问答对数据
|
||||
def load_qa_data(file_path):
|
||||
qa_pairs = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
qa_pairs.append(json.loads(line.strip()))
|
||||
return qa_pairs
|
||||
|
||||
|
||||
# 加载中文停用词
|
||||
def load_stopwords(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return set([line.strip() for line in f])
|
||||
|
||||
|
||||
# 使用jieba对中文文本进行分词,并去除停用词
|
||||
def preprocess_text(text, stopwords):
|
||||
words = jieba.lcut(text) # 使用jieba进行中文分词
|
||||
words = [word for word in words if word not in stopwords and len(word) > 1] # 去除停用词和长度为1的词
|
||||
return words
|
||||
|
||||
|
||||
# 生成LDA主题模型
|
||||
def build_lda_model(qa_pairs, stopwords, num_topics=5):
|
||||
# 处理所有问题文本
|
||||
questions = [qa['input'] for qa in qa_pairs]
|
||||
processed_questions = [preprocess_text(question, stopwords) for question in questions]
|
||||
|
||||
# 创建字典和词袋模型
|
||||
dictionary = corpora.Dictionary(processed_questions)
|
||||
corpus = [dictionary.doc2bow(text) for text in processed_questions]
|
||||
|
||||
# 训练LDA模型
|
||||
lda_model = LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
|
||||
return lda_model, dictionary, corpus
|
||||
|
||||
|
||||
# 打印每个主题的关键词
|
||||
def print_topics(lda_model, num_words=10):
|
||||
for idx, topic in lda_model.print_topics(num_words=num_words):
|
||||
print(f"主题 {idx}: {topic}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
qa_file = "output/train_optimized_multiple.jsonl" # 问答对文件
|
||||
stopwords_file = "chinese_stopwords.txt" # 停用词文件
|
||||
|
||||
# 加载问答对
|
||||
qa_pairs = load_qa_data(qa_file)
|
||||
|
||||
# 加载停用词
|
||||
stopwords = load_stopwords(stopwords_file)
|
||||
|
||||
# 构建LDA主题模型
|
||||
lda_model, dictionary, corpus = build_lda_model(qa_pairs, stopwords, num_topics=20)
|
||||
|
||||
# 打印主题及其关键词
|
||||
print_topics(lda_model)
|
@ -1,70 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import random
|
||||
|
||||
# 定义生成5000条数据集的函数
|
||||
def generate_dataset(num_samples=5000):
|
||||
dataset = []
|
||||
invoke_types = [1, 2, 3]
|
||||
area_codes = [chr(i) for i in range(ord('A'), ord('Z') + 1)]
|
||||
parameters = [
|
||||
{"name": "土壤湿度", "unit": "%", "min": 10, "max": 100},
|
||||
{"name": "土壤温度", "unit": "℃", "min": 5, "max": 40},
|
||||
{"name": "空气温度", "unit": "℃", "min": -10, "max": 45},
|
||||
{"name": "电导率", "unit": "mS/cm", "min": 0.1, "max": 5.0}
|
||||
]
|
||||
|
||||
for _ in range(num_samples):
|
||||
invoke_type = random.choice(invoke_types)
|
||||
area_code = random.choice(area_codes)
|
||||
parameter = random.choice(parameters)
|
||||
|
||||
if isinstance(parameter["min"], int):
|
||||
value = round(random.uniform(parameter["min"], parameter["max"]), 1)
|
||||
else:
|
||||
value = round(random.uniform(parameter["min"], parameter["max"]), 1)
|
||||
|
||||
# 增加多变的提问方式,使数据更自然化
|
||||
instruction_templates = [
|
||||
f"现在{area_code}种植区内{parameter['name']}如何?",
|
||||
f"请告诉我{area_code}区的{parameter['name']}情况。",
|
||||
f"{area_code}区当前的{parameter['name']}是多少?",
|
||||
f"我想知道{area_code}区的{parameter['name']}。",
|
||||
f"{area_code}区的{parameter['name']}现在是多少?",
|
||||
f"{area_code}种植区目前的{parameter['name']}是多少?",
|
||||
f"能提供{area_code}区的{parameter['name']}数据吗?",
|
||||
f"{area_code}种植区的{parameter['name']}是多少?",
|
||||
f"请查询{area_code}区的{parameter['name']}。",
|
||||
f"{area_code}区现在的{parameter['name']}数据是多少?",
|
||||
f"帮我看看{area_code}区{parameter['name']}的情况。",
|
||||
f"{area_code}区的{parameter['name']}值是多少?",
|
||||
f"帮我查一下{area_code}区的{parameter['name']}。",
|
||||
f"{area_code}区的{parameter['name']}现在什么情况?",
|
||||
f"请帮我查一下{area_code}种植区的{parameter['name']}是多少?",
|
||||
f"我需要知道{area_code}区的{parameter['name']}数据。",
|
||||
f"请问{area_code}区的{parameter['name']}如何?",
|
||||
f"帮我查询{area_code}区的{parameter['name']}情况。",
|
||||
f"现在{area_code}区的{parameter['name']}值是多少?"
|
||||
]
|
||||
instruction = random.choice(instruction_templates)
|
||||
output = f"{area_code}区现在{parameter['name']}{value}{parameter['unit']}"
|
||||
|
||||
data = {
|
||||
"instruction": instruction,
|
||||
"invokeType": str(invoke_type),
|
||||
"areaCode": area_code,
|
||||
"output": output
|
||||
}
|
||||
dataset.append(data)
|
||||
|
||||
return dataset
|
||||
|
||||
# 生成数据并保存为json文件
|
||||
if __name__ == '__main__':
|
||||
dataset = generate_dataset()
|
||||
output_file = 'output/synthetic_dataset.json'
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(dataset, f, ensure_ascii=False, indent=4)
|
||||
|
||||
print(f"已生成 {output_file} 文件,包含{len(dataset)}条数据。")
|
@ -1,136 +0,0 @@
|
||||
import _thread as thread
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
import ssl
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import websocket # 使用websocket_client
|
||||
answer = ""
|
||||
|
||||
class Ws_Param(object):
|
||||
# 初始化
|
||||
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
||||
self.APPID = APPID
|
||||
self.APIKey = APIKey
|
||||
self.APISecret = APISecret
|
||||
self.host = urlparse(Spark_url).netloc
|
||||
self.path = urlparse(Spark_url).path
|
||||
self.Spark_url = Spark_url
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": self.host
|
||||
}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.Spark_url + '?' + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
print("### error:", error)
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws,one,two):
|
||||
print(" ")
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
thread.start_new_thread(run, (ws,))
|
||||
|
||||
|
||||
def run(ws, *args):
|
||||
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
|
||||
ws.send(data)
|
||||
|
||||
|
||||
# 收到websocket消息的处理
|
||||
def on_message(ws, message):
|
||||
# print(message)
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
print(f'请求错误: {code}, {data}')
|
||||
ws.close()
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
print(content,end ="")
|
||||
global answer
|
||||
answer += content
|
||||
# print(1)
|
||||
if status == 2:
|
||||
ws.close()
|
||||
|
||||
|
||||
def gen_params(appid, domain,question):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "1234"
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"temperature": 0.5,
|
||||
"max_tokens": 2048
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": question
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def main(appid, api_key, api_secret, Spark_url,domain, question):
|
||||
# print("星火:")
|
||||
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = wsParam.create_url()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
||||
ws.appid = appid
|
||||
ws.question = question
|
||||
ws.domain = domain
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
|
@ -1,24 +0,0 @@
|
||||
import requests
|
||||
import json
|
||||
|
||||
url = "https://chatapi.midjourney-vip.cn/v1/chat/completions"
|
||||
|
||||
payload = json.dumps({
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "测试"
|
||||
}
|
||||
]
|
||||
})
|
||||
headers = {
|
||||
'Accept': 'application/json',
|
||||
'Authorization': 'sk-ATDf2Ax1YTGeeTaBD9Be2a7bE0064618Ae3378EaF0Df6f24',
|
||||
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
|
||||
print(response.text)
|
File diff suppressed because it is too large
Load Diff
@ -1,66 +0,0 @@
|
||||
栽培油橄榄
|
||||
经济价值
|
||||
引种
|
||||
油橄榄属
|
||||
植物分类
|
||||
植物种
|
||||
原产地
|
||||
根系类型
|
||||
土壤关系
|
||||
花芽分化
|
||||
花序
|
||||
授粉特性
|
||||
果实发育
|
||||
油脂形成
|
||||
气候条件
|
||||
温度
|
||||
光照
|
||||
水分
|
||||
土壤生态
|
||||
海拔高度
|
||||
坡度
|
||||
佛奥
|
||||
莱星
|
||||
皮削利
|
||||
阿斯
|
||||
配多灵
|
||||
果大尔
|
||||
皮瓜尔
|
||||
科拉蒂
|
||||
克里
|
||||
爱桑
|
||||
贝拉
|
||||
实生种
|
||||
育苗场地
|
||||
种子繁殖
|
||||
实生苗
|
||||
嫁接繁殖
|
||||
砧木
|
||||
接穗
|
||||
扦插繁殖
|
||||
组织培养
|
||||
园地选择
|
||||
种植密度
|
||||
栽植方式
|
||||
栽后管理
|
||||
土壤管理
|
||||
矿质营养
|
||||
果园灌溉
|
||||
果实采收
|
||||
整形修剪
|
||||
生物学原理
|
||||
结果习性
|
||||
树形
|
||||
幼树修剪
|
||||
复壮修剪
|
||||
孔雀斑病
|
||||
炭疽病
|
||||
黄萎病
|
||||
肿瘤病
|
||||
根腐病
|
||||
云斑天牛
|
||||
油橄榄片盾
|
||||
大粒横沟象
|
||||
引进品种名录
|
||||
中英对照品种名称
|
||||
病虫害判定表
|
@ -1,116 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
|
||||
import SparkApi
|
||||
|
||||
# 输入文件路径
|
||||
input_file = 'output/train_expanded.jsonl'
|
||||
# 输出文件路径
|
||||
output_file = 'output/train_expanded_2.jsonl'
|
||||
# 断点文件路径
|
||||
checkpoint_file = 'output/e2_progress_checkpoint.txt'
|
||||
|
||||
|
||||
# 调用API生成问答对
|
||||
def generate_qa_via_api(content):
|
||||
appid = "48d04aae"
|
||||
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
|
||||
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
|
||||
domain = "4.0Ultra"
|
||||
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||
|
||||
prompt = (
|
||||
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
|
||||
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
|
||||
f"1. 根据给定内容生成**三个**相关的问题和回答。\n"
|
||||
f"2. 你可以简化问题、提取具体要素进行提问,或扩展内容生成额外的相关问题。\n"
|
||||
f"3. **问题必须简洁明了**,并涵盖内容中的关键信息。\n"
|
||||
f"4. 每个回答应该准确且**不超过50字**,同时**不少于20字**,以保证内容的简洁和有用性。\n"
|
||||
f"5. 仅围绕油橄榄栽培的相关内容生成问答对,忽略其他无关信息。\n\n"
|
||||
f"以下是给定内容:\n\n"
|
||||
f"内容:{content}\n\n"
|
||||
f"请按如下格式生成输出:\n"
|
||||
f"问题1:<生成第一个问题>\n"
|
||||
f"回答1:<生成第一个回答>\n"
|
||||
f"问题2:<生成第二个问题>\n"
|
||||
f"回答2:<生成第二个回答>\n"
|
||||
f"问题3:<生成第三个问题>\n"
|
||||
f"回答3:<生成第三个回答>\n\n"
|
||||
f"请确保每个问题和回答都保持与内容的紧密相关性,并保持专业性。"
|
||||
)
|
||||
|
||||
question = [{"role": "user", "content": prompt}]
|
||||
SparkApi.answer = ""
|
||||
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question)
|
||||
return SparkApi.answer.strip()
|
||||
|
||||
|
||||
# 加载断点进度
|
||||
def load_checkpoint():
|
||||
if os.path.exists(checkpoint_file):
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
return int(f.read().strip()) # 返回已处理的行索引
|
||||
return 0 # 没有断点则从0开始
|
||||
|
||||
|
||||
# 保存断点进度
|
||||
def save_checkpoint(index):
|
||||
with open(checkpoint_file, 'w') as f:
|
||||
f.write(str(index))
|
||||
|
||||
|
||||
# 解析返回的问答对,处理多个问答对的情况
|
||||
def parse_multiple_qa(answer_text):
|
||||
qa_pairs = []
|
||||
# 通过正则表达式找到所有的问答对
|
||||
pattern = re.compile(r"问题\d+:(.*?)回答\d+:(.*?)(问题|$)", re.S)
|
||||
matches = pattern.findall(answer_text)
|
||||
|
||||
for match in matches:
|
||||
question = match[0].strip()
|
||||
answer = match[1].strip()
|
||||
qa_pairs.append({"input": question, "output": answer})
|
||||
|
||||
return qa_pairs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 加载原始数据集
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
text_data = [json.loads(line) for line in f]
|
||||
|
||||
# 加载断点进度
|
||||
start_index = load_checkpoint()
|
||||
|
||||
# 从断点开始继续生成问答对
|
||||
with open(output_file, 'a', encoding='utf-8') as f:
|
||||
for i in tqdm(range(start_index, len(text_data))):
|
||||
item = text_data[i]
|
||||
input_content = item['input']
|
||||
|
||||
try:
|
||||
# 使用API生成新的问答对
|
||||
api_generated_qa = generate_qa_via_api(input_content)
|
||||
|
||||
# 解析API生成的问答对并添加到数据集
|
||||
qa_pairs = parse_multiple_qa(api_generated_qa)
|
||||
expanded_data = [{"input": qa_pair['input'], "output": qa_pair['output']} for qa_pair in qa_pairs]
|
||||
|
||||
# 保存生成的问答对
|
||||
for qa in expanded_data:
|
||||
json.dump(qa, f, ensure_ascii=False)
|
||||
f.write('\n')
|
||||
|
||||
# 保存当前的进度索引
|
||||
save_checkpoint(i)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing item {i}: {e}")
|
||||
# 跳过当前条目继续处理
|
||||
save_checkpoint(i)
|
||||
continue
|
||||
|
||||
print(f"已生成 {output_file} 文件,包含扩展的问答对。")
|
@ -1,153 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/10/22
|
||||
# @Author : 黄子寒
|
||||
# @File : generate_qa_with_multiple_pairs.py
|
||||
# @Project : EmoLLM
|
||||
|
||||
import os
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
import SparkApi
|
||||
import json
|
||||
|
||||
|
||||
appid = "f0f73de5"
|
||||
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
|
||||
api_key = "5773f6f95563708de994d17b7ea5d414"
|
||||
|
||||
# Spark服务地址及版本
|
||||
domain = "4.0Ultra"
|
||||
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||
|
||||
# 准备存储清洗后的文本
|
||||
text_data = []
|
||||
|
||||
# 断点文件,用于存储上次处理的段落索引
|
||||
checkpoint_file = "output/progress_checkpoint.txt"
|
||||
|
||||
# 加载处理好的文本文件
|
||||
with open("../processPDF/cleaned_data.txt", "r", encoding="utf-8") as f:
|
||||
cleaned_text = f.read()
|
||||
|
||||
|
||||
# 自定义分割函数,按最大100字以内的句子段落
|
||||
def split_text_to_sentences(text, max_length=300):
|
||||
sentences = re.split('(?<=。)', text)
|
||||
grouped_sentences = []
|
||||
current_group = ""
|
||||
|
||||
for sentence in sentences:
|
||||
if len(current_group) + len(sentence) <= max_length:
|
||||
current_group += sentence
|
||||
else:
|
||||
grouped_sentences.append(current_group.strip())
|
||||
current_group = sentence
|
||||
|
||||
if current_group:
|
||||
grouped_sentences.append(current_group.strip())
|
||||
|
||||
return grouped_sentences
|
||||
|
||||
|
||||
# 加载断点进度
|
||||
def load_checkpoint():
|
||||
if os.path.exists(checkpoint_file):
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
return int(f.read().strip()) # 返回已处理的段落索引
|
||||
return 0 # 没有断点则从0开始
|
||||
|
||||
|
||||
# 保存断点进度
|
||||
def save_checkpoint(index):
|
||||
with open(checkpoint_file, 'w') as f:
|
||||
f.write(str(index))
|
||||
|
||||
|
||||
# 将文本按要求的长度进行分割
|
||||
paragraphs = split_text_to_sentences(cleaned_text, 300)
|
||||
|
||||
|
||||
# 构建 LLM 生成 input 和 output 的详细 prompt,允许模型生成多个问答对
|
||||
def create_prompt(content):
|
||||
prompt = (
|
||||
f"你是一位油橄榄栽培专家。"
|
||||
f"根据以下内容生成一个或多个问题和回答对,请保证语句通顺有逻辑,同时忽略所有内容中和图示相关的内容:\n\n"
|
||||
f"内容:{content}\n\n"
|
||||
f"请以如下格式生成输出:\n"
|
||||
f"问题1:<在这里生成第一个问题>\n"
|
||||
f"回答1:<在这里生成第一个回答>\n"
|
||||
f"问题2:<在这里生成第二个问题(如有)>\n"
|
||||
f"回答2:<在这里生成第二个回答(如有)>\n"
|
||||
f"..."
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
# 解析返回的问答对,处理多个问答对的情况
|
||||
def parse_multiple_qa(answer_text):
|
||||
qa_pairs = []
|
||||
# 通过正则表达式找到所有的问答对
|
||||
pattern = re.compile(r"问题\d+:(.*?)回答\d+:(.*?)(问题|$)", re.S)
|
||||
matches = pattern.findall(answer_text)
|
||||
|
||||
for match in matches:
|
||||
question = match[0].strip()
|
||||
answer = match[1].strip()
|
||||
qa_pairs.append({"input": question, "output": answer})
|
||||
|
||||
return qa_pairs
|
||||
|
||||
|
||||
# 迭代限制,防止API额度过大
|
||||
def checklen(text):
|
||||
while len(text) > 8000: # 限制在8000字符以内
|
||||
del text[0]
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
text_data.clear()
|
||||
file_name = 'output/train_optimized_multiple.jsonl'
|
||||
conversations = []
|
||||
|
||||
# 加载上次的进度
|
||||
start_index = load_checkpoint()
|
||||
|
||||
# 从断点开始继续生成问答对
|
||||
# 从断点开始继续生成问答对
|
||||
for i in tqdm(range(start_index, len(paragraphs))): # 处理所有剩余的段落
|
||||
content = paragraphs[i].strip() # 去除段落前后的空格
|
||||
print("====================\ncontent:", content, "\n==================\n")
|
||||
if len(content) == 0:
|
||||
continue
|
||||
|
||||
# 构建 LLM 的 prompt
|
||||
prompt = create_prompt(content)
|
||||
question = checklen([{"role": "user", "content": prompt}])
|
||||
|
||||
# 调用 LLM 生成问答对
|
||||
SparkApi.answer = "" # 清空之前的回答
|
||||
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question) # 调用API获取回答
|
||||
|
||||
# 将生成的文本分割为问题和回答
|
||||
answer_text = SparkApi.answer.strip()
|
||||
|
||||
# 解析多个问答对
|
||||
qa_pairs = parse_multiple_qa(answer_text)
|
||||
|
||||
for qa_pair in qa_pairs:
|
||||
conversation = {
|
||||
"input": qa_pair['input'],
|
||||
"output": qa_pair['output']
|
||||
}
|
||||
|
||||
# 将对话数据添加到文件中
|
||||
with open(file_name, 'a', encoding='utf-8') as file:
|
||||
json.dump(conversation, file, ensure_ascii=False)
|
||||
file.write("\n")
|
||||
|
||||
# 每处理完一个段落,保存当前的进度索引
|
||||
save_checkpoint(i)
|
||||
|
||||
print(f"已生成 {file_name} 文件,包含问答对。")
|
||||
|
@ -1,32 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/10/24 20:47
|
||||
# @Author : 黄子寒
|
||||
# @Email : 1064071566@qq.com
|
||||
# @File : jsonl2json.py
|
||||
# @Project : EmoLLM
|
||||
import json
|
||||
|
||||
|
||||
input_file = 'output/fine_tune_data.jsonl'
|
||||
output_file = 'output/fine_tune_data.json'
|
||||
|
||||
|
||||
data_list = []
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
entry = json.loads(line.strip())
|
||||
|
||||
new_entry = {
|
||||
"instruction": entry.get("instruction", ""),
|
||||
"input": entry.get("input", ""),
|
||||
"output": entry.get("output", ""),
|
||||
"system": entry.get("system", ""),
|
||||
"history": entry.get("history", [])
|
||||
}
|
||||
data_list.append(new_entry)
|
||||
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data_list, f, ensure_ascii=False, indent=4)
|
||||
|
||||
print(f" {output_file}")
|
File diff suppressed because it is too large
Load Diff
@ -1,50 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/10/18 22:09
|
||||
# @Author : 黄子寒
|
||||
# @Email : 1064071566@qq.com
|
||||
# @File : OCR.py
|
||||
# @Project : EmoLLM
|
||||
import cv2
|
||||
from paddleocr import PaddleOCR
|
||||
import os
|
||||
import glob
|
||||
|
||||
# 初始化OCR模型
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang='ch')
|
||||
|
||||
|
||||
image_dir = "output"
|
||||
output_txt_dir = "output_txt"
|
||||
|
||||
|
||||
if not os.path.exists(output_txt_dir):
|
||||
os.makedirs(output_txt_dir)
|
||||
|
||||
image_list = glob.glob(os.path.join(image_dir, "*.png"))
|
||||
|
||||
# 批量识别处理
|
||||
for img_path in image_list:
|
||||
# 读取图像
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
# 使用OCR模型进行识别
|
||||
result = ocr.ocr(img)
|
||||
|
||||
# 获取图像文件名(不带扩展名)
|
||||
img_name = os.path.splitext(os.path.basename(img_path))[0]
|
||||
|
||||
# 将OCR结果整理为文本
|
||||
txt_file_path = os.path.join(output_txt_dir, f"{img_name}.txt")
|
||||
|
||||
# 打开文件以写入OCR结果
|
||||
with open(txt_file_path, 'w', encoding='utf-8') as f:
|
||||
for line in result:
|
||||
for word_info in line:
|
||||
# 提取识别到的文本和其置信度
|
||||
word, confidence = word_info[1][0], word_info[1][1]
|
||||
|
||||
f.write(f"{word}\n")
|
||||
|
||||
print(f"Word: {word}, Confidence: {confidence}")
|
||||
|
||||
print(f"{txt_file_path}")
|
@ -1,39 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/10/21 22:09
|
||||
# @Author : 黄子寒
|
||||
# @Email : 1064071566@qq.com
|
||||
# @File : PDF2Pic.py
|
||||
# @Project : EmoLLM
|
||||
import fitz # PyMuPDF
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
# PDF 文件路径和输出图像保存目录
|
||||
pdf_file_path = "input.pdf"
|
||||
output_image_dir = "output"
|
||||
|
||||
# 创建输出目录
|
||||
if not os.path.exists(output_image_dir):
|
||||
os.makedirs(output_image_dir)
|
||||
|
||||
# 打开 PDF 文件
|
||||
pdf_document = fitz.open(pdf_file_path)
|
||||
|
||||
# 遍历每一页并保存为图像
|
||||
for page_number in range(len(pdf_document)):
|
||||
# 获取当前页对象
|
||||
page = pdf_document.load_page(page_number)
|
||||
|
||||
# 将页面转换为图像
|
||||
zoom = 4
|
||||
mat = fitz.Matrix(zoom, zoom)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
|
||||
image_path = os.path.join(output_image_dir, f"{page_number + 1}.png")
|
||||
pix.save(image_path)
|
||||
|
||||
print(f"Saved {image_path}")
|
||||
|
||||
|
||||
pdf_document.close()
|
@ -1,25 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import natsort
|
||||
|
||||
folder_path = "output_txt"
|
||||
combined_text = ""
|
||||
|
||||
# 使用自然排序来读取文件
|
||||
for filename in natsort.natsorted(os.listdir(folder_path)):
|
||||
if filename.endswith(".txt"):
|
||||
file_path = os.path.join(folder_path, filename)
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
combined_text += file.read()
|
||||
|
||||
|
||||
combined_text = combined_text.replace('\n', '')
|
||||
|
||||
# 处理连续三个或更多相同的标点符号
|
||||
combined_text = re.sub(r'([。,!?:;. ·])\1{2,}', r'\1', combined_text)
|
||||
|
||||
# 将清洗后的文本保存到一个新的文件中
|
||||
with open("cleaned_data.txt", 'w', encoding='utf-8') as file:
|
||||
file.write(combined_text)
|
||||
|
||||
print("数据处理完成")
|
@ -1,84 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import SparkApi
|
||||
|
||||
# 输入文件路径
|
||||
input_file = 'output/train_expanded.jsonl'
|
||||
# 断点文件路径
|
||||
checkpoint_file = 'output/expand_checkpoint.txt'
|
||||
# 临时文件路径
|
||||
temp_file = 'output/tmp_train_expanded.jsonl'
|
||||
|
||||
|
||||
# 调用API生成回答
|
||||
def generate_answer_via_api(question):
|
||||
appid = "48d04aae"
|
||||
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
|
||||
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
|
||||
domain = "4.0Ultra"
|
||||
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||
|
||||
prompt = (
|
||||
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
|
||||
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
|
||||
f"每个回答应该准确且不超过50字,同时不少于20字,以保证内容的简洁和有用性。\n"
|
||||
f"问题:{question}\n\n"
|
||||
f"请生成一个详细回答。"
|
||||
)
|
||||
|
||||
question_data = [{"role": "user", "content": prompt}]
|
||||
SparkApi.answer = ""
|
||||
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question_data)
|
||||
return SparkApi.answer.strip()
|
||||
|
||||
|
||||
# 加载断点进度
|
||||
def load_checkpoint():
|
||||
if os.path.exists(checkpoint_file):
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
return int(f.read().strip()) # 返回已处理的行索引
|
||||
return 0 # 没有断点则从0开始
|
||||
|
||||
|
||||
# 保存断点进度
|
||||
def save_checkpoint(index):
|
||||
with open(checkpoint_file, 'w') as f:
|
||||
f.write(str(index))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 加载断点进度
|
||||
start_index = load_checkpoint()
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as f, open(temp_file, 'w', encoding='utf-8') as temp_f:
|
||||
for i, line in enumerate(tqdm(f)):
|
||||
item = json.loads(line)
|
||||
|
||||
# 从断点开始处理
|
||||
if i >= start_index:
|
||||
input_content = item['input']
|
||||
output_content = item['output']
|
||||
|
||||
# # 检查是否是未提供回答的问答对
|
||||
# if "未给" in output_content:
|
||||
# # 使用API生成新的回答
|
||||
# new_answer = generate_answer_via_api(input_content)
|
||||
# item['output'] = new_answer
|
||||
|
||||
if len(output_content)<11:
|
||||
# 使用API生成新的回答
|
||||
new_answer = generate_answer_via_api(input_content)
|
||||
item['output'] = new_answer
|
||||
|
||||
# 保存当前的进度索引
|
||||
save_checkpoint(i)
|
||||
|
||||
# 写入更新内容到临时文件
|
||||
json.dump(item, temp_f, ensure_ascii=False)
|
||||
temp_f.write('\n')
|
||||
|
||||
# 替换原始文件
|
||||
os.replace(temp_file, input_file)
|
||||
print(f"已更新 {input_file} 文件,包含重新生成的回答。")
|
@ -1,58 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/10/23 23:16
|
||||
# @Author : 黄子寒
|
||||
# @Email : 1064071566@qq.com
|
||||
# @File : topic_model.py
|
||||
# @Project : EmoLLM
|
||||
import json
|
||||
import gensim
|
||||
from gensim import corpora
|
||||
from nltk.tokenize import word_tokenize
|
||||
from nltk.corpus import stopwords
|
||||
from collections import defaultdict
|
||||
|
||||
# 加载问答对数据
|
||||
def load_qa_data(file_path):
|
||||
qa_pairs = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
qa_pairs.append(json.loads(line.strip()))
|
||||
return qa_pairs
|
||||
|
||||
# 文本预处理
|
||||
def preprocess_text(text):
|
||||
stop_words = set(stopwords.words('english'))
|
||||
tokens = word_tokenize(text.lower())
|
||||
tokens = [word for word in tokens if word.isalnum() and word not in stop_words]
|
||||
return tokens
|
||||
|
||||
# 生成LDA主题模型
|
||||
def build_lda_model(qa_pairs, num_topics=5):
|
||||
# 处理所有问题文本
|
||||
questions = [qa['input'] for qa in qa_pairs]
|
||||
processed_questions = [preprocess_text(question) for question in questions]
|
||||
|
||||
# 创建字典和词袋模型
|
||||
dictionary = corpora.Dictionary(processed_questions)
|
||||
corpus = [dictionary.doc2bow(text) for text in processed_questions]
|
||||
|
||||
# 训练LDA模型
|
||||
lda_model = gensim.models.ldamodel.LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
|
||||
return lda_model, dictionary, corpus
|
||||
|
||||
# 打印每个主题的关键词
|
||||
def print_topics(lda_model, num_words=10):
|
||||
for idx, topic in lda_model.print_topics(num_words=num_words):
|
||||
print(f"主题 {idx}: {topic}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
qa_file = "output/train_optimized_multiple.jsonl" # 问答对文件
|
||||
|
||||
# 加载问答对
|
||||
qa_pairs = load_qa_data(qa_file)
|
||||
|
||||
# 构建LDA主题模型
|
||||
lda_model, dictionary, corpus = build_lda_model(qa_pairs, num_topics=5)
|
||||
|
||||
# 打印主题及其关键词
|
||||
print_topics(lda_model)
|
@ -3,15 +3,15 @@ from prompt import *
|
||||
from tqdm import tqdm
|
||||
|
||||
# 以下密钥信息从控制台获取
|
||||
appid = "f0f73de5" # 填写控制台中获取的 APPID 信息
|
||||
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk" # 填写控制台中获取的 APISecret 信息
|
||||
api_key = "5773f6f95563708de994d17b7ea5d414" # 填写控制台中获取的 APIKey 信息
|
||||
appid = "" # 填写控制台中获取的 APPID 信息
|
||||
api_secret = "" # 填写控制台中获取的 APISecret 信息
|
||||
api_key = "" # 填写控制台中获取的 APIKey 信息
|
||||
|
||||
# 用于配置大模型版本,默认“general/generalv2”
|
||||
domain = "4.0Ultra" # v1.5版本
|
||||
domain = "general" # v1.5版本
|
||||
# domain = "generalv2" # v2.0版本
|
||||
# 云端环境的服务地址
|
||||
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat" # v1.5环境的地址
|
||||
Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
|
||||
# Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ def checklen(text):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
text.clear()
|
||||
text.clear
|
||||
file_name = 'train3.jsonl'
|
||||
conversations = []
|
||||
for i in tqdm(range(200)):
|
||||
|
@ -5,34 +5,8 @@ streamlit==1.24.0
|
||||
sentencepiece==0.1.99
|
||||
accelerate==0.24.1
|
||||
transformers_stream_generator==0.0.4
|
||||
openxlab~=0.0.11
|
||||
openxlab
|
||||
tiktoken
|
||||
einops
|
||||
oss2
|
||||
requests~=2.32.3
|
||||
|
||||
pyjwt~=2.8.0
|
||||
loguru~=0.6.0
|
||||
yaml~=0.2.5
|
||||
pyyaml~=6.0.1
|
||||
tqdm~=4.66.2
|
||||
langchain~=0.0.352
|
||||
torch~=2.5.0
|
||||
metagpt~=0.8.1
|
||||
erniebot~=0.5.9
|
||||
python-dotenv~=1.0.0
|
||||
zhipuai~=2.0.1
|
||||
uvicorn~=0.32.0
|
||||
fastapi~=0.115.2
|
||||
opencv-python~=4.10.0.84
|
||||
paddleocr~=2.9.0
|
||||
dashscope~=1.14.1
|
||||
numpy~=1.24.3
|
||||
jieba~=0.42.1
|
||||
nltk~=3.9.1
|
||||
setuptools~=65.6.3
|
||||
websocket~=0.2.1
|
||||
websocket-client~=1.6.2
|
||||
gensim~=4.3.3
|
||||
pillow~=9.5.0
|
||||
natsort~=8.4.0
|
||||
requests
|
||||
|
@ -1,11 +1,11 @@
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
|
||||
YOUR_ACCESS_TOKEN = '' # 输入你的modelscope access token
|
||||
YOUR_ACCESS_TOKEN = '' #输入你的modelscope access token
|
||||
|
||||
api = HubApi()
|
||||
api.login(YOUR_ACCESS_TOKEN)
|
||||
api.push_model(
|
||||
model_id="zealot5209/EmoLLM-Scientist", # your_name/model_id
|
||||
model_dir="./merged" # 本地模型目录,要求目录中必须包含configuration.json
|
||||
)
|
||||
model_id="zealot5209/EmoLLM-Scientist", #your_name/model_id
|
||||
model_dir="./merged" # 本地模型目录,要求目录中必须包含configuration.json
|
||||
)
|
||||
|
@ -1,213 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from mmengine.dataset import DefaultSampler
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||
LoggerHook, ParamSchedulerHook)
|
||||
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
|
||||
from peft import LoraConfig
|
||||
from torch.optim import AdamW
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
BitsAndBytesConfig)
|
||||
|
||||
from xtuner.dataset import process_hf_dataset
|
||||
from xtuner.dataset.collate_fns import default_collate_fn
|
||||
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
|
||||
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
|
||||
VarlenAttnArgsToMessageHubHook)
|
||||
from xtuner.engine.runner import TrainLoop
|
||||
from xtuner.model import SupervisedFinetune
|
||||
from xtuner.parallel.sequence import SequenceParallelSampler
|
||||
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
|
||||
|
||||
#######################################################################
|
||||
# PART 1 Settings #
|
||||
#######################################################################
|
||||
# Model
|
||||
pretrained_model_name_or_path = 'Qwen2-7B-Instruct' # your model path
|
||||
use_varlen_attn = False
|
||||
|
||||
# Data
|
||||
alpaca_en_path = '../datasets/aiwei.json'
|
||||
prompt_template = PROMPT_TEMPLATE.qwen_chat
|
||||
max_length = 1024
|
||||
pack_to_max_length = True
|
||||
|
||||
# parallel
|
||||
sequence_parallel_size = 1
|
||||
|
||||
# Scheduler & Optimizer
|
||||
batch_size = 8 # per_device
|
||||
accumulative_counts = 16
|
||||
accumulative_counts *= sequence_parallel_size
|
||||
dataloader_num_workers = 4
|
||||
max_epochs = 3
|
||||
optim_type = AdamW
|
||||
lr = 1e-5
|
||||
betas = (0.9, 0.999)
|
||||
weight_decay = 0
|
||||
max_norm = 1 # grad clip
|
||||
warmup_ratio = 0.03
|
||||
|
||||
# Save
|
||||
save_steps = 100
|
||||
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
|
||||
|
||||
# Evaluate the generation performance during the training
|
||||
evaluation_freq = 100
|
||||
SYSTEM = "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。"
|
||||
evaluation_inputs = [
|
||||
'我压力很大', '生活没意思', "非常容易羡慕别人啊"
|
||||
]
|
||||
|
||||
#######################################################################
|
||||
# PART 2 Model & Tokenizer #
|
||||
#######################################################################
|
||||
tokenizer = dict(
|
||||
type=AutoTokenizer.from_pretrained,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
padding_side='right')
|
||||
|
||||
model = dict(
|
||||
type=SupervisedFinetune,
|
||||
use_varlen_attn=use_varlen_attn,
|
||||
llm=dict(
|
||||
type=AutoModelForCausalLM.from_pretrained,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float16,
|
||||
),
|
||||
lora=dict(
|
||||
type=LoraConfig,
|
||||
r=32,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
bias='none',
|
||||
task_type='CAUSAL_LM'))
|
||||
|
||||
#######################################################################
|
||||
# PART 3 Dataset & Dataloader #
|
||||
#######################################################################
|
||||
alpaca_en = dict(
|
||||
type=process_hf_dataset,
|
||||
dataset=dict(type=load_dataset, path='json',
|
||||
data_files=dict(train=alpaca_en_path)),
|
||||
tokenizer=tokenizer,
|
||||
max_length=max_length,
|
||||
dataset_map_fn=None,
|
||||
template_map_fn=dict(
|
||||
type=template_map_fn_factory, template=prompt_template),
|
||||
remove_unused_columns=True,
|
||||
shuffle_before_pack=True,
|
||||
pack_to_max_length=pack_to_max_length,
|
||||
use_varlen_attn=use_varlen_attn)
|
||||
|
||||
sampler = SequenceParallelSampler \
|
||||
if sequence_parallel_size > 1 else DefaultSampler
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=batch_size,
|
||||
num_workers=dataloader_num_workers,
|
||||
dataset=alpaca_en,
|
||||
sampler=dict(type=sampler, shuffle=True),
|
||||
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
|
||||
|
||||
#######################################################################
|
||||
# PART 4 Scheduler & Optimizer #
|
||||
#######################################################################
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
type=AmpOptimWrapper,
|
||||
optimizer=dict(
|
||||
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
|
||||
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
|
||||
accumulative_counts=accumulative_counts,
|
||||
loss_scale='dynamic',
|
||||
dtype='float16')
|
||||
|
||||
# learning policy
|
||||
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type=LinearLR,
|
||||
start_factor=1e-5,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=warmup_ratio * max_epochs,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type=CosineAnnealingLR,
|
||||
eta_min=0.0,
|
||||
by_epoch=True,
|
||||
begin=warmup_ratio * max_epochs,
|
||||
end=max_epochs,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
|
||||
|
||||
#######################################################################
|
||||
# PART 5 Runtime #
|
||||
#######################################################################
|
||||
# Log the dialogue periodically during the training process, optional
|
||||
custom_hooks = [
|
||||
dict(type=DatasetInfoHook, tokenizer=tokenizer),
|
||||
dict(
|
||||
type=EvaluateChatHook,
|
||||
tokenizer=tokenizer,
|
||||
every_n_iters=evaluation_freq,
|
||||
evaluation_inputs=evaluation_inputs,
|
||||
system=SYSTEM,
|
||||
prompt_template=prompt_template)
|
||||
]
|
||||
|
||||
if use_varlen_attn:
|
||||
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
|
||||
|
||||
# configure default hooks
|
||||
default_hooks = dict(
|
||||
# record the time of every iteration.
|
||||
timer=dict(type=IterTimerHook),
|
||||
# print log every 10 iterations.
|
||||
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
|
||||
# enable the parameter scheduler.
|
||||
param_scheduler=dict(type=ParamSchedulerHook),
|
||||
# save checkpoint per `save_steps`.
|
||||
checkpoint=dict(
|
||||
type=CheckpointHook,
|
||||
by_epoch=False,
|
||||
interval=save_steps,
|
||||
max_keep_ckpts=save_total_limit),
|
||||
# set sampler seed in distributed evrionment.
|
||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||
)
|
||||
|
||||
# configure environment
|
||||
env_cfg = dict(
|
||||
# whether to enable cudnn benchmark
|
||||
cudnn_benchmark=False,
|
||||
# set multi process parameters
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
# set distributed parameters
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
)
|
||||
|
||||
# set visualizer
|
||||
visualizer = None
|
||||
|
||||
# set log level
|
||||
log_level = 'INFO'
|
||||
|
||||
# load from which checkpoint
|
||||
load_from = None
|
||||
|
||||
# whether to resume training from the loaded checkpoint
|
||||
resume = False
|
||||
|
||||
# Defaults to use random seed and disable `deterministic`
|
||||
randomness = dict(seed=None, deterministic=False)
|
||||
|
||||
# set log processor
|
||||
log_processor = dict(by_epoch=False)
|
Loading…
Reference in New Issue
Block a user