This commit is contained in:
Anooyman 2024-03-19 21:11:10 +08:00
parent af53e9744d
commit 96f6ce307e
2 changed files with 27 additions and 24 deletions

View File

@ -26,3 +26,12 @@ vector_db_dir = os.path.join(data_dir, 'vector_db.pkl')
select_num = 3 select_num = 3
retrieval_num = 10 retrieval_num = 10
system_prompt = """
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇我有一些心理问题请你用专业的知识和温柔可爱俏皮的口吻帮我解决回复中可以穿插一些可爱的Emoji表情符号或者文本符号\n
"""
prompt_template = """
{system_prompt}
根据下面检索回来的信息回答问题
{content}
问题{question}
"""

View File

@ -2,7 +2,8 @@ from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from transformers.utils import logging from transformers.utils import logging
from config.config import retrieval_num, select_num from data_processing import DataProcessing
from config.config import retrieval_num, select_num, system_prompt, prompt_template
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -16,7 +17,7 @@ class EmoLLMRAG(object):
4. query 和检索回来的 content 传入 LLM 4. query 和检索回来的 content 传入 LLM
""" """
def __init__(self, model) -> None: def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None:
""" """
输入 Model 进行初始化 输入 Model 进行初始化
@ -30,42 +31,35 @@ class EmoLLMRAG(object):
self.vectorstores = self._load_vector_db() self.vectorstores = self._load_vector_db()
self.system_prompt = self._get_system_prompt() self.system_prompt = self._get_system_prompt()
self.prompt_template = self._get_prompt_template() self.prompt_template = self._get_prompt_template()
self.data_processing_obj = DataProcessing()
# 等待 embedding team 封装对应接口 self.system_prompt = system_prompt
#self.data_process_obj = DataProcessing() self.prompt_template = prompt_template
self.retrieval_num = retrieval_num
self.rerank_flag = rerank_flag
self.select_num = select_num
def _load_vector_db(self): def _load_vector_db(self):
""" """
调用 embedding 模块给出接口 load vector DB 调用 embedding 模块给出接口 load vector DB
""" """
return vectorstores = self.data_processing_obj.load_vector_db()
if not vectorstores:
vectorstores = self.data_processing_obj.load_index_and_knowledge()
def _get_system_prompt(self) -> str: return vectorstores
"""
加载 system prompt
"""
return ''
def _get_prompt_template(self) -> str: def get_retrieval_content(self, query) -> str:
"""
加载 prompt template
"""
return ''
def get_retrieval_content(self, query, rerank_flag=False) -> str:
""" """
Input: 用户提问, 是否需要rerank Input: 用户提问, 是否需要rerank
ouput: 检索后并且 rerank 的内容 ouput: 检索后并且 rerank 的内容
""" """
content = '' content = ''
documents = self.vectorstores.similarity_search(query, k=retrieval_num) documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
# 如果需要rerank调用接口对 documents 进行 rerank # 如果需要rerank调用接口对 documents 进行 rerank
if rerank_flag: if self.rerank_flag:
pass documents = self.data_processing_obj.rerank(documents, self.select_num)
# 等后续调用接口
#documents = self.data_process_obj.rerank_documents(documents, select_num)
for doc in documents: for doc in documents:
content += doc.page_content content += doc.page_content