diff --git a/rag/README.md b/rag/README.md index 15af25c..ad87af3 100644 --- a/rag/README.md +++ b/rag/README.md @@ -56,6 +56,12 @@ JSON 数据格式如下 会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl 如果已经有 vector DB 则会直接加载对应数据库 +**注意**: 可以直接从 xlab 下载对应 DB(请在rag文件目录下执行对应 code) +```python +# https://openxlab.org.cn/models/detail/Anooyman/EmoLLMRAGTXT/tree/main +git lfs install +git clone https://code.openxlab.org.cn/Anooyman/EmoLLMRAGTXT.git +``` ### 配置 config 文件 diff --git a/rag/src/config/config.py b/rag/src/config/config.py index 3a1a6a9..fdb6fe1 100644 --- a/rag/src/config/config.py +++ b/rag/src/config/config.py @@ -20,6 +20,7 @@ knowledge_json_path = os.path.join(data_dir, 'knowledge.json') # json knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pkl doc_dir = os.path.join(data_dir, 'txt') qa_dir = os.path.join(data_dir, 'json') +cloud_vector_db_dir = os.path.join(base_dir, 'EmoLLMRAGTXT') # log log_dir = os.path.join(base_dir, 'log') # log @@ -30,13 +31,13 @@ chunk_size=1000 chunk_overlap=100 # vector DB -vector_db_dir = os.path.join(data_dir, 'vector_db') +vector_db_dir = os.path.join(cloud_vector_db_dir, 'vector_db') # RAG related # select num: 代表rerank 之后选取多少个 documents 进入 LLM # retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num) select_num = 3 -retrieval_num = 10 +retrieval_num = 3 # LLM key glm_key = '' diff --git a/rag/src/data_processing.py b/rag/src/data_processing.py index d894faa..6c4103b 100644 --- a/rag/src/data_processing.py +++ b/rag/src/data_processing.py @@ -4,7 +4,7 @@ import os from loguru import logger from langchain_community.vectorstores import FAISS -from config.config import ( +from rag.src.config.config import ( embedding_path, embedding_model_name, doc_dir, qa_dir, @@ -19,7 +19,6 @@ from config.config import ( from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain_community.document_loaders import DirectoryLoader, TextLoader from langchain_text_splitters import RecursiveCharacterTextSplitter -from langchain.document_loaders import DirectoryLoader from langchain_core.documents.base import Document from FlagEmbedding import FlagReranker @@ -199,7 +198,7 @@ class Data_process(): 创建并保存向量库 ''' logger.info(f'Creating index...') - #split_doc = self.split_document(doc_dir) + split_doc = self.split_document(doc_dir) split_qa = self.split_conversation(qa_dir) # logger.info(f'split_doc == {len(split_doc)}') # logger.info(f'split_qa == {len(split_qa)}') @@ -218,7 +217,7 @@ class Data_process(): if not os.path.exists(vector_db_dir) or not os.listdir(vector_db_dir): db = self.create_vector_db(emb_model) else: - db = FAISS.load_local(vector_db_dir, emb_model, allow_dangerous_deserialization=True) + db = FAISS.load_local(vector_db_dir, emb_model) return db if __name__ == "__main__": diff --git a/rag/src/pipeline.py b/rag/src/pipeline.py index 8f59f55..08b9b96 100644 --- a/rag/src/pipeline.py +++ b/rag/src/pipeline.py @@ -2,8 +2,8 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate from transformers.utils import logging -from data_processing import Data_process -from config.config import prompt_template +from rag.src.data_processing import Data_process +from rag.src.config.config import prompt_template logger = logging.get_logger(__name__) @@ -48,19 +48,19 @@ class EmoLLMRAG(object): ouput: 检索后并且 rerank 的内容 """ - content = '' + content = [] documents = self.vectorstores.similarity_search(query, k=self.retrieval_num) for doc in documents: - content += doc.page_content + content.append(doc.page_content) # 如果需要rerank,调用接口对 documents 进行 rerank if self.rerank_flag: documents, _ = self.data_processing_obj.rerank(documents, self.select_num) - content = '' + content = [] for doc in documents: - content += doc + content.append(doc) logger.info(f'Retrieval data: {content}') return content diff --git a/web_internlm2.py b/web_internlm2.py index 59640fc..b907156 100644 --- a/web_internlm2.py +++ b/web_internlm2.py @@ -12,6 +12,7 @@ import copy import os import warnings from dataclasses import asdict, dataclass +from rag.src.pipeline import EmoLLMRAG from typing import Callable, List, Optional import streamlit as st @@ -188,8 +189,9 @@ robot_prompt = "<|im_start|>assistant\n{robot}<|im_end|>\n" cur_query_prompt = "<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n" -def combine_history(prompt): +def combine_history(prompt, retrieval_content=''): messages = st.session_state.messages + prompt = f"你需要根据以下从书本中检索到的专业知识:`{retrieval_content}`。从一个心理专家的专业角度来回答后续提问:{prompt}" meta_instruction = ( "你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发(排名按字母顺序排序,不分先后)、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。" ) @@ -211,6 +213,7 @@ def main(): # torch.cuda.empty_cache() print("load model begin.") model, tokenizer = load_model() + rag_obj = EmoLLMRAG(model) print("load model end.") user_avator = "assets/user.png" @@ -232,9 +235,12 @@ def main(): # Accept user input if prompt := st.chat_input("What is up?"): # Display user message in chat message container + retrieval_content = rag_obj.get_retrieval_content(prompt) with st.chat_message("user", avatar=user_avator): st.markdown(prompt) - real_prompt = combine_history(prompt) + #st.markdown(retrieval_content) + + real_prompt = combine_history(prompt, retrieval_content) # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})