Update basic RAG pipeline
只加了基本的 pipeline,还未进行测试,等具体接口确定之后进行调试
This commit is contained in:
parent
c08e4dccd6
commit
50a5129c77
@ -2,3 +2,5 @@ sentence_transformers
|
|||||||
transformers
|
transformers
|
||||||
numpy
|
numpy
|
||||||
loguru
|
loguru
|
||||||
|
langchain
|
||||||
|
torch
|
||||||
|
@ -3,6 +3,7 @@ import os
|
|||||||
cur_dir = os.path.dirname(os.path.abspath(__file__)) # config
|
cur_dir = os.path.dirname(os.path.abspath(__file__)) # config
|
||||||
src_dir = os.path.dirname(cur_dir) # src
|
src_dir = os.path.dirname(cur_dir) # src
|
||||||
base_dir = os.path.dirname(src_dir) # base
|
base_dir = os.path.dirname(src_dir) # base
|
||||||
|
model_repo = 'ajupyter/EmoLLM_aiwei'
|
||||||
|
|
||||||
# model
|
# model
|
||||||
model_dir = os.path.join(base_dir, 'model') # model
|
model_dir = os.path.join(base_dir, 'model') # model
|
||||||
@ -17,3 +18,6 @@ knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pickle
|
|||||||
# log
|
# log
|
||||||
log_dir = os.path.join(base_dir, 'log') # log
|
log_dir = os.path.join(base_dir, 'log') # log
|
||||||
log_path = os.path.join(log_dir, 'log.log') # file
|
log_path = os.path.join(log_dir, 'log.log') # file
|
||||||
|
|
||||||
|
select_num = 3
|
||||||
|
retrieval_num = 10
|
@ -5,8 +5,19 @@ import numpy as np
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
from config.config import knowledge_json_path, knowledge_pkl_path
|
from config.config import knowledge_json_path, knowledge_pkl_path, model_repo
|
||||||
from util.encode import load_embedding, encode_qa
|
from util.encode import load_embedding, encode_qa
|
||||||
|
from util.pipeline import EmoLLMRAG
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
import torch
|
||||||
|
import streamlit as st
|
||||||
|
from openxlab.model import download
|
||||||
|
|
||||||
|
download(
|
||||||
|
model_repo=model_repo,
|
||||||
|
output='model'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -62,6 +73,19 @@ def main():
|
|||||||
## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容}
|
## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容}
|
||||||
## 3. 要求 LLM 根据已知内容回复
|
## 3. 要求 LLM 根据已知内容回复
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def load_model():
|
||||||
|
model = (
|
||||||
|
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
|
||||||
|
.to(torch.bfloat16)
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
#main()
|
||||||
|
query = ''
|
||||||
|
model, tokenizer = load_model()
|
||||||
|
rag_obj = EmoLLMRAG(model)
|
||||||
|
response = rag_obj.main(query)
|
114
rag/src/util/pipeline.py
Normal file
114
rag/src/util/pipeline.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from config.config import retrieval_num, select_num
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmoLLMRAG(object):
|
||||||
|
"""
|
||||||
|
EmoLLM RAG Pipeline
|
||||||
|
1. 根据 query 进行 embedding
|
||||||
|
2. 从 vector DB 中检索数据
|
||||||
|
3. rerank 检索后的结果
|
||||||
|
4. 将 query 和检索回来的 content 传入 LLM 中
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model) -> None:
|
||||||
|
"""
|
||||||
|
输入 Model 进行初始化
|
||||||
|
|
||||||
|
DataProcessing obj: 进行数据处理,包括数据 embedding/rerank
|
||||||
|
vectorstores: 加载vector DB。如果没有应该重新创建
|
||||||
|
system prompt: 获取预定义的 system prompt
|
||||||
|
prompt template: 定义最后的输入到 LLM 中的 template
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.vectorstores = self._load_vector_db()
|
||||||
|
self.system_prompt = self._get_system_prompt()
|
||||||
|
self.prompt_template = self._get_prompt_template()
|
||||||
|
|
||||||
|
# 等待 embedding team 封装对应接口
|
||||||
|
#self.data_process_obj = DataProcessing()
|
||||||
|
|
||||||
|
def _load_vector_db(self):
|
||||||
|
"""
|
||||||
|
调用 embedding 模块给出接口 load vector DB
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
def _get_system_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
加载 system prompt
|
||||||
|
"""
|
||||||
|
return ''
|
||||||
|
|
||||||
|
def _get_prompt_template(self) -> str:
|
||||||
|
"""
|
||||||
|
加载 prompt template
|
||||||
|
"""
|
||||||
|
return ''
|
||||||
|
|
||||||
|
def get_retrieval_content(self, query, rerank_flag=False) -> str:
|
||||||
|
"""
|
||||||
|
Input: 用户提问, 是否需要rerank
|
||||||
|
ouput: 检索后并且 rerank 的内容
|
||||||
|
"""
|
||||||
|
|
||||||
|
content = ''
|
||||||
|
documents = self.vectorstores.similarity_search(query, k=retrieval_num)
|
||||||
|
|
||||||
|
# 如果需要rerank,调用接口对 documents 进行 rerank
|
||||||
|
if rerank_flag:
|
||||||
|
pass
|
||||||
|
# 等后续调用接口
|
||||||
|
#documents = self.data_process_obj.rerank_documents(documents, select_num)
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
content += doc.page_content
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def generate_answer(self, query, content) -> str:
|
||||||
|
"""
|
||||||
|
Input: 用户提问, 检索返回的内容
|
||||||
|
Output: 模型生成结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 构建 template
|
||||||
|
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template=self.prompt_template,
|
||||||
|
input_variables=["query", "content", "system_prompt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 定义 chain
|
||||||
|
# output格式为 string
|
||||||
|
rag_chain = prompt | self.model | StrOutputParser()
|
||||||
|
|
||||||
|
# Run
|
||||||
|
generation = rag_chain.invoke(
|
||||||
|
{
|
||||||
|
"query": query,
|
||||||
|
"content": content,
|
||||||
|
"system_prompt": self.system_prompt
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return generation
|
||||||
|
|
||||||
|
def main(self, query) -> str:
|
||||||
|
"""
|
||||||
|
Input: 用户提问
|
||||||
|
output: LLM 生成的结果
|
||||||
|
|
||||||
|
定义整个 RAG 的 pipeline 流程,调度各个模块
|
||||||
|
TODO:
|
||||||
|
加入 RAGAS 评分系统
|
||||||
|
"""
|
||||||
|
content = self.get_retrieval_content(query)
|
||||||
|
response = self.generate_answer(query, content)
|
||||||
|
|
||||||
|
return response
|
Loading…
Reference in New Issue
Block a user