Update
This commit is contained in:
parent
af53e9744d
commit
96f6ce307e
@ -26,3 +26,12 @@ vector_db_dir = os.path.join(data_dir, 'vector_db.pkl')
|
|||||||
|
|
||||||
select_num = 3
|
select_num = 3
|
||||||
retrieval_num = 10
|
retrieval_num = 10
|
||||||
|
system_prompt = """
|
||||||
|
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
|
||||||
|
"""
|
||||||
|
prompt_template = """
|
||||||
|
{system_prompt}
|
||||||
|
根据下面检索回来的信息,回答问题。
|
||||||
|
{content}
|
||||||
|
问题:{question}
|
||||||
|
"""
|
@ -2,7 +2,8 @@ from langchain_core.output_parsers import StrOutputParser
|
|||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from config.config import retrieval_num, select_num
|
from data_processing import DataProcessing
|
||||||
|
from config.config import retrieval_num, select_num, system_prompt, prompt_template
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ class EmoLLMRAG(object):
|
|||||||
4. 将 query 和检索回来的 content 传入 LLM 中
|
4. 将 query 和检索回来的 content 传入 LLM 中
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model) -> None:
|
def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None:
|
||||||
"""
|
"""
|
||||||
输入 Model 进行初始化
|
输入 Model 进行初始化
|
||||||
|
|
||||||
@ -30,42 +31,35 @@ class EmoLLMRAG(object):
|
|||||||
self.vectorstores = self._load_vector_db()
|
self.vectorstores = self._load_vector_db()
|
||||||
self.system_prompt = self._get_system_prompt()
|
self.system_prompt = self._get_system_prompt()
|
||||||
self.prompt_template = self._get_prompt_template()
|
self.prompt_template = self._get_prompt_template()
|
||||||
|
self.data_processing_obj = DataProcessing()
|
||||||
# 等待 embedding team 封装对应接口
|
self.system_prompt = system_prompt
|
||||||
#self.data_process_obj = DataProcessing()
|
self.prompt_template = prompt_template
|
||||||
|
self.retrieval_num = retrieval_num
|
||||||
|
self.rerank_flag = rerank_flag
|
||||||
|
self.select_num = select_num
|
||||||
|
|
||||||
def _load_vector_db(self):
|
def _load_vector_db(self):
|
||||||
"""
|
"""
|
||||||
调用 embedding 模块给出接口 load vector DB
|
调用 embedding 模块给出接口 load vector DB
|
||||||
"""
|
"""
|
||||||
return
|
vectorstores = self.data_processing_obj.load_vector_db()
|
||||||
|
if not vectorstores:
|
||||||
|
vectorstores = self.data_processing_obj.load_index_and_knowledge()
|
||||||
|
|
||||||
def _get_system_prompt(self) -> str:
|
return vectorstores
|
||||||
"""
|
|
||||||
加载 system prompt
|
|
||||||
"""
|
|
||||||
return ''
|
|
||||||
|
|
||||||
def _get_prompt_template(self) -> str:
|
def get_retrieval_content(self, query) -> str:
|
||||||
"""
|
|
||||||
加载 prompt template
|
|
||||||
"""
|
|
||||||
return ''
|
|
||||||
|
|
||||||
def get_retrieval_content(self, query, rerank_flag=False) -> str:
|
|
||||||
"""
|
"""
|
||||||
Input: 用户提问, 是否需要rerank
|
Input: 用户提问, 是否需要rerank
|
||||||
ouput: 检索后并且 rerank 的内容
|
ouput: 检索后并且 rerank 的内容
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content = ''
|
content = ''
|
||||||
documents = self.vectorstores.similarity_search(query, k=retrieval_num)
|
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
|
||||||
|
|
||||||
# 如果需要rerank,调用接口对 documents 进行 rerank
|
# 如果需要rerank,调用接口对 documents 进行 rerank
|
||||||
if rerank_flag:
|
if self.rerank_flag:
|
||||||
pass
|
documents = self.data_processing_obj.rerank(documents, self.select_num)
|
||||||
# 等后续调用接口
|
|
||||||
#documents = self.data_process_obj.rerank_documents(documents, select_num)
|
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
content += doc.page_content
|
content += doc.page_content
|
Loading…
Reference in New Issue
Block a user