OliveSensorAPI/rag/src/util/pipeline.py

115 lines
3.4 KiB
Python
Raw Normal View History

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from transformers.utils import logging
from config.config import retrieval_num, select_num
logger = logging.get_logger(__name__)
class EmoLLMRAG(object):
"""
EmoLLM RAG Pipeline
1. 根据 query 进行 embedding
2. vector DB 中检索数据
3. rerank 检索后的结果
4. query 和检索回来的 content 传入 LLM
"""
def __init__(self, model) -> None:
"""
输入 Model 进行初始化
DataProcessing obj: 进行数据处理包括数据 embedding/rerank
vectorstores: 加载vector DB如果没有应该重新创建
system prompt: 获取预定义的 system prompt
prompt template: 定义最后的输入到 LLM 中的 template
"""
self.model = model
self.vectorstores = self._load_vector_db()
self.system_prompt = self._get_system_prompt()
self.prompt_template = self._get_prompt_template()
# 等待 embedding team 封装对应接口
#self.data_process_obj = DataProcessing()
def _load_vector_db(self):
"""
调用 embedding 模块给出接口 load vector DB
"""
return
def _get_system_prompt(self) -> str:
"""
加载 system prompt
"""
return ''
def _get_prompt_template(self) -> str:
"""
加载 prompt template
"""
return ''
def get_retrieval_content(self, query, rerank_flag=False) -> str:
"""
Input: 用户提问, 是否需要rerank
ouput: 检索后并且 rerank 的内容
"""
content = ''
documents = self.vectorstores.similarity_search(query, k=retrieval_num)
# 如果需要rerank调用接口对 documents 进行 rerank
if rerank_flag:
pass
# 等后续调用接口
#documents = self.data_process_obj.rerank_documents(documents, select_num)
for doc in documents:
content += doc.page_content
return content
def generate_answer(self, query, content) -> str:
"""
Input: 用户提问 检索返回的内容
Output: 模型生成结果
"""
# 构建 template
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
prompt = PromptTemplate(
template=self.prompt_template,
input_variables=["query", "content", "system_prompt"],
)
# 定义 chain
# output格式为 string
rag_chain = prompt | self.model | StrOutputParser()
# Run
generation = rag_chain.invoke(
{
"query": query,
"content": content,
"system_prompt": self.system_prompt
}
)
return generation
def main(self, query) -> str:
"""
Input: 用户提问
output: LLM 生成的结果
定义整个 RAG pipeline 流程调度各个模块
TODO:
加入 RAGAS 评分系统
"""
content = self.get_retrieval_content(query)
response = self.generate_answer(query, content)
return response