From 8c81c222a985b509019f0152251b9cd2764b4139 Mon Sep 17 00:00:00 2001 From: Anooyman <875734078@qq.com> Date: Sun, 24 Mar 2024 15:48:59 +0800 Subject: [PATCH] Update --- rag/README.md | 31 ++++++++++++++++++++++++++++--- rag/requirements.txt | 1 + rag/src/config/config.py | 5 ++++- rag/src/data_processing.py | 14 +++++++------- rag/src/main.py | 5 ++++- 5 files changed, 44 insertions(+), 12 deletions(-) diff --git a/rag/README.md b/rag/README.md index e247c8a..d969c57 100644 --- a/rag/README.md +++ b/rag/README.md @@ -18,22 +18,47 @@ langchain_core==0.1.33 langchain_openai==0.0.8 langchain_text_splitters==0.0.1 FlagEmbedding==1.2.8 - +unstructured==0.12.6 ``` ```python -cd rag +cd rag pip3 install -r requirements.txt + ``` ## **使用指南** +### 准备数据 + +- txt数据:放入到 src.data.txt 目录下 +- json 数据:放入到 src.data.json 目录下 + +会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl + +如果已经有 vector DB 则会直接加载对应数据库 + + ### 配置 config 文件 根据需要改写 config.config 文件: ```python + +# 存放所有 model +model_dir = os.path.join(base_dir, 'model') + +# embedding model 路径以及 model name +embedding_path = os.path.join(model_dir, 'embedding_model') +embedding_model_name = 'BAAI/bge-small-zh-v1.5' + + +# rerank model 路径以及 model name +rerank_path = os.path.join(model_dir, 'rerank_model') +rerank_model_name = 'BAAI/bge-reranker-large' + + # select num: 代表rerank 之后选取多少个 documents 进入 LLM select_num = 3 @@ -52,7 +77,7 @@ prompt_template = """ {content} 问题:{query} -" +""" ``` ### 调用 diff --git a/rag/requirements.txt b/rag/requirements.txt index ef8a833..0dd7fe6 100644 --- a/rag/requirements.txt +++ b/rag/requirements.txt @@ -9,3 +9,4 @@ langchain_core==0.1.33 langchain_openai==0.0.8 langchain_text_splitters==0.0.1 FlagEmbedding==1.2.8 +unstructured==0.12.6 \ No newline at end of file diff --git a/rag/src/config/config.py b/rag/src/config/config.py index bcb84a0..df0fc36 100644 --- a/rag/src/config/config.py +++ b/rag/src/config/config.py @@ -8,7 +8,9 @@ model_repo = 'ajupyter/EmoLLM_aiwei' # model model_dir = os.path.join(base_dir, 'model') # model embedding_path = os.path.join(model_dir, 'embedding_model') # embedding +embedding_model_name = 'BAAI/bge-small-zh-v1.5' rerank_path = os.path.join(model_dir, 'rerank_model') # embedding +rerank_model_name = 'BAAI/bge-reranker-large' llm_path = os.path.join(model_dir, 'pythia-14m') # llm # data @@ -36,7 +38,8 @@ glm_key = '' # prompt prompt_template = """ - {system_prompt} + 你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n + 根据下面检索回来的信息,回答问题。 {content} 问题:{query} diff --git a/rag/src/data_processing.py b/rag/src/data_processing.py index 82aa628..6dbcd25 100644 --- a/rag/src/data_processing.py +++ b/rag/src/data_processing.py @@ -4,7 +4,7 @@ import os from loguru import logger from langchain_community.vectorstores import FAISS -from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path +from config.config import embedding_path, embedding_model_name, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path, rerank_model_name from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain_community.document_loaders import DirectoryLoader, TextLoader from langchain_text_splitters import RecursiveCharacterTextSplitter @@ -13,11 +13,12 @@ from langchain_core.documents.base import Document from FlagEmbedding import FlagReranker class Data_process(): + def __init__(self): self.chunk_size: int=1000 self.chunk_overlap: int=100 - def load_embedding_model(self, model_name='BAAI/bge-small-zh-v1.5', device='cpu', normalize_embeddings=True): + def load_embedding_model(self, model_name=embedding_model_name, device='cpu', normalize_embeddings=True): """ 加载嵌入模型。 @@ -51,7 +52,7 @@ class Data_process(): return None return embeddings - def load_rerank_model(self, model_name='BAAI/bge-reranker-large'): + def load_rerank_model(self, model_name=rerank_model_name): """ 加载重排名模型。 @@ -89,7 +90,6 @@ class Data_process(): return reranker_model - def extract_text_from_json(self, obj, content=None): """ 抽取json中的文本,用于向量库构建 @@ -118,7 +118,7 @@ class Data_process(): return content - def split_document(self, data_path, chunk_size=500, chunk_overlap=100): + def split_document(self, data_path): """ 切分data_path文件夹下的所有txt文件 @@ -133,7 +133,7 @@ class Data_process(): # text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) - text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + text_spliter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) split_docs = [] logger.info(f'Loading txt files from {data_path}') if os.path.isdir(data_path): @@ -178,7 +178,7 @@ class Data_process(): # split_qa.append(Document(page_content = content)) #按conversation块切分 content = self.extract_text_from_json(conversation['conversation'], '') - logger.info(f'content====={content}') + #logger.info(f'content====={content}') split_qa.append(Document(page_content = content)) # logger.info(f'split_qa size====={len(split_qa)}') return split_qa diff --git a/rag/src/main.py b/rag/src/main.py index 339a8a2..3fa9417 100644 --- a/rag/src/main.py +++ b/rag/src/main.py @@ -32,7 +32,10 @@ def main(query, system_prompt=''): logger.info(f'score: {str(scores[i])}') if __name__ == "__main__": - query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" + query = """ + 我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。 + 无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想 + """ """ 输入: