From b5af7793d6a0d867af87aa00c5289052d890a110 Mon Sep 17 00:00:00 2001 From: zealot52099 Date: Fri, 22 Mar 2024 07:39:44 +0800 Subject: [PATCH] update rag/src/data_processing.py --- rag/src/data_processing.py | 55 +++++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/rag/src/data_processing.py b/rag/src/data_processing.py index 334ce13..4126ee3 100644 --- a/rag/src/data_processing.py +++ b/rag/src/data_processing.py @@ -12,7 +12,7 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter from BCEmbedding import EmbeddingModel, RerankerModel -from util.pipeline import EmoLLMRAG +# from util.pipeline import EmoLLMRAG from transformers import AutoTokenizer, AutoModelForCausalLM from langchain.document_loaders.pdf import PyPDFDirectoryLoader from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader @@ -91,13 +91,13 @@ class Data_process(): if isinstance(obj, dict): for key, value in obj.items(): try: - self.extract_text_from_json(value, content) + content = self.extract_text_from_json(value, content) except Exception as e: print(f"Error processing value: {e}") elif isinstance(obj, list): for index, item in enumerate(obj): try: - self.extract_text_from_json(item, content) + content = self.extract_text_from_json(item, content) except Exception as e: print(f"Error processing item: {e}") elif isinstance(obj, str): @@ -157,7 +157,7 @@ class Data_process(): logger.info(f'splitting file {file_path}') with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) - print(data) + # print(data) for conversation in data: # for dialog in conversation['conversation']: ##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document @@ -165,6 +165,7 @@ class Data_process(): # split_qa.append(Document(page_content = content)) #按conversation块切分 content = self.extract_text_from_json(conversation['conversation'], '') + logger.info(f'content====={content}') split_qa.append(Document(page_content = content)) # logger.info(f'split_qa size====={len(split_qa)}') return split_qa @@ -229,9 +230,8 @@ class Data_process(): # compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever) # compressed_docs = compression_retriever.get_relevant_documents(query) # return compressed_docs - - def rerank(self, query, docs): + def rerank(self, query, docs): reranker = self.load_rerank_model() passages = [] for doc in docs: @@ -240,9 +240,41 @@ class Data_process(): sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True) sorted_passages, sorted_scores = zip(*sorted_pairs) return sorted_passages, sorted_scores + + +# def create_prompt(question, context): +# from langchain.prompts import PromptTemplate +# prompt_template = f"""请基于以下内容回答问题: + +# {context} + +# 问题: {question} +# 回答:""" +# prompt = PromptTemplate( +# template=prompt_template, input_variables=["context", "question"] +# ) +# logger.info(f'Prompt: {prompt}') +# return prompt + +def create_prompt(question, context): + prompt = f"""请基于以下内容: {context} 给出问题答案。问题如下: {question}。回答:""" + logger.info(f'Prompt: {prompt}') + return prompt - - +def test_zhipu(prompt): + from zhipuai import ZhipuAI + api_key = "" # 填写您自己的APIKey + if api_key == "": + raise ValueError("请填写api_key") + client = ZhipuAI(api_key=api_key) + response = client.chat.completions.create( + model="glm-4", # 填写需要调用的模型名称 + messages=[ + {"role": "user", "content": prompt[:100]} + ], +) + print(response.choices[0].message) + if __name__ == "__main__": logger.info(data_dir) if not os.path.exists(data_dir): @@ -254,7 +286,8 @@ if __name__ == "__main__": # query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版,1979年再版。" # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?" # query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性" - query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" + # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" + query = "我现在心情非常差,有什么解决办法吗?" docs, retriever = dp.retrieve(query, vector_db, k=10) logger.info(f'Query: {query}') logger.info("Retrieve results:") @@ -267,4 +300,6 @@ if __name__ == "__main__": logger.info("After reranking...") for i in range(len(scores)): logger.info(str(scores[i]) + '\n') - logger.info(passages[i]) \ No newline at end of file + logger.info(passages[i]) + prompt = create_prompt(query, passages[0]) + test_zhipu(prompt) ## 如果显示'Server disconnected without sending a response.'可能是由于上下文窗口限制 \ No newline at end of file