From 2d3bd4a8f53de95eceb56ea481718f319ca5ca07 Mon Sep 17 00:00:00 2001 From: Anooyman <875734078@qq.com> Date: Thu, 21 Mar 2024 22:43:09 +0800 Subject: [PATCH] Update RAG pipeline --- rag/src/config/config.py | 2 +- rag/src/data_processing.py | 5 ++-- rag/src/main.py | 60 +++++++++++++++++++++++++++----------- rag/src/pipeline.py | 23 +++++++-------- 4 files changed, 58 insertions(+), 32 deletions(-) diff --git a/rag/src/config/config.py b/rag/src/config/config.py index d4dcfe3..5b81b72 100644 --- a/rag/src/config/config.py +++ b/rag/src/config/config.py @@ -33,5 +33,5 @@ prompt_template = """ {system_prompt} 根据下面检索回来的信息,回答问题。 {content} - 问题:{question} + 问题:{query} """ \ No newline at end of file diff --git a/rag/src/data_processing.py b/rag/src/data_processing.py index 334ce13..55a439f 100644 --- a/rag/src/data_processing.py +++ b/rag/src/data_processing.py @@ -12,7 +12,7 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter from BCEmbedding import EmbeddingModel, RerankerModel -from util.pipeline import EmoLLMRAG +# from util.pipeline import EmoLLMRAG from transformers import AutoTokenizer, AutoModelForCausalLM from langchain.document_loaders.pdf import PyPDFDirectoryLoader from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader @@ -254,7 +254,8 @@ if __name__ == "__main__": # query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版,1979年再版。" # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?" # query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性" - query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" + # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" + query = "我现在心情非常差,有什么解决办法吗?" docs, retriever = dp.retrieve(query, vector_db, k=10) logger.info(f'Query: {query}') logger.info("Retrieve results:") diff --git a/rag/src/main.py b/rag/src/main.py index 86a2f04..abd6056 100644 --- a/rag/src/main.py +++ b/rag/src/main.py @@ -1,20 +1,17 @@ import os -import json -import pickle -import numpy as np -from typing import Tuple -from sentence_transformers import SentenceTransformer +import time +import jwt -from config.config import knowledge_json_path, knowledge_pkl_path, model_repo, model_dir, base_dir -from util.encode import load_embedding, encode_qa -from util.pipeline import EmoLLMRAG +from config.config import base_dir, data_dir +from data_processing import Data_process +from pipeline import EmoLLMRAG + +from langchain_openai import ChatOpenAI from loguru import logger from transformers import AutoTokenizer, AutoModelForCausalLM import torch import streamlit as st from openxlab.model import download -from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir -from data_processing import Data_process ''' 1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer 2)调用 embedding 提供的接口对 query 向量化 @@ -24,21 +21,45 @@ from data_processing import Data_process 6)拼接 prompt 并调用模型返回结果 ''' -# download( -# model_repo=model_repo, -# output='model' -# ) +def get_glm(temprature): + llm = ChatOpenAI( + model_name="glm-4", + openai_api_base="https://open.bigmodel.cn/api/paas/v4", + openai_api_key=generate_token("api-key"), + streaming=False, + temperature=temprature + ) + return llm + +def generate_token(apikey: str, exp_seconds: int=100): + try: + id, secret = apikey.split(".") + except Exception as e: + raise Exception("invalid apikey", e) + + payload = { + "api_key": id, + "exp": int(round(time.time() * 1000)) + exp_seconds * 1000, + "timestamp": int(round(time.time() * 1000)), + } + + return jwt.encode( + payload, + secret, + algorithm="HS256", + headers={"alg": "HS256", "sign_type": "SIGN"}, + ) @st.cache_resource def load_model(): model_dir = os.path.join(base_dir,'../model') logger.info(f'Loading model from {model_dir}') model = ( - AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True) + AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True) .to(torch.bfloat16) .cuda() ) - tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True) return model, tokenizer def main(query, system_prompt=''): @@ -60,4 +81,9 @@ def main(query, system_prompt=''): if __name__ == "__main__": query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" - main(query) \ No newline at end of file + main(query) + #model = get_glm(0.7) + #rag_obj = EmoLLMRAG(model, 3) + #res = rag_obj.main(query) + #logger.info(res) + diff --git a/rag/src/pipeline.py b/rag/src/pipeline.py index 214eef3..b81b26c 100644 --- a/rag/src/pipeline.py +++ b/rag/src/pipeline.py @@ -2,9 +2,8 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate from transformers.utils import logging -from data_processing import DataProcessing -from config.config import retrieval_num, select_num, system_prompt, prompt_template - +from data_processing import Data_process +from config.config import system_prompt, prompt_template logger = logging.get_logger(__name__) @@ -28,10 +27,8 @@ class EmoLLMRAG(object): """ self.model = model + self.data_processing_obj = Data_process() self.vectorstores = self._load_vector_db() - self.system_prompt = self._get_system_prompt() - self.prompt_template = self._get_prompt_template() - self.data_processing_obj = DataProcessing() self.system_prompt = system_prompt self.prompt_template = prompt_template self.retrieval_num = retrieval_num @@ -43,8 +40,6 @@ class EmoLLMRAG(object): 调用 embedding 模块给出接口 load vector DB """ vectorstores = self.data_processing_obj.load_vector_db() - if not vectorstores: - vectorstores = self.data_processing_obj.load_index_and_knowledge() return vectorstores @@ -57,13 +52,17 @@ class EmoLLMRAG(object): content = '' documents = self.vectorstores.similarity_search(query, k=self.retrieval_num) - # 如果需要rerank,调用接口对 documents 进行 rerank - if self.rerank_flag: - documents = self.data_processing_obj.rerank(documents, self.select_num) - for doc in documents: content += doc.page_content + # 如果需要rerank,调用接口对 documents 进行 rerank + if self.rerank_flag: + documents, _ = self.data_processing_obj.rerank(documents, self.select_num) + + content = '' + for doc in documents: + content += doc + logger.info(f'Retrieval data: {content}') return content def generate_answer(self, query, content) -> str: