update rag/src/data_processing.py & main,py

This commit is contained in:
zealot52099 2024-03-20 16:51:07 +08:00
parent a119fdb507
commit fdf05f480c
2 changed files with 268 additions and 268 deletions

View File

@ -1,114 +1,155 @@
import json import json
import pickle import pickle
import faiss
import pickle
import os
from loguru import logger from loguru import logger
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, base_dir, vector_db_dir from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, base_dir, vector_db_dir
import os from langchain.embeddings import HuggingFaceBgeEmbeddings
import faiss
import platform
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
from BCEmbedding import EmbeddingModel, RerankerModel from BCEmbedding import EmbeddingModel, RerankerModel
from util.pipeline import EmoLLMRAG from util.pipeline import EmoLLMRAG
import pickle
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
import torch from langchain.document_loaders.pdf import PyPDFDirectoryLoader
import streamlit as st from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
from openxlab.model import download from langchain_community.llms import Cohere
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain_core.documents.base import Document
from FlagEmbedding import FlagReranker
class Data_process():
def __init__(self):
self.vector_db_dir = vector_db_dir
self.doc_dir = doc_dir
self.qa_dir = qa_dir
self.knowledge_pkl_path = knowledge_pkl_path
self.chunk_size: int=1000
self.chunk_overlap: int=100
''' def load_embedding_model(self, model_name="BAAI/bge-small-zh-v1.5", device='cpu', normalize_embeddings=True):
1根据QA对/TXT 文本生成 embedding
2调用 langchain FAISS 接口构建 vector DB
3存储到 openxlab.dataset 方便后续调用
4提供 embedding 的接口函数方便后续调用
5提供 rerank 的接口函数方便后续调用
'''
""" """
加载向量模型 加载嵌入模型
参数:
- model_name: 模型名称字符串类型默认为"BAAI/bge-small-zh-v1.5"
- device: 指定模型加载的设备'cpu' 'cuda'默认为'cpu'
- normalize_embeddings: 是否标准化嵌入向量布尔类型默认为 True
""" """
def load_embedding_model():
logger.info('Loading embedding model...') logger.info('Loading embedding model...')
# model = EmbeddingModel(model_name_or_path="huggingface/bce-embedding-base_v1") try:
model = EmbeddingModel(model_name_or_path="maidalun1020/bce-embedding-base_v1") embeddings = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs={'device': device},
encode_kwargs={'normalize_embeddings': normalize_embeddings}
)
except Exception as e:
logger.error(f'Failed to load embedding model: {e}')
return None
logger.info('Embedding model loaded.') logger.info('Embedding model loaded.')
return model return embeddings
def load_rerank_model(): def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
logger.info('Loading rerank_model...') """
model = RerankerModel(model_name_or_path="maidalun1020/bce-reranker-base_v1") 加载重排名模型
# model = RerankerModel(model_name_or_path="huggingface/bce-reranker-base_v1")
logger.info('Rerank model loaded.') 参数:
return model - model_name (str): 模型的名称默认为 'BAAI/bge-reranker-large'
返回:
- FlagReranker 实例
异常:
- ValueError: 如果模型名称不在批准的模型列表中
- Exception: 如果模型加载过程中发生任何其他错误
"""
try:
reranker_model = FlagReranker(model_name, use_fp16=True)
except Exception as e:
logger.error(f'Failed to load rerank model: {e}')
raise
return reranker_model
def extract_text_from_json(self, obj, content=None):
"""
抽取json中的文本用于向量库构建
参数:
- obj: dict,list,str
- content: str
返回:
- content: str
"""
if isinstance(obj, dict):
for key, value in obj.items():
try:
self.extract_text_from_json(value, content)
except Exception as e:
print(f"Error processing value: {e}")
elif isinstance(obj, list):
for index, item in enumerate(obj):
try:
self.extract_text_from_json(item, content)
except Exception as e:
print(f"Error processing item: {e}")
elif isinstance(obj, str):
content += obj
return content
def split_document(self, data_path, chunk_size=500, chunk_overlap=100):
"""
切分data_path文件夹下的所有txt文件
参数:
- data_path: str
- chunk_size: int
- chunk_overlap: int
返回
- split_docs: list
"""
def split_document(data_path, chunk_size=1000, chunk_overlap=100):
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) # text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
split_docs = [] split_docs = []
logger.info(f'Loading txt files from {data_path}') logger.info(f'Loading txt files from {data_path}')
if os.path.isdir(data_path): if os.path.isdir(data_path):
# 如果是文件夹,则遍历读取 loader = DirectoryLoader(data_path, glob="**/*.txt",show_progress=True)
for root, dirs, files in os.walk(data_path): docs = loader.load()
for file in files: split_docs = text_spliter.split_documents(docs)
if file.endswith('.txt'):
file_path = os.path.join(root, file)
# logger.info(f'splitting file {file_path}')
text_loader = TextLoader(file_path, encoding='utf-8')
text = text_loader.load()
splits = text_spliter.split_documents(text)
# logger.info(f"splits type {type(splits[0])}")
# logger.info(f'splits size {len(splits)}')
split_docs += splits
elif data_path.endswith('.txt'): elif data_path.endswith('.txt'):
file_path = os.path.join(root, data_path) file_path = data_path
# logger.info(f'splitting file {file_path}') logger.info(f'splitting file {file_path}')
text_loader = TextLoader(file_path, encoding='utf-8') text_loader = TextLoader(file_path, encoding='utf-8')
text = text_loader.load() text = text_loader.load()
splits = text_spliter.split_documents(text) splits = text_spliter.split_documents(text)
# logger.info(f"splits type {type(splits[0])}")
# logger.info(f'splits size {len(splits)}')
split_docs = splits split_docs = splits
logger.info(f'split_docs size {len(split_docs)}') logger.info(f'split_docs size {len(split_docs)}')
return split_docs return split_docs
##TODO 1、读取system prompt 2、限制序列长度 def split_conversation(self, path):
def split_conversation(path): """
''' 按conversation块切分path文件夹下的所有json文件
data format: ##TODO 限制序列长度
[ """
{ # json_spliter = RecursiveJsonSplitter(max_chunk_size=500)
"conversation": [
{
"input": Q1
"output": A1
},
{
"input": Q2
"output": A2
},
]
},
]
'''
qa_pairs = []
logger.info(f'Loading json files from {path}') logger.info(f'Loading json files from {path}')
if os.path.isfile(path): split_qa = []
with open(path, 'r', encoding='utf-8') as file: if os.path.isdir(path):
data = json.load(file) # loader = DirectoryLoader(path, glob="**/*.json",show_progress=True)
for conversation in data: # jsons = loader.load()
for dialog in conversation['conversation']:
# input_text = dialog['input']
# output_text = dialog['output']
# if len(input_text) > max_length or len(output_text) > max_length:
# continue
qa_pairs.append(dialog)
elif os.path.isdir(path):
# 如果是文件夹,则遍历读取
for root, dirs, files in os.walk(path): for root, dirs, files in os.walk(path):
for file in files: for file in files:
if file.endswith('.json'): if file.endswith('.json'):
@ -116,147 +157,114 @@ def split_conversation(path):
logger.info(f'splitting file {file_path}') logger.info(f'splitting file {file_path}')
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f) data = json.load(f)
print(data)
for conversation in data: for conversation in data:
for dialog in conversation['conversation']: # for dialog in conversation['conversation']:
qa_pairs.append(dialog) ##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document
return qa_pairs # content = self.extract_text_from_json(dialog,'')
# split_qa.append(Document(page_content = content))
#按conversation块切分
content = self.extract_text_from_json(conversation['conversation'], '')
split_qa.append(Document(page_content = content))
# logger.info(f'split_qa size====={len(split_qa)}')
return split_qa
def load_knowledge(self, knowledge_pkl_path):
# 加载本地索引 '''
def load_index_and_knowledge(): 读取或创建知识.pkl
current_os = platform.system() '''
split_doc = []
split_qa = []
#读取知识库
if not os.path.exists(knowledge_pkl_path): if not os.path.exists(knowledge_pkl_path):
split_doc = split_document(doc_dir) split_doc = self.split_document(doc_dir)
split_qa = split_conversation(qa_dir) split_qa = self.split_conversation(qa_dir)
# logger.info(f'split_qa size:{len(split_qa)}')
# logger.info(f'type of split_qa:{type(split_qa[0])}')
# logger.info(f'split_doc size:{len(split_doc)}')
# logger.info(f'type of doc:{type(split_doc[0])}')
knowledge_chunks = split_doc + split_qa knowledge_chunks = split_doc + split_qa
with open(knowledge_pkl_path, 'wb') as file: with open(knowledge_pkl_path, 'wb') as file:
pickle.dump(knowledge_chunks, file) pickle.dump(knowledge_chunks, file)
else: else:
with open(knowledge_pkl_path , 'rb') as f: with open(knowledge_pkl_path , 'rb') as f:
knowledge_chunks = pickle.load(f) knowledge_chunks = pickle.load(f)
return knowledge_chunks
#读取vector DB
if not os.path.exists(vector_db_dir): def create_vector_db(self, emb_model):
'''
创建并保存向量库
'''
logger.info(f'Creating index...') logger.info(f'Creating index...')
emb_model = load_embedding_model() split_doc = self.split_document(self.doc_dir)
if not split_doc: split_qa = self.split_conversation(self.qa_dir)
split_doc = split_document(doc_dir) # logger.info(f'split_doc == {len(split_doc)}')
if not split_qa: # logger.info(f'split_qa == {len(split_qa)}')
split_qa = split_conversation(qa_dir) # logger.info(f'split_doc type == {type(split_doc[0])}')
# 创建索引,windows不支持faiss-gpu # logger.info(f'split_qa type== {type(split_qa[0])}')
if current_os == 'Linux': db = FAISS.from_documents(split_doc + split_qa, emb_model)
index = create_index_gpu(split_doc, split_qa, emb_model, vector_db_dir) db.save_local(vector_db_dir)
return db
def load_vector_db(self, knowledge_pkl_path=knowledge_pkl_path, doc_dir=doc_dir, qa_dir=qa_dir):
'''
读取向量库
'''
# current_os = platform.system()
emb_model = self.load_embedding_model()
if not os.path.exists(vector_db_dir) or not os.listdir(vector_db_dir):
db = self.create_vector_db(emb_model)
else: else:
index = create_index_cpu(split_doc, split_qa, emb_model, vector_db_dir) db = FAISS.load_local(vector_db_dir, emb_model, allow_dangerous_deserialization=True)
else: return db
if current_os == 'Linux':
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index, vector_db_dir)
else:
index = faiss.read_index(vector_db_dir)
return index, knowledge_chunks
def create_index_cpu(split_doc, split_qa, emb_model, knowledge_pkl_path, dimension = 768, question_only=False): def retrieve(self, query, vector_db, k=5):
# 假设BCE嵌入的维度是768根据你选择的模型可能不同 '''
faiss_index_cpu = faiss.IndexFlatIP(dimension) # 创建一个使用内积的FAISS索引 基于query对向量库进行检索
# 将问答对转换为向量并添加到FAISS索引中 '''
for doc in split_doc: retriever = vector_db.as_retriever(search_kwargs={"k": k})
# type_of_docs = type(split_doc) docs = retriever.invoke(query)
text = f"{doc.page_content}" return docs, retriever
vector = emb_model.encode([text])
faiss_index_cpu.add(vector)
for qa in split_qa:
#仅对Q对进行编码
text = f"{qa['input']}"
vector = emb_model.encode([text])
faiss_index_cpu.add(vector)
faiss.write_index(faiss_index_cpu, knowledge_pkl_path)
return faiss_index_cpu
def create_index_gpu(split_doc, split_qa, emb_model, knowledge_pkl_path, dimension = 768, question_only=False): ##FlashrankRerank效果一般
res = faiss.StandardGpuResources() # def rerank(self, query, retriever):
index = faiss.IndexFlatIP(dimension) # compressor = FlashrankRerank()
faiss_index_gpu = faiss.index_cpu_to_gpu(res, 0, index) # compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
for doc in split_doc: # compressed_docs = compression_retriever.get_relevant_documents(query)
# type_of_docs = type(split_doc) # return compressed_docs
text = f"{doc.page_content}"
vector = emb_model.encode([text])
faiss_index_gpu.add(vector)
for qa in split_qa:
#仅对Q对进行编码
text = f"{qa['input']}"
vector = emb_model.encode([text])
faiss_index_gpu.add(vector)
faiss.write_index(faiss_index_gpu, knowledge_pkl_path)
return faiss_index_gpu
def rerank(self, query, docs):
# 根据query搜索相似文本 reranker = self.load_rerank_model()
def find_top_k(query, faiss_index, k=5):
emb_model = load_embedding_model()
emb_query = emb_model.encode([query])
distances, indices = faiss_index.search(emb_query, k)
return distances, indices
def rerank(query, indices, knowledge_chunks):
passages = [] passages = []
for index in indices[0]: for doc in docs:
content = knowledge_chunks[index] passages.append(str(doc.page_content))
''' scores = reranker.compute_score([[query, passage] for passage in passages])
txt: 'langchain_core.documents.base.Document' sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
json: dict sorted_passages, sorted_scores = zip(*sorted_pairs)
''' return sorted_passages, sorted_scores
# logger.info(f'retrieved content:{content}')
# logger.info(f'type of content:{type(content)}')
if type(content) == dict:
content = content["input"] + '\n' + content["output"]
else:
content = content.page_content
passages.append(content)
model = load_rerank_model()
rerank_results = model.rerank(query, passages)
return rerank_results
@st.cache_resource
def load_model():
model = (
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
return model, tokenizer
if __name__ == "__main__": if __name__ == "__main__":
logger.info(data_dir) logger.info(data_dir)
if not os.path.exists(data_dir): if not os.path.exists(data_dir):
os.mkdir(data_dir) os.mkdir(data_dir)
faiss_index, knowledge_chunks = load_index_and_knowledge() dp = Data_process()
# faiss_index, knowledge_chunks = dp.load_index_and_knowledge(knowledge_pkl_path='')
vector_db = dp.load_vector_db()
# 按照query进行查询 # 按照query进行查询
# query = "她要阻挠姐姐的婚姻,即使她自己的尸体在房门跟前" # query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版1979年再版。"
# query = "肯定的。我最近睡眠很差,总是做噩梦。而且我吃得也不好,体重一直在下降" # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
# query = "序言 (一) 变态心理学是心理学本科生的必修课程之一,教材更新的问题一直在困扰着我们。" # query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
query = "心理咨询师,我觉得我的胸闷症状越来越严重了,这让我很害怕" query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
distances, indices = find_top_k(query, faiss_index, 5) docs, retriever = dp.retrieve(query, vector_db, k=10)
logger.info(f'distances==={distances}') logger.info(f'Query: {query}')
logger.info(f'indices==={indices}') logger.info("Retrieve results:")
for i, doc in enumerate(docs):
logger.info(str(i) + '\n')
# rerank无法返回id先实现按整个问答对排序 logger.info(doc)
rerank_results = rerank(query, indices, knowledge_chunks) # print(f'get num of docs:{len(docs)}')
for passage, score in zip(rerank_results['rerank_passages'], rerank_results['rerank_scores']): # print(docs)
print(str(score)+'\n') passages,scores = dp.rerank(query, docs)
print(passage+'\n') logger.info("After reranking...")
for i in range(len(scores)):
logger.info(str(scores[i]) + '\n')
logger.info(passages[i])

View File

@ -13,9 +13,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
import torch import torch
import streamlit as st import streamlit as st
from openxlab.model import download from openxlab.model import download
from data_processing import load_index_and_knowledge, create_index_cpu, create_index_gpu, find_top_k, rerank
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir
from data_processing import Data_process
''' '''
1构建完整的 RAG pipeline输入为用户 query输出为 answer 1构建完整的 RAG pipeline输入为用户 query输出为 answer
2调用 embedding 提供的接口对 query 向量化 2调用 embedding 提供的接口对 query 向量化
@ -42,30 +41,23 @@ def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
return model, tokenizer return model, tokenizer
def get_prompt(): def main(query, system_prompt=''):
pass logger.info(data_dir)
def get_prompt_template():
pass
def main(query, system_prompt):
model, tokenizer = load_model()
model = model.eval()
if not os.path.exists(data_dir): if not os.path.exists(data_dir):
os.mkdir(data_dir) os.mkdir(data_dir)
# 下载基于 FAISS 预构建的 vector DB 以及原始知识库 dp = Data_process()
faiss_index, knowledge_chunks = load_index_and_knowledge() vector_db = dp.load_vector_db()
distances, indices = find_top_k(query, faiss_index, 5) docs, retriever = dp.retrieve(query, vector_db, k=10)
rerank_results = rerank(query, indices, knowledge_chunks) logger.info(f'Query: {query}')
messages = [(system_prompt, rerank_results['rerank_passages'][0])] logger.info("Retrieve results===============================")
logger.info(f'messages:{messages}') for i, doc in enumerate(docs):
response, history = model.chat(tokenizer, query, history=messages) logger.info(doc)
messages.append((query, response)) passages,scores = dp.rerank(query, docs)
print(f"robot >>> {response}") logger.info("After reranking===============================")
for i in range(len(scores)):
logger.info(passages[i])
logger.info(f'score: {str(scores[i])}')
if __name__ == '__main__': if __name__ == "__main__":
# query = '你好' query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
query = "心理咨询师,我觉得我的胸闷症状越来越严重了,这让我很害怕" main(query)
#TODO system_prompt = get_prompt()
system_prompt = "你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发排名按字母顺序排序不分先后、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家我有一些心理问题请你用专业的知识帮我解决。"
main(query, system_prompt)