自定义数据集处理脚本
This commit is contained in:
parent
2065b2176c
commit
1125b67f50
217
generate_data/EC_process/Embedding_merge.py
Normal file
217
generate_data/EC_process/Embedding_merge.py
Normal file
@ -0,0 +1,217 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import os
|
||||
|
||||
|
||||
class AssembleHeaderException(Exception):
|
||||
def __init__(this, msg):
|
||||
this.message = msg
|
||||
|
||||
|
||||
class Url:
|
||||
def __init__(this, host, path, schema):
|
||||
this.host = host
|
||||
this.path = path
|
||||
this.schema = schema
|
||||
pass
|
||||
|
||||
|
||||
# calculate sha256 and encode to base64
|
||||
def sha256base64(data):
|
||||
sha256 = hashlib.sha256()
|
||||
sha256.update(data)
|
||||
digest = base64.b64encode(sha256.digest()).decode(encoding='utf-8')
|
||||
return digest
|
||||
|
||||
|
||||
def parse_url(requset_url):
|
||||
stidx = requset_url.index("://")
|
||||
host = requset_url[stidx + 3:]
|
||||
schema = requset_url[:stidx + 3]
|
||||
edidx = host.index("/")
|
||||
if edidx <= 0:
|
||||
raise AssembleHeaderException("invalid request url:" + requset_url)
|
||||
path = host[edidx:]
|
||||
host = host[:edidx]
|
||||
u = Url(host, path, schema)
|
||||
return u
|
||||
|
||||
|
||||
# 生成鉴权url
|
||||
def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
|
||||
u = parse_url(requset_url)
|
||||
host = u.host
|
||||
path = u.path
|
||||
now = datetime.now()
|
||||
date = format_date_time(time.mktime(now.timetuple()))
|
||||
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path)
|
||||
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
api_key, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
values = {
|
||||
"host": host,
|
||||
"date": date,
|
||||
"authorization": authorization
|
||||
}
|
||||
|
||||
return requset_url + "?" + urlencode(values)
|
||||
|
||||
|
||||
def get_Body(appid, text, style):
|
||||
org_content = json.dumps(text).encode('utf-8')
|
||||
body = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "39769795890",
|
||||
"status": 3
|
||||
},
|
||||
"parameter": {
|
||||
"emb": {
|
||||
"domain": style,
|
||||
"feature": {
|
||||
"encoding": "utf8"
|
||||
}
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"messages": {
|
||||
"text": base64.b64encode(json.dumps(text).encode('utf-8')).decode()
|
||||
}
|
||||
}
|
||||
}
|
||||
return body
|
||||
|
||||
|
||||
# 发起请求并返回结果
|
||||
def get_embp_embedding(text, appid, apikey, apisecret):
|
||||
host = 'https://emb-cn-huabei-1.xf-yun.com/'
|
||||
url = assemble_ws_auth_url(host, method='POST', api_key=apikey, api_secret=apisecret)
|
||||
content = get_Body(appid, text, "para")
|
||||
response = requests.post(url, json=content, headers={'content-type': "application/json"}).text
|
||||
return response
|
||||
|
||||
|
||||
# 解析结果并输出
|
||||
def parser_Message(message):
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
print(f'请求错误: {code}, {data}')
|
||||
return None
|
||||
else:
|
||||
text_base = data["payload"]["feature"]["text"]
|
||||
text_data = base64.b64decode(text_base)
|
||||
dt = np.dtype(np.float32).newbyteorder("<")
|
||||
text = np.frombuffer(text_data, dtype=dt)
|
||||
return text
|
||||
|
||||
|
||||
# 加载问答对数据
|
||||
def load_qa_data(file_path):
|
||||
qa_pairs = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
qa_pairs.append(json.loads(line.strip()))
|
||||
return qa_pairs
|
||||
|
||||
|
||||
# 保存embedding到文件
|
||||
def save_embeddings(embeddings, file_path):
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(embeddings, f, ensure_ascii=False)
|
||||
|
||||
|
||||
# 获取文本的embedding
|
||||
def get_embedding_for_text(text, appid, apikey, apisecret):
|
||||
desc = {"messages": [{"content": text, "role": "user"}]}
|
||||
res = get_embp_embedding(desc, appid=appid, apikey=apikey, apisecret=apisecret)
|
||||
return parser_Message(res)
|
||||
|
||||
|
||||
# 逐行加载已存在的embedding
|
||||
def load_embeddings(file_path):
|
||||
embeddings = {}
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip(): # 忽略空行
|
||||
embedding_data = json.loads(line.strip())
|
||||
embeddings.update(embedding_data)
|
||||
except FileNotFoundError:
|
||||
print(f"文件 {file_path} 不存在,将创建新文件")
|
||||
return embeddings
|
||||
|
||||
|
||||
# 逐行保存embedding到文件
|
||||
def save_embedding_line_by_line(qa, embedding, file_path):
|
||||
if embedding is not None:
|
||||
embedding_as_list = embedding.tolist() # 将numpy array转换为列表
|
||||
with open(file_path, 'a', encoding='utf-8') as f:
|
||||
json.dump({qa: embedding_as_list}, f, ensure_ascii=False)
|
||||
f.write("\n") # 每行一个embedding
|
||||
|
||||
|
||||
# 获取单个问题的embedding,并处理请求错误
|
||||
def get_embedding_with_retry(question, appid, apikey, apisecret, max_retries=5):
|
||||
retries = 0
|
||||
while retries < max_retries:
|
||||
try:
|
||||
embedding = get_embedding_for_text(question, appid, apikey, apisecret)
|
||||
if embedding is not None:
|
||||
return embedding
|
||||
except Exception as e:
|
||||
print(f"请求错误: {e}")
|
||||
retries += 1
|
||||
print(f"重试第 {retries} 次...")
|
||||
time.sleep(5) # 每次重试前等待 5 秒
|
||||
print(f"获取'{question}' 的embedding失败")
|
||||
return None
|
||||
|
||||
|
||||
# 获取所有问答对的embedding并逐行保存
|
||||
def get_and_save_embeddings(qa_pairs, appid, apikey, apisecret, file_path, qps_limit=2):
|
||||
all_embeddings = load_embeddings(file_path) # 尝试加载已存在的embedding
|
||||
interval = 1 / qps_limit # 根据QPS限制设置间隔时间
|
||||
for qa in qa_pairs:
|
||||
question = qa['input']
|
||||
if question in all_embeddings:
|
||||
print(f"'{question}' 的embedding已存在,跳过计算")
|
||||
continue
|
||||
print(f"计算'{question}' 的embedding...")
|
||||
embedding = get_embedding_with_retry(question, appid, apikey, apisecret) # 带重试机制的请求
|
||||
if embedding is not None:
|
||||
save_embedding_line_by_line(question, embedding, file_path) # 逐行保存
|
||||
all_embeddings[question] = embedding # 更新已计算的embedding
|
||||
time.sleep(interval) # 确保符合QPS限制
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 设置路径
|
||||
qa_file = "output/train_optimized_multiple.jsonl" # 原问答对文件
|
||||
embedding_file = "output/qa_embeddings.json" # embedding存储文件
|
||||
|
||||
appid = "f0f73de5"
|
||||
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
|
||||
api_key = "5773f6f95563708de994d17b7ea5d414"
|
||||
|
||||
# 加载数据
|
||||
qa_pairs = load_qa_data(qa_file)
|
||||
|
||||
# 获取并保存embedding
|
||||
get_and_save_embeddings(qa_pairs, appid, api_key, api_secret, embedding_file)
|
||||
|
||||
print(f"已保存所有问答对的embedding到 {embedding_file}")
|
122
generate_data/EC_process/Embedding_similarity.py
Normal file
122
generate_data/EC_process/Embedding_similarity.py
Normal file
@ -0,0 +1,122 @@
|
||||
import json
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
import jsonlines
|
||||
|
||||
# 加载问答对嵌入
|
||||
qa_embeddings = {}
|
||||
with jsonlines.open('output/qa_embeddings.json', 'r') as reader:
|
||||
for obj in reader:
|
||||
qa_embeddings.update(obj) # 将每行的json对象加入到qa_embeddings
|
||||
|
||||
# 加载问答对
|
||||
qa_pairs = []
|
||||
with open('output/train_optimized_multiple.jsonl', 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
qa_pairs.append(json.loads(line))
|
||||
|
||||
# 提取嵌入和问题
|
||||
questions = list(qa_embeddings.keys())
|
||||
embeddings = np.array(list(qa_embeddings.values()))
|
||||
|
||||
# 关键词及其类别
|
||||
categories = {
|
||||
"栽培油橄榄的意义": ["栽培油橄榄", "经济价值", "引种"],
|
||||
"油橄榄属植物分类": ["油橄榄属", "植物分类", "植物种", "原产地"],
|
||||
"油橄榄生物学特性": ["根系类型", "土壤关系", "花芽分化", "花序", "授粉特性", "果实发育", "油脂形成"],
|
||||
"油橄榄的生态环境条件": ["气候条件", "温度", "光照", "水分", "土壤生态", "海拔高度", "坡度"],
|
||||
"油橄榄品种": ["佛奥", "莱星", "皮削利", "阿斯", "配多灵", "果大尔", "皮瓜尔", "科拉蒂", "克里", "爱桑", "贝拉", "实生种"],
|
||||
"油橄榄育苗技术": ["育苗场地", "种子繁殖", "实生苗", "嫁接繁殖", "砧木", "接穗", "扦插繁殖", "组织培养"],
|
||||
"油橄榄种植": ["园地选择", "种植密度", "栽植方式", "栽后管理"],
|
||||
"土壤、肥料、水管理": ["土壤管理", "矿质营养", "果园灌溉", "果实采收"],
|
||||
"整形修剪": ["整形修剪", "生物学原理", "结果习性", "树形", "幼树修剪", "复壮修剪"],
|
||||
"病虫害防治": ["孔雀斑病", "炭疽病", "黄萎病", "肿瘤病", "根腐病", "云斑天牛", "油橄榄片盾", "大粒横沟象"]
|
||||
}
|
||||
|
||||
# 初始化类别关键词的嵌入字典
|
||||
category_embeddings = {category: [] for category in categories}
|
||||
|
||||
|
||||
# 假设我们有一个方法来计算关键词的嵌入,例如从qa_embeddings中获取
|
||||
def get_keyword_embedding(keyword):
|
||||
return qa_embeddings.get(keyword, None)
|
||||
|
||||
|
||||
# 为每个类别生成关键词的嵌入
|
||||
for category, keywords in categories.items():
|
||||
for keyword in keywords:
|
||||
keyword_embedding = get_keyword_embedding(keyword)
|
||||
if keyword_embedding is not None:
|
||||
category_embeddings[category].append(keyword_embedding)
|
||||
|
||||
# 将类别关键词的嵌入转化为平均向量
|
||||
for category in category_embeddings:
|
||||
if category_embeddings[category]:
|
||||
category_embeddings[category] = np.mean(category_embeddings[category], axis=0)
|
||||
else:
|
||||
category_embeddings[category] = np.zeros(embeddings.shape[1]) # 默认空向量
|
||||
|
||||
# 计算每个问题与类别之间的相似度
|
||||
category_similarities = {}
|
||||
for idx, question in enumerate(questions):
|
||||
question_embedding = embeddings[idx]
|
||||
category_similarities[question] = {}
|
||||
|
||||
for category, category_embedding in category_embeddings.items():
|
||||
similarity = cosine_similarity([question_embedding], [category_embedding])[0][0]
|
||||
category_similarities[question][category] = similarity
|
||||
|
||||
# 将每个问题分配到相似度最高的类别
|
||||
category_assignments = {category: [] for category in categories}
|
||||
for question in questions:
|
||||
best_category = max(category_similarities[question], key=category_similarities[question].get)
|
||||
category_assignments[best_category].append(question)
|
||||
|
||||
# 整合并生成新的jsonl格式,确保每个问答对都被包括
|
||||
fine_tune_data = []
|
||||
for category, assigned_questions in category_assignments.items():
|
||||
for idx, question in enumerate(assigned_questions):
|
||||
history = []
|
||||
output = ""
|
||||
instruction = ""
|
||||
|
||||
# 查找当前问题及其回答
|
||||
qa_pair = next((qa for qa in qa_pairs if qa['input'] == question), None)
|
||||
|
||||
if qa_pair:
|
||||
instruction = qa_pair['input'] # 当前问题作为instruction
|
||||
output = qa_pair['output'] # 当前问题的回答作为output
|
||||
|
||||
# 从同一类别的其他问题构建history,保证每个history与当前问题在同一类别
|
||||
history_similarities = []
|
||||
for related_question in assigned_questions:
|
||||
if related_question != question:
|
||||
related_embedding = qa_embeddings[related_question]
|
||||
similarity = cosine_similarity([qa_embeddings[question]], [related_embedding])[0][0]
|
||||
history_similarities.append((related_question, similarity))
|
||||
|
||||
# 按相似度排序,并选择前1~3个问题作为history
|
||||
history_similarities = sorted(history_similarities, key=lambda x: x[1], reverse=True)
|
||||
for related_question, _ in history_similarities[:3]:
|
||||
related_qa_pair = next((qa for qa in qa_pairs if qa['input'] == related_question), None)
|
||||
if related_qa_pair:
|
||||
history.append([related_qa_pair['input'], related_qa_pair['output']])
|
||||
|
||||
# 构建最终格式
|
||||
if instruction and output:
|
||||
fine_tune_entry = {
|
||||
"instruction": instruction,
|
||||
"input": "", # input为空
|
||||
"output": output, # 当前问题的回答
|
||||
"history": history, # 最多包含3条相关问题
|
||||
"system": "你是一位油橄榄栽培专家,熟知油橄榄的品种分类、栽培技术、生态环境要求以及病虫害防治。"
|
||||
}
|
||||
fine_tune_data.append(fine_tune_entry)
|
||||
|
||||
# 保存新的jsonl格式
|
||||
with open('output/fine_tune_data.jsonl', 'w', encoding='utf-8') as f:
|
||||
for entry in fine_tune_data:
|
||||
json.dump(entry, f, ensure_ascii=False)
|
||||
f.write('\n')
|
||||
|
||||
print("对话数据整理完成")
|
71
generate_data/EC_process/LDArec.py
Normal file
71
generate_data/EC_process/LDArec.py
Normal file
@ -0,0 +1,71 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/10/24 11:10
|
||||
# @Author : 黄子寒
|
||||
# @Email : 1064071566@qq.com
|
||||
# @File : LDArec.py
|
||||
# @Project : EmoLLM
|
||||
import json
|
||||
import jieba
|
||||
from gensim import corpora
|
||||
from gensim.models.ldamodel import LdaModel
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
# 加载问答对数据
|
||||
def load_qa_data(file_path):
|
||||
qa_pairs = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
qa_pairs.append(json.loads(line.strip()))
|
||||
return qa_pairs
|
||||
|
||||
|
||||
# 加载中文停用词
|
||||
def load_stopwords(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return set([line.strip() for line in f])
|
||||
|
||||
|
||||
# 使用jieba对中文文本进行分词,并去除停用词
|
||||
def preprocess_text(text, stopwords):
|
||||
words = jieba.lcut(text) # 使用jieba进行中文分词
|
||||
words = [word for word in words if word not in stopwords and len(word) > 1] # 去除停用词和长度为1的词
|
||||
return words
|
||||
|
||||
|
||||
# 生成LDA主题模型
|
||||
def build_lda_model(qa_pairs, stopwords, num_topics=5):
|
||||
# 处理所有问题文本
|
||||
questions = [qa['input'] for qa in qa_pairs]
|
||||
processed_questions = [preprocess_text(question, stopwords) for question in questions]
|
||||
|
||||
# 创建字典和词袋模型
|
||||
dictionary = corpora.Dictionary(processed_questions)
|
||||
corpus = [dictionary.doc2bow(text) for text in processed_questions]
|
||||
|
||||
# 训练LDA模型
|
||||
lda_model = LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
|
||||
return lda_model, dictionary, corpus
|
||||
|
||||
|
||||
# 打印每个主题的关键词
|
||||
def print_topics(lda_model, num_words=10):
|
||||
for idx, topic in lda_model.print_topics(num_words=num_words):
|
||||
print(f"主题 {idx}: {topic}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
qa_file = "output/train_optimized_multiple.jsonl" # 问答对文件
|
||||
stopwords_file = "chinese_stopwords.txt" # 停用词文件
|
||||
|
||||
# 加载问答对
|
||||
qa_pairs = load_qa_data(qa_file)
|
||||
|
||||
# 加载停用词
|
||||
stopwords = load_stopwords(stopwords_file)
|
||||
|
||||
# 构建LDA主题模型
|
||||
lda_model, dictionary, corpus = build_lda_model(qa_pairs, stopwords, num_topics=20)
|
||||
|
||||
# 打印主题及其关键词
|
||||
print_topics(lda_model)
|
70
generate_data/EC_process/Sensor_QA.py
Normal file
70
generate_data/EC_process/Sensor_QA.py
Normal file
@ -0,0 +1,70 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import random
|
||||
|
||||
# 定义生成5000条数据集的函数
|
||||
def generate_dataset(num_samples=5000):
|
||||
dataset = []
|
||||
invoke_types = [1, 2, 3]
|
||||
area_codes = [chr(i) for i in range(ord('A'), ord('Z') + 1)]
|
||||
parameters = [
|
||||
{"name": "土壤湿度", "unit": "%", "min": 10, "max": 100},
|
||||
{"name": "土壤温度", "unit": "℃", "min": 5, "max": 40},
|
||||
{"name": "空气温度", "unit": "℃", "min": -10, "max": 45},
|
||||
{"name": "电导率", "unit": "mS/cm", "min": 0.1, "max": 5.0}
|
||||
]
|
||||
|
||||
for _ in range(num_samples):
|
||||
invoke_type = random.choice(invoke_types)
|
||||
area_code = random.choice(area_codes)
|
||||
parameter = random.choice(parameters)
|
||||
|
||||
if isinstance(parameter["min"], int):
|
||||
value = round(random.uniform(parameter["min"], parameter["max"]), 1)
|
||||
else:
|
||||
value = round(random.uniform(parameter["min"], parameter["max"]), 1)
|
||||
|
||||
# 增加多变的提问方式,使数据更自然化
|
||||
instruction_templates = [
|
||||
f"现在{area_code}种植区内{parameter['name']}如何?",
|
||||
f"请告诉我{area_code}区的{parameter['name']}情况。",
|
||||
f"{area_code}区当前的{parameter['name']}是多少?",
|
||||
f"我想知道{area_code}区的{parameter['name']}。",
|
||||
f"{area_code}区的{parameter['name']}现在是多少?",
|
||||
f"{area_code}种植区目前的{parameter['name']}是多少?",
|
||||
f"能提供{area_code}区的{parameter['name']}数据吗?",
|
||||
f"{area_code}种植区的{parameter['name']}是多少?",
|
||||
f"请查询{area_code}区的{parameter['name']}。",
|
||||
f"{area_code}区现在的{parameter['name']}数据是多少?",
|
||||
f"帮我看看{area_code}区{parameter['name']}的情况。",
|
||||
f"{area_code}区的{parameter['name']}值是多少?",
|
||||
f"帮我查一下{area_code}区的{parameter['name']}。",
|
||||
f"{area_code}区的{parameter['name']}现在什么情况?",
|
||||
f"请帮我查一下{area_code}种植区的{parameter['name']}是多少?",
|
||||
f"我需要知道{area_code}区的{parameter['name']}数据。",
|
||||
f"请问{area_code}区的{parameter['name']}如何?",
|
||||
f"帮我查询{area_code}区的{parameter['name']}情况。",
|
||||
f"现在{area_code}区的{parameter['name']}值是多少?"
|
||||
]
|
||||
instruction = random.choice(instruction_templates)
|
||||
output = f"{area_code}区现在{parameter['name']}{value}{parameter['unit']}"
|
||||
|
||||
data = {
|
||||
"instruction": instruction,
|
||||
"invokeType": str(invoke_type),
|
||||
"areaCode": area_code,
|
||||
"output": output
|
||||
}
|
||||
dataset.append(data)
|
||||
|
||||
return dataset
|
||||
|
||||
# 生成数据并保存为json文件
|
||||
if __name__ == '__main__':
|
||||
dataset = generate_dataset()
|
||||
output_file = 'output/synthetic_dataset.json'
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(dataset, f, ensure_ascii=False, indent=4)
|
||||
|
||||
print(f"已生成 {output_file} 文件,包含{len(dataset)}条数据。")
|
136
generate_data/EC_process/SparkApi.py
Normal file
136
generate_data/EC_process/SparkApi.py
Normal file
@ -0,0 +1,136 @@
|
||||
import _thread as thread
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
import ssl
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import websocket # 使用websocket_client
|
||||
answer = ""
|
||||
|
||||
class Ws_Param(object):
|
||||
# 初始化
|
||||
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
||||
self.APPID = APPID
|
||||
self.APIKey = APIKey
|
||||
self.APISecret = APISecret
|
||||
self.host = urlparse(Spark_url).netloc
|
||||
self.path = urlparse(Spark_url).path
|
||||
self.Spark_url = Spark_url
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": self.host
|
||||
}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.Spark_url + '?' + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
print("### error:", error)
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws,one,two):
|
||||
print(" ")
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
thread.start_new_thread(run, (ws,))
|
||||
|
||||
|
||||
def run(ws, *args):
|
||||
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
|
||||
ws.send(data)
|
||||
|
||||
|
||||
# 收到websocket消息的处理
|
||||
def on_message(ws, message):
|
||||
# print(message)
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
print(f'请求错误: {code}, {data}')
|
||||
ws.close()
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
print(content,end ="")
|
||||
global answer
|
||||
answer += content
|
||||
# print(1)
|
||||
if status == 2:
|
||||
ws.close()
|
||||
|
||||
|
||||
def gen_params(appid, domain,question):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "1234"
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"temperature": 0.5,
|
||||
"max_tokens": 2048
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": question
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def main(appid, api_key, api_secret, Spark_url,domain, question):
|
||||
# print("星火:")
|
||||
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = wsParam.create_url()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
||||
ws.appid = appid
|
||||
ws.question = question
|
||||
ws.domain = domain
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
|
24
generate_data/EC_process/api_test.py
Normal file
24
generate_data/EC_process/api_test.py
Normal file
@ -0,0 +1,24 @@
|
||||
import requests
|
||||
import json
|
||||
|
||||
url = "https://chatapi.midjourney-vip.cn/v1/chat/completions"
|
||||
|
||||
payload = json.dumps({
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "测试"
|
||||
}
|
||||
]
|
||||
})
|
||||
headers = {
|
||||
'Accept': 'application/json',
|
||||
'Authorization': 'sk-ATDf2Ax1YTGeeTaBD9Be2a7bE0064618Ae3378EaF0Df6f24',
|
||||
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
|
||||
print(response.text)
|
1598
generate_data/EC_process/chinese_stopwords.txt
Normal file
1598
generate_data/EC_process/chinese_stopwords.txt
Normal file
File diff suppressed because it is too large
Load Diff
66
generate_data/EC_process/custom_dict.txt
Normal file
66
generate_data/EC_process/custom_dict.txt
Normal file
@ -0,0 +1,66 @@
|
||||
栽培油橄榄
|
||||
经济价值
|
||||
引种
|
||||
油橄榄属
|
||||
植物分类
|
||||
植物种
|
||||
原产地
|
||||
根系类型
|
||||
土壤关系
|
||||
花芽分化
|
||||
花序
|
||||
授粉特性
|
||||
果实发育
|
||||
油脂形成
|
||||
气候条件
|
||||
温度
|
||||
光照
|
||||
水分
|
||||
土壤生态
|
||||
海拔高度
|
||||
坡度
|
||||
佛奥
|
||||
莱星
|
||||
皮削利
|
||||
阿斯
|
||||
配多灵
|
||||
果大尔
|
||||
皮瓜尔
|
||||
科拉蒂
|
||||
克里
|
||||
爱桑
|
||||
贝拉
|
||||
实生种
|
||||
育苗场地
|
||||
种子繁殖
|
||||
实生苗
|
||||
嫁接繁殖
|
||||
砧木
|
||||
接穗
|
||||
扦插繁殖
|
||||
组织培养
|
||||
园地选择
|
||||
种植密度
|
||||
栽植方式
|
||||
栽后管理
|
||||
土壤管理
|
||||
矿质营养
|
||||
果园灌溉
|
||||
果实采收
|
||||
整形修剪
|
||||
生物学原理
|
||||
结果习性
|
||||
树形
|
||||
幼树修剪
|
||||
复壮修剪
|
||||
孔雀斑病
|
||||
炭疽病
|
||||
黄萎病
|
||||
肿瘤病
|
||||
根腐病
|
||||
云斑天牛
|
||||
油橄榄片盾
|
||||
大粒横沟象
|
||||
引进品种名录
|
||||
中英对照品种名称
|
||||
病虫害判定表
|
116
generate_data/EC_process/extend_QA.py
Normal file
116
generate_data/EC_process/extend_QA.py
Normal file
@ -0,0 +1,116 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
|
||||
import SparkApi
|
||||
|
||||
# 输入文件路径
|
||||
input_file = 'output/train_expanded.jsonl'
|
||||
# 输出文件路径
|
||||
output_file = 'output/train_expanded_2.jsonl'
|
||||
# 断点文件路径
|
||||
checkpoint_file = 'output/e2_progress_checkpoint.txt'
|
||||
|
||||
|
||||
# 调用API生成问答对
|
||||
def generate_qa_via_api(content):
|
||||
appid = "48d04aae"
|
||||
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
|
||||
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
|
||||
domain = "4.0Ultra"
|
||||
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||
|
||||
prompt = (
|
||||
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
|
||||
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
|
||||
f"1. 根据给定内容生成**三个**相关的问题和回答。\n"
|
||||
f"2. 你可以简化问题、提取具体要素进行提问,或扩展内容生成额外的相关问题。\n"
|
||||
f"3. **问题必须简洁明了**,并涵盖内容中的关键信息。\n"
|
||||
f"4. 每个回答应该准确且**不超过50字**,同时**不少于20字**,以保证内容的简洁和有用性。\n"
|
||||
f"5. 仅围绕油橄榄栽培的相关内容生成问答对,忽略其他无关信息。\n\n"
|
||||
f"以下是给定内容:\n\n"
|
||||
f"内容:{content}\n\n"
|
||||
f"请按如下格式生成输出:\n"
|
||||
f"问题1:<生成第一个问题>\n"
|
||||
f"回答1:<生成第一个回答>\n"
|
||||
f"问题2:<生成第二个问题>\n"
|
||||
f"回答2:<生成第二个回答>\n"
|
||||
f"问题3:<生成第三个问题>\n"
|
||||
f"回答3:<生成第三个回答>\n\n"
|
||||
f"请确保每个问题和回答都保持与内容的紧密相关性,并保持专业性。"
|
||||
)
|
||||
|
||||
question = [{"role": "user", "content": prompt}]
|
||||
SparkApi.answer = ""
|
||||
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question)
|
||||
return SparkApi.answer.strip()
|
||||
|
||||
|
||||
# 加载断点进度
|
||||
def load_checkpoint():
|
||||
if os.path.exists(checkpoint_file):
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
return int(f.read().strip()) # 返回已处理的行索引
|
||||
return 0 # 没有断点则从0开始
|
||||
|
||||
|
||||
# 保存断点进度
|
||||
def save_checkpoint(index):
|
||||
with open(checkpoint_file, 'w') as f:
|
||||
f.write(str(index))
|
||||
|
||||
|
||||
# 解析返回的问答对,处理多个问答对的情况
|
||||
def parse_multiple_qa(answer_text):
|
||||
qa_pairs = []
|
||||
# 通过正则表达式找到所有的问答对
|
||||
pattern = re.compile(r"问题\d+:(.*?)回答\d+:(.*?)(问题|$)", re.S)
|
||||
matches = pattern.findall(answer_text)
|
||||
|
||||
for match in matches:
|
||||
question = match[0].strip()
|
||||
answer = match[1].strip()
|
||||
qa_pairs.append({"input": question, "output": answer})
|
||||
|
||||
return qa_pairs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 加载原始数据集
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
text_data = [json.loads(line) for line in f]
|
||||
|
||||
# 加载断点进度
|
||||
start_index = load_checkpoint()
|
||||
|
||||
# 从断点开始继续生成问答对
|
||||
with open(output_file, 'a', encoding='utf-8') as f:
|
||||
for i in tqdm(range(start_index, len(text_data))):
|
||||
item = text_data[i]
|
||||
input_content = item['input']
|
||||
|
||||
try:
|
||||
# 使用API生成新的问答对
|
||||
api_generated_qa = generate_qa_via_api(input_content)
|
||||
|
||||
# 解析API生成的问答对并添加到数据集
|
||||
qa_pairs = parse_multiple_qa(api_generated_qa)
|
||||
expanded_data = [{"input": qa_pair['input'], "output": qa_pair['output']} for qa_pair in qa_pairs]
|
||||
|
||||
# 保存生成的问答对
|
||||
for qa in expanded_data:
|
||||
json.dump(qa, f, ensure_ascii=False)
|
||||
f.write('\n')
|
||||
|
||||
# 保存当前的进度索引
|
||||
save_checkpoint(i)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing item {i}: {e}")
|
||||
# 跳过当前条目继续处理
|
||||
save_checkpoint(i)
|
||||
continue
|
||||
|
||||
print(f"已生成 {output_file} 文件,包含扩展的问答对。")
|
153
generate_data/EC_process/gen_QA.py
Normal file
153
generate_data/EC_process/gen_QA.py
Normal file
@ -0,0 +1,153 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/10/22
|
||||
# @Author : 黄子寒
|
||||
# @File : generate_qa_with_multiple_pairs.py
|
||||
# @Project : EmoLLM
|
||||
|
||||
import os
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
import SparkApi
|
||||
import json
|
||||
|
||||
|
||||
appid = "f0f73de5"
|
||||
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
|
||||
api_key = "5773f6f95563708de994d17b7ea5d414"
|
||||
|
||||
# Spark服务地址及版本
|
||||
domain = "4.0Ultra"
|
||||
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||
|
||||
# 准备存储清洗后的文本
|
||||
text_data = []
|
||||
|
||||
# 断点文件,用于存储上次处理的段落索引
|
||||
checkpoint_file = "output/progress_checkpoint.txt"
|
||||
|
||||
# 加载处理好的文本文件
|
||||
with open("../processPDF/cleaned_data.txt", "r", encoding="utf-8") as f:
|
||||
cleaned_text = f.read()
|
||||
|
||||
|
||||
# 自定义分割函数,按最大100字以内的句子段落
|
||||
def split_text_to_sentences(text, max_length=300):
|
||||
sentences = re.split('(?<=。)', text)
|
||||
grouped_sentences = []
|
||||
current_group = ""
|
||||
|
||||
for sentence in sentences:
|
||||
if len(current_group) + len(sentence) <= max_length:
|
||||
current_group += sentence
|
||||
else:
|
||||
grouped_sentences.append(current_group.strip())
|
||||
current_group = sentence
|
||||
|
||||
if current_group:
|
||||
grouped_sentences.append(current_group.strip())
|
||||
|
||||
return grouped_sentences
|
||||
|
||||
|
||||
# 加载断点进度
|
||||
def load_checkpoint():
|
||||
if os.path.exists(checkpoint_file):
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
return int(f.read().strip()) # 返回已处理的段落索引
|
||||
return 0 # 没有断点则从0开始
|
||||
|
||||
|
||||
# 保存断点进度
|
||||
def save_checkpoint(index):
|
||||
with open(checkpoint_file, 'w') as f:
|
||||
f.write(str(index))
|
||||
|
||||
|
||||
# 将文本按要求的长度进行分割
|
||||
paragraphs = split_text_to_sentences(cleaned_text, 300)
|
||||
|
||||
|
||||
# 构建 LLM 生成 input 和 output 的详细 prompt,允许模型生成多个问答对
|
||||
def create_prompt(content):
|
||||
prompt = (
|
||||
f"你是一位油橄榄栽培专家。"
|
||||
f"根据以下内容生成一个或多个问题和回答对,请保证语句通顺有逻辑,同时忽略所有内容中和图示相关的内容:\n\n"
|
||||
f"内容:{content}\n\n"
|
||||
f"请以如下格式生成输出:\n"
|
||||
f"问题1:<在这里生成第一个问题>\n"
|
||||
f"回答1:<在这里生成第一个回答>\n"
|
||||
f"问题2:<在这里生成第二个问题(如有)>\n"
|
||||
f"回答2:<在这里生成第二个回答(如有)>\n"
|
||||
f"..."
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
# 解析返回的问答对,处理多个问答对的情况
|
||||
def parse_multiple_qa(answer_text):
|
||||
qa_pairs = []
|
||||
# 通过正则表达式找到所有的问答对
|
||||
pattern = re.compile(r"问题\d+:(.*?)回答\d+:(.*?)(问题|$)", re.S)
|
||||
matches = pattern.findall(answer_text)
|
||||
|
||||
for match in matches:
|
||||
question = match[0].strip()
|
||||
answer = match[1].strip()
|
||||
qa_pairs.append({"input": question, "output": answer})
|
||||
|
||||
return qa_pairs
|
||||
|
||||
|
||||
# 迭代限制,防止API额度过大
|
||||
def checklen(text):
|
||||
while len(text) > 8000: # 限制在8000字符以内
|
||||
del text[0]
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
text_data.clear()
|
||||
file_name = 'output/train_optimized_multiple.jsonl'
|
||||
conversations = []
|
||||
|
||||
# 加载上次的进度
|
||||
start_index = load_checkpoint()
|
||||
|
||||
# 从断点开始继续生成问答对
|
||||
# 从断点开始继续生成问答对
|
||||
for i in tqdm(range(start_index, len(paragraphs))): # 处理所有剩余的段落
|
||||
content = paragraphs[i].strip() # 去除段落前后的空格
|
||||
print("====================\ncontent:", content, "\n==================\n")
|
||||
if len(content) == 0:
|
||||
continue
|
||||
|
||||
# 构建 LLM 的 prompt
|
||||
prompt = create_prompt(content)
|
||||
question = checklen([{"role": "user", "content": prompt}])
|
||||
|
||||
# 调用 LLM 生成问答对
|
||||
SparkApi.answer = "" # 清空之前的回答
|
||||
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question) # 调用API获取回答
|
||||
|
||||
# 将生成的文本分割为问题和回答
|
||||
answer_text = SparkApi.answer.strip()
|
||||
|
||||
# 解析多个问答对
|
||||
qa_pairs = parse_multiple_qa(answer_text)
|
||||
|
||||
for qa_pair in qa_pairs:
|
||||
conversation = {
|
||||
"input": qa_pair['input'],
|
||||
"output": qa_pair['output']
|
||||
}
|
||||
|
||||
# 将对话数据添加到文件中
|
||||
with open(file_name, 'a', encoding='utf-8') as file:
|
||||
json.dump(conversation, file, ensure_ascii=False)
|
||||
file.write("\n")
|
||||
|
||||
# 每处理完一个段落,保存当前的进度索引
|
||||
save_checkpoint(i)
|
||||
|
||||
print(f"已生成 {file_name} 文件,包含问答对。")
|
||||
|
32
generate_data/EC_process/jsonl2json.py
Normal file
32
generate_data/EC_process/jsonl2json.py
Normal file
@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/10/24 20:47
|
||||
# @Author : 黄子寒
|
||||
# @Email : 1064071566@qq.com
|
||||
# @File : jsonl2json.py
|
||||
# @Project : EmoLLM
|
||||
import json
|
||||
|
||||
|
||||
input_file = 'output/fine_tune_data.jsonl'
|
||||
output_file = 'output/fine_tune_data.json'
|
||||
|
||||
|
||||
data_list = []
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
entry = json.loads(line.strip())
|
||||
|
||||
new_entry = {
|
||||
"instruction": entry.get("instruction", ""),
|
||||
"input": entry.get("input", ""),
|
||||
"output": entry.get("output", ""),
|
||||
"system": entry.get("system", ""),
|
||||
"history": entry.get("history", [])
|
||||
}
|
||||
data_list.append(new_entry)
|
||||
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data_list, f, ensure_ascii=False, indent=4)
|
||||
|
||||
print(f" {output_file}")
|
1848
generate_data/EC_process/output/train_expanded_part2.jsonl
Normal file
1848
generate_data/EC_process/output/train_expanded_part2.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
50
generate_data/EC_process/processPDF/OCR.py
Normal file
50
generate_data/EC_process/processPDF/OCR.py
Normal file
@ -0,0 +1,50 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/10/18 22:09
|
||||
# @Author : 黄子寒
|
||||
# @Email : 1064071566@qq.com
|
||||
# @File : OCR.py
|
||||
# @Project : EmoLLM
|
||||
import cv2
|
||||
from paddleocr import PaddleOCR
|
||||
import os
|
||||
import glob
|
||||
|
||||
# 初始化OCR模型
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang='ch')
|
||||
|
||||
|
||||
image_dir = "output"
|
||||
output_txt_dir = "output_txt"
|
||||
|
||||
|
||||
if not os.path.exists(output_txt_dir):
|
||||
os.makedirs(output_txt_dir)
|
||||
|
||||
image_list = glob.glob(os.path.join(image_dir, "*.png"))
|
||||
|
||||
# 批量识别处理
|
||||
for img_path in image_list:
|
||||
# 读取图像
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
# 使用OCR模型进行识别
|
||||
result = ocr.ocr(img)
|
||||
|
||||
# 获取图像文件名(不带扩展名)
|
||||
img_name = os.path.splitext(os.path.basename(img_path))[0]
|
||||
|
||||
# 将OCR结果整理为文本
|
||||
txt_file_path = os.path.join(output_txt_dir, f"{img_name}.txt")
|
||||
|
||||
# 打开文件以写入OCR结果
|
||||
with open(txt_file_path, 'w', encoding='utf-8') as f:
|
||||
for line in result:
|
||||
for word_info in line:
|
||||
# 提取识别到的文本和其置信度
|
||||
word, confidence = word_info[1][0], word_info[1][1]
|
||||
|
||||
f.write(f"{word}\n")
|
||||
|
||||
print(f"Word: {word}, Confidence: {confidence}")
|
||||
|
||||
print(f"{txt_file_path}")
|
39
generate_data/EC_process/processPDF/PDF2Pic.py
Normal file
39
generate_data/EC_process/processPDF/PDF2Pic.py
Normal file
@ -0,0 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/10/21 22:09
|
||||
# @Author : 黄子寒
|
||||
# @Email : 1064071566@qq.com
|
||||
# @File : PDF2Pic.py
|
||||
# @Project : EmoLLM
|
||||
import fitz # PyMuPDF
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
# PDF 文件路径和输出图像保存目录
|
||||
pdf_file_path = "input.pdf"
|
||||
output_image_dir = "output"
|
||||
|
||||
# 创建输出目录
|
||||
if not os.path.exists(output_image_dir):
|
||||
os.makedirs(output_image_dir)
|
||||
|
||||
# 打开 PDF 文件
|
||||
pdf_document = fitz.open(pdf_file_path)
|
||||
|
||||
# 遍历每一页并保存为图像
|
||||
for page_number in range(len(pdf_document)):
|
||||
# 获取当前页对象
|
||||
page = pdf_document.load_page(page_number)
|
||||
|
||||
# 将页面转换为图像
|
||||
zoom = 4
|
||||
mat = fitz.Matrix(zoom, zoom)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
|
||||
image_path = os.path.join(output_image_dir, f"{page_number + 1}.png")
|
||||
pix.save(image_path)
|
||||
|
||||
print(f"Saved {image_path}")
|
||||
|
||||
|
||||
pdf_document.close()
|
25
generate_data/EC_process/processPDF/mergeTXT.py
Normal file
25
generate_data/EC_process/processPDF/mergeTXT.py
Normal file
@ -0,0 +1,25 @@
|
||||
import os
|
||||
import re
|
||||
import natsort
|
||||
|
||||
folder_path = "output_txt"
|
||||
combined_text = ""
|
||||
|
||||
# 使用自然排序来读取文件
|
||||
for filename in natsort.natsorted(os.listdir(folder_path)):
|
||||
if filename.endswith(".txt"):
|
||||
file_path = os.path.join(folder_path, filename)
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
combined_text += file.read()
|
||||
|
||||
|
||||
combined_text = combined_text.replace('\n', '')
|
||||
|
||||
# 处理连续三个或更多相同的标点符号
|
||||
combined_text = re.sub(r'([。,!?:;. ·])\1{2,}', r'\1', combined_text)
|
||||
|
||||
# 将清洗后的文本保存到一个新的文件中
|
||||
with open("cleaned_data.txt", 'w', encoding='utf-8') as file:
|
||||
file.write(combined_text)
|
||||
|
||||
print("数据处理完成")
|
84
generate_data/EC_process/process_missing_QA.py
Normal file
84
generate_data/EC_process/process_missing_QA.py
Normal file
@ -0,0 +1,84 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import SparkApi
|
||||
|
||||
# 输入文件路径
|
||||
input_file = 'output/train_expanded.jsonl'
|
||||
# 断点文件路径
|
||||
checkpoint_file = 'output/expand_checkpoint.txt'
|
||||
# 临时文件路径
|
||||
temp_file = 'output/tmp_train_expanded.jsonl'
|
||||
|
||||
|
||||
# 调用API生成回答
|
||||
def generate_answer_via_api(question):
|
||||
appid = "48d04aae"
|
||||
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
|
||||
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
|
||||
domain = "4.0Ultra"
|
||||
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||
|
||||
prompt = (
|
||||
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
|
||||
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
|
||||
f"每个回答应该准确且不超过50字,同时不少于20字,以保证内容的简洁和有用性。\n"
|
||||
f"问题:{question}\n\n"
|
||||
f"请生成一个详细回答。"
|
||||
)
|
||||
|
||||
question_data = [{"role": "user", "content": prompt}]
|
||||
SparkApi.answer = ""
|
||||
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question_data)
|
||||
return SparkApi.answer.strip()
|
||||
|
||||
|
||||
# 加载断点进度
|
||||
def load_checkpoint():
|
||||
if os.path.exists(checkpoint_file):
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
return int(f.read().strip()) # 返回已处理的行索引
|
||||
return 0 # 没有断点则从0开始
|
||||
|
||||
|
||||
# 保存断点进度
|
||||
def save_checkpoint(index):
|
||||
with open(checkpoint_file, 'w') as f:
|
||||
f.write(str(index))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 加载断点进度
|
||||
start_index = load_checkpoint()
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as f, open(temp_file, 'w', encoding='utf-8') as temp_f:
|
||||
for i, line in enumerate(tqdm(f)):
|
||||
item = json.loads(line)
|
||||
|
||||
# 从断点开始处理
|
||||
if i >= start_index:
|
||||
input_content = item['input']
|
||||
output_content = item['output']
|
||||
|
||||
# # 检查是否是未提供回答的问答对
|
||||
# if "未给" in output_content:
|
||||
# # 使用API生成新的回答
|
||||
# new_answer = generate_answer_via_api(input_content)
|
||||
# item['output'] = new_answer
|
||||
|
||||
if len(output_content)<11:
|
||||
# 使用API生成新的回答
|
||||
new_answer = generate_answer_via_api(input_content)
|
||||
item['output'] = new_answer
|
||||
|
||||
# 保存当前的进度索引
|
||||
save_checkpoint(i)
|
||||
|
||||
# 写入更新内容到临时文件
|
||||
json.dump(item, temp_f, ensure_ascii=False)
|
||||
temp_f.write('\n')
|
||||
|
||||
# 替换原始文件
|
||||
os.replace(temp_file, input_file)
|
||||
print(f"已更新 {input_file} 文件,包含重新生成的回答。")
|
58
generate_data/EC_process/topic_model.py
Normal file
58
generate_data/EC_process/topic_model.py
Normal file
@ -0,0 +1,58 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/10/23 23:16
|
||||
# @Author : 黄子寒
|
||||
# @Email : 1064071566@qq.com
|
||||
# @File : topic_model.py
|
||||
# @Project : EmoLLM
|
||||
import json
|
||||
import gensim
|
||||
from gensim import corpora
|
||||
from nltk.tokenize import word_tokenize
|
||||
from nltk.corpus import stopwords
|
||||
from collections import defaultdict
|
||||
|
||||
# 加载问答对数据
|
||||
def load_qa_data(file_path):
|
||||
qa_pairs = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
qa_pairs.append(json.loads(line.strip()))
|
||||
return qa_pairs
|
||||
|
||||
# 文本预处理
|
||||
def preprocess_text(text):
|
||||
stop_words = set(stopwords.words('english'))
|
||||
tokens = word_tokenize(text.lower())
|
||||
tokens = [word for word in tokens if word.isalnum() and word not in stop_words]
|
||||
return tokens
|
||||
|
||||
# 生成LDA主题模型
|
||||
def build_lda_model(qa_pairs, num_topics=5):
|
||||
# 处理所有问题文本
|
||||
questions = [qa['input'] for qa in qa_pairs]
|
||||
processed_questions = [preprocess_text(question) for question in questions]
|
||||
|
||||
# 创建字典和词袋模型
|
||||
dictionary = corpora.Dictionary(processed_questions)
|
||||
corpus = [dictionary.doc2bow(text) for text in processed_questions]
|
||||
|
||||
# 训练LDA模型
|
||||
lda_model = gensim.models.ldamodel.LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
|
||||
return lda_model, dictionary, corpus
|
||||
|
||||
# 打印每个主题的关键词
|
||||
def print_topics(lda_model, num_words=10):
|
||||
for idx, topic in lda_model.print_topics(num_words=num_words):
|
||||
print(f"主题 {idx}: {topic}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
qa_file = "output/train_optimized_multiple.jsonl" # 问答对文件
|
||||
|
||||
# 加载问答对
|
||||
qa_pairs = load_qa_data(qa_file)
|
||||
|
||||
# 构建LDA主题模型
|
||||
lda_model, dictionary, corpus = build_lda_model(qa_pairs, num_topics=5)
|
||||
|
||||
# 打印主题及其关键词
|
||||
print_topics(lda_model)
|
@ -5,8 +5,34 @@ streamlit==1.24.0
|
||||
sentencepiece==0.1.99
|
||||
accelerate==0.24.1
|
||||
transformers_stream_generator==0.0.4
|
||||
openxlab
|
||||
openxlab~=0.0.11
|
||||
tiktoken
|
||||
einops
|
||||
oss2
|
||||
requests
|
||||
requests~=2.32.3
|
||||
|
||||
pyjwt~=2.8.0
|
||||
loguru~=0.6.0
|
||||
yaml~=0.2.5
|
||||
pyyaml~=6.0.1
|
||||
tqdm~=4.66.2
|
||||
langchain~=0.0.352
|
||||
torch~=2.5.0
|
||||
metagpt~=0.8.1
|
||||
erniebot~=0.5.9
|
||||
python-dotenv~=1.0.0
|
||||
zhipuai~=2.0.1
|
||||
uvicorn~=0.32.0
|
||||
fastapi~=0.115.2
|
||||
opencv-python~=4.10.0.84
|
||||
paddleocr~=2.9.0
|
||||
dashscope~=1.14.1
|
||||
numpy~=1.24.3
|
||||
jieba~=0.42.1
|
||||
nltk~=3.9.1
|
||||
setuptools~=65.6.3
|
||||
websocket~=0.2.1
|
||||
websocket-client~=1.6.2
|
||||
gensim~=4.3.3
|
||||
pillow~=9.5.0
|
||||
natsort~=8.4.0
|
Loading…
Reference in New Issue
Block a user