update
This commit is contained in:
parent
de0674ccf7
commit
f44310f665
@ -8,6 +8,61 @@
|
|||||||
- 经典案例
|
- 经典案例
|
||||||
- 客户背景知识
|
- 客户背景知识
|
||||||
|
|
||||||
|
## **环境准备**
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
langchain==0.1.13
|
||||||
|
langchain_community==0.0.29
|
||||||
|
langchain_core==0.1.33
|
||||||
|
langchain_openai==0.0.8
|
||||||
|
langchain_text_splitters==0.0.1
|
||||||
|
FlagEmbedding==1.2.8
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
cd rag
|
||||||
|
|
||||||
|
pip3 install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## **使用指南**
|
||||||
|
|
||||||
|
### 配置 config 文件
|
||||||
|
|
||||||
|
根据需要改写 config.config 文件:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
|
||||||
|
select_num = 3
|
||||||
|
|
||||||
|
# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num)
|
||||||
|
retrieval_num = 10
|
||||||
|
|
||||||
|
# 智谱 LLM 的 API key。目前 demo 仅支持智谱 AI api 作为最后生成
|
||||||
|
glm_key = ''
|
||||||
|
|
||||||
|
# Prompt template: 定义
|
||||||
|
prompt_template = """
|
||||||
|
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
|
||||||
|
|
||||||
|
根据下面检索回来的信息,回答问题。
|
||||||
|
|
||||||
|
{content}
|
||||||
|
|
||||||
|
问题:{query}
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 调用
|
||||||
|
|
||||||
|
```python
|
||||||
|
cd rag/src
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## **数据集**
|
## **数据集**
|
||||||
|
|
||||||
- 经过清洗的QA对: 每一个QA对作为一个样本进行 embedding
|
- 经过清洗的QA对: 每一个QA对作为一个样本进行 embedding
|
||||||
@ -65,12 +120,3 @@ RAG的经典评估框架,通过以下三个方面进行评估:
|
|||||||
- 增加多路检索以增加召回率。即根据用户输入生成多个类似的query进行检索
|
- 增加多路检索以增加召回率。即根据用户输入生成多个类似的query进行检索
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,5 +2,10 @@ sentence_transformers
|
|||||||
transformers
|
transformers
|
||||||
numpy
|
numpy
|
||||||
loguru
|
loguru
|
||||||
langchain
|
|
||||||
torch
|
torch
|
||||||
|
langchain==0.1.13
|
||||||
|
langchain_community==0.0.29
|
||||||
|
langchain_core==0.1.33
|
||||||
|
langchain_openai==0.0.8
|
||||||
|
langchain_text_splitters==0.0.1
|
||||||
|
FlagEmbedding==1.2.8
|
||||||
|
@ -23,13 +23,18 @@ log_dir = os.path.join(base_dir, 'log') # log
|
|||||||
log_path = os.path.join(log_dir, 'log.log') # file
|
log_path = os.path.join(log_dir, 'log.log') # file
|
||||||
|
|
||||||
# vector DB
|
# vector DB
|
||||||
vector_db_dir = os.path.join(data_dir, 'vector_db.pkl')
|
vector_db_dir = os.path.join(data_dir, 'vector_db')
|
||||||
|
|
||||||
|
# RAG related
|
||||||
|
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
|
||||||
|
# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num)
|
||||||
select_num = 3
|
select_num = 3
|
||||||
retrieval_num = 10
|
retrieval_num = 10
|
||||||
system_prompt = """
|
|
||||||
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
|
# LLM key
|
||||||
"""
|
glm_key = ''
|
||||||
|
|
||||||
|
# prompt
|
||||||
prompt_template = """
|
prompt_template = """
|
||||||
{system_prompt}
|
{system_prompt}
|
||||||
根据下面检索回来的信息,回答问题。
|
根据下面检索回来的信息,回答问题。
|
||||||
|
@ -1,24 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
import faiss
|
|
||||||
import pickle
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from langchain_community.vectorstores import FAISS
|
from langchain_community.vectorstores import FAISS
|
||||||
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path
|
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path
|
||||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||||
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
from BCEmbedding import EmbeddingModel, RerankerModel
|
from langchain.document_loaders import DirectoryLoader
|
||||||
# from util.pipeline import EmoLLMRAG
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
|
|
||||||
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
|
|
||||||
from langchain_community.llms import Cohere
|
|
||||||
from langchain.retrievers import ContextualCompressionRetriever
|
|
||||||
from langchain.retrievers.document_compressors import FlashrankRerank
|
|
||||||
from langchain_core.documents.base import Document
|
from langchain_core.documents.base import Document
|
||||||
from FlagEmbedding import FlagReranker
|
from FlagEmbedding import FlagReranker
|
||||||
|
|
||||||
|
@ -1,17 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
from config.config import data_dir
|
||||||
import jwt
|
|
||||||
|
|
||||||
from config.config import base_dir, data_dir
|
|
||||||
from data_processing import Data_process
|
from data_processing import Data_process
|
||||||
from pipeline import EmoLLMRAG
|
from pipeline import EmoLLMRAG
|
||||||
|
from util.llm import get_glm
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
import torch
|
|
||||||
import streamlit as st
|
|
||||||
from openxlab.model import download
|
|
||||||
'''
|
'''
|
||||||
1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer
|
1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer
|
||||||
2)调用 embedding 提供的接口对 query 向量化
|
2)调用 embedding 提供的接口对 query 向量化
|
||||||
@ -21,46 +13,6 @@ from openxlab.model import download
|
|||||||
6)拼接 prompt 并调用模型返回结果
|
6)拼接 prompt 并调用模型返回结果
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def get_glm(temprature):
|
|
||||||
llm = ChatOpenAI(
|
|
||||||
model_name="glm-4",
|
|
||||||
openai_api_base="https://open.bigmodel.cn/api/paas/v4",
|
|
||||||
openai_api_key=generate_token("api-key"),
|
|
||||||
streaming=False,
|
|
||||||
temperature=temprature
|
|
||||||
)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
def generate_token(apikey: str, exp_seconds: int=100):
|
|
||||||
try:
|
|
||||||
id, secret = apikey.split(".")
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception("invalid apikey", e)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"api_key": id,
|
|
||||||
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
|
|
||||||
"timestamp": int(round(time.time() * 1000)),
|
|
||||||
}
|
|
||||||
|
|
||||||
return jwt.encode(
|
|
||||||
payload,
|
|
||||||
secret,
|
|
||||||
algorithm="HS256",
|
|
||||||
headers={"alg": "HS256", "sign_type": "SIGN"},
|
|
||||||
)
|
|
||||||
|
|
||||||
@st.cache_resource
|
|
||||||
def load_model():
|
|
||||||
model_dir = os.path.join(base_dir,'../model')
|
|
||||||
logger.info(f'Loading model from {model_dir}')
|
|
||||||
model = (
|
|
||||||
AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
|
|
||||||
.to(torch.bfloat16)
|
|
||||||
.cuda()
|
|
||||||
)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
def main(query, system_prompt=''):
|
def main(query, system_prompt=''):
|
||||||
logger.info(data_dir)
|
logger.info(data_dir)
|
||||||
@ -81,9 +33,28 @@ def main(query, system_prompt=''):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
||||||
main(query)
|
|
||||||
#model = get_glm(0.7)
|
"""
|
||||||
#rag_obj = EmoLLMRAG(model, 3)
|
输入:
|
||||||
#res = rag_obj.main(query)
|
model_name='glm-4',
|
||||||
#logger.info(res)
|
api_base="https://open.bigmodel.cn/api/paas/v4",
|
||||||
|
temprature=0.7,
|
||||||
|
streaming=False,
|
||||||
|
输出:
|
||||||
|
LLM Model
|
||||||
|
"""
|
||||||
|
model = get_glm()
|
||||||
|
|
||||||
|
"""
|
||||||
|
输入:
|
||||||
|
LLM model
|
||||||
|
retrieval_num=3
|
||||||
|
rerank_flag=False
|
||||||
|
select_num-3
|
||||||
|
"""
|
||||||
|
rag_obj = EmoLLMRAG(model)
|
||||||
|
|
||||||
|
res = rag_obj.main(query)
|
||||||
|
|
||||||
|
logger.info(res)
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from langchain_core.prompts import PromptTemplate
|
|||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from data_processing import Data_process
|
from data_processing import Data_process
|
||||||
from config.config import system_prompt, prompt_template
|
from config.config import prompt_template
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ class EmoLLMRAG(object):
|
|||||||
4. 将 query 和检索回来的 content 传入 LLM 中
|
4. 将 query 和检索回来的 content 传入 LLM 中
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None:
|
def __init__(self, model, retrieval_num=3, rerank_flag=False, select_num=3) -> None:
|
||||||
"""
|
"""
|
||||||
输入 Model 进行初始化
|
输入 Model 进行初始化
|
||||||
|
|
||||||
@ -29,7 +29,6 @@ class EmoLLMRAG(object):
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.data_processing_obj = Data_process()
|
self.data_processing_obj = Data_process()
|
||||||
self.vectorstores = self._load_vector_db()
|
self.vectorstores = self._load_vector_db()
|
||||||
self.system_prompt = system_prompt
|
|
||||||
self.prompt_template = prompt_template
|
self.prompt_template = prompt_template
|
||||||
self.retrieval_num = retrieval_num
|
self.retrieval_num = retrieval_num
|
||||||
self.rerank_flag = rerank_flag
|
self.rerank_flag = rerank_flag
|
||||||
@ -75,7 +74,7 @@ class EmoLLMRAG(object):
|
|||||||
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
|
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template=self.prompt_template,
|
template=self.prompt_template,
|
||||||
input_variables=["query", "content", "system_prompt"],
|
input_variables=["query", "content"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# 定义 chain
|
# 定义 chain
|
||||||
@ -87,7 +86,6 @@ class EmoLLMRAG(object):
|
|||||||
{
|
{
|
||||||
"query": query,
|
"query": query,
|
||||||
"content": content,
|
"content": content,
|
||||||
"system_prompt": self.system_prompt
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return generation
|
return generation
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
import time
|
||||||
|
import jwt
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from config.config import glm_key
|
||||||
|
|
||||||
|
|
||||||
|
def get_glm(
|
||||||
|
model_name='glm-4',
|
||||||
|
api_base="https://open.bigmodel.cn/api/paas/v4",
|
||||||
|
temprature=0.7,
|
||||||
|
streaming=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
model_name=model_name,
|
||||||
|
openai_api_base=api_base,
|
||||||
|
openai_api_key=generate_token(glm_key),
|
||||||
|
streaming=streaming,
|
||||||
|
temperature=temprature
|
||||||
|
)
|
||||||
|
|
||||||
|
return llm
|
||||||
|
|
||||||
|
def generate_token(apikey: str, exp_seconds: int=100):
|
||||||
|
try:
|
||||||
|
id, secret = apikey.split(".")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("invalid apikey", e)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"api_key": id,
|
||||||
|
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
|
||||||
|
"timestamp": int(round(time.time() * 1000)),
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt.encode(
|
||||||
|
payload,
|
||||||
|
secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user