OliveSensorAPI/IOTLLM/generate_data/EC_process/Embedding_merge.py

218 lines
7.0 KiB
Python
Raw Permalink Normal View History

2024-11-11 17:32:36 +08:00
import base64
import hashlib
import hmac
import json
import random
import time
from datetime import datetime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import numpy as np
import requests
from sklearn.metrics.pairwise import cosine_similarity
import os
class AssembleHeaderException(Exception):
def __init__(this, msg):
this.message = msg
class Url:
def __init__(this, host, path, schema):
this.host = host
this.path = path
this.schema = schema
pass
# calculate sha256 and encode to base64
def sha256base64(data):
sha256 = hashlib.sha256()
sha256.update(data)
digest = base64.b64encode(sha256.digest()).decode(encoding='utf-8')
return digest
def parse_url(requset_url):
stidx = requset_url.index("://")
host = requset_url[stidx + 3:]
schema = requset_url[:stidx + 3]
edidx = host.index("/")
if edidx <= 0:
raise AssembleHeaderException("invalid request url:" + requset_url)
path = host[edidx:]
host = host[:edidx]
u = Url(host, path, schema)
return u
# 生成鉴权url
def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
u = parse_url(requset_url)
host = u.host
path = u.path
now = datetime.now()
date = format_date_time(time.mktime(now.timetuple()))
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path)
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
api_key, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
values = {
"host": host,
"date": date,
"authorization": authorization
}
return requset_url + "?" + urlencode(values)
def get_Body(appid, text, style):
org_content = json.dumps(text).encode('utf-8')
body = {
"header": {
"app_id": appid,
"uid": "39769795890",
"status": 3
},
"parameter": {
"emb": {
"domain": style,
"feature": {
"encoding": "utf8"
}
}
},
"payload": {
"messages": {
"text": base64.b64encode(json.dumps(text).encode('utf-8')).decode()
}
}
}
return body
# 发起请求并返回结果
def get_embp_embedding(text, appid, apikey, apisecret):
host = 'https://emb-cn-huabei-1.xf-yun.com/'
url = assemble_ws_auth_url(host, method='POST', api_key=apikey, api_secret=apisecret)
content = get_Body(appid, text, "para")
response = requests.post(url, json=content, headers={'content-type': "application/json"}).text
return response
# 解析结果并输出
def parser_Message(message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
return None
else:
text_base = data["payload"]["feature"]["text"]
text_data = base64.b64decode(text_base)
dt = np.dtype(np.float32).newbyteorder("<")
text = np.frombuffer(text_data, dtype=dt)
return text
# 加载问答对数据
def load_qa_data(file_path):
qa_pairs = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
qa_pairs.append(json.loads(line.strip()))
return qa_pairs
# 保存embedding到文件
def save_embeddings(embeddings, file_path):
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(embeddings, f, ensure_ascii=False)
# 获取文本的embedding
def get_embedding_for_text(text, appid, apikey, apisecret):
desc = {"messages": [{"content": text, "role": "user"}]}
res = get_embp_embedding(desc, appid=appid, apikey=apikey, apisecret=apisecret)
return parser_Message(res)
# 逐行加载已存在的embedding
def load_embeddings(file_path):
embeddings = {}
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip(): # 忽略空行
embedding_data = json.loads(line.strip())
embeddings.update(embedding_data)
except FileNotFoundError:
print(f"文件 {file_path} 不存在,将创建新文件")
return embeddings
# 逐行保存embedding到文件
def save_embedding_line_by_line(qa, embedding, file_path):
if embedding is not None:
embedding_as_list = embedding.tolist() # 将numpy array转换为列表
with open(file_path, 'a', encoding='utf-8') as f:
json.dump({qa: embedding_as_list}, f, ensure_ascii=False)
f.write("\n") # 每行一个embedding
# 获取单个问题的embedding并处理请求错误
def get_embedding_with_retry(question, appid, apikey, apisecret, max_retries=5):
retries = 0
while retries < max_retries:
try:
embedding = get_embedding_for_text(question, appid, apikey, apisecret)
if embedding is not None:
return embedding
except Exception as e:
print(f"请求错误: {e}")
retries += 1
print(f"重试第 {retries} 次...")
time.sleep(5) # 每次重试前等待 5 秒
print(f"获取'{question}' 的embedding失败")
return None
# 获取所有问答对的embedding并逐行保存
def get_and_save_embeddings(qa_pairs, appid, apikey, apisecret, file_path, qps_limit=2):
all_embeddings = load_embeddings(file_path) # 尝试加载已存在的embedding
interval = 1 / qps_limit # 根据QPS限制设置间隔时间
for qa in qa_pairs:
question = qa['input']
if question in all_embeddings:
print(f"'{question}' 的embedding已存在跳过计算")
continue
print(f"计算'{question}' 的embedding...")
embedding = get_embedding_with_retry(question, appid, apikey, apisecret) # 带重试机制的请求
if embedding is not None:
save_embedding_line_by_line(question, embedding, file_path) # 逐行保存
all_embeddings[question] = embedding # 更新已计算的embedding
time.sleep(interval) # 确保符合QPS限制
if __name__ == '__main__':
# 设置路径
qa_file = "output/train_optimized_multiple.jsonl" # 原问答对文件
embedding_file = "output/qa_embeddings.json" # embedding存储文件
appid = "f0f73de5"
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
api_key = "5773f6f95563708de994d17b7ea5d414"
# 加载数据
qa_pairs = load_qa_data(qa_file)
# 获取并保存embedding
get_and_save_embeddings(qa_pairs, appid, api_key, api_secret, embedding_file)
print(f"已保存所有问答对的embedding到 {embedding_file}")