import base64 import hashlib import hmac import json import random import time from datetime import datetime from urllib.parse import urlencode from wsgiref.handlers import format_date_time import numpy as np import requests from sklearn.metrics.pairwise import cosine_similarity import os class AssembleHeaderException(Exception): def __init__(this, msg): this.message = msg class Url: def __init__(this, host, path, schema): this.host = host this.path = path this.schema = schema pass # calculate sha256 and encode to base64 def sha256base64(data): sha256 = hashlib.sha256() sha256.update(data) digest = base64.b64encode(sha256.digest()).decode(encoding='utf-8') return digest def parse_url(requset_url): stidx = requset_url.index("://") host = requset_url[stidx + 3:] schema = requset_url[:stidx + 3] edidx = host.index("/") if edidx <= 0: raise AssembleHeaderException("invalid request url:" + requset_url) path = host[edidx:] host = host[:edidx] u = Url(host, path, schema) return u # 生成鉴权url def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""): u = parse_url(requset_url) host = u.host path = u.path now = datetime.now() date = format_date_time(time.mktime(now.timetuple())) signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path) signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), digestmod=hashlib.sha256).digest() signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( api_key, "hmac-sha256", "host date request-line", signature_sha) authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') values = { "host": host, "date": date, "authorization": authorization } return requset_url + "?" + urlencode(values) def get_Body(appid, text, style): org_content = json.dumps(text).encode('utf-8') body = { "header": { "app_id": appid, "uid": "39769795890", "status": 3 }, "parameter": { "emb": { "domain": style, "feature": { "encoding": "utf8" } } }, "payload": { "messages": { "text": base64.b64encode(json.dumps(text).encode('utf-8')).decode() } } } return body # 发起请求并返回结果 def get_embp_embedding(text, appid, apikey, apisecret): host = 'https://emb-cn-huabei-1.xf-yun.com/' url = assemble_ws_auth_url(host, method='POST', api_key=apikey, api_secret=apisecret) content = get_Body(appid, text, "para") response = requests.post(url, json=content, headers={'content-type': "application/json"}).text return response # 解析结果并输出 def parser_Message(message): data = json.loads(message) code = data['header']['code'] if code != 0: print(f'请求错误: {code}, {data}') return None else: text_base = data["payload"]["feature"]["text"] text_data = base64.b64decode(text_base) dt = np.dtype(np.float32).newbyteorder("<") text = np.frombuffer(text_data, dtype=dt) return text # 加载问答对数据 def load_qa_data(file_path): qa_pairs = [] with open(file_path, 'r', encoding='utf-8') as f: for line in f: qa_pairs.append(json.loads(line.strip())) return qa_pairs # 保存embedding到文件 def save_embeddings(embeddings, file_path): with open(file_path, 'w', encoding='utf-8') as f: json.dump(embeddings, f, ensure_ascii=False) # 获取文本的embedding def get_embedding_for_text(text, appid, apikey, apisecret): desc = {"messages": [{"content": text, "role": "user"}]} res = get_embp_embedding(desc, appid=appid, apikey=apikey, apisecret=apisecret) return parser_Message(res) # 逐行加载已存在的embedding def load_embeddings(file_path): embeddings = {} try: with open(file_path, 'r', encoding='utf-8') as f: for line in f: if line.strip(): # 忽略空行 embedding_data = json.loads(line.strip()) embeddings.update(embedding_data) except FileNotFoundError: print(f"文件 {file_path} 不存在,将创建新文件") return embeddings # 逐行保存embedding到文件 def save_embedding_line_by_line(qa, embedding, file_path): if embedding is not None: embedding_as_list = embedding.tolist() # 将numpy array转换为列表 with open(file_path, 'a', encoding='utf-8') as f: json.dump({qa: embedding_as_list}, f, ensure_ascii=False) f.write("\n") # 每行一个embedding # 获取单个问题的embedding,并处理请求错误 def get_embedding_with_retry(question, appid, apikey, apisecret, max_retries=5): retries = 0 while retries < max_retries: try: embedding = get_embedding_for_text(question, appid, apikey, apisecret) if embedding is not None: return embedding except Exception as e: print(f"请求错误: {e}") retries += 1 print(f"重试第 {retries} 次...") time.sleep(5) # 每次重试前等待 5 秒 print(f"获取'{question}' 的embedding失败") return None # 获取所有问答对的embedding并逐行保存 def get_and_save_embeddings(qa_pairs, appid, apikey, apisecret, file_path, qps_limit=2): all_embeddings = load_embeddings(file_path) # 尝试加载已存在的embedding interval = 1 / qps_limit # 根据QPS限制设置间隔时间 for qa in qa_pairs: question = qa['input'] if question in all_embeddings: print(f"'{question}' 的embedding已存在,跳过计算") continue print(f"计算'{question}' 的embedding...") embedding = get_embedding_with_retry(question, appid, apikey, apisecret) # 带重试机制的请求 if embedding is not None: save_embedding_line_by_line(question, embedding, file_path) # 逐行保存 all_embeddings[question] = embedding # 更新已计算的embedding time.sleep(interval) # 确保符合QPS限制 if __name__ == '__main__': # 设置路径 qa_file = "output/train_optimized_multiple.jsonl" # 原问答对文件 embedding_file = "output/qa_embeddings.json" # embedding存储文件 appid = "f0f73de5" api_secret = "YzkyYjQwMTU0MGZjMmUzMGE1Y2ZjYzBk" api_key = "5773f6f95563708de994d17b7ea5d414" # 加载数据 qa_pairs = load_qa_data(qa_file) # 获取并保存embedding get_and_save_embeddings(qa_pairs, appid, api_key, api_secret, embedding_file) print(f"已保存所有问答对的embedding到 {embedding_file}")