OliveSensorAPI/rag/src/main.py

90 lines
2.9 KiB
Python
Raw Normal View History

2024-03-07 18:05:10 +08:00
import os
2024-03-21 22:43:09 +08:00
import time
import jwt
from config.config import base_dir, data_dir
from data_processing import Data_process
from pipeline import EmoLLMRAG
from langchain_openai import ChatOpenAI
2024-03-18 10:33:01 +08:00
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
2024-03-18 10:33:01 +08:00
'''
1构建完整的 RAG pipeline输入为用户 query输出为 answer
2调用 embedding 提供的接口对 query 向量化
3下载基于 FAISS 预构建的 vector DB 并检索对应信息
4调用 rerank 接口重排序检索内容
5调用 prompt 接口获取 system prompt prompt template
6拼接 prompt 并调用模型返回结果
2024-03-07 18:05:10 +08:00
2024-03-18 10:33:01 +08:00
'''
2024-03-21 22:43:09 +08:00
def get_glm(temprature):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_base="https://open.bigmodel.cn/api/paas/v4",
openai_api_key=generate_token("api-key"),
streaming=False,
temperature=temprature
)
return llm
def generate_token(apikey: str, exp_seconds: int=100):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
2024-03-07 18:05:10 +08:00
@st.cache_resource
def load_model():
2024-03-18 10:33:01 +08:00
model_dir = os.path.join(base_dir,'../model')
logger.info(f'Loading model from {model_dir}')
model = (
2024-03-21 22:43:09 +08:00
AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
2024-03-21 22:43:09 +08:00
tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
return model, tokenizer
2024-03-07 18:05:10 +08:00
def main(query, system_prompt=''):
logger.info(data_dir)
2024-03-18 10:33:01 +08:00
if not os.path.exists(data_dir):
os.mkdir(data_dir)
dp = Data_process()
vector_db = dp.load_vector_db()
docs, retriever = dp.retrieve(query, vector_db, k=10)
logger.info(f'Query: {query}')
logger.info("Retrieve results===============================")
for i, doc in enumerate(docs):
logger.info(doc)
passages,scores = dp.rerank(query, docs)
logger.info("After reranking===============================")
for i in range(len(scores)):
logger.info(passages[i])
logger.info(f'score: {str(scores[i])}')
if __name__ == "__main__":
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
2024-03-21 22:43:09 +08:00
main(query)
#model = get_glm(0.7)
#rag_obj = EmoLLMRAG(model, 3)
#res = rag_obj.main(query)
#logger.info(res)