Add files via upload
allow user to load embedding & rerank models from cache
This commit is contained in:
parent
382d338ab3
commit
0aa58372bb
@ -7,7 +7,7 @@ import os
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from langchain_community.vectorstores import FAISS
|
from langchain_community.vectorstores import FAISS
|
||||||
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, base_dir, vector_db_dir
|
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path
|
||||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
|
||||||
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
|
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
|
||||||
@ -24,14 +24,10 @@ from FlagEmbedding import FlagReranker
|
|||||||
|
|
||||||
class Data_process():
|
class Data_process():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.vector_db_dir = vector_db_dir
|
|
||||||
self.doc_dir = doc_dir
|
|
||||||
self.qa_dir = qa_dir
|
|
||||||
self.knowledge_pkl_path = knowledge_pkl_path
|
|
||||||
self.chunk_size: int=1000
|
self.chunk_size: int=1000
|
||||||
self.chunk_overlap: int=100
|
self.chunk_overlap: int=100
|
||||||
|
|
||||||
def load_embedding_model(self, model_name="BAAI/bge-small-zh-v1.5", device='cpu', normalize_embeddings=True):
|
def load_embedding_model(self, model_name='BAAI/bge-small-zh-v1.5', device='cpu', normalize_embeddings=True):
|
||||||
"""
|
"""
|
||||||
加载嵌入模型。
|
加载嵌入模型。
|
||||||
|
|
||||||
@ -40,18 +36,29 @@ class Data_process():
|
|||||||
- device: 指定模型加载的设备,'cpu' 或 'cuda',默认为'cpu'。
|
- device: 指定模型加载的设备,'cpu' 或 'cuda',默认为'cpu'。
|
||||||
- normalize_embeddings: 是否标准化嵌入向量,布尔类型,默认为 True。
|
- normalize_embeddings: 是否标准化嵌入向量,布尔类型,默认为 True。
|
||||||
"""
|
"""
|
||||||
|
if not os.path.exists(embedding_path):
|
||||||
|
os.makedirs(embedding_path, exist_ok=True)
|
||||||
|
embedding_model_path = os.path.join(embedding_path,model_name.split('/')[1] + '.pkl')
|
||||||
logger.info('Loading embedding model...')
|
logger.info('Loading embedding model...')
|
||||||
|
if os.path.exists(embedding_model_path):
|
||||||
|
try:
|
||||||
|
with open(embedding_model_path , 'rb') as f:
|
||||||
|
embeddings = pickle.load(f)
|
||||||
|
logger.info('Embedding model loaded.')
|
||||||
|
return embeddings
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Failed to load embedding model from {embedding_model_path}')
|
||||||
try:
|
try:
|
||||||
embeddings = HuggingFaceBgeEmbeddings(
|
embeddings = HuggingFaceBgeEmbeddings(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_kwargs={'device': device},
|
model_kwargs={'device': device},
|
||||||
encode_kwargs={'normalize_embeddings': normalize_embeddings}
|
encode_kwargs={'normalize_embeddings': normalize_embeddings})
|
||||||
)
|
logger.info('Embedding model loaded.')
|
||||||
|
with open(embedding_model_path, 'wb') as file:
|
||||||
|
pickle.dump(embeddings, file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'Failed to load embedding model: {e}')
|
logger.error(f'Failed to load embedding model: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger.info('Embedding model loaded.')
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
|
def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
|
||||||
@ -67,9 +74,25 @@ class Data_process():
|
|||||||
异常:
|
异常:
|
||||||
- ValueError: 如果模型名称不在批准的模型列表中。
|
- ValueError: 如果模型名称不在批准的模型列表中。
|
||||||
- Exception: 如果模型加载过程中发生任何其他错误。
|
- Exception: 如果模型加载过程中发生任何其他错误。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if not os.path.exists(rerank_path):
|
||||||
|
os.makedirs(rerank_path, exist_ok=True)
|
||||||
|
rerank_model_path = os.path.join(rerank_path, model_name.split('/')[1] + '.pkl')
|
||||||
|
logger.info('Loading rerank model...')
|
||||||
|
if os.path.exists(rerank_model_path):
|
||||||
|
try:
|
||||||
|
with open(rerank_model_path , 'rb') as f:
|
||||||
|
reranker_model = pickle.load(f)
|
||||||
|
logger.info('Rerank model loaded.')
|
||||||
|
return reranker_model
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Failed to load embedding model from {rerank_model_path}')
|
||||||
try:
|
try:
|
||||||
reranker_model = FlagReranker(model_name, use_fp16=True)
|
reranker_model = FlagReranker(model_name, use_fp16=True)
|
||||||
|
logger.info('Rerank model loaded.')
|
||||||
|
with open(rerank_model_path, 'wb') as file:
|
||||||
|
pickle.dump(reranker_model, file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'Failed to load rerank model: {e}')
|
logger.error(f'Failed to load rerank model: {e}')
|
||||||
raise
|
raise
|
||||||
@ -192,8 +215,8 @@ class Data_process():
|
|||||||
创建并保存向量库
|
创建并保存向量库
|
||||||
'''
|
'''
|
||||||
logger.info(f'Creating index...')
|
logger.info(f'Creating index...')
|
||||||
split_doc = self.split_document(self.doc_dir)
|
split_doc = self.split_document(doc_dir)
|
||||||
split_qa = self.split_conversation(self.qa_dir)
|
split_qa = self.split_conversation(qa_dir)
|
||||||
# logger.info(f'split_doc == {len(split_doc)}')
|
# logger.info(f'split_doc == {len(split_doc)}')
|
||||||
# logger.info(f'split_qa == {len(split_qa)}')
|
# logger.info(f'split_qa == {len(split_qa)}')
|
||||||
# logger.info(f'split_doc type == {type(split_doc[0])}')
|
# logger.info(f'split_doc type == {type(split_doc[0])}')
|
||||||
@ -287,7 +310,8 @@ if __name__ == "__main__":
|
|||||||
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
|
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
|
||||||
# query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
|
# query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
|
||||||
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
||||||
query = "我现在心情非常差,有什么解决办法吗?"
|
# query = "我现在心情非常差,有什么解决办法吗?"
|
||||||
|
query = "我最近总感觉胸口很闷,但医生检查过说身体没问题。可我就是觉得喘不过气来,尤其是看到那些旧照片,想起过去的日子"
|
||||||
docs, retriever = dp.retrieve(query, vector_db, k=10)
|
docs, retriever = dp.retrieve(query, vector_db, k=10)
|
||||||
logger.info(f'Query: {query}')
|
logger.info(f'Query: {query}')
|
||||||
logger.info("Retrieve results:")
|
logger.info("Retrieve results:")
|
||||||
|
Loading…
Reference in New Issue
Block a user