自定义数据集处理脚本
This commit is contained in:
parent
2065b2176c
commit
1125b67f50
217
generate_data/EC_process/Embedding_merge.py
Normal file
217
generate_data/EC_process/Embedding_merge.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
from wsgiref.handlers import format_date_time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class AssembleHeaderException(Exception):
|
||||||
|
def __init__(this, msg):
|
||||||
|
this.message = msg
|
||||||
|
|
||||||
|
|
||||||
|
class Url:
|
||||||
|
def __init__(this, host, path, schema):
|
||||||
|
this.host = host
|
||||||
|
this.path = path
|
||||||
|
this.schema = schema
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# calculate sha256 and encode to base64
|
||||||
|
def sha256base64(data):
|
||||||
|
sha256 = hashlib.sha256()
|
||||||
|
sha256.update(data)
|
||||||
|
digest = base64.b64encode(sha256.digest()).decode(encoding='utf-8')
|
||||||
|
return digest
|
||||||
|
|
||||||
|
|
||||||
|
def parse_url(requset_url):
|
||||||
|
stidx = requset_url.index("://")
|
||||||
|
host = requset_url[stidx + 3:]
|
||||||
|
schema = requset_url[:stidx + 3]
|
||||||
|
edidx = host.index("/")
|
||||||
|
if edidx <= 0:
|
||||||
|
raise AssembleHeaderException("invalid request url:" + requset_url)
|
||||||
|
path = host[edidx:]
|
||||||
|
host = host[:edidx]
|
||||||
|
u = Url(host, path, schema)
|
||||||
|
return u
|
||||||
|
|
||||||
|
|
||||||
|
# 生成鉴权url
|
||||||
|
def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
|
||||||
|
u = parse_url(requset_url)
|
||||||
|
host = u.host
|
||||||
|
path = u.path
|
||||||
|
now = datetime.now()
|
||||||
|
date = format_date_time(time.mktime(now.timetuple()))
|
||||||
|
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path)
|
||||||
|
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||||
|
digestmod=hashlib.sha256).digest()
|
||||||
|
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||||
|
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||||
|
api_key, "hmac-sha256", "host date request-line", signature_sha)
|
||||||
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||||
|
values = {
|
||||||
|
"host": host,
|
||||||
|
"date": date,
|
||||||
|
"authorization": authorization
|
||||||
|
}
|
||||||
|
|
||||||
|
return requset_url + "?" + urlencode(values)
|
||||||
|
|
||||||
|
|
||||||
|
def get_Body(appid, text, style):
|
||||||
|
org_content = json.dumps(text).encode('utf-8')
|
||||||
|
body = {
|
||||||
|
"header": {
|
||||||
|
"app_id": appid,
|
||||||
|
"uid": "39769795890",
|
||||||
|
"status": 3
|
||||||
|
},
|
||||||
|
"parameter": {
|
||||||
|
"emb": {
|
||||||
|
"domain": style,
|
||||||
|
"feature": {
|
||||||
|
"encoding": "utf8"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"messages": {
|
||||||
|
"text": base64.b64encode(json.dumps(text).encode('utf-8')).decode()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
# 发起请求并返回结果
|
||||||
|
def get_embp_embedding(text, appid, apikey, apisecret):
|
||||||
|
host = 'https://emb-cn-huabei-1.xf-yun.com/'
|
||||||
|
url = assemble_ws_auth_url(host, method='POST', api_key=apikey, api_secret=apisecret)
|
||||||
|
content = get_Body(appid, text, "para")
|
||||||
|
response = requests.post(url, json=content, headers={'content-type': "application/json"}).text
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# 解析结果并输出
|
||||||
|
def parser_Message(message):
|
||||||
|
data = json.loads(message)
|
||||||
|
code = data['header']['code']
|
||||||
|
if code != 0:
|
||||||
|
print(f'请求错误: {code}, {data}')
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
text_base = data["payload"]["feature"]["text"]
|
||||||
|
text_data = base64.b64decode(text_base)
|
||||||
|
dt = np.dtype(np.float32).newbyteorder("<")
|
||||||
|
text = np.frombuffer(text_data, dtype=dt)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# 加载问答对数据
|
||||||
|
def load_qa_data(file_path):
|
||||||
|
qa_pairs = []
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
qa_pairs.append(json.loads(line.strip()))
|
||||||
|
return qa_pairs
|
||||||
|
|
||||||
|
|
||||||
|
# 保存embedding到文件
|
||||||
|
def save_embeddings(embeddings, file_path):
|
||||||
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(embeddings, f, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
# 获取文本的embedding
|
||||||
|
def get_embedding_for_text(text, appid, apikey, apisecret):
|
||||||
|
desc = {"messages": [{"content": text, "role": "user"}]}
|
||||||
|
res = get_embp_embedding(desc, appid=appid, apikey=apikey, apisecret=apisecret)
|
||||||
|
return parser_Message(res)
|
||||||
|
|
||||||
|
|
||||||
|
# 逐行加载已存在的embedding
|
||||||
|
def load_embeddings(file_path):
|
||||||
|
embeddings = {}
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip(): # 忽略空行
|
||||||
|
embedding_data = json.loads(line.strip())
|
||||||
|
embeddings.update(embedding_data)
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"文件 {file_path} 不存在,将创建新文件")
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
# 逐行保存embedding到文件
|
||||||
|
def save_embedding_line_by_line(qa, embedding, file_path):
|
||||||
|
if embedding is not None:
|
||||||
|
embedding_as_list = embedding.tolist() # 将numpy array转换为列表
|
||||||
|
with open(file_path, 'a', encoding='utf-8') as f:
|
||||||
|
json.dump({qa: embedding_as_list}, f, ensure_ascii=False)
|
||||||
|
f.write("\n") # 每行一个embedding
|
||||||
|
|
||||||
|
|
||||||
|
# 获取单个问题的embedding,并处理请求错误
|
||||||
|
def get_embedding_with_retry(question, appid, apikey, apisecret, max_retries=5):
|
||||||
|
retries = 0
|
||||||
|
while retries < max_retries:
|
||||||
|
try:
|
||||||
|
embedding = get_embedding_for_text(question, appid, apikey, apisecret)
|
||||||
|
if embedding is not None:
|
||||||
|
return embedding
|
||||||
|
except Exception as e:
|
||||||
|
print(f"请求错误: {e}")
|
||||||
|
retries += 1
|
||||||
|
print(f"重试第 {retries} 次...")
|
||||||
|
time.sleep(5) # 每次重试前等待 5 秒
|
||||||
|
print(f"获取'{question}' 的embedding失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# 获取所有问答对的embedding并逐行保存
|
||||||
|
def get_and_save_embeddings(qa_pairs, appid, apikey, apisecret, file_path, qps_limit=2):
|
||||||
|
all_embeddings = load_embeddings(file_path) # 尝试加载已存在的embedding
|
||||||
|
interval = 1 / qps_limit # 根据QPS限制设置间隔时间
|
||||||
|
for qa in qa_pairs:
|
||||||
|
question = qa['input']
|
||||||
|
if question in all_embeddings:
|
||||||
|
print(f"'{question}' 的embedding已存在,跳过计算")
|
||||||
|
continue
|
||||||
|
print(f"计算'{question}' 的embedding...")
|
||||||
|
embedding = get_embedding_with_retry(question, appid, apikey, apisecret) # 带重试机制的请求
|
||||||
|
if embedding is not None:
|
||||||
|
save_embedding_line_by_line(question, embedding, file_path) # 逐行保存
|
||||||
|
all_embeddings[question] = embedding # 更新已计算的embedding
|
||||||
|
time.sleep(interval) # 确保符合QPS限制
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 设置路径
|
||||||
|
qa_file = "output/train_optimized_multiple.jsonl" # 原问答对文件
|
||||||
|
embedding_file = "output/qa_embeddings.json" # embedding存储文件
|
||||||
|
|
||||||
|
appid = "f0f73de5"
|
||||||
|
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
|
||||||
|
api_key = "5773f6f95563708de994d17b7ea5d414"
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
qa_pairs = load_qa_data(qa_file)
|
||||||
|
|
||||||
|
# 获取并保存embedding
|
||||||
|
get_and_save_embeddings(qa_pairs, appid, api_key, api_secret, embedding_file)
|
||||||
|
|
||||||
|
print(f"已保存所有问答对的embedding到 {embedding_file}")
|
122
generate_data/EC_process/Embedding_similarity.py
Normal file
122
generate_data/EC_process/Embedding_similarity.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import json
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
import numpy as np
|
||||||
|
import jsonlines
|
||||||
|
|
||||||
|
# 加载问答对嵌入
|
||||||
|
qa_embeddings = {}
|
||||||
|
with jsonlines.open('output/qa_embeddings.json', 'r') as reader:
|
||||||
|
for obj in reader:
|
||||||
|
qa_embeddings.update(obj) # 将每行的json对象加入到qa_embeddings
|
||||||
|
|
||||||
|
# 加载问答对
|
||||||
|
qa_pairs = []
|
||||||
|
with open('output/train_optimized_multiple.jsonl', 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
qa_pairs.append(json.loads(line))
|
||||||
|
|
||||||
|
# 提取嵌入和问题
|
||||||
|
questions = list(qa_embeddings.keys())
|
||||||
|
embeddings = np.array(list(qa_embeddings.values()))
|
||||||
|
|
||||||
|
# 关键词及其类别
|
||||||
|
categories = {
|
||||||
|
"栽培油橄榄的意义": ["栽培油橄榄", "经济价值", "引种"],
|
||||||
|
"油橄榄属植物分类": ["油橄榄属", "植物分类", "植物种", "原产地"],
|
||||||
|
"油橄榄生物学特性": ["根系类型", "土壤关系", "花芽分化", "花序", "授粉特性", "果实发育", "油脂形成"],
|
||||||
|
"油橄榄的生态环境条件": ["气候条件", "温度", "光照", "水分", "土壤生态", "海拔高度", "坡度"],
|
||||||
|
"油橄榄品种": ["佛奥", "莱星", "皮削利", "阿斯", "配多灵", "果大尔", "皮瓜尔", "科拉蒂", "克里", "爱桑", "贝拉", "实生种"],
|
||||||
|
"油橄榄育苗技术": ["育苗场地", "种子繁殖", "实生苗", "嫁接繁殖", "砧木", "接穗", "扦插繁殖", "组织培养"],
|
||||||
|
"油橄榄种植": ["园地选择", "种植密度", "栽植方式", "栽后管理"],
|
||||||
|
"土壤、肥料、水管理": ["土壤管理", "矿质营养", "果园灌溉", "果实采收"],
|
||||||
|
"整形修剪": ["整形修剪", "生物学原理", "结果习性", "树形", "幼树修剪", "复壮修剪"],
|
||||||
|
"病虫害防治": ["孔雀斑病", "炭疽病", "黄萎病", "肿瘤病", "根腐病", "云斑天牛", "油橄榄片盾", "大粒横沟象"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 初始化类别关键词的嵌入字典
|
||||||
|
category_embeddings = {category: [] for category in categories}
|
||||||
|
|
||||||
|
|
||||||
|
# 假设我们有一个方法来计算关键词的嵌入,例如从qa_embeddings中获取
|
||||||
|
def get_keyword_embedding(keyword):
|
||||||
|
return qa_embeddings.get(keyword, None)
|
||||||
|
|
||||||
|
|
||||||
|
# 为每个类别生成关键词的嵌入
|
||||||
|
for category, keywords in categories.items():
|
||||||
|
for keyword in keywords:
|
||||||
|
keyword_embedding = get_keyword_embedding(keyword)
|
||||||
|
if keyword_embedding is not None:
|
||||||
|
category_embeddings[category].append(keyword_embedding)
|
||||||
|
|
||||||
|
# 将类别关键词的嵌入转化为平均向量
|
||||||
|
for category in category_embeddings:
|
||||||
|
if category_embeddings[category]:
|
||||||
|
category_embeddings[category] = np.mean(category_embeddings[category], axis=0)
|
||||||
|
else:
|
||||||
|
category_embeddings[category] = np.zeros(embeddings.shape[1]) # 默认空向量
|
||||||
|
|
||||||
|
# 计算每个问题与类别之间的相似度
|
||||||
|
category_similarities = {}
|
||||||
|
for idx, question in enumerate(questions):
|
||||||
|
question_embedding = embeddings[idx]
|
||||||
|
category_similarities[question] = {}
|
||||||
|
|
||||||
|
for category, category_embedding in category_embeddings.items():
|
||||||
|
similarity = cosine_similarity([question_embedding], [category_embedding])[0][0]
|
||||||
|
category_similarities[question][category] = similarity
|
||||||
|
|
||||||
|
# 将每个问题分配到相似度最高的类别
|
||||||
|
category_assignments = {category: [] for category in categories}
|
||||||
|
for question in questions:
|
||||||
|
best_category = max(category_similarities[question], key=category_similarities[question].get)
|
||||||
|
category_assignments[best_category].append(question)
|
||||||
|
|
||||||
|
# 整合并生成新的jsonl格式,确保每个问答对都被包括
|
||||||
|
fine_tune_data = []
|
||||||
|
for category, assigned_questions in category_assignments.items():
|
||||||
|
for idx, question in enumerate(assigned_questions):
|
||||||
|
history = []
|
||||||
|
output = ""
|
||||||
|
instruction = ""
|
||||||
|
|
||||||
|
# 查找当前问题及其回答
|
||||||
|
qa_pair = next((qa for qa in qa_pairs if qa['input'] == question), None)
|
||||||
|
|
||||||
|
if qa_pair:
|
||||||
|
instruction = qa_pair['input'] # 当前问题作为instruction
|
||||||
|
output = qa_pair['output'] # 当前问题的回答作为output
|
||||||
|
|
||||||
|
# 从同一类别的其他问题构建history,保证每个history与当前问题在同一类别
|
||||||
|
history_similarities = []
|
||||||
|
for related_question in assigned_questions:
|
||||||
|
if related_question != question:
|
||||||
|
related_embedding = qa_embeddings[related_question]
|
||||||
|
similarity = cosine_similarity([qa_embeddings[question]], [related_embedding])[0][0]
|
||||||
|
history_similarities.append((related_question, similarity))
|
||||||
|
|
||||||
|
# 按相似度排序,并选择前1~3个问题作为history
|
||||||
|
history_similarities = sorted(history_similarities, key=lambda x: x[1], reverse=True)
|
||||||
|
for related_question, _ in history_similarities[:3]:
|
||||||
|
related_qa_pair = next((qa for qa in qa_pairs if qa['input'] == related_question), None)
|
||||||
|
if related_qa_pair:
|
||||||
|
history.append([related_qa_pair['input'], related_qa_pair['output']])
|
||||||
|
|
||||||
|
# 构建最终格式
|
||||||
|
if instruction and output:
|
||||||
|
fine_tune_entry = {
|
||||||
|
"instruction": instruction,
|
||||||
|
"input": "", # input为空
|
||||||
|
"output": output, # 当前问题的回答
|
||||||
|
"history": history, # 最多包含3条相关问题
|
||||||
|
"system": "你是一位油橄榄栽培专家,熟知油橄榄的品种分类、栽培技术、生态环境要求以及病虫害防治。"
|
||||||
|
}
|
||||||
|
fine_tune_data.append(fine_tune_entry)
|
||||||
|
|
||||||
|
# 保存新的jsonl格式
|
||||||
|
with open('output/fine_tune_data.jsonl', 'w', encoding='utf-8') as f:
|
||||||
|
for entry in fine_tune_data:
|
||||||
|
json.dump(entry, f, ensure_ascii=False)
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
print("对话数据整理完成")
|
71
generate_data/EC_process/LDArec.py
Normal file
71
generate_data/EC_process/LDArec.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Time : 2024/10/24 11:10
|
||||||
|
# @Author : 黄子寒
|
||||||
|
# @Email : 1064071566@qq.com
|
||||||
|
# @File : LDArec.py
|
||||||
|
# @Project : EmoLLM
|
||||||
|
import json
|
||||||
|
import jieba
|
||||||
|
from gensim import corpora
|
||||||
|
from gensim.models.ldamodel import LdaModel
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
# 加载问答对数据
|
||||||
|
def load_qa_data(file_path):
|
||||||
|
qa_pairs = []
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
qa_pairs.append(json.loads(line.strip()))
|
||||||
|
return qa_pairs
|
||||||
|
|
||||||
|
|
||||||
|
# 加载中文停用词
|
||||||
|
def load_stopwords(file_path):
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
return set([line.strip() for line in f])
|
||||||
|
|
||||||
|
|
||||||
|
# 使用jieba对中文文本进行分词,并去除停用词
|
||||||
|
def preprocess_text(text, stopwords):
|
||||||
|
words = jieba.lcut(text) # 使用jieba进行中文分词
|
||||||
|
words = [word for word in words if word not in stopwords and len(word) > 1] # 去除停用词和长度为1的词
|
||||||
|
return words
|
||||||
|
|
||||||
|
|
||||||
|
# 生成LDA主题模型
|
||||||
|
def build_lda_model(qa_pairs, stopwords, num_topics=5):
|
||||||
|
# 处理所有问题文本
|
||||||
|
questions = [qa['input'] for qa in qa_pairs]
|
||||||
|
processed_questions = [preprocess_text(question, stopwords) for question in questions]
|
||||||
|
|
||||||
|
# 创建字典和词袋模型
|
||||||
|
dictionary = corpora.Dictionary(processed_questions)
|
||||||
|
corpus = [dictionary.doc2bow(text) for text in processed_questions]
|
||||||
|
|
||||||
|
# 训练LDA模型
|
||||||
|
lda_model = LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
|
||||||
|
return lda_model, dictionary, corpus
|
||||||
|
|
||||||
|
|
||||||
|
# 打印每个主题的关键词
|
||||||
|
def print_topics(lda_model, num_words=10):
|
||||||
|
for idx, topic in lda_model.print_topics(num_words=num_words):
|
||||||
|
print(f"主题 {idx}: {topic}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
qa_file = "output/train_optimized_multiple.jsonl" # 问答对文件
|
||||||
|
stopwords_file = "chinese_stopwords.txt" # 停用词文件
|
||||||
|
|
||||||
|
# 加载问答对
|
||||||
|
qa_pairs = load_qa_data(qa_file)
|
||||||
|
|
||||||
|
# 加载停用词
|
||||||
|
stopwords = load_stopwords(stopwords_file)
|
||||||
|
|
||||||
|
# 构建LDA主题模型
|
||||||
|
lda_model, dictionary, corpus = build_lda_model(qa_pairs, stopwords, num_topics=20)
|
||||||
|
|
||||||
|
# 打印主题及其关键词
|
||||||
|
print_topics(lda_model)
|
70
generate_data/EC_process/Sensor_QA.py
Normal file
70
generate_data/EC_process/Sensor_QA.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
|
# 定义生成5000条数据集的函数
|
||||||
|
def generate_dataset(num_samples=5000):
|
||||||
|
dataset = []
|
||||||
|
invoke_types = [1, 2, 3]
|
||||||
|
area_codes = [chr(i) for i in range(ord('A'), ord('Z') + 1)]
|
||||||
|
parameters = [
|
||||||
|
{"name": "土壤湿度", "unit": "%", "min": 10, "max": 100},
|
||||||
|
{"name": "土壤温度", "unit": "℃", "min": 5, "max": 40},
|
||||||
|
{"name": "空气温度", "unit": "℃", "min": -10, "max": 45},
|
||||||
|
{"name": "电导率", "unit": "mS/cm", "min": 0.1, "max": 5.0}
|
||||||
|
]
|
||||||
|
|
||||||
|
for _ in range(num_samples):
|
||||||
|
invoke_type = random.choice(invoke_types)
|
||||||
|
area_code = random.choice(area_codes)
|
||||||
|
parameter = random.choice(parameters)
|
||||||
|
|
||||||
|
if isinstance(parameter["min"], int):
|
||||||
|
value = round(random.uniform(parameter["min"], parameter["max"]), 1)
|
||||||
|
else:
|
||||||
|
value = round(random.uniform(parameter["min"], parameter["max"]), 1)
|
||||||
|
|
||||||
|
# 增加多变的提问方式,使数据更自然化
|
||||||
|
instruction_templates = [
|
||||||
|
f"现在{area_code}种植区内{parameter['name']}如何?",
|
||||||
|
f"请告诉我{area_code}区的{parameter['name']}情况。",
|
||||||
|
f"{area_code}区当前的{parameter['name']}是多少?",
|
||||||
|
f"我想知道{area_code}区的{parameter['name']}。",
|
||||||
|
f"{area_code}区的{parameter['name']}现在是多少?",
|
||||||
|
f"{area_code}种植区目前的{parameter['name']}是多少?",
|
||||||
|
f"能提供{area_code}区的{parameter['name']}数据吗?",
|
||||||
|
f"{area_code}种植区的{parameter['name']}是多少?",
|
||||||
|
f"请查询{area_code}区的{parameter['name']}。",
|
||||||
|
f"{area_code}区现在的{parameter['name']}数据是多少?",
|
||||||
|
f"帮我看看{area_code}区{parameter['name']}的情况。",
|
||||||
|
f"{area_code}区的{parameter['name']}值是多少?",
|
||||||
|
f"帮我查一下{area_code}区的{parameter['name']}。",
|
||||||
|
f"{area_code}区的{parameter['name']}现在什么情况?",
|
||||||
|
f"请帮我查一下{area_code}种植区的{parameter['name']}是多少?",
|
||||||
|
f"我需要知道{area_code}区的{parameter['name']}数据。",
|
||||||
|
f"请问{area_code}区的{parameter['name']}如何?",
|
||||||
|
f"帮我查询{area_code}区的{parameter['name']}情况。",
|
||||||
|
f"现在{area_code}区的{parameter['name']}值是多少?"
|
||||||
|
]
|
||||||
|
instruction = random.choice(instruction_templates)
|
||||||
|
output = f"{area_code}区现在{parameter['name']}{value}{parameter['unit']}"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"instruction": instruction,
|
||||||
|
"invokeType": str(invoke_type),
|
||||||
|
"areaCode": area_code,
|
||||||
|
"output": output
|
||||||
|
}
|
||||||
|
dataset.append(data)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
# 生成数据并保存为json文件
|
||||||
|
if __name__ == '__main__':
|
||||||
|
dataset = generate_dataset()
|
||||||
|
output_file = 'output/synthetic_dataset.json'
|
||||||
|
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(dataset, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
print(f"已生成 {output_file} 文件,包含{len(dataset)}条数据。")
|
136
generate_data/EC_process/SparkApi.py
Normal file
136
generate_data/EC_process/SparkApi.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
import _thread as thread
|
||||||
|
import base64
|
||||||
|
import datetime
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import ssl
|
||||||
|
from datetime import datetime
|
||||||
|
from time import mktime
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
from wsgiref.handlers import format_date_time
|
||||||
|
|
||||||
|
import websocket # 使用websocket_client
|
||||||
|
answer = ""
|
||||||
|
|
||||||
|
class Ws_Param(object):
|
||||||
|
# 初始化
|
||||||
|
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
||||||
|
self.APPID = APPID
|
||||||
|
self.APIKey = APIKey
|
||||||
|
self.APISecret = APISecret
|
||||||
|
self.host = urlparse(Spark_url).netloc
|
||||||
|
self.path = urlparse(Spark_url).path
|
||||||
|
self.Spark_url = Spark_url
|
||||||
|
|
||||||
|
# 生成url
|
||||||
|
def create_url(self):
|
||||||
|
# 生成RFC1123格式的时间戳
|
||||||
|
now = datetime.now()
|
||||||
|
date = format_date_time(mktime(now.timetuple()))
|
||||||
|
|
||||||
|
# 拼接字符串
|
||||||
|
signature_origin = "host: " + self.host + "\n"
|
||||||
|
signature_origin += "date: " + date + "\n"
|
||||||
|
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||||
|
|
||||||
|
# 进行hmac-sha256进行加密
|
||||||
|
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||||
|
digestmod=hashlib.sha256).digest()
|
||||||
|
|
||||||
|
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||||
|
|
||||||
|
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||||
|
|
||||||
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||||
|
|
||||||
|
# 将请求的鉴权参数组合为字典
|
||||||
|
v = {
|
||||||
|
"authorization": authorization,
|
||||||
|
"date": date,
|
||||||
|
"host": self.host
|
||||||
|
}
|
||||||
|
# 拼接鉴权参数,生成url
|
||||||
|
url = self.Spark_url + '?' + urlencode(v)
|
||||||
|
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
# 收到websocket错误的处理
|
||||||
|
def on_error(ws, error):
|
||||||
|
print("### error:", error)
|
||||||
|
|
||||||
|
|
||||||
|
# 收到websocket关闭的处理
|
||||||
|
def on_close(ws,one,two):
|
||||||
|
print(" ")
|
||||||
|
|
||||||
|
|
||||||
|
# 收到websocket连接建立的处理
|
||||||
|
def on_open(ws):
|
||||||
|
thread.start_new_thread(run, (ws,))
|
||||||
|
|
||||||
|
|
||||||
|
def run(ws, *args):
|
||||||
|
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
|
||||||
|
ws.send(data)
|
||||||
|
|
||||||
|
|
||||||
|
# 收到websocket消息的处理
|
||||||
|
def on_message(ws, message):
|
||||||
|
# print(message)
|
||||||
|
data = json.loads(message)
|
||||||
|
code = data['header']['code']
|
||||||
|
if code != 0:
|
||||||
|
print(f'请求错误: {code}, {data}')
|
||||||
|
ws.close()
|
||||||
|
else:
|
||||||
|
choices = data["payload"]["choices"]
|
||||||
|
status = choices["status"]
|
||||||
|
content = choices["text"][0]["content"]
|
||||||
|
print(content,end ="")
|
||||||
|
global answer
|
||||||
|
answer += content
|
||||||
|
# print(1)
|
||||||
|
if status == 2:
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
|
||||||
|
def gen_params(appid, domain,question):
|
||||||
|
"""
|
||||||
|
通过appid和用户的提问来生成请参数
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"header": {
|
||||||
|
"app_id": appid,
|
||||||
|
"uid": "1234"
|
||||||
|
},
|
||||||
|
"parameter": {
|
||||||
|
"chat": {
|
||||||
|
"domain": domain,
|
||||||
|
"temperature": 0.5,
|
||||||
|
"max_tokens": 2048
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"message": {
|
||||||
|
"text": question
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def main(appid, api_key, api_secret, Spark_url,domain, question):
|
||||||
|
# print("星火:")
|
||||||
|
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
|
||||||
|
websocket.enableTrace(False)
|
||||||
|
wsUrl = wsParam.create_url()
|
||||||
|
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
||||||
|
ws.appid = appid
|
||||||
|
ws.question = question
|
||||||
|
ws.domain = domain
|
||||||
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||||
|
|
||||||
|
|
24
generate_data/EC_process/api_test.py
Normal file
24
generate_data/EC_process/api_test.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
url = "https://chatapi.midjourney-vip.cn/v1/chat/completions"
|
||||||
|
|
||||||
|
payload = json.dumps({
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "测试"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
headers = {
|
||||||
|
'Accept': 'application/json',
|
||||||
|
'Authorization': 'sk-ATDf2Ax1YTGeeTaBD9Be2a7bE0064618Ae3378EaF0Df6f24',
|
||||||
|
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.request("POST", url, headers=headers, data=payload)
|
||||||
|
|
||||||
|
print(response.text)
|
1598
generate_data/EC_process/chinese_stopwords.txt
Normal file
1598
generate_data/EC_process/chinese_stopwords.txt
Normal file
File diff suppressed because it is too large
Load Diff
66
generate_data/EC_process/custom_dict.txt
Normal file
66
generate_data/EC_process/custom_dict.txt
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
栽培油橄榄
|
||||||
|
经济价值
|
||||||
|
引种
|
||||||
|
油橄榄属
|
||||||
|
植物分类
|
||||||
|
植物种
|
||||||
|
原产地
|
||||||
|
根系类型
|
||||||
|
土壤关系
|
||||||
|
花芽分化
|
||||||
|
花序
|
||||||
|
授粉特性
|
||||||
|
果实发育
|
||||||
|
油脂形成
|
||||||
|
气候条件
|
||||||
|
温度
|
||||||
|
光照
|
||||||
|
水分
|
||||||
|
土壤生态
|
||||||
|
海拔高度
|
||||||
|
坡度
|
||||||
|
佛奥
|
||||||
|
莱星
|
||||||
|
皮削利
|
||||||
|
阿斯
|
||||||
|
配多灵
|
||||||
|
果大尔
|
||||||
|
皮瓜尔
|
||||||
|
科拉蒂
|
||||||
|
克里
|
||||||
|
爱桑
|
||||||
|
贝拉
|
||||||
|
实生种
|
||||||
|
育苗场地
|
||||||
|
种子繁殖
|
||||||
|
实生苗
|
||||||
|
嫁接繁殖
|
||||||
|
砧木
|
||||||
|
接穗
|
||||||
|
扦插繁殖
|
||||||
|
组织培养
|
||||||
|
园地选择
|
||||||
|
种植密度
|
||||||
|
栽植方式
|
||||||
|
栽后管理
|
||||||
|
土壤管理
|
||||||
|
矿质营养
|
||||||
|
果园灌溉
|
||||||
|
果实采收
|
||||||
|
整形修剪
|
||||||
|
生物学原理
|
||||||
|
结果习性
|
||||||
|
树形
|
||||||
|
幼树修剪
|
||||||
|
复壮修剪
|
||||||
|
孔雀斑病
|
||||||
|
炭疽病
|
||||||
|
黄萎病
|
||||||
|
肿瘤病
|
||||||
|
根腐病
|
||||||
|
云斑天牛
|
||||||
|
油橄榄片盾
|
||||||
|
大粒横沟象
|
||||||
|
引进品种名录
|
||||||
|
中英对照品种名称
|
||||||
|
病虫害判定表
|
116
generate_data/EC_process/extend_QA.py
Normal file
116
generate_data/EC_process/extend_QA.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import SparkApi
|
||||||
|
|
||||||
|
# 输入文件路径
|
||||||
|
input_file = 'output/train_expanded.jsonl'
|
||||||
|
# 输出文件路径
|
||||||
|
output_file = 'output/train_expanded_2.jsonl'
|
||||||
|
# 断点文件路径
|
||||||
|
checkpoint_file = 'output/e2_progress_checkpoint.txt'
|
||||||
|
|
||||||
|
|
||||||
|
# 调用API生成问答对
|
||||||
|
def generate_qa_via_api(content):
|
||||||
|
appid = "48d04aae"
|
||||||
|
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
|
||||||
|
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
|
||||||
|
domain = "4.0Ultra"
|
||||||
|
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
|
||||||
|
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
|
||||||
|
f"1. 根据给定内容生成**三个**相关的问题和回答。\n"
|
||||||
|
f"2. 你可以简化问题、提取具体要素进行提问,或扩展内容生成额外的相关问题。\n"
|
||||||
|
f"3. **问题必须简洁明了**,并涵盖内容中的关键信息。\n"
|
||||||
|
f"4. 每个回答应该准确且**不超过50字**,同时**不少于20字**,以保证内容的简洁和有用性。\n"
|
||||||
|
f"5. 仅围绕油橄榄栽培的相关内容生成问答对,忽略其他无关信息。\n\n"
|
||||||
|
f"以下是给定内容:\n\n"
|
||||||
|
f"内容:{content}\n\n"
|
||||||
|
f"请按如下格式生成输出:\n"
|
||||||
|
f"问题1:<生成第一个问题>\n"
|
||||||
|
f"回答1:<生成第一个回答>\n"
|
||||||
|
f"问题2:<生成第二个问题>\n"
|
||||||
|
f"回答2:<生成第二个回答>\n"
|
||||||
|
f"问题3:<生成第三个问题>\n"
|
||||||
|
f"回答3:<生成第三个回答>\n\n"
|
||||||
|
f"请确保每个问题和回答都保持与内容的紧密相关性,并保持专业性。"
|
||||||
|
)
|
||||||
|
|
||||||
|
question = [{"role": "user", "content": prompt}]
|
||||||
|
SparkApi.answer = ""
|
||||||
|
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question)
|
||||||
|
return SparkApi.answer.strip()
|
||||||
|
|
||||||
|
|
||||||
|
# 加载断点进度
|
||||||
|
def load_checkpoint():
|
||||||
|
if os.path.exists(checkpoint_file):
|
||||||
|
with open(checkpoint_file, 'r') as f:
|
||||||
|
return int(f.read().strip()) # 返回已处理的行索引
|
||||||
|
return 0 # 没有断点则从0开始
|
||||||
|
|
||||||
|
|
||||||
|
# 保存断点进度
|
||||||
|
def save_checkpoint(index):
|
||||||
|
with open(checkpoint_file, 'w') as f:
|
||||||
|
f.write(str(index))
|
||||||
|
|
||||||
|
|
||||||
|
# 解析返回的问答对,处理多个问答对的情况
|
||||||
|
def parse_multiple_qa(answer_text):
|
||||||
|
qa_pairs = []
|
||||||
|
# 通过正则表达式找到所有的问答对
|
||||||
|
pattern = re.compile(r"问题\d+:(.*?)回答\d+:(.*?)(问题|$)", re.S)
|
||||||
|
matches = pattern.findall(answer_text)
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
question = match[0].strip()
|
||||||
|
answer = match[1].strip()
|
||||||
|
qa_pairs.append({"input": question, "output": answer})
|
||||||
|
|
||||||
|
return qa_pairs
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 加载原始数据集
|
||||||
|
with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
|
text_data = [json.loads(line) for line in f]
|
||||||
|
|
||||||
|
# 加载断点进度
|
||||||
|
start_index = load_checkpoint()
|
||||||
|
|
||||||
|
# 从断点开始继续生成问答对
|
||||||
|
with open(output_file, 'a', encoding='utf-8') as f:
|
||||||
|
for i in tqdm(range(start_index, len(text_data))):
|
||||||
|
item = text_data[i]
|
||||||
|
input_content = item['input']
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用API生成新的问答对
|
||||||
|
api_generated_qa = generate_qa_via_api(input_content)
|
||||||
|
|
||||||
|
# 解析API生成的问答对并添加到数据集
|
||||||
|
qa_pairs = parse_multiple_qa(api_generated_qa)
|
||||||
|
expanded_data = [{"input": qa_pair['input'], "output": qa_pair['output']} for qa_pair in qa_pairs]
|
||||||
|
|
||||||
|
# 保存生成的问答对
|
||||||
|
for qa in expanded_data:
|
||||||
|
json.dump(qa, f, ensure_ascii=False)
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
# 保存当前的进度索引
|
||||||
|
save_checkpoint(i)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing item {i}: {e}")
|
||||||
|
# 跳过当前条目继续处理
|
||||||
|
save_checkpoint(i)
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"已生成 {output_file} 文件,包含扩展的问答对。")
|
153
generate_data/EC_process/gen_QA.py
Normal file
153
generate_data/EC_process/gen_QA.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Time : 2024/10/22
|
||||||
|
# @Author : 黄子寒
|
||||||
|
# @File : generate_qa_with_multiple_pairs.py
|
||||||
|
# @Project : EmoLLM
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from tqdm import tqdm
|
||||||
|
import SparkApi
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
appid = "f0f73de5"
|
||||||
|
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
|
||||||
|
api_key = "5773f6f95563708de994d17b7ea5d414"
|
||||||
|
|
||||||
|
# Spark服务地址及版本
|
||||||
|
domain = "4.0Ultra"
|
||||||
|
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||||
|
|
||||||
|
# 准备存储清洗后的文本
|
||||||
|
text_data = []
|
||||||
|
|
||||||
|
# 断点文件,用于存储上次处理的段落索引
|
||||||
|
checkpoint_file = "output/progress_checkpoint.txt"
|
||||||
|
|
||||||
|
# 加载处理好的文本文件
|
||||||
|
with open("../processPDF/cleaned_data.txt", "r", encoding="utf-8") as f:
|
||||||
|
cleaned_text = f.read()
|
||||||
|
|
||||||
|
|
||||||
|
# 自定义分割函数,按最大100字以内的句子段落
|
||||||
|
def split_text_to_sentences(text, max_length=300):
|
||||||
|
sentences = re.split('(?<=。)', text)
|
||||||
|
grouped_sentences = []
|
||||||
|
current_group = ""
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
if len(current_group) + len(sentence) <= max_length:
|
||||||
|
current_group += sentence
|
||||||
|
else:
|
||||||
|
grouped_sentences.append(current_group.strip())
|
||||||
|
current_group = sentence
|
||||||
|
|
||||||
|
if current_group:
|
||||||
|
grouped_sentences.append(current_group.strip())
|
||||||
|
|
||||||
|
return grouped_sentences
|
||||||
|
|
||||||
|
|
||||||
|
# 加载断点进度
|
||||||
|
def load_checkpoint():
|
||||||
|
if os.path.exists(checkpoint_file):
|
||||||
|
with open(checkpoint_file, 'r') as f:
|
||||||
|
return int(f.read().strip()) # 返回已处理的段落索引
|
||||||
|
return 0 # 没有断点则从0开始
|
||||||
|
|
||||||
|
|
||||||
|
# 保存断点进度
|
||||||
|
def save_checkpoint(index):
|
||||||
|
with open(checkpoint_file, 'w') as f:
|
||||||
|
f.write(str(index))
|
||||||
|
|
||||||
|
|
||||||
|
# 将文本按要求的长度进行分割
|
||||||
|
paragraphs = split_text_to_sentences(cleaned_text, 300)
|
||||||
|
|
||||||
|
|
||||||
|
# 构建 LLM 生成 input 和 output 的详细 prompt,允许模型生成多个问答对
|
||||||
|
def create_prompt(content):
|
||||||
|
prompt = (
|
||||||
|
f"你是一位油橄榄栽培专家。"
|
||||||
|
f"根据以下内容生成一个或多个问题和回答对,请保证语句通顺有逻辑,同时忽略所有内容中和图示相关的内容:\n\n"
|
||||||
|
f"内容:{content}\n\n"
|
||||||
|
f"请以如下格式生成输出:\n"
|
||||||
|
f"问题1:<在这里生成第一个问题>\n"
|
||||||
|
f"回答1:<在这里生成第一个回答>\n"
|
||||||
|
f"问题2:<在这里生成第二个问题(如有)>\n"
|
||||||
|
f"回答2:<在这里生成第二个回答(如有)>\n"
|
||||||
|
f"..."
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
# 解析返回的问答对,处理多个问答对的情况
|
||||||
|
def parse_multiple_qa(answer_text):
|
||||||
|
qa_pairs = []
|
||||||
|
# 通过正则表达式找到所有的问答对
|
||||||
|
pattern = re.compile(r"问题\d+:(.*?)回答\d+:(.*?)(问题|$)", re.S)
|
||||||
|
matches = pattern.findall(answer_text)
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
question = match[0].strip()
|
||||||
|
answer = match[1].strip()
|
||||||
|
qa_pairs.append({"input": question, "output": answer})
|
||||||
|
|
||||||
|
return qa_pairs
|
||||||
|
|
||||||
|
|
||||||
|
# 迭代限制,防止API额度过大
|
||||||
|
def checklen(text):
|
||||||
|
while len(text) > 8000: # 限制在8000字符以内
|
||||||
|
del text[0]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
text_data.clear()
|
||||||
|
file_name = 'output/train_optimized_multiple.jsonl'
|
||||||
|
conversations = []
|
||||||
|
|
||||||
|
# 加载上次的进度
|
||||||
|
start_index = load_checkpoint()
|
||||||
|
|
||||||
|
# 从断点开始继续生成问答对
|
||||||
|
# 从断点开始继续生成问答对
|
||||||
|
for i in tqdm(range(start_index, len(paragraphs))): # 处理所有剩余的段落
|
||||||
|
content = paragraphs[i].strip() # 去除段落前后的空格
|
||||||
|
print("====================\ncontent:", content, "\n==================\n")
|
||||||
|
if len(content) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 构建 LLM 的 prompt
|
||||||
|
prompt = create_prompt(content)
|
||||||
|
question = checklen([{"role": "user", "content": prompt}])
|
||||||
|
|
||||||
|
# 调用 LLM 生成问答对
|
||||||
|
SparkApi.answer = "" # 清空之前的回答
|
||||||
|
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question) # 调用API获取回答
|
||||||
|
|
||||||
|
# 将生成的文本分割为问题和回答
|
||||||
|
answer_text = SparkApi.answer.strip()
|
||||||
|
|
||||||
|
# 解析多个问答对
|
||||||
|
qa_pairs = parse_multiple_qa(answer_text)
|
||||||
|
|
||||||
|
for qa_pair in qa_pairs:
|
||||||
|
conversation = {
|
||||||
|
"input": qa_pair['input'],
|
||||||
|
"output": qa_pair['output']
|
||||||
|
}
|
||||||
|
|
||||||
|
# 将对话数据添加到文件中
|
||||||
|
with open(file_name, 'a', encoding='utf-8') as file:
|
||||||
|
json.dump(conversation, file, ensure_ascii=False)
|
||||||
|
file.write("\n")
|
||||||
|
|
||||||
|
# 每处理完一个段落,保存当前的进度索引
|
||||||
|
save_checkpoint(i)
|
||||||
|
|
||||||
|
print(f"已生成 {file_name} 文件,包含问答对。")
|
||||||
|
|
32
generate_data/EC_process/jsonl2json.py
Normal file
32
generate_data/EC_process/jsonl2json.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Time : 2024/10/24 20:47
|
||||||
|
# @Author : 黄子寒
|
||||||
|
# @Email : 1064071566@qq.com
|
||||||
|
# @File : jsonl2json.py
|
||||||
|
# @Project : EmoLLM
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
input_file = 'output/fine_tune_data.jsonl'
|
||||||
|
output_file = 'output/fine_tune_data.json'
|
||||||
|
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
entry = json.loads(line.strip())
|
||||||
|
|
||||||
|
new_entry = {
|
||||||
|
"instruction": entry.get("instruction", ""),
|
||||||
|
"input": entry.get("input", ""),
|
||||||
|
"output": entry.get("output", ""),
|
||||||
|
"system": entry.get("system", ""),
|
||||||
|
"history": entry.get("history", [])
|
||||||
|
}
|
||||||
|
data_list.append(new_entry)
|
||||||
|
|
||||||
|
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data_list, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
print(f" {output_file}")
|
1848
generate_data/EC_process/output/train_expanded_part2.jsonl
Normal file
1848
generate_data/EC_process/output/train_expanded_part2.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
50
generate_data/EC_process/processPDF/OCR.py
Normal file
50
generate_data/EC_process/processPDF/OCR.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Time : 2024/10/18 22:09
|
||||||
|
# @Author : 黄子寒
|
||||||
|
# @Email : 1064071566@qq.com
|
||||||
|
# @File : OCR.py
|
||||||
|
# @Project : EmoLLM
|
||||||
|
import cv2
|
||||||
|
from paddleocr import PaddleOCR
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
|
||||||
|
# 初始化OCR模型
|
||||||
|
ocr = PaddleOCR(use_angle_cls=True, lang='ch')
|
||||||
|
|
||||||
|
|
||||||
|
image_dir = "output"
|
||||||
|
output_txt_dir = "output_txt"
|
||||||
|
|
||||||
|
|
||||||
|
if not os.path.exists(output_txt_dir):
|
||||||
|
os.makedirs(output_txt_dir)
|
||||||
|
|
||||||
|
image_list = glob.glob(os.path.join(image_dir, "*.png"))
|
||||||
|
|
||||||
|
# 批量识别处理
|
||||||
|
for img_path in image_list:
|
||||||
|
# 读取图像
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
|
||||||
|
# 使用OCR模型进行识别
|
||||||
|
result = ocr.ocr(img)
|
||||||
|
|
||||||
|
# 获取图像文件名(不带扩展名)
|
||||||
|
img_name = os.path.splitext(os.path.basename(img_path))[0]
|
||||||
|
|
||||||
|
# 将OCR结果整理为文本
|
||||||
|
txt_file_path = os.path.join(output_txt_dir, f"{img_name}.txt")
|
||||||
|
|
||||||
|
# 打开文件以写入OCR结果
|
||||||
|
with open(txt_file_path, 'w', encoding='utf-8') as f:
|
||||||
|
for line in result:
|
||||||
|
for word_info in line:
|
||||||
|
# 提取识别到的文本和其置信度
|
||||||
|
word, confidence = word_info[1][0], word_info[1][1]
|
||||||
|
|
||||||
|
f.write(f"{word}\n")
|
||||||
|
|
||||||
|
print(f"Word: {word}, Confidence: {confidence}")
|
||||||
|
|
||||||
|
print(f"{txt_file_path}")
|
39
generate_data/EC_process/processPDF/PDF2Pic.py
Normal file
39
generate_data/EC_process/processPDF/PDF2Pic.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Time : 2024/10/21 22:09
|
||||||
|
# @Author : 黄子寒
|
||||||
|
# @Email : 1064071566@qq.com
|
||||||
|
# @File : PDF2Pic.py
|
||||||
|
# @Project : EmoLLM
|
||||||
|
import fitz # PyMuPDF
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
|
||||||
|
# PDF 文件路径和输出图像保存目录
|
||||||
|
pdf_file_path = "input.pdf"
|
||||||
|
output_image_dir = "output"
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
if not os.path.exists(output_image_dir):
|
||||||
|
os.makedirs(output_image_dir)
|
||||||
|
|
||||||
|
# 打开 PDF 文件
|
||||||
|
pdf_document = fitz.open(pdf_file_path)
|
||||||
|
|
||||||
|
# 遍历每一页并保存为图像
|
||||||
|
for page_number in range(len(pdf_document)):
|
||||||
|
# 获取当前页对象
|
||||||
|
page = pdf_document.load_page(page_number)
|
||||||
|
|
||||||
|
# 将页面转换为图像
|
||||||
|
zoom = 4
|
||||||
|
mat = fitz.Matrix(zoom, zoom)
|
||||||
|
pix = page.get_pixmap(matrix=mat)
|
||||||
|
|
||||||
|
|
||||||
|
image_path = os.path.join(output_image_dir, f"{page_number + 1}.png")
|
||||||
|
pix.save(image_path)
|
||||||
|
|
||||||
|
print(f"Saved {image_path}")
|
||||||
|
|
||||||
|
|
||||||
|
pdf_document.close()
|
25
generate_data/EC_process/processPDF/mergeTXT.py
Normal file
25
generate_data/EC_process/processPDF/mergeTXT.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import natsort
|
||||||
|
|
||||||
|
folder_path = "output_txt"
|
||||||
|
combined_text = ""
|
||||||
|
|
||||||
|
# 使用自然排序来读取文件
|
||||||
|
for filename in natsort.natsorted(os.listdir(folder_path)):
|
||||||
|
if filename.endswith(".txt"):
|
||||||
|
file_path = os.path.join(folder_path, filename)
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||||||
|
combined_text += file.read()
|
||||||
|
|
||||||
|
|
||||||
|
combined_text = combined_text.replace('\n', '')
|
||||||
|
|
||||||
|
# 处理连续三个或更多相同的标点符号
|
||||||
|
combined_text = re.sub(r'([。,!?:;. ·])\1{2,}', r'\1', combined_text)
|
||||||
|
|
||||||
|
# 将清洗后的文本保存到一个新的文件中
|
||||||
|
with open("cleaned_data.txt", 'w', encoding='utf-8') as file:
|
||||||
|
file.write(combined_text)
|
||||||
|
|
||||||
|
print("数据处理完成")
|
84
generate_data/EC_process/process_missing_QA.py
Normal file
84
generate_data/EC_process/process_missing_QA.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
import SparkApi
|
||||||
|
|
||||||
|
# 输入文件路径
|
||||||
|
input_file = 'output/train_expanded.jsonl'
|
||||||
|
# 断点文件路径
|
||||||
|
checkpoint_file = 'output/expand_checkpoint.txt'
|
||||||
|
# 临时文件路径
|
||||||
|
temp_file = 'output/tmp_train_expanded.jsonl'
|
||||||
|
|
||||||
|
|
||||||
|
# 调用API生成回答
|
||||||
|
def generate_answer_via_api(question):
|
||||||
|
appid = "48d04aae"
|
||||||
|
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
|
||||||
|
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
|
||||||
|
domain = "4.0Ultra"
|
||||||
|
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
|
||||||
|
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
|
||||||
|
f"每个回答应该准确且不超过50字,同时不少于20字,以保证内容的简洁和有用性。\n"
|
||||||
|
f"问题:{question}\n\n"
|
||||||
|
f"请生成一个详细回答。"
|
||||||
|
)
|
||||||
|
|
||||||
|
question_data = [{"role": "user", "content": prompt}]
|
||||||
|
SparkApi.answer = ""
|
||||||
|
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question_data)
|
||||||
|
return SparkApi.answer.strip()
|
||||||
|
|
||||||
|
|
||||||
|
# 加载断点进度
|
||||||
|
def load_checkpoint():
|
||||||
|
if os.path.exists(checkpoint_file):
|
||||||
|
with open(checkpoint_file, 'r') as f:
|
||||||
|
return int(f.read().strip()) # 返回已处理的行索引
|
||||||
|
return 0 # 没有断点则从0开始
|
||||||
|
|
||||||
|
|
||||||
|
# 保存断点进度
|
||||||
|
def save_checkpoint(index):
|
||||||
|
with open(checkpoint_file, 'w') as f:
|
||||||
|
f.write(str(index))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 加载断点进度
|
||||||
|
start_index = load_checkpoint()
|
||||||
|
|
||||||
|
with open(input_file, 'r', encoding='utf-8') as f, open(temp_file, 'w', encoding='utf-8') as temp_f:
|
||||||
|
for i, line in enumerate(tqdm(f)):
|
||||||
|
item = json.loads(line)
|
||||||
|
|
||||||
|
# 从断点开始处理
|
||||||
|
if i >= start_index:
|
||||||
|
input_content = item['input']
|
||||||
|
output_content = item['output']
|
||||||
|
|
||||||
|
# # 检查是否是未提供回答的问答对
|
||||||
|
# if "未给" in output_content:
|
||||||
|
# # 使用API生成新的回答
|
||||||
|
# new_answer = generate_answer_via_api(input_content)
|
||||||
|
# item['output'] = new_answer
|
||||||
|
|
||||||
|
if len(output_content)<11:
|
||||||
|
# 使用API生成新的回答
|
||||||
|
new_answer = generate_answer_via_api(input_content)
|
||||||
|
item['output'] = new_answer
|
||||||
|
|
||||||
|
# 保存当前的进度索引
|
||||||
|
save_checkpoint(i)
|
||||||
|
|
||||||
|
# 写入更新内容到临时文件
|
||||||
|
json.dump(item, temp_f, ensure_ascii=False)
|
||||||
|
temp_f.write('\n')
|
||||||
|
|
||||||
|
# 替换原始文件
|
||||||
|
os.replace(temp_file, input_file)
|
||||||
|
print(f"已更新 {input_file} 文件,包含重新生成的回答。")
|
58
generate_data/EC_process/topic_model.py
Normal file
58
generate_data/EC_process/topic_model.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Time : 2024/10/23 23:16
|
||||||
|
# @Author : 黄子寒
|
||||||
|
# @Email : 1064071566@qq.com
|
||||||
|
# @File : topic_model.py
|
||||||
|
# @Project : EmoLLM
|
||||||
|
import json
|
||||||
|
import gensim
|
||||||
|
from gensim import corpora
|
||||||
|
from nltk.tokenize import word_tokenize
|
||||||
|
from nltk.corpus import stopwords
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
# 加载问答对数据
|
||||||
|
def load_qa_data(file_path):
|
||||||
|
qa_pairs = []
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
qa_pairs.append(json.loads(line.strip()))
|
||||||
|
return qa_pairs
|
||||||
|
|
||||||
|
# 文本预处理
|
||||||
|
def preprocess_text(text):
|
||||||
|
stop_words = set(stopwords.words('english'))
|
||||||
|
tokens = word_tokenize(text.lower())
|
||||||
|
tokens = [word for word in tokens if word.isalnum() and word not in stop_words]
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
# 生成LDA主题模型
|
||||||
|
def build_lda_model(qa_pairs, num_topics=5):
|
||||||
|
# 处理所有问题文本
|
||||||
|
questions = [qa['input'] for qa in qa_pairs]
|
||||||
|
processed_questions = [preprocess_text(question) for question in questions]
|
||||||
|
|
||||||
|
# 创建字典和词袋模型
|
||||||
|
dictionary = corpora.Dictionary(processed_questions)
|
||||||
|
corpus = [dictionary.doc2bow(text) for text in processed_questions]
|
||||||
|
|
||||||
|
# 训练LDA模型
|
||||||
|
lda_model = gensim.models.ldamodel.LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
|
||||||
|
return lda_model, dictionary, corpus
|
||||||
|
|
||||||
|
# 打印每个主题的关键词
|
||||||
|
def print_topics(lda_model, num_words=10):
|
||||||
|
for idx, topic in lda_model.print_topics(num_words=num_words):
|
||||||
|
print(f"主题 {idx}: {topic}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
qa_file = "output/train_optimized_multiple.jsonl" # 问答对文件
|
||||||
|
|
||||||
|
# 加载问答对
|
||||||
|
qa_pairs = load_qa_data(qa_file)
|
||||||
|
|
||||||
|
# 构建LDA主题模型
|
||||||
|
lda_model, dictionary, corpus = build_lda_model(qa_pairs, num_topics=5)
|
||||||
|
|
||||||
|
# 打印主题及其关键词
|
||||||
|
print_topics(lda_model)
|
@ -5,8 +5,34 @@ streamlit==1.24.0
|
|||||||
sentencepiece==0.1.99
|
sentencepiece==0.1.99
|
||||||
accelerate==0.24.1
|
accelerate==0.24.1
|
||||||
transformers_stream_generator==0.0.4
|
transformers_stream_generator==0.0.4
|
||||||
openxlab
|
openxlab~=0.0.11
|
||||||
tiktoken
|
tiktoken
|
||||||
einops
|
einops
|
||||||
oss2
|
oss2
|
||||||
requests
|
requests~=2.32.3
|
||||||
|
|
||||||
|
pyjwt~=2.8.0
|
||||||
|
loguru~=0.6.0
|
||||||
|
yaml~=0.2.5
|
||||||
|
pyyaml~=6.0.1
|
||||||
|
tqdm~=4.66.2
|
||||||
|
langchain~=0.0.352
|
||||||
|
torch~=2.5.0
|
||||||
|
metagpt~=0.8.1
|
||||||
|
erniebot~=0.5.9
|
||||||
|
python-dotenv~=1.0.0
|
||||||
|
zhipuai~=2.0.1
|
||||||
|
uvicorn~=0.32.0
|
||||||
|
fastapi~=0.115.2
|
||||||
|
opencv-python~=4.10.0.84
|
||||||
|
paddleocr~=2.9.0
|
||||||
|
dashscope~=1.14.1
|
||||||
|
numpy~=1.24.3
|
||||||
|
jieba~=0.42.1
|
||||||
|
nltk~=3.9.1
|
||||||
|
setuptools~=65.6.3
|
||||||
|
websocket~=0.2.1
|
||||||
|
websocket-client~=1.6.2
|
||||||
|
gensim~=4.3.3
|
||||||
|
pillow~=9.5.0
|
||||||
|
natsort~=8.4.0
|
Loading…
Reference in New Issue
Block a user