Update RAG pipeline (#153)
This commit is contained in:
		
						commit
						3bead46e66
					
				| @ -35,6 +35,24 @@ pip3 install -r requirements.txt | ||||
| - txt数据:放入到 src.data.txt 目录下 | ||||
| - json 数据:放入到 src.data.json 目录下 | ||||
| 
 | ||||
| JSON 数据格式如下 | ||||
| ```python | ||||
| [ | ||||
|     { | ||||
|         "conversation": [ | ||||
|             { | ||||
|                 "input": "(微微皱眉)艾薇,我觉得我在学校的学习压力好大啊,虽然我真的很喜欢我的专业,但最近总是担心自己无法达到自己的期望,这让我有点焦虑。", | ||||
|                 "output": "(微笑)😊 我完全理解你的感受,亲爱的。学习是我们生活的一部分,尤其是在我们热爱的专业里,期待自己做得更好是再正常不过了。你知道吗?这种对成功的渴望其实是一种积极的动力源呢!但是,当它开始影响到我们的心理健康时,我们就需要学会调整和应对了。你可以先试着跟我分享一下,具体是什么样的学习任务或者目标让你感到压力山大呢?" | ||||
|             }, | ||||
|             { | ||||
|                 "input": "主要是下个月的一场重要考试,我总觉得自己复习得不够充分,而且害怕失败...", | ||||
|                 "output": "💖 噢,明白了。面对重大考试的紧张和担忧是很普遍的现象。首先,我要肯定你对自己的高标准要求,这是你对自己负责的表现。不过,我们可以尝试制定一个更为实际可行的学习计划,将大目标分解成小目标,每天完成一部分,这样既能提升自信心又能减轻焦虑感。同时,也要记得适时放松,合理安排休息时间哦!比如学习45分钟后,可以适当休息10分钟,做些自己喜欢的事情,劳逸结合才是长久之计呢!💪📚\n另外,也可以尝试一些深呼吸、冥想等放松技巧来缓解焦虑情绪。如果你愿意的话,下次咨询我们可以一起练习,看看哪种方式最适合帮助你应对压力。现在,让我们一步步来,先从细化学习计划开始,你觉得怎么样呢?🌸" | ||||
|             } | ||||
|         ] | ||||
|     }, | ||||
| ]  | ||||
| ``` | ||||
| 
 | ||||
| 会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl | ||||
| 
 | ||||
| 如果已经有 vector DB 则会直接加载对应数据库 | ||||
| @ -91,6 +109,7 @@ python main.py | ||||
| ## **数据集** | ||||
| 
 | ||||
| - 经过清洗的QA对: 每一个QA对作为一个样本进行 embedding | ||||
| - 经过清洗的对话: 每一个对话作为一个样本进行 embedding | ||||
| - 经过筛选的TXT文本 | ||||
| 	- 直接对TXT文本生成embedding (基于token长度进行切分) | ||||
| 	- 过滤目录等无关信息后对TXT文本生成embedding (基于token长度进行切分) | ||||
| @ -115,7 +134,7 @@ LangChain 是一个开源框架,用于构建基于大型语言模型(LLM) | ||||
| Faiss是一个用于高效相似性搜索和密集向量聚类的库。它包含的算法可以搜索任意大小的向量集。由于langchain已经整合过FAISS,因此本项目中不在基于原生文档开发[FAISS in Langchain](https://python.langchain.com/docs/integrations/vectorstores/faiss) | ||||
| 
 | ||||
| 
 | ||||
| ### [RAGAS](https://github.com/explodinggradients/ragas) | ||||
| ### [RAGAS](https://github.com/explodinggradients/ragas) (TODO) | ||||
| 
 | ||||
| RAG的经典评估框架,通过以下三个方面进行评估: | ||||
| 
 | ||||
|  | ||||
| @ -25,6 +25,10 @@ qa_dir = os.path.join(data_dir, 'json') | ||||
| log_dir = os.path.join(base_dir, 'log')                             # log | ||||
| log_path = os.path.join(log_dir, 'log.log')                         # file | ||||
| 
 | ||||
| # txt embedding 切分参数      | ||||
| chunk_size=1000 | ||||
| chunk_overlap=100 | ||||
| 
 | ||||
| # vector DB | ||||
| vector_db_dir = os.path.join(data_dir, 'vector_db') | ||||
| 
 | ||||
|  | ||||
| @ -4,7 +4,18 @@ import os | ||||
| 
 | ||||
| from loguru import logger | ||||
| from langchain_community.vectorstores import FAISS | ||||
| from config.config import embedding_path, embedding_model_name, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path, rerank_model_name | ||||
| from config.config import ( | ||||
|     embedding_path, | ||||
|     embedding_model_name, | ||||
|     doc_dir, qa_dir, | ||||
|     knowledge_pkl_path, | ||||
|     data_dir, | ||||
|     vector_db_dir, | ||||
|     rerank_path, | ||||
|     rerank_model_name, | ||||
|     chunk_size, | ||||
|     chunk_overlap | ||||
| ) | ||||
| from langchain.embeddings import HuggingFaceBgeEmbeddings | ||||
| from langchain_community.document_loaders import DirectoryLoader, TextLoader | ||||
| from langchain_text_splitters import RecursiveCharacterTextSplitter | ||||
| @ -15,8 +26,9 @@ from FlagEmbedding import FlagReranker | ||||
| class Data_process(): | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         self.chunk_size: int=1000 | ||||
|         self.chunk_overlap: int=100     | ||||
| 
 | ||||
|         self.chunk_size: int=chunk_size | ||||
|         self.chunk_overlap: int=chunk_overlap | ||||
|          | ||||
|     def load_embedding_model(self, model_name=embedding_model_name, device='cpu', normalize_embeddings=True): | ||||
|         """ | ||||
| @ -53,7 +65,6 @@ class Data_process(): | ||||
|         return embeddings | ||||
|      | ||||
|     def load_rerank_model(self, model_name=rerank_model_name): | ||||
| 
 | ||||
|         """ | ||||
|         加载重排名模型。 | ||||
|          | ||||
| @ -118,9 +129,7 @@ class Data_process(): | ||||
|             content += obj | ||||
|         return content | ||||
| 
 | ||||
| 
 | ||||
|     def split_document(self, data_path): | ||||
| 
 | ||||
|         """ | ||||
|         切分data_path文件夹下的所有txt文件 | ||||
|          | ||||
| @ -132,8 +141,6 @@ class Data_process(): | ||||
|         返回: | ||||
|         - split_docs: list | ||||
|         """ | ||||
|          | ||||
|          | ||||
|         # text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | ||||
|         text_spliter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)  | ||||
|         split_docs = [] | ||||
| @ -152,7 +159,6 @@ class Data_process(): | ||||
|         logger.info(f'split_docs size {len(split_docs)}') | ||||
|         return split_docs | ||||
|    | ||||
|    | ||||
|     def split_conversation(self, path): | ||||
|         """ | ||||
|         按conversation块切分path文件夹下的所有json文件 | ||||
| @ -171,43 +177,29 @@ class Data_process(): | ||||
|                         file_path = os.path.join(root, file) | ||||
|                         logger.info(f'splitting file {file_path}') | ||||
|                         with open(file_path, 'r', encoding='utf-8') as f: | ||||
|                             data = json.load(f) | ||||
|                             # print(data) | ||||
|                             for conversation in data: | ||||
|                                 # for dialog in conversation['conversation']: | ||||
|                                     ##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document | ||||
|                                     # content = self.extract_text_from_json(dialog,'') | ||||
|                                     # split_qa.append(Document(page_content = content)) | ||||
|                                 #按conversation块切分 | ||||
|                                 content = self.extract_text_from_json(conversation['conversation'], '') | ||||
|                                 #logger.info(f'content====={content}') | ||||
|                             for line in f.readlines(): | ||||
|                                 content = self.extract_text_from_json(line,'') | ||||
|                                 split_qa.append(Document(page_content = content)) | ||||
| 
 | ||||
|                             #data = json.load(f) | ||||
|                             #for conversation in data: | ||||
|                             #    #for dialog in conversation['conversation']: | ||||
|                             #    #    #按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document | ||||
|                             #    #    content = self.extract_text_from_json(dialog,'') | ||||
|                             #    #    split_qa.append(Document(page_content = content)) | ||||
|                             #    #按conversation块切分 | ||||
|                             #    content = self.extract_text_from_json(conversation['conversation'], '') | ||||
|                             #    #logger.info(f'content====={content}') | ||||
|                             #    split_qa.append(Document(page_content = content))     | ||||
|             # logger.info(f'split_qa size====={len(split_qa)}') | ||||
|         return split_qa | ||||
| 
 | ||||
| 
 | ||||
|     def load_knowledge(self, knowledge_pkl_path): | ||||
|         ''' | ||||
|         读取或创建知识.pkl | ||||
|         ''' | ||||
|         if not os.path.exists(knowledge_pkl_path): | ||||
|             split_doc = self.split_document(doc_dir) | ||||
|             split_qa = self.split_conversation(qa_dir) | ||||
|             knowledge_chunks = split_doc + split_qa | ||||
|             with open(knowledge_pkl_path, 'wb') as file: | ||||
|                 pickle.dump(knowledge_chunks, file) | ||||
|         else: | ||||
|             with open(knowledge_pkl_path , 'rb') as f: | ||||
|                 knowledge_chunks = pickle.load(f) | ||||
|         return knowledge_chunks | ||||
|        | ||||
|   | ||||
|     def create_vector_db(self, emb_model): | ||||
|         ''' | ||||
|         创建并保存向量库 | ||||
|         ''' | ||||
|         logger.info(f'Creating index...') | ||||
|         split_doc = self.split_document(doc_dir) | ||||
|         #split_doc = self.split_document(doc_dir) | ||||
|         split_qa = self.split_conversation(qa_dir) | ||||
|         # logger.info(f'split_doc == {len(split_doc)}') | ||||
|         # logger.info(f'split_qa == {len(split_qa)}') | ||||
| @ -217,7 +209,6 @@ class Data_process(): | ||||
|         db.save_local(vector_db_dir) | ||||
|         return db | ||||
|          | ||||
|    | ||||
|     def load_vector_db(self, knowledge_pkl_path=knowledge_pkl_path, doc_dir=doc_dir, qa_dir=qa_dir): | ||||
|         ''' | ||||
|         读取向量库 | ||||
| @ -230,66 +221,6 @@ class Data_process(): | ||||
|             db = FAISS.load_local(vector_db_dir, emb_model, allow_dangerous_deserialization=True) | ||||
|         return db | ||||
|      | ||||
|   | ||||
|     def retrieve(self, query, vector_db, k=5): | ||||
|         ''' | ||||
|         基于query对向量库进行检索 | ||||
|         ''' | ||||
|         retriever = vector_db.as_retriever(search_kwargs={"k": k}) | ||||
|         docs = retriever.invoke(query) | ||||
|         return docs, retriever | ||||
|      | ||||
|     ##FlashrankRerank效果一般 | ||||
|     # def rerank(self, query, retriever): | ||||
|     #     compressor = FlashrankRerank() | ||||
|     #     compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever) | ||||
|     #     compressed_docs = compression_retriever.get_relevant_documents(query) | ||||
|     #     return compressed_docs | ||||
| 
 | ||||
|     def rerank(self, query, docs):  | ||||
|         reranker = self.load_rerank_model() | ||||
|         passages = [] | ||||
|         for doc in docs: | ||||
|             passages.append(str(doc.page_content)) | ||||
|         scores = reranker.compute_score([[query, passage] for passage in passages]) | ||||
|         sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True) | ||||
|         sorted_passages, sorted_scores = zip(*sorted_pairs) | ||||
|         return sorted_passages, sorted_scores | ||||
| 
 | ||||
| 
 | ||||
| # def create_prompt(question, context): | ||||
| #     from langchain.prompts import PromptTemplate | ||||
| #     prompt_template = f"""请基于以下内容回答问题: | ||||
| 
 | ||||
| #     {context} | ||||
| 
 | ||||
| #     问题: {question} | ||||
| #     回答:""" | ||||
| #     prompt = PromptTemplate( | ||||
| #         template=prompt_template, input_variables=["context", "question"] | ||||
| #     ) | ||||
| #     logger.info(f'Prompt: {prompt}') | ||||
| #     return prompt | ||||
|      | ||||
| def create_prompt(question, context): | ||||
|     prompt = f"""请基于以下内容: {context} 给出问题答案。问题如下: {question}。回答:""" | ||||
|     logger.info(f'Prompt: {prompt}') | ||||
|     return prompt | ||||
|          | ||||
| def test_zhipu(prompt): | ||||
|     from zhipuai import ZhipuAI | ||||
|     api_key = "" # 填写您自己的APIKey | ||||
|     if api_key == "": | ||||
|         raise ValueError("请填写api_key") | ||||
|     client = ZhipuAI(api_key=api_key)  | ||||
|     response = client.chat.completions.create( | ||||
|     model="glm-4",  # 填写需要调用的模型名称 | ||||
|     messages=[ | ||||
|         {"role": "user", "content": prompt[:100]} | ||||
|     ], | ||||
| ) | ||||
|     print(response.choices[0].message) | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     logger.info(data_dir) | ||||
|     if not os.path.exists(data_dir): | ||||
| @ -317,5 +248,3 @@ if __name__ == "__main__": | ||||
|     for i in range(len(scores)): | ||||
|         logger.info(str(scores[i]) + '\n') | ||||
|         logger.info(passages[i]) | ||||
|     prompt = create_prompt(query, passages[0]) | ||||
|     test_zhipu(prompt) ## 如果显示'Server disconnected without sending a response.'可能是由于上下文窗口限制 | ||||
| @ -13,8 +13,7 @@ from loguru import logger | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     query = """ | ||||
|         我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。 | ||||
|         无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想 | ||||
|         我现在经常会被别人催眠,做一些我不愿意做的事情,是什么原因? | ||||
|     """ | ||||
| 
 | ||||
|     """ | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 xzw
						xzw