commit
						a119fdb507
					
				| @ -25,4 +25,13 @@ log_path = os.path.join(log_dir, 'log.log')                         # file | ||||
| vector_db_dir = os.path.join(data_dir, 'vector_db.pkl') | ||||
| 
 | ||||
| select_num = 3 | ||||
| retrieval_num = 10 | ||||
| retrieval_num = 10 | ||||
| system_prompt = """ | ||||
| 	你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n | ||||
| """ | ||||
| prompt_template = """ | ||||
| 	{system_prompt} | ||||
| 	根据下面检索回来的信息,回答问题。 | ||||
| 	{content} | ||||
| 	问题:{question} | ||||
| """ | ||||
| @ -2,7 +2,8 @@ from langchain_core.output_parsers import StrOutputParser | ||||
| from langchain_core.prompts import PromptTemplate | ||||
| from transformers.utils import logging | ||||
| 
 | ||||
| from config.config import retrieval_num, select_num | ||||
| from data_processing import DataProcessing | ||||
| from config.config import retrieval_num, select_num, system_prompt, prompt_template | ||||
| 
 | ||||
| logger = logging.get_logger(__name__) | ||||
| 
 | ||||
| @ -16,7 +17,7 @@ class EmoLLMRAG(object): | ||||
|             4. 将 query 和检索回来的 content 传入 LLM 中 | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, model) -> None: | ||||
|     def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None: | ||||
|         """ | ||||
|             输入 Model 进行初始化  | ||||
| 
 | ||||
| @ -30,42 +31,35 @@ class EmoLLMRAG(object): | ||||
|         self.vectorstores = self._load_vector_db() | ||||
|         self.system_prompt = self._get_system_prompt() | ||||
|         self.prompt_template = self._get_prompt_template() | ||||
| 
 | ||||
|         # 等待 embedding team 封装对应接口 | ||||
|         #self.data_process_obj = DataProcessing() | ||||
|         self.data_processing_obj = DataProcessing() | ||||
|         self.system_prompt = system_prompt | ||||
|         self.prompt_template = prompt_template | ||||
|         self.retrieval_num = retrieval_num | ||||
|         self.rerank_flag = rerank_flag | ||||
|         self.select_num = select_num | ||||
| 
 | ||||
|     def _load_vector_db(self): | ||||
|         """ | ||||
|             调用 embedding 模块给出接口 load vector DB | ||||
|         """ | ||||
|         return  | ||||
|      | ||||
|     def _get_system_prompt(self) -> str: | ||||
|         """ | ||||
|             加载 system prompt | ||||
|         """ | ||||
|         return '' | ||||
|         vectorstores = self.data_processing_obj.load_vector_db() | ||||
|         if not vectorstores: | ||||
|             vectorstores = self.data_processing_obj.load_index_and_knowledge() | ||||
| 
 | ||||
|     def _get_prompt_template(self) -> str: | ||||
|         """ | ||||
|             加载 prompt template | ||||
|         """ | ||||
|         return '' | ||||
|         return vectorstores  | ||||
| 
 | ||||
|     def get_retrieval_content(self, query, rerank_flag=False) -> str: | ||||
|     def get_retrieval_content(self, query) -> str: | ||||
|         """ | ||||
|             Input: 用户提问, 是否需要rerank | ||||
|             ouput: 检索后并且 rerank 的内容         | ||||
|         """ | ||||
|      | ||||
|         content = '' | ||||
|         documents = self.vectorstores.similarity_search(query, k=retrieval_num) | ||||
|         documents = self.vectorstores.similarity_search(query, k=self.retrieval_num) | ||||
| 
 | ||||
|         # 如果需要rerank,调用接口对 documents 进行 rerank | ||||
|         if rerank_flag: | ||||
|             pass | ||||
|             # 等后续调用接口 | ||||
|             #documents = self.data_process_obj.rerank_documents(documents, select_num) | ||||
|         if self.rerank_flag: | ||||
|             documents = self.data_processing_obj.rerank(documents, self.select_num) | ||||
| 
 | ||||
|         for doc in documents: | ||||
|             content += doc.page_content | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 xzw
						xzw