Update RAG pipeline (#153)
This commit is contained in:
		
						commit
						3bead46e66
					
				| @ -35,6 +35,24 @@ pip3 install -r requirements.txt | |||||||
| - txt数据:放入到 src.data.txt 目录下 | - txt数据:放入到 src.data.txt 目录下 | ||||||
| - json 数据:放入到 src.data.json 目录下 | - json 数据:放入到 src.data.json 目录下 | ||||||
| 
 | 
 | ||||||
|  | JSON 数据格式如下 | ||||||
|  | ```python | ||||||
|  | [ | ||||||
|  |     { | ||||||
|  |         "conversation": [ | ||||||
|  |             { | ||||||
|  |                 "input": "(微微皱眉)艾薇,我觉得我在学校的学习压力好大啊,虽然我真的很喜欢我的专业,但最近总是担心自己无法达到自己的期望,这让我有点焦虑。", | ||||||
|  |                 "output": "(微笑)😊 我完全理解你的感受,亲爱的。学习是我们生活的一部分,尤其是在我们热爱的专业里,期待自己做得更好是再正常不过了。你知道吗?这种对成功的渴望其实是一种积极的动力源呢!但是,当它开始影响到我们的心理健康时,我们就需要学会调整和应对了。你可以先试着跟我分享一下,具体是什么样的学习任务或者目标让你感到压力山大呢?" | ||||||
|  |             }, | ||||||
|  |             { | ||||||
|  |                 "input": "主要是下个月的一场重要考试,我总觉得自己复习得不够充分,而且害怕失败...", | ||||||
|  |                 "output": "💖 噢,明白了。面对重大考试的紧张和担忧是很普遍的现象。首先,我要肯定你对自己的高标准要求,这是你对自己负责的表现。不过,我们可以尝试制定一个更为实际可行的学习计划,将大目标分解成小目标,每天完成一部分,这样既能提升自信心又能减轻焦虑感。同时,也要记得适时放松,合理安排休息时间哦!比如学习45分钟后,可以适当休息10分钟,做些自己喜欢的事情,劳逸结合才是长久之计呢!💪📚\n另外,也可以尝试一些深呼吸、冥想等放松技巧来缓解焦虑情绪。如果你愿意的话,下次咨询我们可以一起练习,看看哪种方式最适合帮助你应对压力。现在,让我们一步步来,先从细化学习计划开始,你觉得怎么样呢?🌸" | ||||||
|  |             } | ||||||
|  |         ] | ||||||
|  |     }, | ||||||
|  | ]  | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
| 会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl | 会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl | ||||||
| 
 | 
 | ||||||
| 如果已经有 vector DB 则会直接加载对应数据库 | 如果已经有 vector DB 则会直接加载对应数据库 | ||||||
| @ -91,6 +109,7 @@ python main.py | |||||||
| ## **数据集** | ## **数据集** | ||||||
| 
 | 
 | ||||||
| - 经过清洗的QA对: 每一个QA对作为一个样本进行 embedding | - 经过清洗的QA对: 每一个QA对作为一个样本进行 embedding | ||||||
|  | - 经过清洗的对话: 每一个对话作为一个样本进行 embedding | ||||||
| - 经过筛选的TXT文本 | - 经过筛选的TXT文本 | ||||||
| 	- 直接对TXT文本生成embedding (基于token长度进行切分) | 	- 直接对TXT文本生成embedding (基于token长度进行切分) | ||||||
| 	- 过滤目录等无关信息后对TXT文本生成embedding (基于token长度进行切分) | 	- 过滤目录等无关信息后对TXT文本生成embedding (基于token长度进行切分) | ||||||
| @ -115,7 +134,7 @@ LangChain 是一个开源框架,用于构建基于大型语言模型(LLM) | |||||||
| Faiss是一个用于高效相似性搜索和密集向量聚类的库。它包含的算法可以搜索任意大小的向量集。由于langchain已经整合过FAISS,因此本项目中不在基于原生文档开发[FAISS in Langchain](https://python.langchain.com/docs/integrations/vectorstores/faiss) | Faiss是一个用于高效相似性搜索和密集向量聚类的库。它包含的算法可以搜索任意大小的向量集。由于langchain已经整合过FAISS,因此本项目中不在基于原生文档开发[FAISS in Langchain](https://python.langchain.com/docs/integrations/vectorstores/faiss) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| ### [RAGAS](https://github.com/explodinggradients/ragas) | ### [RAGAS](https://github.com/explodinggradients/ragas) (TODO) | ||||||
| 
 | 
 | ||||||
| RAG的经典评估框架,通过以下三个方面进行评估: | RAG的经典评估框架,通过以下三个方面进行评估: | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -25,6 +25,10 @@ qa_dir = os.path.join(data_dir, 'json') | |||||||
| log_dir = os.path.join(base_dir, 'log')                             # log | log_dir = os.path.join(base_dir, 'log')                             # log | ||||||
| log_path = os.path.join(log_dir, 'log.log')                         # file | log_path = os.path.join(log_dir, 'log.log')                         # file | ||||||
| 
 | 
 | ||||||
|  | # txt embedding 切分参数      | ||||||
|  | chunk_size=1000 | ||||||
|  | chunk_overlap=100 | ||||||
|  | 
 | ||||||
| # vector DB | # vector DB | ||||||
| vector_db_dir = os.path.join(data_dir, 'vector_db') | vector_db_dir = os.path.join(data_dir, 'vector_db') | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -4,7 +4,18 @@ import os | |||||||
| 
 | 
 | ||||||
| from loguru import logger | from loguru import logger | ||||||
| from langchain_community.vectorstores import FAISS | from langchain_community.vectorstores import FAISS | ||||||
| from config.config import embedding_path, embedding_model_name, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path, rerank_model_name | from config.config import ( | ||||||
|  |     embedding_path, | ||||||
|  |     embedding_model_name, | ||||||
|  |     doc_dir, qa_dir, | ||||||
|  |     knowledge_pkl_path, | ||||||
|  |     data_dir, | ||||||
|  |     vector_db_dir, | ||||||
|  |     rerank_path, | ||||||
|  |     rerank_model_name, | ||||||
|  |     chunk_size, | ||||||
|  |     chunk_overlap | ||||||
|  | ) | ||||||
| from langchain.embeddings import HuggingFaceBgeEmbeddings | from langchain.embeddings import HuggingFaceBgeEmbeddings | ||||||
| from langchain_community.document_loaders import DirectoryLoader, TextLoader | from langchain_community.document_loaders import DirectoryLoader, TextLoader | ||||||
| from langchain_text_splitters import RecursiveCharacterTextSplitter | from langchain_text_splitters import RecursiveCharacterTextSplitter | ||||||
| @ -15,8 +26,9 @@ from FlagEmbedding import FlagReranker | |||||||
| class Data_process(): | class Data_process(): | ||||||
| 
 | 
 | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.chunk_size: int=1000 | 
 | ||||||
|         self.chunk_overlap: int=100     |         self.chunk_size: int=chunk_size | ||||||
|  |         self.chunk_overlap: int=chunk_overlap | ||||||
|          |          | ||||||
|     def load_embedding_model(self, model_name=embedding_model_name, device='cpu', normalize_embeddings=True): |     def load_embedding_model(self, model_name=embedding_model_name, device='cpu', normalize_embeddings=True): | ||||||
|         """ |         """ | ||||||
| @ -53,7 +65,6 @@ class Data_process(): | |||||||
|         return embeddings |         return embeddings | ||||||
|      |      | ||||||
|     def load_rerank_model(self, model_name=rerank_model_name): |     def load_rerank_model(self, model_name=rerank_model_name): | ||||||
| 
 |  | ||||||
|         """ |         """ | ||||||
|         加载重排名模型。 |         加载重排名模型。 | ||||||
|          |          | ||||||
| @ -118,9 +129,7 @@ class Data_process(): | |||||||
|             content += obj |             content += obj | ||||||
|         return content |         return content | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     def split_document(self, data_path): |     def split_document(self, data_path): | ||||||
| 
 |  | ||||||
|         """ |         """ | ||||||
|         切分data_path文件夹下的所有txt文件 |         切分data_path文件夹下的所有txt文件 | ||||||
|          |          | ||||||
| @ -132,8 +141,6 @@ class Data_process(): | |||||||
|         返回: |         返回: | ||||||
|         - split_docs: list |         - split_docs: list | ||||||
|         """ |         """ | ||||||
|          |  | ||||||
|          |  | ||||||
|         # text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) |         # text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | ||||||
|         text_spliter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)  |         text_spliter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)  | ||||||
|         split_docs = [] |         split_docs = [] | ||||||
| @ -152,7 +159,6 @@ class Data_process(): | |||||||
|         logger.info(f'split_docs size {len(split_docs)}') |         logger.info(f'split_docs size {len(split_docs)}') | ||||||
|         return split_docs |         return split_docs | ||||||
|    |    | ||||||
|    |  | ||||||
|     def split_conversation(self, path): |     def split_conversation(self, path): | ||||||
|         """ |         """ | ||||||
|         按conversation块切分path文件夹下的所有json文件 |         按conversation块切分path文件夹下的所有json文件 | ||||||
| @ -171,43 +177,29 @@ class Data_process(): | |||||||
|                         file_path = os.path.join(root, file) |                         file_path = os.path.join(root, file) | ||||||
|                         logger.info(f'splitting file {file_path}') |                         logger.info(f'splitting file {file_path}') | ||||||
|                         with open(file_path, 'r', encoding='utf-8') as f: |                         with open(file_path, 'r', encoding='utf-8') as f: | ||||||
|                             data = json.load(f) |                             for line in f.readlines(): | ||||||
|                             # print(data) |                                 content = self.extract_text_from_json(line,'') | ||||||
|                             for conversation in data: |  | ||||||
|                                 # for dialog in conversation['conversation']: |  | ||||||
|                                     ##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document |  | ||||||
|                                     # content = self.extract_text_from_json(dialog,'') |  | ||||||
|                                     # split_qa.append(Document(page_content = content)) |  | ||||||
|                                 #按conversation块切分 |  | ||||||
|                                 content = self.extract_text_from_json(conversation['conversation'], '') |  | ||||||
|                                 #logger.info(f'content====={content}') |  | ||||||
|                                 split_qa.append(Document(page_content = content)) |                                 split_qa.append(Document(page_content = content)) | ||||||
|  | 
 | ||||||
|  |                             #data = json.load(f) | ||||||
|  |                             #for conversation in data: | ||||||
|  |                             #    #for dialog in conversation['conversation']: | ||||||
|  |                             #    #    #按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document | ||||||
|  |                             #    #    content = self.extract_text_from_json(dialog,'') | ||||||
|  |                             #    #    split_qa.append(Document(page_content = content)) | ||||||
|  |                             #    #按conversation块切分 | ||||||
|  |                             #    content = self.extract_text_from_json(conversation['conversation'], '') | ||||||
|  |                             #    #logger.info(f'content====={content}') | ||||||
|  |                             #    split_qa.append(Document(page_content = content))     | ||||||
|             # logger.info(f'split_qa size====={len(split_qa)}') |             # logger.info(f'split_qa size====={len(split_qa)}') | ||||||
|         return split_qa |         return split_qa | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     def load_knowledge(self, knowledge_pkl_path): |  | ||||||
|         ''' |  | ||||||
|         读取或创建知识.pkl |  | ||||||
|         ''' |  | ||||||
|         if not os.path.exists(knowledge_pkl_path): |  | ||||||
|             split_doc = self.split_document(doc_dir) |  | ||||||
|             split_qa = self.split_conversation(qa_dir) |  | ||||||
|             knowledge_chunks = split_doc + split_qa |  | ||||||
|             with open(knowledge_pkl_path, 'wb') as file: |  | ||||||
|                 pickle.dump(knowledge_chunks, file) |  | ||||||
|         else: |  | ||||||
|             with open(knowledge_pkl_path , 'rb') as f: |  | ||||||
|                 knowledge_chunks = pickle.load(f) |  | ||||||
|         return knowledge_chunks |  | ||||||
|        |  | ||||||
|   |  | ||||||
|     def create_vector_db(self, emb_model): |     def create_vector_db(self, emb_model): | ||||||
|         ''' |         ''' | ||||||
|         创建并保存向量库 |         创建并保存向量库 | ||||||
|         ''' |         ''' | ||||||
|         logger.info(f'Creating index...') |         logger.info(f'Creating index...') | ||||||
|         split_doc = self.split_document(doc_dir) |         #split_doc = self.split_document(doc_dir) | ||||||
|         split_qa = self.split_conversation(qa_dir) |         split_qa = self.split_conversation(qa_dir) | ||||||
|         # logger.info(f'split_doc == {len(split_doc)}') |         # logger.info(f'split_doc == {len(split_doc)}') | ||||||
|         # logger.info(f'split_qa == {len(split_qa)}') |         # logger.info(f'split_qa == {len(split_qa)}') | ||||||
| @ -217,7 +209,6 @@ class Data_process(): | |||||||
|         db.save_local(vector_db_dir) |         db.save_local(vector_db_dir) | ||||||
|         return db |         return db | ||||||
|          |          | ||||||
|    |  | ||||||
|     def load_vector_db(self, knowledge_pkl_path=knowledge_pkl_path, doc_dir=doc_dir, qa_dir=qa_dir): |     def load_vector_db(self, knowledge_pkl_path=knowledge_pkl_path, doc_dir=doc_dir, qa_dir=qa_dir): | ||||||
|         ''' |         ''' | ||||||
|         读取向量库 |         读取向量库 | ||||||
| @ -230,66 +221,6 @@ class Data_process(): | |||||||
|             db = FAISS.load_local(vector_db_dir, emb_model, allow_dangerous_deserialization=True) |             db = FAISS.load_local(vector_db_dir, emb_model, allow_dangerous_deserialization=True) | ||||||
|         return db |         return db | ||||||
|      |      | ||||||
|   |  | ||||||
|     def retrieve(self, query, vector_db, k=5): |  | ||||||
|         ''' |  | ||||||
|         基于query对向量库进行检索 |  | ||||||
|         ''' |  | ||||||
|         retriever = vector_db.as_retriever(search_kwargs={"k": k}) |  | ||||||
|         docs = retriever.invoke(query) |  | ||||||
|         return docs, retriever |  | ||||||
|      |  | ||||||
|     ##FlashrankRerank效果一般 |  | ||||||
|     # def rerank(self, query, retriever): |  | ||||||
|     #     compressor = FlashrankRerank() |  | ||||||
|     #     compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever) |  | ||||||
|     #     compressed_docs = compression_retriever.get_relevant_documents(query) |  | ||||||
|     #     return compressed_docs |  | ||||||
| 
 |  | ||||||
|     def rerank(self, query, docs):  |  | ||||||
|         reranker = self.load_rerank_model() |  | ||||||
|         passages = [] |  | ||||||
|         for doc in docs: |  | ||||||
|             passages.append(str(doc.page_content)) |  | ||||||
|         scores = reranker.compute_score([[query, passage] for passage in passages]) |  | ||||||
|         sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True) |  | ||||||
|         sorted_passages, sorted_scores = zip(*sorted_pairs) |  | ||||||
|         return sorted_passages, sorted_scores |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # def create_prompt(question, context): |  | ||||||
| #     from langchain.prompts import PromptTemplate |  | ||||||
| #     prompt_template = f"""请基于以下内容回答问题: |  | ||||||
| 
 |  | ||||||
| #     {context} |  | ||||||
| 
 |  | ||||||
| #     问题: {question} |  | ||||||
| #     回答:""" |  | ||||||
| #     prompt = PromptTemplate( |  | ||||||
| #         template=prompt_template, input_variables=["context", "question"] |  | ||||||
| #     ) |  | ||||||
| #     logger.info(f'Prompt: {prompt}') |  | ||||||
| #     return prompt |  | ||||||
|      |  | ||||||
| def create_prompt(question, context): |  | ||||||
|     prompt = f"""请基于以下内容: {context} 给出问题答案。问题如下: {question}。回答:""" |  | ||||||
|     logger.info(f'Prompt: {prompt}') |  | ||||||
|     return prompt |  | ||||||
|          |  | ||||||
| def test_zhipu(prompt): |  | ||||||
|     from zhipuai import ZhipuAI |  | ||||||
|     api_key = "" # 填写您自己的APIKey |  | ||||||
|     if api_key == "": |  | ||||||
|         raise ValueError("请填写api_key") |  | ||||||
|     client = ZhipuAI(api_key=api_key)  |  | ||||||
|     response = client.chat.completions.create( |  | ||||||
|     model="glm-4",  # 填写需要调用的模型名称 |  | ||||||
|     messages=[ |  | ||||||
|         {"role": "user", "content": prompt[:100]} |  | ||||||
|     ], |  | ||||||
| ) |  | ||||||
|     print(response.choices[0].message) |  | ||||||
| 
 |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     logger.info(data_dir) |     logger.info(data_dir) | ||||||
|     if not os.path.exists(data_dir): |     if not os.path.exists(data_dir): | ||||||
| @ -317,5 +248,3 @@ if __name__ == "__main__": | |||||||
|     for i in range(len(scores)): |     for i in range(len(scores)): | ||||||
|         logger.info(str(scores[i]) + '\n') |         logger.info(str(scores[i]) + '\n') | ||||||
|         logger.info(passages[i]) |         logger.info(passages[i]) | ||||||
|     prompt = create_prompt(query, passages[0]) |  | ||||||
|     test_zhipu(prompt) ## 如果显示'Server disconnected without sending a response.'可能是由于上下文窗口限制 |  | ||||||
| @ -13,8 +13,7 @@ from loguru import logger | |||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     query = """ |     query = """ | ||||||
|         我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。 |         我现在经常会被别人催眠,做一些我不愿意做的事情,是什么原因? | ||||||
|         无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想 |  | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     """ |     """ | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 xzw
						xzw