olivebot/ai_module/yuan1_0/inspurai.py
xszyou 77b2c0a479 2023.02.10
1、加入了NLP三选一:xfaiui、yuan、chatgpt;
2、修复由pyaudio bug导致的录音错误;
3、修改远程音频输入输出设备同时只允许连1个;
2023-02-10 19:50:47 +08:00

195 lines
6.6 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import uuid
from ai_module.yuan1_0.url_config import submit_request, reply_request
def set_yuan_account(user, phone):
os.environ['YUAN_ACCOUNT'] = user + '||' + phone
class Example:
""" store some examples(input, output pairs and formats) for few-shots to prime the model."""
def __init__(self, inp, out):
self.input = inp
self.output = out
self.id = uuid.uuid4().hex
def get_input(self):
"""return the input of the example."""
return self.input
def get_output(self):
"""Return the output of the example."""
return self.output
def get_id(self):
"""Returns the unique ID of the example."""
return self.id
def as_dict(self):
return {
"input": self.get_input(),
"output": self.get_output(),
"id": self.get_id(),
}
class Yuan:
"""The main class for a user to interface with the Inspur Yuan API.
A user can set account info and add examples of the API request.
"""
def __init__(self,
engine='base_10B',
temperature=0.9,
max_tokens=100,
input_prefix='',
input_suffix='\n',
output_prefix='答:',
output_suffix='\n\n',
append_output_prefix_to_query=False,
topK=1,
topP=0.9,
frequencyPenalty=1.2,
responsePenalty=1.2,
noRepeatNgramSize=2):
self.examples = {}
self.engine = engine
self.temperature = temperature
self.max_tokens = max_tokens
self.topK = topK
self.topP = topP
self.frequencyPenalty = frequencyPenalty
self.responsePenalty = responsePenalty
self.noRepeatNgramSize = noRepeatNgramSize
self.input_prefix = input_prefix
self.input_suffix = input_suffix
self.output_prefix = output_prefix
self.output_suffix = output_suffix
self.append_output_prefix_to_query = append_output_prefix_to_query
self.stop = (output_suffix + input_prefix).strip()
# if self.engine not in ['base_10B','translate','dialog']:
# raise Exception('engine must be one of [\'base_10B\',\'translate\',\'dialog\'] ')
def add_example(self, ex):
"""Add an example to the object.
Example must be an instance of the Example class."""
assert isinstance(ex, Example), "Please create an Example object."
self.examples[ex.get_id()] = ex
def delete_example(self, id):
"""Delete example with the specific id."""
if id in self.examples:
del self.examples[id]
def get_example(self, id):
"""Get a single example."""
return self.examples.get(id, None)
def get_all_examples(self):
"""Returns all examples as a list of dicts."""
return {k: v.as_dict() for k, v in self.examples.items()}
def get_prime_text(self):
"""Formats all examples to prime the model."""
return "".join(
[self.format_example(ex) for ex in self.examples.values()])
def get_engine(self):
"""Returns the engine specified for the API."""
return self.engine
def get_temperature(self):
"""Returns the temperature specified for the API."""
return self.temperature
def get_max_tokens(self):
"""Returns the max tokens specified for the API."""
return self.max_tokens
def craft_query(self, prompt):
"""Creates the query for the API request."""
q = self.get_prime_text(
) + self.input_prefix + prompt + self.input_suffix
if self.append_output_prefix_to_query:
q = q + self.output_prefix
return q
def format_example(self, ex):
"""Formats the input, output pair."""
return self.input_prefix + ex.get_input(
) + self.input_suffix + self.output_prefix + ex.get_output(
) + self.output_suffix
def response(self,
query,
engine='base_10B',
max_tokens=20,
temperature=0.9,
topP=0.1,
topK=1,
frequencyPenalty=1.0,
responsePenalty=1.0,
noRepeatNgramSize=0):
"""Obtains the original result returned by the API."""
try:
# requestId = submit_request(query,temperature,topP,topK,max_tokens, engine)
requestId = submit_request(query, temperature, topP, topK, max_tokens, engine, frequencyPenalty,
responsePenalty, noRepeatNgramSize)
response_text = reply_request(requestId)
except Exception as e:
raise e
return response_text
def del_special_chars(self, msg):
special_chars = ['<unk>', '<eod>', '#', '', '', '', ' ']
for char in special_chars:
msg = msg.replace(char, '')
return msg
def submit_API(self, prompt, trun=[]):
"""Submit prompt to yuan API interface and obtain an pure text reply.
:prompt: Question or any content a user may input.
:return: pure text response."""
query = self.craft_query(prompt)
res = self.response(query,engine=self.engine,
max_tokens=self.max_tokens,
temperature=self.temperature,
topP=self.topP,
topK=self.topK,
frequencyPenalty = self.frequencyPenalty,
responsePenalty = self.responsePenalty,
noRepeatNgramSize = self.noRepeatNgramSize)
if 'resData' in res and res['resData'] != None:
txt = res['resData']
else:
txt = '模型返回为空,请尝试修改输入'
# 单独针对翻译模型的后处理
if self.engine == 'translate':
txt = txt.replace(' ##', '').replace(' "', '"').replace(": ", ":").replace(" ,", ",") \
.replace('英文:', '').replace('文:', '').replace("( ", "(").replace(" )", ")")
else:
txt = txt.replace(' ', '')
txt = self.del_special_chars(txt)
# trun多结束符截断模型输出
if isinstance(trun, str):
trun = [trun]
try:
if trun != None and isinstance(trun, list) and trun != []:
for tr in trun:
if tr in txt and tr!="":
txt = txt[:txt.index(tr)]
else:
continue
except:
return txt
return txt