Add files via upload

allow user to load embedding & rerank models from cache
This commit is contained in:
zealot52099 2024-03-22 20:15:37 +08:00 committed by GitHub
parent 382d338ab3
commit 0aa58372bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,7 +7,7 @@ import os
from loguru import logger from loguru import logger
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, base_dir, vector_db_dir from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path
from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
@ -24,14 +24,10 @@ from FlagEmbedding import FlagReranker
class Data_process(): class Data_process():
def __init__(self): def __init__(self):
self.vector_db_dir = vector_db_dir
self.doc_dir = doc_dir
self.qa_dir = qa_dir
self.knowledge_pkl_path = knowledge_pkl_path
self.chunk_size: int=1000 self.chunk_size: int=1000
self.chunk_overlap: int=100 self.chunk_overlap: int=100
def load_embedding_model(self, model_name="BAAI/bge-small-zh-v1.5", device='cpu', normalize_embeddings=True): def load_embedding_model(self, model_name='BAAI/bge-small-zh-v1.5', device='cpu', normalize_embeddings=True):
""" """
加载嵌入模型 加载嵌入模型
@ -40,18 +36,29 @@ class Data_process():
- device: 指定模型加载的设备'cpu' 'cuda'默认为'cpu' - device: 指定模型加载的设备'cpu' 'cuda'默认为'cpu'
- normalize_embeddings: 是否标准化嵌入向量布尔类型默认为 True - normalize_embeddings: 是否标准化嵌入向量布尔类型默认为 True
""" """
if not os.path.exists(embedding_path):
os.makedirs(embedding_path, exist_ok=True)
embedding_model_path = os.path.join(embedding_path,model_name.split('/')[1] + '.pkl')
logger.info('Loading embedding model...') logger.info('Loading embedding model...')
if os.path.exists(embedding_model_path):
try:
with open(embedding_model_path , 'rb') as f:
embeddings = pickle.load(f)
logger.info('Embedding model loaded.')
return embeddings
except Exception as e:
logger.error(f'Failed to load embedding model from {embedding_model_path}')
try: try:
embeddings = HuggingFaceBgeEmbeddings( embeddings = HuggingFaceBgeEmbeddings(
model_name=model_name, model_name=model_name,
model_kwargs={'device': device}, model_kwargs={'device': device},
encode_kwargs={'normalize_embeddings': normalize_embeddings} encode_kwargs={'normalize_embeddings': normalize_embeddings})
) logger.info('Embedding model loaded.')
with open(embedding_model_path, 'wb') as file:
pickle.dump(embeddings, file)
except Exception as e: except Exception as e:
logger.error(f'Failed to load embedding model: {e}') logger.error(f'Failed to load embedding model: {e}')
return None return None
logger.info('Embedding model loaded.')
return embeddings return embeddings
def load_rerank_model(self, model_name='BAAI/bge-reranker-large'): def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
@ -67,9 +74,25 @@ class Data_process():
异常: 异常:
- ValueError: 如果模型名称不在批准的模型列表中 - ValueError: 如果模型名称不在批准的模型列表中
- Exception: 如果模型加载过程中发生任何其他错误 - Exception: 如果模型加载过程中发生任何其他错误
""" """
if not os.path.exists(rerank_path):
os.makedirs(rerank_path, exist_ok=True)
rerank_model_path = os.path.join(rerank_path, model_name.split('/')[1] + '.pkl')
logger.info('Loading rerank model...')
if os.path.exists(rerank_model_path):
try:
with open(rerank_model_path , 'rb') as f:
reranker_model = pickle.load(f)
logger.info('Rerank model loaded.')
return reranker_model
except Exception as e:
logger.error(f'Failed to load embedding model from {rerank_model_path}')
try: try:
reranker_model = FlagReranker(model_name, use_fp16=True) reranker_model = FlagReranker(model_name, use_fp16=True)
logger.info('Rerank model loaded.')
with open(rerank_model_path, 'wb') as file:
pickle.dump(reranker_model, file)
except Exception as e: except Exception as e:
logger.error(f'Failed to load rerank model: {e}') logger.error(f'Failed to load rerank model: {e}')
raise raise
@ -192,8 +215,8 @@ class Data_process():
创建并保存向量库 创建并保存向量库
''' '''
logger.info(f'Creating index...') logger.info(f'Creating index...')
split_doc = self.split_document(self.doc_dir) split_doc = self.split_document(doc_dir)
split_qa = self.split_conversation(self.qa_dir) split_qa = self.split_conversation(qa_dir)
# logger.info(f'split_doc == {len(split_doc)}') # logger.info(f'split_doc == {len(split_doc)}')
# logger.info(f'split_qa == {len(split_qa)}') # logger.info(f'split_qa == {len(split_qa)}')
# logger.info(f'split_doc type == {type(split_doc[0])}') # logger.info(f'split_doc type == {type(split_doc[0])}')
@ -287,7 +310,8 @@ if __name__ == "__main__":
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?" # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
# query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性" # query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
query = "我现在心情非常差,有什么解决办法吗?" # query = "我现在心情非常差,有什么解决办法吗?"
query = "我最近总感觉胸口很闷,但医生检查过说身体没问题。可我就是觉得喘不过气来,尤其是看到那些旧照片,想起过去的日子"
docs, retriever = dp.retrieve(query, vector_db, k=10) docs, retriever = dp.retrieve(query, vector_db, k=10)
logger.info(f'Query: {query}') logger.info(f'Query: {query}')
logger.info("Retrieve results:") logger.info("Retrieve results:")