RAG - Initial commit
This commit is contained in:
parent
18997ec79c
commit
1ca8349839
0
rag/README.md
Normal file
0
rag/README.md
Normal file
4
rag/requirements.txt
Normal file
4
rag/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
sentence_transformers
|
||||||
|
transformers
|
||||||
|
numpy
|
||||||
|
loguru
|
19
rag/src/config/config.py
Normal file
19
rag/src/config/config.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
cur_dir = os.path.dirname(os.path.abspath(__file__)) # config
|
||||||
|
src_dir = os.path.dirname(cur_dir) # src
|
||||||
|
base_dir = os.path.dirname(src_dir) # base
|
||||||
|
|
||||||
|
# model
|
||||||
|
model_dir = os.path.join(base_dir, 'model') # model
|
||||||
|
embedding_path = os.path.join(model_dir, 'gte-small-zh') # embedding
|
||||||
|
llm_path = os.path.join(model_dir, 'pythia-14m') # llm
|
||||||
|
|
||||||
|
# data
|
||||||
|
data_dir = os.path.join(base_dir, 'data') # data
|
||||||
|
knowledge_json_path = os.path.join(data_dir, 'knowledge.json') # json
|
||||||
|
knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pickle
|
||||||
|
|
||||||
|
# log
|
||||||
|
log_dir = os.path.join(base_dir, 'log') # log
|
||||||
|
log_path = os.path.join(log_dir, 'log.log') # file
|
67
rag/src/main.py
Normal file
67
rag/src/main.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
from typing import Tuple
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from config.config import knowledge_json_path, knowledge_pkl_path
|
||||||
|
from util.encode import load_embedding, encode_qa
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
读取知识库
|
||||||
|
"""
|
||||||
|
def load_knowledge() -> Tuple[list, list]:
|
||||||
|
# 如果 pkl 不存在,则先编码存储
|
||||||
|
if not os.path.exists(knowledge_pkl_path):
|
||||||
|
encode_qa(knowledge_json_path, knowledge_pkl_path)
|
||||||
|
|
||||||
|
# 加载 json 和 pkl
|
||||||
|
with open(knowledge_json_path, 'r', encoding='utf-8') as f1, open(knowledge_pkl_path, 'rb') as f2:
|
||||||
|
knowledge = json.load(f1)
|
||||||
|
encoded_knowledge = pickle.load(f2)
|
||||||
|
return knowledge, encoded_knowledge
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
召回 top_k 个相关的文本段
|
||||||
|
"""
|
||||||
|
def find_top_k(
|
||||||
|
emb: SentenceTransformer,
|
||||||
|
query: str,
|
||||||
|
knowledge: list,
|
||||||
|
encoded_knowledge: list,
|
||||||
|
k=3
|
||||||
|
) -> list[str]:
|
||||||
|
# 编码 query
|
||||||
|
query_embedding = emb.encode(query)
|
||||||
|
|
||||||
|
# 查找 top_k
|
||||||
|
scores = query_embedding @ encoded_knowledge.T
|
||||||
|
# 使用 argpartition 找出每行第 k 个大的值的索引,第 k 个位置左侧都是比它大的值,右侧都是比它小的值
|
||||||
|
top_k_indices = np.argpartition(scores, -k)[-k:]
|
||||||
|
# 由于 argpartition 不保证顺序,我们需要对提取出的 k 个索引进行排序
|
||||||
|
top_k_values_sorted_indices = np.argsort(scores[top_k_indices])[::-1]
|
||||||
|
top_k_indices = top_k_indices[top_k_values_sorted_indices]
|
||||||
|
|
||||||
|
# 返回
|
||||||
|
contents = [knowledge[index] for index in top_k_indices]
|
||||||
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
emb = load_embedding()
|
||||||
|
knowledge, encoded_knowledge = load_knowledge()
|
||||||
|
query = "认知心理学研究哪些心理活动?"
|
||||||
|
contents = find_top_k(emb, query, knowledge, encoded_knowledge, 2)
|
||||||
|
print('召回的 top-k 条相关内容如下:')
|
||||||
|
print(json.dumps(contents, ensure_ascii=False, indent=2))
|
||||||
|
# 这里我没实现 LLM 部分,如果有 LLM
|
||||||
|
## 1. 读取 LLM
|
||||||
|
## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容}
|
||||||
|
## 3. 要求 LLM 根据已知内容回复
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
57
rag/src/util/encode.py
Normal file
57
rag/src/util/encode.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from loguru import logger
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from config.config import embedding_path
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
加载向量模型
|
||||||
|
"""
|
||||||
|
def load_embedding() -> SentenceTransformer:
|
||||||
|
logger.info('Loading embedding...')
|
||||||
|
emb = SentenceTransformer(embedding_path)
|
||||||
|
logger.info('Embedding loaded.')
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
文本编码
|
||||||
|
"""
|
||||||
|
def encode_raw_corpus(file_path: str, store_path: str) -> None:
|
||||||
|
emb = load_embedding()
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
read_lines = f.readlines()
|
||||||
|
|
||||||
|
"""
|
||||||
|
对文本分割(例如:按句子分割)
|
||||||
|
"""
|
||||||
|
lines = []
|
||||||
|
# 分割好后的存入 lines 中
|
||||||
|
|
||||||
|
# 编码(转换为向量)
|
||||||
|
encoded_knowledge = emb.encode(lines)
|
||||||
|
with open(store_path, 'wb') as f:
|
||||||
|
pickle.dump(encoded_knowledge, f)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
QA 对编码
|
||||||
|
暂时只实现了加载 json,csv和txt先没写
|
||||||
|
"""
|
||||||
|
def encode_qa(file_path: str, store_path: str) -> None:
|
||||||
|
emb = load_embedding()
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
qa_list = json.load(f)
|
||||||
|
|
||||||
|
# 将 QA 对拼起来作为完整一句来编码,也可以只编码 Q
|
||||||
|
lines = []
|
||||||
|
for qa in qa_list:
|
||||||
|
question = qa['question']
|
||||||
|
answer = qa['answer']
|
||||||
|
lines.append(question + answer)
|
||||||
|
|
||||||
|
encoded_knowledge = emb.encode(lines)
|
||||||
|
with open(store_path, 'wb') as f:
|
||||||
|
pickle.dump(encoded_knowledge, f)
|
0
rag/src/util/llm.py
Normal file
0
rag/src/util/llm.py
Normal file
Loading…
Reference in New Issue
Block a user