93 lines
3.1 KiB
Python
93 lines
3.1 KiB
Python
|
from simhash import Simhash
|
|||
|
from ai_module.yuan1_0.inspurai import Yuan, set_yuan_account,Example
|
|||
|
import heapq
|
|||
|
import sys
|
|||
|
from utils import config_util as cfg
|
|||
|
|
|||
|
class Yuan1Dialog:
|
|||
|
def __init__(self, account, phone) -> None:
|
|||
|
self.account = account
|
|||
|
self.phone = phone
|
|||
|
set_yuan_account(account, phone)
|
|||
|
self.yuan = Yuan(engine='dialog',
|
|||
|
input_prefix="问:“",
|
|||
|
input_suffix="”",
|
|||
|
output_prefix="答:“",
|
|||
|
output_suffix="”",
|
|||
|
max_tokens=30,
|
|||
|
append_output_prefix_to_query=True)
|
|||
|
self.h_dialog = []
|
|||
|
|
|||
|
def get_relative_qa(self, prompt, h_dialog, topN=2):
|
|||
|
"""
|
|||
|
可以添加相关性计算,这里简单使用最近的一次对话。
|
|||
|
:topN: 需要返回的相关对话轮数。
|
|||
|
"""
|
|||
|
def simhash(query, text,):
|
|||
|
"""
|
|||
|
采用局部敏感的hash值表示语义。
|
|||
|
"""
|
|||
|
q_simhash = Simhash(query)
|
|||
|
t_simhash = Simhash(text)
|
|||
|
max_hashbit = max(len(bin(q_simhash.value)), len(bin(t_simhash.value)))
|
|||
|
|
|||
|
distance = q_simhash.distance(t_simhash)
|
|||
|
# print(distance)
|
|||
|
|
|||
|
similar = 1 - distance / max_hashbit
|
|||
|
return similar
|
|||
|
|
|||
|
h_num = len(h_dialog)
|
|||
|
sim_values = []
|
|||
|
tm_effs= []
|
|||
|
rel_effs = []
|
|||
|
gamma = 0.8 # time effect coefficient
|
|||
|
|
|||
|
if not h_dialog:
|
|||
|
return []
|
|||
|
else:
|
|||
|
for indx, dialog in enumerate(h_dialog):
|
|||
|
text = '|'.join((dialog.input, dialog.output))
|
|||
|
sim_value = simhash(prompt, text)
|
|||
|
tm_eff = gamma ** ((h_num - indx)/h_num)
|
|||
|
rel_eff = sim_value * tm_eff
|
|||
|
sim_values.append(sim_value)
|
|||
|
tm_effs.append(tm_eff)
|
|||
|
rel_effs.append(rel_eff)
|
|||
|
|
|||
|
top_idx = heapq.nlargest(topN, range(len(rel_effs)), rel_effs.__getitem__)
|
|||
|
mst_dialog = [h_dialog[idx] for idx in top_idx]
|
|||
|
mst_dialog.reverse()
|
|||
|
return mst_dialog
|
|||
|
|
|||
|
def update_example(self, yuan, exs):
|
|||
|
ex_ids = []
|
|||
|
for ex in exs:
|
|||
|
ex_ids.append(ex.get_id())
|
|||
|
yuan.add_example(ex)
|
|||
|
return yuan, ex_ids
|
|||
|
|
|||
|
def dialog(self, prompt):
|
|||
|
yuan = self.yuan
|
|||
|
h_dialog = self.h_dialog
|
|||
|
exs = self.get_relative_qa(prompt, h_dialog)
|
|||
|
yuan, ex_ids = self.update_example(yuan, exs)
|
|||
|
response = yuan.submit_API(prompt=prompt, trun="。")
|
|||
|
if len(h_dialog)<10: # 设置保存最多不超过10轮最近的历史对话
|
|||
|
h_dialog.append(Example(inp=prompt,out=response))
|
|||
|
else:
|
|||
|
del(h_dialog[0])
|
|||
|
h_dialog.append(Example(inp=prompt,out=response))
|
|||
|
for ex_id in ex_ids:
|
|||
|
yuan.delete_example(ex_id)
|
|||
|
return response
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
cfg.load_config()
|
|||
|
account = cfg.key_yuan_1_0_account
|
|||
|
phone = cfg.key_yuan_1_0_phone
|
|||
|
yuan1_dialog = Yuan1Dialog(account, phone)
|
|||
|
prompt = "你好"
|
|||
|
print(yuan1_dialog.dialog(prompt))
|
|||
|
|