From 50a5129c77049c2fa756711a2eeb34c3346e1ce0 Mon Sep 17 00:00:00 2001 From: edward_ke Date: Sun, 17 Mar 2024 10:31:11 +0800 Subject: [PATCH 1/2] Update basic RAG pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 只加了基本的 pipeline,还未进行测试,等具体接口确定之后进行调试 --- rag/requirements.txt | 4 +- rag/src/config/config.py | 4 ++ rag/src/main.py | 28 +++++++++- rag/src/util/pipeline.py | 114 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 147 insertions(+), 3 deletions(-) create mode 100644 rag/src/util/pipeline.py diff --git a/rag/requirements.txt b/rag/requirements.txt index 08289b2..15f915c 100644 --- a/rag/requirements.txt +++ b/rag/requirements.txt @@ -1,4 +1,6 @@ sentence_transformers transformers numpy -loguru \ No newline at end of file +loguru +langchain +torch diff --git a/rag/src/config/config.py b/rag/src/config/config.py index 4c7e335..b84327f 100644 --- a/rag/src/config/config.py +++ b/rag/src/config/config.py @@ -3,6 +3,7 @@ import os cur_dir = os.path.dirname(os.path.abspath(__file__)) # config src_dir = os.path.dirname(cur_dir) # src base_dir = os.path.dirname(src_dir) # base +model_repo = 'ajupyter/EmoLLM_aiwei' # model model_dir = os.path.join(base_dir, 'model') # model @@ -17,3 +18,6 @@ knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pickle # log log_dir = os.path.join(base_dir, 'log') # log log_path = os.path.join(log_dir, 'log.log') # file + +select_num = 3 +retrieval_num = 10 \ No newline at end of file diff --git a/rag/src/main.py b/rag/src/main.py index 219ce85..97f60a0 100644 --- a/rag/src/main.py +++ b/rag/src/main.py @@ -5,8 +5,19 @@ import numpy as np from typing import Tuple from sentence_transformers import SentenceTransformer -from config.config import knowledge_json_path, knowledge_pkl_path +from config.config import knowledge_json_path, knowledge_pkl_path, model_repo from util.encode import load_embedding, encode_qa +from util.pipeline import EmoLLMRAG + +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import streamlit as st +from openxlab.model import download + +download( + model_repo=model_repo, + output='model' +) """ @@ -62,6 +73,19 @@ def main(): ## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容} ## 3. 要求 LLM 根据已知内容回复 +@st.cache_resource +def load_model(): + model = ( + AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True) + .to(torch.bfloat16) + .cuda() + ) + tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True) + return model, tokenizer if __name__ == '__main__': - main() + #main() + query = '' + model, tokenizer = load_model() + rag_obj = EmoLLMRAG(model) + response = rag_obj.main(query) \ No newline at end of file diff --git a/rag/src/util/pipeline.py b/rag/src/util/pipeline.py new file mode 100644 index 0000000..a6f2cdf --- /dev/null +++ b/rag/src/util/pipeline.py @@ -0,0 +1,114 @@ +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate +from transformers.utils import logging + +from config.config import retrieval_num, select_num + +logger = logging.get_logger(__name__) + + +class EmoLLMRAG(object): + """ + EmoLLM RAG Pipeline + 1. 根据 query 进行 embedding + 2. 从 vector DB 中检索数据 + 3. rerank 检索后的结果 + 4. 将 query 和检索回来的 content 传入 LLM 中 + """ + + def __init__(self, model) -> None: + """ + 输入 Model 进行初始化 + + DataProcessing obj: 进行数据处理,包括数据 embedding/rerank + vectorstores: 加载vector DB。如果没有应该重新创建 + system prompt: 获取预定义的 system prompt + prompt template: 定义最后的输入到 LLM 中的 template + + """ + self.model = model + self.vectorstores = self._load_vector_db() + self.system_prompt = self._get_system_prompt() + self.prompt_template = self._get_prompt_template() + + # 等待 embedding team 封装对应接口 + #self.data_process_obj = DataProcessing() + + def _load_vector_db(self): + """ + 调用 embedding 模块给出接口 load vector DB + """ + return + + def _get_system_prompt(self) -> str: + """ + 加载 system prompt + """ + return '' + + def _get_prompt_template(self) -> str: + """ + 加载 prompt template + """ + return '' + + def get_retrieval_content(self, query, rerank_flag=False) -> str: + """ + Input: 用户提问, 是否需要rerank + ouput: 检索后并且 rerank 的内容 + """ + + content = '' + documents = self.vectorstores.similarity_search(query, k=retrieval_num) + + # 如果需要rerank,调用接口对 documents 进行 rerank + if rerank_flag: + pass + # 等后续调用接口 + #documents = self.data_process_obj.rerank_documents(documents, select_num) + + for doc in documents: + content += doc.page_content + + return content + + def generate_answer(self, query, content) -> str: + """ + Input: 用户提问, 检索返回的内容 + Output: 模型生成结果 + """ + + # 构建 template + # 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中 + prompt = PromptTemplate( + template=self.prompt_template, + input_variables=["query", "content", "system_prompt"], + ) + + # 定义 chain + # output格式为 string + rag_chain = prompt | self.model | StrOutputParser() + + # Run + generation = rag_chain.invoke( + { + "query": query, + "content": content, + "system_prompt": self.system_prompt + } + ) + return generation + + def main(self, query) -> str: + """ + Input: 用户提问 + output: LLM 生成的结果 + + 定义整个 RAG 的 pipeline 流程,调度各个模块 + TODO: + 加入 RAGAS 评分系统 + """ + content = self.get_retrieval_content(query) + response = self.generate_answer(query, content) + + return response From b050fe8122aaf7c6e7f50b0718b90e330bbb024a Mon Sep 17 00:00:00 2001 From: edward_ke Date: Sun, 17 Mar 2024 10:40:26 +0800 Subject: [PATCH 2/2] Update README_EN.md --- rag/README_EN.md | 66 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/rag/README_EN.md b/rag/README_EN.md index e69de29..df4fe43 100644 --- a/rag/README_EN.md +++ b/rag/README_EN.md @@ -0,0 +1,66 @@ +# EmoLLM RAG + +## **Module purpose** + +Based on the customer's questions, the corresponding information is retrieved to enhance the professionalism of the answer, making EmoLLM's answer more professional and reliable. Search content includes but is not limited to the following: +- Psychology related theories +- Psychology methodology +- Classic Case +- Customer background knowledge + +## **Datasets** + + +- Cleaned QA pairs: Each QA pair is embedding as a sample +- Filtered TXT texts + - Directly generate embedding for TXT text (segmented based on token length) + - Filter out irrelevant information such as directories and generate embedding for TXT text (segmented based on token length) + - After filtering irrelevant information such as directories, the TXT is semantically segmented to generate embedding. + - Split TXT according to the directory structure, and generate embeddings based on the architecture hierarchy. + + +For details on data collection construction, please refer to [qa_generation_README](https://github.com/SmartFlowAI/EmoLLM/blob/ccfa75c493c4685e84073dfbc53c50c09a2988e3/scripts/qa_generation/README.md) + +## **Components** + +### [BCEmbedding](https://github.com/netease-youdao/BCEmbedding?tab=readme-ov-file) + +- [bce-embedding-base_v1](https://hf-mirror.com/maidalun1020/bce-embedding-base_v1): embedding model, used to build vector DB +- [bce-reranker-base_v1](https://hf-mirror.com/maidalun1020/bce-reranker-base_v1): rerank model, used to rerank retrieved documents + +### [Langchain](https://python.langchain.com/docs/get_started) + +LangChain is an open source framework for building large language model (LLM) based applications. LangChain provides a variety of tools and abstractions to increase the customization, accuracy, and relevance of the information generated by your models. + +### [FAISS](https://faiss.ai/) + +FAISS is a library for efficient similarity search and dense vector clustering. It contains algorithms that can search sets of vectors of any size. Since langchain has integrated FAISS, this project will no longer be developed based on native documents. [FAISS in Langchain](https://python.langchain.com/docs/integrations/vectorstores/faiss) + + +### [RAGAS](https://github.com/explodinggradients/ragas) + +RAG’s classic evaluation framework is evaluated through the following three aspects: + +- Faithfulness: The answers given should be generated based on the given context. +- Answer Relevance: The generated answer should solve the actual question asked. +- Context Relevance: The retrieved information should be highly concentrated and contain as little irrelevant information as possible. + +Later, more evaluation indicators were added, such as: context recall, etc. + +## **Detials** + +### RAG pipeline + +- Build vector DB based on data set +- Embedding questions entered by customers +- Search in vector database based on embedding results +- Reorder recall data +- Generate final results based on user questions and recall data + +**Noted**: The above process will only be carried out when the user chooses to use RAG + +### Follow-up actions + +- Add RAGAS evaluation results to the generation process. For example, when the generated results cannot solve the user's problem, it needs to be regenerated. +- Add web retrieval to deal with the problem that the corresponding information cannot be retrieved in vector DB +- Add multi-channel retrieval to increase recall rate. That is, multiple similar queries are generated based on user input for retrieval. \ No newline at end of file