Update RAG pipeline

This commit is contained in:
Anooyman 2024-03-21 22:43:09 +08:00
parent 6c2c7496ba
commit 2d3bd4a8f5
4 changed files with 58 additions and 32 deletions

View File

@ -33,5 +33,5 @@ prompt_template = """
{system_prompt}
根据下面检索回来的信息回答问题
{content}
问题{question}
问题{query}
"""

View File

@ -12,7 +12,7 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
from BCEmbedding import EmbeddingModel, RerankerModel
from util.pipeline import EmoLLMRAG
# from util.pipeline import EmoLLMRAG
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
@ -254,7 +254,8 @@ if __name__ == "__main__":
# query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版1979年再版。"
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
# query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
query = "我现在心情非常差,有什么解决办法吗?"
docs, retriever = dp.retrieve(query, vector_db, k=10)
logger.info(f'Query: {query}')
logger.info("Retrieve results:")

View File

@ -1,20 +1,17 @@
import os
import json
import pickle
import numpy as np
from typing import Tuple
from sentence_transformers import SentenceTransformer
import time
import jwt
from config.config import knowledge_json_path, knowledge_pkl_path, model_repo, model_dir, base_dir
from util.encode import load_embedding, encode_qa
from util.pipeline import EmoLLMRAG
from config.config import base_dir, data_dir
from data_processing import Data_process
from pipeline import EmoLLMRAG
from langchain_openai import ChatOpenAI
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir
from data_processing import Data_process
'''
1构建完整的 RAG pipeline输入为用户 query输出为 answer
2调用 embedding 提供的接口对 query 向量化
@ -24,21 +21,45 @@ from data_processing import Data_process
6拼接 prompt 并调用模型返回结果
'''
# download(
# model_repo=model_repo,
# output='model'
# )
def get_glm(temprature):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_base="https://open.bigmodel.cn/api/paas/v4",
openai_api_key=generate_token("api-key"),
streaming=False,
temperature=temprature
)
return llm
def generate_token(apikey: str, exp_seconds: int=100):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
@st.cache_resource
def load_model():
model_dir = os.path.join(base_dir,'../model')
logger.info(f'Loading model from {model_dir}')
model = (
AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
return model, tokenizer
def main(query, system_prompt=''):
@ -60,4 +81,9 @@ def main(query, system_prompt=''):
if __name__ == "__main__":
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
main(query)
main(query)
#model = get_glm(0.7)
#rag_obj = EmoLLMRAG(model, 3)
#res = rag_obj.main(query)
#logger.info(res)

View File

@ -2,9 +2,8 @@ from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from transformers.utils import logging
from data_processing import DataProcessing
from config.config import retrieval_num, select_num, system_prompt, prompt_template
from data_processing import Data_process
from config.config import system_prompt, prompt_template
logger = logging.get_logger(__name__)
@ -28,10 +27,8 @@ class EmoLLMRAG(object):
"""
self.model = model
self.data_processing_obj = Data_process()
self.vectorstores = self._load_vector_db()
self.system_prompt = self._get_system_prompt()
self.prompt_template = self._get_prompt_template()
self.data_processing_obj = DataProcessing()
self.system_prompt = system_prompt
self.prompt_template = prompt_template
self.retrieval_num = retrieval_num
@ -43,8 +40,6 @@ class EmoLLMRAG(object):
调用 embedding 模块给出接口 load vector DB
"""
vectorstores = self.data_processing_obj.load_vector_db()
if not vectorstores:
vectorstores = self.data_processing_obj.load_index_and_knowledge()
return vectorstores
@ -57,13 +52,17 @@ class EmoLLMRAG(object):
content = ''
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
# 如果需要rerank调用接口对 documents 进行 rerank
if self.rerank_flag:
documents = self.data_processing_obj.rerank(documents, self.select_num)
for doc in documents:
content += doc.page_content
# 如果需要rerank调用接口对 documents 进行 rerank
if self.rerank_flag:
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
content = ''
for doc in documents:
content += doc
logger.info(f'Retrieval data: {content}')
return content
def generate_answer(self, query, content) -> str: