Update basic RAG pipeline

只加了基本的 pipeline,还未进行测试,等具体接口确定之后进行调试
This commit is contained in:
edward_ke 2024-03-17 10:31:11 +08:00
parent c08e4dccd6
commit 50a5129c77
4 changed files with 147 additions and 3 deletions

View File

@ -2,3 +2,5 @@ sentence_transformers
transformers
numpy
loguru
langchain
torch

View File

@ -3,6 +3,7 @@ import os
cur_dir = os.path.dirname(os.path.abspath(__file__)) # config
src_dir = os.path.dirname(cur_dir) # src
base_dir = os.path.dirname(src_dir) # base
model_repo = 'ajupyter/EmoLLM_aiwei'
# model
model_dir = os.path.join(base_dir, 'model') # model
@ -17,3 +18,6 @@ knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pickle
# log
log_dir = os.path.join(base_dir, 'log') # log
log_path = os.path.join(log_dir, 'log.log') # file
select_num = 3
retrieval_num = 10

View File

@ -5,8 +5,19 @@ import numpy as np
from typing import Tuple
from sentence_transformers import SentenceTransformer
from config.config import knowledge_json_path, knowledge_pkl_path
from config.config import knowledge_json_path, knowledge_pkl_path, model_repo
from util.encode import load_embedding, encode_qa
from util.pipeline import EmoLLMRAG
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
download(
model_repo=model_repo,
output='model'
)
"""
@ -62,6 +73,19 @@ def main():
## 2. 将 contents 拼接为 prompt传给 LLM作为 {已知内容}
## 3. 要求 LLM 根据已知内容回复
@st.cache_resource
def load_model():
model = (
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
return model, tokenizer
if __name__ == '__main__':
main()
#main()
query = ''
model, tokenizer = load_model()
rag_obj = EmoLLMRAG(model)
response = rag_obj.main(query)

114
rag/src/util/pipeline.py Normal file
View File

@ -0,0 +1,114 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from transformers.utils import logging
from config.config import retrieval_num, select_num
logger = logging.get_logger(__name__)
class EmoLLMRAG(object):
"""
EmoLLM RAG Pipeline
1. 根据 query 进行 embedding
2. vector DB 中检索数据
3. rerank 检索后的结果
4. query 和检索回来的 content 传入 LLM
"""
def __init__(self, model) -> None:
"""
输入 Model 进行初始化
DataProcessing obj: 进行数据处理包括数据 embedding/rerank
vectorstores: 加载vector DB如果没有应该重新创建
system prompt: 获取预定义的 system prompt
prompt template: 定义最后的输入到 LLM 中的 template
"""
self.model = model
self.vectorstores = self._load_vector_db()
self.system_prompt = self._get_system_prompt()
self.prompt_template = self._get_prompt_template()
# 等待 embedding team 封装对应接口
#self.data_process_obj = DataProcessing()
def _load_vector_db(self):
"""
调用 embedding 模块给出接口 load vector DB
"""
return
def _get_system_prompt(self) -> str:
"""
加载 system prompt
"""
return ''
def _get_prompt_template(self) -> str:
"""
加载 prompt template
"""
return ''
def get_retrieval_content(self, query, rerank_flag=False) -> str:
"""
Input: 用户提问, 是否需要rerank
ouput: 检索后并且 rerank 的内容
"""
content = ''
documents = self.vectorstores.similarity_search(query, k=retrieval_num)
# 如果需要rerank调用接口对 documents 进行 rerank
if rerank_flag:
pass
# 等后续调用接口
#documents = self.data_process_obj.rerank_documents(documents, select_num)
for doc in documents:
content += doc.page_content
return content
def generate_answer(self, query, content) -> str:
"""
Input: 用户提问 检索返回的内容
Output: 模型生成结果
"""
# 构建 template
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
prompt = PromptTemplate(
template=self.prompt_template,
input_variables=["query", "content", "system_prompt"],
)
# 定义 chain
# output格式为 string
rag_chain = prompt | self.model | StrOutputParser()
# Run
generation = rag_chain.invoke(
{
"query": query,
"content": content,
"system_prompt": self.system_prompt
}
)
return generation
def main(self, query) -> str:
"""
Input: 用户提问
output: LLM 生成的结果
定义整个 RAG pipeline 流程调度各个模块
TODO:
加入 RAGAS 评分系统
"""
content = self.get_retrieval_content(query)
response = self.generate_answer(query, content)
return response