OliveSensorAPI/rag/src/main.py
2024-03-21 22:43:09 +08:00

90 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import time
import jwt
from config.config import base_dir, data_dir
from data_processing import Data_process
from pipeline import EmoLLMRAG
from langchain_openai import ChatOpenAI
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
'''
1构建完整的 RAG pipeline。输入为用户 query输出为 answer
2调用 embedding 提供的接口对 query 向量化
3下载基于 FAISS 预构建的 vector DB ,并检索对应信息
4调用 rerank 接口重排序检索内容
5调用 prompt 接口获取 system prompt 和 prompt template
6拼接 prompt 并调用模型返回结果
'''
def get_glm(temprature):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_base="https://open.bigmodel.cn/api/paas/v4",
openai_api_key=generate_token("api-key"),
streaming=False,
temperature=temprature
)
return llm
def generate_token(apikey: str, exp_seconds: int=100):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
@st.cache_resource
def load_model():
model_dir = os.path.join(base_dir,'../model')
logger.info(f'Loading model from {model_dir}')
model = (
AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
return model, tokenizer
def main(query, system_prompt=''):
logger.info(data_dir)
if not os.path.exists(data_dir):
os.mkdir(data_dir)
dp = Data_process()
vector_db = dp.load_vector_db()
docs, retriever = dp.retrieve(query, vector_db, k=10)
logger.info(f'Query: {query}')
logger.info("Retrieve results===============================")
for i, doc in enumerate(docs):
logger.info(doc)
passages,scores = dp.rerank(query, docs)
logger.info("After reranking===============================")
for i in range(len(scores)):
logger.info(passages[i])
logger.info(f'score: {str(scores[i])}')
if __name__ == "__main__":
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
main(query)
#model = get_glm(0.7)
#rag_obj = EmoLLMRAG(model, 3)
#res = rag_obj.main(query)
#logger.info(res)