diff --git a/rag/src/main.py b/rag/src/main.py index 97f60a0..7dd7639 100644 --- a/rag/src/main.py +++ b/rag/src/main.py @@ -5,87 +5,67 @@ import numpy as np from typing import Tuple from sentence_transformers import SentenceTransformer -from config.config import knowledge_json_path, knowledge_pkl_path, model_repo +from config.config import knowledge_json_path, knowledge_pkl_path, model_repo, model_dir, base_dir from util.encode import load_embedding, encode_qa from util.pipeline import EmoLLMRAG - +from loguru import logger from transformers import AutoTokenizer, AutoModelForCausalLM import torch import streamlit as st from openxlab.model import download +from data_processing import load_index_and_knowledge, create_index_cpu, create_index_gpu, find_top_k, rerank +from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir -download( - model_repo=model_repo, - output='model' -) +''' + 1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer + 2)调用 embedding 提供的接口对 query 向量化 + 3)下载基于 FAISS 预构建的 vector DB ,并检索对应信息 + 4)调用 rerank 接口重排序检索内容 + 5)调用 prompt 接口获取 system prompt 和 prompt template + 6)拼接 prompt 并调用模型返回结果 - -""" -读取知识库 -""" -def load_knowledge() -> Tuple[list, list]: - # 如果 pkl 不存在,则先编码存储 - if not os.path.exists(knowledge_pkl_path): - encode_qa(knowledge_json_path, knowledge_pkl_path) - - # 加载 json 和 pkl - with open(knowledge_json_path, 'r', encoding='utf-8') as f1, open(knowledge_pkl_path, 'rb') as f2: - knowledge = json.load(f1) - encoded_knowledge = pickle.load(f2) - return knowledge, encoded_knowledge - - -""" -召回 top_k 个相关的文本段 -""" -def find_top_k( - emb: SentenceTransformer, - query: str, - knowledge: list, - encoded_knowledge: list, - k=3 -) -> list[str]: - # 编码 query - query_embedding = emb.encode(query) - - # 查找 top_k - scores = query_embedding @ encoded_knowledge.T - # 使用 argpartition 找出每行第 k 个大的值的索引,第 k 个位置左侧都是比它大的值,右侧都是比它小的值 - top_k_indices = np.argpartition(scores, -k)[-k:] - # 由于 argpartition 不保证顺序,我们需要对提取出的 k 个索引进行排序 - top_k_values_sorted_indices = np.argsort(scores[top_k_indices])[::-1] - top_k_indices = top_k_indices[top_k_values_sorted_indices] - - # 返回 - contents = [knowledge[index] for index in top_k_indices] - return contents - - -def main(): - emb = load_embedding() - knowledge, encoded_knowledge = load_knowledge() - query = "认知心理学研究哪些心理活动?" - contents = find_top_k(emb, query, knowledge, encoded_knowledge, 2) - print('召回的 top-k 条相关内容如下:') - print(json.dumps(contents, ensure_ascii=False, indent=2)) - # 这里我没实现 LLM 部分,如果有 LLM - ## 1. 读取 LLM - ## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容} - ## 3. 要求 LLM 根据已知内容回复 +''' +# download( +# model_repo=model_repo, +# output='model' +# ) @st.cache_resource def load_model(): + model_dir = os.path.join(base_dir,'../model') + logger.info(f'Loading model from {model_dir}') model = ( - AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True) + AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True) .to(torch.bfloat16) .cuda() ) - tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) return model, tokenizer -if __name__ == '__main__': - #main() - query = '' +def get_prompt(): + pass + +def get_prompt_template(): + pass + +def main(query, system_prompt): model, tokenizer = load_model() - rag_obj = EmoLLMRAG(model) - response = rag_obj.main(query) \ No newline at end of file + model = model.eval() + if not os.path.exists(data_dir): + os.mkdir(data_dir) + # 下载基于 FAISS 预构建的 vector DB 以及原始知识库 + faiss_index, knowledge_chunks = load_index_and_knowledge() + distances, indices = find_top_k(query, faiss_index, 5) + rerank_results = rerank(query, indices, knowledge_chunks) + messages = [(system_prompt, rerank_results['rerank_passages'][0])] + logger.info(f'messages:{messages}') + response, history = model.chat(tokenizer, query, history=messages) + messages.append((query, response)) + print(f"robot >>> {response}") + +if __name__ == '__main__': + # query = '你好' + query = "心理咨询师,我觉得我的胸闷症状越来越严重了,这让我很害怕" + #TODO system_prompt = get_prompt() + system_prompt = "你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发(排名按字母顺序排序,不分先后)、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。" + main(query, system_prompt) \ No newline at end of file