diff --git a/.gitignore b/.gitignore
index df8d9c9..092eab3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,168 @@
ESConv.json
.DS_Store
-__pycache__/
tmp/
-zhipuai/
\ No newline at end of file
+zhipuai/
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# merged_weights
+hf_merge/
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
diff --git a/README.md b/README.md
index 446837f..ef93b1d 100644
--- a/README.md
+++ b/README.md
@@ -1,22 +1,21 @@
# EmoLLM-心理健康大模型
-该EmoLLM-心理健康大模型主要基于...
-
[![Contributors][contributors-shield]][contributors-url]
[![Forks][forks-shield]][forks-url]
[![Stargazers][stars-shield]][stars-url]
[![Issues][issues-shield]][issues-url]
[![MIT License][license-shield]][license-url]
-[![LinkedIn][linkedin-shield]][linkedin-url]
+
+
-
+
EmoLLM
@@ -35,22 +34,44 @@
- 本篇README.md面向开发者
+
+
+该模型基于InternLM2-7B-chat进行微调,从而构建一个能够理解用户-支持用户-帮助用户提供解决问题思路的心理AI助手。 心理健康大模型(Mental Health Grand Model)是一个综合性的概念,它旨在全面理解和促进个体、群体乃至整个社会的心理健康状态。这个模型通常包含以下几个关键组成部分:
+
+认知因素:涉及个体的思维模式、信念系统、认知偏差以及解决问题的能力。认知因素对心理健康有重要影响,因为它们影响个体如何解释和应对生活中的事件。
+
+情感因素:包括情绪调节、情感表达和情感体验。情感健康是心理健康的重要组成部分,涉及个体如何管理和表达自己的情感,以及如何从负面情绪中恢复。
+
+行为因素:涉及个体的行为模式、习惯和应对策略。这包括应对压力的技巧、社交技能以及自我效能感,即个体对自己能力的信心。
+
+社会环境:包括家庭、工作、社区和文化背景等外部因素,这些因素对个体的心理健康有着直接和间接的影响。
+
+生理健康:身体健康与心理健康紧密相关。良好的身体健康可以促进心理健康,反之亦然。
+
+心理韧性:指个体在面对逆境时的恢复力和适应能力。心理韧性强的人更能够从挑战中恢复,并从中学习和成长。
+
+预防和干预措施:心理健康大模型还包括预防心理问题和促进心理健康的策略,如心理教育、心理咨询、心理治疗和社会支持系统。
+
+评估和诊断工具:为了有效促进心理健康,需要有科学的工具来评估个体的心理状态,以及诊断可能存在的心理问题。
## 目录
-- [上手指南](#上手指南)
- - [开发前的配置要求](#开发前的配置要求)
- - [安装步骤](#安装步骤)
-- [文件目录说明](#文件目录说明)
-- [开发的架构](#开发的架构)
-- [部署](#部署)
-- [使用到的框架](#使用到的框架)
-- [贡献者](#贡献者)
- - [如何参与开源项目](#如何参与开源项目)
-- [版本控制](#版本控制)
-- [作者](#作者)
-- [鸣谢](#鸣谢)
+- [EmoLLM-心理健康大模型](#emollm-心理健康大模型)
+ - [目录](#目录)
+ - [开发前的配置要求](#开发前的配置要求)
+ - [**安装步骤**](#安装步骤)
+ - [文件目录说明](#文件目录说明)
+ - [开发的架构](#开发的架构)
+ - [demo部署](#demo部署)
+ - [使用到的框架](#使用到的框架)
+ - [贡献者](#贡献者)
+ - [如何参与开源项目](#如何参与开源项目)
+ - [版本控制](#版本控制)
+ - [作者](#作者)
+ - [版权说明](#版权说明)
+ - [鸣谢](#鸣谢)
+ - [Star History](#star-history)
+ - [🌟 Contributors](#-contributors)
###### 开发前的配置要求
@@ -159,9 +180,9 @@ filetree
[issues-shield]: https://img.shields.io/github/issues/aJupyter/EmoLLM.svg?style=flat-square
[issues-url]: https://img.shields.io/github/issues/aJupyter/EmoLLM.svg
[license-shield]: https://img.shields.io/github/license/aJupyter/EmoLLM.svg?style=flat-square
-[license-url]: https://github.com/aJupyter/EmoLLM/blob/master/LICENSE
-[linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=flat-square&logo=linkedin&colorB=555
-[linkedin-url]: https://linkedin.com/in/aJupyter
+[license-url]: https://github.com/aJupyter/EmoLLM/blob/main/LICENSE
+
+
@@ -172,3 +193,7 @@ filetree
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=aJupyter/EmoLLM&type=Date)](https://star-history.com/#aJupyter/EmoLLM&Date)
+
+## 🌟 Contributors
+
+[![EmoLLM contributors](https://contrib.rocks/image?repo=aJupyter/EmoLLM&max=50)](https://github.com/aJupyter/EmoLLM/graphs/contributors)
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000..7e5d424
--- /dev/null
+++ b/app.py
@@ -0,0 +1,2 @@
+import os
+os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
diff --git a/assets/logo.jpeg b/assets/logo.jpeg
new file mode 100644
index 0000000..591718d
Binary files /dev/null and b/assets/logo.jpeg differ
diff --git a/assets/robot.jpeg b/assets/robot.jpeg
new file mode 100644
index 0000000..591718d
Binary files /dev/null and b/assets/robot.jpeg differ
diff --git a/assets/user.png b/assets/user.png
new file mode 100644
index 0000000..1b2cf2e
Binary files /dev/null and b/assets/user.png differ
diff --git a/demo/cli_internlm2.py b/demo/cli_internlm2.py
new file mode 100644
index 0000000..6886802
--- /dev/null
+++ b/demo/cli_internlm2.py
@@ -0,0 +1,24 @@
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+
+model_name_or_path = "./model"
+
+tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
+model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map='auto')
+model = model.eval()
+
+system_prompt = "你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发(排名按字母顺序排序,不分先后)、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。"
+
+messages = [(system_prompt, '')]
+
+print("=============Welcome to InternLM chatbot, type 'exit' to exit.=============")
+
+while True:
+ input_text = input("User >>> ")
+ input_text.replace(' ', '')
+ if input_text == "exit":
+ break
+ response, history = model.chat(tokenizer, input_text, history=messages)
+ messages.append((input_text, response))
+ print(f"robot >>> {response}")
\ No newline at end of file
diff --git a/demo/cli_qwen.py b/demo/cli_qwen.py
index 210b6fe..7dcc8f6 100644
--- a/demo/cli_qwen.py
+++ b/demo/cli_qwen.py
@@ -16,7 +16,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from transformers.trainer_utils import set_seed
-DEFAULT_CKPT_PATH = './merged'
+DEFAULT_CKPT_PATH = './model'
_WELCOME_MSG = '''\
Welcome to use Emo-Chat model, type text to start chat, type :h to show command help.
diff --git a/demo/web_qwen.py b/demo/web_qwen.py
index 9769f81..b47ebcd 100644
--- a/demo/web_qwen.py
+++ b/demo/web_qwen.py
@@ -15,7 +15,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
-DEFAULT_CKPT_PATH = './merged'
+DEFAULT_CKPT_PATH = './model'
def _get_args():
diff --git a/配置要求/ChatGLM3-6B环境配置依赖.md b/model_config/ChatGLM3-6b.md
similarity index 97%
rename from 配置要求/ChatGLM3-6B环境配置依赖.md
rename to model_config/ChatGLM3-6b.md
index 13dbc4d..7ffca43 100644
--- a/配置要求/ChatGLM3-6B环境配置依赖.md
+++ b/model_config/ChatGLM3-6b.md
@@ -2,11 +2,11 @@
## 环境准备
我们实践了两种平台进行选择
* 在[autodl](https://www.autodl.com/)平台中租一个3090等24G显存的显卡机器,如下图所示镜像选择`PyTorch`-->`2.0.0`-->`3.8(ubuntu20.04)`-->`11.8`
-![Alt text](Images/image-1.png)
+![autodl](Images/autodl.png)
* 在 [InternStudio](https://studio.intern-ai.org.cn/) 平台中选择 A100(1/4) 的配置,如下图所示镜像选择 `Cuda11.7-conda`,如下图所示:
-![Alt text](Images/image.png)
+![internstudio](Images/internstudio.png)
在Terminal中,进行pip换源和安装依赖包
```shell
diff --git a/配置要求/InternLM模型环境依赖.md b/model_config/InternLM2-7b.md
similarity index 100%
rename from 配置要求/InternLM模型环境依赖.md
rename to model_config/InternLM2-7b.md
diff --git a/配置要求/qwen环境配置依赖.md b/model_config/Qwen-7b.md
similarity index 100%
rename from 配置要求/qwen环境配置依赖.md
rename to model_config/Qwen-7b.md
diff --git a/配置要求/Images/说明.md b/model_config/images/README.md
similarity index 100%
rename from 配置要求/Images/说明.md
rename to model_config/images/README.md
diff --git a/配置要求/Images/image-1.png b/model_config/images/autodl.png
similarity index 100%
rename from 配置要求/Images/image-1.png
rename to model_config/images/autodl.png
diff --git a/配置要求/Images/image.png b/model_config/images/internstudio.png
similarity index 100%
rename from 配置要求/Images/image.png
rename to model_config/images/internstudio.png
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..de40ca1
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,8 @@
+transformers==4.36.2
+streamlit==1.24.0
+sentencepiece==0.1.99
+accelerate==0.24.1
+transformers_stream_generator==0.0.4
+openxlab
+tiktoken
+einops
\ No newline at end of file
diff --git a/scripts/Gen/SparkApi.py b/scripts/Gen/SparkApi.py
deleted file mode 100644
index d0edbfc..0000000
--- a/scripts/Gen/SparkApi.py
+++ /dev/null
@@ -1,136 +0,0 @@
-import _thread as thread
-import base64
-import datetime
-import hashlib
-import hmac
-import json
-from urllib.parse import urlparse
-import ssl
-from datetime import datetime
-from time import mktime
-from urllib.parse import urlencode
-from wsgiref.handlers import format_date_time
-
-import websocket # 使用websocket_client
-answer = ""
-
-class Ws_Param(object):
- # 初始化
- def __init__(self, APPID, APIKey, APISecret, Spark_url):
- self.APPID = APPID
- self.APIKey = APIKey
- self.APISecret = APISecret
- self.host = urlparse(Spark_url).netloc
- self.path = urlparse(Spark_url).path
- self.Spark_url = Spark_url
-
- # 生成url
- def create_url(self):
- # 生成RFC1123格式的时间戳
- now = datetime.now()
- date = format_date_time(mktime(now.timetuple()))
-
- # 拼接字符串
- signature_origin = "host: " + self.host + "\n"
- signature_origin += "date: " + date + "\n"
- signature_origin += "GET " + self.path + " HTTP/1.1"
-
- # 进行hmac-sha256进行加密
- signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
- digestmod=hashlib.sha256).digest()
-
- signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
-
- authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
-
- authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
-
- # 将请求的鉴权参数组合为字典
- v = {
- "authorization": authorization,
- "date": date,
- "host": self.host
- }
- # 拼接鉴权参数,生成url
- url = self.Spark_url + '?' + urlencode(v)
- # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
- return url
-
-
-# 收到websocket错误的处理
-def on_error(ws, error):
- print("### error:", error)
-
-
-# 收到websocket关闭的处理
-def on_close(ws,one,two):
- print(" ")
-
-
-# 收到websocket连接建立的处理
-def on_open(ws):
- thread.start_new_thread(run, (ws,))
-
-
-def run(ws, *args):
- data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
- ws.send(data)
-
-
-# 收到websocket消息的处理
-def on_message(ws, message):
- # print(message)
- data = json.loads(message)
- code = data['header']['code']
- if code != 0:
- print(f'请求错误: {code}, {data}')
- ws.close()
- else:
- choices = data["payload"]["choices"]
- status = choices["status"]
- content = choices["text"][0]["content"]
- print(content,end ="")
- global answer
- answer += content
- # print(1)
- if status == 2:
- ws.close()
-
-
-def gen_params(appid, domain,question):
- """
- 通过appid和用户的提问来生成请参数
- """
- data = {
- "header": {
- "app_id": appid,
- "uid": "1234"
- },
- "parameter": {
- "chat": {
- "domain": domain,
- "temperature": 0.5,
- "max_tokens": 2048
- }
- },
- "payload": {
- "message": {
- "text": question
- }
- }
- }
- return data
-
-
-def main(appid, api_key, api_secret, Spark_url,domain, question):
- # print("星火:")
- wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
- websocket.enableTrace(False)
- wsUrl = wsParam.create_url()
- ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
- ws.appid = appid
- ws.question = question
- ws.domain = domain
- ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
-
-
diff --git a/scripts/Gen/gen_Chat.py b/scripts/Gen/gen_Chat.py
deleted file mode 100644
index 945fde9..0000000
--- a/scripts/Gen/gen_Chat.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import SparkApi
-from prompt import *
-from tqdm import tqdm
-
-# 以下密钥信息从控制台获取
-appid = "" # 填写控制台中获取的 APPID 信息
-api_secret = "" # 填写控制台中获取的 APISecret 信息
-api_key = "" # 填写控制台中获取的 APIKey 信息
-
-# 用于配置大模型版本,默认“general/generalv2”
-domain = "general" # v1.5版本
-# domain = "generalv2" # v2.0版本
-# 云端环境的服务地址
-Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
-# Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
-
-
-text = []
-
-
-# length = 0
-
-def getText(role, content):
- jsoncon = {}
- jsoncon["role"] = role
- jsoncon["content"] = content
- text.append(jsoncon)
- return text
-
-
-def getlength(text):
- length = 0
- for content in text:
- temp = content["content"]
- leng = len(temp)
- length += leng
- return length
-
-
-def checklen(text):
- while (getlength(text) > 8000):
- del text[0]
- return text
-
-
-if __name__ == '__main__':
- text.clear
- file_name = 'train3.jsonl'
- conversations = []
- for i in tqdm(range(200)):
- Input = prompt(random.randint(0, 16))
- question = checklen(getText("user", Input))
- SparkApi.answer = ""
- SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question)
- getText("assistant", SparkApi.answer)
- conversations.append(ChatGLM3_6B(SparkApi.answer))
- for item in conversations:
- save_jsonl(item, file_name)
- conversations.clear()
-
diff --git a/scripts/Gen/gen_data.py b/scripts/Gen/gen_data.py
deleted file mode 100644
index c73381c..0000000
--- a/scripts/Gen/gen_data.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import SparkApi
-from prompt import *
-from tqdm import tqdm
-
-
-# 以下密钥信息从控制台获取
-appid = "" # 填写控制台中获取的 APPID 信息
-api_secret = "" # 填写控制台中获取的 APISecret 信息
-api_key = "" # 填写控制台中获取的 APIKey 信息
-
-#用于配置大模型版本,默认“general/generalv2”
-domain = "general" # v1.5版本
-# domain = "generalv2" # v2.0版本
-#云端环境的服务地址
-Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
-# Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
-
-
-text =[]
-
-# length = 0
-
-def getText(role,content):
- jsoncon = {}
- jsoncon["role"] = role
- jsoncon["content"] = content
- text.append(jsoncon)
- return text
-
-def getlength(text):
- length = 0
- for content in text:
- temp = content["content"]
- leng = len(temp)
- length += leng
- return length
-
-def checklen(text):
- while (getlength(text) > 8000):
- del text[0]
- return text
-
-
-
-if __name__ == '__main__':
- text.clear
- file_name = 'a2.jsonl'
- conversations = []
- for i in range(1):
- for j in tqdm(range(10)):
- Input = prompt(i)
- question = checklen(getText("user",Input))
- SparkApi.answer =""
- SparkApi.main(appid,api_key, api_secret, Spark_url, domain, question)
- getText("assistant", SparkApi.answer)
- conversations.append(xinghuo_api(SparkApi.answer))
- if i % 2 == 0 :
- save_jsonl(conversations, file_name)
- conversations.clear()
-
diff --git a/scripts/Gen/prompt.py b/scripts/Gen/prompt.py
deleted file mode 100644
index 2236d63..0000000
--- a/scripts/Gen/prompt.py
+++ /dev/null
@@ -1,151 +0,0 @@
-import json
-import random
-import re
-import copy
-# new_data = [{'role': 'user', 'content': '你好'}, {'role': 'assistant', 'content': '你好!有什么我可以帮助您的吗?'}, {'role': 'user', 'content': '生成一段心理医生和病人的对话'}, {'role': 'assistant', 'content': '心理医生:你好,我是你的心理医生。请问你的名字是?\n\n病人:我叫李明。\n\n心理医生:很高兴见到你,李明先生。你来咨询是因为什么问题呢?\n\n病人:最近我总是感觉很焦虑,睡眠也不好。\n\n心理医生:好的,可以跟我具体说说你的情况吗?你有什么压力或者担忧的事情吗?\n\n病人:其实我一直在担心工作上的表现,觉得自己做得不够好,还有家庭的问题。\n\n心理医生:这些都是很常见的问题。你可以告诉我你在工作中遇到了什么困难吗?我们可以一起探讨一下如何解决。\n\n病人:我觉得自己的工作能力不够强,经常被领导批评。而且我家里的情况也不是很好,父母经常吵架,让我很难受。\n\n心理医生:我理解你的感受。这些问题确实会让人感到压力和焦虑。不过我们可以通过一些方法来缓解这种情况。比如说,你可以尝试一些放松的活动,比如瑜伽或者冥想,来减轻压力和焦虑。同时,你也可以考虑寻求家人或者朋友的帮助,让他们给你提供一些支持和鼓励。\n\n病人:好的,我会试试的。谢谢你的建议。\n\n心理医生:不用客气,如果你有任何问题或者需要进一步的帮助,随时可以联系我。'}]
-# text2 = []
-# data = {'system':'现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。', 'input':'', 'output':''}
-# for val in new_data:
-# if val['role'] == 'user':
-# continue
-#
-# print(text2)
-
-def save_jsonl(conversations, path_file):
- # 把对话写入文件
- with open(path_file, 'a+', encoding='utf-8') as f:
- for conversation in conversations:
- Json_String = json.dumps(conversation, ensure_ascii=False) + '\n'
- f.write(Json_String)
-
-
-# 生成输入提示词
-def prompt(life_type=0):
- emotions_lis = [
- "钦佩",
- "崇拜",
- "欣赏",
- "娱乐",
- "焦虑",
- "敬畏",
- "尴尬",
- "厌倦",
- "冷静",
- "困惑",
- "渴望",
- "厌恶",
- "同情",
- "痛苦",
- "着迷",
- "嫉妒",
- "兴奋",
- "恐惧",
- "痛恨",
- "有趣",
- "快乐",
- "怀旧",
- "浪漫",
- "悲伤",
- "满意",
- "性欲",
- "同情",
- "满足"
- ]
- areas_of_life = [
- "工作",
- "学业(小学,初中,高中,大学,研究生,博士)",
- "生活(衣,食,住,行等等)",
- "身体",
- "家人",
- "朋友",
- "社交",
- "恋爱",
- "就业",
- "责任",
- "爱好",
- "环境",
- "隐私",
- "安全",
- "梦想",
- "自由"
- ]
-
- # 输入数据处理
- if life_type < 0:
- raise ValueError('life_type must > 0')
-
- emo = random.choice(emotions_lis)
- life_type %= 16
-
- Input = f'''你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的专家,请你构造一些符合实际情况的具有心理健
- 康问题的病人和心理健康医生的连续的一段多轮对话记录。要求病人的问题属于{areas_of_life[life_type]}场景,具有{emo}情感,医生的回复尽可能包含心理辅导知识,并且能够一步步诱导病人说出自己的问题进而提供解决问题的可行方案。注意,构造的数据必须以医生的陈述为结束语,请只返回完整的对话内容。请以如下格式返回生成的数据:
- 病人:病人的咨询或陈述
- 医生:医生的安抚和建议
- '''
- return Input
-
-def xinghuo_api(content):
- # 对话格式
- conversation1 = {'system':'现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。', 'input':'', 'output':''}
- conversation = {'input':'', 'output':''}
- conversations = {'conversation':[]}
- # temp = {'system':'现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。', 'input':'', 'output':''}
- # 划分对话形式
- dialogue = re.split('医生:|病人:', content)
- # 对话前的数据处理
- if dialogue[0] == '':
- dialogue.pop(0)
- # 一次对话
- flag = False
- for ind, item in enumerate(dialogue):
- if flag == False:
- if (ind + 1) % 2 == 1:
- conversation1['input'] = dialogue[ind]
- else:
- conversation1['output'] = dialogue[ind]
-
- if (ind + 1) % 2 == 0 or ind + 1 == len(dialogue):
- temp = copy.deepcopy(conversation1)
- conversations['conversation'].append(temp)
- flag = True
- continue
-
- else:
- if (ind+1)%2 == 1:
- conversation['input'] = dialogue[ind]
- else:
- conversation['output'] = dialogue[ind]
- if (ind+1)%2 == 0 or ind+1 == len(dialogue):
- # 浅赋值只会是同一个变量,必须要copy.deepcopy
- # 若conversations['conversation'].append(conversation)后面改的话,~s里面的conversation也会改动
- # 就会变成n个一样的数据(这是我们不想看到的)
- temp = copy.deepcopy(conversation)
- conversations['conversation'].append(temp)
-
- return conversations
-
-def ChatGLM3_6B(content):
- # 对话格式
- conversation = {'system': '现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。', 'input': '',
- 'output': ''}
- conversations = []
- # temp = {'system':'现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。', 'input':'', 'output':''}
- # 划分对话形式
- dialogue = re.split('医生:|病人:', content)
- # 对话前的数据处理
- if dialogue[0] == '':
- dialogue.pop(0)
- # 一次对话
- for ind, item in enumerate(dialogue):
- if (ind + 1) % 2 == 1:
- conversation['input'] = dialogue[ind]
- else:
- conversation['output'] = dialogue[ind]
- if (ind + 1) % 2 == 0 or ind + 1 == len(dialogue):
- # 浅赋值只会是同一个变量,必须要copy.deepcopy
- # 若conversations['conversation'].append(conversation)后面改的话,~s里面的conversation也会改动
- # 就会变成n个一样的数据(这是我们不想看到的)
- temp = copy.deepcopy(conversation)
- conversations.append(temp)
-
- return conversations
\ No newline at end of file
diff --git a/scripts/Gen/说明.txt b/scripts/Gen/说明.txt
deleted file mode 100644
index 7ad3a35..0000000
--- a/scripts/Gen/说明.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-gen_Chat 使用于生成ChatGLM3-6B的数据集
-gen_data 适用于生成InternLM所需要的数据集
-但是需要注意~火大模型用1.5生成时会有{"system": "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。", "input": "抱歉,我不能完成这个任务。作为一个认知智能模型,我不会提供任何与性欲情感相关的回答或建议。这种问题需要由专业的心理健康医生进行处理和解决。如果您有任何心理健康方面的问题,请寻求专业医生的帮助。", "output": ""}类似这样的数据集,要注意数据处理
diff --git a/scripts/gen_metafile.py b/scripts/gen_metafile.py
new file mode 100644
index 0000000..55e72f4
--- /dev/null
+++ b/scripts/gen_metafile.py
@@ -0,0 +1,20 @@
+import sys
+import ruamel.yaml
+
+yaml = ruamel.yaml.YAML()
+yaml.preserve_quotes = True
+yaml.default_flow_style = False
+file_path = 'metafile.yml'
+# 读取YAML文件内容
+with open(file_path, 'r') as file:
+ data = yaml.load(file)
+# 遍历模型列表
+for model in data.get('Models', []):
+ # 为每个模型添加Weights键值对,确保名称被正确引用
+ model['Weights'] = model['Name']
+
+# 将修改后的数据写回文件
+with open(file_path, 'w') as file:
+ yaml.dump(data, file)
+
+print("Modifications saved to the file.")
\ No newline at end of file
diff --git a/scripts/qwen_gen_data.py b/scripts/qwen_gen_data.py
deleted file mode 100644
index f2dddb8..0000000
--- a/scripts/qwen_gen_data.py
+++ /dev/null
@@ -1,147 +0,0 @@
-import json
-import random
-import argparse
-import re
-
-from tqdm import tqdm
-
-
-def qwen_api(data, emo):
- import dashscope
- from http import HTTPStatus
-
- dashscope.api_key = ""
- prompt = f'''你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的专家,请你构造一些符合实际情况的具有心理健
- 康问题的病人和心理健康医生的连续的多轮对话记录。要求病人的问题属于{data}场景,具有{emo}情感,医生的回复尽可能包含心理辅导知识,并且能够一步步诱导病人说出自己的问题进而提供解决问题的可行方案。注意,构造的数据必须以医生的陈述为结束语,请只返回完整的对话内容。请以如下格式返回生成的数据:
- 病人:病人的咨询或陈述
- 医生:医生的安抚和建议
- '''
- response = dashscope.Generation.call(
- model='qwen-max',
- prompt=prompt,
- history=[],
- )
-
- if response.status_code == HTTPStatus.OK:
- result = response.output.text
- print(result)
- else:
- result = 'ERROR'
- return result
-
-
-def save_jsonl(data_lis, file_path):
- import json
-
- # 将字典列表写入文件,每一行一个字典
- with open(file_path, 'at', encoding='utf-8') as file:
- for item in data_lis:
- json_string = json.dumps(item, ensure_ascii=False) + '\n'
- file.write(json_string)
-
-
-if __name__ == '__main__':
- idx = 0
- parser = argparse.ArgumentParser(description='数据生成参数')
-
- parser.add_argument('--data', type=str, help='生活场景')
-
- # 解析命令行参数
- args = parser.parse_args()
-
- emotions_lis = [
- "钦佩",
- "崇拜",
- "欣赏",
- "娱乐",
- "焦虑",
- "敬畏",
- "尴尬",
- "厌倦",
- "冷静",
- "困惑",
- "渴望",
- "厌恶",
- "同情",
- "痛苦"
- "着迷",
- "嫉妒",
- "兴奋",
- "恐惧",
- "痛恨",
- "有趣",
- "快乐",
- "怀旧",
- "浪漫",
- "悲伤",
- "满意",
- "性欲",
- "同情",
- "满足"
- ]
- areas_of_life = [
- "工作",
- "学业",
- "生活",
- "身体",
- "家人",
- "朋友",
- "社交",
- "恋爱",
- "就业",
- "责任",
- "爱好",
- "环境",
- "隐私",
- "安全",
- "梦想",
- "自由"
- ]
-
- conversation_lis = []
- for i in tqdm(range(100)):
- one_conversation = {
- "conversation": []
- }
-
- dia_tuple = []
- emo = random.choice(emotions_lis)
- res = qwen_api(data=args.data, emo=emo)
- print(res)
-
- # 一次会话
- doctor_pattern = r'医生:(.*?)(病人:|$)'
-
- doctor_matches = re.findall(doctor_pattern, res, re.DOTALL)
- doctor_conversations = [match[0] for match in doctor_matches]
-
- patient_pattern = r'病人:(.*?)医生:'
- patient_matches = re.findall(patient_pattern, res, re.DOTALL)
- patient_conversations = [match for match in patient_matches]
-
- for doc, pat in zip(doctor_conversations, patient_conversations):
- if len(one_conversation['conversation']) == 0:
- one_conversation['conversation'].append(
- {
- "system": "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。",
- "input": pat,
- "output": doc
- },
- )
-
- else:
- one_conversation['conversation'].append(
- {
- "input": pat,
- "output": doc
- },
- )
- conversation_lis.append(one_conversation)
-
- idx += 1
-
- # 每生成10条数据存储一次
- if (idx % 10 == 0):
- path = f'./{args.data}.jsonl'
- save_jsonl(data_lis=conversation_lis, file_path=path)
- conversation_lis = [] # 清空
diff --git a/scripts/run_qwen.bash b/scripts/run_qwen.bash
deleted file mode 100644
index cf07df9..0000000
--- a/scripts/run_qwen.bash
+++ /dev/null
@@ -1,27 +0,0 @@
-#!/bin/bash
-
-# 定义生活领域的列表
-areas_of_life=(
- "工作"
- "学业"
- "生活"
- "身体"
- "家人"
- "朋友"
- "社交"
- "恋爱"
- "就业"
- "责任"
- "爱好"
- "环境"
- "隐私"
- "安全"
- "梦想"
- "自由"
-)
-
-# 使用for循环遍历数组
-for area in "${areas_of_life[@]}"; do
- echo "当前生活领域: $area"
- python qwen_gen_data.py --data $area
-done
diff --git a/scripts/upload_openxlab.py b/scripts/upload_openxlab.py
new file mode 100644
index 0000000..252fd3b
--- /dev/null
+++ b/scripts/upload_openxlab.py
@@ -0,0 +1,3 @@
+import os
+
+os.system("openxlab model create --model-repo='jujimeizuo/EmoLLM_Model' -s ./metafile.yml")
\ No newline at end of file
diff --git a/scripts/zhipuai_gen_data.py b/scripts/zhipuai_gen_data.py
deleted file mode 100644
index d8287a4..0000000
--- a/scripts/zhipuai_gen_data.py
+++ /dev/null
@@ -1,124 +0,0 @@
-import os
-import random
-import json
-from tqdm import tqdm
-from dotenv import load_dotenv
-from zhipuai import ZhipuAI
-
-load_dotenv()
-client = ZhipuAI(api_key=os.getenv('ZHIPUAI_API_KEY'))
-
-def zhipu_api(data, emo):
-
- def getText(role, content, text = []):
- jsoncon = {}
- jsoncon['role'] = role
- jsoncon['content'] = content
- text.append(jsoncon)
- return text
-
- prompt = f'''你是一个研究过无数具有心理健康问题的病人与心理健康医生对话的专家,请你构造一些符合实际情况的具有心理健
-康问题的病人和心理健康医生的连续的多轮对话记录。要求病人的问题属于{data}场景,具有{emo}情感,医生的回复尽可能包含心理辅导知识,并且能够一步步诱导病人说出自己的问题进而提供解决问题的可行方案。注意,构造的数据必须以医生的陈述为结束语,每次只需要构造一个案例并且不需要写案例一、二等等,请只返回完整的对话内容。请以如下格式返回生成的数据:
-病人:病人的咨询或陈述
-医生:医生的安抚和建议
- '''
-
- top_p = round(random.uniform(0.1, 0.9), 2)
- messages = getText('user', prompt)
- response = client.chat.completions.create(
- model='glm-4',
- messages=messages,
- top_p=top_p,
- )
-
- return response.choices[0].message.content
-
-
-def convert(conversation):
- ret, one_conversation = {}, {}
- ret['conversation'] = []
- one_conversation['system'] = '现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。'
-
- while '病人:' in conversation and '医生:' in conversation:
- one_conversation['input'] = conversation.split('病人:')[1].split('医生:')[0]
- one_conversation['output'] = conversation.split('病人:')[1].split('医生:')[1].split('病人:')[0]
- conversation = '病人:' + '病人:'.join(conversation.split('病人:')[2:])
- ret['conversation'].append(one_conversation)
- one_conversation = {}
-
- return ret
-
-
-def save_jsonl(data_lis, file_path):
- if not os.path.exists(os.path.dirname(file_path)):
- os.makedirs(os.path.dirname(file_path))
- with open(file_path, 'w', encoding='utf-8') as f:
- for item in data_lis:
- f.write(json.dumps(item, ensure_ascii=False) + '\n')
-
-
-if __name__ == '__main__':
- emotions_lis = [
- "钦佩",
- "崇拜",
- "欣赏",
- "娱乐",
- "焦虑",
- "敬畏",
- "尴尬",
- "厌倦",
- "冷静",
- "困惑",
- "渴望",
- "厌恶",
- "同情",
- "痛苦",
- "着迷",
- "嫉妒",
- "兴奋",
- "恐惧",
- "痛恨",
- "有趣",
- "快乐",
- "怀旧",
- "浪漫",
- "悲伤",
- "满意",
- "性欲",
- "满足"
- ]
- areas_of_life = [
- "工作",
- "学业",
- "生活",
- "身体",
- "家人",
- "朋友",
- "社交",
- "恋爱",
- "就业",
- "责任",
- "爱好",
- "环境",
- "隐私",
- "安全",
- "梦想",
- "自由"
- ]
-
- conversation_lis = []
- for emo in emotions_lis:
- for area in areas_of_life:
- if os.path.exists(f'./zhipuai/{area}/{emo}.jsonl'):
- print(f'./zhipuai/{area}/{emo}.jsonl exists')
- continue
- for i in tqdm(range(5), desc='{emo}, {area}'.format(emo=emo, area=area)):
- res = zhipu_api(area, emo)
- print(res)
- if res == 'null':
- print(area, emo, 'error')
- continue
- conversation_lis.append(convert(res))
- save_jsonl(conversation_lis, f'./zhipuai/{area}/{emo}.jsonl')
- print(f'generate ./zhipuai/{area}/{emo}.jsonl')
- conversation_lis = []
diff --git a/web_internlm2.py b/web_internlm2.py
new file mode 100644
index 0000000..0b4abec
--- /dev/null
+++ b/web_internlm2.py
@@ -0,0 +1,263 @@
+"""
+This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers.
+We mainly modified part of the code logic to adapt to the generation of our model.
+Please refer to these links below for more information:
+ 1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
+ 2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
+ 3. transformers: https://github.com/huggingface/transformers
+Please run with the command `streamlit run path/to/web_demo.py --server.address=0.0.0.0 --server.port 7860`.
+Using `python path/to/web_demo.py` may cause unknown problems.
+"""
+import copy
+import warnings
+from dataclasses import asdict, dataclass
+from typing import Callable, List, Optional
+
+import streamlit as st
+import torch
+from torch import nn
+from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
+from transformers.utils import logging
+
+from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
+from openxlab.model import download
+
+logger = logging.get_logger(__name__)
+
+download(model_repo='jujimeizuo/EmoLLM_Model',
+ output='model')
+
+@dataclass
+class GenerationConfig:
+ # this config is used for chat to provide more diversity
+ max_length: int = 32768
+ top_p: float = 0.8
+ temperature: float = 0.8
+ do_sample: bool = True
+ repetition_penalty: float = 1.005
+
+
+@torch.inference_mode()
+def generate_interactive(
+ model,
+ tokenizer,
+ prompt,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
+ additional_eos_token_id: Optional[int] = None,
+ **kwargs,
+):
+ inputs = tokenizer([prompt], padding=True, return_tensors="pt")
+ input_length = len(inputs["input_ids"][0])
+ for k, v in inputs.items():
+ inputs[k] = v.cuda()
+ input_ids = inputs["input_ids"]
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612
+ if generation_config is None:
+ generation_config = model.generation_config
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs)
+ bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
+ generation_config.bos_token_id,
+ generation_config.eos_token_id,
+ )
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ if additional_eos_token_id is not None:
+ eos_token_id.append(additional_eos_token_id)
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ warnings.warn(
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif generation_config.max_new_tokens is not None:
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
+ if not has_default_max_length:
+ logger.warn( # pylint: disable=W4902
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
+ "Please refer to the documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
+ UserWarning,
+ )
+
+ if input_ids_seq_length >= generation_config.max_length:
+ input_ids_string = "input_ids"
+ logger.warning(
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`."
+ )
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ logits_processor = model._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=input_ids,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ )
+
+ stopping_criteria = model._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+ logits_warper = model._get_logits_warper(generation_config)
+
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+ scores = None
+ while True:
+ model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
+ # forward pass to get next token
+ outputs = model(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=False,
+ output_hidden_states=False,
+ )
+
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+
+ # sample
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
+ if generation_config.do_sample:
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ next_tokens = torch.argmax(probs, dim=-1)
+
+ # update generated ids, model inputs, and length for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
+ unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
+
+ output_token_ids = input_ids[0].cpu().tolist()
+ output_token_ids = output_token_ids[input_length:]
+ for each_eos_token_id in eos_token_id:
+ if output_token_ids[-1] == each_eos_token_id:
+ output_token_ids = output_token_ids[:-1]
+ response = tokenizer.decode(output_token_ids)
+
+ yield response
+ # stop when each sentence is finished, or if we exceed the maximum length
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
+ break
+
+
+def on_btn_click():
+ del st.session_state.messages
+
+
+@st.cache_resource
+def load_model():
+ model = (
+ AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
+ .to(torch.bfloat16)
+ .cuda()
+ )
+ tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
+ return model, tokenizer
+
+
+def prepare_generation_config():
+ with st.sidebar:
+ max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768)
+ top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
+ temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
+ st.button("Clear Chat History", on_click=on_btn_click)
+
+ generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
+
+ return generation_config
+
+
+user_prompt = "<|im_start|>user\n{user}<|im_end|>\n"
+robot_prompt = "<|im_start|>assistant\n{robot}<|im_end|>\n"
+cur_query_prompt = "<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n"
+
+
+def combine_history(prompt):
+ messages = st.session_state.messages
+ meta_instruction = (
+ "你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发(排名按字母顺序排序,不分先后)、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。"
+ )
+ total_prompt = f"<|im_start|>system\n{meta_instruction}<|im_end|>\n"
+ for message in messages:
+ cur_content = message["content"]
+ if message["role"] == "user":
+ cur_prompt = user_prompt.format(user=cur_content)
+ elif message["role"] == "robot":
+ cur_prompt = robot_prompt.format(robot=cur_content)
+ else:
+ raise RuntimeError
+ total_prompt += cur_prompt
+ total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
+ return total_prompt
+
+
+def main():
+ # torch.cuda.empty_cache()
+ print("load model begin.")
+ model, tokenizer = load_model()
+ print("load model end.")
+
+ user_avator = "assets/user.png"
+ robot_avator = "assets/robot.jpeg"
+
+ st.title("EmoLLM")
+
+ generation_config = prepare_generation_config()
+
+ # Initialize chat history
+ if "messages" not in st.session_state:
+ st.session_state.messages = []
+
+ # Display chat messages from history on app rerun
+ for message in st.session_state.messages:
+ with st.chat_message(message["role"], avatar=message.get("avatar")):
+ st.markdown(message["content"])
+
+ # Accept user input
+ if prompt := st.chat_input("What is up?"):
+ # Display user message in chat message container
+ with st.chat_message("user", avatar=user_avator):
+ st.markdown(prompt)
+ real_prompt = combine_history(prompt)
+ # Add user message to chat history
+ st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})
+
+ with st.chat_message("robot", avatar=robot_avator):
+ message_placeholder = st.empty()
+ for cur_response in generate_interactive(
+ model=model,
+ tokenizer=tokenizer,
+ prompt=real_prompt,
+ additional_eos_token_id=92542,
+ **asdict(generation_config),
+ ):
+ # Display robot response in chat message container
+ message_placeholder.markdown(cur_response + "▌")
+ message_placeholder.markdown(cur_response) # pylint: disable=undefined-loop-variable
+ # Add robot response to chat history
+ st.session_state.messages.append(
+ {
+ "role": "robot",
+ "content": cur_response, # pylint: disable=undefined-loop-variable
+ "avatar": robot_avator,
+ }
+ )
+ torch.cuda.empty_cache()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/config/ft_config.py b/xtuner_config/internlm2_7b_chat_qlora_e3.py
similarity index 100%
rename from config/ft_config.py
rename to xtuner_config/internlm2_7b_chat_qlora_e3.py
diff --git a/config/qwen_7b_chat_qlora_e3.py b/xtuner_config/qwen_7b_chat_qlora_e3.py
similarity index 100%
rename from config/qwen_7b_chat_qlora_e3.py
rename to xtuner_config/qwen_7b_chat_qlora_e3.py