RAG - Initial commit
This commit is contained in:
parent
18997ec79c
commit
1ca8349839
0
rag/README.md
Normal file
0
rag/README.md
Normal file
4
rag/requirements.txt
Normal file
4
rag/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
sentence_transformers
|
||||
transformers
|
||||
numpy
|
||||
loguru
|
19
rag/src/config/config.py
Normal file
19
rag/src/config/config.py
Normal file
@ -0,0 +1,19 @@
|
||||
import os
|
||||
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__)) # config
|
||||
src_dir = os.path.dirname(cur_dir) # src
|
||||
base_dir = os.path.dirname(src_dir) # base
|
||||
|
||||
# model
|
||||
model_dir = os.path.join(base_dir, 'model') # model
|
||||
embedding_path = os.path.join(model_dir, 'gte-small-zh') # embedding
|
||||
llm_path = os.path.join(model_dir, 'pythia-14m') # llm
|
||||
|
||||
# data
|
||||
data_dir = os.path.join(base_dir, 'data') # data
|
||||
knowledge_json_path = os.path.join(data_dir, 'knowledge.json') # json
|
||||
knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pickle
|
||||
|
||||
# log
|
||||
log_dir = os.path.join(base_dir, 'log') # log
|
||||
log_path = os.path.join(log_dir, 'log.log') # file
|
67
rag/src/main.py
Normal file
67
rag/src/main.py
Normal file
@ -0,0 +1,67 @@
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from config.config import knowledge_json_path, knowledge_pkl_path
|
||||
from util.encode import load_embedding, encode_qa
|
||||
|
||||
|
||||
"""
|
||||
读取知识库
|
||||
"""
|
||||
def load_knowledge() -> Tuple[list, list]:
|
||||
# 如果 pkl 不存在,则先编码存储
|
||||
if not os.path.exists(knowledge_pkl_path):
|
||||
encode_qa(knowledge_json_path, knowledge_pkl_path)
|
||||
|
||||
# 加载 json 和 pkl
|
||||
with open(knowledge_json_path, 'r', encoding='utf-8') as f1, open(knowledge_pkl_path, 'rb') as f2:
|
||||
knowledge = json.load(f1)
|
||||
encoded_knowledge = pickle.load(f2)
|
||||
return knowledge, encoded_knowledge
|
||||
|
||||
|
||||
"""
|
||||
召回 top_k 个相关的文本段
|
||||
"""
|
||||
def find_top_k(
|
||||
emb: SentenceTransformer,
|
||||
query: str,
|
||||
knowledge: list,
|
||||
encoded_knowledge: list,
|
||||
k=3
|
||||
) -> list[str]:
|
||||
# 编码 query
|
||||
query_embedding = emb.encode(query)
|
||||
|
||||
# 查找 top_k
|
||||
scores = query_embedding @ encoded_knowledge.T
|
||||
# 使用 argpartition 找出每行第 k 个大的值的索引,第 k 个位置左侧都是比它大的值,右侧都是比它小的值
|
||||
top_k_indices = np.argpartition(scores, -k)[-k:]
|
||||
# 由于 argpartition 不保证顺序,我们需要对提取出的 k 个索引进行排序
|
||||
top_k_values_sorted_indices = np.argsort(scores[top_k_indices])[::-1]
|
||||
top_k_indices = top_k_indices[top_k_values_sorted_indices]
|
||||
|
||||
# 返回
|
||||
contents = [knowledge[index] for index in top_k_indices]
|
||||
return contents
|
||||
|
||||
|
||||
def main():
|
||||
emb = load_embedding()
|
||||
knowledge, encoded_knowledge = load_knowledge()
|
||||
query = "认知心理学研究哪些心理活动?"
|
||||
contents = find_top_k(emb, query, knowledge, encoded_knowledge, 2)
|
||||
print('召回的 top-k 条相关内容如下:')
|
||||
print(json.dumps(contents, ensure_ascii=False, indent=2))
|
||||
# 这里我没实现 LLM 部分,如果有 LLM
|
||||
## 1. 读取 LLM
|
||||
## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容}
|
||||
## 3. 要求 LLM 根据已知内容回复
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
57
rag/src/util/encode.py
Normal file
57
rag/src/util/encode.py
Normal file
@ -0,0 +1,57 @@
|
||||
import json
|
||||
import pickle
|
||||
from loguru import logger
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from config.config import embedding_path
|
||||
|
||||
|
||||
"""
|
||||
加载向量模型
|
||||
"""
|
||||
def load_embedding() -> SentenceTransformer:
|
||||
logger.info('Loading embedding...')
|
||||
emb = SentenceTransformer(embedding_path)
|
||||
logger.info('Embedding loaded.')
|
||||
return emb
|
||||
|
||||
|
||||
"""
|
||||
文本编码
|
||||
"""
|
||||
def encode_raw_corpus(file_path: str, store_path: str) -> None:
|
||||
emb = load_embedding()
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
read_lines = f.readlines()
|
||||
|
||||
"""
|
||||
对文本分割(例如:按句子分割)
|
||||
"""
|
||||
lines = []
|
||||
# 分割好后的存入 lines 中
|
||||
|
||||
# 编码(转换为向量)
|
||||
encoded_knowledge = emb.encode(lines)
|
||||
with open(store_path, 'wb') as f:
|
||||
pickle.dump(encoded_knowledge, f)
|
||||
|
||||
|
||||
"""
|
||||
QA 对编码
|
||||
暂时只实现了加载 json,csv和txt先没写
|
||||
"""
|
||||
def encode_qa(file_path: str, store_path: str) -> None:
|
||||
emb = load_embedding()
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
qa_list = json.load(f)
|
||||
|
||||
# 将 QA 对拼起来作为完整一句来编码,也可以只编码 Q
|
||||
lines = []
|
||||
for qa in qa_list:
|
||||
question = qa['question']
|
||||
answer = qa['answer']
|
||||
lines.append(question + answer)
|
||||
|
||||
encoded_knowledge = emb.encode(lines)
|
||||
with open(store_path, 'wb') as f:
|
||||
pickle.dump(encoded_knowledge, f)
|
0
rag/src/util/llm.py
Normal file
0
rag/src/util/llm.py
Normal file
Loading…
Reference in New Issue
Block a user