From 50a5129c77049c2fa756711a2eeb34c3346e1ce0 Mon Sep 17 00:00:00 2001 From: edward_ke Date: Sun, 17 Mar 2024 10:31:11 +0800 Subject: [PATCH] Update basic RAG pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 只加了基本的 pipeline,还未进行测试,等具体接口确定之后进行调试 --- rag/requirements.txt | 4 +- rag/src/config/config.py | 4 ++ rag/src/main.py | 28 +++++++++- rag/src/util/pipeline.py | 114 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 147 insertions(+), 3 deletions(-) create mode 100644 rag/src/util/pipeline.py diff --git a/rag/requirements.txt b/rag/requirements.txt index 08289b2..15f915c 100644 --- a/rag/requirements.txt +++ b/rag/requirements.txt @@ -1,4 +1,6 @@ sentence_transformers transformers numpy -loguru \ No newline at end of file +loguru +langchain +torch diff --git a/rag/src/config/config.py b/rag/src/config/config.py index 4c7e335..b84327f 100644 --- a/rag/src/config/config.py +++ b/rag/src/config/config.py @@ -3,6 +3,7 @@ import os cur_dir = os.path.dirname(os.path.abspath(__file__)) # config src_dir = os.path.dirname(cur_dir) # src base_dir = os.path.dirname(src_dir) # base +model_repo = 'ajupyter/EmoLLM_aiwei' # model model_dir = os.path.join(base_dir, 'model') # model @@ -17,3 +18,6 @@ knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pickle # log log_dir = os.path.join(base_dir, 'log') # log log_path = os.path.join(log_dir, 'log.log') # file + +select_num = 3 +retrieval_num = 10 \ No newline at end of file diff --git a/rag/src/main.py b/rag/src/main.py index 219ce85..97f60a0 100644 --- a/rag/src/main.py +++ b/rag/src/main.py @@ -5,8 +5,19 @@ import numpy as np from typing import Tuple from sentence_transformers import SentenceTransformer -from config.config import knowledge_json_path, knowledge_pkl_path +from config.config import knowledge_json_path, knowledge_pkl_path, model_repo from util.encode import load_embedding, encode_qa +from util.pipeline import EmoLLMRAG + +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import streamlit as st +from openxlab.model import download + +download( + model_repo=model_repo, + output='model' +) """ @@ -62,6 +73,19 @@ def main(): ## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容} ## 3. 要求 LLM 根据已知内容回复 +@st.cache_resource +def load_model(): + model = ( + AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True) + .to(torch.bfloat16) + .cuda() + ) + tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True) + return model, tokenizer if __name__ == '__main__': - main() + #main() + query = '' + model, tokenizer = load_model() + rag_obj = EmoLLMRAG(model) + response = rag_obj.main(query) \ No newline at end of file diff --git a/rag/src/util/pipeline.py b/rag/src/util/pipeline.py new file mode 100644 index 0000000..a6f2cdf --- /dev/null +++ b/rag/src/util/pipeline.py @@ -0,0 +1,114 @@ +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate +from transformers.utils import logging + +from config.config import retrieval_num, select_num + +logger = logging.get_logger(__name__) + + +class EmoLLMRAG(object): + """ + EmoLLM RAG Pipeline + 1. 根据 query 进行 embedding + 2. 从 vector DB 中检索数据 + 3. rerank 检索后的结果 + 4. 将 query 和检索回来的 content 传入 LLM 中 + """ + + def __init__(self, model) -> None: + """ + 输入 Model 进行初始化 + + DataProcessing obj: 进行数据处理,包括数据 embedding/rerank + vectorstores: 加载vector DB。如果没有应该重新创建 + system prompt: 获取预定义的 system prompt + prompt template: 定义最后的输入到 LLM 中的 template + + """ + self.model = model + self.vectorstores = self._load_vector_db() + self.system_prompt = self._get_system_prompt() + self.prompt_template = self._get_prompt_template() + + # 等待 embedding team 封装对应接口 + #self.data_process_obj = DataProcessing() + + def _load_vector_db(self): + """ + 调用 embedding 模块给出接口 load vector DB + """ + return + + def _get_system_prompt(self) -> str: + """ + 加载 system prompt + """ + return '' + + def _get_prompt_template(self) -> str: + """ + 加载 prompt template + """ + return '' + + def get_retrieval_content(self, query, rerank_flag=False) -> str: + """ + Input: 用户提问, 是否需要rerank + ouput: 检索后并且 rerank 的内容 + """ + + content = '' + documents = self.vectorstores.similarity_search(query, k=retrieval_num) + + # 如果需要rerank,调用接口对 documents 进行 rerank + if rerank_flag: + pass + # 等后续调用接口 + #documents = self.data_process_obj.rerank_documents(documents, select_num) + + for doc in documents: + content += doc.page_content + + return content + + def generate_answer(self, query, content) -> str: + """ + Input: 用户提问, 检索返回的内容 + Output: 模型生成结果 + """ + + # 构建 template + # 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中 + prompt = PromptTemplate( + template=self.prompt_template, + input_variables=["query", "content", "system_prompt"], + ) + + # 定义 chain + # output格式为 string + rag_chain = prompt | self.model | StrOutputParser() + + # Run + generation = rag_chain.invoke( + { + "query": query, + "content": content, + "system_prompt": self.system_prompt + } + ) + return generation + + def main(self, query) -> str: + """ + Input: 用户提问 + output: LLM 生成的结果 + + 定义整个 RAG 的 pipeline 流程,调度各个模块 + TODO: + 加入 RAGAS 评分系统 + """ + content = self.get_retrieval_content(query) + response = self.generate_answer(query, content) + + return response