Add RAG into internlm2
This commit is contained in:
parent
14890fad56
commit
2632ec390d
@ -56,6 +56,12 @@ JSON 数据格式如下
|
|||||||
会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl
|
会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl
|
||||||
|
|
||||||
如果已经有 vector DB 则会直接加载对应数据库
|
如果已经有 vector DB 则会直接加载对应数据库
|
||||||
|
**注意**: 可以直接从 xlab 下载对应 DB(请在rag文件目录下执行对应 code)
|
||||||
|
```python
|
||||||
|
# https://openxlab.org.cn/models/detail/Anooyman/EmoLLMRAGTXT/tree/main
|
||||||
|
git lfs install
|
||||||
|
git clone https://code.openxlab.org.cn/Anooyman/EmoLLMRAGTXT.git
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### 配置 config 文件
|
### 配置 config 文件
|
||||||
|
@ -20,6 +20,7 @@ knowledge_json_path = os.path.join(data_dir, 'knowledge.json') # json
|
|||||||
knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pkl
|
knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pkl
|
||||||
doc_dir = os.path.join(data_dir, 'txt')
|
doc_dir = os.path.join(data_dir, 'txt')
|
||||||
qa_dir = os.path.join(data_dir, 'json')
|
qa_dir = os.path.join(data_dir, 'json')
|
||||||
|
cloud_vector_db_dir = os.path.join(base_dir, 'EmoLLMRAGTXT')
|
||||||
|
|
||||||
# log
|
# log
|
||||||
log_dir = os.path.join(base_dir, 'log') # log
|
log_dir = os.path.join(base_dir, 'log') # log
|
||||||
@ -30,13 +31,13 @@ chunk_size=1000
|
|||||||
chunk_overlap=100
|
chunk_overlap=100
|
||||||
|
|
||||||
# vector DB
|
# vector DB
|
||||||
vector_db_dir = os.path.join(data_dir, 'vector_db')
|
vector_db_dir = os.path.join(cloud_vector_db_dir, 'vector_db')
|
||||||
|
|
||||||
# RAG related
|
# RAG related
|
||||||
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
|
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
|
||||||
# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num)
|
# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num)
|
||||||
select_num = 3
|
select_num = 3
|
||||||
retrieval_num = 10
|
retrieval_num = 3
|
||||||
|
|
||||||
# LLM key
|
# LLM key
|
||||||
glm_key = ''
|
glm_key = ''
|
||||||
|
@ -4,7 +4,7 @@ import os
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from langchain_community.vectorstores import FAISS
|
from langchain_community.vectorstores import FAISS
|
||||||
from config.config import (
|
from rag.src.config.config import (
|
||||||
embedding_path,
|
embedding_path,
|
||||||
embedding_model_name,
|
embedding_model_name,
|
||||||
doc_dir, qa_dir,
|
doc_dir, qa_dir,
|
||||||
@ -19,7 +19,6 @@ from config.config import (
|
|||||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
from langchain.document_loaders import DirectoryLoader
|
|
||||||
from langchain_core.documents.base import Document
|
from langchain_core.documents.base import Document
|
||||||
from FlagEmbedding import FlagReranker
|
from FlagEmbedding import FlagReranker
|
||||||
|
|
||||||
@ -199,7 +198,7 @@ class Data_process():
|
|||||||
创建并保存向量库
|
创建并保存向量库
|
||||||
'''
|
'''
|
||||||
logger.info(f'Creating index...')
|
logger.info(f'Creating index...')
|
||||||
#split_doc = self.split_document(doc_dir)
|
split_doc = self.split_document(doc_dir)
|
||||||
split_qa = self.split_conversation(qa_dir)
|
split_qa = self.split_conversation(qa_dir)
|
||||||
# logger.info(f'split_doc == {len(split_doc)}')
|
# logger.info(f'split_doc == {len(split_doc)}')
|
||||||
# logger.info(f'split_qa == {len(split_qa)}')
|
# logger.info(f'split_qa == {len(split_qa)}')
|
||||||
@ -218,7 +217,7 @@ class Data_process():
|
|||||||
if not os.path.exists(vector_db_dir) or not os.listdir(vector_db_dir):
|
if not os.path.exists(vector_db_dir) or not os.listdir(vector_db_dir):
|
||||||
db = self.create_vector_db(emb_model)
|
db = self.create_vector_db(emb_model)
|
||||||
else:
|
else:
|
||||||
db = FAISS.load_local(vector_db_dir, emb_model, allow_dangerous_deserialization=True)
|
db = FAISS.load_local(vector_db_dir, emb_model)
|
||||||
return db
|
return db
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -2,8 +2,8 @@ from langchain_core.output_parsers import StrOutputParser
|
|||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from data_processing import Data_process
|
from rag.src.data_processing import Data_process
|
||||||
from config.config import prompt_template
|
from rag.src.config.config import prompt_template
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -48,19 +48,19 @@ class EmoLLMRAG(object):
|
|||||||
ouput: 检索后并且 rerank 的内容
|
ouput: 检索后并且 rerank 的内容
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content = ''
|
content = []
|
||||||
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
|
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
content += doc.page_content
|
content.append(doc.page_content)
|
||||||
|
|
||||||
# 如果需要rerank,调用接口对 documents 进行 rerank
|
# 如果需要rerank,调用接口对 documents 进行 rerank
|
||||||
if self.rerank_flag:
|
if self.rerank_flag:
|
||||||
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
|
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
|
||||||
|
|
||||||
content = ''
|
content = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
content += doc
|
content.append(doc)
|
||||||
logger.info(f'Retrieval data: {content}')
|
logger.info(f'Retrieval data: {content}')
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ import copy
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
|
from rag.src.pipeline import EmoLLMRAG
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
@ -188,8 +189,9 @@ robot_prompt = "<|im_start|>assistant\n{robot}<|im_end|>\n"
|
|||||||
cur_query_prompt = "<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n"
|
cur_query_prompt = "<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
|
|
||||||
def combine_history(prompt):
|
def combine_history(prompt, retrieval_content=''):
|
||||||
messages = st.session_state.messages
|
messages = st.session_state.messages
|
||||||
|
prompt = f"你需要根据以下从书本中检索到的专业知识:`{retrieval_content}`。从一个心理专家的专业角度来回答后续提问:{prompt}"
|
||||||
meta_instruction = (
|
meta_instruction = (
|
||||||
"你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发(排名按字母顺序排序,不分先后)、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。"
|
"你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发(排名按字母顺序排序,不分先后)、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。"
|
||||||
)
|
)
|
||||||
@ -211,6 +213,7 @@ def main():
|
|||||||
# torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
print("load model begin.")
|
print("load model begin.")
|
||||||
model, tokenizer = load_model()
|
model, tokenizer = load_model()
|
||||||
|
rag_obj = EmoLLMRAG(model)
|
||||||
print("load model end.")
|
print("load model end.")
|
||||||
|
|
||||||
user_avator = "assets/user.png"
|
user_avator = "assets/user.png"
|
||||||
@ -232,9 +235,12 @@ def main():
|
|||||||
# Accept user input
|
# Accept user input
|
||||||
if prompt := st.chat_input("What is up?"):
|
if prompt := st.chat_input("What is up?"):
|
||||||
# Display user message in chat message container
|
# Display user message in chat message container
|
||||||
|
retrieval_content = rag_obj.get_retrieval_content(prompt)
|
||||||
with st.chat_message("user", avatar=user_avator):
|
with st.chat_message("user", avatar=user_avator):
|
||||||
st.markdown(prompt)
|
st.markdown(prompt)
|
||||||
real_prompt = combine_history(prompt)
|
#st.markdown(retrieval_content)
|
||||||
|
|
||||||
|
real_prompt = combine_history(prompt, retrieval_content)
|
||||||
# Add user message to chat history
|
# Add user message to chat history
|
||||||
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})
|
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user