68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
![]() |
import os
|
|||
|
import json
|
|||
|
import pickle
|
|||
|
import numpy as np
|
|||
|
from typing import Tuple
|
|||
|
from sentence_transformers import SentenceTransformer
|
|||
|
|
|||
|
from config.config import knowledge_json_path, knowledge_pkl_path
|
|||
|
from util.encode import load_embedding, encode_qa
|
|||
|
|
|||
|
|
|||
|
"""
|
|||
|
读取知识库
|
|||
|
"""
|
|||
|
def load_knowledge() -> Tuple[list, list]:
|
|||
|
# 如果 pkl 不存在,则先编码存储
|
|||
|
if not os.path.exists(knowledge_pkl_path):
|
|||
|
encode_qa(knowledge_json_path, knowledge_pkl_path)
|
|||
|
|
|||
|
# 加载 json 和 pkl
|
|||
|
with open(knowledge_json_path, 'r', encoding='utf-8') as f1, open(knowledge_pkl_path, 'rb') as f2:
|
|||
|
knowledge = json.load(f1)
|
|||
|
encoded_knowledge = pickle.load(f2)
|
|||
|
return knowledge, encoded_knowledge
|
|||
|
|
|||
|
|
|||
|
"""
|
|||
|
召回 top_k 个相关的文本段
|
|||
|
"""
|
|||
|
def find_top_k(
|
|||
|
emb: SentenceTransformer,
|
|||
|
query: str,
|
|||
|
knowledge: list,
|
|||
|
encoded_knowledge: list,
|
|||
|
k=3
|
|||
|
) -> list[str]:
|
|||
|
# 编码 query
|
|||
|
query_embedding = emb.encode(query)
|
|||
|
|
|||
|
# 查找 top_k
|
|||
|
scores = query_embedding @ encoded_knowledge.T
|
|||
|
# 使用 argpartition 找出每行第 k 个大的值的索引,第 k 个位置左侧都是比它大的值,右侧都是比它小的值
|
|||
|
top_k_indices = np.argpartition(scores, -k)[-k:]
|
|||
|
# 由于 argpartition 不保证顺序,我们需要对提取出的 k 个索引进行排序
|
|||
|
top_k_values_sorted_indices = np.argsort(scores[top_k_indices])[::-1]
|
|||
|
top_k_indices = top_k_indices[top_k_values_sorted_indices]
|
|||
|
|
|||
|
# 返回
|
|||
|
contents = [knowledge[index] for index in top_k_indices]
|
|||
|
return contents
|
|||
|
|
|||
|
|
|||
|
def main():
|
|||
|
emb = load_embedding()
|
|||
|
knowledge, encoded_knowledge = load_knowledge()
|
|||
|
query = "认知心理学研究哪些心理活动?"
|
|||
|
contents = find_top_k(emb, query, knowledge, encoded_knowledge, 2)
|
|||
|
print('召回的 top-k 条相关内容如下:')
|
|||
|
print(json.dumps(contents, ensure_ascii=False, indent=2))
|
|||
|
# 这里我没实现 LLM 部分,如果有 LLM
|
|||
|
## 1. 读取 LLM
|
|||
|
## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容}
|
|||
|
## 3. 要求 LLM 根据已知内容回复
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
main()
|