Update
This commit is contained in:
parent
f44310f665
commit
8c81c222a9
@ -18,22 +18,47 @@ langchain_core==0.1.33
|
||||
langchain_openai==0.0.8
|
||||
langchain_text_splitters==0.0.1
|
||||
FlagEmbedding==1.2.8
|
||||
|
||||
unstructured==0.12.6
|
||||
```
|
||||
|
||||
```python
|
||||
cd rag
|
||||
|
||||
cd rag
|
||||
pip3 install -r requirements.txt
|
||||
|
||||
```
|
||||
|
||||
## **使用指南**
|
||||
|
||||
### 准备数据
|
||||
|
||||
- txt数据:放入到 src.data.txt 目录下
|
||||
- json 数据:放入到 src.data.json 目录下
|
||||
|
||||
会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl
|
||||
|
||||
如果已经有 vector DB 则会直接加载对应数据库
|
||||
|
||||
|
||||
### 配置 config 文件
|
||||
|
||||
根据需要改写 config.config 文件:
|
||||
|
||||
```python
|
||||
|
||||
# 存放所有 model
|
||||
model_dir = os.path.join(base_dir, 'model')
|
||||
|
||||
# embedding model 路径以及 model name
|
||||
embedding_path = os.path.join(model_dir, 'embedding_model')
|
||||
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
|
||||
|
||||
|
||||
# rerank model 路径以及 model name
|
||||
rerank_path = os.path.join(model_dir, 'rerank_model')
|
||||
rerank_model_name = 'BAAI/bge-reranker-large'
|
||||
|
||||
|
||||
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
|
||||
select_num = 3
|
||||
|
||||
@ -52,7 +77,7 @@ prompt_template = """
|
||||
{content}
|
||||
|
||||
问题:{query}
|
||||
"
|
||||
"""
|
||||
```
|
||||
|
||||
### 调用
|
||||
|
@ -9,3 +9,4 @@ langchain_core==0.1.33
|
||||
langchain_openai==0.0.8
|
||||
langchain_text_splitters==0.0.1
|
||||
FlagEmbedding==1.2.8
|
||||
unstructured==0.12.6
|
@ -8,7 +8,9 @@ model_repo = 'ajupyter/EmoLLM_aiwei'
|
||||
# model
|
||||
model_dir = os.path.join(base_dir, 'model') # model
|
||||
embedding_path = os.path.join(model_dir, 'embedding_model') # embedding
|
||||
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
|
||||
rerank_path = os.path.join(model_dir, 'rerank_model') # embedding
|
||||
rerank_model_name = 'BAAI/bge-reranker-large'
|
||||
llm_path = os.path.join(model_dir, 'pythia-14m') # llm
|
||||
|
||||
# data
|
||||
@ -36,7 +38,8 @@ glm_key = ''
|
||||
|
||||
# prompt
|
||||
prompt_template = """
|
||||
{system_prompt}
|
||||
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
|
||||
|
||||
根据下面检索回来的信息,回答问题。
|
||||
{content}
|
||||
问题:{query}
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
|
||||
from loguru import logger
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path
|
||||
from config.config import embedding_path, embedding_model_name, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path, rerank_model_name
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
@ -13,11 +13,12 @@ from langchain_core.documents.base import Document
|
||||
from FlagEmbedding import FlagReranker
|
||||
|
||||
class Data_process():
|
||||
|
||||
def __init__(self):
|
||||
self.chunk_size: int=1000
|
||||
self.chunk_overlap: int=100
|
||||
|
||||
def load_embedding_model(self, model_name='BAAI/bge-small-zh-v1.5', device='cpu', normalize_embeddings=True):
|
||||
def load_embedding_model(self, model_name=embedding_model_name, device='cpu', normalize_embeddings=True):
|
||||
"""
|
||||
加载嵌入模型。
|
||||
|
||||
@ -51,7 +52,7 @@ class Data_process():
|
||||
return None
|
||||
return embeddings
|
||||
|
||||
def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
|
||||
def load_rerank_model(self, model_name=rerank_model_name):
|
||||
"""
|
||||
加载重排名模型。
|
||||
|
||||
@ -89,7 +90,6 @@ class Data_process():
|
||||
|
||||
return reranker_model
|
||||
|
||||
|
||||
def extract_text_from_json(self, obj, content=None):
|
||||
"""
|
||||
抽取json中的文本,用于向量库构建
|
||||
@ -118,7 +118,7 @@ class Data_process():
|
||||
return content
|
||||
|
||||
|
||||
def split_document(self, data_path, chunk_size=500, chunk_overlap=100):
|
||||
def split_document(self, data_path):
|
||||
"""
|
||||
切分data_path文件夹下的所有txt文件
|
||||
|
||||
@ -133,7 +133,7 @@ class Data_process():
|
||||
|
||||
|
||||
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
text_spliter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
|
||||
split_docs = []
|
||||
logger.info(f'Loading txt files from {data_path}')
|
||||
if os.path.isdir(data_path):
|
||||
@ -178,7 +178,7 @@ class Data_process():
|
||||
# split_qa.append(Document(page_content = content))
|
||||
#按conversation块切分
|
||||
content = self.extract_text_from_json(conversation['conversation'], '')
|
||||
logger.info(f'content====={content}')
|
||||
#logger.info(f'content====={content}')
|
||||
split_qa.append(Document(page_content = content))
|
||||
# logger.info(f'split_qa size====={len(split_qa)}')
|
||||
return split_qa
|
||||
|
@ -32,7 +32,10 @@ def main(query, system_prompt=''):
|
||||
logger.info(f'score: {str(scores[i])}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
||||
query = """
|
||||
我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。
|
||||
无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想
|
||||
"""
|
||||
|
||||
"""
|
||||
输入:
|
||||
|
Loading…
Reference in New Issue
Block a user