From 1ca8349839e9f10338cb13cbf080f51991a152ab Mon Sep 17 00:00:00 2001 From: Mxode Date: Thu, 7 Mar 2024 18:05:10 +0800 Subject: [PATCH] RAG - Initial commit --- rag/README.md | 0 rag/requirements.txt | 4 +++ rag/src/config/config.py | 19 ++++++++++++ rag/src/main.py | 67 ++++++++++++++++++++++++++++++++++++++++ rag/src/util/encode.py | 57 ++++++++++++++++++++++++++++++++++ rag/src/util/llm.py | 0 6 files changed, 147 insertions(+) create mode 100644 rag/README.md create mode 100644 rag/requirements.txt create mode 100644 rag/src/config/config.py create mode 100644 rag/src/main.py create mode 100644 rag/src/util/encode.py create mode 100644 rag/src/util/llm.py diff --git a/rag/README.md b/rag/README.md new file mode 100644 index 0000000..e69de29 diff --git a/rag/requirements.txt b/rag/requirements.txt new file mode 100644 index 0000000..08289b2 --- /dev/null +++ b/rag/requirements.txt @@ -0,0 +1,4 @@ +sentence_transformers +transformers +numpy +loguru \ No newline at end of file diff --git a/rag/src/config/config.py b/rag/src/config/config.py new file mode 100644 index 0000000..4c7e335 --- /dev/null +++ b/rag/src/config/config.py @@ -0,0 +1,19 @@ +import os + +cur_dir = os.path.dirname(os.path.abspath(__file__)) # config +src_dir = os.path.dirname(cur_dir) # src +base_dir = os.path.dirname(src_dir) # base + +# model +model_dir = os.path.join(base_dir, 'model') # model +embedding_path = os.path.join(model_dir, 'gte-small-zh') # embedding +llm_path = os.path.join(model_dir, 'pythia-14m') # llm + +# data +data_dir = os.path.join(base_dir, 'data') # data +knowledge_json_path = os.path.join(data_dir, 'knowledge.json') # json +knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pickle + +# log +log_dir = os.path.join(base_dir, 'log') # log +log_path = os.path.join(log_dir, 'log.log') # file diff --git a/rag/src/main.py b/rag/src/main.py new file mode 100644 index 0000000..219ce85 --- /dev/null +++ b/rag/src/main.py @@ -0,0 +1,67 @@ +import os +import json +import pickle +import numpy as np +from typing import Tuple +from sentence_transformers import SentenceTransformer + +from config.config import knowledge_json_path, knowledge_pkl_path +from util.encode import load_embedding, encode_qa + + +""" +读取知识库 +""" +def load_knowledge() -> Tuple[list, list]: + # 如果 pkl 不存在,则先编码存储 + if not os.path.exists(knowledge_pkl_path): + encode_qa(knowledge_json_path, knowledge_pkl_path) + + # 加载 json 和 pkl + with open(knowledge_json_path, 'r', encoding='utf-8') as f1, open(knowledge_pkl_path, 'rb') as f2: + knowledge = json.load(f1) + encoded_knowledge = pickle.load(f2) + return knowledge, encoded_knowledge + + +""" +召回 top_k 个相关的文本段 +""" +def find_top_k( + emb: SentenceTransformer, + query: str, + knowledge: list, + encoded_knowledge: list, + k=3 +) -> list[str]: + # 编码 query + query_embedding = emb.encode(query) + + # 查找 top_k + scores = query_embedding @ encoded_knowledge.T + # 使用 argpartition 找出每行第 k 个大的值的索引,第 k 个位置左侧都是比它大的值,右侧都是比它小的值 + top_k_indices = np.argpartition(scores, -k)[-k:] + # 由于 argpartition 不保证顺序,我们需要对提取出的 k 个索引进行排序 + top_k_values_sorted_indices = np.argsort(scores[top_k_indices])[::-1] + top_k_indices = top_k_indices[top_k_values_sorted_indices] + + # 返回 + contents = [knowledge[index] for index in top_k_indices] + return contents + + +def main(): + emb = load_embedding() + knowledge, encoded_knowledge = load_knowledge() + query = "认知心理学研究哪些心理活动?" + contents = find_top_k(emb, query, knowledge, encoded_knowledge, 2) + print('召回的 top-k 条相关内容如下:') + print(json.dumps(contents, ensure_ascii=False, indent=2)) + # 这里我没实现 LLM 部分,如果有 LLM + ## 1. 读取 LLM + ## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容} + ## 3. 要求 LLM 根据已知内容回复 + + +if __name__ == '__main__': + main() diff --git a/rag/src/util/encode.py b/rag/src/util/encode.py new file mode 100644 index 0000000..791fa73 --- /dev/null +++ b/rag/src/util/encode.py @@ -0,0 +1,57 @@ +import json +import pickle +from loguru import logger +from sentence_transformers import SentenceTransformer + +from config.config import embedding_path + + +""" +加载向量模型 +""" +def load_embedding() -> SentenceTransformer: + logger.info('Loading embedding...') + emb = SentenceTransformer(embedding_path) + logger.info('Embedding loaded.') + return emb + + +""" +文本编码 +""" +def encode_raw_corpus(file_path: str, store_path: str) -> None: + emb = load_embedding() + with open(file_path, 'r', encoding='utf-8') as f: + read_lines = f.readlines() + + """ + 对文本分割(例如:按句子分割) + """ + lines = [] + # 分割好后的存入 lines 中 + + # 编码(转换为向量) + encoded_knowledge = emb.encode(lines) + with open(store_path, 'wb') as f: + pickle.dump(encoded_knowledge, f) + + +""" +QA 对编码 +暂时只实现了加载 json,csv和txt先没写 +""" +def encode_qa(file_path: str, store_path: str) -> None: + emb = load_embedding() + with open(file_path, 'r', encoding='utf-8') as f: + qa_list = json.load(f) + + # 将 QA 对拼起来作为完整一句来编码,也可以只编码 Q + lines = [] + for qa in qa_list: + question = qa['question'] + answer = qa['answer'] + lines.append(question + answer) + + encoded_knowledge = emb.encode(lines) + with open(store_path, 'wb') as f: + pickle.dump(encoded_knowledge, f) diff --git a/rag/src/util/llm.py b/rag/src/util/llm.py new file mode 100644 index 0000000..e69de29