123 lines
5.6 KiB
Python
123 lines
5.6 KiB
Python
import json
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
import numpy as np
|
||
import jsonlines
|
||
|
||
# 加载问答对嵌入
|
||
qa_embeddings = {}
|
||
with jsonlines.open('output/qa_embeddings.json', 'r') as reader:
|
||
for obj in reader:
|
||
qa_embeddings.update(obj) # 将每行的json对象加入到qa_embeddings
|
||
|
||
# 加载问答对
|
||
qa_pairs = []
|
||
with open('output/train_optimized_multiple.jsonl', 'r', encoding='utf-8') as f:
|
||
for line in f:
|
||
qa_pairs.append(json.loads(line))
|
||
|
||
# 提取嵌入和问题
|
||
questions = list(qa_embeddings.keys())
|
||
embeddings = np.array(list(qa_embeddings.values()))
|
||
|
||
# 关键词及其类别
|
||
categories = {
|
||
"栽培油橄榄的意义": ["栽培油橄榄", "经济价值", "引种"],
|
||
"油橄榄属植物分类": ["油橄榄属", "植物分类", "植物种", "原产地"],
|
||
"油橄榄生物学特性": ["根系类型", "土壤关系", "花芽分化", "花序", "授粉特性", "果实发育", "油脂形成"],
|
||
"油橄榄的生态环境条件": ["气候条件", "温度", "光照", "水分", "土壤生态", "海拔高度", "坡度"],
|
||
"油橄榄品种": ["佛奥", "莱星", "皮削利", "阿斯", "配多灵", "果大尔", "皮瓜尔", "科拉蒂", "克里", "爱桑", "贝拉", "实生种"],
|
||
"油橄榄育苗技术": ["育苗场地", "种子繁殖", "实生苗", "嫁接繁殖", "砧木", "接穗", "扦插繁殖", "组织培养"],
|
||
"油橄榄种植": ["园地选择", "种植密度", "栽植方式", "栽后管理"],
|
||
"土壤、肥料、水管理": ["土壤管理", "矿质营养", "果园灌溉", "果实采收"],
|
||
"整形修剪": ["整形修剪", "生物学原理", "结果习性", "树形", "幼树修剪", "复壮修剪"],
|
||
"病虫害防治": ["孔雀斑病", "炭疽病", "黄萎病", "肿瘤病", "根腐病", "云斑天牛", "油橄榄片盾", "大粒横沟象"]
|
||
}
|
||
|
||
# 初始化类别关键词的嵌入字典
|
||
category_embeddings = {category: [] for category in categories}
|
||
|
||
|
||
# 假设我们有一个方法来计算关键词的嵌入,例如从qa_embeddings中获取
|
||
def get_keyword_embedding(keyword):
|
||
return qa_embeddings.get(keyword, None)
|
||
|
||
|
||
# 为每个类别生成关键词的嵌入
|
||
for category, keywords in categories.items():
|
||
for keyword in keywords:
|
||
keyword_embedding = get_keyword_embedding(keyword)
|
||
if keyword_embedding is not None:
|
||
category_embeddings[category].append(keyword_embedding)
|
||
|
||
# 将类别关键词的嵌入转化为平均向量
|
||
for category in category_embeddings:
|
||
if category_embeddings[category]:
|
||
category_embeddings[category] = np.mean(category_embeddings[category], axis=0)
|
||
else:
|
||
category_embeddings[category] = np.zeros(embeddings.shape[1]) # 默认空向量
|
||
|
||
# 计算每个问题与类别之间的相似度
|
||
category_similarities = {}
|
||
for idx, question in enumerate(questions):
|
||
question_embedding = embeddings[idx]
|
||
category_similarities[question] = {}
|
||
|
||
for category, category_embedding in category_embeddings.items():
|
||
similarity = cosine_similarity([question_embedding], [category_embedding])[0][0]
|
||
category_similarities[question][category] = similarity
|
||
|
||
# 将每个问题分配到相似度最高的类别
|
||
category_assignments = {category: [] for category in categories}
|
||
for question in questions:
|
||
best_category = max(category_similarities[question], key=category_similarities[question].get)
|
||
category_assignments[best_category].append(question)
|
||
|
||
# 整合并生成新的jsonl格式,确保每个问答对都被包括
|
||
fine_tune_data = []
|
||
for category, assigned_questions in category_assignments.items():
|
||
for idx, question in enumerate(assigned_questions):
|
||
history = []
|
||
output = ""
|
||
instruction = ""
|
||
|
||
# 查找当前问题及其回答
|
||
qa_pair = next((qa for qa in qa_pairs if qa['input'] == question), None)
|
||
|
||
if qa_pair:
|
||
instruction = qa_pair['input'] # 当前问题作为instruction
|
||
output = qa_pair['output'] # 当前问题的回答作为output
|
||
|
||
# 从同一类别的其他问题构建history,保证每个history与当前问题在同一类别
|
||
history_similarities = []
|
||
for related_question in assigned_questions:
|
||
if related_question != question:
|
||
related_embedding = qa_embeddings[related_question]
|
||
similarity = cosine_similarity([qa_embeddings[question]], [related_embedding])[0][0]
|
||
history_similarities.append((related_question, similarity))
|
||
|
||
# 按相似度排序,并选择前1~3个问题作为history
|
||
history_similarities = sorted(history_similarities, key=lambda x: x[1], reverse=True)
|
||
for related_question, _ in history_similarities[:3]:
|
||
related_qa_pair = next((qa for qa in qa_pairs if qa['input'] == related_question), None)
|
||
if related_qa_pair:
|
||
history.append([related_qa_pair['input'], related_qa_pair['output']])
|
||
|
||
# 构建最终格式
|
||
if instruction and output:
|
||
fine_tune_entry = {
|
||
"instruction": instruction,
|
||
"input": "", # input为空
|
||
"output": output, # 当前问题的回答
|
||
"history": history, # 最多包含3条相关问题
|
||
"system": "你是一位油橄榄栽培专家,熟知油橄榄的品种分类、栽培技术、生态环境要求以及病虫害防治。"
|
||
}
|
||
fine_tune_data.append(fine_tune_entry)
|
||
|
||
# 保存新的jsonl格式
|
||
with open('output/fine_tune_data.jsonl', 'w', encoding='utf-8') as f:
|
||
for entry in fine_tune_data:
|
||
json.dump(entry, f, ensure_ascii=False)
|
||
f.write('\n')
|
||
|
||
print("对话数据整理完成")
|