RAG - Initial commit

This commit is contained in:
Mxode 2024-03-07 18:05:10 +08:00
parent 18997ec79c
commit 1ca8349839
6 changed files with 147 additions and 0 deletions

0
rag/README.md Normal file
View File

4
rag/requirements.txt Normal file
View File

@ -0,0 +1,4 @@
sentence_transformers
transformers
numpy
loguru

19
rag/src/config/config.py Normal file
View File

@ -0,0 +1,19 @@
import os
cur_dir = os.path.dirname(os.path.abspath(__file__)) # config
src_dir = os.path.dirname(cur_dir) # src
base_dir = os.path.dirname(src_dir) # base
# model
model_dir = os.path.join(base_dir, 'model') # model
embedding_path = os.path.join(model_dir, 'gte-small-zh') # embedding
llm_path = os.path.join(model_dir, 'pythia-14m') # llm
# data
data_dir = os.path.join(base_dir, 'data') # data
knowledge_json_path = os.path.join(data_dir, 'knowledge.json') # json
knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pickle
# log
log_dir = os.path.join(base_dir, 'log') # log
log_path = os.path.join(log_dir, 'log.log') # file

67
rag/src/main.py Normal file
View File

@ -0,0 +1,67 @@
import os
import json
import pickle
import numpy as np
from typing import Tuple
from sentence_transformers import SentenceTransformer
from config.config import knowledge_json_path, knowledge_pkl_path
from util.encode import load_embedding, encode_qa
"""
读取知识库
"""
def load_knowledge() -> Tuple[list, list]:
# 如果 pkl 不存在,则先编码存储
if not os.path.exists(knowledge_pkl_path):
encode_qa(knowledge_json_path, knowledge_pkl_path)
# 加载 json 和 pkl
with open(knowledge_json_path, 'r', encoding='utf-8') as f1, open(knowledge_pkl_path, 'rb') as f2:
knowledge = json.load(f1)
encoded_knowledge = pickle.load(f2)
return knowledge, encoded_knowledge
"""
召回 top_k 个相关的文本段
"""
def find_top_k(
emb: SentenceTransformer,
query: str,
knowledge: list,
encoded_knowledge: list,
k=3
) -> list[str]:
# 编码 query
query_embedding = emb.encode(query)
# 查找 top_k
scores = query_embedding @ encoded_knowledge.T
# 使用 argpartition 找出每行第 k 个大的值的索引,第 k 个位置左侧都是比它大的值,右侧都是比它小的值
top_k_indices = np.argpartition(scores, -k)[-k:]
# 由于 argpartition 不保证顺序,我们需要对提取出的 k 个索引进行排序
top_k_values_sorted_indices = np.argsort(scores[top_k_indices])[::-1]
top_k_indices = top_k_indices[top_k_values_sorted_indices]
# 返回
contents = [knowledge[index] for index in top_k_indices]
return contents
def main():
emb = load_embedding()
knowledge, encoded_knowledge = load_knowledge()
query = "认知心理学研究哪些心理活动?"
contents = find_top_k(emb, query, knowledge, encoded_knowledge, 2)
print('召回的 top-k 条相关内容如下:')
print(json.dumps(contents, ensure_ascii=False, indent=2))
# 这里我没实现 LLM 部分,如果有 LLM
## 1. 读取 LLM
## 2. 将 contents 拼接为 prompt传给 LLM作为 {已知内容}
## 3. 要求 LLM 根据已知内容回复
if __name__ == '__main__':
main()

57
rag/src/util/encode.py Normal file
View File

@ -0,0 +1,57 @@
import json
import pickle
from loguru import logger
from sentence_transformers import SentenceTransformer
from config.config import embedding_path
"""
加载向量模型
"""
def load_embedding() -> SentenceTransformer:
logger.info('Loading embedding...')
emb = SentenceTransformer(embedding_path)
logger.info('Embedding loaded.')
return emb
"""
文本编码
"""
def encode_raw_corpus(file_path: str, store_path: str) -> None:
emb = load_embedding()
with open(file_path, 'r', encoding='utf-8') as f:
read_lines = f.readlines()
"""
对文本分割例如按句子分割
"""
lines = []
# 分割好后的存入 lines 中
# 编码(转换为向量)
encoded_knowledge = emb.encode(lines)
with open(store_path, 'wb') as f:
pickle.dump(encoded_knowledge, f)
"""
QA 对编码
暂时只实现了加载 jsoncsv和txt先没写
"""
def encode_qa(file_path: str, store_path: str) -> None:
emb = load_embedding()
with open(file_path, 'r', encoding='utf-8') as f:
qa_list = json.load(f)
# 将 QA 对拼起来作为完整一句来编码,也可以只编码 Q
lines = []
for qa in qa_list:
question = qa['question']
answer = qa['answer']
lines.append(question + answer)
encoded_knowledge = emb.encode(lines)
with open(store_path, 'wb') as f:
pickle.dump(encoded_knowledge, f)

0
rag/src/util/llm.py Normal file
View File