update main.py

This commit is contained in:
zealot52099 2024-03-18 10:33:01 +08:00
parent 5879afffe6
commit 74db6d9893

View File

@ -5,87 +5,67 @@ import numpy as np
from typing import Tuple from typing import Tuple
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from config.config import knowledge_json_path, knowledge_pkl_path, model_repo from config.config import knowledge_json_path, knowledge_pkl_path, model_repo, model_dir, base_dir
from util.encode import load_embedding, encode_qa from util.encode import load_embedding, encode_qa
from util.pipeline import EmoLLMRAG from util.pipeline import EmoLLMRAG
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
import torch import torch
import streamlit as st import streamlit as st
from openxlab.model import download from openxlab.model import download
from data_processing import load_index_and_knowledge, create_index_cpu, create_index_gpu, find_top_k, rerank
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir
download( '''
model_repo=model_repo, 1构建完整的 RAG pipeline输入为用户 query输出为 answer
output='model' 2调用 embedding 提供的接口对 query 向量化
) 3下载基于 FAISS 预构建的 vector DB 并检索对应信息
4调用 rerank 接口重排序检索内容
5调用 prompt 接口获取 system prompt prompt template
6拼接 prompt 并调用模型返回结果
'''
""" # download(
读取知识库 # model_repo=model_repo,
""" # output='model'
def load_knowledge() -> Tuple[list, list]: # )
# 如果 pkl 不存在,则先编码存储
if not os.path.exists(knowledge_pkl_path):
encode_qa(knowledge_json_path, knowledge_pkl_path)
# 加载 json 和 pkl
with open(knowledge_json_path, 'r', encoding='utf-8') as f1, open(knowledge_pkl_path, 'rb') as f2:
knowledge = json.load(f1)
encoded_knowledge = pickle.load(f2)
return knowledge, encoded_knowledge
"""
召回 top_k 个相关的文本段
"""
def find_top_k(
emb: SentenceTransformer,
query: str,
knowledge: list,
encoded_knowledge: list,
k=3
) -> list[str]:
# 编码 query
query_embedding = emb.encode(query)
# 查找 top_k
scores = query_embedding @ encoded_knowledge.T
# 使用 argpartition 找出每行第 k 个大的值的索引,第 k 个位置左侧都是比它大的值,右侧都是比它小的值
top_k_indices = np.argpartition(scores, -k)[-k:]
# 由于 argpartition 不保证顺序,我们需要对提取出的 k 个索引进行排序
top_k_values_sorted_indices = np.argsort(scores[top_k_indices])[::-1]
top_k_indices = top_k_indices[top_k_values_sorted_indices]
# 返回
contents = [knowledge[index] for index in top_k_indices]
return contents
def main():
emb = load_embedding()
knowledge, encoded_knowledge = load_knowledge()
query = "认知心理学研究哪些心理活动?"
contents = find_top_k(emb, query, knowledge, encoded_knowledge, 2)
print('召回的 top-k 条相关内容如下:')
print(json.dumps(contents, ensure_ascii=False, indent=2))
# 这里我没实现 LLM 部分,如果有 LLM
## 1. 读取 LLM
## 2. 将 contents 拼接为 prompt传给 LLM作为 {已知内容}
## 3. 要求 LLM 根据已知内容回复
@st.cache_resource @st.cache_resource
def load_model(): def load_model():
model_dir = os.path.join(base_dir,'../model')
logger.info(f'Loading model from {model_dir}')
model = ( model = (
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True) AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
.to(torch.bfloat16) .to(torch.bfloat16)
.cuda() .cuda()
) )
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
return model, tokenizer return model, tokenizer
if __name__ == '__main__': def get_prompt():
#main() pass
query = ''
def get_prompt_template():
pass
def main(query, system_prompt):
model, tokenizer = load_model() model, tokenizer = load_model()
rag_obj = EmoLLMRAG(model) model = model.eval()
response = rag_obj.main(query) if not os.path.exists(data_dir):
os.mkdir(data_dir)
# 下载基于 FAISS 预构建的 vector DB 以及原始知识库
faiss_index, knowledge_chunks = load_index_and_knowledge()
distances, indices = find_top_k(query, faiss_index, 5)
rerank_results = rerank(query, indices, knowledge_chunks)
messages = [(system_prompt, rerank_results['rerank_passages'][0])]
logger.info(f'messages:{messages}')
response, history = model.chat(tokenizer, query, history=messages)
messages.append((query, response))
print(f"robot >>> {response}")
if __name__ == '__main__':
# query = '你好'
query = "心理咨询师,我觉得我的胸闷症状越来越严重了,这让我很害怕"
#TODO system_prompt = get_prompt()
system_prompt = "你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发排名按字母顺序排序不分先后、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家我有一些心理问题请你用专业的知识帮我解决。"
main(query, system_prompt)