Add RAG into internlm2

This commit is contained in:
Anooyman 2024-04-14 12:22:35 +08:00
parent 14890fad56
commit 2632ec390d
5 changed files with 26 additions and 14 deletions

View File

@ -56,6 +56,12 @@ JSON 数据格式如下
会根据准备的数据构建vector DB最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl 会根据准备的数据构建vector DB最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl
如果已经有 vector DB 则会直接加载对应数据库 如果已经有 vector DB 则会直接加载对应数据库
**注意**: 可以直接从 xlab 下载对应 DB请在rag文件目录下执行对应 code
```python
# https://openxlab.org.cn/models/detail/Anooyman/EmoLLMRAGTXT/tree/main
git lfs install
git clone https://code.openxlab.org.cn/Anooyman/EmoLLMRAGTXT.git
```
### 配置 config 文件 ### 配置 config 文件

View File

@ -20,6 +20,7 @@ knowledge_json_path = os.path.join(data_dir, 'knowledge.json') # json
knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pkl knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pkl
doc_dir = os.path.join(data_dir, 'txt') doc_dir = os.path.join(data_dir, 'txt')
qa_dir = os.path.join(data_dir, 'json') qa_dir = os.path.join(data_dir, 'json')
cloud_vector_db_dir = os.path.join(base_dir, 'EmoLLMRAGTXT')
# log # log
log_dir = os.path.join(base_dir, 'log') # log log_dir = os.path.join(base_dir, 'log') # log
@ -30,13 +31,13 @@ chunk_size=1000
chunk_overlap=100 chunk_overlap=100
# vector DB # vector DB
vector_db_dir = os.path.join(data_dir, 'vector_db') vector_db_dir = os.path.join(cloud_vector_db_dir, 'vector_db')
# RAG related # RAG related
# select num: 代表rerank 之后选取多少个 documents 进入 LLM # select num: 代表rerank 之后选取多少个 documents 进入 LLM
# retrieval num 代表从 vector db 中检索多少 documents。retrieval num 应该大于等于 select num # retrieval num 代表从 vector db 中检索多少 documents。retrieval num 应该大于等于 select num
select_num = 3 select_num = 3
retrieval_num = 10 retrieval_num = 3
# LLM key # LLM key
glm_key = '' glm_key = ''

View File

@ -4,7 +4,7 @@ import os
from loguru import logger from loguru import logger
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
from config.config import ( from rag.src.config.config import (
embedding_path, embedding_path,
embedding_model_name, embedding_model_name,
doc_dir, qa_dir, doc_dir, qa_dir,
@ -19,7 +19,6 @@ from config.config import (
from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import DirectoryLoader, TextLoader from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader
from langchain_core.documents.base import Document from langchain_core.documents.base import Document
from FlagEmbedding import FlagReranker from FlagEmbedding import FlagReranker
@ -199,7 +198,7 @@ class Data_process():
创建并保存向量库 创建并保存向量库
''' '''
logger.info(f'Creating index...') logger.info(f'Creating index...')
#split_doc = self.split_document(doc_dir) split_doc = self.split_document(doc_dir)
split_qa = self.split_conversation(qa_dir) split_qa = self.split_conversation(qa_dir)
# logger.info(f'split_doc == {len(split_doc)}') # logger.info(f'split_doc == {len(split_doc)}')
# logger.info(f'split_qa == {len(split_qa)}') # logger.info(f'split_qa == {len(split_qa)}')
@ -218,7 +217,7 @@ class Data_process():
if not os.path.exists(vector_db_dir) or not os.listdir(vector_db_dir): if not os.path.exists(vector_db_dir) or not os.listdir(vector_db_dir):
db = self.create_vector_db(emb_model) db = self.create_vector_db(emb_model)
else: else:
db = FAISS.load_local(vector_db_dir, emb_model, allow_dangerous_deserialization=True) db = FAISS.load_local(vector_db_dir, emb_model)
return db return db
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,8 +2,8 @@ from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from transformers.utils import logging from transformers.utils import logging
from data_processing import Data_process from rag.src.data_processing import Data_process
from config.config import prompt_template from rag.src.config.config import prompt_template
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -48,19 +48,19 @@ class EmoLLMRAG(object):
ouput: 检索后并且 rerank 的内容 ouput: 检索后并且 rerank 的内容
""" """
content = '' content = []
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num) documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
for doc in documents: for doc in documents:
content += doc.page_content content.append(doc.page_content)
# 如果需要rerank调用接口对 documents 进行 rerank # 如果需要rerank调用接口对 documents 进行 rerank
if self.rerank_flag: if self.rerank_flag:
documents, _ = self.data_processing_obj.rerank(documents, self.select_num) documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
content = '' content = []
for doc in documents: for doc in documents:
content += doc content.append(doc)
logger.info(f'Retrieval data: {content}') logger.info(f'Retrieval data: {content}')
return content return content

View File

@ -12,6 +12,7 @@ import copy
import os import os
import warnings import warnings
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from rag.src.pipeline import EmoLLMRAG
from typing import Callable, List, Optional from typing import Callable, List, Optional
import streamlit as st import streamlit as st
@ -188,8 +189,9 @@ robot_prompt = "<|im_start|>assistant\n{robot}<|im_end|>\n"
cur_query_prompt = "<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n" cur_query_prompt = "<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n"
def combine_history(prompt): def combine_history(prompt, retrieval_content=''):
messages = st.session_state.messages messages = st.session_state.messages
prompt = f"你需要根据以下从书本中检索到的专业知识:`{retrieval_content}`。从一个心理专家的专业角度来回答后续提问:{prompt}"
meta_instruction = ( meta_instruction = (
"你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发排名按字母顺序排序不分先后、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家我有一些心理问题请你用专业的知识帮我解决。" "你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发排名按字母顺序排序不分先后、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家我有一些心理问题请你用专业的知识帮我解决。"
) )
@ -211,6 +213,7 @@ def main():
# torch.cuda.empty_cache() # torch.cuda.empty_cache()
print("load model begin.") print("load model begin.")
model, tokenizer = load_model() model, tokenizer = load_model()
rag_obj = EmoLLMRAG(model)
print("load model end.") print("load model end.")
user_avator = "assets/user.png" user_avator = "assets/user.png"
@ -232,9 +235,12 @@ def main():
# Accept user input # Accept user input
if prompt := st.chat_input("What is up?"): if prompt := st.chat_input("What is up?"):
# Display user message in chat message container # Display user message in chat message container
retrieval_content = rag_obj.get_retrieval_content(prompt)
with st.chat_message("user", avatar=user_avator): with st.chat_message("user", avatar=user_avator):
st.markdown(prompt) st.markdown(prompt)
real_prompt = combine_history(prompt) #st.markdown(retrieval_content)
real_prompt = combine_history(prompt, retrieval_content)
# Add user message to chat history # Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator}) st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})