Update RAG pipeline

This commit is contained in:
Anooyman 2024-03-21 22:43:09 +08:00
parent 6c2c7496ba
commit 2d3bd4a8f5
4 changed files with 58 additions and 32 deletions

View File

@ -33,5 +33,5 @@ prompt_template = """
{system_prompt} {system_prompt}
根据下面检索回来的信息回答问题 根据下面检索回来的信息回答问题
{content} {content}
问题{question} 问题{query}
""" """

View File

@ -12,7 +12,7 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
from BCEmbedding import EmbeddingModel, RerankerModel from BCEmbedding import EmbeddingModel, RerankerModel
from util.pipeline import EmoLLMRAG # from util.pipeline import EmoLLMRAG
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.document_loaders.pdf import PyPDFDirectoryLoader from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
@ -254,7 +254,8 @@ if __name__ == "__main__":
# query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版1979年再版。" # query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版1979年再版。"
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?" # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
# query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性" # query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
query = "我现在心情非常差,有什么解决办法吗?"
docs, retriever = dp.retrieve(query, vector_db, k=10) docs, retriever = dp.retrieve(query, vector_db, k=10)
logger.info(f'Query: {query}') logger.info(f'Query: {query}')
logger.info("Retrieve results:") logger.info("Retrieve results:")

View File

@ -1,20 +1,17 @@
import os import os
import json import time
import pickle import jwt
import numpy as np
from typing import Tuple
from sentence_transformers import SentenceTransformer
from config.config import knowledge_json_path, knowledge_pkl_path, model_repo, model_dir, base_dir from config.config import base_dir, data_dir
from util.encode import load_embedding, encode_qa from data_processing import Data_process
from util.pipeline import EmoLLMRAG from pipeline import EmoLLMRAG
from langchain_openai import ChatOpenAI
from loguru import logger from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
import torch import torch
import streamlit as st import streamlit as st
from openxlab.model import download from openxlab.model import download
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir
from data_processing import Data_process
''' '''
1构建完整的 RAG pipeline输入为用户 query输出为 answer 1构建完整的 RAG pipeline输入为用户 query输出为 answer
2调用 embedding 提供的接口对 query 向量化 2调用 embedding 提供的接口对 query 向量化
@ -24,21 +21,45 @@ from data_processing import Data_process
6拼接 prompt 并调用模型返回结果 6拼接 prompt 并调用模型返回结果
''' '''
# download( def get_glm(temprature):
# model_repo=model_repo, llm = ChatOpenAI(
# output='model' model_name="glm-4",
# ) openai_api_base="https://open.bigmodel.cn/api/paas/v4",
openai_api_key=generate_token("api-key"),
streaming=False,
temperature=temprature
)
return llm
def generate_token(apikey: str, exp_seconds: int=100):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
@st.cache_resource @st.cache_resource
def load_model(): def load_model():
model_dir = os.path.join(base_dir,'../model') model_dir = os.path.join(base_dir,'../model')
logger.info(f'Loading model from {model_dir}') logger.info(f'Loading model from {model_dir}')
model = ( model = (
AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True) AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
.to(torch.bfloat16) .to(torch.bfloat16)
.cuda() .cuda()
) )
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
return model, tokenizer return model, tokenizer
def main(query, system_prompt=''): def main(query, system_prompt=''):
@ -61,3 +82,8 @@ def main(query, system_prompt=''):
if __name__ == "__main__": if __name__ == "__main__":
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
main(query) main(query)
#model = get_glm(0.7)
#rag_obj = EmoLLMRAG(model, 3)
#res = rag_obj.main(query)
#logger.info(res)

View File

@ -2,9 +2,8 @@ from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from transformers.utils import logging from transformers.utils import logging
from data_processing import DataProcessing from data_processing import Data_process
from config.config import retrieval_num, select_num, system_prompt, prompt_template from config.config import system_prompt, prompt_template
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -28,10 +27,8 @@ class EmoLLMRAG(object):
""" """
self.model = model self.model = model
self.data_processing_obj = Data_process()
self.vectorstores = self._load_vector_db() self.vectorstores = self._load_vector_db()
self.system_prompt = self._get_system_prompt()
self.prompt_template = self._get_prompt_template()
self.data_processing_obj = DataProcessing()
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.prompt_template = prompt_template self.prompt_template = prompt_template
self.retrieval_num = retrieval_num self.retrieval_num = retrieval_num
@ -43,8 +40,6 @@ class EmoLLMRAG(object):
调用 embedding 模块给出接口 load vector DB 调用 embedding 模块给出接口 load vector DB
""" """
vectorstores = self.data_processing_obj.load_vector_db() vectorstores = self.data_processing_obj.load_vector_db()
if not vectorstores:
vectorstores = self.data_processing_obj.load_index_and_knowledge()
return vectorstores return vectorstores
@ -57,13 +52,17 @@ class EmoLLMRAG(object):
content = '' content = ''
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num) documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
# 如果需要rerank调用接口对 documents 进行 rerank
if self.rerank_flag:
documents = self.data_processing_obj.rerank(documents, self.select_num)
for doc in documents: for doc in documents:
content += doc.page_content content += doc.page_content
# 如果需要rerank调用接口对 documents 进行 rerank
if self.rerank_flag:
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
content = ''
for doc in documents:
content += doc
logger.info(f'Retrieval data: {content}')
return content return content
def generate_answer(self, query, content) -> str: def generate_answer(self, query, content) -> str: