Dev (#141)
This commit is contained in:
commit
f16b10ef05
@ -8,6 +8,86 @@
|
|||||||
- 经典案例
|
- 经典案例
|
||||||
- 客户背景知识
|
- 客户背景知识
|
||||||
|
|
||||||
|
## **环境准备**
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
langchain==0.1.13
|
||||||
|
langchain_community==0.0.29
|
||||||
|
langchain_core==0.1.33
|
||||||
|
langchain_openai==0.0.8
|
||||||
|
langchain_text_splitters==0.0.1
|
||||||
|
FlagEmbedding==1.2.8
|
||||||
|
unstructured==0.12.6
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
cd rag
|
||||||
|
pip3 install -r requirements.txt
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## **使用指南**
|
||||||
|
|
||||||
|
### 准备数据
|
||||||
|
|
||||||
|
- txt数据:放入到 src.data.txt 目录下
|
||||||
|
- json 数据:放入到 src.data.json 目录下
|
||||||
|
|
||||||
|
会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl
|
||||||
|
|
||||||
|
如果已经有 vector DB 则会直接加载对应数据库
|
||||||
|
|
||||||
|
|
||||||
|
### 配置 config 文件
|
||||||
|
|
||||||
|
根据需要改写 config.config 文件:
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
# 存放所有 model
|
||||||
|
model_dir = os.path.join(base_dir, 'model')
|
||||||
|
|
||||||
|
# embedding model 路径以及 model name
|
||||||
|
embedding_path = os.path.join(model_dir, 'embedding_model')
|
||||||
|
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
|
||||||
|
|
||||||
|
|
||||||
|
# rerank model 路径以及 model name
|
||||||
|
rerank_path = os.path.join(model_dir, 'rerank_model')
|
||||||
|
rerank_model_name = 'BAAI/bge-reranker-large'
|
||||||
|
|
||||||
|
|
||||||
|
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
|
||||||
|
select_num = 3
|
||||||
|
|
||||||
|
# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num)
|
||||||
|
retrieval_num = 10
|
||||||
|
|
||||||
|
# 智谱 LLM 的 API key。目前 demo 仅支持智谱 AI api 作为最后生成
|
||||||
|
glm_key = ''
|
||||||
|
|
||||||
|
# Prompt template: 定义
|
||||||
|
prompt_template = """
|
||||||
|
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
|
||||||
|
|
||||||
|
根据下面检索回来的信息,回答问题。
|
||||||
|
|
||||||
|
{content}
|
||||||
|
|
||||||
|
问题:{query}
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### 调用
|
||||||
|
|
||||||
|
```python
|
||||||
|
cd rag/src
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## **数据集**
|
## **数据集**
|
||||||
|
|
||||||
- 经过清洗的QA对: 每一个QA对作为一个样本进行 embedding
|
- 经过清洗的QA对: 每一个QA对作为一个样本进行 embedding
|
||||||
@ -65,12 +145,3 @@ RAG的经典评估框架,通过以下三个方面进行评估:
|
|||||||
- 增加多路检索以增加召回率。即根据用户输入生成多个类似的query进行检索
|
- 增加多路检索以增加召回率。即根据用户输入生成多个类似的query进行检索
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,5 +2,11 @@ sentence_transformers
|
|||||||
transformers
|
transformers
|
||||||
numpy
|
numpy
|
||||||
loguru
|
loguru
|
||||||
langchain
|
|
||||||
torch
|
torch
|
||||||
|
langchain==0.1.13
|
||||||
|
langchain_community==0.0.29
|
||||||
|
langchain_core==0.1.33
|
||||||
|
langchain_openai==0.0.8
|
||||||
|
langchain_text_splitters==0.0.1
|
||||||
|
FlagEmbedding==1.2.8
|
||||||
|
unstructured==0.12.6
|
@ -8,7 +8,10 @@ model_repo = 'ajupyter/EmoLLM_aiwei'
|
|||||||
# model
|
# model
|
||||||
model_dir = os.path.join(base_dir, 'model') # model
|
model_dir = os.path.join(base_dir, 'model') # model
|
||||||
embedding_path = os.path.join(model_dir, 'embedding_model') # embedding
|
embedding_path = os.path.join(model_dir, 'embedding_model') # embedding
|
||||||
rerank_path = os.path.join(model_dir, 'rerank_model') # embedding
|
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
|
||||||
|
rerank_path = os.path.join(model_dir, 'rerank_model') # embedding
|
||||||
|
rerank_model_name = 'BAAI/bge-reranker-large'
|
||||||
|
|
||||||
llm_path = os.path.join(model_dir, 'pythia-14m') # llm
|
llm_path = os.path.join(model_dir, 'pythia-14m') # llm
|
||||||
|
|
||||||
# data
|
# data
|
||||||
@ -23,15 +26,21 @@ log_dir = os.path.join(base_dir, 'log') # log
|
|||||||
log_path = os.path.join(log_dir, 'log.log') # file
|
log_path = os.path.join(log_dir, 'log.log') # file
|
||||||
|
|
||||||
# vector DB
|
# vector DB
|
||||||
vector_db_dir = os.path.join(data_dir, 'vector_db.pkl')
|
vector_db_dir = os.path.join(data_dir, 'vector_db')
|
||||||
|
|
||||||
|
# RAG related
|
||||||
|
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
|
||||||
|
# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num)
|
||||||
select_num = 3
|
select_num = 3
|
||||||
retrieval_num = 10
|
retrieval_num = 10
|
||||||
system_prompt = """
|
|
||||||
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
|
# LLM key
|
||||||
"""
|
glm_key = ''
|
||||||
|
|
||||||
|
# prompt
|
||||||
prompt_template = """
|
prompt_template = """
|
||||||
{system_prompt}
|
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
|
||||||
|
|
||||||
根据下面检索回来的信息,回答问题。
|
根据下面检索回来的信息,回答问题。
|
||||||
{content}
|
{content}
|
||||||
问题:{query}
|
问题:{query}
|
||||||
|
@ -1,33 +1,24 @@
|
|||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
import faiss
|
|
||||||
import pickle
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from langchain_community.vectorstores import FAISS
|
from langchain_community.vectorstores import FAISS
|
||||||
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path
|
from config.config import embedding_path, embedding_model_name, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path, rerank_model_name
|
||||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||||
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
from BCEmbedding import EmbeddingModel, RerankerModel
|
from langchain.document_loaders import DirectoryLoader
|
||||||
# from util.pipeline import EmoLLMRAG
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
|
|
||||||
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
|
|
||||||
from langchain_community.llms import Cohere
|
|
||||||
from langchain.retrievers import ContextualCompressionRetriever
|
|
||||||
from langchain.retrievers.document_compressors import FlashrankRerank
|
|
||||||
from langchain_core.documents.base import Document
|
from langchain_core.documents.base import Document
|
||||||
from FlagEmbedding import FlagReranker
|
from FlagEmbedding import FlagReranker
|
||||||
|
|
||||||
class Data_process():
|
class Data_process():
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.chunk_size: int=1000
|
self.chunk_size: int=1000
|
||||||
self.chunk_overlap: int=100
|
self.chunk_overlap: int=100
|
||||||
|
|
||||||
def load_embedding_model(self, model_name='BAAI/bge-small-zh-v1.5', device='cpu', normalize_embeddings=True):
|
def load_embedding_model(self, model_name=embedding_model_name, device='cpu', normalize_embeddings=True):
|
||||||
"""
|
"""
|
||||||
加载嵌入模型。
|
加载嵌入模型。
|
||||||
|
|
||||||
@ -61,7 +52,8 @@ class Data_process():
|
|||||||
return None
|
return None
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
|
def load_rerank_model(self, model_name=rerank_model_name):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
加载重排名模型。
|
加载重排名模型。
|
||||||
|
|
||||||
@ -99,7 +91,6 @@ class Data_process():
|
|||||||
|
|
||||||
return reranker_model
|
return reranker_model
|
||||||
|
|
||||||
|
|
||||||
def extract_text_from_json(self, obj, content=None):
|
def extract_text_from_json(self, obj, content=None):
|
||||||
"""
|
"""
|
||||||
抽取json中的文本,用于向量库构建
|
抽取json中的文本,用于向量库构建
|
||||||
@ -128,7 +119,8 @@ class Data_process():
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def split_document(self, data_path, chunk_size=500, chunk_overlap=100):
|
def split_document(self, data_path):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
切分data_path文件夹下的所有txt文件
|
切分data_path文件夹下的所有txt文件
|
||||||
|
|
||||||
@ -143,7 +135,7 @@ class Data_process():
|
|||||||
|
|
||||||
|
|
||||||
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
text_spliter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
|
||||||
split_docs = []
|
split_docs = []
|
||||||
logger.info(f'Loading txt files from {data_path}')
|
logger.info(f'Loading txt files from {data_path}')
|
||||||
if os.path.isdir(data_path):
|
if os.path.isdir(data_path):
|
||||||
@ -188,7 +180,7 @@ class Data_process():
|
|||||||
# split_qa.append(Document(page_content = content))
|
# split_qa.append(Document(page_content = content))
|
||||||
#按conversation块切分
|
#按conversation块切分
|
||||||
content = self.extract_text_from_json(conversation['conversation'], '')
|
content = self.extract_text_from_json(conversation['conversation'], '')
|
||||||
logger.info(f'content====={content}')
|
#logger.info(f'content====={content}')
|
||||||
split_qa.append(Document(page_content = content))
|
split_qa.append(Document(page_content = content))
|
||||||
# logger.info(f'split_qa size====={len(split_qa)}')
|
# logger.info(f'split_qa size====={len(split_qa)}')
|
||||||
return split_qa
|
return split_qa
|
||||||
|
104
rag/src/main.py
104
rag/src/main.py
@ -1,17 +1,6 @@
|
|||||||
import os
|
|
||||||
import time
|
|
||||||
import jwt
|
|
||||||
|
|
||||||
from config.config import base_dir, data_dir
|
|
||||||
from data_processing import Data_process
|
|
||||||
from pipeline import EmoLLMRAG
|
from pipeline import EmoLLMRAG
|
||||||
|
from util.llm import get_glm
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
import torch
|
|
||||||
import streamlit as st
|
|
||||||
from openxlab.model import download
|
|
||||||
'''
|
'''
|
||||||
1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer
|
1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer
|
||||||
2)调用 embedding 提供的接口对 query 向量化
|
2)调用 embedding 提供的接口对 query 向量化
|
||||||
@ -21,69 +10,34 @@ from openxlab.model import download
|
|||||||
6)拼接 prompt 并调用模型返回结果
|
6)拼接 prompt 并调用模型返回结果
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def get_glm(temprature):
|
|
||||||
llm = ChatOpenAI(
|
|
||||||
model_name="glm-4",
|
|
||||||
openai_api_base="https://open.bigmodel.cn/api/paas/v4",
|
|
||||||
openai_api_key=generate_token("api-key"),
|
|
||||||
streaming=False,
|
|
||||||
temperature=temprature
|
|
||||||
)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
def generate_token(apikey: str, exp_seconds: int=100):
|
|
||||||
try:
|
|
||||||
id, secret = apikey.split(".")
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception("invalid apikey", e)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"api_key": id,
|
|
||||||
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
|
|
||||||
"timestamp": int(round(time.time() * 1000)),
|
|
||||||
}
|
|
||||||
|
|
||||||
return jwt.encode(
|
|
||||||
payload,
|
|
||||||
secret,
|
|
||||||
algorithm="HS256",
|
|
||||||
headers={"alg": "HS256", "sign_type": "SIGN"},
|
|
||||||
)
|
|
||||||
|
|
||||||
@st.cache_resource
|
|
||||||
def load_model():
|
|
||||||
model_dir = os.path.join(base_dir,'../model')
|
|
||||||
logger.info(f'Loading model from {model_dir}')
|
|
||||||
model = (
|
|
||||||
AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
|
|
||||||
.to(torch.bfloat16)
|
|
||||||
.cuda()
|
|
||||||
)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
def main(query, system_prompt=''):
|
|
||||||
logger.info(data_dir)
|
|
||||||
if not os.path.exists(data_dir):
|
|
||||||
os.mkdir(data_dir)
|
|
||||||
dp = Data_process()
|
|
||||||
vector_db = dp.load_vector_db()
|
|
||||||
docs, retriever = dp.retrieve(query, vector_db, k=10)
|
|
||||||
logger.info(f'Query: {query}')
|
|
||||||
logger.info("Retrieve results===============================")
|
|
||||||
for i, doc in enumerate(docs):
|
|
||||||
logger.info(doc)
|
|
||||||
passages,scores = dp.rerank(query, docs)
|
|
||||||
logger.info("After reranking===============================")
|
|
||||||
for i in range(len(scores)):
|
|
||||||
logger.info(passages[i])
|
|
||||||
logger.info(f'score: {str(scores[i])}')
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
query = """
|
||||||
main(query)
|
我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。
|
||||||
#model = get_glm(0.7)
|
无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想
|
||||||
#rag_obj = EmoLLMRAG(model, 3)
|
"""
|
||||||
#res = rag_obj.main(query)
|
|
||||||
#logger.info(res)
|
"""
|
||||||
|
输入:
|
||||||
|
model_name='glm-4',
|
||||||
|
api_base="https://open.bigmodel.cn/api/paas/v4",
|
||||||
|
temprature=0.7,
|
||||||
|
streaming=False,
|
||||||
|
输出:
|
||||||
|
LLM Model
|
||||||
|
"""
|
||||||
|
model = get_glm()
|
||||||
|
|
||||||
|
"""
|
||||||
|
输入:
|
||||||
|
LLM model
|
||||||
|
retrieval_num=3
|
||||||
|
rerank_flag=False
|
||||||
|
select_num-3
|
||||||
|
"""
|
||||||
|
rag_obj = EmoLLMRAG(model)
|
||||||
|
|
||||||
|
res = rag_obj.main(query)
|
||||||
|
|
||||||
|
logger.info(res)
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from langchain_core.prompts import PromptTemplate
|
|||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from data_processing import Data_process
|
from data_processing import Data_process
|
||||||
from config.config import system_prompt, prompt_template
|
from config.config import prompt_template
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ class EmoLLMRAG(object):
|
|||||||
4. 将 query 和检索回来的 content 传入 LLM 中
|
4. 将 query 和检索回来的 content 传入 LLM 中
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None:
|
def __init__(self, model, retrieval_num=3, rerank_flag=False, select_num=3) -> None:
|
||||||
"""
|
"""
|
||||||
输入 Model 进行初始化
|
输入 Model 进行初始化
|
||||||
|
|
||||||
@ -29,7 +29,6 @@ class EmoLLMRAG(object):
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.data_processing_obj = Data_process()
|
self.data_processing_obj = Data_process()
|
||||||
self.vectorstores = self._load_vector_db()
|
self.vectorstores = self._load_vector_db()
|
||||||
self.system_prompt = system_prompt
|
|
||||||
self.prompt_template = prompt_template
|
self.prompt_template = prompt_template
|
||||||
self.retrieval_num = retrieval_num
|
self.retrieval_num = retrieval_num
|
||||||
self.rerank_flag = rerank_flag
|
self.rerank_flag = rerank_flag
|
||||||
@ -75,7 +74,7 @@ class EmoLLMRAG(object):
|
|||||||
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
|
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template=self.prompt_template,
|
template=self.prompt_template,
|
||||||
input_variables=["query", "content", "system_prompt"],
|
input_variables=["query", "content"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# 定义 chain
|
# 定义 chain
|
||||||
@ -87,7 +86,6 @@ class EmoLLMRAG(object):
|
|||||||
{
|
{
|
||||||
"query": query,
|
"query": query,
|
||||||
"content": content,
|
"content": content,
|
||||||
"system_prompt": self.system_prompt
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return generation
|
return generation
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
import time
|
||||||
|
import jwt
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from config.config import glm_key
|
||||||
|
|
||||||
|
|
||||||
|
def get_glm(
|
||||||
|
model_name='glm-4',
|
||||||
|
api_base="https://open.bigmodel.cn/api/paas/v4",
|
||||||
|
temprature=0.7,
|
||||||
|
streaming=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
model_name=model_name,
|
||||||
|
openai_api_base=api_base,
|
||||||
|
openai_api_key=generate_token(glm_key),
|
||||||
|
streaming=streaming,
|
||||||
|
temperature=temprature
|
||||||
|
)
|
||||||
|
|
||||||
|
return llm
|
||||||
|
|
||||||
|
def generate_token(apikey: str, exp_seconds: int=100):
|
||||||
|
try:
|
||||||
|
id, secret = apikey.split(".")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("invalid apikey", e)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"api_key": id,
|
||||||
|
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
|
||||||
|
"timestamp": int(round(time.time() * 1000)),
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt.encode(
|
||||||
|
payload,
|
||||||
|
secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||||
|
)
|
@ -48,7 +48,7 @@ OpenXLab浦源 内容平台 是面向 AI 研究员和开发者提供 AI 领域
|
|||||||
- 步骤2:填写仓库相关信息
|
- 步骤2:填写仓库相关信息
|
||||||
- **步骤3:上传模型相关文件**
|
- **步骤3:上传模型相关文件**
|
||||||
|
|
||||||
更多详情和操作步骤请查看, 请参考[**模型创建流程 **(步骤1和2)](https://openxlab.org.cn/docs/models/%E6%A8%A1%E5%9E%8B%E5%88%9B%E5%BB%BA%E6%B5%81%E7%A8%8B.html)和[**上传模型**(步骤3)](https://openxlab.org.cn/docs/models/%E4%B8%8A%E4%BC%A0%E6%A8%A1%E5%9E%8B.html), 这里我们将给出所用到的基本步骤和需要注意的操作要点.
|
更多详情和操作步骤请查看, 请参考[**模型创建流程**(步骤1和2)](https://openxlab.org.cn/docs/models/%E6%A8%A1%E5%9E%8B%E5%88%9B%E5%BB%BA%E6%B5%81%E7%A8%8B.html)和[**上传模型**(步骤3)](https://openxlab.org.cn/docs/models/%E4%B8%8A%E4%BC%A0%E6%A8%A1%E5%9E%8B.html), 这里我们将给出所用到的基本步骤和需要注意的操作要点.
|
||||||
|
|
||||||
## 上传模型
|
## 上传模型
|
||||||
|
|
||||||
|
@ -4,6 +4,54 @@
|
|||||||
|
|
||||||
- 本项目在[**internlm2_7b_chat_qlora_e3**模型](./internlm2_7b_chat_qlora_e3.py)微调[指南](./README.md)的基础上,更新了对[**internlm2_7b_base_qlora_e3(配置文件)**](./internlm2_7b_base_qlora_e10_M_1e4_32_64.py)**模型**的微调。
|
- 本项目在[**internlm2_7b_chat_qlora_e3**模型](./internlm2_7b_chat_qlora_e3.py)微调[指南](./README.md)的基础上,更新了对[**internlm2_7b_base_qlora_e3(配置文件)**](./internlm2_7b_base_qlora_e10_M_1e4_32_64.py)**模型**的微调。
|
||||||
|
|
||||||
|
## 模型公布和训练epoch数设置
|
||||||
|
|
||||||
|
- 由于采用了合并后的数据集,我们对选用的internlm2_7b_base模型进行了**10 epoch**的训练,读者可以根据训练过程中的输出和loss变化,进行训练的终止和模型的挑选,也可以采用更加专业的评估方法,来对模型评测。
|
||||||
|
|
||||||
|
- 在我们公布的internlm2_7b_base_qlora微调模型时,也分别在OpenXLab和ModelScope中提供了两个不同的权重版本供用户使用和测试,更多专业测评结果将会在近期更新, 敬请期待。
|
||||||
|
|
||||||
|
- **OpenXLab**:
|
||||||
|
- [5 epoch 模型](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-InternLM7B-base)
|
||||||
|
- [10 epoch 模型](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-InternLM7B-base-10e)
|
||||||
|
|
||||||
|
- **ModelScope**:
|
||||||
|
- [5 epoch 模型](https://www.modelscope.cn/models/chg0901/EmoLLM-InternLM7B-base/files)
|
||||||
|
- [10 epoch 模型](https://www.modelscope.cn/models/chg0901/EmoLLM-InternLM7B-base-10e/files)
|
||||||
|
|
||||||
|
### 超参数设置
|
||||||
|
|
||||||
|
训练config设置详情,请查看[**internlm2_7b_base_qlora_e3(配置文件)**](./internlm2_7b_base_qlora_e10_M_1e4_32_64.py),这里我们只列出了关键的超参数或者我们做过调整的超参数。
|
||||||
|
|
||||||
|
```python
|
||||||
|
prompt_template = PROMPT_TEMPLATE.internlm2_chat
|
||||||
|
max_length = 2048
|
||||||
|
pack_to_max_length = True
|
||||||
|
|
||||||
|
batch_size = 16 # per_device
|
||||||
|
accumulative_counts = 1
|
||||||
|
|
||||||
|
max_epochs = 10
|
||||||
|
lr = 1e-4
|
||||||
|
evaluation_freq = 500
|
||||||
|
|
||||||
|
SYSTEM = "你是心理健康助手EmoLLM,由EmoLLM团队打造。你旨在通过专业心理咨询,协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术,一步步帮助来访者解决心理问题。"
|
||||||
|
evaluation_inputs = [
|
||||||
|
'我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。',
|
||||||
|
'我知道应该理性看待,但就是忍不住会去比较。我甚至晚上会因为这个睡不着觉,总想着怎样才能像他那样出色。',
|
||||||
|
'我今天心情不好,感觉不开心,很烦。']
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
lora=dict(
|
||||||
|
type=LoraConfig,
|
||||||
|
r=32,
|
||||||
|
lora_alpha=64, # lora_alpha=2*r
|
||||||
|
lora_dropout=0.1,
|
||||||
|
bias='none',
|
||||||
|
task_type='CAUSAL_LM'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## 数据
|
## 数据
|
||||||
|
|
||||||
### 数据集
|
### 数据集
|
||||||
@ -19,6 +67,8 @@
|
|||||||
| General | single_turn_dataset_1 | QA | 14000+ |
|
| General | single_turn_dataset_1 | QA | 14000+ |
|
||||||
| General | single_turn_dataset_2 | QA | 18300+ |
|
| General | single_turn_dataset_2 | QA | 18300+ |
|
||||||
|
|
||||||
|
注意:此处的数据量计数是将多轮对话拆成单轮问答后的数据量,请注意联系区别,合并后总数据量为**51468**个对话(多轮对话算一个)。
|
||||||
|
|
||||||
### 数据集处理
|
### 数据集处理
|
||||||
|
|
||||||
#### 数据格式
|
#### 数据格式
|
||||||
@ -75,20 +125,15 @@
|
|||||||
|
|
||||||
### 数据处理
|
### 数据处理
|
||||||
|
|
||||||
- 使用 `../datasets/process.py` 以处理 **multi_turn_dataset(1 和 2,QA数据转单轮对话)**, `data.json` 和 `data_pro.json` 文件(两个多轮对话),以添加或者调整 **`system` prompt**
|
- 使用 `../datasets/process.py` 以处理 **multi_turn_dataset(1 和 2,QA数据转单轮对话)**, `data.json` 和 `data_pro.json` 文件(两个多轮对话),以添加或者调整 **`system` prompt**
|
||||||
- 使用 `../datasets/processed/process_single_turn_conversation_construction.py` 处理 **single-turn dataset** (1 和 2),修改 (`input` 和 `ouput`) ,并在每次 **conversation** 中添加 **`system` prompt**
|
- 使用 `../datasets/processed/process_single_turn_conversation_construction.py` 处理 **single-turn dataset** (1 和 2),修改 (`input` 和 `ouput`) ,并在每次 **conversation** 中添加 **`system` prompt**
|
||||||
- 使用 `../datasets/processed/process_merge.py` 用于合并 `../datasets/processed/` 目录下**6个更新后的数据集**,生成一个合并后的数据集 `combined_data.json`用于最终训练
|
- 使用 `../datasets/processed/process_merge.py` 用于合并 `../datasets/processed/` 目录下**6个更新后的数据集**,生成一个合并后的数据集 `combined_data.json`用于最终训练
|
||||||
|
|
||||||
### 数据量与训练epochs设置
|
|
||||||
|
|
||||||
- 由于采用了更大的数据集,我们对模型进行了**10 epoch**的训练,读者可以根据训练过程中的输出和loss变化,进行训练的终止和模型的挑选,也可以采用更加专业的评估方法,来对模型评测。
|
|
||||||
- 在我们公布的托管于OpenXlab微调后的 internlm2_7b_chat_qlora微调模型中,我们保留了两个版本,一个是[5 epoch模型](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-InternLM7B-base/tree/main),另一个是[10 epoch模型](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-InternLM7B-base-10e/tree/main)版本(**ModelScope**模型:[5 epoch模型](https://www.modelscope.cn/models/chg0901/EmoLLM-InternLM7B-base/files)和[10 epoch模型](https://www.modelscope.cn/models/chg0901/EmoLLM-InternLM7B-base-10e/files))。
|
|
||||||
|
|
||||||
## 基于XTuner的微调🎉🎉🎉🎉🎉
|
## 基于XTuner的微调🎉🎉🎉🎉🎉
|
||||||
|
|
||||||
### 环境准备
|
### 环境准备
|
||||||
|
|
||||||
```markdown
|
```bash
|
||||||
datasets==2.16.1
|
datasets==2.16.1
|
||||||
deepspeed==0.13.1
|
deepspeed==0.13.1
|
||||||
einops==0.7.0
|
einops==0.7.0
|
||||||
@ -98,9 +143,12 @@ peft==0.7.1
|
|||||||
sentencepiece==0.1.99
|
sentencepiece==0.1.99
|
||||||
torch==2.1.2
|
torch==2.1.2
|
||||||
transformers==4.36.2
|
transformers==4.36.2
|
||||||
|
|
||||||
|
# 需要注意的几个库(版本调整或者安装较麻烦)
|
||||||
mmengine==0.10.3
|
mmengine==0.10.3
|
||||||
xtuner==0.1.15
|
xtuner==0.1.15
|
||||||
flash_attn==2.5.0
|
flash_attn==2.5.0
|
||||||
|
mpi4py==3.1.5 # conda install mpi4py
|
||||||
```
|
```
|
||||||
|
|
||||||
也可以一键安装
|
也可以一键安装
|
||||||
@ -110,7 +158,7 @@ cd xtuner_config/
|
|||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
温馨提示:flash_attn的安装可能需要在本地编译,大约需要一到两小时,可以去[flash-attention](https://github.com/Dao-AILab/flash-attention/releases)中,查找和自己机器配置匹配的whl安装包或者采用InternLM AI studio提供的2.4.2版本whl安装包,自行安装,如:
|
温馨提示:`flash_attn`的安装可能需要在本地编译,大约需要一到两小时,可以去[flash-attention](https://github.com/Dao-AILab/flash-attention/releases)中,查找和自己机器配置匹配的whl安装包或者采用InternLM AI studio提供的`2.4.2`版本whl安装包,自行安装,如:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# from flash-attention
|
# from flash-attention
|
||||||
@ -133,7 +181,7 @@ xtuner train internlm2_7b_base_qlora_e10_M_1e4_32_64.py --deepspeed deepspeed_ze
|
|||||||
|
|
||||||
### 将得到的 PTH 模型转换为 HuggingFace 模型
|
### 将得到的 PTH 模型转换为 HuggingFace 模型
|
||||||
|
|
||||||
**即:生成 Adapter 文件夹**
|
即:生成 HuggingFace Adapter 文件夹, 用于和原模型权重合并
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd xtuner_config/
|
cd xtuner_config/
|
||||||
@ -145,7 +193,7 @@ xtuner convert pth_to_hf internlm2_7b_base_qlora_e10_M_1e4_32_64.py ./work_dirs/
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### 将 HuggingFace adapter 合并到大语言模型
|
### 将 HuggingFace Adapter QLoRA权重合并到大语言模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
xtuner convert merge /root/share/model_repos/internlm2-base-7b ./hf ./merged --max-shard-size 2GB
|
xtuner convert merge /root/share/model_repos/internlm2-base-7b ./hf ./merged --max-shard-size 2GB
|
||||||
|
Loading…
Reference in New Issue
Block a user