This commit is contained in:
Anooyman 2024-03-24 15:48:59 +08:00
parent f44310f665
commit 8c81c222a9
5 changed files with 44 additions and 12 deletions

View File

@ -18,22 +18,47 @@ langchain_core==0.1.33
langchain_openai==0.0.8
langchain_text_splitters==0.0.1
FlagEmbedding==1.2.8
unstructured==0.12.6
```
```python
cd rag
cd rag
pip3 install -r requirements.txt
```
## **使用指南**
### 准备数据
- txt数据放入到 src.data.txt 目录下
- json 数据:放入到 src.data.json 目录下
会根据准备的数据构建vector DB最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl
如果已经有 vector DB 则会直接加载对应数据库
### 配置 config 文件
根据需要改写 config.config 文件:
```python
# 存放所有 model
model_dir = os.path.join(base_dir, 'model')
# embedding model 路径以及 model name
embedding_path = os.path.join(model_dir, 'embedding_model')
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
# rerank model 路径以及 model name
rerank_path = os.path.join(model_dir, 'rerank_model')
rerank_model_name = 'BAAI/bge-reranker-large'
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
select_num = 3
@ -52,7 +77,7 @@ prompt_template = """
{content}
问题:{query}
"
"""
```
### 调用

View File

@ -9,3 +9,4 @@ langchain_core==0.1.33
langchain_openai==0.0.8
langchain_text_splitters==0.0.1
FlagEmbedding==1.2.8
unstructured==0.12.6

View File

@ -8,7 +8,9 @@ model_repo = 'ajupyter/EmoLLM_aiwei'
# model
model_dir = os.path.join(base_dir, 'model') # model
embedding_path = os.path.join(model_dir, 'embedding_model') # embedding
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
rerank_path = os.path.join(model_dir, 'rerank_model') # embedding
rerank_model_name = 'BAAI/bge-reranker-large'
llm_path = os.path.join(model_dir, 'pythia-14m') # llm
# data
@ -36,7 +38,8 @@ glm_key = ''
# prompt
prompt_template = """
{system_prompt}
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇我有一些心理问题请你用专业的知识和温柔可爱俏皮的口吻帮我解决回复中可以穿插一些可爱的Emoji表情符号或者文本符号\n
根据下面检索回来的信息回答问题
{content}
问题{query}

View File

@ -4,7 +4,7 @@ import os
from loguru import logger
from langchain_community.vectorstores import FAISS
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path
from config.config import embedding_path, embedding_model_name, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path, rerank_model_name
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
@ -13,11 +13,12 @@ from langchain_core.documents.base import Document
from FlagEmbedding import FlagReranker
class Data_process():
def __init__(self):
self.chunk_size: int=1000
self.chunk_overlap: int=100
def load_embedding_model(self, model_name='BAAI/bge-small-zh-v1.5', device='cpu', normalize_embeddings=True):
def load_embedding_model(self, model_name=embedding_model_name, device='cpu', normalize_embeddings=True):
"""
加载嵌入模型
@ -51,7 +52,7 @@ class Data_process():
return None
return embeddings
def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
def load_rerank_model(self, model_name=rerank_model_name):
"""
加载重排名模型
@ -89,7 +90,6 @@ class Data_process():
return reranker_model
def extract_text_from_json(self, obj, content=None):
"""
抽取json中的文本用于向量库构建
@ -118,7 +118,7 @@ class Data_process():
return content
def split_document(self, data_path, chunk_size=500, chunk_overlap=100):
def split_document(self, data_path):
"""
切分data_path文件夹下的所有txt文件
@ -133,7 +133,7 @@ class Data_process():
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_spliter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
split_docs = []
logger.info(f'Loading txt files from {data_path}')
if os.path.isdir(data_path):
@ -178,7 +178,7 @@ class Data_process():
# split_qa.append(Document(page_content = content))
#按conversation块切分
content = self.extract_text_from_json(conversation['conversation'], '')
logger.info(f'content====={content}')
#logger.info(f'content====={content}')
split_qa.append(Document(page_content = content))
# logger.info(f'split_qa size====={len(split_qa)}')
return split_qa

View File

@ -32,7 +32,10 @@ def main(query, system_prompt=''):
logger.info(f'score: {str(scores[i])}')
if __name__ == "__main__":
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
query = """
我现在处于高三阶段感到非常迷茫和害怕我觉得自己从出生以来就是多余的没有必要存在于这个世界
无论是在家庭学校朋友还是老师面前我都感到被否定我非常难过对高考充满期望但成绩却不理想
"""
"""
输入: