Merge pull request #100 from zealot52099/main

update
This commit is contained in:
xzw 2024-03-18 23:22:02 +08:00 committed by GitHub
commit c7d916bf4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 314 additions and 67 deletions

View File

@ -13,11 +13,16 @@ llm_path = os.path.join(model_dir, 'pythia-14m') # llm
# data
data_dir = os.path.join(base_dir, 'data') # data
knowledge_json_path = os.path.join(data_dir, 'knowledge.json') # json
knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pickle
knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pkl
doc_dir = os.path.join(data_dir, 'txt')
qa_dir = os.path.join(data_dir, 'json')
# log
log_dir = os.path.join(base_dir, 'log') # log
log_path = os.path.join(log_dir, 'log.log') # file
# vector DB
vector_db_dir = os.path.join(data_dir, 'vector_db.pkl')
select_num = 3
retrieval_num = 10

262
rag/src/data_processing.py Normal file
View File

@ -0,0 +1,262 @@
import json
import pickle
from loguru import logger
from sentence_transformers import SentenceTransformer
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, base_dir, vector_db_dir
import os
import faiss
import platform
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
from BCEmbedding import EmbeddingModel, RerankerModel
from util.pipeline import EmoLLMRAG
import pickle
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
'''
1根据QA对/TXT 文本生成 embedding
2调用 langchain FAISS 接口构建 vector DB
3存储到 openxlab.dataset 方便后续调用
4提供 embedding 的接口函数方便后续调用
5提供 rerank 的接口函数方便后续调用
'''
"""
加载向量模型
"""
def load_embedding_model():
logger.info('Loading embedding model...')
# model = EmbeddingModel(model_name_or_path="huggingface/bce-embedding-base_v1")
model = EmbeddingModel(model_name_or_path="maidalun1020/bce-embedding-base_v1")
logger.info('Embedding model loaded.')
return model
def load_rerank_model():
logger.info('Loading rerank_model...')
model = RerankerModel(model_name_or_path="maidalun1020/bce-reranker-base_v1")
# model = RerankerModel(model_name_or_path="huggingface/bce-reranker-base_v1")
logger.info('Rerank model loaded.')
return model
def split_document(data_path, chunk_size=1000, chunk_overlap=100):
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
split_docs = []
logger.info(f'Loading txt files from {data_path}')
if os.path.isdir(data_path):
# 如果是文件夹,则遍历读取
for root, dirs, files in os.walk(data_path):
for file in files:
if file.endswith('.txt'):
file_path = os.path.join(root, file)
# logger.info(f'splitting file {file_path}')
text_loader = TextLoader(file_path, encoding='utf-8')
text = text_loader.load()
splits = text_spliter.split_documents(text)
# logger.info(f"splits type {type(splits[0])}")
# logger.info(f'splits size {len(splits)}')
split_docs += splits
elif data_path.endswith('.txt'):
file_path = os.path.join(root, data_path)
# logger.info(f'splitting file {file_path}')
text_loader = TextLoader(file_path, encoding='utf-8')
text = text_loader.load()
splits = text_spliter.split_documents(text)
# logger.info(f"splits type {type(splits[0])}")
# logger.info(f'splits size {len(splits)}')
split_docs = splits
logger.info(f'split_docs size {len(split_docs)}')
return split_docs
##TODO 1、读取system prompt 2、限制序列长度
def split_conversation(path):
'''
data format:
[
{
"conversation": [
{
"input": Q1
"output": A1
},
{
"input": Q2
"output": A2
},
]
},
]
'''
qa_pairs = []
logger.info(f'Loading json files from {path}')
if os.path.isfile(path):
with open(path, 'r', encoding='utf-8') as file:
data = json.load(file)
for conversation in data:
for dialog in conversation['conversation']:
# input_text = dialog['input']
# output_text = dialog['output']
# if len(input_text) > max_length or len(output_text) > max_length:
# continue
qa_pairs.append(dialog)
elif os.path.isdir(path):
# 如果是文件夹,则遍历读取
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith('.json'):
file_path = os.path.join(root, file)
logger.info(f'splitting file {file_path}')
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for conversation in data:
for dialog in conversation['conversation']:
qa_pairs.append(dialog)
return qa_pairs
# 加载本地索引
def load_index_and_knowledge():
current_os = platform.system()
split_doc = []
split_qa = []
#读取知识库
if not os.path.exists(knowledge_pkl_path):
split_doc = split_document(doc_dir)
split_qa = split_conversation(qa_dir)
# logger.info(f'split_qa size:{len(split_qa)}')
# logger.info(f'type of split_qa:{type(split_qa[0])}')
# logger.info(f'split_doc size:{len(split_doc)}')
# logger.info(f'type of doc:{type(split_doc[0])}')
knowledge_chunks = split_doc + split_qa
with open(knowledge_pkl_path, 'wb') as file:
pickle.dump(knowledge_chunks, file)
else:
with open(knowledge_pkl_path , 'rb') as f:
knowledge_chunks = pickle.load(f)
#读取vector DB
if not os.path.exists(vector_db_dir):
logger.info(f'Creating index...')
emb_model = load_embedding_model()
if not split_doc:
split_doc = split_document(doc_dir)
if not split_qa:
split_qa = split_conversation(qa_dir)
# 创建索引,windows不支持faiss-gpu
if current_os == 'Linux':
index = create_index_gpu(split_doc, split_qa, emb_model, vector_db_dir)
else:
index = create_index_cpu(split_doc, split_qa, emb_model, vector_db_dir)
else:
if current_os == 'Linux':
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index, vector_db_dir)
else:
index = faiss.read_index(vector_db_dir)
return index, knowledge_chunks
def create_index_cpu(split_doc, split_qa, emb_model, knowledge_pkl_path, dimension = 768, question_only=False):
# 假设BCE嵌入的维度是768根据你选择的模型可能不同
faiss_index_cpu = faiss.IndexFlatIP(dimension) # 创建一个使用内积的FAISS索引
# 将问答对转换为向量并添加到FAISS索引中
for doc in split_doc:
# type_of_docs = type(split_doc)
text = f"{doc.page_content}"
vector = emb_model.encode([text])
faiss_index_cpu.add(vector)
for qa in split_qa:
#仅对Q对进行编码
text = f"{qa['input']}"
vector = emb_model.encode([text])
faiss_index_cpu.add(vector)
faiss.write_index(faiss_index_cpu, knowledge_pkl_path)
return faiss_index_cpu
def create_index_gpu(split_doc, split_qa, emb_model, knowledge_pkl_path, dimension = 768, question_only=False):
res = faiss.StandardGpuResources()
index = faiss.IndexFlatIP(dimension)
faiss_index_gpu = faiss.index_cpu_to_gpu(res, 0, index)
for doc in split_doc:
# type_of_docs = type(split_doc)
text = f"{doc.page_content}"
vector = emb_model.encode([text])
faiss_index_gpu.add(vector)
for qa in split_qa:
#仅对Q对进行编码
text = f"{qa['input']}"
vector = emb_model.encode([text])
faiss_index_gpu.add(vector)
faiss.write_index(faiss_index_gpu, knowledge_pkl_path)
return faiss_index_gpu
# 根据query搜索相似文本
def find_top_k(query, faiss_index, k=5):
emb_model = load_embedding_model()
emb_query = emb_model.encode([query])
distances, indices = faiss_index.search(emb_query, k)
return distances, indices
def rerank(query, indices, knowledge_chunks):
passages = []
for index in indices[0]:
content = knowledge_chunks[index]
'''
txt: 'langchain_core.documents.base.Document'
json: dict
'''
# logger.info(f'retrieved content:{content}')
# logger.info(f'type of content:{type(content)}')
if type(content) == dict:
content = content["input"] + '\n' + content["output"]
else:
content = content.page_content
passages.append(content)
model = load_rerank_model()
rerank_results = model.rerank(query, passages)
return rerank_results
@st.cache_resource
def load_model():
model = (
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
return model, tokenizer
if __name__ == "__main__":
logger.info(data_dir)
if not os.path.exists(data_dir):
os.mkdir(data_dir)
faiss_index, knowledge_chunks = load_index_and_knowledge()
# 按照query进行查询
# query = "她要阻挠姐姐的婚姻,即使她自己的尸体在房门跟前"
# query = "肯定的。我最近睡眠很差,总是做噩梦。而且我吃得也不好,体重一直在下降"
# query = "序言 (一) 变态心理学是心理学本科生的必修课程之一,教材更新的问题一直在困扰着我们。"
query = "心理咨询师,我觉得我的胸闷症状越来越严重了,这让我很害怕"
distances, indices = find_top_k(query, faiss_index, 5)
logger.info(f'distances==={distances}')
logger.info(f'indices==={indices}')
# rerank无法返回id先实现按整个问答对排序
rerank_results = rerank(query, indices, knowledge_chunks)
for passage, score in zip(rerank_results['rerank_passages'], rerank_results['rerank_scores']):
print(str(score)+'\n')
print(passage+'\n')

View File

@ -5,87 +5,67 @@ import numpy as np
from typing import Tuple
from sentence_transformers import SentenceTransformer
from config.config import knowledge_json_path, knowledge_pkl_path, model_repo
from config.config import knowledge_json_path, knowledge_pkl_path, model_repo, model_dir, base_dir
from util.encode import load_embedding, encode_qa
from util.pipeline import EmoLLMRAG
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
from data_processing import load_index_and_knowledge, create_index_cpu, create_index_gpu, find_top_k, rerank
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir
download(
model_repo=model_repo,
output='model'
)
'''
1构建完整的 RAG pipeline输入为用户 query输出为 answer
2调用 embedding 提供的接口对 query 向量化
3下载基于 FAISS 预构建的 vector DB 并检索对应信息
4调用 rerank 接口重排序检索内容
5调用 prompt 接口获取 system prompt prompt template
6拼接 prompt 并调用模型返回结果
"""
读取知识库
"""
def load_knowledge() -> Tuple[list, list]:
# 如果 pkl 不存在,则先编码存储
if not os.path.exists(knowledge_pkl_path):
encode_qa(knowledge_json_path, knowledge_pkl_path)
# 加载 json 和 pkl
with open(knowledge_json_path, 'r', encoding='utf-8') as f1, open(knowledge_pkl_path, 'rb') as f2:
knowledge = json.load(f1)
encoded_knowledge = pickle.load(f2)
return knowledge, encoded_knowledge
"""
召回 top_k 个相关的文本段
"""
def find_top_k(
emb: SentenceTransformer,
query: str,
knowledge: list,
encoded_knowledge: list,
k=3
) -> list[str]:
# 编码 query
query_embedding = emb.encode(query)
# 查找 top_k
scores = query_embedding @ encoded_knowledge.T
# 使用 argpartition 找出每行第 k 个大的值的索引,第 k 个位置左侧都是比它大的值,右侧都是比它小的值
top_k_indices = np.argpartition(scores, -k)[-k:]
# 由于 argpartition 不保证顺序,我们需要对提取出的 k 个索引进行排序
top_k_values_sorted_indices = np.argsort(scores[top_k_indices])[::-1]
top_k_indices = top_k_indices[top_k_values_sorted_indices]
# 返回
contents = [knowledge[index] for index in top_k_indices]
return contents
def main():
emb = load_embedding()
knowledge, encoded_knowledge = load_knowledge()
query = "认知心理学研究哪些心理活动?"
contents = find_top_k(emb, query, knowledge, encoded_knowledge, 2)
print('召回的 top-k 条相关内容如下:')
print(json.dumps(contents, ensure_ascii=False, indent=2))
# 这里我没实现 LLM 部分,如果有 LLM
## 1. 读取 LLM
## 2. 将 contents 拼接为 prompt传给 LLM作为 {已知内容}
## 3. 要求 LLM 根据已知内容回复
'''
# download(
# model_repo=model_repo,
# output='model'
# )
@st.cache_resource
def load_model():
model_dir = os.path.join(base_dir,'../model')
logger.info(f'Loading model from {model_dir}')
model = (
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
return model, tokenizer
if __name__ == '__main__':
#main()
query = ''
def get_prompt():
pass
def get_prompt_template():
pass
def main(query, system_prompt):
model, tokenizer = load_model()
rag_obj = EmoLLMRAG(model)
response = rag_obj.main(query)
model = model.eval()
if not os.path.exists(data_dir):
os.mkdir(data_dir)
# 下载基于 FAISS 预构建的 vector DB 以及原始知识库
faiss_index, knowledge_chunks = load_index_and_knowledge()
distances, indices = find_top_k(query, faiss_index, 5)
rerank_results = rerank(query, indices, knowledge_chunks)
messages = [(system_prompt, rerank_results['rerank_passages'][0])]
logger.info(f'messages:{messages}')
response, history = model.chat(tokenizer, query, history=messages)
messages.append((query, response))
print(f"robot >>> {response}")
if __name__ == '__main__':
# query = '你好'
query = "心理咨询师,我觉得我的胸闷症状越来越严重了,这让我很害怕"
#TODO system_prompt = get_prompt()
system_prompt = "你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发排名按字母顺序排序不分先后、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家我有一些心理问题请你用专业的知识帮我解决。"
main(query, system_prompt)