106 lines
3.3 KiB
Python
106 lines
3.3 KiB
Python
from langchain_core.output_parsers import StrOutputParser
|
||
from langchain_core.prompts import PromptTemplate
|
||
from transformers.utils import logging
|
||
|
||
from rag.src.data_processing import Data_process
|
||
from rag.src.config.config import prompt_template
|
||
logger = logging.get_logger(__name__)
|
||
|
||
|
||
class EmoLLMRAG(object):
|
||
"""
|
||
EmoLLM RAG Pipeline
|
||
1. 根据 query 进行 embedding
|
||
2. 从 vector DB 中检索数据
|
||
3. rerank 检索后的结果
|
||
4. 将 query 和检索回来的 content 传入 LLM 中
|
||
"""
|
||
|
||
def __init__(self, model, retrieval_num=3, rerank_flag=False, select_num=3) -> None:
|
||
"""
|
||
输入 Model 进行初始化
|
||
|
||
DataProcessing obj: 进行数据处理,包括数据 embedding/rerank
|
||
vectorstores: 加载vector DB。如果没有应该重新创建
|
||
system prompt: 获取预定义的 system prompt
|
||
prompt template: 定义最后的输入到 LLM 中的 template
|
||
|
||
"""
|
||
self.model = model
|
||
self.data_processing_obj = Data_process()
|
||
self.vectorstores = self._load_vector_db()
|
||
self.prompt_template = prompt_template
|
||
self.retrieval_num = retrieval_num
|
||
self.rerank_flag = rerank_flag
|
||
self.select_num = select_num
|
||
|
||
def _load_vector_db(self):
|
||
"""
|
||
调用 embedding 模块给出接口 load vector DB
|
||
"""
|
||
vectorstores = self.data_processing_obj.load_vector_db()
|
||
|
||
return vectorstores
|
||
|
||
def get_retrieval_content(self, query) -> str:
|
||
"""
|
||
Input: 用户提问, 是否需要rerank
|
||
ouput: 检索后并且 rerank 的内容
|
||
"""
|
||
|
||
content = []
|
||
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
|
||
|
||
for doc in documents:
|
||
content.append(doc.page_content)
|
||
|
||
# 如果需要rerank,调用接口对 documents 进行 rerank
|
||
if self.rerank_flag:
|
||
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
|
||
|
||
content = []
|
||
for doc in documents:
|
||
content.append(doc)
|
||
logger.info(f'Retrieval data: {content}')
|
||
return content
|
||
|
||
def generate_answer(self, query, content) -> str:
|
||
"""
|
||
Input: 用户提问, 检索返回的内容
|
||
Output: 模型生成结果
|
||
"""
|
||
|
||
# 构建 template
|
||
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
|
||
prompt = PromptTemplate(
|
||
template=self.prompt_template,
|
||
input_variables=["query", "content"],
|
||
)
|
||
|
||
# 定义 chain
|
||
# output格式为 string
|
||
rag_chain = prompt | self.model | StrOutputParser()
|
||
|
||
# Run
|
||
generation = rag_chain.invoke(
|
||
{
|
||
"query": query,
|
||
"content": content,
|
||
}
|
||
)
|
||
return generation
|
||
|
||
def main(self, query) -> str:
|
||
"""
|
||
Input: 用户提问
|
||
output: LLM 生成的结果
|
||
|
||
定义整个 RAG 的 pipeline 流程,调度各个模块
|
||
TODO:
|
||
加入 RAGAS 评分系统
|
||
"""
|
||
content = self.get_retrieval_content(query)
|
||
response = self.generate_answer(query, content)
|
||
|
||
return response
|