OliveSensorAPI/IOTLLM/generate_data/EC_process/Embedding_similarity.py

123 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import jsonlines
# 加载问答对嵌入
qa_embeddings = {}
with jsonlines.open('output/qa_embeddings.json', 'r') as reader:
for obj in reader:
qa_embeddings.update(obj) # 将每行的json对象加入到qa_embeddings
# 加载问答对
qa_pairs = []
with open('output/train_optimized_multiple.jsonl', 'r', encoding='utf-8') as f:
for line in f:
qa_pairs.append(json.loads(line))
# 提取嵌入和问题
questions = list(qa_embeddings.keys())
embeddings = np.array(list(qa_embeddings.values()))
# 关键词及其类别
categories = {
"栽培油橄榄的意义": ["栽培油橄榄", "经济价值", "引种"],
"油橄榄属植物分类": ["油橄榄属", "植物分类", "植物种", "原产地"],
"油橄榄生物学特性": ["根系类型", "土壤关系", "花芽分化", "花序", "授粉特性", "果实发育", "油脂形成"],
"油橄榄的生态环境条件": ["气候条件", "温度", "光照", "水分", "土壤生态", "海拔高度", "坡度"],
"油橄榄品种": ["佛奥", "莱星", "皮削利", "阿斯", "配多灵", "果大尔", "皮瓜尔", "科拉蒂", "克里", "爱桑", "贝拉", "实生种"],
"油橄榄育苗技术": ["育苗场地", "种子繁殖", "实生苗", "嫁接繁殖", "砧木", "接穗", "扦插繁殖", "组织培养"],
"油橄榄种植": ["园地选择", "种植密度", "栽植方式", "栽后管理"],
"土壤、肥料、水管理": ["土壤管理", "矿质营养", "果园灌溉", "果实采收"],
"整形修剪": ["整形修剪", "生物学原理", "结果习性", "树形", "幼树修剪", "复壮修剪"],
"病虫害防治": ["孔雀斑病", "炭疽病", "黄萎病", "肿瘤病", "根腐病", "云斑天牛", "油橄榄片盾", "大粒横沟象"]
}
# 初始化类别关键词的嵌入字典
category_embeddings = {category: [] for category in categories}
# 假设我们有一个方法来计算关键词的嵌入例如从qa_embeddings中获取
def get_keyword_embedding(keyword):
return qa_embeddings.get(keyword, None)
# 为每个类别生成关键词的嵌入
for category, keywords in categories.items():
for keyword in keywords:
keyword_embedding = get_keyword_embedding(keyword)
if keyword_embedding is not None:
category_embeddings[category].append(keyword_embedding)
# 将类别关键词的嵌入转化为平均向量
for category in category_embeddings:
if category_embeddings[category]:
category_embeddings[category] = np.mean(category_embeddings[category], axis=0)
else:
category_embeddings[category] = np.zeros(embeddings.shape[1]) # 默认空向量
# 计算每个问题与类别之间的相似度
category_similarities = {}
for idx, question in enumerate(questions):
question_embedding = embeddings[idx]
category_similarities[question] = {}
for category, category_embedding in category_embeddings.items():
similarity = cosine_similarity([question_embedding], [category_embedding])[0][0]
category_similarities[question][category] = similarity
# 将每个问题分配到相似度最高的类别
category_assignments = {category: [] for category in categories}
for question in questions:
best_category = max(category_similarities[question], key=category_similarities[question].get)
category_assignments[best_category].append(question)
# 整合并生成新的jsonl格式确保每个问答对都被包括
fine_tune_data = []
for category, assigned_questions in category_assignments.items():
for idx, question in enumerate(assigned_questions):
history = []
output = ""
instruction = ""
# 查找当前问题及其回答
qa_pair = next((qa for qa in qa_pairs if qa['input'] == question), None)
if qa_pair:
instruction = qa_pair['input'] # 当前问题作为instruction
output = qa_pair['output'] # 当前问题的回答作为output
# 从同一类别的其他问题构建history保证每个history与当前问题在同一类别
history_similarities = []
for related_question in assigned_questions:
if related_question != question:
related_embedding = qa_embeddings[related_question]
similarity = cosine_similarity([qa_embeddings[question]], [related_embedding])[0][0]
history_similarities.append((related_question, similarity))
# 按相似度排序并选择前1~3个问题作为history
history_similarities = sorted(history_similarities, key=lambda x: x[1], reverse=True)
for related_question, _ in history_similarities[:3]:
related_qa_pair = next((qa for qa in qa_pairs if qa['input'] == related_question), None)
if related_qa_pair:
history.append([related_qa_pair['input'], related_qa_pair['output']])
# 构建最终格式
if instruction and output:
fine_tune_entry = {
"instruction": instruction,
"input": "", # input为空
"output": output, # 当前问题的回答
"history": history, # 最多包含3条相关问题
"system": "你是一位油橄榄栽培专家,熟知油橄榄的品种分类、栽培技术、生态环境要求以及病虫害防治。"
}
fine_tune_data.append(fine_tune_entry)
# 保存新的jsonl格式
with open('output/fine_tune_data.jsonl', 'w', encoding='utf-8') as f:
for entry in fine_tune_data:
json.dump(entry, f, ensure_ascii=False)
f.write('\n')
print("对话数据整理完成")