RAG - Initial commit
This commit is contained in:
		
							parent
							
								
									18997ec79c
								
							
						
					
					
						commit
						1ca8349839
					
				
							
								
								
									
										0
									
								
								rag/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								rag/README.md
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										4
									
								
								rag/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								rag/requirements.txt
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
sentence_transformers
 | 
			
		||||
transformers
 | 
			
		||||
numpy
 | 
			
		||||
loguru
 | 
			
		||||
							
								
								
									
										19
									
								
								rag/src/config/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								rag/src/config/config.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,19 @@
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
cur_dir = os.path.dirname(os.path.abspath(__file__))                # config
 | 
			
		||||
src_dir = os.path.dirname(cur_dir)                                  # src
 | 
			
		||||
base_dir = os.path.dirname(src_dir)                                 # base
 | 
			
		||||
 | 
			
		||||
# model
 | 
			
		||||
model_dir = os.path.join(base_dir, 'model')                         # model
 | 
			
		||||
embedding_path = os.path.join(model_dir, 'gte-small-zh')            # embedding
 | 
			
		||||
llm_path = os.path.join(model_dir, 'pythia-14m')                    # llm
 | 
			
		||||
 | 
			
		||||
# data
 | 
			
		||||
data_dir = os.path.join(base_dir, 'data')                           # data
 | 
			
		||||
knowledge_json_path = os.path.join(data_dir, 'knowledge.json')      # json
 | 
			
		||||
knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl')        # pickle
 | 
			
		||||
 | 
			
		||||
# log
 | 
			
		||||
log_dir = os.path.join(base_dir, 'log')                             # log
 | 
			
		||||
log_path = os.path.join(log_dir, 'log.log')                         # file
 | 
			
		||||
							
								
								
									
										67
									
								
								rag/src/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								rag/src/main.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,67 @@
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
import pickle
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
from sentence_transformers import SentenceTransformer
 | 
			
		||||
 | 
			
		||||
from config.config import knowledge_json_path, knowledge_pkl_path
 | 
			
		||||
from util.encode import load_embedding, encode_qa
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
读取知识库
 | 
			
		||||
"""
 | 
			
		||||
def load_knowledge() -> Tuple[list, list]:
 | 
			
		||||
    # 如果 pkl 不存在,则先编码存储
 | 
			
		||||
    if not os.path.exists(knowledge_pkl_path):
 | 
			
		||||
        encode_qa(knowledge_json_path, knowledge_pkl_path)
 | 
			
		||||
 | 
			
		||||
    # 加载 json 和 pkl
 | 
			
		||||
    with open(knowledge_json_path, 'r', encoding='utf-8') as f1, open(knowledge_pkl_path, 'rb') as f2:
 | 
			
		||||
        knowledge = json.load(f1)
 | 
			
		||||
        encoded_knowledge = pickle.load(f2)
 | 
			
		||||
    return knowledge, encoded_knowledge
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
召回 top_k 个相关的文本段
 | 
			
		||||
"""
 | 
			
		||||
def find_top_k(
 | 
			
		||||
    emb: SentenceTransformer,
 | 
			
		||||
    query: str,
 | 
			
		||||
    knowledge: list,
 | 
			
		||||
    encoded_knowledge: list,
 | 
			
		||||
    k=3
 | 
			
		||||
) -> list[str]:
 | 
			
		||||
    # 编码 query
 | 
			
		||||
    query_embedding = emb.encode(query)
 | 
			
		||||
 | 
			
		||||
    # 查找 top_k
 | 
			
		||||
    scores = query_embedding @ encoded_knowledge.T
 | 
			
		||||
    # 使用 argpartition 找出每行第 k 个大的值的索引,第 k 个位置左侧都是比它大的值,右侧都是比它小的值
 | 
			
		||||
    top_k_indices = np.argpartition(scores, -k)[-k:]
 | 
			
		||||
    # 由于 argpartition 不保证顺序,我们需要对提取出的 k 个索引进行排序
 | 
			
		||||
    top_k_values_sorted_indices = np.argsort(scores[top_k_indices])[::-1]
 | 
			
		||||
    top_k_indices = top_k_indices[top_k_values_sorted_indices]
 | 
			
		||||
 | 
			
		||||
    # 返回
 | 
			
		||||
    contents = [knowledge[index] for index in top_k_indices]
 | 
			
		||||
    return contents
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    emb = load_embedding()
 | 
			
		||||
    knowledge, encoded_knowledge = load_knowledge()
 | 
			
		||||
    query = "认知心理学研究哪些心理活动?"
 | 
			
		||||
    contents = find_top_k(emb, query, knowledge, encoded_knowledge, 2)
 | 
			
		||||
    print('召回的 top-k 条相关内容如下:')
 | 
			
		||||
    print(json.dumps(contents, ensure_ascii=False, indent=2))
 | 
			
		||||
    # 这里我没实现 LLM 部分,如果有 LLM
 | 
			
		||||
    ## 1. 读取 LLM
 | 
			
		||||
    ## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容}
 | 
			
		||||
    ## 3. 要求 LLM 根据已知内容回复
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										57
									
								
								rag/src/util/encode.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								rag/src/util/encode.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,57 @@
 | 
			
		||||
import json
 | 
			
		||||
import pickle
 | 
			
		||||
from loguru import logger
 | 
			
		||||
from sentence_transformers import SentenceTransformer
 | 
			
		||||
 | 
			
		||||
from config.config import embedding_path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
加载向量模型
 | 
			
		||||
"""
 | 
			
		||||
def load_embedding() -> SentenceTransformer:
 | 
			
		||||
    logger.info('Loading embedding...')
 | 
			
		||||
    emb = SentenceTransformer(embedding_path)
 | 
			
		||||
    logger.info('Embedding loaded.')
 | 
			
		||||
    return emb
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
文本编码
 | 
			
		||||
"""
 | 
			
		||||
def encode_raw_corpus(file_path: str, store_path: str) -> None:
 | 
			
		||||
    emb = load_embedding()
 | 
			
		||||
    with open(file_path, 'r', encoding='utf-8') as f:
 | 
			
		||||
        read_lines = f.readlines()
 | 
			
		||||
    
 | 
			
		||||
    """
 | 
			
		||||
    对文本分割(例如:按句子分割)
 | 
			
		||||
    """
 | 
			
		||||
    lines = []
 | 
			
		||||
    # 分割好后的存入 lines 中
 | 
			
		||||
 | 
			
		||||
    # 编码(转换为向量)
 | 
			
		||||
    encoded_knowledge = emb.encode(lines)
 | 
			
		||||
    with open(store_path, 'wb') as f:
 | 
			
		||||
        pickle.dump(encoded_knowledge, f)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
QA 对编码
 | 
			
		||||
暂时只实现了加载 json,csv和txt先没写
 | 
			
		||||
"""
 | 
			
		||||
def encode_qa(file_path: str, store_path: str) -> None:
 | 
			
		||||
    emb = load_embedding()
 | 
			
		||||
    with open(file_path, 'r', encoding='utf-8') as f:
 | 
			
		||||
        qa_list = json.load(f)
 | 
			
		||||
    
 | 
			
		||||
    # 将 QA 对拼起来作为完整一句来编码,也可以只编码 Q
 | 
			
		||||
    lines = []
 | 
			
		||||
    for qa in qa_list:
 | 
			
		||||
        question = qa['question']
 | 
			
		||||
        answer = qa['answer']
 | 
			
		||||
        lines.append(question + answer)
 | 
			
		||||
 | 
			
		||||
    encoded_knowledge = emb.encode(lines)
 | 
			
		||||
    with open(store_path, 'wb') as f:
 | 
			
		||||
        pickle.dump(encoded_knowledge, f)
 | 
			
		||||
							
								
								
									
										0
									
								
								rag/src/util/llm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								rag/src/util/llm.py
									
									
									
									
									
										Normal file
									
								
							
		Loading…
	
		Reference in New Issue
	
	Block a user