Updata RAG (#140)

This commit is contained in:
xzw 2024-03-24 16:11:15 +08:00 committed by GitHub
commit 7fa3a8b706
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 190 additions and 116 deletions

View File

@ -8,6 +8,86 @@
- 经典案例
- 客户背景知识
## **环境准备**
```python
langchain==0.1.13
langchain_community==0.0.29
langchain_core==0.1.33
langchain_openai==0.0.8
langchain_text_splitters==0.0.1
FlagEmbedding==1.2.8
unstructured==0.12.6
```
```python
cd rag
pip3 install -r requirements.txt
```
## **使用指南**
### 准备数据
- txt数据放入到 src.data.txt 目录下
- json 数据:放入到 src.data.json 目录下
会根据准备的数据构建vector DB最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl
如果已经有 vector DB 则会直接加载对应数据库
### 配置 config 文件
根据需要改写 config.config 文件:
```python
# 存放所有 model
model_dir = os.path.join(base_dir, 'model')
# embedding model 路径以及 model name
embedding_path = os.path.join(model_dir, 'embedding_model')
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
# rerank model 路径以及 model name
rerank_path = os.path.join(model_dir, 'rerank_model')
rerank_model_name = 'BAAI/bge-reranker-large'
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
select_num = 3
# retrieval num 代表从 vector db 中检索多少 documents。retrieval num 应该大于等于 select num
retrieval_num = 10
# 智谱 LLM 的 API key。目前 demo 仅支持智谱 AI api 作为最后生成
glm_key = ''
# Prompt template: 定义
prompt_template = """
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇我有一些心理问题请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
根据下面检索回来的信息,回答问题。
{content}
问题:{query}
"""
```
### 调用
```python
cd rag/src
python main.py
```
## **数据集**
- 经过清洗的QA对: 每一个QA对作为一个样本进行 embedding
@ -65,12 +145,3 @@ RAG的经典评估框架通过以下三个方面进行评估:
- 增加多路检索以增加召回率。即根据用户输入生成多个类似的query进行检索

View File

@ -2,5 +2,11 @@ sentence_transformers
transformers
numpy
loguru
langchain
torch
langchain==0.1.13
langchain_community==0.0.29
langchain_core==0.1.33
langchain_openai==0.0.8
langchain_text_splitters==0.0.1
FlagEmbedding==1.2.8
unstructured==0.12.6

View File

@ -8,7 +8,10 @@ model_repo = 'ajupyter/EmoLLM_aiwei'
# model
model_dir = os.path.join(base_dir, 'model') # model
embedding_path = os.path.join(model_dir, 'embedding_model') # embedding
rerank_path = os.path.join(model_dir, 'rerank_model') # embedding
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
rerank_path = os.path.join(model_dir, 'rerank_model') # embedding
rerank_model_name = 'BAAI/bge-reranker-large'
llm_path = os.path.join(model_dir, 'pythia-14m') # llm
# data
@ -23,15 +26,21 @@ log_dir = os.path.join(base_dir, 'log') # log
log_path = os.path.join(log_dir, 'log.log') # file
# vector DB
vector_db_dir = os.path.join(data_dir, 'vector_db.pkl')
vector_db_dir = os.path.join(data_dir, 'vector_db')
# RAG related
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
# retrieval num 代表从 vector db 中检索多少 documents。retrieval num 应该大于等于 select num
select_num = 3
retrieval_num = 10
system_prompt = """
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇我有一些心理问题请你用专业的知识和温柔可爱俏皮的口吻帮我解决回复中可以穿插一些可爱的Emoji表情符号或者文本符号\n
"""
# LLM key
glm_key = ''
# prompt
prompt_template = """
{system_prompt}
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇我有一些心理问题请你用专业的知识和温柔可爱俏皮的口吻帮我解决回复中可以穿插一些可爱的Emoji表情符号或者文本符号\n
根据下面检索回来的信息回答问题
{content}
问题{query}

View File

@ -1,33 +1,24 @@
import json
import pickle
import faiss
import pickle
import os
from loguru import logger
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path
from config.config import embedding_path, embedding_model_name, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path, rerank_model_name
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
from BCEmbedding import EmbeddingModel, RerankerModel
# from util.pipeline import EmoLLMRAG
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
from langchain_community.llms import Cohere
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader
from langchain_core.documents.base import Document
from FlagEmbedding import FlagReranker
class Data_process():
def __init__(self):
self.chunk_size: int=1000
self.chunk_overlap: int=100
def load_embedding_model(self, model_name='BAAI/bge-small-zh-v1.5', device='cpu', normalize_embeddings=True):
def load_embedding_model(self, model_name=embedding_model_name, device='cpu', normalize_embeddings=True):
"""
加载嵌入模型
@ -61,7 +52,8 @@ class Data_process():
return None
return embeddings
def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
def load_rerank_model(self, model_name=rerank_model_name):
"""
加载重排名模型
@ -99,7 +91,6 @@ class Data_process():
return reranker_model
def extract_text_from_json(self, obj, content=None):
"""
抽取json中的文本用于向量库构建
@ -128,7 +119,8 @@ class Data_process():
return content
def split_document(self, data_path, chunk_size=500, chunk_overlap=100):
def split_document(self, data_path):
"""
切分data_path文件夹下的所有txt文件
@ -143,7 +135,7 @@ class Data_process():
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_spliter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
split_docs = []
logger.info(f'Loading txt files from {data_path}')
if os.path.isdir(data_path):
@ -188,7 +180,7 @@ class Data_process():
# split_qa.append(Document(page_content = content))
#按conversation块切分
content = self.extract_text_from_json(conversation['conversation'], '')
logger.info(f'content====={content}')
#logger.info(f'content====={content}')
split_qa.append(Document(page_content = content))
# logger.info(f'split_qa size====={len(split_qa)}')
return split_qa

View File

@ -1,17 +1,6 @@
import os
import time
import jwt
from config.config import base_dir, data_dir
from data_processing import Data_process
from pipeline import EmoLLMRAG
from langchain_openai import ChatOpenAI
from util.llm import get_glm
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
'''
1构建完整的 RAG pipeline输入为用户 query输出为 answer
2调用 embedding 提供的接口对 query 向量化
@ -21,69 +10,34 @@ from openxlab.model import download
6拼接 prompt 并调用模型返回结果
'''
def get_glm(temprature):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_base="https://open.bigmodel.cn/api/paas/v4",
openai_api_key=generate_token("api-key"),
streaming=False,
temperature=temprature
)
return llm
def generate_token(apikey: str, exp_seconds: int=100):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
@st.cache_resource
def load_model():
model_dir = os.path.join(base_dir,'../model')
logger.info(f'Loading model from {model_dir}')
model = (
AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
return model, tokenizer
def main(query, system_prompt=''):
logger.info(data_dir)
if not os.path.exists(data_dir):
os.mkdir(data_dir)
dp = Data_process()
vector_db = dp.load_vector_db()
docs, retriever = dp.retrieve(query, vector_db, k=10)
logger.info(f'Query: {query}')
logger.info("Retrieve results===============================")
for i, doc in enumerate(docs):
logger.info(doc)
passages,scores = dp.rerank(query, docs)
logger.info("After reranking===============================")
for i in range(len(scores)):
logger.info(passages[i])
logger.info(f'score: {str(scores[i])}')
if __name__ == "__main__":
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
main(query)
#model = get_glm(0.7)
#rag_obj = EmoLLMRAG(model, 3)
#res = rag_obj.main(query)
#logger.info(res)
query = """
我现在处于高三阶段感到非常迷茫和害怕我觉得自己从出生以来就是多余的没有必要存在于这个世界
无论是在家庭学校朋友还是老师面前我都感到被否定我非常难过对高考充满期望但成绩却不理想
"""
"""
输入:
model_name='glm-4',
api_base="https://open.bigmodel.cn/api/paas/v4",
temprature=0.7,
streaming=False,
输出
LLM Model
"""
model = get_glm()
"""
输入:
LLM model
retrieval_num=3
rerank_flag=False
select_num-3
"""
rag_obj = EmoLLMRAG(model)
res = rag_obj.main(query)
logger.info(res)

View File

@ -3,7 +3,7 @@ from langchain_core.prompts import PromptTemplate
from transformers.utils import logging
from data_processing import Data_process
from config.config import system_prompt, prompt_template
from config.config import prompt_template
logger = logging.get_logger(__name__)
@ -16,7 +16,7 @@ class EmoLLMRAG(object):
4. query 和检索回来的 content 传入 LLM
"""
def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None:
def __init__(self, model, retrieval_num=3, rerank_flag=False, select_num=3) -> None:
"""
输入 Model 进行初始化
@ -29,7 +29,6 @@ class EmoLLMRAG(object):
self.model = model
self.data_processing_obj = Data_process()
self.vectorstores = self._load_vector_db()
self.system_prompt = system_prompt
self.prompt_template = prompt_template
self.retrieval_num = retrieval_num
self.rerank_flag = rerank_flag
@ -75,7 +74,7 @@ class EmoLLMRAG(object):
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
prompt = PromptTemplate(
template=self.prompt_template,
input_variables=["query", "content", "system_prompt"],
input_variables=["query", "content"],
)
# 定义 chain
@ -87,7 +86,6 @@ class EmoLLMRAG(object):
{
"query": query,
"content": content,
"system_prompt": self.system_prompt
}
)
return generation

View File

@ -0,0 +1,44 @@
import time
import jwt
from langchain_openai import ChatOpenAI
from config.config import glm_key
def get_glm(
model_name='glm-4',
api_base="https://open.bigmodel.cn/api/paas/v4",
temprature=0.7,
streaming=False,
):
"""
"""
llm = ChatOpenAI(
model_name=model_name,
openai_api_base=api_base,
openai_api_key=generate_token(glm_key),
streaming=streaming,
temperature=temprature
)
return llm
def generate_token(apikey: str, exp_seconds: int=100):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)