4cfad5ae0f
- 全新ui - 全面优化websocket逻辑,提高数字人和ui连接的稳定性及资源开销 - 全面优化唤醒逻辑,提供稳定的普通唤醒模式和前置词唤醒模式 - 优化拾音质量,支持多声道麦克风拾音 - 优化自动播放服务器的对接机制,提供稳定和兼容旧版ue工程的对接模式 - 数字人接口输出机器人表情,以适应新fay ui及单片机的数字人表情输出 - 使用更高级的音频时长计算方式,可以更精准控制音频播放完成后的逻辑 - 修复点击关闭按钮会导致程序退出的bug - 修复没有麦克风的设备开启麦克风会出错的问题 - 为服务器主机地址提供配置项,以方便服务器部署
187 lines
5.9 KiB
Python
187 lines
5.9 KiB
Python
import gc
|
||
import json
|
||
import torch
|
||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||
from transformers.generation.logits_process import LogitsProcessor
|
||
from typing import Union, Tuple
|
||
|
||
|
||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||
def __call__(
|
||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||
) -> torch.FloatTensor:
|
||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||
scores.zero_()
|
||
scores[..., 5] = 5e4
|
||
return scores
|
||
|
||
|
||
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
||
content = ""
|
||
for response in output.split("<|assistant|>"):
|
||
metadata, content = response.split("\n", maxsplit=1)
|
||
if not metadata.strip():
|
||
content = content.strip()
|
||
content = content.replace("[[训练时间]]", "2023年")
|
||
else:
|
||
if use_tool:
|
||
content = "\n".join(content.split("\n")[1:-1])
|
||
|
||
def tool_call(**kwargs):
|
||
return kwargs
|
||
|
||
parameters = eval(content)
|
||
content = {
|
||
"name": metadata.strip(),
|
||
"arguments": json.dumps(parameters, ensure_ascii=False)
|
||
}
|
||
else:
|
||
content = {
|
||
"name": metadata.strip(),
|
||
"content": content
|
||
}
|
||
return content
|
||
|
||
|
||
@torch.inference_mode()
|
||
def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
|
||
messages = params["messages"]
|
||
tools = params["tools"]
|
||
temperature = float(params.get("temperature", 1.0))
|
||
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||
top_p = float(params.get("top_p", 1.0))
|
||
max_new_tokens = int(params.get("max_tokens", 256))
|
||
echo = params.get("echo", True)
|
||
messages = process_chatglm_messages(messages, tools=tools)
|
||
query, role = messages[-1]["content"], messages[-1]["role"]
|
||
|
||
inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role)
|
||
inputs = inputs.to(model.device)
|
||
input_echo_len = len(inputs["input_ids"][0])
|
||
|
||
if input_echo_len >= model.config.seq_length:
|
||
print(f"Input length larger than {model.config.seq_length}")
|
||
|
||
eos_token_id = [
|
||
tokenizer.eos_token_id,
|
||
tokenizer.get_command("<|user|>"),
|
||
]
|
||
|
||
gen_kwargs = {
|
||
"max_new_tokens": max_new_tokens,
|
||
"do_sample": True if temperature > 1e-5 else False,
|
||
"top_p": top_p,
|
||
"repetition_penalty": repetition_penalty,
|
||
"logits_processor": [InvalidScoreLogitsProcessor()],
|
||
}
|
||
if temperature > 1e-5:
|
||
gen_kwargs["temperature"] = temperature
|
||
|
||
total_len = 0
|
||
for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs):
|
||
total_ids = total_ids.tolist()[0]
|
||
total_len = len(total_ids)
|
||
if echo:
|
||
output_ids = total_ids[:-1]
|
||
else:
|
||
output_ids = total_ids[input_echo_len:-1]
|
||
|
||
response = tokenizer.decode(output_ids)
|
||
if response and response[-1] != "<EFBFBD>":
|
||
response, stop_found = apply_stopping_strings(response, ["<|observation|>"])
|
||
|
||
yield {
|
||
"text": response,
|
||
"usage": {
|
||
"prompt_tokens": input_echo_len,
|
||
"completion_tokens": total_len - input_echo_len,
|
||
"total_tokens": total_len,
|
||
},
|
||
"finish_reason": "function_call" if stop_found else None,
|
||
}
|
||
|
||
if stop_found:
|
||
break
|
||
|
||
# Only last stream result contains finish_reason, we set finish_reason as stop
|
||
ret = {
|
||
"text": response,
|
||
"usage": {
|
||
"prompt_tokens": input_echo_len,
|
||
"completion_tokens": total_len - input_echo_len,
|
||
"total_tokens": total_len,
|
||
},
|
||
"finish_reason": "stop",
|
||
}
|
||
yield ret
|
||
|
||
gc.collect()
|
||
torch.cuda.empty_cache()
|
||
|
||
|
||
def process_chatglm_messages(messages, tools=None):
|
||
_messages = messages
|
||
messages = []
|
||
if tools:
|
||
messages.append(
|
||
{
|
||
"role": "system",
|
||
"content": "Answer the following questions as best as you can. You have access to the following tools:",
|
||
"tools": tools
|
||
}
|
||
)
|
||
|
||
for m in _messages:
|
||
role, content, func_call = m.role, m.content, m.function_call
|
||
if role == "function":
|
||
messages.append(
|
||
{
|
||
"role": "observation",
|
||
"content": content
|
||
}
|
||
)
|
||
|
||
elif role == "assistant" and func_call is not None:
|
||
for response in content.split("<|assistant|>"):
|
||
metadata, sub_content = response.split("\n", maxsplit=1)
|
||
messages.append(
|
||
{
|
||
"role": role,
|
||
"metadata": metadata,
|
||
"content": sub_content.strip()
|
||
}
|
||
)
|
||
else:
|
||
messages.append({"role": role, "content": content})
|
||
return messages
|
||
|
||
|
||
def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
|
||
for response in generate_stream_chatglm3(model, tokenizer, params):
|
||
pass
|
||
return response
|
||
|
||
|
||
def apply_stopping_strings(reply, stop_strings) -> Tuple[str, bool]:
|
||
stop_found = False
|
||
for string in stop_strings:
|
||
idx = reply.find(string)
|
||
if idx != -1:
|
||
reply = reply[:idx]
|
||
stop_found = True
|
||
break
|
||
|
||
if not stop_found:
|
||
# If something like "\nYo" is generated just before "\nYou: is completed, trim it
|
||
for string in stop_strings:
|
||
for j in range(len(string) - 1, 0, -1):
|
||
if reply[-j:] == string[:j]:
|
||
reply = reply[:-j]
|
||
break
|
||
else:
|
||
continue
|
||
|
||
break
|
||
|
||
return reply, stop_found
|