Dev (#141)
This commit is contained in:
commit
f16b10ef05
@ -8,6 +8,86 @@
|
||||
- 经典案例
|
||||
- 客户背景知识
|
||||
|
||||
## **环境准备**
|
||||
|
||||
```python
|
||||
|
||||
langchain==0.1.13
|
||||
langchain_community==0.0.29
|
||||
langchain_core==0.1.33
|
||||
langchain_openai==0.0.8
|
||||
langchain_text_splitters==0.0.1
|
||||
FlagEmbedding==1.2.8
|
||||
unstructured==0.12.6
|
||||
```
|
||||
|
||||
```python
|
||||
|
||||
cd rag
|
||||
pip3 install -r requirements.txt
|
||||
|
||||
```
|
||||
|
||||
## **使用指南**
|
||||
|
||||
### 准备数据
|
||||
|
||||
- txt数据:放入到 src.data.txt 目录下
|
||||
- json 数据:放入到 src.data.json 目录下
|
||||
|
||||
会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl
|
||||
|
||||
如果已经有 vector DB 则会直接加载对应数据库
|
||||
|
||||
|
||||
### 配置 config 文件
|
||||
|
||||
根据需要改写 config.config 文件:
|
||||
|
||||
```python
|
||||
|
||||
# 存放所有 model
|
||||
model_dir = os.path.join(base_dir, 'model')
|
||||
|
||||
# embedding model 路径以及 model name
|
||||
embedding_path = os.path.join(model_dir, 'embedding_model')
|
||||
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
|
||||
|
||||
|
||||
# rerank model 路径以及 model name
|
||||
rerank_path = os.path.join(model_dir, 'rerank_model')
|
||||
rerank_model_name = 'BAAI/bge-reranker-large'
|
||||
|
||||
|
||||
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
|
||||
select_num = 3
|
||||
|
||||
# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num)
|
||||
retrieval_num = 10
|
||||
|
||||
# 智谱 LLM 的 API key。目前 demo 仅支持智谱 AI api 作为最后生成
|
||||
glm_key = ''
|
||||
|
||||
# Prompt template: 定义
|
||||
prompt_template = """
|
||||
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
|
||||
|
||||
根据下面检索回来的信息,回答问题。
|
||||
|
||||
{content}
|
||||
|
||||
问题:{query}
|
||||
"""
|
||||
```
|
||||
|
||||
### 调用
|
||||
|
||||
```python
|
||||
cd rag/src
|
||||
python main.py
|
||||
```
|
||||
|
||||
|
||||
## **数据集**
|
||||
|
||||
- 经过清洗的QA对: 每一个QA对作为一个样本进行 embedding
|
||||
@ -65,12 +145,3 @@ RAG的经典评估框架,通过以下三个方面进行评估:
|
||||
- 增加多路检索以增加召回率。即根据用户输入生成多个类似的query进行检索
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -2,5 +2,11 @@ sentence_transformers
|
||||
transformers
|
||||
numpy
|
||||
loguru
|
||||
langchain
|
||||
torch
|
||||
langchain==0.1.13
|
||||
langchain_community==0.0.29
|
||||
langchain_core==0.1.33
|
||||
langchain_openai==0.0.8
|
||||
langchain_text_splitters==0.0.1
|
||||
FlagEmbedding==1.2.8
|
||||
unstructured==0.12.6
|
@ -8,7 +8,10 @@ model_repo = 'ajupyter/EmoLLM_aiwei'
|
||||
# model
|
||||
model_dir = os.path.join(base_dir, 'model') # model
|
||||
embedding_path = os.path.join(model_dir, 'embedding_model') # embedding
|
||||
rerank_path = os.path.join(model_dir, 'rerank_model') # embedding
|
||||
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
|
||||
rerank_path = os.path.join(model_dir, 'rerank_model') # embedding
|
||||
rerank_model_name = 'BAAI/bge-reranker-large'
|
||||
|
||||
llm_path = os.path.join(model_dir, 'pythia-14m') # llm
|
||||
|
||||
# data
|
||||
@ -23,15 +26,21 @@ log_dir = os.path.join(base_dir, 'log') # log
|
||||
log_path = os.path.join(log_dir, 'log.log') # file
|
||||
|
||||
# vector DB
|
||||
vector_db_dir = os.path.join(data_dir, 'vector_db.pkl')
|
||||
vector_db_dir = os.path.join(data_dir, 'vector_db')
|
||||
|
||||
# RAG related
|
||||
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
|
||||
# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num)
|
||||
select_num = 3
|
||||
retrieval_num = 10
|
||||
system_prompt = """
|
||||
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
|
||||
"""
|
||||
|
||||
# LLM key
|
||||
glm_key = ''
|
||||
|
||||
# prompt
|
||||
prompt_template = """
|
||||
{system_prompt}
|
||||
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
|
||||
|
||||
根据下面检索回来的信息,回答问题。
|
||||
{content}
|
||||
问题:{query}
|
||||
|
@ -1,33 +1,24 @@
|
||||
import json
|
||||
import pickle
|
||||
import faiss
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from loguru import logger
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path
|
||||
from config.config import embedding_path, embedding_model_name, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path, rerank_model_name
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
|
||||
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
|
||||
from BCEmbedding import EmbeddingModel, RerankerModel
|
||||
# from util.pipeline import EmoLLMRAG
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
|
||||
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
|
||||
from langchain_community.llms import Cohere
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import FlashrankRerank
|
||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain.document_loaders import DirectoryLoader
|
||||
from langchain_core.documents.base import Document
|
||||
from FlagEmbedding import FlagReranker
|
||||
|
||||
class Data_process():
|
||||
|
||||
def __init__(self):
|
||||
self.chunk_size: int=1000
|
||||
self.chunk_overlap: int=100
|
||||
|
||||
def load_embedding_model(self, model_name='BAAI/bge-small-zh-v1.5', device='cpu', normalize_embeddings=True):
|
||||
def load_embedding_model(self, model_name=embedding_model_name, device='cpu', normalize_embeddings=True):
|
||||
"""
|
||||
加载嵌入模型。
|
||||
|
||||
@ -61,7 +52,8 @@ class Data_process():
|
||||
return None
|
||||
return embeddings
|
||||
|
||||
def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
|
||||
def load_rerank_model(self, model_name=rerank_model_name):
|
||||
|
||||
"""
|
||||
加载重排名模型。
|
||||
|
||||
@ -99,7 +91,6 @@ class Data_process():
|
||||
|
||||
return reranker_model
|
||||
|
||||
|
||||
def extract_text_from_json(self, obj, content=None):
|
||||
"""
|
||||
抽取json中的文本,用于向量库构建
|
||||
@ -128,7 +119,8 @@ class Data_process():
|
||||
return content
|
||||
|
||||
|
||||
def split_document(self, data_path, chunk_size=500, chunk_overlap=100):
|
||||
def split_document(self, data_path):
|
||||
|
||||
"""
|
||||
切分data_path文件夹下的所有txt文件
|
||||
|
||||
@ -143,7 +135,7 @@ class Data_process():
|
||||
|
||||
|
||||
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
text_spliter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
|
||||
split_docs = []
|
||||
logger.info(f'Loading txt files from {data_path}')
|
||||
if os.path.isdir(data_path):
|
||||
@ -188,7 +180,7 @@ class Data_process():
|
||||
# split_qa.append(Document(page_content = content))
|
||||
#按conversation块切分
|
||||
content = self.extract_text_from_json(conversation['conversation'], '')
|
||||
logger.info(f'content====={content}')
|
||||
#logger.info(f'content====={content}')
|
||||
split_qa.append(Document(page_content = content))
|
||||
# logger.info(f'split_qa size====={len(split_qa)}')
|
||||
return split_qa
|
||||
|
104
rag/src/main.py
104
rag/src/main.py
@ -1,17 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
import jwt
|
||||
|
||||
from config.config import base_dir, data_dir
|
||||
from data_processing import Data_process
|
||||
from pipeline import EmoLLMRAG
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from util.llm import get_glm
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
import streamlit as st
|
||||
from openxlab.model import download
|
||||
'''
|
||||
1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer
|
||||
2)调用 embedding 提供的接口对 query 向量化
|
||||
@ -21,69 +10,34 @@ from openxlab.model import download
|
||||
6)拼接 prompt 并调用模型返回结果
|
||||
|
||||
'''
|
||||
def get_glm(temprature):
|
||||
llm = ChatOpenAI(
|
||||
model_name="glm-4",
|
||||
openai_api_base="https://open.bigmodel.cn/api/paas/v4",
|
||||
openai_api_key=generate_token("api-key"),
|
||||
streaming=False,
|
||||
temperature=temprature
|
||||
)
|
||||
return llm
|
||||
|
||||
def generate_token(apikey: str, exp_seconds: int=100):
|
||||
try:
|
||||
id, secret = apikey.split(".")
|
||||
except Exception as e:
|
||||
raise Exception("invalid apikey", e)
|
||||
|
||||
payload = {
|
||||
"api_key": id,
|
||||
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
|
||||
"timestamp": int(round(time.time() * 1000)),
|
||||
}
|
||||
|
||||
return jwt.encode(
|
||||
payload,
|
||||
secret,
|
||||
algorithm="HS256",
|
||||
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||
)
|
||||
|
||||
@st.cache_resource
|
||||
def load_model():
|
||||
model_dir = os.path.join(base_dir,'../model')
|
||||
logger.info(f'Loading model from {model_dir}')
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
|
||||
.to(torch.bfloat16)
|
||||
.cuda()
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
def main(query, system_prompt=''):
|
||||
logger.info(data_dir)
|
||||
if not os.path.exists(data_dir):
|
||||
os.mkdir(data_dir)
|
||||
dp = Data_process()
|
||||
vector_db = dp.load_vector_db()
|
||||
docs, retriever = dp.retrieve(query, vector_db, k=10)
|
||||
logger.info(f'Query: {query}')
|
||||
logger.info("Retrieve results===============================")
|
||||
for i, doc in enumerate(docs):
|
||||
logger.info(doc)
|
||||
passages,scores = dp.rerank(query, docs)
|
||||
logger.info("After reranking===============================")
|
||||
for i in range(len(scores)):
|
||||
logger.info(passages[i])
|
||||
logger.info(f'score: {str(scores[i])}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
||||
main(query)
|
||||
#model = get_glm(0.7)
|
||||
#rag_obj = EmoLLMRAG(model, 3)
|
||||
#res = rag_obj.main(query)
|
||||
#logger.info(res)
|
||||
query = """
|
||||
我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。
|
||||
无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想
|
||||
"""
|
||||
|
||||
"""
|
||||
输入:
|
||||
model_name='glm-4',
|
||||
api_base="https://open.bigmodel.cn/api/paas/v4",
|
||||
temprature=0.7,
|
||||
streaming=False,
|
||||
输出:
|
||||
LLM Model
|
||||
"""
|
||||
model = get_glm()
|
||||
|
||||
"""
|
||||
输入:
|
||||
LLM model
|
||||
retrieval_num=3
|
||||
rerank_flag=False
|
||||
select_num-3
|
||||
"""
|
||||
rag_obj = EmoLLMRAG(model)
|
||||
|
||||
res = rag_obj.main(query)
|
||||
|
||||
logger.info(res)
|
||||
|
||||
|
@ -3,7 +3,7 @@ from langchain_core.prompts import PromptTemplate
|
||||
from transformers.utils import logging
|
||||
|
||||
from data_processing import Data_process
|
||||
from config.config import system_prompt, prompt_template
|
||||
from config.config import prompt_template
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ class EmoLLMRAG(object):
|
||||
4. 将 query 和检索回来的 content 传入 LLM 中
|
||||
"""
|
||||
|
||||
def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None:
|
||||
def __init__(self, model, retrieval_num=3, rerank_flag=False, select_num=3) -> None:
|
||||
"""
|
||||
输入 Model 进行初始化
|
||||
|
||||
@ -29,7 +29,6 @@ class EmoLLMRAG(object):
|
||||
self.model = model
|
||||
self.data_processing_obj = Data_process()
|
||||
self.vectorstores = self._load_vector_db()
|
||||
self.system_prompt = system_prompt
|
||||
self.prompt_template = prompt_template
|
||||
self.retrieval_num = retrieval_num
|
||||
self.rerank_flag = rerank_flag
|
||||
@ -75,7 +74,7 @@ class EmoLLMRAG(object):
|
||||
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
|
||||
prompt = PromptTemplate(
|
||||
template=self.prompt_template,
|
||||
input_variables=["query", "content", "system_prompt"],
|
||||
input_variables=["query", "content"],
|
||||
)
|
||||
|
||||
# 定义 chain
|
||||
@ -87,7 +86,6 @@ class EmoLLMRAG(object):
|
||||
{
|
||||
"query": query,
|
||||
"content": content,
|
||||
"system_prompt": self.system_prompt
|
||||
}
|
||||
)
|
||||
return generation
|
||||
|
@ -0,0 +1,44 @@
|
||||
import time
|
||||
import jwt
|
||||
from langchain_openai import ChatOpenAI
|
||||
from config.config import glm_key
|
||||
|
||||
|
||||
def get_glm(
|
||||
model_name='glm-4',
|
||||
api_base="https://open.bigmodel.cn/api/paas/v4",
|
||||
temprature=0.7,
|
||||
streaming=False,
|
||||
):
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model_name=model_name,
|
||||
openai_api_base=api_base,
|
||||
openai_api_key=generate_token(glm_key),
|
||||
streaming=streaming,
|
||||
temperature=temprature
|
||||
)
|
||||
|
||||
return llm
|
||||
|
||||
def generate_token(apikey: str, exp_seconds: int=100):
|
||||
try:
|
||||
id, secret = apikey.split(".")
|
||||
except Exception as e:
|
||||
raise Exception("invalid apikey", e)
|
||||
|
||||
payload = {
|
||||
"api_key": id,
|
||||
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
|
||||
"timestamp": int(round(time.time() * 1000)),
|
||||
}
|
||||
|
||||
return jwt.encode(
|
||||
payload,
|
||||
secret,
|
||||
algorithm="HS256",
|
||||
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||
)
|
@ -48,7 +48,7 @@ OpenXLab浦源 内容平台 是面向 AI 研究员和开发者提供 AI 领域
|
||||
- 步骤2:填写仓库相关信息
|
||||
- **步骤3:上传模型相关文件**
|
||||
|
||||
更多详情和操作步骤请查看, 请参考[**模型创建流程 **(步骤1和2)](https://openxlab.org.cn/docs/models/%E6%A8%A1%E5%9E%8B%E5%88%9B%E5%BB%BA%E6%B5%81%E7%A8%8B.html)和[**上传模型**(步骤3)](https://openxlab.org.cn/docs/models/%E4%B8%8A%E4%BC%A0%E6%A8%A1%E5%9E%8B.html), 这里我们将给出所用到的基本步骤和需要注意的操作要点.
|
||||
更多详情和操作步骤请查看, 请参考[**模型创建流程**(步骤1和2)](https://openxlab.org.cn/docs/models/%E6%A8%A1%E5%9E%8B%E5%88%9B%E5%BB%BA%E6%B5%81%E7%A8%8B.html)和[**上传模型**(步骤3)](https://openxlab.org.cn/docs/models/%E4%B8%8A%E4%BC%A0%E6%A8%A1%E5%9E%8B.html), 这里我们将给出所用到的基本步骤和需要注意的操作要点.
|
||||
|
||||
## 上传模型
|
||||
|
||||
|
@ -4,6 +4,54 @@
|
||||
|
||||
- 本项目在[**internlm2_7b_chat_qlora_e3**模型](./internlm2_7b_chat_qlora_e3.py)微调[指南](./README.md)的基础上,更新了对[**internlm2_7b_base_qlora_e3(配置文件)**](./internlm2_7b_base_qlora_e10_M_1e4_32_64.py)**模型**的微调。
|
||||
|
||||
## 模型公布和训练epoch数设置
|
||||
|
||||
- 由于采用了合并后的数据集,我们对选用的internlm2_7b_base模型进行了**10 epoch**的训练,读者可以根据训练过程中的输出和loss变化,进行训练的终止和模型的挑选,也可以采用更加专业的评估方法,来对模型评测。
|
||||
|
||||
- 在我们公布的internlm2_7b_base_qlora微调模型时,也分别在OpenXLab和ModelScope中提供了两个不同的权重版本供用户使用和测试,更多专业测评结果将会在近期更新, 敬请期待。
|
||||
|
||||
- **OpenXLab**:
|
||||
- [5 epoch 模型](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-InternLM7B-base)
|
||||
- [10 epoch 模型](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-InternLM7B-base-10e)
|
||||
|
||||
- **ModelScope**:
|
||||
- [5 epoch 模型](https://www.modelscope.cn/models/chg0901/EmoLLM-InternLM7B-base/files)
|
||||
- [10 epoch 模型](https://www.modelscope.cn/models/chg0901/EmoLLM-InternLM7B-base-10e/files)
|
||||
|
||||
### 超参数设置
|
||||
|
||||
训练config设置详情,请查看[**internlm2_7b_base_qlora_e3(配置文件)**](./internlm2_7b_base_qlora_e10_M_1e4_32_64.py),这里我们只列出了关键的超参数或者我们做过调整的超参数。
|
||||
|
||||
```python
|
||||
prompt_template = PROMPT_TEMPLATE.internlm2_chat
|
||||
max_length = 2048
|
||||
pack_to_max_length = True
|
||||
|
||||
batch_size = 16 # per_device
|
||||
accumulative_counts = 1
|
||||
|
||||
max_epochs = 10
|
||||
lr = 1e-4
|
||||
evaluation_freq = 500
|
||||
|
||||
SYSTEM = "你是心理健康助手EmoLLM,由EmoLLM团队打造。你旨在通过专业心理咨询,协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术,一步步帮助来访者解决心理问题。"
|
||||
evaluation_inputs = [
|
||||
'我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。',
|
||||
'我知道应该理性看待,但就是忍不住会去比较。我甚至晚上会因为这个睡不着觉,总想着怎样才能像他那样出色。',
|
||||
'我今天心情不好,感觉不开心,很烦。']
|
||||
|
||||
model = dict(
|
||||
lora=dict(
|
||||
type=LoraConfig,
|
||||
r=32,
|
||||
lora_alpha=64, # lora_alpha=2*r
|
||||
lora_dropout=0.1,
|
||||
bias='none',
|
||||
task_type='CAUSAL_LM'
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
## 数据
|
||||
|
||||
### 数据集
|
||||
@ -19,6 +67,8 @@
|
||||
| General | single_turn_dataset_1 | QA | 14000+ |
|
||||
| General | single_turn_dataset_2 | QA | 18300+ |
|
||||
|
||||
注意:此处的数据量计数是将多轮对话拆成单轮问答后的数据量,请注意联系区别,合并后总数据量为**51468**个对话(多轮对话算一个)。
|
||||
|
||||
### 数据集处理
|
||||
|
||||
#### 数据格式
|
||||
@ -75,20 +125,15 @@
|
||||
|
||||
### 数据处理
|
||||
|
||||
- 使用 `../datasets/process.py` 以处理 **multi_turn_dataset(1 和 2,QA数据转单轮对话)**, `data.json` 和 `data_pro.json` 文件(两个多轮对话),以添加或者调整 **`system` prompt**
|
||||
- 使用 `../datasets/processed/process_single_turn_conversation_construction.py` 处理 **single-turn dataset** (1 和 2),修改 (`input` 和 `ouput`) ,并在每次 **conversation** 中添加 **`system` prompt**
|
||||
- 使用 `../datasets/processed/process_merge.py` 用于合并 `../datasets/processed/` 目录下**6个更新后的数据集**,生成一个合并后的数据集 `combined_data.json`用于最终训练
|
||||
|
||||
### 数据量与训练epochs设置
|
||||
|
||||
- 由于采用了更大的数据集,我们对模型进行了**10 epoch**的训练,读者可以根据训练过程中的输出和loss变化,进行训练的终止和模型的挑选,也可以采用更加专业的评估方法,来对模型评测。
|
||||
- 在我们公布的托管于OpenXlab微调后的 internlm2_7b_chat_qlora微调模型中,我们保留了两个版本,一个是[5 epoch模型](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-InternLM7B-base/tree/main),另一个是[10 epoch模型](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-InternLM7B-base-10e/tree/main)版本(**ModelScope**模型:[5 epoch模型](https://www.modelscope.cn/models/chg0901/EmoLLM-InternLM7B-base/files)和[10 epoch模型](https://www.modelscope.cn/models/chg0901/EmoLLM-InternLM7B-base-10e/files))。
|
||||
- 使用 `../datasets/process.py` 以处理 **multi_turn_dataset(1 和 2,QA数据转单轮对话)**, `data.json` 和 `data_pro.json` 文件(两个多轮对话),以添加或者调整 **`system` prompt**
|
||||
- 使用 `../datasets/processed/process_single_turn_conversation_construction.py` 处理 **single-turn dataset** (1 和 2),修改 (`input` 和 `ouput`) ,并在每次 **conversation** 中添加 **`system` prompt**
|
||||
- 使用 `../datasets/processed/process_merge.py` 用于合并 `../datasets/processed/` 目录下**6个更新后的数据集**,生成一个合并后的数据集 `combined_data.json`用于最终训练
|
||||
|
||||
## 基于XTuner的微调🎉🎉🎉🎉🎉
|
||||
|
||||
### 环境准备
|
||||
|
||||
```markdown
|
||||
```bash
|
||||
datasets==2.16.1
|
||||
deepspeed==0.13.1
|
||||
einops==0.7.0
|
||||
@ -98,9 +143,12 @@ peft==0.7.1
|
||||
sentencepiece==0.1.99
|
||||
torch==2.1.2
|
||||
transformers==4.36.2
|
||||
|
||||
# 需要注意的几个库(版本调整或者安装较麻烦)
|
||||
mmengine==0.10.3
|
||||
xtuner==0.1.15
|
||||
flash_attn==2.5.0
|
||||
mpi4py==3.1.5 # conda install mpi4py
|
||||
```
|
||||
|
||||
也可以一键安装
|
||||
@ -110,7 +158,7 @@ cd xtuner_config/
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
温馨提示:flash_attn的安装可能需要在本地编译,大约需要一到两小时,可以去[flash-attention](https://github.com/Dao-AILab/flash-attention/releases)中,查找和自己机器配置匹配的whl安装包或者采用InternLM AI studio提供的2.4.2版本whl安装包,自行安装,如:
|
||||
温馨提示:`flash_attn`的安装可能需要在本地编译,大约需要一到两小时,可以去[flash-attention](https://github.com/Dao-AILab/flash-attention/releases)中,查找和自己机器配置匹配的whl安装包或者采用InternLM AI studio提供的`2.4.2`版本whl安装包,自行安装,如:
|
||||
|
||||
```bash
|
||||
# from flash-attention
|
||||
@ -133,7 +181,7 @@ xtuner train internlm2_7b_base_qlora_e10_M_1e4_32_64.py --deepspeed deepspeed_ze
|
||||
|
||||
### 将得到的 PTH 模型转换为 HuggingFace 模型
|
||||
|
||||
**即:生成 Adapter 文件夹**
|
||||
即:生成 HuggingFace Adapter 文件夹, 用于和原模型权重合并
|
||||
|
||||
```bash
|
||||
cd xtuner_config/
|
||||
@ -145,7 +193,7 @@ xtuner convert pth_to_hf internlm2_7b_base_qlora_e10_M_1e4_32_64.py ./work_dirs/
|
||||
|
||||
---
|
||||
|
||||
### 将 HuggingFace adapter 合并到大语言模型
|
||||
### 将 HuggingFace Adapter QLoRA权重合并到大语言模型
|
||||
|
||||
```bash
|
||||
xtuner convert merge /root/share/model_repos/internlm2-base-7b ./hf ./merged --max-shard-size 2GB
|
||||
|
Loading…
Reference in New Issue
Block a user