commit
						c7d916bf4f
					
				| @ -13,11 +13,16 @@ llm_path = os.path.join(model_dir, 'pythia-14m')                    # llm | ||||
| # data | ||||
| data_dir = os.path.join(base_dir, 'data')                           # data | ||||
| knowledge_json_path = os.path.join(data_dir, 'knowledge.json')      # json | ||||
| knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl')        # pickle | ||||
| knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl')        # pkl | ||||
| doc_dir = os.path.join(data_dir, 'txt')    | ||||
| qa_dir = os.path.join(data_dir, 'json')    | ||||
| 
 | ||||
| # log | ||||
| log_dir = os.path.join(base_dir, 'log')                             # log | ||||
| log_path = os.path.join(log_dir, 'log.log')                         # file | ||||
| 
 | ||||
| # vector DB | ||||
| vector_db_dir = os.path.join(data_dir, 'vector_db.pkl') | ||||
| 
 | ||||
| select_num = 3 | ||||
| retrieval_num = 10 | ||||
							
								
								
									
										262
									
								
								rag/src/data_processing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										262
									
								
								rag/src/data_processing.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,262 @@ | ||||
| import json | ||||
| import pickle | ||||
| from loguru import logger | ||||
| from sentence_transformers import SentenceTransformer | ||||
| 
 | ||||
| from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, base_dir, vector_db_dir | ||||
| import os | ||||
| import faiss | ||||
| import platform | ||||
| from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader | ||||
| from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter | ||||
| from BCEmbedding import EmbeddingModel, RerankerModel | ||||
| from util.pipeline import EmoLLMRAG | ||||
| import pickle | ||||
| from transformers import AutoTokenizer, AutoModelForCausalLM | ||||
| import torch | ||||
| import streamlit as st | ||||
| from openxlab.model import download | ||||
| 
 | ||||
| 
 | ||||
| ''' | ||||
| 1)根据QA对/TXT 文本生成 embedding  | ||||
| 2)调用 langchain FAISS 接口构建 vector DB  | ||||
| 3)存储到 openxlab.dataset 中,方便后续调用 | ||||
| 4)提供 embedding 的接口函数,方便后续调用 | ||||
| 5)提供 rerank 的接口函数,方便后续调用 | ||||
| ''' | ||||
| 
 | ||||
| """ | ||||
| 加载向量模型 | ||||
| """ | ||||
| def load_embedding_model(): | ||||
|     logger.info('Loading embedding model...') | ||||
|     # model = EmbeddingModel(model_name_or_path="huggingface/bce-embedding-base_v1")    | ||||
|     model = EmbeddingModel(model_name_or_path="maidalun1020/bce-embedding-base_v1") | ||||
|     logger.info('Embedding model loaded.') | ||||
|     return model | ||||
| 
 | ||||
| def load_rerank_model(): | ||||
|     logger.info('Loading rerank_model...') | ||||
|     model = RerankerModel(model_name_or_path="maidalun1020/bce-reranker-base_v1") | ||||
|     # model = RerankerModel(model_name_or_path="huggingface/bce-reranker-base_v1") | ||||
|     logger.info('Rerank model loaded.') | ||||
|     return model | ||||
| 
 | ||||
| 
 | ||||
| def split_document(data_path, chunk_size=1000, chunk_overlap=100): | ||||
|     # text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | ||||
|     text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)  | ||||
|     split_docs = [] | ||||
|     logger.info(f'Loading txt files from {data_path}') | ||||
|     if os.path.isdir(data_path): | ||||
|         # 如果是文件夹,则遍历读取 | ||||
|         for root, dirs, files in os.walk(data_path): | ||||
|             for file in files:       | ||||
|                 if file.endswith('.txt'):             | ||||
|                     file_path = os.path.join(root, file) | ||||
|                     # logger.info(f'splitting file {file_path}') | ||||
|                     text_loader = TextLoader(file_path, encoding='utf-8')         | ||||
|                     text = text_loader.load() | ||||
|                      | ||||
|                     splits = text_spliter.split_documents(text) | ||||
|                     # logger.info(f"splits type {type(splits[0])}") | ||||
|                     # logger.info(f'splits size {len(splits)}') | ||||
|                     split_docs += splits | ||||
|     elif data_path.endswith('.txt'):  | ||||
|         file_path = os.path.join(root, data_path) | ||||
|         # logger.info(f'splitting file {file_path}') | ||||
|         text_loader = TextLoader(file_path, encoding='utf-8')         | ||||
|         text = text_loader.load() | ||||
|         splits = text_spliter.split_documents(text) | ||||
|         # logger.info(f"splits type {type(splits[0])}") | ||||
|         # logger.info(f'splits size {len(splits)}') | ||||
|         split_docs = splits | ||||
|     logger.info(f'split_docs size {len(split_docs)}') | ||||
|     return split_docs | ||||
| 
 | ||||
| 
 | ||||
| ##TODO 1、读取system prompt 2、限制序列长度 | ||||
| def split_conversation(path): | ||||
|     ''' | ||||
|     data format: | ||||
|     [ | ||||
|         { | ||||
|             "conversation": [ | ||||
|                 { | ||||
|                     "input":  Q1 | ||||
|                     "output": A1 | ||||
|                 }, | ||||
|                 { | ||||
|                     "input":  Q2 | ||||
|                     "output": A2 | ||||
|                 }, | ||||
|             ] | ||||
|         }, | ||||
|     ] | ||||
|     ''' | ||||
|     qa_pairs = [] | ||||
|     logger.info(f'Loading json files from {path}') | ||||
|     if os.path.isfile(path): | ||||
|         with open(path, 'r', encoding='utf-8') as file: | ||||
|             data = json.load(file) | ||||
|         for conversation in data: | ||||
|             for dialog in conversation['conversation']: | ||||
|                 # input_text = dialog['input'] | ||||
|                 # output_text = dialog['output']       | ||||
|                 # if len(input_text) > max_length or len(output_text) > max_length: | ||||
|                 #     continue | ||||
|                 qa_pairs.append(dialog)         | ||||
|     elif os.path.isdir(path): | ||||
|         # 如果是文件夹,则遍历读取 | ||||
|         for root, dirs, files in os.walk(path): | ||||
|             for file in files: | ||||
|                 if file.endswith('.json'):  | ||||
|                     file_path = os.path.join(root, file) | ||||
|                     logger.info(f'splitting file {file_path}') | ||||
|                     with open(file_path, 'r', encoding='utf-8') as f: | ||||
|                         data = json.load(f) | ||||
|                         for conversation in data: | ||||
|                             for dialog in conversation['conversation']: | ||||
|                                 qa_pairs.append(dialog) | ||||
|     return qa_pairs | ||||
| 
 | ||||
| 
 | ||||
|          | ||||
| # 加载本地索引 | ||||
| def load_index_and_knowledge(): | ||||
|     current_os = platform.system() | ||||
|     split_doc = [] | ||||
|     split_qa = [] | ||||
|     #读取知识库 | ||||
|     if not os.path.exists(knowledge_pkl_path): | ||||
|         split_doc = split_document(doc_dir) | ||||
|         split_qa = split_conversation(qa_dir) | ||||
|     # logger.info(f'split_qa size:{len(split_qa)}') | ||||
|     # logger.info(f'type of split_qa:{type(split_qa[0])}') | ||||
|     # logger.info(f'split_doc size:{len(split_doc)}') | ||||
|     # logger.info(f'type of doc:{type(split_doc[0])}') | ||||
|         knowledge_chunks = split_doc + split_qa | ||||
|         with open(knowledge_pkl_path, 'wb') as file: | ||||
|             pickle.dump(knowledge_chunks, file) | ||||
|     else: | ||||
|         with open(knowledge_pkl_path , 'rb') as f: | ||||
|             knowledge_chunks = pickle.load(f) | ||||
|          | ||||
|     #读取vector DB | ||||
|     if not os.path.exists(vector_db_dir): | ||||
|         logger.info(f'Creating index...') | ||||
|         emb_model = load_embedding_model() | ||||
|         if not split_doc: | ||||
|             split_doc = split_document(doc_dir) | ||||
|         if not split_qa: | ||||
|             split_qa = split_conversation(qa_dir) | ||||
|         # 创建索引,windows不支持faiss-gpu | ||||
|         if current_os == 'Linux': | ||||
|             index = create_index_gpu(split_doc, split_qa, emb_model, vector_db_dir) | ||||
|         else: | ||||
|             index = create_index_cpu(split_doc, split_qa, emb_model, vector_db_dir) | ||||
|     else: | ||||
|         if current_os == 'Linux': | ||||
|             res = faiss.StandardGpuResources() | ||||
|             index = faiss.index_cpu_to_gpu(res, 0, index, vector_db_dir) | ||||
|         else: | ||||
|             index = faiss.read_index(vector_db_dir) | ||||
|      | ||||
|     return index, knowledge_chunks | ||||
| 
 | ||||
| 
 | ||||
| def create_index_cpu(split_doc, split_qa, emb_model, knowledge_pkl_path, dimension = 768, question_only=False): | ||||
|     # 假设BCE嵌入的维度是768,根据你选择的模型可能不同 | ||||
|     faiss_index_cpu = faiss.IndexFlatIP(dimension)  # 创建一个使用内积的FAISS索引 | ||||
|     # 将问答对转换为向量并添加到FAISS索引中 | ||||
|     for doc in split_doc: | ||||
|         # type_of_docs = type(split_doc)  | ||||
|         text = f"{doc.page_content}" | ||||
|         vector = emb_model.encode([text]) | ||||
|         faiss_index_cpu.add(vector) | ||||
|     for qa in split_qa: | ||||
|         #仅对Q对进行编码 | ||||
|         text = f"{qa['input']}" | ||||
|         vector = emb_model.encode([text]) | ||||
|         faiss_index_cpu.add(vector) | ||||
|     faiss.write_index(faiss_index_cpu, knowledge_pkl_path) | ||||
|     return faiss_index_cpu | ||||
| 
 | ||||
| def create_index_gpu(split_doc, split_qa, emb_model, knowledge_pkl_path, dimension = 768, question_only=False): | ||||
|     res = faiss.StandardGpuResources() | ||||
|     index = faiss.IndexFlatIP(dimension) | ||||
|     faiss_index_gpu = faiss.index_cpu_to_gpu(res, 0, index) | ||||
|     for doc in split_doc: | ||||
|         # type_of_docs = type(split_doc) | ||||
|         text = f"{doc.page_content}" | ||||
|         vector = emb_model.encode([text]) | ||||
|         faiss_index_gpu.add(vector) | ||||
|     for qa in split_qa: | ||||
|         #仅对Q对进行编码 | ||||
|         text = f"{qa['input']}" | ||||
|         vector = emb_model.encode([text]) | ||||
|         faiss_index_gpu.add(vector) | ||||
|     faiss.write_index(faiss_index_gpu, knowledge_pkl_path) | ||||
|     return faiss_index_gpu | ||||
| 
 | ||||
|     | ||||
| 
 | ||||
| # 根据query搜索相似文本 | ||||
| def find_top_k(query, faiss_index, k=5): | ||||
|     emb_model = load_embedding_model() | ||||
|     emb_query = emb_model.encode([query]) | ||||
|     distances, indices = faiss_index.search(emb_query, k) | ||||
|     return distances, indices | ||||
| 
 | ||||
| def rerank(query, indices, knowledge_chunks): | ||||
|     passages = [] | ||||
|     for index in indices[0]: | ||||
|         content = knowledge_chunks[index] | ||||
|         ''' | ||||
|         txt: 'langchain_core.documents.base.Document' | ||||
|         json: dict | ||||
|         ''' | ||||
|         # logger.info(f'retrieved content:{content}') | ||||
|         # logger.info(f'type of content:{type(content)}') | ||||
|         if type(content) == dict: | ||||
|             content = content["input"] + '\n' + content["output"] | ||||
|         else: | ||||
|             content = content.page_content | ||||
|         passages.append(content) | ||||
|      | ||||
|     model = load_rerank_model() | ||||
|     rerank_results = model.rerank(query, passages) | ||||
|     return rerank_results | ||||
| 
 | ||||
| @st.cache_resource | ||||
| def load_model(): | ||||
|     model = ( | ||||
|         AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True) | ||||
|         .to(torch.bfloat16) | ||||
|         .cuda() | ||||
|     ) | ||||
|     tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True) | ||||
|     return model, tokenizer | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     logger.info(data_dir) | ||||
|     if not os.path.exists(data_dir): | ||||
|          os.mkdir(data_dir) | ||||
|     faiss_index, knowledge_chunks = load_index_and_knowledge() | ||||
|     # 按照query进行查询 | ||||
|     # query = "她要阻挠姐姐的婚姻,即使她自己的尸体在房门跟前" | ||||
|     # query = "肯定的。我最近睡眠很差,总是做噩梦。而且我吃得也不好,体重一直在下降" | ||||
|     # query = "序言 (一) 变态心理学是心理学本科生的必修课程之一,教材更新的问题一直在困扰着我们。" | ||||
|     query = "心理咨询师,我觉得我的胸闷症状越来越严重了,这让我很害怕" | ||||
|     distances, indices = find_top_k(query, faiss_index, 5) | ||||
|     logger.info(f'distances==={distances}') | ||||
|     logger.info(f'indices==={indices}') | ||||
|     | ||||
| 
 | ||||
|     # rerank无法返回id,先实现按整个问答对排序 | ||||
|     rerank_results = rerank(query, indices, knowledge_chunks) | ||||
|     for passage, score in zip(rerank_results['rerank_passages'], rerank_results['rerank_scores']): | ||||
|         print(str(score)+'\n') | ||||
|         print(passage+'\n') | ||||
|    | ||||
							
								
								
									
										112
									
								
								rag/src/main.py
									
									
									
									
									
								
							
							
						
						
									
										112
									
								
								rag/src/main.py
									
									
									
									
									
								
							| @ -5,87 +5,67 @@ import numpy as np | ||||
| from typing import Tuple | ||||
| from sentence_transformers import SentenceTransformer | ||||
| 
 | ||||
| from config.config import knowledge_json_path, knowledge_pkl_path, model_repo | ||||
| from config.config import knowledge_json_path, knowledge_pkl_path, model_repo, model_dir, base_dir | ||||
| from util.encode import load_embedding, encode_qa | ||||
| from util.pipeline import EmoLLMRAG | ||||
| 
 | ||||
| from loguru import logger | ||||
| from transformers import AutoTokenizer, AutoModelForCausalLM | ||||
| import torch | ||||
| import streamlit as st | ||||
| from openxlab.model import download | ||||
| from data_processing import load_index_and_knowledge, create_index_cpu, create_index_gpu, find_top_k, rerank | ||||
| from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir | ||||
| 
 | ||||
| download( | ||||
|     model_repo=model_repo,  | ||||
|     output='model' | ||||
| ) | ||||
| ''' | ||||
| 	1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer | ||||
| 	2)调用 embedding 提供的接口对 query 向量化 | ||||
| 	3)下载基于 FAISS 预构建的 vector DB ,并检索对应信息 | ||||
| 	4)调用 rerank 接口重排序检索内容 | ||||
| 	5)调用 prompt 接口获取 system prompt 和 prompt template | ||||
| 	6)拼接 prompt 并调用模型返回结果 | ||||
| 
 | ||||
| 
 | ||||
| """ | ||||
| 读取知识库 | ||||
| """ | ||||
| def load_knowledge() -> Tuple[list, list]: | ||||
|     # 如果 pkl 不存在,则先编码存储 | ||||
|     if not os.path.exists(knowledge_pkl_path): | ||||
|         encode_qa(knowledge_json_path, knowledge_pkl_path) | ||||
| 
 | ||||
|     # 加载 json 和 pkl | ||||
|     with open(knowledge_json_path, 'r', encoding='utf-8') as f1, open(knowledge_pkl_path, 'rb') as f2: | ||||
|         knowledge = json.load(f1) | ||||
|         encoded_knowledge = pickle.load(f2) | ||||
|     return knowledge, encoded_knowledge | ||||
| 
 | ||||
| 
 | ||||
| """ | ||||
| 召回 top_k 个相关的文本段 | ||||
| """ | ||||
| def find_top_k( | ||||
|     emb: SentenceTransformer, | ||||
|     query: str, | ||||
|     knowledge: list, | ||||
|     encoded_knowledge: list, | ||||
|     k=3 | ||||
| ) -> list[str]: | ||||
|     # 编码 query | ||||
|     query_embedding = emb.encode(query) | ||||
| 
 | ||||
|     # 查找 top_k | ||||
|     scores = query_embedding @ encoded_knowledge.T | ||||
|     # 使用 argpartition 找出每行第 k 个大的值的索引,第 k 个位置左侧都是比它大的值,右侧都是比它小的值 | ||||
|     top_k_indices = np.argpartition(scores, -k)[-k:] | ||||
|     # 由于 argpartition 不保证顺序,我们需要对提取出的 k 个索引进行排序 | ||||
|     top_k_values_sorted_indices = np.argsort(scores[top_k_indices])[::-1] | ||||
|     top_k_indices = top_k_indices[top_k_values_sorted_indices] | ||||
| 
 | ||||
|     # 返回 | ||||
|     contents = [knowledge[index] for index in top_k_indices] | ||||
|     return contents | ||||
|      | ||||
| 
 | ||||
| def main(): | ||||
|     emb = load_embedding() | ||||
|     knowledge, encoded_knowledge = load_knowledge() | ||||
|     query = "认知心理学研究哪些心理活动?" | ||||
|     contents = find_top_k(emb, query, knowledge, encoded_knowledge, 2) | ||||
|     print('召回的 top-k 条相关内容如下:') | ||||
|     print(json.dumps(contents, ensure_ascii=False, indent=2)) | ||||
|     # 这里我没实现 LLM 部分,如果有 LLM | ||||
|     ## 1. 读取 LLM | ||||
|     ## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容} | ||||
|     ## 3. 要求 LLM 根据已知内容回复 | ||||
| ''' | ||||
| # download( | ||||
| #     model_repo=model_repo,  | ||||
| #     output='model' | ||||
| # ) | ||||
| 
 | ||||
| @st.cache_resource | ||||
| def load_model(): | ||||
|     model_dir = os.path.join(base_dir,'../model')  | ||||
|     logger.info(f'Loading model from {model_dir}') | ||||
|     model = ( | ||||
|         AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True) | ||||
|         AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True) | ||||
|         .to(torch.bfloat16) | ||||
|         .cuda() | ||||
|     ) | ||||
|     tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True) | ||||
|     tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) | ||||
|     return model, tokenizer | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     #main() | ||||
|     query = '' | ||||
| def get_prompt(): | ||||
|     pass | ||||
| 
 | ||||
| def get_prompt_template(): | ||||
|     pass | ||||
| 
 | ||||
| def main(query, system_prompt): | ||||
|     model, tokenizer = load_model() | ||||
|     rag_obj = EmoLLMRAG(model) | ||||
|     response = rag_obj.main(query) | ||||
|     model = model.eval() | ||||
|     if not os.path.exists(data_dir): | ||||
|          os.mkdir(data_dir) | ||||
|     # 下载基于 FAISS 预构建的 vector DB 以及原始知识库 | ||||
|     faiss_index, knowledge_chunks = load_index_and_knowledge() | ||||
|     distances, indices = find_top_k(query, faiss_index, 5) | ||||
|     rerank_results = rerank(query, indices, knowledge_chunks) | ||||
|     messages = [(system_prompt, rerank_results['rerank_passages'][0])] | ||||
|     logger.info(f'messages:{messages}') | ||||
|     response, history = model.chat(tokenizer, query, history=messages) | ||||
|     messages.append((query, response)) | ||||
|     print(f"robot >>> {response}")   | ||||
|      | ||||
| if __name__ == '__main__': | ||||
|     # query = '你好'  | ||||
|     query = "心理咨询师,我觉得我的胸闷症状越来越严重了,这让我很害怕" | ||||
|     #TODO system_prompt = get_prompt() | ||||
|     system_prompt = "你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发(排名按字母顺序排序,不分先后)、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。" | ||||
|     main(query, system_prompt) | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 xzw
						xzw