OliveSensorAPI/IOTLLM/generate_data/EC_process/Embedding_similarity.py

123 lines
5.6 KiB
Python
Raw Normal View History

2024-11-11 17:32:36 +08:00
import json
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import jsonlines
# 加载问答对嵌入
qa_embeddings = {}
with jsonlines.open('output/qa_embeddings.json', 'r') as reader:
for obj in reader:
qa_embeddings.update(obj) # 将每行的json对象加入到qa_embeddings
# 加载问答对
qa_pairs = []
with open('output/train_optimized_multiple.jsonl', 'r', encoding='utf-8') as f:
for line in f:
qa_pairs.append(json.loads(line))
# 提取嵌入和问题
questions = list(qa_embeddings.keys())
embeddings = np.array(list(qa_embeddings.values()))
# 关键词及其类别
categories = {
"栽培油橄榄的意义": ["栽培油橄榄", "经济价值", "引种"],
"油橄榄属植物分类": ["油橄榄属", "植物分类", "植物种", "原产地"],
"油橄榄生物学特性": ["根系类型", "土壤关系", "花芽分化", "花序", "授粉特性", "果实发育", "油脂形成"],
"油橄榄的生态环境条件": ["气候条件", "温度", "光照", "水分", "土壤生态", "海拔高度", "坡度"],
"油橄榄品种": ["佛奥", "莱星", "皮削利", "阿斯", "配多灵", "果大尔", "皮瓜尔", "科拉蒂", "克里", "爱桑", "贝拉", "实生种"],
"油橄榄育苗技术": ["育苗场地", "种子繁殖", "实生苗", "嫁接繁殖", "砧木", "接穗", "扦插繁殖", "组织培养"],
"油橄榄种植": ["园地选择", "种植密度", "栽植方式", "栽后管理"],
"土壤、肥料、水管理": ["土壤管理", "矿质营养", "果园灌溉", "果实采收"],
"整形修剪": ["整形修剪", "生物学原理", "结果习性", "树形", "幼树修剪", "复壮修剪"],
"病虫害防治": ["孔雀斑病", "炭疽病", "黄萎病", "肿瘤病", "根腐病", "云斑天牛", "油橄榄片盾", "大粒横沟象"]
}
# 初始化类别关键词的嵌入字典
category_embeddings = {category: [] for category in categories}
# 假设我们有一个方法来计算关键词的嵌入例如从qa_embeddings中获取
def get_keyword_embedding(keyword):
return qa_embeddings.get(keyword, None)
# 为每个类别生成关键词的嵌入
for category, keywords in categories.items():
for keyword in keywords:
keyword_embedding = get_keyword_embedding(keyword)
if keyword_embedding is not None:
category_embeddings[category].append(keyword_embedding)
# 将类别关键词的嵌入转化为平均向量
for category in category_embeddings:
if category_embeddings[category]:
category_embeddings[category] = np.mean(category_embeddings[category], axis=0)
else:
category_embeddings[category] = np.zeros(embeddings.shape[1]) # 默认空向量
# 计算每个问题与类别之间的相似度
category_similarities = {}
for idx, question in enumerate(questions):
question_embedding = embeddings[idx]
category_similarities[question] = {}
for category, category_embedding in category_embeddings.items():
similarity = cosine_similarity([question_embedding], [category_embedding])[0][0]
category_similarities[question][category] = similarity
# 将每个问题分配到相似度最高的类别
category_assignments = {category: [] for category in categories}
for question in questions:
best_category = max(category_similarities[question], key=category_similarities[question].get)
category_assignments[best_category].append(question)
# 整合并生成新的jsonl格式确保每个问答对都被包括
fine_tune_data = []
for category, assigned_questions in category_assignments.items():
for idx, question in enumerate(assigned_questions):
history = []
output = ""
instruction = ""
# 查找当前问题及其回答
qa_pair = next((qa for qa in qa_pairs if qa['input'] == question), None)
if qa_pair:
instruction = qa_pair['input'] # 当前问题作为instruction
output = qa_pair['output'] # 当前问题的回答作为output
# 从同一类别的其他问题构建history保证每个history与当前问题在同一类别
history_similarities = []
for related_question in assigned_questions:
if related_question != question:
related_embedding = qa_embeddings[related_question]
similarity = cosine_similarity([qa_embeddings[question]], [related_embedding])[0][0]
history_similarities.append((related_question, similarity))
# 按相似度排序并选择前1~3个问题作为history
history_similarities = sorted(history_similarities, key=lambda x: x[1], reverse=True)
for related_question, _ in history_similarities[:3]:
related_qa_pair = next((qa for qa in qa_pairs if qa['input'] == related_question), None)
if related_qa_pair:
history.append([related_qa_pair['input'], related_qa_pair['output']])
# 构建最终格式
if instruction and output:
fine_tune_entry = {
"instruction": instruction,
"input": "", # input为空
"output": output, # 当前问题的回答
"history": history, # 最多包含3条相关问题
"system": "你是一位油橄榄栽培专家,熟知油橄榄的品种分类、栽培技术、生态环境要求以及病虫害防治。"
}
fine_tune_data.append(fine_tune_entry)
# 保存新的jsonl格式
with open('output/fine_tune_data.jsonl', 'w', encoding='utf-8') as f:
for entry in fine_tune_data:
json.dump(entry, f, ensure_ascii=False)
f.write('\n')
print("对话数据整理完成")