OliveSensorAPI/rag/src/pipeline.py
2024-04-14 12:22:35 +08:00

106 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from transformers.utils import logging
from rag.src.data_processing import Data_process
from rag.src.config.config import prompt_template
logger = logging.get_logger(__name__)
class EmoLLMRAG(object):
"""
EmoLLM RAG Pipeline
1. 根据 query 进行 embedding
2. 从 vector DB 中检索数据
3. rerank 检索后的结果
4. 将 query 和检索回来的 content 传入 LLM 中
"""
def __init__(self, model, retrieval_num=3, rerank_flag=False, select_num=3) -> None:
"""
输入 Model 进行初始化
DataProcessing obj: 进行数据处理,包括数据 embedding/rerank
vectorstores: 加载vector DB。如果没有应该重新创建
system prompt: 获取预定义的 system prompt
prompt template: 定义最后的输入到 LLM 中的 template
"""
self.model = model
self.data_processing_obj = Data_process()
self.vectorstores = self._load_vector_db()
self.prompt_template = prompt_template
self.retrieval_num = retrieval_num
self.rerank_flag = rerank_flag
self.select_num = select_num
def _load_vector_db(self):
"""
调用 embedding 模块给出接口 load vector DB
"""
vectorstores = self.data_processing_obj.load_vector_db()
return vectorstores
def get_retrieval_content(self, query) -> str:
"""
Input: 用户提问, 是否需要rerank
ouput: 检索后并且 rerank 的内容
"""
content = []
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
for doc in documents:
content.append(doc.page_content)
# 如果需要rerank调用接口对 documents 进行 rerank
if self.rerank_flag:
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
content = []
for doc in documents:
content.append(doc)
logger.info(f'Retrieval data: {content}')
return content
def generate_answer(self, query, content) -> str:
"""
Input: 用户提问, 检索返回的内容
Output: 模型生成结果
"""
# 构建 template
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
prompt = PromptTemplate(
template=self.prompt_template,
input_variables=["query", "content"],
)
# 定义 chain
# output格式为 string
rag_chain = prompt | self.model | StrOutputParser()
# Run
generation = rag_chain.invoke(
{
"query": query,
"content": content,
}
)
return generation
def main(self, query) -> str:
"""
Input: 用户提问
output: LLM 生成的结果
定义整个 RAG 的 pipeline 流程,调度各个模块
TODO:
加入 RAGAS 评分系统
"""
content = self.get_retrieval_content(query)
response = self.generate_answer(query, content)
return response