自定义数据集处理脚本

This commit is contained in:
黄子寒 2024-11-11 17:32:36 +08:00
parent 2065b2176c
commit 1125b67f50
18 changed files with 4737 additions and 2 deletions

View File

@ -0,0 +1,217 @@
import base64
import hashlib
import hmac
import json
import random
import time
from datetime import datetime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import numpy as np
import requests
from sklearn.metrics.pairwise import cosine_similarity
import os
class AssembleHeaderException(Exception):
def __init__(this, msg):
this.message = msg
class Url:
def __init__(this, host, path, schema):
this.host = host
this.path = path
this.schema = schema
pass
# calculate sha256 and encode to base64
def sha256base64(data):
sha256 = hashlib.sha256()
sha256.update(data)
digest = base64.b64encode(sha256.digest()).decode(encoding='utf-8')
return digest
def parse_url(requset_url):
stidx = requset_url.index("://")
host = requset_url[stidx + 3:]
schema = requset_url[:stidx + 3]
edidx = host.index("/")
if edidx <= 0:
raise AssembleHeaderException("invalid request url:" + requset_url)
path = host[edidx:]
host = host[:edidx]
u = Url(host, path, schema)
return u
# 生成鉴权url
def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
u = parse_url(requset_url)
host = u.host
path = u.path
now = datetime.now()
date = format_date_time(time.mktime(now.timetuple()))
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path)
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
api_key, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
values = {
"host": host,
"date": date,
"authorization": authorization
}
return requset_url + "?" + urlencode(values)
def get_Body(appid, text, style):
org_content = json.dumps(text).encode('utf-8')
body = {
"header": {
"app_id": appid,
"uid": "39769795890",
"status": 3
},
"parameter": {
"emb": {
"domain": style,
"feature": {
"encoding": "utf8"
}
}
},
"payload": {
"messages": {
"text": base64.b64encode(json.dumps(text).encode('utf-8')).decode()
}
}
}
return body
# 发起请求并返回结果
def get_embp_embedding(text, appid, apikey, apisecret):
host = 'https://emb-cn-huabei-1.xf-yun.com/'
url = assemble_ws_auth_url(host, method='POST', api_key=apikey, api_secret=apisecret)
content = get_Body(appid, text, "para")
response = requests.post(url, json=content, headers={'content-type': "application/json"}).text
return response
# 解析结果并输出
def parser_Message(message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
return None
else:
text_base = data["payload"]["feature"]["text"]
text_data = base64.b64decode(text_base)
dt = np.dtype(np.float32).newbyteorder("<")
text = np.frombuffer(text_data, dtype=dt)
return text
# 加载问答对数据
def load_qa_data(file_path):
qa_pairs = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
qa_pairs.append(json.loads(line.strip()))
return qa_pairs
# 保存embedding到文件
def save_embeddings(embeddings, file_path):
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(embeddings, f, ensure_ascii=False)
# 获取文本的embedding
def get_embedding_for_text(text, appid, apikey, apisecret):
desc = {"messages": [{"content": text, "role": "user"}]}
res = get_embp_embedding(desc, appid=appid, apikey=apikey, apisecret=apisecret)
return parser_Message(res)
# 逐行加载已存在的embedding
def load_embeddings(file_path):
embeddings = {}
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip(): # 忽略空行
embedding_data = json.loads(line.strip())
embeddings.update(embedding_data)
except FileNotFoundError:
print(f"文件 {file_path} 不存在,将创建新文件")
return embeddings
# 逐行保存embedding到文件
def save_embedding_line_by_line(qa, embedding, file_path):
if embedding is not None:
embedding_as_list = embedding.tolist() # 将numpy array转换为列表
with open(file_path, 'a', encoding='utf-8') as f:
json.dump({qa: embedding_as_list}, f, ensure_ascii=False)
f.write("\n") # 每行一个embedding
# 获取单个问题的embedding并处理请求错误
def get_embedding_with_retry(question, appid, apikey, apisecret, max_retries=5):
retries = 0
while retries < max_retries:
try:
embedding = get_embedding_for_text(question, appid, apikey, apisecret)
if embedding is not None:
return embedding
except Exception as e:
print(f"请求错误: {e}")
retries += 1
print(f"重试第 {retries} 次...")
time.sleep(5) # 每次重试前等待 5 秒
print(f"获取'{question}' 的embedding失败")
return None
# 获取所有问答对的embedding并逐行保存
def get_and_save_embeddings(qa_pairs, appid, apikey, apisecret, file_path, qps_limit=2):
all_embeddings = load_embeddings(file_path) # 尝试加载已存在的embedding
interval = 1 / qps_limit # 根据QPS限制设置间隔时间
for qa in qa_pairs:
question = qa['input']
if question in all_embeddings:
print(f"'{question}' 的embedding已存在跳过计算")
continue
print(f"计算'{question}' 的embedding...")
embedding = get_embedding_with_retry(question, appid, apikey, apisecret) # 带重试机制的请求
if embedding is not None:
save_embedding_line_by_line(question, embedding, file_path) # 逐行保存
all_embeddings[question] = embedding # 更新已计算的embedding
time.sleep(interval) # 确保符合QPS限制
if __name__ == '__main__':
# 设置路径
qa_file = "output/train_optimized_multiple.jsonl" # 原问答对文件
embedding_file = "output/qa_embeddings.json" # embedding存储文件
appid = "f0f73de5"
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
api_key = "5773f6f95563708de994d17b7ea5d414"
# 加载数据
qa_pairs = load_qa_data(qa_file)
# 获取并保存embedding
get_and_save_embeddings(qa_pairs, appid, api_key, api_secret, embedding_file)
print(f"已保存所有问答对的embedding到 {embedding_file}")

View File

@ -0,0 +1,122 @@
import json
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import jsonlines
# 加载问答对嵌入
qa_embeddings = {}
with jsonlines.open('output/qa_embeddings.json', 'r') as reader:
for obj in reader:
qa_embeddings.update(obj) # 将每行的json对象加入到qa_embeddings
# 加载问答对
qa_pairs = []
with open('output/train_optimized_multiple.jsonl', 'r', encoding='utf-8') as f:
for line in f:
qa_pairs.append(json.loads(line))
# 提取嵌入和问题
questions = list(qa_embeddings.keys())
embeddings = np.array(list(qa_embeddings.values()))
# 关键词及其类别
categories = {
"栽培油橄榄的意义": ["栽培油橄榄", "经济价值", "引种"],
"油橄榄属植物分类": ["油橄榄属", "植物分类", "植物种", "原产地"],
"油橄榄生物学特性": ["根系类型", "土壤关系", "花芽分化", "花序", "授粉特性", "果实发育", "油脂形成"],
"油橄榄的生态环境条件": ["气候条件", "温度", "光照", "水分", "土壤生态", "海拔高度", "坡度"],
"油橄榄品种": ["佛奥", "莱星", "皮削利", "阿斯", "配多灵", "果大尔", "皮瓜尔", "科拉蒂", "克里", "爱桑", "贝拉", "实生种"],
"油橄榄育苗技术": ["育苗场地", "种子繁殖", "实生苗", "嫁接繁殖", "砧木", "接穗", "扦插繁殖", "组织培养"],
"油橄榄种植": ["园地选择", "种植密度", "栽植方式", "栽后管理"],
"土壤、肥料、水管理": ["土壤管理", "矿质营养", "果园灌溉", "果实采收"],
"整形修剪": ["整形修剪", "生物学原理", "结果习性", "树形", "幼树修剪", "复壮修剪"],
"病虫害防治": ["孔雀斑病", "炭疽病", "黄萎病", "肿瘤病", "根腐病", "云斑天牛", "油橄榄片盾", "大粒横沟象"]
}
# 初始化类别关键词的嵌入字典
category_embeddings = {category: [] for category in categories}
# 假设我们有一个方法来计算关键词的嵌入例如从qa_embeddings中获取
def get_keyword_embedding(keyword):
return qa_embeddings.get(keyword, None)
# 为每个类别生成关键词的嵌入
for category, keywords in categories.items():
for keyword in keywords:
keyword_embedding = get_keyword_embedding(keyword)
if keyword_embedding is not None:
category_embeddings[category].append(keyword_embedding)
# 将类别关键词的嵌入转化为平均向量
for category in category_embeddings:
if category_embeddings[category]:
category_embeddings[category] = np.mean(category_embeddings[category], axis=0)
else:
category_embeddings[category] = np.zeros(embeddings.shape[1]) # 默认空向量
# 计算每个问题与类别之间的相似度
category_similarities = {}
for idx, question in enumerate(questions):
question_embedding = embeddings[idx]
category_similarities[question] = {}
for category, category_embedding in category_embeddings.items():
similarity = cosine_similarity([question_embedding], [category_embedding])[0][0]
category_similarities[question][category] = similarity
# 将每个问题分配到相似度最高的类别
category_assignments = {category: [] for category in categories}
for question in questions:
best_category = max(category_similarities[question], key=category_similarities[question].get)
category_assignments[best_category].append(question)
# 整合并生成新的jsonl格式确保每个问答对都被包括
fine_tune_data = []
for category, assigned_questions in category_assignments.items():
for idx, question in enumerate(assigned_questions):
history = []
output = ""
instruction = ""
# 查找当前问题及其回答
qa_pair = next((qa for qa in qa_pairs if qa['input'] == question), None)
if qa_pair:
instruction = qa_pair['input'] # 当前问题作为instruction
output = qa_pair['output'] # 当前问题的回答作为output
# 从同一类别的其他问题构建history保证每个history与当前问题在同一类别
history_similarities = []
for related_question in assigned_questions:
if related_question != question:
related_embedding = qa_embeddings[related_question]
similarity = cosine_similarity([qa_embeddings[question]], [related_embedding])[0][0]
history_similarities.append((related_question, similarity))
# 按相似度排序并选择前1~3个问题作为history
history_similarities = sorted(history_similarities, key=lambda x: x[1], reverse=True)
for related_question, _ in history_similarities[:3]:
related_qa_pair = next((qa for qa in qa_pairs if qa['input'] == related_question), None)
if related_qa_pair:
history.append([related_qa_pair['input'], related_qa_pair['output']])
# 构建最终格式
if instruction and output:
fine_tune_entry = {
"instruction": instruction,
"input": "", # input为空
"output": output, # 当前问题的回答
"history": history, # 最多包含3条相关问题
"system": "你是一位油橄榄栽培专家,熟知油橄榄的品种分类、栽培技术、生态环境要求以及病虫害防治。"
}
fine_tune_data.append(fine_tune_entry)
# 保存新的jsonl格式
with open('output/fine_tune_data.jsonl', 'w', encoding='utf-8') as f:
for entry in fine_tune_data:
json.dump(entry, f, ensure_ascii=False)
f.write('\n')
print("对话数据整理完成")

View File

@ -0,0 +1,71 @@
# -*- coding: utf-8 -*-
# @Time : 2024/10/24 11:10
# @Author : 黄子寒
# @Email : 1064071566@qq.com
# @File : LDArec.py
# @Project : EmoLLM
import json
import jieba
from gensim import corpora
from gensim.models.ldamodel import LdaModel
from collections import defaultdict
# 加载问答对数据
def load_qa_data(file_path):
qa_pairs = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
qa_pairs.append(json.loads(line.strip()))
return qa_pairs
# 加载中文停用词
def load_stopwords(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return set([line.strip() for line in f])
# 使用jieba对中文文本进行分词并去除停用词
def preprocess_text(text, stopwords):
words = jieba.lcut(text) # 使用jieba进行中文分词
words = [word for word in words if word not in stopwords and len(word) > 1] # 去除停用词和长度为1的词
return words
# 生成LDA主题模型
def build_lda_model(qa_pairs, stopwords, num_topics=5):
# 处理所有问题文本
questions = [qa['input'] for qa in qa_pairs]
processed_questions = [preprocess_text(question, stopwords) for question in questions]
# 创建字典和词袋模型
dictionary = corpora.Dictionary(processed_questions)
corpus = [dictionary.doc2bow(text) for text in processed_questions]
# 训练LDA模型
lda_model = LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
return lda_model, dictionary, corpus
# 打印每个主题的关键词
def print_topics(lda_model, num_words=10):
for idx, topic in lda_model.print_topics(num_words=num_words):
print(f"主题 {idx}: {topic}")
if __name__ == '__main__':
qa_file = "output/train_optimized_multiple.jsonl" # 问答对文件
stopwords_file = "chinese_stopwords.txt" # 停用词文件
# 加载问答对
qa_pairs = load_qa_data(qa_file)
# 加载停用词
stopwords = load_stopwords(stopwords_file)
# 构建LDA主题模型
lda_model, dictionary, corpus = build_lda_model(qa_pairs, stopwords, num_topics=20)
# 打印主题及其关键词
print_topics(lda_model)

View File

@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
import json
import random
# 定义生成5000条数据集的函数
def generate_dataset(num_samples=5000):
dataset = []
invoke_types = [1, 2, 3]
area_codes = [chr(i) for i in range(ord('A'), ord('Z') + 1)]
parameters = [
{"name": "土壤湿度", "unit": "%", "min": 10, "max": 100},
{"name": "土壤温度", "unit": "", "min": 5, "max": 40},
{"name": "空气温度", "unit": "", "min": -10, "max": 45},
{"name": "电导率", "unit": "mS/cm", "min": 0.1, "max": 5.0}
]
for _ in range(num_samples):
invoke_type = random.choice(invoke_types)
area_code = random.choice(area_codes)
parameter = random.choice(parameters)
if isinstance(parameter["min"], int):
value = round(random.uniform(parameter["min"], parameter["max"]), 1)
else:
value = round(random.uniform(parameter["min"], parameter["max"]), 1)
# 增加多变的提问方式,使数据更自然化
instruction_templates = [
f"现在{area_code}种植区内{parameter['name']}如何?",
f"请告诉我{area_code}区的{parameter['name']}情况。",
f"{area_code}区当前的{parameter['name']}是多少?",
f"我想知道{area_code}区的{parameter['name']}",
f"{area_code}区的{parameter['name']}现在是多少?",
f"{area_code}种植区目前的{parameter['name']}是多少?",
f"能提供{area_code}区的{parameter['name']}数据吗?",
f"{area_code}种植区的{parameter['name']}是多少?",
f"请查询{area_code}区的{parameter['name']}",
f"{area_code}区现在的{parameter['name']}数据是多少?",
f"帮我看看{area_code}{parameter['name']}的情况。",
f"{area_code}区的{parameter['name']}值是多少?",
f"帮我查一下{area_code}区的{parameter['name']}",
f"{area_code}区的{parameter['name']}现在什么情况?",
f"请帮我查一下{area_code}种植区的{parameter['name']}是多少?",
f"我需要知道{area_code}区的{parameter['name']}数据。",
f"请问{area_code}区的{parameter['name']}如何?",
f"帮我查询{area_code}区的{parameter['name']}情况。",
f"现在{area_code}区的{parameter['name']}值是多少?"
]
instruction = random.choice(instruction_templates)
output = f"{area_code}区现在{parameter['name']}{value}{parameter['unit']}"
data = {
"instruction": instruction,
"invokeType": str(invoke_type),
"areaCode": area_code,
"output": output
}
dataset.append(data)
return dataset
# 生成数据并保存为json文件
if __name__ == '__main__':
dataset = generate_dataset()
output_file = 'output/synthetic_dataset.json'
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(dataset, f, ensure_ascii=False, indent=4)
print(f"已生成 {output_file} 文件,包含{len(dataset)}条数据。")

View File

@ -0,0 +1,136 @@
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
answer = ""
class Ws_Param(object):
# 初始化
def __init__(self, APPID, APIKey, APISecret, Spark_url):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
self.host = urlparse(Spark_url).netloc
self.path = urlparse(Spark_url).path
self.Spark_url = Spark_url
# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
# 拼接鉴权参数生成url
url = self.Spark_url + '?' + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
# 收到websocket错误的处理
def on_error(ws, error):
print("### error:", error)
# 收到websocket关闭的处理
def on_close(ws,one,two):
print(" ")
# 收到websocket连接建立的处理
def on_open(ws):
thread.start_new_thread(run, (ws,))
def run(ws, *args):
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
ws.send(data)
# 收到websocket消息的处理
def on_message(ws, message):
# print(message)
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
print(content,end ="")
global answer
answer += content
# print(1)
if status == 2:
ws.close()
def gen_params(appid, domain,question):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"temperature": 0.5,
"max_tokens": 2048
}
},
"payload": {
"message": {
"text": question
}
}
}
return data
def main(appid, api_key, api_secret, Spark_url,domain, question):
# print("星火:")
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
ws.appid = appid
ws.question = question
ws.domain = domain
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})

View File

@ -0,0 +1,24 @@
import requests
import json
url = "https://chatapi.midjourney-vip.cn/v1/chat/completions"
payload = json.dumps({
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "测试"
}
]
})
headers = {
'Accept': 'application/json',
'Authorization': 'sk-ATDf2Ax1YTGeeTaBD9Be2a7bE0064618Ae3378EaF0Df6f24',
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
print(response.text)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,66 @@
栽培油橄榄
经济价值
引种
油橄榄属
植物分类
植物种
原产地
根系类型
土壤关系
花芽分化
花序
授粉特性
果实发育
油脂形成
气候条件
温度
光照
水分
土壤生态
海拔高度
坡度
佛奥
莱星
皮削利
阿斯
配多灵
果大尔
皮瓜尔
科拉蒂
克里
爱桑
贝拉
实生种
育苗场地
种子繁殖
实生苗
嫁接繁殖
砧木
接穗
扦插繁殖
组织培养
园地选择
种植密度
栽植方式
栽后管理
土壤管理
矿质营养
果园灌溉
果实采收
整形修剪
生物学原理
结果习性
树形
幼树修剪
复壮修剪
孔雀斑病
炭疽病
黄萎病
肿瘤病
根腐病
云斑天牛
油橄榄片盾
大粒横沟象
引进品种名录
中英对照品种名称
病虫害判定表

View File

@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
import json
import os
import re
from tqdm import tqdm
import SparkApi
# 输入文件路径
input_file = 'output/train_expanded.jsonl'
# 输出文件路径
output_file = 'output/train_expanded_2.jsonl'
# 断点文件路径
checkpoint_file = 'output/e2_progress_checkpoint.txt'
# 调用API生成问答对
def generate_qa_via_api(content):
appid = "48d04aae"
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
domain = "4.0Ultra"
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
prompt = (
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
f"1. 根据给定内容生成**三个**相关的问题和回答。\n"
f"2. 你可以简化问题、提取具体要素进行提问,或扩展内容生成额外的相关问题。\n"
f"3. **问题必须简洁明了**,并涵盖内容中的关键信息。\n"
f"4. 每个回答应该准确且**不超过50字**,同时**不少于20字**,以保证内容的简洁和有用性。\n"
f"5. 仅围绕油橄榄栽培的相关内容生成问答对,忽略其他无关信息。\n\n"
f"以下是给定内容:\n\n"
f"内容:{content}\n\n"
f"请按如下格式生成输出:\n"
f"问题1<生成第一个问题>\n"
f"回答1<生成第一个回答>\n"
f"问题2<生成第二个问题>\n"
f"回答2<生成第二个回答>\n"
f"问题3<生成第三个问题>\n"
f"回答3<生成第三个回答>\n\n"
f"请确保每个问题和回答都保持与内容的紧密相关性,并保持专业性。"
)
question = [{"role": "user", "content": prompt}]
SparkApi.answer = ""
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question)
return SparkApi.answer.strip()
# 加载断点进度
def load_checkpoint():
if os.path.exists(checkpoint_file):
with open(checkpoint_file, 'r') as f:
return int(f.read().strip()) # 返回已处理的行索引
return 0 # 没有断点则从0开始
# 保存断点进度
def save_checkpoint(index):
with open(checkpoint_file, 'w') as f:
f.write(str(index))
# 解析返回的问答对,处理多个问答对的情况
def parse_multiple_qa(answer_text):
qa_pairs = []
# 通过正则表达式找到所有的问答对
pattern = re.compile(r"问题\d+(.*?)回答\d+(.*?)(问题|$)", re.S)
matches = pattern.findall(answer_text)
for match in matches:
question = match[0].strip()
answer = match[1].strip()
qa_pairs.append({"input": question, "output": answer})
return qa_pairs
if __name__ == '__main__':
# 加载原始数据集
with open(input_file, 'r', encoding='utf-8') as f:
text_data = [json.loads(line) for line in f]
# 加载断点进度
start_index = load_checkpoint()
# 从断点开始继续生成问答对
with open(output_file, 'a', encoding='utf-8') as f:
for i in tqdm(range(start_index, len(text_data))):
item = text_data[i]
input_content = item['input']
try:
# 使用API生成新的问答对
api_generated_qa = generate_qa_via_api(input_content)
# 解析API生成的问答对并添加到数据集
qa_pairs = parse_multiple_qa(api_generated_qa)
expanded_data = [{"input": qa_pair['input'], "output": qa_pair['output']} for qa_pair in qa_pairs]
# 保存生成的问答对
for qa in expanded_data:
json.dump(qa, f, ensure_ascii=False)
f.write('\n')
# 保存当前的进度索引
save_checkpoint(i)
except Exception as e:
print(f"Error processing item {i}: {e}")
# 跳过当前条目继续处理
save_checkpoint(i)
continue
print(f"已生成 {output_file} 文件,包含扩展的问答对。")

View File

@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
# @Time : 2024/10/22
# @Author : 黄子寒
# @File : generate_qa_with_multiple_pairs.py
# @Project : EmoLLM
import os
import re
from tqdm import tqdm
import SparkApi
import json
appid = "f0f73de5"
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
api_key = "5773f6f95563708de994d17b7ea5d414"
# Spark服务地址及版本
domain = "4.0Ultra"
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
# 准备存储清洗后的文本
text_data = []
# 断点文件,用于存储上次处理的段落索引
checkpoint_file = "output/progress_checkpoint.txt"
# 加载处理好的文本文件
with open("../processPDF/cleaned_data.txt", "r", encoding="utf-8") as f:
cleaned_text = f.read()
# 自定义分割函数按最大100字以内的句子段落
def split_text_to_sentences(text, max_length=300):
sentences = re.split('(?<=。)', text)
grouped_sentences = []
current_group = ""
for sentence in sentences:
if len(current_group) + len(sentence) <= max_length:
current_group += sentence
else:
grouped_sentences.append(current_group.strip())
current_group = sentence
if current_group:
grouped_sentences.append(current_group.strip())
return grouped_sentences
# 加载断点进度
def load_checkpoint():
if os.path.exists(checkpoint_file):
with open(checkpoint_file, 'r') as f:
return int(f.read().strip()) # 返回已处理的段落索引
return 0 # 没有断点则从0开始
# 保存断点进度
def save_checkpoint(index):
with open(checkpoint_file, 'w') as f:
f.write(str(index))
# 将文本按要求的长度进行分割
paragraphs = split_text_to_sentences(cleaned_text, 300)
# 构建 LLM 生成 input 和 output 的详细 prompt允许模型生成多个问答对
def create_prompt(content):
prompt = (
f"你是一位油橄榄栽培专家。"
f"根据以下内容生成一个或多个问题和回答对,请保证语句通顺有逻辑,同时忽略所有内容中和图示相关的内容:\n\n"
f"内容:{content}\n\n"
f"请以如下格式生成输出:\n"
f"问题1<在这里生成第一个问题>\n"
f"回答1<在这里生成第一个回答>\n"
f"问题2<在这里生成第二个问题(如有)>\n"
f"回答2<在这里生成第二个回答(如有)>\n"
f"..."
)
return prompt
# 解析返回的问答对,处理多个问答对的情况
def parse_multiple_qa(answer_text):
qa_pairs = []
# 通过正则表达式找到所有的问答对
pattern = re.compile(r"问题\d+(.*?)回答\d+(.*?)(问题|$)", re.S)
matches = pattern.findall(answer_text)
for match in matches:
question = match[0].strip()
answer = match[1].strip()
qa_pairs.append({"input": question, "output": answer})
return qa_pairs
# 迭代限制防止API额度过大
def checklen(text):
while len(text) > 8000: # 限制在8000字符以内
del text[0]
return text
if __name__ == '__main__':
text_data.clear()
file_name = 'output/train_optimized_multiple.jsonl'
conversations = []
# 加载上次的进度
start_index = load_checkpoint()
# 从断点开始继续生成问答对
# 从断点开始继续生成问答对
for i in tqdm(range(start_index, len(paragraphs))): # 处理所有剩余的段落
content = paragraphs[i].strip() # 去除段落前后的空格
print("====================\ncontent:", content, "\n==================\n")
if len(content) == 0:
continue
# 构建 LLM 的 prompt
prompt = create_prompt(content)
question = checklen([{"role": "user", "content": prompt}])
# 调用 LLM 生成问答对
SparkApi.answer = "" # 清空之前的回答
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question) # 调用API获取回答
# 将生成的文本分割为问题和回答
answer_text = SparkApi.answer.strip()
# 解析多个问答对
qa_pairs = parse_multiple_qa(answer_text)
for qa_pair in qa_pairs:
conversation = {
"input": qa_pair['input'],
"output": qa_pair['output']
}
# 将对话数据添加到文件中
with open(file_name, 'a', encoding='utf-8') as file:
json.dump(conversation, file, ensure_ascii=False)
file.write("\n")
# 每处理完一个段落,保存当前的进度索引
save_checkpoint(i)
print(f"已生成 {file_name} 文件,包含问答对。")

View File

@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
# @Time : 2024/10/24 20:47
# @Author : 黄子寒
# @Email : 1064071566@qq.com
# @File : jsonl2json.py
# @Project : EmoLLM
import json
input_file = 'output/fine_tune_data.jsonl'
output_file = 'output/fine_tune_data.json'
data_list = []
with open(input_file, 'r', encoding='utf-8') as f:
for line in f:
entry = json.loads(line.strip())
new_entry = {
"instruction": entry.get("instruction", ""),
"input": entry.get("input", ""),
"output": entry.get("output", ""),
"system": entry.get("system", ""),
"history": entry.get("history", [])
}
data_list.append(new_entry)
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(data_list, f, ensure_ascii=False, indent=4)
print(f" {output_file}")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# @Time : 2024/10/18 22:09
# @Author : 黄子寒
# @Email : 1064071566@qq.com
# @File : OCR.py
# @Project : EmoLLM
import cv2
from paddleocr import PaddleOCR
import os
import glob
# 初始化OCR模型
ocr = PaddleOCR(use_angle_cls=True, lang='ch')
image_dir = "output"
output_txt_dir = "output_txt"
if not os.path.exists(output_txt_dir):
os.makedirs(output_txt_dir)
image_list = glob.glob(os.path.join(image_dir, "*.png"))
# 批量识别处理
for img_path in image_list:
# 读取图像
img = cv2.imread(img_path)
# 使用OCR模型进行识别
result = ocr.ocr(img)
# 获取图像文件名(不带扩展名)
img_name = os.path.splitext(os.path.basename(img_path))[0]
# 将OCR结果整理为文本
txt_file_path = os.path.join(output_txt_dir, f"{img_name}.txt")
# 打开文件以写入OCR结果
with open(txt_file_path, 'w', encoding='utf-8') as f:
for line in result:
for word_info in line:
# 提取识别到的文本和其置信度
word, confidence = word_info[1][0], word_info[1][1]
f.write(f"{word}\n")
print(f"Word: {word}, Confidence: {confidence}")
print(f"{txt_file_path}")

View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# @Time : 2024/10/21 22:09
# @Author : 黄子寒
# @Email : 1064071566@qq.com
# @File : PDF2Pic.py
# @Project : EmoLLM
import fitz # PyMuPDF
from PIL import Image
import os
# PDF 文件路径和输出图像保存目录
pdf_file_path = "input.pdf"
output_image_dir = "output"
# 创建输出目录
if not os.path.exists(output_image_dir):
os.makedirs(output_image_dir)
# 打开 PDF 文件
pdf_document = fitz.open(pdf_file_path)
# 遍历每一页并保存为图像
for page_number in range(len(pdf_document)):
# 获取当前页对象
page = pdf_document.load_page(page_number)
# 将页面转换为图像
zoom = 4
mat = fitz.Matrix(zoom, zoom)
pix = page.get_pixmap(matrix=mat)
image_path = os.path.join(output_image_dir, f"{page_number + 1}.png")
pix.save(image_path)
print(f"Saved {image_path}")
pdf_document.close()

View File

@ -0,0 +1,25 @@
import os
import re
import natsort
folder_path = "output_txt"
combined_text = ""
# 使用自然排序来读取文件
for filename in natsort.natsorted(os.listdir(folder_path)):
if filename.endswith(".txt"):
file_path = os.path.join(folder_path, filename)
with open(file_path, 'r', encoding='utf-8') as file:
combined_text += file.read()
combined_text = combined_text.replace('\n', '')
# 处理连续三个或更多相同的标点符号
combined_text = re.sub(r'([。,!?:;. ·])\1{2,}', r'\1', combined_text)
# 将清洗后的文本保存到一个新的文件中
with open("cleaned_data.txt", 'w', encoding='utf-8') as file:
file.write(combined_text)
print("数据处理完成")

View File

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
import json
import os
from tqdm import tqdm
import SparkApi
# 输入文件路径
input_file = 'output/train_expanded.jsonl'
# 断点文件路径
checkpoint_file = 'output/expand_checkpoint.txt'
# 临时文件路径
temp_file = 'output/tmp_train_expanded.jsonl'
# 调用API生成回答
def generate_answer_via_api(question):
appid = "48d04aae"
api_secret = "ZDE1ZGZmNTQ1YWYxZjcxYTI5Mjk0NGIz"
api_key = "3ad87d03c4e3a4fb7d7b36a7dfa3be00"
domain = "4.0Ultra"
Spark_url = "wss://spark-api.xf-yun.com/v4.0/chat"
prompt = (
f"你是一位油橄榄栽培领域的专家,需要基于给定内容生成高质量的问答对。"
f"生成的问答对用于油橄榄知识库微调,请确保问答的准确性和相关性。具体要求如下:\n"
f"每个回答应该准确且不超过50字同时不少于20字以保证内容的简洁和有用性。\n"
f"问题:{question}\n\n"
f"请生成一个详细回答。"
)
question_data = [{"role": "user", "content": prompt}]
SparkApi.answer = ""
SparkApi.main(appid, api_key, api_secret, Spark_url, domain, question_data)
return SparkApi.answer.strip()
# 加载断点进度
def load_checkpoint():
if os.path.exists(checkpoint_file):
with open(checkpoint_file, 'r') as f:
return int(f.read().strip()) # 返回已处理的行索引
return 0 # 没有断点则从0开始
# 保存断点进度
def save_checkpoint(index):
with open(checkpoint_file, 'w') as f:
f.write(str(index))
if __name__ == '__main__':
# 加载断点进度
start_index = load_checkpoint()
with open(input_file, 'r', encoding='utf-8') as f, open(temp_file, 'w', encoding='utf-8') as temp_f:
for i, line in enumerate(tqdm(f)):
item = json.loads(line)
# 从断点开始处理
if i >= start_index:
input_content = item['input']
output_content = item['output']
# # 检查是否是未提供回答的问答对
# if "未给" in output_content:
# # 使用API生成新的回答
# new_answer = generate_answer_via_api(input_content)
# item['output'] = new_answer
if len(output_content)<11:
# 使用API生成新的回答
new_answer = generate_answer_via_api(input_content)
item['output'] = new_answer
# 保存当前的进度索引
save_checkpoint(i)
# 写入更新内容到临时文件
json.dump(item, temp_f, ensure_ascii=False)
temp_f.write('\n')
# 替换原始文件
os.replace(temp_file, input_file)
print(f"已更新 {input_file} 文件,包含重新生成的回答。")

View File

@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
# @Time : 2024/10/23 23:16
# @Author : 黄子寒
# @Email : 1064071566@qq.com
# @File : topic_model.py
# @Project : EmoLLM
import json
import gensim
from gensim import corpora
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from collections import defaultdict
# 加载问答对数据
def load_qa_data(file_path):
qa_pairs = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
qa_pairs.append(json.loads(line.strip()))
return qa_pairs
# 文本预处理
def preprocess_text(text):
stop_words = set(stopwords.words('english'))
tokens = word_tokenize(text.lower())
tokens = [word for word in tokens if word.isalnum() and word not in stop_words]
return tokens
# 生成LDA主题模型
def build_lda_model(qa_pairs, num_topics=5):
# 处理所有问题文本
questions = [qa['input'] for qa in qa_pairs]
processed_questions = [preprocess_text(question) for question in questions]
# 创建字典和词袋模型
dictionary = corpora.Dictionary(processed_questions)
corpus = [dictionary.doc2bow(text) for text in processed_questions]
# 训练LDA模型
lda_model = gensim.models.ldamodel.LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
return lda_model, dictionary, corpus
# 打印每个主题的关键词
def print_topics(lda_model, num_words=10):
for idx, topic in lda_model.print_topics(num_words=num_words):
print(f"主题 {idx}: {topic}")
if __name__ == '__main__':
qa_file = "output/train_optimized_multiple.jsonl" # 问答对文件
# 加载问答对
qa_pairs = load_qa_data(qa_file)
# 构建LDA主题模型
lda_model, dictionary, corpus = build_lda_model(qa_pairs, num_topics=5)
# 打印主题及其关键词
print_topics(lda_model)

View File

@ -5,8 +5,34 @@ streamlit==1.24.0
sentencepiece==0.1.99 sentencepiece==0.1.99
accelerate==0.24.1 accelerate==0.24.1
transformers_stream_generator==0.0.4 transformers_stream_generator==0.0.4
openxlab openxlab~=0.0.11
tiktoken tiktoken
einops einops
oss2 oss2
requests requests~=2.32.3
pyjwt~=2.8.0
loguru~=0.6.0
yaml~=0.2.5
pyyaml~=6.0.1
tqdm~=4.66.2
langchain~=0.0.352
torch~=2.5.0
metagpt~=0.8.1
erniebot~=0.5.9
python-dotenv~=1.0.0
zhipuai~=2.0.1
uvicorn~=0.32.0
fastapi~=0.115.2
opencv-python~=4.10.0.84
paddleocr~=2.9.0
dashscope~=1.14.1
numpy~=1.24.3
jieba~=0.42.1
nltk~=3.9.1
setuptools~=65.6.3
websocket~=0.2.1
websocket-client~=1.6.2
gensim~=4.3.3
pillow~=9.5.0
natsort~=8.4.0