This commit is contained in:
Anooyman 2024-03-24 15:18:35 +08:00
parent de0674ccf7
commit f44310f665
7 changed files with 146 additions and 87 deletions

View File

@ -8,6 +8,61 @@
- 经典案例
- 客户背景知识
## **环境准备**
```python
langchain==0.1.13
langchain_community==0.0.29
langchain_core==0.1.33
langchain_openai==0.0.8
langchain_text_splitters==0.0.1
FlagEmbedding==1.2.8
```
```python
cd rag
pip3 install -r requirements.txt
```
## **使用指南**
### 配置 config 文件
根据需要改写 config.config 文件:
```python
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
select_num = 3
# retrieval num 代表从 vector db 中检索多少 documents。retrieval num 应该大于等于 select num
retrieval_num = 10
# 智谱 LLM 的 API key。目前 demo 仅支持智谱 AI api 作为最后生成
glm_key = ''
# Prompt template: 定义
prompt_template = """
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇我有一些心理问题请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
根据下面检索回来的信息,回答问题。
{content}
问题:{query}
"
```
### 调用
```python
cd rag/src
python main.py
```
## **数据集**
- 经过清洗的QA对: 每一个QA对作为一个样本进行 embedding
@ -65,12 +120,3 @@ RAG的经典评估框架通过以下三个方面进行评估:
- 增加多路检索以增加召回率。即根据用户输入生成多个类似的query进行检索

View File

@ -2,5 +2,10 @@ sentence_transformers
transformers
numpy
loguru
langchain
torch
langchain==0.1.13
langchain_community==0.0.29
langchain_core==0.1.33
langchain_openai==0.0.8
langchain_text_splitters==0.0.1
FlagEmbedding==1.2.8

View File

@ -23,13 +23,18 @@ log_dir = os.path.join(base_dir, 'log') # log
log_path = os.path.join(log_dir, 'log.log') # file
# vector DB
vector_db_dir = os.path.join(data_dir, 'vector_db.pkl')
vector_db_dir = os.path.join(data_dir, 'vector_db')
# RAG related
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
# retrieval num 代表从 vector db 中检索多少 documents。retrieval num 应该大于等于 select num
select_num = 3
retrieval_num = 10
system_prompt = """
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇我有一些心理问题请你用专业的知识和温柔可爱俏皮的口吻帮我解决回复中可以穿插一些可爱的Emoji表情符号或者文本符号\n
"""
# LLM key
glm_key = ''
# prompt
prompt_template = """
{system_prompt}
根据下面检索回来的信息回答问题

View File

@ -1,24 +1,14 @@
import json
import pickle
import faiss
import pickle
import os
from loguru import logger
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
from BCEmbedding import EmbeddingModel, RerankerModel
# from util.pipeline import EmoLLMRAG
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
from langchain_community.llms import Cohere
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader
from langchain_core.documents.base import Document
from FlagEmbedding import FlagReranker

View File

@ -1,17 +1,9 @@
import os
import time
import jwt
from config.config import base_dir, data_dir
from config.config import data_dir
from data_processing import Data_process
from pipeline import EmoLLMRAG
from langchain_openai import ChatOpenAI
from util.llm import get_glm
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
'''
1构建完整的 RAG pipeline输入为用户 query输出为 answer
2调用 embedding 提供的接口对 query 向量化
@ -21,46 +13,6 @@ from openxlab.model import download
6拼接 prompt 并调用模型返回结果
'''
def get_glm(temprature):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_base="https://open.bigmodel.cn/api/paas/v4",
openai_api_key=generate_token("api-key"),
streaming=False,
temperature=temprature
)
return llm
def generate_token(apikey: str, exp_seconds: int=100):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
@st.cache_resource
def load_model():
model_dir = os.path.join(base_dir,'../model')
logger.info(f'Loading model from {model_dir}')
model = (
AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
return model, tokenizer
def main(query, system_prompt=''):
logger.info(data_dir)
@ -81,9 +33,28 @@ def main(query, system_prompt=''):
if __name__ == "__main__":
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
main(query)
#model = get_glm(0.7)
#rag_obj = EmoLLMRAG(model, 3)
#res = rag_obj.main(query)
#logger.info(res)
"""
输入:
model_name='glm-4',
api_base="https://open.bigmodel.cn/api/paas/v4",
temprature=0.7,
streaming=False,
输出
LLM Model
"""
model = get_glm()
"""
输入:
LLM model
retrieval_num=3
rerank_flag=False
select_num-3
"""
rag_obj = EmoLLMRAG(model)
res = rag_obj.main(query)
logger.info(res)

View File

@ -3,7 +3,7 @@ from langchain_core.prompts import PromptTemplate
from transformers.utils import logging
from data_processing import Data_process
from config.config import system_prompt, prompt_template
from config.config import prompt_template
logger = logging.get_logger(__name__)
@ -16,7 +16,7 @@ class EmoLLMRAG(object):
4. query 和检索回来的 content 传入 LLM
"""
def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None:
def __init__(self, model, retrieval_num=3, rerank_flag=False, select_num=3) -> None:
"""
输入 Model 进行初始化
@ -29,7 +29,6 @@ class EmoLLMRAG(object):
self.model = model
self.data_processing_obj = Data_process()
self.vectorstores = self._load_vector_db()
self.system_prompt = system_prompt
self.prompt_template = prompt_template
self.retrieval_num = retrieval_num
self.rerank_flag = rerank_flag
@ -75,7 +74,7 @@ class EmoLLMRAG(object):
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
prompt = PromptTemplate(
template=self.prompt_template,
input_variables=["query", "content", "system_prompt"],
input_variables=["query", "content"],
)
# 定义 chain
@ -87,7 +86,6 @@ class EmoLLMRAG(object):
{
"query": query,
"content": content,
"system_prompt": self.system_prompt
}
)
return generation

View File

@ -0,0 +1,44 @@
import time
import jwt
from langchain_openai import ChatOpenAI
from config.config import glm_key
def get_glm(
model_name='glm-4',
api_base="https://open.bigmodel.cn/api/paas/v4",
temprature=0.7,
streaming=False,
):
"""
"""
llm = ChatOpenAI(
model_name=model_name,
openai_api_base=api_base,
openai_api_key=generate_token(glm_key),
streaming=streaming,
temperature=temprature
)
return llm
def generate_token(apikey: str, exp_seconds: int=100):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)