Merge pull request #91 from Anooyman/main

RAG module update
This commit is contained in:
xzw 2024-03-17 13:15:13 +08:00 committed by GitHub
commit 1e93a3dbd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 213 additions and 3 deletions

View File

@ -0,0 +1,66 @@
# EmoLLM RAG
## **Module purpose**
Based on the customer's questions, the corresponding information is retrieved to enhance the professionalism of the answer, making EmoLLM's answer more professional and reliable. Search content includes but is not limited to the following:
- Psychology related theories
- Psychology methodology
- Classic Case
- Customer background knowledge
## **Datasets**
- Cleaned QA pairs: Each QA pair is embedding as a sample
- Filtered TXT texts
- Directly generate embedding for TXT text (segmented based on token length)
- Filter out irrelevant information such as directories and generate embedding for TXT text (segmented based on token length)
- After filtering irrelevant information such as directories, the TXT is semantically segmented to generate embedding.
- Split TXT according to the directory structure, and generate embeddings based on the architecture hierarchy.
For details on data collection construction, please refer to [qa_generation_README](https://github.com/SmartFlowAI/EmoLLM/blob/ccfa75c493c4685e84073dfbc53c50c09a2988e3/scripts/qa_generation/README.md)
## **Components**
### [BCEmbedding](https://github.com/netease-youdao/BCEmbedding?tab=readme-ov-file)
- [bce-embedding-base_v1](https://hf-mirror.com/maidalun1020/bce-embedding-base_v1): embedding model, used to build vector DB
- [bce-reranker-base_v1](https://hf-mirror.com/maidalun1020/bce-reranker-base_v1): rerank model, used to rerank retrieved documents
### [Langchain](https://python.langchain.com/docs/get_started)
LangChain is an open source framework for building large language model (LLM) based applications. LangChain provides a variety of tools and abstractions to increase the customization, accuracy, and relevance of the information generated by your models.
### [FAISS](https://faiss.ai/)
FAISS is a library for efficient similarity search and dense vector clustering. It contains algorithms that can search sets of vectors of any size. Since langchain has integrated FAISS, this project will no longer be developed based on native documents. [FAISS in Langchain](https://python.langchain.com/docs/integrations/vectorstores/faiss)
### [RAGAS](https://github.com/explodinggradients/ragas)
RAGs classic evaluation framework is evaluated through the following three aspects:
- Faithfulness: The answers given should be generated based on the given context.
- Answer Relevance: The generated answer should solve the actual question asked.
- Context Relevance: The retrieved information should be highly concentrated and contain as little irrelevant information as possible.
Later, more evaluation indicators were added, such as: context recall, etc.
## **Detials**
### RAG pipeline
- Build vector DB based on data set
- Embedding questions entered by customers
- Search in vector database based on embedding results
- Reorder recall data
- Generate final results based on user questions and recall data
**Noted**: The above process will only be carried out when the user chooses to use RAG
### Follow-up actions
- Add RAGAS evaluation results to the generation process. For example, when the generated results cannot solve the user's problem, it needs to be regenerated.
- Add web retrieval to deal with the problem that the corresponding information cannot be retrieved in vector DB
- Add multi-channel retrieval to increase recall rate. That is, multiple similar queries are generated based on user input for retrieval.

View File

@ -1,4 +1,6 @@
sentence_transformers
transformers
numpy
loguru
loguru
langchain
torch

View File

@ -3,6 +3,7 @@ import os
cur_dir = os.path.dirname(os.path.abspath(__file__)) # config
src_dir = os.path.dirname(cur_dir) # src
base_dir = os.path.dirname(src_dir) # base
model_repo = 'ajupyter/EmoLLM_aiwei'
# model
model_dir = os.path.join(base_dir, 'model') # model
@ -17,3 +18,6 @@ knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pickle
# log
log_dir = os.path.join(base_dir, 'log') # log
log_path = os.path.join(log_dir, 'log.log') # file
select_num = 3
retrieval_num = 10

View File

@ -5,8 +5,19 @@ import numpy as np
from typing import Tuple
from sentence_transformers import SentenceTransformer
from config.config import knowledge_json_path, knowledge_pkl_path
from config.config import knowledge_json_path, knowledge_pkl_path, model_repo
from util.encode import load_embedding, encode_qa
from util.pipeline import EmoLLMRAG
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
download(
model_repo=model_repo,
output='model'
)
"""
@ -62,6 +73,19 @@ def main():
## 2. 将 contents 拼接为 prompt传给 LLM作为 {已知内容}
## 3. 要求 LLM 根据已知内容回复
@st.cache_resource
def load_model():
model = (
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
return model, tokenizer
if __name__ == '__main__':
main()
#main()
query = ''
model, tokenizer = load_model()
rag_obj = EmoLLMRAG(model)
response = rag_obj.main(query)

114
rag/src/util/pipeline.py Normal file
View File

@ -0,0 +1,114 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from transformers.utils import logging
from config.config import retrieval_num, select_num
logger = logging.get_logger(__name__)
class EmoLLMRAG(object):
"""
EmoLLM RAG Pipeline
1. 根据 query 进行 embedding
2. vector DB 中检索数据
3. rerank 检索后的结果
4. query 和检索回来的 content 传入 LLM
"""
def __init__(self, model) -> None:
"""
输入 Model 进行初始化
DataProcessing obj: 进行数据处理包括数据 embedding/rerank
vectorstores: 加载vector DB如果没有应该重新创建
system prompt: 获取预定义的 system prompt
prompt template: 定义最后的输入到 LLM 中的 template
"""
self.model = model
self.vectorstores = self._load_vector_db()
self.system_prompt = self._get_system_prompt()
self.prompt_template = self._get_prompt_template()
# 等待 embedding team 封装对应接口
#self.data_process_obj = DataProcessing()
def _load_vector_db(self):
"""
调用 embedding 模块给出接口 load vector DB
"""
return
def _get_system_prompt(self) -> str:
"""
加载 system prompt
"""
return ''
def _get_prompt_template(self) -> str:
"""
加载 prompt template
"""
return ''
def get_retrieval_content(self, query, rerank_flag=False) -> str:
"""
Input: 用户提问, 是否需要rerank
ouput: 检索后并且 rerank 的内容
"""
content = ''
documents = self.vectorstores.similarity_search(query, k=retrieval_num)
# 如果需要rerank调用接口对 documents 进行 rerank
if rerank_flag:
pass
# 等后续调用接口
#documents = self.data_process_obj.rerank_documents(documents, select_num)
for doc in documents:
content += doc.page_content
return content
def generate_answer(self, query, content) -> str:
"""
Input: 用户提问 检索返回的内容
Output: 模型生成结果
"""
# 构建 template
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
prompt = PromptTemplate(
template=self.prompt_template,
input_variables=["query", "content", "system_prompt"],
)
# 定义 chain
# output格式为 string
rag_chain = prompt | self.model | StrOutputParser()
# Run
generation = rag_chain.invoke(
{
"query": query,
"content": content,
"system_prompt": self.system_prompt
}
)
return generation
def main(self, query) -> str:
"""
Input: 用户提问
output: LLM 生成的结果
定义整个 RAG pipeline 流程调度各个模块
TODO:
加入 RAGAS 评分系统
"""
content = self.get_retrieval_content(query)
response = self.generate_answer(query, content)
return response