Add RAG into internlm2
This commit is contained in:
parent
14890fad56
commit
2632ec390d
@ -56,6 +56,12 @@ JSON 数据格式如下
|
||||
会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl
|
||||
|
||||
如果已经有 vector DB 则会直接加载对应数据库
|
||||
**注意**: 可以直接从 xlab 下载对应 DB(请在rag文件目录下执行对应 code)
|
||||
```python
|
||||
# https://openxlab.org.cn/models/detail/Anooyman/EmoLLMRAGTXT/tree/main
|
||||
git lfs install
|
||||
git clone https://code.openxlab.org.cn/Anooyman/EmoLLMRAGTXT.git
|
||||
```
|
||||
|
||||
|
||||
### 配置 config 文件
|
||||
|
@ -20,6 +20,7 @@ knowledge_json_path = os.path.join(data_dir, 'knowledge.json') # json
|
||||
knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pkl
|
||||
doc_dir = os.path.join(data_dir, 'txt')
|
||||
qa_dir = os.path.join(data_dir, 'json')
|
||||
cloud_vector_db_dir = os.path.join(base_dir, 'EmoLLMRAGTXT')
|
||||
|
||||
# log
|
||||
log_dir = os.path.join(base_dir, 'log') # log
|
||||
@ -30,13 +31,13 @@ chunk_size=1000
|
||||
chunk_overlap=100
|
||||
|
||||
# vector DB
|
||||
vector_db_dir = os.path.join(data_dir, 'vector_db')
|
||||
vector_db_dir = os.path.join(cloud_vector_db_dir, 'vector_db')
|
||||
|
||||
# RAG related
|
||||
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
|
||||
# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num)
|
||||
select_num = 3
|
||||
retrieval_num = 10
|
||||
retrieval_num = 3
|
||||
|
||||
# LLM key
|
||||
glm_key = ''
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
|
||||
from loguru import logger
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from config.config import (
|
||||
from rag.src.config.config import (
|
||||
embedding_path,
|
||||
embedding_model_name,
|
||||
doc_dir, qa_dir,
|
||||
@ -19,7 +19,6 @@ from config.config import (
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain.document_loaders import DirectoryLoader
|
||||
from langchain_core.documents.base import Document
|
||||
from FlagEmbedding import FlagReranker
|
||||
|
||||
@ -199,7 +198,7 @@ class Data_process():
|
||||
创建并保存向量库
|
||||
'''
|
||||
logger.info(f'Creating index...')
|
||||
#split_doc = self.split_document(doc_dir)
|
||||
split_doc = self.split_document(doc_dir)
|
||||
split_qa = self.split_conversation(qa_dir)
|
||||
# logger.info(f'split_doc == {len(split_doc)}')
|
||||
# logger.info(f'split_qa == {len(split_qa)}')
|
||||
@ -218,7 +217,7 @@ class Data_process():
|
||||
if not os.path.exists(vector_db_dir) or not os.listdir(vector_db_dir):
|
||||
db = self.create_vector_db(emb_model)
|
||||
else:
|
||||
db = FAISS.load_local(vector_db_dir, emb_model, allow_dangerous_deserialization=True)
|
||||
db = FAISS.load_local(vector_db_dir, emb_model)
|
||||
return db
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2,8 +2,8 @@ from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from transformers.utils import logging
|
||||
|
||||
from data_processing import Data_process
|
||||
from config.config import prompt_template
|
||||
from rag.src.data_processing import Data_process
|
||||
from rag.src.config.config import prompt_template
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -48,19 +48,19 @@ class EmoLLMRAG(object):
|
||||
ouput: 检索后并且 rerank 的内容
|
||||
"""
|
||||
|
||||
content = ''
|
||||
content = []
|
||||
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
|
||||
|
||||
for doc in documents:
|
||||
content += doc.page_content
|
||||
content.append(doc.page_content)
|
||||
|
||||
# 如果需要rerank,调用接口对 documents 进行 rerank
|
||||
if self.rerank_flag:
|
||||
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
|
||||
|
||||
content = ''
|
||||
content = []
|
||||
for doc in documents:
|
||||
content += doc
|
||||
content.append(doc)
|
||||
logger.info(f'Retrieval data: {content}')
|
||||
return content
|
||||
|
||||
|
@ -12,6 +12,7 @@ import copy
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass
|
||||
from rag.src.pipeline import EmoLLMRAG
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import streamlit as st
|
||||
@ -188,8 +189,9 @@ robot_prompt = "<|im_start|>assistant\n{robot}<|im_end|>\n"
|
||||
cur_query_prompt = "<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
|
||||
def combine_history(prompt):
|
||||
def combine_history(prompt, retrieval_content=''):
|
||||
messages = st.session_state.messages
|
||||
prompt = f"你需要根据以下从书本中检索到的专业知识:`{retrieval_content}`。从一个心理专家的专业角度来回答后续提问:{prompt}"
|
||||
meta_instruction = (
|
||||
"你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发(排名按字母顺序排序,不分先后)、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。"
|
||||
)
|
||||
@ -211,6 +213,7 @@ def main():
|
||||
# torch.cuda.empty_cache()
|
||||
print("load model begin.")
|
||||
model, tokenizer = load_model()
|
||||
rag_obj = EmoLLMRAG(model)
|
||||
print("load model end.")
|
||||
|
||||
user_avator = "assets/user.png"
|
||||
@ -232,9 +235,12 @@ def main():
|
||||
# Accept user input
|
||||
if prompt := st.chat_input("What is up?"):
|
||||
# Display user message in chat message container
|
||||
retrieval_content = rag_obj.get_retrieval_content(prompt)
|
||||
with st.chat_message("user", avatar=user_avator):
|
||||
st.markdown(prompt)
|
||||
real_prompt = combine_history(prompt)
|
||||
#st.markdown(retrieval_content)
|
||||
|
||||
real_prompt = combine_history(prompt, retrieval_content)
|
||||
# Add user message to chat history
|
||||
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user