OliveSensorAPI/rag/src/pipeline.py

106 lines
3.3 KiB
Python
Raw Normal View History

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from transformers.utils import logging
2024-04-14 12:22:35 +08:00
from rag.src.data_processing import Data_process
from rag.src.config.config import prompt_template
logger = logging.get_logger(__name__)
class EmoLLMRAG(object):
"""
EmoLLM RAG Pipeline
1. 根据 query 进行 embedding
2. vector DB 中检索数据
3. rerank 检索后的结果
4. query 和检索回来的 content 传入 LLM
"""
2024-03-24 15:18:35 +08:00
def __init__(self, model, retrieval_num=3, rerank_flag=False, select_num=3) -> None:
"""
输入 Model 进行初始化
DataProcessing obj: 进行数据处理包括数据 embedding/rerank
vectorstores: 加载vector DB如果没有应该重新创建
system prompt: 获取预定义的 system prompt
prompt template: 定义最后的输入到 LLM 中的 template
"""
self.model = model
2024-03-21 22:43:09 +08:00
self.data_processing_obj = Data_process()
self.vectorstores = self._load_vector_db()
2024-03-19 21:11:10 +08:00
self.prompt_template = prompt_template
self.retrieval_num = retrieval_num
self.rerank_flag = rerank_flag
self.select_num = select_num
def _load_vector_db(self):
"""
调用 embedding 模块给出接口 load vector DB
"""
2024-03-19 21:11:10 +08:00
vectorstores = self.data_processing_obj.load_vector_db()
2024-03-19 21:11:10 +08:00
return vectorstores
2024-03-19 21:11:10 +08:00
def get_retrieval_content(self, query) -> str:
"""
Input: 用户提问, 是否需要rerank
ouput: 检索后并且 rerank 的内容
"""
2024-04-14 12:22:35 +08:00
content = []
2024-03-19 21:11:10 +08:00
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
for doc in documents:
2024-04-14 12:22:35 +08:00
content.append(doc.page_content)
2024-03-21 22:43:09 +08:00
# 如果需要rerank调用接口对 documents 进行 rerank
if self.rerank_flag:
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
2024-04-14 12:22:35 +08:00
content = []
2024-03-21 22:43:09 +08:00
for doc in documents:
2024-04-14 12:22:35 +08:00
content.append(doc)
2024-03-21 22:43:09 +08:00
logger.info(f'Retrieval data: {content}')
return content
def generate_answer(self, query, content) -> str:
"""
Input: 用户提问 检索返回的内容
Output: 模型生成结果
"""
# 构建 template
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
prompt = PromptTemplate(
template=self.prompt_template,
2024-03-24 15:18:35 +08:00
input_variables=["query", "content"],
)
# 定义 chain
# output格式为 string
rag_chain = prompt | self.model | StrOutputParser()
# Run
generation = rag_chain.invoke(
{
"query": query,
"content": content,
}
)
return generation
def main(self, query) -> str:
"""
Input: 用户提问
output: LLM 生成的结果
定义整个 RAG pipeline 流程调度各个模块
TODO:
加入 RAGAS 评分系统
"""
content = self.get_retrieval_content(query)
response = self.generate_answer(query, content)
return response