update rag/src/data_processing.py (#121)
This commit is contained in:
commit
382d338ab3
@ -91,13 +91,13 @@ class Data_process():
|
|||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
for key, value in obj.items():
|
for key, value in obj.items():
|
||||||
try:
|
try:
|
||||||
self.extract_text_from_json(value, content)
|
content = self.extract_text_from_json(value, content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing value: {e}")
|
print(f"Error processing value: {e}")
|
||||||
elif isinstance(obj, list):
|
elif isinstance(obj, list):
|
||||||
for index, item in enumerate(obj):
|
for index, item in enumerate(obj):
|
||||||
try:
|
try:
|
||||||
self.extract_text_from_json(item, content)
|
content = self.extract_text_from_json(item, content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing item: {e}")
|
print(f"Error processing item: {e}")
|
||||||
elif isinstance(obj, str):
|
elif isinstance(obj, str):
|
||||||
@ -157,7 +157,7 @@ class Data_process():
|
|||||||
logger.info(f'splitting file {file_path}')
|
logger.info(f'splitting file {file_path}')
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
print(data)
|
# print(data)
|
||||||
for conversation in data:
|
for conversation in data:
|
||||||
# for dialog in conversation['conversation']:
|
# for dialog in conversation['conversation']:
|
||||||
##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document
|
##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document
|
||||||
@ -165,6 +165,7 @@ class Data_process():
|
|||||||
# split_qa.append(Document(page_content = content))
|
# split_qa.append(Document(page_content = content))
|
||||||
#按conversation块切分
|
#按conversation块切分
|
||||||
content = self.extract_text_from_json(conversation['conversation'], '')
|
content = self.extract_text_from_json(conversation['conversation'], '')
|
||||||
|
logger.info(f'content====={content}')
|
||||||
split_qa.append(Document(page_content = content))
|
split_qa.append(Document(page_content = content))
|
||||||
# logger.info(f'split_qa size====={len(split_qa)}')
|
# logger.info(f'split_qa size====={len(split_qa)}')
|
||||||
return split_qa
|
return split_qa
|
||||||
@ -230,7 +231,6 @@ class Data_process():
|
|||||||
# compressed_docs = compression_retriever.get_relevant_documents(query)
|
# compressed_docs = compression_retriever.get_relevant_documents(query)
|
||||||
# return compressed_docs
|
# return compressed_docs
|
||||||
|
|
||||||
|
|
||||||
def rerank(self, query, docs):
|
def rerank(self, query, docs):
|
||||||
reranker = self.load_rerank_model()
|
reranker = self.load_rerank_model()
|
||||||
passages = []
|
passages = []
|
||||||
@ -242,6 +242,38 @@ class Data_process():
|
|||||||
return sorted_passages, sorted_scores
|
return sorted_passages, sorted_scores
|
||||||
|
|
||||||
|
|
||||||
|
# def create_prompt(question, context):
|
||||||
|
# from langchain.prompts import PromptTemplate
|
||||||
|
# prompt_template = f"""请基于以下内容回答问题:
|
||||||
|
|
||||||
|
# {context}
|
||||||
|
|
||||||
|
# 问题: {question}
|
||||||
|
# 回答:"""
|
||||||
|
# prompt = PromptTemplate(
|
||||||
|
# template=prompt_template, input_variables=["context", "question"]
|
||||||
|
# )
|
||||||
|
# logger.info(f'Prompt: {prompt}')
|
||||||
|
# return prompt
|
||||||
|
|
||||||
|
def create_prompt(question, context):
|
||||||
|
prompt = f"""请基于以下内容: {context} 给出问题答案。问题如下: {question}。回答:"""
|
||||||
|
logger.info(f'Prompt: {prompt}')
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def test_zhipu(prompt):
|
||||||
|
from zhipuai import ZhipuAI
|
||||||
|
api_key = "" # 填写您自己的APIKey
|
||||||
|
if api_key == "":
|
||||||
|
raise ValueError("请填写api_key")
|
||||||
|
client = ZhipuAI(api_key=api_key)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="glm-4", # 填写需要调用的模型名称
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": prompt[:100]}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(response.choices[0].message)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger.info(data_dir)
|
logger.info(data_dir)
|
||||||
@ -269,3 +301,5 @@ if __name__ == "__main__":
|
|||||||
for i in range(len(scores)):
|
for i in range(len(scores)):
|
||||||
logger.info(str(scores[i]) + '\n')
|
logger.info(str(scores[i]) + '\n')
|
||||||
logger.info(passages[i])
|
logger.info(passages[i])
|
||||||
|
prompt = create_prompt(query, passages[0])
|
||||||
|
test_zhipu(prompt) ## 如果显示'Server disconnected without sending a response.'可能是由于上下文窗口限制
|
Loading…
Reference in New Issue
Block a user