update rag/src/data_processing.py (#121)

This commit is contained in:
xzw 2024-03-22 10:04:35 +08:00 committed by GitHub
commit 382d338ab3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -91,13 +91,13 @@ class Data_process():
if isinstance(obj, dict):
for key, value in obj.items():
try:
self.extract_text_from_json(value, content)
content = self.extract_text_from_json(value, content)
except Exception as e:
print(f"Error processing value: {e}")
elif isinstance(obj, list):
for index, item in enumerate(obj):
try:
self.extract_text_from_json(item, content)
content = self.extract_text_from_json(item, content)
except Exception as e:
print(f"Error processing item: {e}")
elif isinstance(obj, str):
@ -157,7 +157,7 @@ class Data_process():
logger.info(f'splitting file {file_path}')
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(data)
# print(data)
for conversation in data:
# for dialog in conversation['conversation']:
##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document
@ -165,6 +165,7 @@ class Data_process():
# split_qa.append(Document(page_content = content))
#按conversation块切分
content = self.extract_text_from_json(conversation['conversation'], '')
logger.info(f'content====={content}')
split_qa.append(Document(page_content = content))
# logger.info(f'split_qa size====={len(split_qa)}')
return split_qa
@ -229,9 +230,8 @@ class Data_process():
# compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
# compressed_docs = compression_retriever.get_relevant_documents(query)
# return compressed_docs
def rerank(self, query, docs):
def rerank(self, query, docs):
reranker = self.load_rerank_model()
passages = []
for doc in docs:
@ -240,9 +240,41 @@ class Data_process():
sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
sorted_passages, sorted_scores = zip(*sorted_pairs)
return sorted_passages, sorted_scores
# def create_prompt(question, context):
# from langchain.prompts import PromptTemplate
# prompt_template = f"""请基于以下内容回答问题:
# {context}
# 问题: {question}
# 回答:"""
# prompt = PromptTemplate(
# template=prompt_template, input_variables=["context", "question"]
# )
# logger.info(f'Prompt: {prompt}')
# return prompt
def create_prompt(question, context):
prompt = f"""请基于以下内容: {context} 给出问题答案。问题如下: {question}。回答:"""
logger.info(f'Prompt: {prompt}')
return prompt
def test_zhipu(prompt):
from zhipuai import ZhipuAI
api_key = "" # 填写您自己的APIKey
if api_key == "":
raise ValueError("请填写api_key")
client = ZhipuAI(api_key=api_key)
response = client.chat.completions.create(
model="glm-4", # 填写需要调用的模型名称
messages=[
{"role": "user", "content": prompt[:100]}
],
)
print(response.choices[0].message)
if __name__ == "__main__":
logger.info(data_dir)
if not os.path.exists(data_dir):
@ -268,4 +300,6 @@ if __name__ == "__main__":
logger.info("After reranking...")
for i in range(len(scores)):
logger.info(str(scores[i]) + '\n')
logger.info(passages[i])
logger.info(passages[i])
prompt = create_prompt(query, passages[0])
test_zhipu(prompt) ## 如果显示'Server disconnected without sending a response.'可能是由于上下文窗口限制