2024-03-17 10:31:11 +08:00
|
|
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
|
|
|
from langchain_core.prompts import PromptTemplate
|
|
|
|
|
from transformers.utils import logging
|
|
|
|
|
|
2024-04-14 12:22:35 +08:00
|
|
|
|
from rag.src.data_processing import Data_process
|
|
|
|
|
from rag.src.config.config import prompt_template
|
2024-03-17 10:31:11 +08:00
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmoLLMRAG(object):
|
|
|
|
|
"""
|
|
|
|
|
EmoLLM RAG Pipeline
|
|
|
|
|
1. 根据 query 进行 embedding
|
|
|
|
|
2. 从 vector DB 中检索数据
|
|
|
|
|
3. rerank 检索后的结果
|
|
|
|
|
4. 将 query 和检索回来的 content 传入 LLM 中
|
|
|
|
|
"""
|
|
|
|
|
|
2024-03-24 15:18:35 +08:00
|
|
|
|
def __init__(self, model, retrieval_num=3, rerank_flag=False, select_num=3) -> None:
|
2024-03-17 10:31:11 +08:00
|
|
|
|
"""
|
|
|
|
|
输入 Model 进行初始化
|
|
|
|
|
|
|
|
|
|
DataProcessing obj: 进行数据处理,包括数据 embedding/rerank
|
|
|
|
|
vectorstores: 加载vector DB。如果没有应该重新创建
|
|
|
|
|
system prompt: 获取预定义的 system prompt
|
|
|
|
|
prompt template: 定义最后的输入到 LLM 中的 template
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
self.model = model
|
2024-03-21 22:43:09 +08:00
|
|
|
|
self.data_processing_obj = Data_process()
|
2024-03-17 10:31:11 +08:00
|
|
|
|
self.vectorstores = self._load_vector_db()
|
2024-03-19 21:11:10 +08:00
|
|
|
|
self.prompt_template = prompt_template
|
|
|
|
|
self.retrieval_num = retrieval_num
|
|
|
|
|
self.rerank_flag = rerank_flag
|
|
|
|
|
self.select_num = select_num
|
2024-03-17 10:31:11 +08:00
|
|
|
|
|
|
|
|
|
def _load_vector_db(self):
|
|
|
|
|
"""
|
|
|
|
|
调用 embedding 模块给出接口 load vector DB
|
|
|
|
|
"""
|
2024-03-19 21:11:10 +08:00
|
|
|
|
vectorstores = self.data_processing_obj.load_vector_db()
|
2024-03-17 10:31:11 +08:00
|
|
|
|
|
2024-03-19 21:11:10 +08:00
|
|
|
|
return vectorstores
|
2024-03-17 10:31:11 +08:00
|
|
|
|
|
2024-03-19 21:11:10 +08:00
|
|
|
|
def get_retrieval_content(self, query) -> str:
|
2024-03-17 10:31:11 +08:00
|
|
|
|
"""
|
|
|
|
|
Input: 用户提问, 是否需要rerank
|
|
|
|
|
ouput: 检索后并且 rerank 的内容
|
|
|
|
|
"""
|
|
|
|
|
|
2024-04-14 12:22:35 +08:00
|
|
|
|
content = []
|
2024-03-19 21:11:10 +08:00
|
|
|
|
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
|
2024-03-17 10:31:11 +08:00
|
|
|
|
|
|
|
|
|
for doc in documents:
|
2024-04-14 12:22:35 +08:00
|
|
|
|
content.append(doc.page_content)
|
2024-03-17 10:31:11 +08:00
|
|
|
|
|
2024-03-21 22:43:09 +08:00
|
|
|
|
# 如果需要rerank,调用接口对 documents 进行 rerank
|
|
|
|
|
if self.rerank_flag:
|
|
|
|
|
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
|
|
|
|
|
|
2024-04-14 12:22:35 +08:00
|
|
|
|
content = []
|
2024-03-21 22:43:09 +08:00
|
|
|
|
for doc in documents:
|
2024-04-14 12:22:35 +08:00
|
|
|
|
content.append(doc)
|
2024-03-21 22:43:09 +08:00
|
|
|
|
logger.info(f'Retrieval data: {content}')
|
2024-03-17 10:31:11 +08:00
|
|
|
|
return content
|
|
|
|
|
|
|
|
|
|
def generate_answer(self, query, content) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Input: 用户提问, 检索返回的内容
|
|
|
|
|
Output: 模型生成结果
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 构建 template
|
|
|
|
|
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
|
|
|
|
|
prompt = PromptTemplate(
|
|
|
|
|
template=self.prompt_template,
|
2024-03-24 15:18:35 +08:00
|
|
|
|
input_variables=["query", "content"],
|
2024-03-17 10:31:11 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 定义 chain
|
|
|
|
|
# output格式为 string
|
|
|
|
|
rag_chain = prompt | self.model | StrOutputParser()
|
|
|
|
|
|
|
|
|
|
# Run
|
|
|
|
|
generation = rag_chain.invoke(
|
|
|
|
|
{
|
|
|
|
|
"query": query,
|
|
|
|
|
"content": content,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
return generation
|
|
|
|
|
|
|
|
|
|
def main(self, query) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Input: 用户提问
|
|
|
|
|
output: LLM 生成的结果
|
|
|
|
|
|
|
|
|
|
定义整个 RAG 的 pipeline 流程,调度各个模块
|
|
|
|
|
TODO:
|
|
|
|
|
加入 RAGAS 评分系统
|
|
|
|
|
"""
|
|
|
|
|
content = self.get_retrieval_content(query)
|
|
|
|
|
response = self.generate_answer(query, content)
|
|
|
|
|
|
|
|
|
|
return response
|