218 lines
7.0 KiB
Python
218 lines
7.0 KiB
Python
import base64
|
||
import hashlib
|
||
import hmac
|
||
import json
|
||
import random
|
||
import time
|
||
from datetime import datetime
|
||
from urllib.parse import urlencode
|
||
from wsgiref.handlers import format_date_time
|
||
|
||
import numpy as np
|
||
import requests
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
import os
|
||
|
||
|
||
class AssembleHeaderException(Exception):
|
||
def __init__(this, msg):
|
||
this.message = msg
|
||
|
||
|
||
class Url:
|
||
def __init__(this, host, path, schema):
|
||
this.host = host
|
||
this.path = path
|
||
this.schema = schema
|
||
pass
|
||
|
||
|
||
# calculate sha256 and encode to base64
|
||
def sha256base64(data):
|
||
sha256 = hashlib.sha256()
|
||
sha256.update(data)
|
||
digest = base64.b64encode(sha256.digest()).decode(encoding='utf-8')
|
||
return digest
|
||
|
||
|
||
def parse_url(requset_url):
|
||
stidx = requset_url.index("://")
|
||
host = requset_url[stidx + 3:]
|
||
schema = requset_url[:stidx + 3]
|
||
edidx = host.index("/")
|
||
if edidx <= 0:
|
||
raise AssembleHeaderException("invalid request url:" + requset_url)
|
||
path = host[edidx:]
|
||
host = host[:edidx]
|
||
u = Url(host, path, schema)
|
||
return u
|
||
|
||
|
||
# 生成鉴权url
|
||
def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
|
||
u = parse_url(requset_url)
|
||
host = u.host
|
||
path = u.path
|
||
now = datetime.now()
|
||
date = format_date_time(time.mktime(now.timetuple()))
|
||
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path)
|
||
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||
digestmod=hashlib.sha256).digest()
|
||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||
api_key, "hmac-sha256", "host date request-line", signature_sha)
|
||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||
values = {
|
||
"host": host,
|
||
"date": date,
|
||
"authorization": authorization
|
||
}
|
||
|
||
return requset_url + "?" + urlencode(values)
|
||
|
||
|
||
def get_Body(appid, text, style):
|
||
org_content = json.dumps(text).encode('utf-8')
|
||
body = {
|
||
"header": {
|
||
"app_id": appid,
|
||
"uid": "39769795890",
|
||
"status": 3
|
||
},
|
||
"parameter": {
|
||
"emb": {
|
||
"domain": style,
|
||
"feature": {
|
||
"encoding": "utf8"
|
||
}
|
||
}
|
||
},
|
||
"payload": {
|
||
"messages": {
|
||
"text": base64.b64encode(json.dumps(text).encode('utf-8')).decode()
|
||
}
|
||
}
|
||
}
|
||
return body
|
||
|
||
|
||
# 发起请求并返回结果
|
||
def get_embp_embedding(text, appid, apikey, apisecret):
|
||
host = 'https://emb-cn-huabei-1.xf-yun.com/'
|
||
url = assemble_ws_auth_url(host, method='POST', api_key=apikey, api_secret=apisecret)
|
||
content = get_Body(appid, text, "para")
|
||
response = requests.post(url, json=content, headers={'content-type': "application/json"}).text
|
||
return response
|
||
|
||
|
||
# 解析结果并输出
|
||
def parser_Message(message):
|
||
data = json.loads(message)
|
||
code = data['header']['code']
|
||
if code != 0:
|
||
print(f'请求错误: {code}, {data}')
|
||
return None
|
||
else:
|
||
text_base = data["payload"]["feature"]["text"]
|
||
text_data = base64.b64decode(text_base)
|
||
dt = np.dtype(np.float32).newbyteorder("<")
|
||
text = np.frombuffer(text_data, dtype=dt)
|
||
return text
|
||
|
||
|
||
# 加载问答对数据
|
||
def load_qa_data(file_path):
|
||
qa_pairs = []
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
for line in f:
|
||
qa_pairs.append(json.loads(line.strip()))
|
||
return qa_pairs
|
||
|
||
|
||
# 保存embedding到文件
|
||
def save_embeddings(embeddings, file_path):
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
json.dump(embeddings, f, ensure_ascii=False)
|
||
|
||
|
||
# 获取文本的embedding
|
||
def get_embedding_for_text(text, appid, apikey, apisecret):
|
||
desc = {"messages": [{"content": text, "role": "user"}]}
|
||
res = get_embp_embedding(desc, appid=appid, apikey=apikey, apisecret=apisecret)
|
||
return parser_Message(res)
|
||
|
||
|
||
# 逐行加载已存在的embedding
|
||
def load_embeddings(file_path):
|
||
embeddings = {}
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
for line in f:
|
||
if line.strip(): # 忽略空行
|
||
embedding_data = json.loads(line.strip())
|
||
embeddings.update(embedding_data)
|
||
except FileNotFoundError:
|
||
print(f"文件 {file_path} 不存在,将创建新文件")
|
||
return embeddings
|
||
|
||
|
||
# 逐行保存embedding到文件
|
||
def save_embedding_line_by_line(qa, embedding, file_path):
|
||
if embedding is not None:
|
||
embedding_as_list = embedding.tolist() # 将numpy array转换为列表
|
||
with open(file_path, 'a', encoding='utf-8') as f:
|
||
json.dump({qa: embedding_as_list}, f, ensure_ascii=False)
|
||
f.write("\n") # 每行一个embedding
|
||
|
||
|
||
# 获取单个问题的embedding,并处理请求错误
|
||
def get_embedding_with_retry(question, appid, apikey, apisecret, max_retries=5):
|
||
retries = 0
|
||
while retries < max_retries:
|
||
try:
|
||
embedding = get_embedding_for_text(question, appid, apikey, apisecret)
|
||
if embedding is not None:
|
||
return embedding
|
||
except Exception as e:
|
||
print(f"请求错误: {e}")
|
||
retries += 1
|
||
print(f"重试第 {retries} 次...")
|
||
time.sleep(5) # 每次重试前等待 5 秒
|
||
print(f"获取'{question}' 的embedding失败")
|
||
return None
|
||
|
||
|
||
# 获取所有问答对的embedding并逐行保存
|
||
def get_and_save_embeddings(qa_pairs, appid, apikey, apisecret, file_path, qps_limit=2):
|
||
all_embeddings = load_embeddings(file_path) # 尝试加载已存在的embedding
|
||
interval = 1 / qps_limit # 根据QPS限制设置间隔时间
|
||
for qa in qa_pairs:
|
||
question = qa['input']
|
||
if question in all_embeddings:
|
||
print(f"'{question}' 的embedding已存在,跳过计算")
|
||
continue
|
||
print(f"计算'{question}' 的embedding...")
|
||
embedding = get_embedding_with_retry(question, appid, apikey, apisecret) # 带重试机制的请求
|
||
if embedding is not None:
|
||
save_embedding_line_by_line(question, embedding, file_path) # 逐行保存
|
||
all_embeddings[question] = embedding # 更新已计算的embedding
|
||
time.sleep(interval) # 确保符合QPS限制
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 设置路径
|
||
qa_file = "output/train_optimized_multiple.jsonl" # 原问答对文件
|
||
embedding_file = "output/qa_embeddings.json" # embedding存储文件
|
||
|
||
appid = "f0f73de5"
|
||
api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk"
|
||
api_key = "5773f6f95563708de994d17b7ea5d414"
|
||
|
||
# 加载数据
|
||
qa_pairs = load_qa_data(qa_file)
|
||
|
||
# 获取并保存embedding
|
||
get_and_save_embeddings(qa_pairs, appid, api_key, api_secret, embedding_file)
|
||
|
||
print(f"已保存所有问答对的embedding到 {embedding_file}")
|