Update
This commit is contained in:
parent
af53e9744d
commit
96f6ce307e
@ -25,4 +25,13 @@ log_path = os.path.join(log_dir, 'log.log') # file
|
||||
vector_db_dir = os.path.join(data_dir, 'vector_db.pkl')
|
||||
|
||||
select_num = 3
|
||||
retrieval_num = 10
|
||||
retrieval_num = 10
|
||||
system_prompt = """
|
||||
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
|
||||
"""
|
||||
prompt_template = """
|
||||
{system_prompt}
|
||||
根据下面检索回来的信息,回答问题。
|
||||
{content}
|
||||
问题:{question}
|
||||
"""
|
@ -2,7 +2,8 @@ from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from transformers.utils import logging
|
||||
|
||||
from config.config import retrieval_num, select_num
|
||||
from data_processing import DataProcessing
|
||||
from config.config import retrieval_num, select_num, system_prompt, prompt_template
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -16,7 +17,7 @@ class EmoLLMRAG(object):
|
||||
4. 将 query 和检索回来的 content 传入 LLM 中
|
||||
"""
|
||||
|
||||
def __init__(self, model) -> None:
|
||||
def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None:
|
||||
"""
|
||||
输入 Model 进行初始化
|
||||
|
||||
@ -30,42 +31,35 @@ class EmoLLMRAG(object):
|
||||
self.vectorstores = self._load_vector_db()
|
||||
self.system_prompt = self._get_system_prompt()
|
||||
self.prompt_template = self._get_prompt_template()
|
||||
|
||||
# 等待 embedding team 封装对应接口
|
||||
#self.data_process_obj = DataProcessing()
|
||||
self.data_processing_obj = DataProcessing()
|
||||
self.system_prompt = system_prompt
|
||||
self.prompt_template = prompt_template
|
||||
self.retrieval_num = retrieval_num
|
||||
self.rerank_flag = rerank_flag
|
||||
self.select_num = select_num
|
||||
|
||||
def _load_vector_db(self):
|
||||
"""
|
||||
调用 embedding 模块给出接口 load vector DB
|
||||
"""
|
||||
return
|
||||
|
||||
def _get_system_prompt(self) -> str:
|
||||
"""
|
||||
加载 system prompt
|
||||
"""
|
||||
return ''
|
||||
vectorstores = self.data_processing_obj.load_vector_db()
|
||||
if not vectorstores:
|
||||
vectorstores = self.data_processing_obj.load_index_and_knowledge()
|
||||
|
||||
def _get_prompt_template(self) -> str:
|
||||
"""
|
||||
加载 prompt template
|
||||
"""
|
||||
return ''
|
||||
return vectorstores
|
||||
|
||||
def get_retrieval_content(self, query, rerank_flag=False) -> str:
|
||||
def get_retrieval_content(self, query) -> str:
|
||||
"""
|
||||
Input: 用户提问, 是否需要rerank
|
||||
ouput: 检索后并且 rerank 的内容
|
||||
"""
|
||||
|
||||
content = ''
|
||||
documents = self.vectorstores.similarity_search(query, k=retrieval_num)
|
||||
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
|
||||
|
||||
# 如果需要rerank,调用接口对 documents 进行 rerank
|
||||
if rerank_flag:
|
||||
pass
|
||||
# 等后续调用接口
|
||||
#documents = self.data_process_obj.rerank_documents(documents, select_num)
|
||||
if self.rerank_flag:
|
||||
documents = self.data_processing_obj.rerank(documents, self.select_num)
|
||||
|
||||
for doc in documents:
|
||||
content += doc.page_content
|
Loading…
Reference in New Issue
Block a user