Update
This commit is contained in:
parent
f44310f665
commit
8c81c222a9
@ -18,22 +18,47 @@ langchain_core==0.1.33
|
|||||||
langchain_openai==0.0.8
|
langchain_openai==0.0.8
|
||||||
langchain_text_splitters==0.0.1
|
langchain_text_splitters==0.0.1
|
||||||
FlagEmbedding==1.2.8
|
FlagEmbedding==1.2.8
|
||||||
|
unstructured==0.12.6
|
||||||
```
|
```
|
||||||
|
|
||||||
```python
|
```python
|
||||||
cd rag
|
|
||||||
|
|
||||||
|
cd rag
|
||||||
pip3 install -r requirements.txt
|
pip3 install -r requirements.txt
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## **使用指南**
|
## **使用指南**
|
||||||
|
|
||||||
|
### 准备数据
|
||||||
|
|
||||||
|
- txt数据:放入到 src.data.txt 目录下
|
||||||
|
- json 数据:放入到 src.data.json 目录下
|
||||||
|
|
||||||
|
会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl
|
||||||
|
|
||||||
|
如果已经有 vector DB 则会直接加载对应数据库
|
||||||
|
|
||||||
|
|
||||||
### 配置 config 文件
|
### 配置 config 文件
|
||||||
|
|
||||||
根据需要改写 config.config 文件:
|
根据需要改写 config.config 文件:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
||||||
|
# 存放所有 model
|
||||||
|
model_dir = os.path.join(base_dir, 'model')
|
||||||
|
|
||||||
|
# embedding model 路径以及 model name
|
||||||
|
embedding_path = os.path.join(model_dir, 'embedding_model')
|
||||||
|
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
|
||||||
|
|
||||||
|
|
||||||
|
# rerank model 路径以及 model name
|
||||||
|
rerank_path = os.path.join(model_dir, 'rerank_model')
|
||||||
|
rerank_model_name = 'BAAI/bge-reranker-large'
|
||||||
|
|
||||||
|
|
||||||
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
|
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
|
||||||
select_num = 3
|
select_num = 3
|
||||||
|
|
||||||
@ -52,7 +77,7 @@ prompt_template = """
|
|||||||
{content}
|
{content}
|
||||||
|
|
||||||
问题:{query}
|
问题:{query}
|
||||||
"
|
"""
|
||||||
```
|
```
|
||||||
|
|
||||||
### 调用
|
### 调用
|
||||||
|
@ -9,3 +9,4 @@ langchain_core==0.1.33
|
|||||||
langchain_openai==0.0.8
|
langchain_openai==0.0.8
|
||||||
langchain_text_splitters==0.0.1
|
langchain_text_splitters==0.0.1
|
||||||
FlagEmbedding==1.2.8
|
FlagEmbedding==1.2.8
|
||||||
|
unstructured==0.12.6
|
@ -8,7 +8,9 @@ model_repo = 'ajupyter/EmoLLM_aiwei'
|
|||||||
# model
|
# model
|
||||||
model_dir = os.path.join(base_dir, 'model') # model
|
model_dir = os.path.join(base_dir, 'model') # model
|
||||||
embedding_path = os.path.join(model_dir, 'embedding_model') # embedding
|
embedding_path = os.path.join(model_dir, 'embedding_model') # embedding
|
||||||
|
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
|
||||||
rerank_path = os.path.join(model_dir, 'rerank_model') # embedding
|
rerank_path = os.path.join(model_dir, 'rerank_model') # embedding
|
||||||
|
rerank_model_name = 'BAAI/bge-reranker-large'
|
||||||
llm_path = os.path.join(model_dir, 'pythia-14m') # llm
|
llm_path = os.path.join(model_dir, 'pythia-14m') # llm
|
||||||
|
|
||||||
# data
|
# data
|
||||||
@ -36,7 +38,8 @@ glm_key = ''
|
|||||||
|
|
||||||
# prompt
|
# prompt
|
||||||
prompt_template = """
|
prompt_template = """
|
||||||
{system_prompt}
|
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
|
||||||
|
|
||||||
根据下面检索回来的信息,回答问题。
|
根据下面检索回来的信息,回答问题。
|
||||||
{content}
|
{content}
|
||||||
问题:{query}
|
问题:{query}
|
||||||
|
@ -4,7 +4,7 @@ import os
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from langchain_community.vectorstores import FAISS
|
from langchain_community.vectorstores import FAISS
|
||||||
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path
|
from config.config import embedding_path, embedding_model_name, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path, rerank_model_name
|
||||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
@ -13,11 +13,12 @@ from langchain_core.documents.base import Document
|
|||||||
from FlagEmbedding import FlagReranker
|
from FlagEmbedding import FlagReranker
|
||||||
|
|
||||||
class Data_process():
|
class Data_process():
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.chunk_size: int=1000
|
self.chunk_size: int=1000
|
||||||
self.chunk_overlap: int=100
|
self.chunk_overlap: int=100
|
||||||
|
|
||||||
def load_embedding_model(self, model_name='BAAI/bge-small-zh-v1.5', device='cpu', normalize_embeddings=True):
|
def load_embedding_model(self, model_name=embedding_model_name, device='cpu', normalize_embeddings=True):
|
||||||
"""
|
"""
|
||||||
加载嵌入模型。
|
加载嵌入模型。
|
||||||
|
|
||||||
@ -51,7 +52,7 @@ class Data_process():
|
|||||||
return None
|
return None
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
|
def load_rerank_model(self, model_name=rerank_model_name):
|
||||||
"""
|
"""
|
||||||
加载重排名模型。
|
加载重排名模型。
|
||||||
|
|
||||||
@ -89,7 +90,6 @@ class Data_process():
|
|||||||
|
|
||||||
return reranker_model
|
return reranker_model
|
||||||
|
|
||||||
|
|
||||||
def extract_text_from_json(self, obj, content=None):
|
def extract_text_from_json(self, obj, content=None):
|
||||||
"""
|
"""
|
||||||
抽取json中的文本,用于向量库构建
|
抽取json中的文本,用于向量库构建
|
||||||
@ -118,7 +118,7 @@ class Data_process():
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def split_document(self, data_path, chunk_size=500, chunk_overlap=100):
|
def split_document(self, data_path):
|
||||||
"""
|
"""
|
||||||
切分data_path文件夹下的所有txt文件
|
切分data_path文件夹下的所有txt文件
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ class Data_process():
|
|||||||
|
|
||||||
|
|
||||||
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
text_spliter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
|
||||||
split_docs = []
|
split_docs = []
|
||||||
logger.info(f'Loading txt files from {data_path}')
|
logger.info(f'Loading txt files from {data_path}')
|
||||||
if os.path.isdir(data_path):
|
if os.path.isdir(data_path):
|
||||||
@ -178,7 +178,7 @@ class Data_process():
|
|||||||
# split_qa.append(Document(page_content = content))
|
# split_qa.append(Document(page_content = content))
|
||||||
#按conversation块切分
|
#按conversation块切分
|
||||||
content = self.extract_text_from_json(conversation['conversation'], '')
|
content = self.extract_text_from_json(conversation['conversation'], '')
|
||||||
logger.info(f'content====={content}')
|
#logger.info(f'content====={content}')
|
||||||
split_qa.append(Document(page_content = content))
|
split_qa.append(Document(page_content = content))
|
||||||
# logger.info(f'split_qa size====={len(split_qa)}')
|
# logger.info(f'split_qa size====={len(split_qa)}')
|
||||||
return split_qa
|
return split_qa
|
||||||
|
@ -32,7 +32,10 @@ def main(query, system_prompt=''):
|
|||||||
logger.info(f'score: {str(scores[i])}')
|
logger.info(f'score: {str(scores[i])}')
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
query = """
|
||||||
|
我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。
|
||||||
|
无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想
|
||||||
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
输入:
|
输入:
|
||||||
|
Loading…
Reference in New Issue
Block a user