Update basic RAG pipeline
只加了基本的 pipeline,还未进行测试,等具体接口确定之后进行调试
This commit is contained in:
parent
c08e4dccd6
commit
50a5129c77
@ -2,3 +2,5 @@ sentence_transformers
|
||||
transformers
|
||||
numpy
|
||||
loguru
|
||||
langchain
|
||||
torch
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__)) # config
|
||||
src_dir = os.path.dirname(cur_dir) # src
|
||||
base_dir = os.path.dirname(src_dir) # base
|
||||
model_repo = 'ajupyter/EmoLLM_aiwei'
|
||||
|
||||
# model
|
||||
model_dir = os.path.join(base_dir, 'model') # model
|
||||
@ -17,3 +18,6 @@ knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pickle
|
||||
# log
|
||||
log_dir = os.path.join(base_dir, 'log') # log
|
||||
log_path = os.path.join(log_dir, 'log.log') # file
|
||||
|
||||
select_num = 3
|
||||
retrieval_num = 10
|
@ -5,8 +5,19 @@ import numpy as np
|
||||
from typing import Tuple
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from config.config import knowledge_json_path, knowledge_pkl_path
|
||||
from config.config import knowledge_json_path, knowledge_pkl_path, model_repo
|
||||
from util.encode import load_embedding, encode_qa
|
||||
from util.pipeline import EmoLLMRAG
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
import streamlit as st
|
||||
from openxlab.model import download
|
||||
|
||||
download(
|
||||
model_repo=model_repo,
|
||||
output='model'
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
@ -62,6 +73,19 @@ def main():
|
||||
## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容}
|
||||
## 3. 要求 LLM 根据已知内容回复
|
||||
|
||||
@st.cache_resource
|
||||
def load_model():
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
|
||||
.to(torch.bfloat16)
|
||||
.cuda()
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
#main()
|
||||
query = ''
|
||||
model, tokenizer = load_model()
|
||||
rag_obj = EmoLLMRAG(model)
|
||||
response = rag_obj.main(query)
|
114
rag/src/util/pipeline.py
Normal file
114
rag/src/util/pipeline.py
Normal file
@ -0,0 +1,114 @@
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from transformers.utils import logging
|
||||
|
||||
from config.config import retrieval_num, select_num
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class EmoLLMRAG(object):
|
||||
"""
|
||||
EmoLLM RAG Pipeline
|
||||
1. 根据 query 进行 embedding
|
||||
2. 从 vector DB 中检索数据
|
||||
3. rerank 检索后的结果
|
||||
4. 将 query 和检索回来的 content 传入 LLM 中
|
||||
"""
|
||||
|
||||
def __init__(self, model) -> None:
|
||||
"""
|
||||
输入 Model 进行初始化
|
||||
|
||||
DataProcessing obj: 进行数据处理,包括数据 embedding/rerank
|
||||
vectorstores: 加载vector DB。如果没有应该重新创建
|
||||
system prompt: 获取预定义的 system prompt
|
||||
prompt template: 定义最后的输入到 LLM 中的 template
|
||||
|
||||
"""
|
||||
self.model = model
|
||||
self.vectorstores = self._load_vector_db()
|
||||
self.system_prompt = self._get_system_prompt()
|
||||
self.prompt_template = self._get_prompt_template()
|
||||
|
||||
# 等待 embedding team 封装对应接口
|
||||
#self.data_process_obj = DataProcessing()
|
||||
|
||||
def _load_vector_db(self):
|
||||
"""
|
||||
调用 embedding 模块给出接口 load vector DB
|
||||
"""
|
||||
return
|
||||
|
||||
def _get_system_prompt(self) -> str:
|
||||
"""
|
||||
加载 system prompt
|
||||
"""
|
||||
return ''
|
||||
|
||||
def _get_prompt_template(self) -> str:
|
||||
"""
|
||||
加载 prompt template
|
||||
"""
|
||||
return ''
|
||||
|
||||
def get_retrieval_content(self, query, rerank_flag=False) -> str:
|
||||
"""
|
||||
Input: 用户提问, 是否需要rerank
|
||||
ouput: 检索后并且 rerank 的内容
|
||||
"""
|
||||
|
||||
content = ''
|
||||
documents = self.vectorstores.similarity_search(query, k=retrieval_num)
|
||||
|
||||
# 如果需要rerank,调用接口对 documents 进行 rerank
|
||||
if rerank_flag:
|
||||
pass
|
||||
# 等后续调用接口
|
||||
#documents = self.data_process_obj.rerank_documents(documents, select_num)
|
||||
|
||||
for doc in documents:
|
||||
content += doc.page_content
|
||||
|
||||
return content
|
||||
|
||||
def generate_answer(self, query, content) -> str:
|
||||
"""
|
||||
Input: 用户提问, 检索返回的内容
|
||||
Output: 模型生成结果
|
||||
"""
|
||||
|
||||
# 构建 template
|
||||
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
|
||||
prompt = PromptTemplate(
|
||||
template=self.prompt_template,
|
||||
input_variables=["query", "content", "system_prompt"],
|
||||
)
|
||||
|
||||
# 定义 chain
|
||||
# output格式为 string
|
||||
rag_chain = prompt | self.model | StrOutputParser()
|
||||
|
||||
# Run
|
||||
generation = rag_chain.invoke(
|
||||
{
|
||||
"query": query,
|
||||
"content": content,
|
||||
"system_prompt": self.system_prompt
|
||||
}
|
||||
)
|
||||
return generation
|
||||
|
||||
def main(self, query) -> str:
|
||||
"""
|
||||
Input: 用户提问
|
||||
output: LLM 生成的结果
|
||||
|
||||
定义整个 RAG 的 pipeline 流程,调度各个模块
|
||||
TODO:
|
||||
加入 RAGAS 评分系统
|
||||
"""
|
||||
content = self.get_retrieval_content(query)
|
||||
response = self.generate_answer(query, content)
|
||||
|
||||
return response
|
Loading…
Reference in New Issue
Block a user