115 lines
3.4 KiB
Python
115 lines
3.4 KiB
Python
|
from langchain_core.output_parsers import StrOutputParser
|
|||
|
from langchain_core.prompts import PromptTemplate
|
|||
|
from transformers.utils import logging
|
|||
|
|
|||
|
from config.config import retrieval_num, select_num
|
|||
|
|
|||
|
logger = logging.get_logger(__name__)
|
|||
|
|
|||
|
|
|||
|
class EmoLLMRAG(object):
|
|||
|
"""
|
|||
|
EmoLLM RAG Pipeline
|
|||
|
1. 根据 query 进行 embedding
|
|||
|
2. 从 vector DB 中检索数据
|
|||
|
3. rerank 检索后的结果
|
|||
|
4. 将 query 和检索回来的 content 传入 LLM 中
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self, model) -> None:
|
|||
|
"""
|
|||
|
输入 Model 进行初始化
|
|||
|
|
|||
|
DataProcessing obj: 进行数据处理,包括数据 embedding/rerank
|
|||
|
vectorstores: 加载vector DB。如果没有应该重新创建
|
|||
|
system prompt: 获取预定义的 system prompt
|
|||
|
prompt template: 定义最后的输入到 LLM 中的 template
|
|||
|
|
|||
|
"""
|
|||
|
self.model = model
|
|||
|
self.vectorstores = self._load_vector_db()
|
|||
|
self.system_prompt = self._get_system_prompt()
|
|||
|
self.prompt_template = self._get_prompt_template()
|
|||
|
|
|||
|
# 等待 embedding team 封装对应接口
|
|||
|
#self.data_process_obj = DataProcessing()
|
|||
|
|
|||
|
def _load_vector_db(self):
|
|||
|
"""
|
|||
|
调用 embedding 模块给出接口 load vector DB
|
|||
|
"""
|
|||
|
return
|
|||
|
|
|||
|
def _get_system_prompt(self) -> str:
|
|||
|
"""
|
|||
|
加载 system prompt
|
|||
|
"""
|
|||
|
return ''
|
|||
|
|
|||
|
def _get_prompt_template(self) -> str:
|
|||
|
"""
|
|||
|
加载 prompt template
|
|||
|
"""
|
|||
|
return ''
|
|||
|
|
|||
|
def get_retrieval_content(self, query, rerank_flag=False) -> str:
|
|||
|
"""
|
|||
|
Input: 用户提问, 是否需要rerank
|
|||
|
ouput: 检索后并且 rerank 的内容
|
|||
|
"""
|
|||
|
|
|||
|
content = ''
|
|||
|
documents = self.vectorstores.similarity_search(query, k=retrieval_num)
|
|||
|
|
|||
|
# 如果需要rerank,调用接口对 documents 进行 rerank
|
|||
|
if rerank_flag:
|
|||
|
pass
|
|||
|
# 等后续调用接口
|
|||
|
#documents = self.data_process_obj.rerank_documents(documents, select_num)
|
|||
|
|
|||
|
for doc in documents:
|
|||
|
content += doc.page_content
|
|||
|
|
|||
|
return content
|
|||
|
|
|||
|
def generate_answer(self, query, content) -> str:
|
|||
|
"""
|
|||
|
Input: 用户提问, 检索返回的内容
|
|||
|
Output: 模型生成结果
|
|||
|
"""
|
|||
|
|
|||
|
# 构建 template
|
|||
|
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
|
|||
|
prompt = PromptTemplate(
|
|||
|
template=self.prompt_template,
|
|||
|
input_variables=["query", "content", "system_prompt"],
|
|||
|
)
|
|||
|
|
|||
|
# 定义 chain
|
|||
|
# output格式为 string
|
|||
|
rag_chain = prompt | self.model | StrOutputParser()
|
|||
|
|
|||
|
# Run
|
|||
|
generation = rag_chain.invoke(
|
|||
|
{
|
|||
|
"query": query,
|
|||
|
"content": content,
|
|||
|
"system_prompt": self.system_prompt
|
|||
|
}
|
|||
|
)
|
|||
|
return generation
|
|||
|
|
|||
|
def main(self, query) -> str:
|
|||
|
"""
|
|||
|
Input: 用户提问
|
|||
|
output: LLM 生成的结果
|
|||
|
|
|||
|
定义整个 RAG 的 pipeline 流程,调度各个模块
|
|||
|
TODO:
|
|||
|
加入 RAGAS 评分系统
|
|||
|
"""
|
|||
|
content = self.get_retrieval_content(query)
|
|||
|
response = self.generate_answer(query, content)
|
|||
|
|
|||
|
return response
|