diff --git a/rag/src/config/config.py b/rag/src/config/config.py index d803d64..d4dcfe3 100644 --- a/rag/src/config/config.py +++ b/rag/src/config/config.py @@ -25,4 +25,13 @@ log_path = os.path.join(log_dir, 'log.log') # file vector_db_dir = os.path.join(data_dir, 'vector_db.pkl') select_num = 3 -retrieval_num = 10 \ No newline at end of file +retrieval_num = 10 +system_prompt = """ + 你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n +""" +prompt_template = """ + {system_prompt} + 根据下面检索回来的信息,回答问题。 + {content} + 问题:{question} +""" \ No newline at end of file diff --git a/rag/src/util/pipeline.py b/rag/src/pipeline.py similarity index 76% rename from rag/src/util/pipeline.py rename to rag/src/pipeline.py index a6f2cdf..214eef3 100644 --- a/rag/src/util/pipeline.py +++ b/rag/src/pipeline.py @@ -2,7 +2,8 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate from transformers.utils import logging -from config.config import retrieval_num, select_num +from data_processing import DataProcessing +from config.config import retrieval_num, select_num, system_prompt, prompt_template logger = logging.get_logger(__name__) @@ -16,7 +17,7 @@ class EmoLLMRAG(object): 4. 将 query 和检索回来的 content 传入 LLM 中 """ - def __init__(self, model) -> None: + def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None: """ 输入 Model 进行初始化 @@ -30,42 +31,35 @@ class EmoLLMRAG(object): self.vectorstores = self._load_vector_db() self.system_prompt = self._get_system_prompt() self.prompt_template = self._get_prompt_template() - - # 等待 embedding team 封装对应接口 - #self.data_process_obj = DataProcessing() + self.data_processing_obj = DataProcessing() + self.system_prompt = system_prompt + self.prompt_template = prompt_template + self.retrieval_num = retrieval_num + self.rerank_flag = rerank_flag + self.select_num = select_num def _load_vector_db(self): """ 调用 embedding 模块给出接口 load vector DB """ - return - - def _get_system_prompt(self) -> str: - """ - 加载 system prompt - """ - return '' + vectorstores = self.data_processing_obj.load_vector_db() + if not vectorstores: + vectorstores = self.data_processing_obj.load_index_and_knowledge() - def _get_prompt_template(self) -> str: - """ - 加载 prompt template - """ - return '' + return vectorstores - def get_retrieval_content(self, query, rerank_flag=False) -> str: + def get_retrieval_content(self, query) -> str: """ Input: 用户提问, 是否需要rerank ouput: 检索后并且 rerank 的内容 """ content = '' - documents = self.vectorstores.similarity_search(query, k=retrieval_num) + documents = self.vectorstores.similarity_search(query, k=self.retrieval_num) # 如果需要rerank,调用接口对 documents 进行 rerank - if rerank_flag: - pass - # 等后续调用接口 - #documents = self.data_process_obj.rerank_documents(documents, select_num) + if self.rerank_flag: + documents = self.data_processing_obj.rerank(documents, self.select_num) for doc in documents: content += doc.page_content