diff --git a/rag/README.md b/rag/README.md index 9c16408..e247c8a 100644 --- a/rag/README.md +++ b/rag/README.md @@ -8,6 +8,61 @@ - 经典案例 - 客户背景知识 +## **环境准备** + +```python + +langchain==0.1.13 +langchain_community==0.0.29 +langchain_core==0.1.33 +langchain_openai==0.0.8 +langchain_text_splitters==0.0.1 +FlagEmbedding==1.2.8 + +``` + +```python +cd rag + +pip3 install -r requirements.txt +``` + +## **使用指南** + +### 配置 config 文件 + +根据需要改写 config.config 文件: + +```python +# select num: 代表rerank 之后选取多少个 documents 进入 LLM +select_num = 3 + +# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num) +retrieval_num = 10 + +# 智谱 LLM 的 API key。目前 demo 仅支持智谱 AI api 作为最后生成 +glm_key = '' + +# Prompt template: 定义 +prompt_template = """ + 你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n + + 根据下面检索回来的信息,回答问题。 + + {content} + + 问题:{query} +" +``` + +### 调用 + +```python +cd rag/src +python main.py +``` + + ## **数据集** - 经过清洗的QA对: 每一个QA对作为一个样本进行 embedding @@ -65,12 +120,3 @@ RAG的经典评估框架,通过以下三个方面进行评估: - 增加多路检索以增加召回率。即根据用户输入生成多个类似的query进行检索 - - - - - - - - - diff --git a/rag/requirements.txt b/rag/requirements.txt index 15f915c..ef8a833 100644 --- a/rag/requirements.txt +++ b/rag/requirements.txt @@ -2,5 +2,10 @@ sentence_transformers transformers numpy loguru -langchain torch +langchain==0.1.13 +langchain_community==0.0.29 +langchain_core==0.1.33 +langchain_openai==0.0.8 +langchain_text_splitters==0.0.1 +FlagEmbedding==1.2.8 diff --git a/rag/src/config/config.py b/rag/src/config/config.py index 366cf85..bcb84a0 100644 --- a/rag/src/config/config.py +++ b/rag/src/config/config.py @@ -23,13 +23,18 @@ log_dir = os.path.join(base_dir, 'log') # log log_path = os.path.join(log_dir, 'log.log') # file # vector DB -vector_db_dir = os.path.join(data_dir, 'vector_db.pkl') +vector_db_dir = os.path.join(data_dir, 'vector_db') +# RAG related +# select num: 代表rerank 之后选取多少个 documents 进入 LLM +# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num) select_num = 3 retrieval_num = 10 -system_prompt = """ - 你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n -""" + +# LLM key +glm_key = '' + +# prompt prompt_template = """ {system_prompt} 根据下面检索回来的信息,回答问题。 diff --git a/rag/src/data_processing.py b/rag/src/data_processing.py index edf1a37..82aa628 100644 --- a/rag/src/data_processing.py +++ b/rag/src/data_processing.py @@ -1,24 +1,14 @@ import json import pickle -import faiss -import pickle import os from loguru import logger -from sentence_transformers import SentenceTransformer from langchain_community.vectorstores import FAISS from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path from langchain.embeddings import HuggingFaceBgeEmbeddings -from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader -from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter -from BCEmbedding import EmbeddingModel, RerankerModel -# from util.pipeline import EmoLLMRAG -from transformers import AutoTokenizer, AutoModelForCausalLM -from langchain.document_loaders.pdf import PyPDFDirectoryLoader -from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader -from langchain_community.llms import Cohere -from langchain.retrievers import ContextualCompressionRetriever -from langchain.retrievers.document_compressors import FlashrankRerank +from langchain_community.document_loaders import DirectoryLoader, TextLoader +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain.document_loaders import DirectoryLoader from langchain_core.documents.base import Document from FlagEmbedding import FlagReranker diff --git a/rag/src/main.py b/rag/src/main.py index abd6056..339a8a2 100644 --- a/rag/src/main.py +++ b/rag/src/main.py @@ -1,17 +1,9 @@ import os -import time -import jwt - -from config.config import base_dir, data_dir +from config.config import data_dir from data_processing import Data_process from pipeline import EmoLLMRAG - -from langchain_openai import ChatOpenAI +from util.llm import get_glm from loguru import logger -from transformers import AutoTokenizer, AutoModelForCausalLM -import torch -import streamlit as st -from openxlab.model import download ''' 1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer 2)调用 embedding 提供的接口对 query 向量化 @@ -21,46 +13,6 @@ from openxlab.model import download 6)拼接 prompt 并调用模型返回结果 ''' -def get_glm(temprature): - llm = ChatOpenAI( - model_name="glm-4", - openai_api_base="https://open.bigmodel.cn/api/paas/v4", - openai_api_key=generate_token("api-key"), - streaming=False, - temperature=temprature - ) - return llm - -def generate_token(apikey: str, exp_seconds: int=100): - try: - id, secret = apikey.split(".") - except Exception as e: - raise Exception("invalid apikey", e) - - payload = { - "api_key": id, - "exp": int(round(time.time() * 1000)) + exp_seconds * 1000, - "timestamp": int(round(time.time() * 1000)), - } - - return jwt.encode( - payload, - secret, - algorithm="HS256", - headers={"alg": "HS256", "sign_type": "SIGN"}, - ) - -@st.cache_resource -def load_model(): - model_dir = os.path.join(base_dir,'../model') - logger.info(f'Loading model from {model_dir}') - model = ( - AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True) - .to(torch.bfloat16) - .cuda() - ) - tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True) - return model, tokenizer def main(query, system_prompt=''): logger.info(data_dir) @@ -81,9 +33,28 @@ def main(query, system_prompt=''): if __name__ == "__main__": query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" - main(query) - #model = get_glm(0.7) - #rag_obj = EmoLLMRAG(model, 3) - #res = rag_obj.main(query) - #logger.info(res) + + """ + 输入: + model_name='glm-4', + api_base="https://open.bigmodel.cn/api/paas/v4", + temprature=0.7, + streaming=False, + 输出: + LLM Model + """ + model = get_glm() + + """ + 输入: + LLM model + retrieval_num=3 + rerank_flag=False + select_num-3 + """ + rag_obj = EmoLLMRAG(model) + + res = rag_obj.main(query) + + logger.info(res) diff --git a/rag/src/pipeline.py b/rag/src/pipeline.py index b81b26c..8f59f55 100644 --- a/rag/src/pipeline.py +++ b/rag/src/pipeline.py @@ -3,7 +3,7 @@ from langchain_core.prompts import PromptTemplate from transformers.utils import logging from data_processing import Data_process -from config.config import system_prompt, prompt_template +from config.config import prompt_template logger = logging.get_logger(__name__) @@ -16,7 +16,7 @@ class EmoLLMRAG(object): 4. 将 query 和检索回来的 content 传入 LLM 中 """ - def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None: + def __init__(self, model, retrieval_num=3, rerank_flag=False, select_num=3) -> None: """ 输入 Model 进行初始化 @@ -29,7 +29,6 @@ class EmoLLMRAG(object): self.model = model self.data_processing_obj = Data_process() self.vectorstores = self._load_vector_db() - self.system_prompt = system_prompt self.prompt_template = prompt_template self.retrieval_num = retrieval_num self.rerank_flag = rerank_flag @@ -75,7 +74,7 @@ class EmoLLMRAG(object): # 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中 prompt = PromptTemplate( template=self.prompt_template, - input_variables=["query", "content", "system_prompt"], + input_variables=["query", "content"], ) # 定义 chain @@ -87,7 +86,6 @@ class EmoLLMRAG(object): { "query": query, "content": content, - "system_prompt": self.system_prompt } ) return generation diff --git a/rag/src/util/llm.py b/rag/src/util/llm.py index e69de29..b254722 100644 --- a/rag/src/util/llm.py +++ b/rag/src/util/llm.py @@ -0,0 +1,44 @@ +import time +import jwt +from langchain_openai import ChatOpenAI +from config.config import glm_key + + +def get_glm( + model_name='glm-4', + api_base="https://open.bigmodel.cn/api/paas/v4", + temprature=0.7, + streaming=False, + ): + """ + + """ + + llm = ChatOpenAI( + model_name=model_name, + openai_api_base=api_base, + openai_api_key=generate_token(glm_key), + streaming=streaming, + temperature=temprature + ) + + return llm + +def generate_token(apikey: str, exp_seconds: int=100): + try: + id, secret = apikey.split(".") + except Exception as e: + raise Exception("invalid apikey", e) + + payload = { + "api_key": id, + "exp": int(round(time.time() * 1000)) + exp_seconds * 1000, + "timestamp": int(round(time.time() * 1000)), + } + + return jwt.encode( + payload, + secret, + algorithm="HS256", + headers={"alg": "HS256", "sign_type": "SIGN"}, + )