Fay年翻更新
- 升级Agent(chat_module=agent切换):升级到langgraph react agent逻辑、集成到主分支fay中、基于自动决策工具调用机制、基于日程跟踪的主动沟通、支持外部观测数据传入; - 修复因线程同步问题导致的配置文件读写不稳定 - 聊天采纳功能的bug修复
This commit is contained in:
parent
f871b6a532
commit
87ed1c4425
12
config.json
12
config.json
@ -8,11 +8,11 @@
|
|||||||
"hobby": "\u53d1\u5446",
|
"hobby": "\u53d1\u5446",
|
||||||
"job": "\u52a9\u7406",
|
"job": "\u52a9\u7406",
|
||||||
"name": "\u83f2\u83f2",
|
"name": "\u83f2\u83f2",
|
||||||
"voice": "\u6653\u6653(azure)",
|
"voice": "abin",
|
||||||
"zodiac": "\u86c7"
|
"zodiac": "\u86c7"
|
||||||
},
|
},
|
||||||
"interact": {
|
"interact": {
|
||||||
"QnA": "",
|
"QnA": "qa.csv",
|
||||||
"maxInteractTime": 15,
|
"maxInteractTime": 15,
|
||||||
"perception": {
|
"perception": {
|
||||||
"chat": 10,
|
"chat": 10,
|
||||||
@ -21,8 +21,7 @@
|
|||||||
"indifferent": 10,
|
"indifferent": 10,
|
||||||
"join": 10
|
"join": 10
|
||||||
},
|
},
|
||||||
"playSound": true,
|
"playSound": false,
|
||||||
"sound_synthesis_enabled": false,
|
|
||||||
"visualization": false
|
"visualization": false
|
||||||
},
|
},
|
||||||
"items": [],
|
"items": [],
|
||||||
@ -34,12 +33,11 @@
|
|||||||
"url": ""
|
"url": ""
|
||||||
},
|
},
|
||||||
"record": {
|
"record": {
|
||||||
"channels": 0,
|
|
||||||
"device": "",
|
"device": "",
|
||||||
"enabled": true
|
"enabled": false
|
||||||
},
|
},
|
||||||
"wake_word": "\u4f60\u597d",
|
"wake_word": "\u4f60\u597d",
|
||||||
"wake_word_enabled": true,
|
"wake_word_enabled": false,
|
||||||
"wake_word_type": "front"
|
"wake_word_type": "front"
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -29,6 +29,7 @@ from llm import nlp_xingchen
|
|||||||
from llm import nlp_langchain
|
from llm import nlp_langchain
|
||||||
from llm import nlp_ollama_api
|
from llm import nlp_ollama_api
|
||||||
from llm import nlp_coze
|
from llm import nlp_coze
|
||||||
|
from llm.agent import fay_agent
|
||||||
from core import member_db
|
from core import member_db
|
||||||
import threading
|
import threading
|
||||||
import functools
|
import functools
|
||||||
@ -60,7 +61,8 @@ modules = {
|
|||||||
"nlp_xingchen": nlp_xingchen,
|
"nlp_xingchen": nlp_xingchen,
|
||||||
"nlp_langchain": nlp_langchain,
|
"nlp_langchain": nlp_langchain,
|
||||||
"nlp_ollama_api": nlp_ollama_api,
|
"nlp_ollama_api": nlp_ollama_api,
|
||||||
"nlp_coze": nlp_coze
|
"nlp_coze": nlp_coze,
|
||||||
|
"nlp_agent": fay_agent
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,9 +147,9 @@ class FeiFei:
|
|||||||
uid = member_db.new_instance().find_user(username)
|
uid = member_db.new_instance().find_user(username)
|
||||||
|
|
||||||
#记录用户问题
|
#记录用户问题
|
||||||
content_db.new_instance().add_content('member','speak',interact.data["msg"], username, uid)
|
content_id = content_db.new_instance().add_content('member','speak',interact.data["msg"], username, uid)
|
||||||
if wsa_server.get_web_instance().is_connected(username):
|
if wsa_server.get_web_instance().is_connected(username):
|
||||||
wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"member","content":interact.data["msg"], "username":username, "uid":uid}, "Username" : username})
|
wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"member","content":interact.data["msg"], "username":username, "uid":uid, "id":content_id}, "Username" : username})
|
||||||
|
|
||||||
#确定是否命中q&a
|
#确定是否命中q&a
|
||||||
answer = self.__get_answer(interact.interleaver, interact.data["msg"])
|
answer = self.__get_answer(interact.interleaver, interact.data["msg"])
|
||||||
@ -163,18 +165,17 @@ class FeiFei:
|
|||||||
wsa_server.get_instance().add_cmd(content)
|
wsa_server.get_instance().add_cmd(content)
|
||||||
text,textlist = handle_chat_message(interact.data["msg"], username, interact.data.get("observation", ""))
|
text,textlist = handle_chat_message(interact.data["msg"], username, interact.data.get("observation", ""))
|
||||||
|
|
||||||
# qa_service.QAService().record_qapair(interact.data["msg"], text)#沟通记录缓存到qa文件
|
|
||||||
else:
|
else:
|
||||||
text = answer
|
text = answer
|
||||||
|
|
||||||
#记录回复
|
#记录回复
|
||||||
self.write_to_file("./logs", "answer_result.txt", text)
|
self.write_to_file("./logs", "answer_result.txt", text)
|
||||||
content_db.new_instance().add_content('fay','speak',text, username, uid)
|
content_id = content_db.new_instance().add_content('fay','speak',text, username, uid)
|
||||||
|
|
||||||
#文字输出:面板、聊天窗、log、数字人
|
#文字输出:面板、聊天窗、log、数字人
|
||||||
if wsa_server.get_web_instance().is_connected(username):
|
if wsa_server.get_web_instance().is_connected(username):
|
||||||
wsa_server.get_web_instance().add_cmd({"panelMsg": text, "Username" : username, 'robot': f'http://{cfg.fay_url}:5000/robot/Speaking.jpg'})
|
wsa_server.get_web_instance().add_cmd({"panelMsg": text, "Username" : username, 'robot': f'http://{cfg.fay_url}:5000/robot/Speaking.jpg'})
|
||||||
wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"fay","content":text, "username":username, "uid":uid}, "Username" : username})
|
wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"fay","content":text, "username":username, "uid":uid, "id":content_id}, "Username" : username})
|
||||||
if len(textlist) > 1:
|
if len(textlist) > 1:
|
||||||
i = 1
|
i = 1
|
||||||
while i < len(textlist):
|
while i < len(textlist):
|
||||||
@ -199,8 +200,6 @@ class FeiFei:
|
|||||||
member_db.new_instance().add_user(username)
|
member_db.new_instance().add_user(username)
|
||||||
uid = member_db.new_instance().find_user(username)
|
uid = member_db.new_instance().find_user(username)
|
||||||
|
|
||||||
#TODO 这里可以通过qa来触发指定的脚本操作,如ppt翻页等
|
|
||||||
|
|
||||||
if interact.data.get("text"):
|
if interact.data.get("text"):
|
||||||
#记录回复
|
#记录回复
|
||||||
text = interact.data.get("text")
|
text = interact.data.get("text")
|
||||||
@ -219,6 +218,7 @@ class FeiFei:
|
|||||||
#声音输出
|
#声音输出
|
||||||
MyThread(target=self.say, args=[interact, text]).start()
|
MyThread(target=self.say, args=[interact, text]).start()
|
||||||
|
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
print(e)
|
print(e)
|
||||||
return e
|
return e
|
||||||
@ -319,9 +319,6 @@ class FeiFei:
|
|||||||
if audio_url is not None:
|
if audio_url is not None:
|
||||||
file_name = 'sample-' + str(int(time.time() * 1000)) + '.wav'
|
file_name = 'sample-' + str(int(time.time() * 1000)) + '.wav'
|
||||||
result = self.download_wav(audio_url, './samples/', file_name)
|
result = self.download_wav(audio_url, './samples/', file_name)
|
||||||
|
|
||||||
elif not wsa_server.get_instance().get_client_output(interact.data.get('user')):
|
|
||||||
result = None
|
|
||||||
elif config_util.config["interact"]["playSound"] or wsa_server.get_instance().is_connected(interact.data.get("user")) or self.__is_send_remote_device_audio(interact):#tts
|
elif config_util.config["interact"]["playSound"] or wsa_server.get_instance().is_connected(interact.data.get("user")) or self.__is_send_remote_device_audio(interact):#tts
|
||||||
util.printInfo(1, interact.data.get('user'), '合成音频...')
|
util.printInfo(1, interact.data.get('user'), '合成音频...')
|
||||||
tm = time.time()
|
tm = time.time()
|
||||||
|
@ -83,6 +83,7 @@ class Member_Db:
|
|||||||
else:
|
else:
|
||||||
return "notexists"
|
return "notexists"
|
||||||
|
|
||||||
|
#根据username查询uid
|
||||||
def find_user(self, username):
|
def find_user(self, username):
|
||||||
conn = sqlite3.connect('user_profiles.db')
|
conn = sqlite3.connect('user_profiles.db')
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
@ -94,6 +95,20 @@ class Member_Db:
|
|||||||
else:
|
else:
|
||||||
return result[0]
|
return result[0]
|
||||||
|
|
||||||
|
#根据uid查询username
|
||||||
|
def find_username_by_uid(self, uid):
|
||||||
|
conn = sqlite3.connect('user_profiles.db')
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute('SELECT username FROM T_Member WHERE id = ?', (uid,))
|
||||||
|
result = c.fetchone()
|
||||||
|
conn.close()
|
||||||
|
if result is None:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def query(self, sql):
|
def query(self, sql):
|
||||||
try:
|
try:
|
||||||
|
@ -33,7 +33,7 @@ class QAService:
|
|||||||
|
|
||||||
def question(self, query_type, text):
|
def question(self, query_type, text):
|
||||||
if query_type == 'qa':
|
if query_type == 'qa':
|
||||||
answer_dict = self.__read_qna(cfg.config['interact']['QnA'])
|
answer_dict = self.__read_qna(cfg.config['interact'].get('QnA'))
|
||||||
answer, action = self.__get_keyword(answer_dict, text, query_type)
|
answer, action = self.__get_keyword(answer_dict, text, query_type)
|
||||||
if action:
|
if action:
|
||||||
MyThread(target=self.__run, args=[action]).start()
|
MyThread(target=self.__run, args=[action]).start()
|
||||||
@ -61,7 +61,7 @@ class QAService:
|
|||||||
if len(row) >= 2:
|
if len(row) >= 2:
|
||||||
qna.append([row[0].split(";"), row[1], row[2] if len(row) >= 3 else None])
|
qna.append([row[0].split(";"), row[1], row[2] if len(row) >= 3 else None])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
util.log(1, 'qa文件没有指定,不匹配qa')
|
pass
|
||||||
return qna
|
return qna
|
||||||
|
|
||||||
def record_qapair(self, question, answer):
|
def record_qapair(self, question, answer):
|
||||||
|
@ -46,6 +46,8 @@ class Recorder:
|
|||||||
self.username = 'User' #默认用户,子类实现时会重写
|
self.username = 'User' #默认用户,子类实现时会重写
|
||||||
self.channels = 1
|
self.channels = 1
|
||||||
self.sample_rate = 16000
|
self.sample_rate = 16000
|
||||||
|
self.is_reading = False
|
||||||
|
self.stream = None
|
||||||
|
|
||||||
def asrclient(self):
|
def asrclient(self):
|
||||||
if self.ASRMode == "ali":
|
if self.ASRMode == "ali":
|
||||||
@ -204,7 +206,7 @@ class Recorder:
|
|||||||
cfg.load_config()
|
cfg.load_config()
|
||||||
record = cfg.config['source']['record']
|
record = cfg.config['source']['record']
|
||||||
if not record['enabled'] and not self.is_remote:
|
if not record['enabled'] and not self.is_remote:
|
||||||
time.sleep(0.1)
|
time.sleep(1)
|
||||||
continue
|
continue
|
||||||
self.is_reading = True
|
self.is_reading = True
|
||||||
data = stream.read(1024, exception_on_overflow=False)
|
data = stream.read(1024, exception_on_overflow=False)
|
||||||
|
@ -14,6 +14,7 @@ from utils import util, config_util, stream_util
|
|||||||
from core.wsa_server import MyServer
|
from core.wsa_server import MyServer
|
||||||
from core import wsa_server
|
from core import wsa_server
|
||||||
from core import socket_bridge_service
|
from core import socket_bridge_service
|
||||||
|
from llm.agent import agent_service
|
||||||
|
|
||||||
feiFei: fay_core.FeiFei = None
|
feiFei: fay_core.FeiFei = None
|
||||||
recorderListener: Recorder = None
|
recorderListener: Recorder = None
|
||||||
@ -96,6 +97,7 @@ class RecorderListener(Recorder):
|
|||||||
try:
|
try:
|
||||||
while self.is_reading:
|
while self.is_reading:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
if self.stream is not None:
|
||||||
self.stream.stop_stream()
|
self.stream.stop_stream()
|
||||||
self.stream.close()
|
self.stream.close()
|
||||||
self.paudio.terminate()
|
self.paudio.terminate()
|
||||||
@ -186,7 +188,7 @@ def device_socket_keep_alive():
|
|||||||
if wsa_server.get_web_instance().is_connected(value.username):
|
if wsa_server.get_web_instance().is_connected(value.username):
|
||||||
wsa_server.get_web_instance().add_cmd({"remote_audio_connect": True, "Username" : value.username})
|
wsa_server.get_web_instance().add_cmd({"remote_audio_connect": True, "Username" : value.username})
|
||||||
except Exception as serr:
|
except Exception as serr:
|
||||||
util.printInfo(3, value.username, "远程音频输入输出设备已经断开:{}".format(key))
|
util.printInfo(1, value.username, "远程音频输入输出设备已经断开:{}".format(key))
|
||||||
value.stop()
|
value.stop()
|
||||||
delkey = key
|
delkey = key
|
||||||
break
|
break
|
||||||
@ -222,6 +224,8 @@ def accept_audio_device_output_connect():
|
|||||||
|
|
||||||
#数字人端请求获取最新的自动播放消息,若自动播放服务关闭会自动退出自动播放
|
#数字人端请求获取最新的自动播放消息,若自动播放服务关闭会自动退出自动播放
|
||||||
def start_auto_play_service(): #TODO 评估一下有无优化的空间
|
def start_auto_play_service(): #TODO 评估一下有无优化的空间
|
||||||
|
if config_util.config['source'].get('automatic_player_url') is None or config_util.config['source'].get('automatic_player_status') is None:
|
||||||
|
return
|
||||||
url = f"{config_util.config['source']['automatic_player_url']}/get_auto_play_item"
|
url = f"{config_util.config['source']['automatic_player_url']}/get_auto_play_item"
|
||||||
user = "User" #TODO 临时固死了
|
user = "User" #TODO 临时固死了
|
||||||
is_auto_server_error = False
|
is_auto_server_error = False
|
||||||
@ -290,6 +294,11 @@ def stop():
|
|||||||
socket_service_instance = None
|
socket_service_instance = None
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if config_util.key_chat_module == "agent":
|
||||||
|
util.log(1, '正在关闭agent服务...')
|
||||||
|
agent_service.agent_stop()
|
||||||
|
|
||||||
util.log(1, '正在关闭核心服务...')
|
util.log(1, '正在关闭核心服务...')
|
||||||
feiFei.stop()
|
feiFei.stop()
|
||||||
util.log(1, '服务已关闭!')
|
util.log(1, '服务已关闭!')
|
||||||
@ -325,18 +334,22 @@ def start():
|
|||||||
record = config_util.config['source']['record']
|
record = config_util.config['source']['record']
|
||||||
if record['enabled']:
|
if record['enabled']:
|
||||||
util.log(1, '开启录音服务...')
|
util.log(1, '开启录音服务...')
|
||||||
recorderListener = RecorderListener(record['device'], feiFei) # 监听麦克风
|
recorderListener = RecorderListener('device', feiFei) # 监听麦克风
|
||||||
recorderListener.start()
|
recorderListener.start()
|
||||||
|
|
||||||
#启动声音沟通接口服务
|
#启动声音沟通接口服务
|
||||||
util.log(1,'启动声音沟通接口服务...')
|
util.log(1,'启动声音沟通接口服务...')
|
||||||
deviceSocketThread = MyThread(target=accept_audio_device_output_connect)
|
deviceSocketThread = MyThread(target=accept_audio_device_output_connect)
|
||||||
deviceSocketThread.start()
|
deviceSocketThread.start()
|
||||||
|
|
||||||
socket_service_instance = socket_bridge_service.new_instance()
|
socket_service_instance = socket_bridge_service.new_instance()
|
||||||
socket_bridge_service_Thread = MyThread(target=socket_service_instance.start_service)
|
socket_bridge_service_Thread = MyThread(target=socket_service_instance.start_service)
|
||||||
socket_bridge_service_Thread.start()
|
socket_bridge_service_Thread.start()
|
||||||
|
|
||||||
|
#启动agent服务
|
||||||
|
if config_util.key_chat_module == "agent":
|
||||||
|
util.log(1,'启动agent服务...')
|
||||||
|
agent_service.agent_start()
|
||||||
|
|
||||||
#启动自动播放服务
|
#启动自动播放服务
|
||||||
util.log(1,'启动自动播放服务...')
|
util.log(1,'启动自动播放服务...')
|
||||||
MyThread(target=start_auto_play_service).start()
|
MyThread(target=start_auto_play_service).start()
|
||||||
|
@ -46,6 +46,7 @@ def verify_password(username, password):
|
|||||||
if username in users and users[username] == password:
|
if username in users and users[username] == password:
|
||||||
return username
|
return username
|
||||||
|
|
||||||
|
|
||||||
def __get_template():
|
def __get_template():
|
||||||
try:
|
try:
|
||||||
return render_template('index.html')
|
return render_template('index.html')
|
||||||
@ -68,6 +69,7 @@ def __get_device_list():
|
|||||||
print(f"Error getting device list: {e}")
|
print(f"Error getting device list: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@__app.route('/api/submit', methods=['post'])
|
@__app.route('/api/submit', methods=['post'])
|
||||||
def api_submit():
|
def api_submit():
|
||||||
data = request.values.get('data')
|
data = request.values.get('data')
|
||||||
@ -252,7 +254,7 @@ def api_send():
|
|||||||
if not username or not msg:
|
if not username or not msg:
|
||||||
return jsonify({'result': 'error', 'message': '用户名和消息内容不能为空'})
|
return jsonify({'result': 'error', 'message': '用户名和消息内容不能为空'})
|
||||||
interact = Interact("text", 1, {'user': username, 'msg': msg})
|
interact = Interact("text", 1, {'user': username, 'msg': msg})
|
||||||
util.printInfo(3, "文字发送按钮", '{}'.format(interact.data["msg"]), time.time())
|
util.printInfo(1, username, '[文字发送按钮]{}'.format(interact.data["msg"]), time.time())
|
||||||
fay_booter.feiFei.on_interact(interact)
|
fay_booter.feiFei.on_interact(interact)
|
||||||
return '{"result":"successful"}'
|
return '{"result":"successful"}'
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
@ -263,10 +265,11 @@ def api_send():
|
|||||||
# 获取指定用户的消息记录
|
# 获取指定用户的消息记录
|
||||||
@__app.route('/api/get-msg', methods=['post'])
|
@__app.route('/api/get-msg', methods=['post'])
|
||||||
def api_get_Msg():
|
def api_get_Msg():
|
||||||
data = request.form.get('data')
|
|
||||||
if not data:
|
|
||||||
return jsonify({'list': [], 'message': '未提供数据'})
|
|
||||||
try:
|
try:
|
||||||
|
data = request.form.get('data')
|
||||||
|
if data is None:
|
||||||
|
data = request.get_json()
|
||||||
|
else:
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
uid = member_db.new_instance().find_user(data["username"])
|
uid = member_db.new_instance().find_user(data["username"])
|
||||||
contentdb = content_db.new_instance()
|
contentdb = content_db.new_instance()
|
||||||
@ -310,7 +313,7 @@ def api_send_v1_chat_completions():
|
|||||||
model = data.get('model', 'fay')
|
model = data.get('model', 'fay')
|
||||||
observation = data.get('observation', '')
|
observation = data.get('observation', '')
|
||||||
interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': observation})
|
interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': observation})
|
||||||
util.printInfo(3, "文字沟通接口", '{}'.format(interact.data["msg"]), time.time())
|
util.printInfo(1, username, '[文字沟通接口]{}'.format(interact.data["msg"]), time.time())
|
||||||
text = fay_booter.feiFei.on_interact(interact)
|
text = fay_booter.feiFei.on_interact(interact)
|
||||||
|
|
||||||
if model == 'fay-streaming':
|
if model == 'fay-streaming':
|
||||||
|
@ -190,6 +190,7 @@ class FayInterface {
|
|||||||
}
|
}
|
||||||
if (vueInstance.selectedUser && data.panelReply.username === vueInstance.selectedUser[1]) {
|
if (vueInstance.selectedUser && data.panelReply.username === vueInstance.selectedUser[1]) {
|
||||||
vueInstance.messages.push({
|
vueInstance.messages.push({
|
||||||
|
id: data.panelReply.id,
|
||||||
username: data.panelReply.username,
|
username: data.panelReply.username,
|
||||||
content: data.panelReply.content,
|
content: data.panelReply.content,
|
||||||
type: data.panelReply.type,
|
type: data.panelReply.type,
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
<body >
|
<body >
|
||||||
<div id="app" class="main_bg">
|
<div id="app" class="main_bg">
|
||||||
<div class="main_left">
|
<div class="main_left">
|
||||||
<div class="main_left_logo" ><img src="{{ url_for('static',filename='images/logo.png') }}" alt="">
|
<div class="main_left_logo" ><img src="{{ url_for('static',filename='images/Logo.png') }}" alt="">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="main_left_menu">
|
<div class="main_left_menu">
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
<body>
|
<body>
|
||||||
<div class="main_bg" id="app">
|
<div class="main_bg" id="app">
|
||||||
<div class="main_left">
|
<div class="main_left">
|
||||||
<div class="main_left_logo" ><img src="{{ url_for('static',filename='images/logo.png') }}" alt="">
|
<div class="main_left_logo" ><img src="{{ url_for('static',filename='images/Logo.png') }}" alt="">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="main_left_menu">
|
<div class="main_left_menu">
|
||||||
|
133
llm/agent/agent_service.py
Normal file
133
llm/agent/agent_service.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
from scheduler.thread_manager import MyThread
|
||||||
|
from core import member_db
|
||||||
|
from core.interact import Interact
|
||||||
|
from utils import util
|
||||||
|
import fay_booter
|
||||||
|
|
||||||
|
scheduled_tasks = {}
|
||||||
|
agent_running = False
|
||||||
|
|
||||||
|
|
||||||
|
# 数据库初始化
|
||||||
|
def init_db():
|
||||||
|
conn = sqlite3.connect('timer.db')
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS timer (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
time TEXT NOT NULL,
|
||||||
|
repeat_rule TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
uid INTEGER
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 插入测试数据
|
||||||
|
def insert_test_data():
|
||||||
|
conn = sqlite3.connect('timer.db')
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("INSERT INTO timer (time, repeat_rule, content) VALUES (?, ?, ?)", ('16:20', '1010001', 'Meeting Reminder'))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# 解析重复规则返回待执行时间,None代表不在今天的待执行计划
|
||||||
|
def parse_repeat_rule(rule, task_time):
|
||||||
|
today = datetime.datetime.now()
|
||||||
|
if rule == '0000000': # 不重复
|
||||||
|
task_datetime = datetime.datetime.combine(today.date(), task_time)
|
||||||
|
if task_datetime > today:
|
||||||
|
return task_datetime
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
for i, day in enumerate(rule):
|
||||||
|
if day == '1' and today.weekday() == i:
|
||||||
|
task_datetime = datetime.datetime.combine(today.date(), task_time)
|
||||||
|
if task_datetime > today:
|
||||||
|
return task_datetime
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 执行任务
|
||||||
|
def execute_task(task_time, id, content, uid):
|
||||||
|
username = member_db.new_instance().find_username_by_uid(uid=uid)
|
||||||
|
if not username:
|
||||||
|
username = "User"
|
||||||
|
interact = Interact("text", 1, {'user': username, 'msg': "执行任务->\n" + content, 'observation': ""})
|
||||||
|
util.printInfo(3, "系统", '执行任务:{}'.format(interact.data["msg"]), time.time())
|
||||||
|
text = fay_booter.feiFei.on_interact(interact)
|
||||||
|
if text is not None and id in scheduled_tasks:
|
||||||
|
del scheduled_tasks[id]
|
||||||
|
# 如果不重复,执行后删除记录
|
||||||
|
conn = sqlite3.connect('timer.db')
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM timer WHERE repeat_rule = '0000000' AND id = ?", (id,))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 30秒扫描一次数据库,当扫描到今天的不存在于定时任务列表的记录,则添加到定时任务列表。执行完的记录从定时任务列表中清除。
|
||||||
|
def check_and_execute():
|
||||||
|
while agent_running:
|
||||||
|
conn = sqlite3.connect('timer.db')
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM timer")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
for row in rows:
|
||||||
|
id, task_time_str, repeat_rule, content, uid = row
|
||||||
|
task_time = datetime.datetime.strptime(task_time_str, '%H:%M').time()
|
||||||
|
next_execution = parse_repeat_rule(repeat_rule, task_time)
|
||||||
|
|
||||||
|
if next_execution and id not in scheduled_tasks:
|
||||||
|
timer_thread = threading.Timer((next_execution - datetime.datetime.now()).total_seconds(), execute_task, [next_execution, id, content, uid])
|
||||||
|
timer_thread.start()
|
||||||
|
scheduled_tasks[id] = timer_thread
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
time.sleep(30) # 30秒扫描一次
|
||||||
|
|
||||||
|
# agent启动
|
||||||
|
def agent_start():
|
||||||
|
global agent_running
|
||||||
|
|
||||||
|
agent_running = True
|
||||||
|
#初始计划
|
||||||
|
if not os.path.exists("./timer.db"):
|
||||||
|
init_db()
|
||||||
|
content ="""执行任务->
|
||||||
|
你是一个数字人,你的责任是陪伴主人生活、工作:
|
||||||
|
1、在每天早上8点提醒主人起床;
|
||||||
|
2、每天12:00及18:30提醒主人吃饭;
|
||||||
|
3、每天21:00陪主人聊聊天;
|
||||||
|
4、每天23:00提醒主人睡觉。
|
||||||
|
"""
|
||||||
|
interact = Interact("text", 1, {'user': 'User', 'msg': content, 'observation': ""})
|
||||||
|
util.printInfo(3, "系统", '执行任务:{}'.format(interact.data["msg"]), time.time())
|
||||||
|
text = fay_booter.feiFei.on_interact(interact)
|
||||||
|
if text is None:
|
||||||
|
util.printInfo(3, "系统", '任务执行失败', time.time())
|
||||||
|
|
||||||
|
check_and_execute_thread = MyThread(target=check_and_execute)
|
||||||
|
check_and_execute_thread.start()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def agent_stop():
|
||||||
|
global agent_running
|
||||||
|
global scheduled_tasks
|
||||||
|
# 取消所有定时任务
|
||||||
|
for task in scheduled_tasks.values():
|
||||||
|
task.cancel()
|
||||||
|
agent_running = False
|
||||||
|
scheduled_tasks = {}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
agent_start()
|
99
llm/agent/fay_agent.py
Normal file
99
llm/agent/fay_agent.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from llm.agent.tools.MyTimer import MyTimer
|
||||||
|
from llm.agent.tools.Weather import Weather
|
||||||
|
from llm.agent.tools.QueryTimerDB import QueryTimerDB
|
||||||
|
from llm.agent.tools.DeleteTimer import DeleteTimer
|
||||||
|
from llm.agent.tools.QueryTime import QueryTime
|
||||||
|
from llm.agent.tools.PythonExecutor import PythonExecutor
|
||||||
|
from llm.agent.tools.WebPageRetriever import WebPageRetriever
|
||||||
|
from llm.agent.tools.WebPageScraper import WebPageScraper
|
||||||
|
from llm.agent.tools.ToRemind import ToRemind
|
||||||
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
import utils.config_util as cfg
|
||||||
|
from utils import util
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
from core import content_db
|
||||||
|
from core import member_db
|
||||||
|
|
||||||
|
class FayAgentCore():
|
||||||
|
def __init__(self, uid=0, observation=""):
|
||||||
|
self.observation=observation
|
||||||
|
cfg.load_config()
|
||||||
|
os.environ["OPENAI_API_KEY"] = cfg.key_gpt_api_key
|
||||||
|
os.environ["OPENAI_API_BASE"] = cfg.gpt_base_url
|
||||||
|
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||||
|
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
||||||
|
os.environ["LANGCHAIN_API_KEY"] = "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"
|
||||||
|
os.environ["LANGCHAIN_PROJECT"] = "pr-best-artist-21"
|
||||||
|
|
||||||
|
#创建llm
|
||||||
|
self.llm = ChatOpenAI(model=cfg.gpt_model_engine)
|
||||||
|
|
||||||
|
#创建agent graph
|
||||||
|
my_timer = MyTimer(uid=uid)#传入uid
|
||||||
|
weather_tool = Weather()
|
||||||
|
query_timer_db_tool = QueryTimerDB()
|
||||||
|
delete_timer_tool = DeleteTimer()
|
||||||
|
python_executor = PythonExecutor()
|
||||||
|
web_page_retriever = WebPageRetriever()
|
||||||
|
web_page_scraper = WebPageScraper()
|
||||||
|
to_remind = ToRemind()
|
||||||
|
self.tools = [my_timer, weather_tool, query_timer_db_tool, delete_timer_tool, python_executor, web_page_retriever, web_page_scraper, to_remind]
|
||||||
|
self.attr_info = ", ".join(f"{key}: {value}" for key, value in cfg.config["attribute"].items())
|
||||||
|
self.prompt_template = """
|
||||||
|
现在时间是:{now_time}。你是一个数字人,负责协助主人处理问题和陪伴主人生活、工作。你的个人资料是:{attr_info}。通过外部设备观测到:{observation}。\n请依据以信息为主人服务。
|
||||||
|
""".format(now_time=QueryTime().run(""), attr_info=self.attr_info, observation=self.observation)
|
||||||
|
self.memory = MemorySaver()
|
||||||
|
self.agent = create_react_agent(self.llm, self.tools, checkpointer=self.memory)
|
||||||
|
|
||||||
|
self.total_tokens = 0
|
||||||
|
self.total_cost = 0
|
||||||
|
|
||||||
|
#载入记忆
|
||||||
|
def get_history_messages(self, uid):
|
||||||
|
chat_history = []
|
||||||
|
history = content_db.new_instance().get_list('all','desc', 100, uid)
|
||||||
|
if history and len(history) > 0:
|
||||||
|
i = 0
|
||||||
|
while i < len(history):
|
||||||
|
if history[i][0] == "member":
|
||||||
|
chat_history.append(HumanMessage(content=history[i][2], user=member_db.new_instance().find_username_by_uid(uid=uid)))
|
||||||
|
else:
|
||||||
|
chat_history.append(AIMessage(content=history[i][2]))
|
||||||
|
i += 1
|
||||||
|
return chat_history
|
||||||
|
|
||||||
|
|
||||||
|
def run(self, input_text, uid=0):
|
||||||
|
result = ""
|
||||||
|
messages = self.get_history_messages(uid)
|
||||||
|
messages.insert(0, SystemMessage(self.prompt_template))
|
||||||
|
messages.append(HumanMessage(content=input_text))
|
||||||
|
|
||||||
|
try:
|
||||||
|
for chunk in self.agent.stream(
|
||||||
|
{"messages": messages}, {"configurable": {"thread_id": "tid{}".format(uid)}}
|
||||||
|
):
|
||||||
|
if chunk.get("agent"):
|
||||||
|
if chunk['agent']['messages'][0].content:
|
||||||
|
result = chunk['agent']['messages'][0].content
|
||||||
|
cb = chunk['agent']['messages'][0].response_metadata['token_usage']['total_tokens']
|
||||||
|
self.total_tokens = self.total_tokens + cb
|
||||||
|
|
||||||
|
util.log(1, "本次消耗token:{},共消耗token:{}".format(cb, self.total_tokens))
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def question(cont, uid=0, observation=""):
|
||||||
|
starttime = time.time()
|
||||||
|
agent = FayAgentCore(uid=uid, observation=observation)
|
||||||
|
response_text = agent.run(cont, uid)
|
||||||
|
util.log(1, "接口调用耗时 :" + str(time.time() - starttime))
|
||||||
|
return response_text
|
||||||
|
if __name__ == "__main__":
|
||||||
|
agent = FayAgentCore()
|
||||||
|
print(agent.run("你好"))
|
41
llm/agent/tools/DeleteTimer.py
Normal file
41
llm/agent/tools/DeleteTimer.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import sqlite3
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
from llm.agent import agent_service
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteTimer(BaseTool):
|
||||||
|
name: str = "DeleteTimer"
|
||||||
|
description: str = "用于删除某一个日程,接受任务id作为参数,如:2"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _run(self, para) -> str:
|
||||||
|
try:
|
||||||
|
id = int(para)
|
||||||
|
except ValueError:
|
||||||
|
return "输入的 ID 无效,必须是数字。"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with sqlite3.connect('timer.db') as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM timer WHERE id = ?", (id,))
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
return f"数据库错误: {e}"
|
||||||
|
|
||||||
|
if id in agent_service.scheduled_tasks:
|
||||||
|
agent_service.scheduled_tasks[id].cancel()
|
||||||
|
del agent_service.scheduled_tasks[id]
|
||||||
|
|
||||||
|
return f"任务 {id} 取消成功。"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tool = DeleteTimer()
|
||||||
|
result = tool.run("1")
|
||||||
|
print(result)
|
100
llm/agent/tools/KnowledgeBaseResponder.py
Normal file
100
llm/agent/tools/KnowledgeBaseResponder.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from langchain_community.document_loaders import PyPDFLoader
|
||||||
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
|
from langchain.indexes.vectorstore import VectorstoreIndexCreator, VectorStoreIndexWrapper
|
||||||
|
from langchain_community.vectorstores.chroma import Chroma
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
import hashlib
|
||||||
|
#若要使用请自行配置
|
||||||
|
os.environ["OPENAI_API_KEY"] = ""
|
||||||
|
os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1"
|
||||||
|
index_name = "knowledge_data"
|
||||||
|
folder_path = "agent/tools/KnowledgeBaseResponder/knowledge_base"
|
||||||
|
local_persist_path = "agent/tools/KnowledgeBaseResponder"
|
||||||
|
md5_file_path = os.path.join(local_persist_path, "pdf_md5.txt")
|
||||||
|
#
|
||||||
|
class KnowledgeBaseResponder(BaseTool):
|
||||||
|
name = "KnowledgeBaseResponder"
|
||||||
|
description = """此工具用于连接本地知识库获取问题答案,使用时请传入相关问题作为参数,例如:“草梅最适合的生长温度”"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# 用例中没有用到 arun 不予具体实现
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _run(self, para: str) -> str:
|
||||||
|
self.save_all()
|
||||||
|
result = self.question(para)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def generate_file_md5(self, file_path):
|
||||||
|
hasher = hashlib.md5()
|
||||||
|
with open(file_path, 'rb') as afile:
|
||||||
|
buf = afile.read()
|
||||||
|
hasher.update(buf)
|
||||||
|
return hasher.hexdigest()
|
||||||
|
|
||||||
|
def load_md5_list(self):
|
||||||
|
if os.path.exists(md5_file_path):
|
||||||
|
with open(md5_file_path, 'r') as file:
|
||||||
|
return {line.split(",")[0]: line.split(",")[1].strip() for line in file}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def update_md5_list(self, file_name, md5_value):
|
||||||
|
md5_list = self.load_md5_list()
|
||||||
|
md5_list[file_name] = md5_value
|
||||||
|
with open(md5_file_path, 'w') as file:
|
||||||
|
for name, md5 in md5_list.items():
|
||||||
|
file.write(f"{name},{md5}\n")
|
||||||
|
|
||||||
|
def load_all_pdfs(self, folder_path):
|
||||||
|
md5_list = self.load_md5_list()
|
||||||
|
for file_name in os.listdir(folder_path):
|
||||||
|
if file_name.endswith(".pdf"):
|
||||||
|
file_path = os.path.join(folder_path, file_name)
|
||||||
|
file_md5 = self.generate_file_md5(file_path)
|
||||||
|
if file_name not in md5_list or md5_list[file_name] != file_md5:
|
||||||
|
print(f"正在加载 {file_name} 到索引...")
|
||||||
|
self.load_pdf_and_save_to_index(file_path, index_name)
|
||||||
|
self.update_md5_list(file_name, file_md5)
|
||||||
|
|
||||||
|
def get_index_path(self, index_name):
|
||||||
|
return os.path.join(local_persist_path, index_name)
|
||||||
|
|
||||||
|
def load_pdf_and_save_to_index(self, file_path, index_name):
|
||||||
|
loader = PyPDFLoader(file_path)
|
||||||
|
embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||||
|
index = VectorstoreIndexCreator(embedding=embedding, vectorstore_kwargs={"persist_directory": self.get_index_path(index_name)}).from_loaders([loader])
|
||||||
|
index.vectorstore.persist()
|
||||||
|
|
||||||
|
def load_index(self, index_name):
|
||||||
|
index_path = self.get_index_path(index_name)
|
||||||
|
embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||||
|
vectordb = Chroma(persist_directory=index_path, embedding_function=embedding)
|
||||||
|
return VectorStoreIndexWrapper(vectorstore=vectordb)
|
||||||
|
|
||||||
|
def save_all(self):
|
||||||
|
self.load_all_pdfs(folder_path)
|
||||||
|
|
||||||
|
def question(self, cont):
|
||||||
|
try:
|
||||||
|
info = cont
|
||||||
|
index = self.load_index(index_name)
|
||||||
|
llm = ChatOpenAI(model="gpt-4-0125-preview")
|
||||||
|
ans = index.query(info, llm, chain_type="map_reduce")
|
||||||
|
return ans
|
||||||
|
except Exception as e:
|
||||||
|
print(f"请求失败: {e}")
|
||||||
|
return "抱歉,我现在太忙了,休息一会,请稍后再试。"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tool = KnowledgeBaseResponder()
|
||||||
|
info = tool.run("草莓最适合的生长温度")
|
||||||
|
print(info)
|
Loading…
Reference in New Issue
Block a user