diff --git a/rag/src/data_processing.py b/rag/src/data_processing.py index 55a439f..4126ee3 100644 --- a/rag/src/data_processing.py +++ b/rag/src/data_processing.py @@ -91,13 +91,13 @@ class Data_process(): if isinstance(obj, dict): for key, value in obj.items(): try: - self.extract_text_from_json(value, content) + content = self.extract_text_from_json(value, content) except Exception as e: print(f"Error processing value: {e}") elif isinstance(obj, list): for index, item in enumerate(obj): try: - self.extract_text_from_json(item, content) + content = self.extract_text_from_json(item, content) except Exception as e: print(f"Error processing item: {e}") elif isinstance(obj, str): @@ -157,7 +157,7 @@ class Data_process(): logger.info(f'splitting file {file_path}') with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) - print(data) + # print(data) for conversation in data: # for dialog in conversation['conversation']: ##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document @@ -165,6 +165,7 @@ class Data_process(): # split_qa.append(Document(page_content = content)) #按conversation块切分 content = self.extract_text_from_json(conversation['conversation'], '') + logger.info(f'content====={content}') split_qa.append(Document(page_content = content)) # logger.info(f'split_qa size====={len(split_qa)}') return split_qa @@ -229,9 +230,8 @@ class Data_process(): # compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever) # compressed_docs = compression_retriever.get_relevant_documents(query) # return compressed_docs - - def rerank(self, query, docs): + def rerank(self, query, docs): reranker = self.load_rerank_model() passages = [] for doc in docs: @@ -240,9 +240,41 @@ class Data_process(): sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True) sorted_passages, sorted_scores = zip(*sorted_pairs) return sorted_passages, sorted_scores + + +# def create_prompt(question, context): +# from langchain.prompts import PromptTemplate +# prompt_template = f"""请基于以下内容回答问题: + +# {context} + +# 问题: {question} +# 回答:""" +# prompt = PromptTemplate( +# template=prompt_template, input_variables=["context", "question"] +# ) +# logger.info(f'Prompt: {prompt}') +# return prompt + +def create_prompt(question, context): + prompt = f"""请基于以下内容: {context} 给出问题答案。问题如下: {question}。回答:""" + logger.info(f'Prompt: {prompt}') + return prompt - - +def test_zhipu(prompt): + from zhipuai import ZhipuAI + api_key = "" # 填写您自己的APIKey + if api_key == "": + raise ValueError("请填写api_key") + client = ZhipuAI(api_key=api_key) + response = client.chat.completions.create( + model="glm-4", # 填写需要调用的模型名称 + messages=[ + {"role": "user", "content": prompt[:100]} + ], +) + print(response.choices[0].message) + if __name__ == "__main__": logger.info(data_dir) if not os.path.exists(data_dir): @@ -268,4 +300,6 @@ if __name__ == "__main__": logger.info("After reranking...") for i in range(len(scores)): logger.info(str(scores[i]) + '\n') - logger.info(passages[i]) \ No newline at end of file + logger.info(passages[i]) + prompt = create_prompt(query, passages[0]) + test_zhipu(prompt) ## 如果显示'Server disconnected without sending a response.'可能是由于上下文窗口限制 \ No newline at end of file