OliveSensorAPI/IOTLLM/generate_data/EC_process/Embedding_merge.py

218 lines
7.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import hashlib
import hmac
import json
import random
import time
from datetime import datetime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import numpy as np
import requests
from sklearn.metrics.pairwise import cosine_similarity
import os
class AssembleHeaderException(Exception):
def __init__(this, msg):
this.message = msg
class Url:
def __init__(this, host, path, schema):
this.host = host
this.path = path
this.schema = schema
pass
# calculate sha256 and encode to base64
def sha256base64(data):
sha256 = hashlib.sha256()
sha256.update(data)
digest = base64.b64encode(sha256.digest()).decode(encoding='utf-8')
return digest
def parse_url(requset_url):
stidx = requset_url.index("://")
host = requset_url[stidx + 3:]
schema = requset_url[:stidx + 3]
edidx = host.index("/")
if edidx <= 0:
raise AssembleHeaderException("invalid request url:" + requset_url)
path = host[edidx:]
host = host[:edidx]
u = Url(host, path, schema)
return u
# 生成鉴权url
def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
u = parse_url(requset_url)
host = u.host
path = u.path
now = datetime.now()
date = format_date_time(time.mktime(now.timetuple()))
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path)
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
api_key, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
values = {
"host": host,
"date": date,
"authorization": authorization
}
return requset_url + "?" + urlencode(values)
def get_Body(appid, text, style):
org_content = json.dumps(text).encode('utf-8')
body = {
"header": {
"app_id": appid,
"uid": "39769795890",
"status": 3
},
"parameter": {
"emb": {
"domain": style,
"feature": {
"encoding": "utf8"
}
}
},
"payload": {
"messages": {
"text": base64.b64encode(json.dumps(text).encode('utf-8')).decode()
}
}
}
return body
# 发起请求并返回结果
def get_embp_embedding(text, appid, apikey, apisecret):
host = 'https://emb-cn-huabei-1.xf-yun.com/'
url = assemble_ws_auth_url(host, method='POST', api_key=apikey, api_secret=apisecret)
content = get_Body(appid, text, "para")
response = requests.post(url, json=content, headers={'content-type': "application/json"}).text
return response
# 解析结果并输出
def parser_Message(message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
return None
else:
text_base = data["payload"]["feature"]["text"]
text_data = base64.b64decode(text_base)
dt = np.dtype(np.float32).newbyteorder("<")
text = np.frombuffer(text_data, dtype=dt)
return text
# 加载问答对数据
def load_qa_data(file_path):
qa_pairs = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
qa_pairs.append(json.loads(line.strip()))
return qa_pairs
# 保存embedding到文件
def save_embeddings(embeddings, file_path):
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(embeddings, f, ensure_ascii=False)
# 获取文本的embedding
def get_embedding_for_text(text, appid, apikey, apisecret):
desc = {"messages": [{"content": text, "role": "user"}]}
res = get_embp_embedding(desc, appid=appid, apikey=apikey, apisecret=apisecret)
return parser_Message(res)
# 逐行加载已存在的embedding
def load_embeddings(file_path):
embeddings = {}
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip(): # 忽略空行
embedding_data = json.loads(line.strip())
embeddings.update(embedding_data)
except FileNotFoundError:
print(f"文件 {file_path} 不存在,将创建新文件")
return embeddings
# 逐行保存embedding到文件
def save_embedding_line_by_line(qa, embedding, file_path):
if embedding is not None:
embedding_as_list = embedding.tolist() # 将numpy array转换为列表
with open(file_path, 'a', encoding='utf-8') as f:
json.dump({qa: embedding_as_list}, f, ensure_ascii=False)
f.write("\n") # 每行一个embedding
# 获取单个问题的embedding并处理请求错误
def get_embedding_with_retry(question, appid, apikey, apisecret, max_retries=5):
retries = 0
while retries < max_retries:
try:
embedding = get_embedding_for_text(question, appid, apikey, apisecret)
if embedding is not None:
return embedding
except Exception as e:
print(f"请求错误: {e}")
retries += 1
print(f"重试第 {retries} 次...")
time.sleep(5) # 每次重试前等待 5 秒
print(f"获取'{question}' 的embedding失败")
return None
# 获取所有问答对的embedding并逐行保存
def get_and_save_embeddings(qa_pairs, appid, apikey, apisecret, file_path, qps_limit=2):
all_embeddings = load_embeddings(file_path) # 尝试加载已存在的embedding
interval = 1 / qps_limit # 根据QPS限制设置间隔时间
for qa in qa_pairs:
question = qa['input']
if question in all_embeddings:
print(f"'{question}' 的embedding已存在跳过计算")
continue
print(f"计算'{question}' 的embedding...")
embedding = get_embedding_with_retry(question, appid, apikey, apisecret) # 带重试机制的请求
if embedding is not None:
save_embedding_line_by_line(question, embedding, file_path) # 逐行保存
all_embeddings[question] = embedding # 更新已计算的embedding
time.sleep(interval) # 确保符合QPS限制
if __name__ == '__main__':
# 设置路径
qa_file = "output/train_optimized_multiple.jsonl" # 原问答对文件
embedding_file = "output/qa_embeddings.json" # embedding存储文件
appid = "f0f73de5"
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
api_key = "5773f6f95563708de994d17b7ea5d414"
# 加载数据
qa_pairs = load_qa_data(qa_file)
# 获取并保存embedding
get_and_save_embeddings(qa_pairs, appid, api_key, api_secret, embedding_file)
print(f"已保存所有问答对的embedding到 {embedding_file}")