Update RAG pipeline (#120)
This commit is contained in:
		
						commit
						ee6b365588
					
				| @ -33,5 +33,5 @@ prompt_template = """ | |||||||
| 	{system_prompt} | 	{system_prompt} | ||||||
| 	根据下面检索回来的信息,回答问题。 | 	根据下面检索回来的信息,回答问题。 | ||||||
| 	{content} | 	{content} | ||||||
| 	问题:{question} | 	问题:{query} | ||||||
| """ | """ | ||||||
| @ -12,7 +12,7 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings | |||||||
| from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader | from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader | ||||||
| from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter | from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter | ||||||
| from BCEmbedding import EmbeddingModel, RerankerModel | from BCEmbedding import EmbeddingModel, RerankerModel | ||||||
| from util.pipeline import EmoLLMRAG | # from util.pipeline import EmoLLMRAG | ||||||
| from transformers import AutoTokenizer, AutoModelForCausalLM | from transformers import AutoTokenizer, AutoModelForCausalLM | ||||||
| from langchain.document_loaders.pdf import PyPDFDirectoryLoader | from langchain.document_loaders.pdf import PyPDFDirectoryLoader | ||||||
| from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader | from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader | ||||||
| @ -254,7 +254,8 @@ if __name__ == "__main__": | |||||||
|     # query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版,1979年再版。" |     # query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版,1979年再版。" | ||||||
|     # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?" |     # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?" | ||||||
|     # query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性" |     # query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性" | ||||||
|     query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" |     # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" | ||||||
|  |     query = "我现在心情非常差,有什么解决办法吗?" | ||||||
|     docs, retriever = dp.retrieve(query, vector_db, k=10) |     docs, retriever = dp.retrieve(query, vector_db, k=10) | ||||||
|     logger.info(f'Query: {query}') |     logger.info(f'Query: {query}') | ||||||
|     logger.info("Retrieve results:") |     logger.info("Retrieve results:") | ||||||
|  | |||||||
| @ -1,20 +1,17 @@ | |||||||
| import os | import os | ||||||
| import json | import time | ||||||
| import pickle | import jwt | ||||||
| import numpy as np |  | ||||||
| from typing import Tuple |  | ||||||
| from sentence_transformers import SentenceTransformer |  | ||||||
| 
 | 
 | ||||||
| from config.config import knowledge_json_path, knowledge_pkl_path, model_repo, model_dir, base_dir | from config.config import base_dir, data_dir | ||||||
| from util.encode import load_embedding, encode_qa | from data_processing import Data_process | ||||||
| from util.pipeline import EmoLLMRAG | from pipeline import EmoLLMRAG | ||||||
|  | 
 | ||||||
|  | from langchain_openai import ChatOpenAI | ||||||
| from loguru import logger | from loguru import logger | ||||||
| from transformers import AutoTokenizer, AutoModelForCausalLM | from transformers import AutoTokenizer, AutoModelForCausalLM | ||||||
| import torch | import torch | ||||||
| import streamlit as st | import streamlit as st | ||||||
| from openxlab.model import download | from openxlab.model import download | ||||||
| from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir |  | ||||||
| from data_processing import Data_process |  | ||||||
| ''' | ''' | ||||||
| 	1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer | 	1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer | ||||||
| 	2)调用 embedding 提供的接口对 query 向量化 | 	2)调用 embedding 提供的接口对 query 向量化 | ||||||
| @ -24,21 +21,45 @@ from data_processing import Data_process | |||||||
| 	6)拼接 prompt 并调用模型返回结果 | 	6)拼接 prompt 并调用模型返回结果 | ||||||
| 
 | 
 | ||||||
| ''' | ''' | ||||||
| # download( | def get_glm(temprature): | ||||||
| #     model_repo=model_repo,  |     llm = ChatOpenAI( | ||||||
| #     output='model' |         model_name="glm-4", | ||||||
| # ) |         openai_api_base="https://open.bigmodel.cn/api/paas/v4", | ||||||
|  |         openai_api_key=generate_token("api-key"), | ||||||
|  |         streaming=False, | ||||||
|  |         temperature=temprature | ||||||
|  |     ) | ||||||
|  |     return llm  | ||||||
|  | 
 | ||||||
|  | def generate_token(apikey: str, exp_seconds: int=100): | ||||||
|  |     try: | ||||||
|  |         id, secret = apikey.split(".") | ||||||
|  |     except Exception as e: | ||||||
|  |         raise Exception("invalid apikey", e) | ||||||
|  |   | ||||||
|  |     payload = { | ||||||
|  |         "api_key": id, | ||||||
|  |         "exp": int(round(time.time() * 1000)) + exp_seconds * 1000, | ||||||
|  |         "timestamp": int(round(time.time() * 1000)), | ||||||
|  |     } | ||||||
|  |   | ||||||
|  |     return jwt.encode( | ||||||
|  |         payload, | ||||||
|  |         secret, | ||||||
|  |         algorithm="HS256", | ||||||
|  |         headers={"alg": "HS256", "sign_type": "SIGN"}, | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
| @st.cache_resource | @st.cache_resource | ||||||
| def load_model(): | def load_model(): | ||||||
|     model_dir = os.path.join(base_dir,'../model')  |     model_dir = os.path.join(base_dir,'../model')  | ||||||
|     logger.info(f'Loading model from {model_dir}') |     logger.info(f'Loading model from {model_dir}') | ||||||
|     model = ( |     model = ( | ||||||
|         AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True) |         AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True) | ||||||
|         .to(torch.bfloat16) |         .to(torch.bfloat16) | ||||||
|         .cuda() |         .cuda() | ||||||
|     ) |     ) | ||||||
|     tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |     tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True) | ||||||
|     return model, tokenizer |     return model, tokenizer | ||||||
| 
 | 
 | ||||||
| def main(query, system_prompt=''): | def main(query, system_prompt=''): | ||||||
| @ -61,3 +82,8 @@ def main(query, system_prompt=''): | |||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" |     query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" | ||||||
|     main(query) |     main(query) | ||||||
|  |     #model = get_glm(0.7) | ||||||
|  |     #rag_obj = EmoLLMRAG(model, 3) | ||||||
|  |     #res = rag_obj.main(query) | ||||||
|  |     #logger.info(res) | ||||||
|  | 
 | ||||||
|  | |||||||
| @ -2,9 +2,8 @@ from langchain_core.output_parsers import StrOutputParser | |||||||
| from langchain_core.prompts import PromptTemplate | from langchain_core.prompts import PromptTemplate | ||||||
| from transformers.utils import logging | from transformers.utils import logging | ||||||
| 
 | 
 | ||||||
| from data_processing import DataProcessing | from data_processing import Data_process | ||||||
| from config.config import retrieval_num, select_num, system_prompt, prompt_template | from config.config import system_prompt, prompt_template  | ||||||
| 
 |  | ||||||
| logger = logging.get_logger(__name__) | logger = logging.get_logger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -28,10 +27,8 @@ class EmoLLMRAG(object): | |||||||
| 
 | 
 | ||||||
|         """ |         """ | ||||||
|         self.model = model |         self.model = model | ||||||
|  |         self.data_processing_obj = Data_process() | ||||||
|         self.vectorstores = self._load_vector_db() |         self.vectorstores = self._load_vector_db() | ||||||
|         self.system_prompt = self._get_system_prompt() |  | ||||||
|         self.prompt_template = self._get_prompt_template() |  | ||||||
|         self.data_processing_obj = DataProcessing() |  | ||||||
|         self.system_prompt = system_prompt |         self.system_prompt = system_prompt | ||||||
|         self.prompt_template = prompt_template |         self.prompt_template = prompt_template | ||||||
|         self.retrieval_num = retrieval_num |         self.retrieval_num = retrieval_num | ||||||
| @ -43,8 +40,6 @@ class EmoLLMRAG(object): | |||||||
|             调用 embedding 模块给出接口 load vector DB |             调用 embedding 模块给出接口 load vector DB | ||||||
|         """ |         """ | ||||||
|         vectorstores = self.data_processing_obj.load_vector_db() |         vectorstores = self.data_processing_obj.load_vector_db() | ||||||
|         if not vectorstores: |  | ||||||
|             vectorstores = self.data_processing_obj.load_index_and_knowledge() |  | ||||||
| 
 | 
 | ||||||
|         return vectorstores  |         return vectorstores  | ||||||
| 
 | 
 | ||||||
| @ -57,13 +52,17 @@ class EmoLLMRAG(object): | |||||||
|         content = '' |         content = '' | ||||||
|         documents = self.vectorstores.similarity_search(query, k=self.retrieval_num) |         documents = self.vectorstores.similarity_search(query, k=self.retrieval_num) | ||||||
| 
 | 
 | ||||||
|         # 如果需要rerank,调用接口对 documents 进行 rerank |  | ||||||
|         if self.rerank_flag: |  | ||||||
|             documents = self.data_processing_obj.rerank(documents, self.select_num) |  | ||||||
| 
 |  | ||||||
|         for doc in documents: |         for doc in documents: | ||||||
|             content += doc.page_content |             content += doc.page_content | ||||||
| 
 | 
 | ||||||
|  |         # 如果需要rerank,调用接口对 documents 进行 rerank | ||||||
|  |         if self.rerank_flag: | ||||||
|  |             documents, _ = self.data_processing_obj.rerank(documents, self.select_num) | ||||||
|  | 
 | ||||||
|  |             content = '' | ||||||
|  |             for doc in documents: | ||||||
|  |                 content += doc | ||||||
|  |         logger.info(f'Retrieval data: {content}') | ||||||
|         return content |         return content | ||||||
|      |      | ||||||
|     def generate_answer(self, query, content) -> str: |     def generate_answer(self, query, content) -> str: | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 xzw
						xzw