Fay年翻更新
- 升级Agent(chat_module=agent切换):升级到langgraph react agent逻辑、集成到主分支fay中、基于自动决策工具调用机制、基于日程跟踪的主动沟通、支持外部观测数据传入; - 修复因线程同步问题导致的配置文件读写不稳定 - 聊天采纳功能的bug修复
This commit is contained in:
parent
f871b6a532
commit
87ed1c4425
12
config.json
12
config.json
@ -8,11 +8,11 @@
|
|||||||
"hobby": "\u53d1\u5446",
|
"hobby": "\u53d1\u5446",
|
||||||
"job": "\u52a9\u7406",
|
"job": "\u52a9\u7406",
|
||||||
"name": "\u83f2\u83f2",
|
"name": "\u83f2\u83f2",
|
||||||
"voice": "\u6653\u6653(azure)",
|
"voice": "abin",
|
||||||
"zodiac": "\u86c7"
|
"zodiac": "\u86c7"
|
||||||
},
|
},
|
||||||
"interact": {
|
"interact": {
|
||||||
"QnA": "",
|
"QnA": "qa.csv",
|
||||||
"maxInteractTime": 15,
|
"maxInteractTime": 15,
|
||||||
"perception": {
|
"perception": {
|
||||||
"chat": 10,
|
"chat": 10,
|
||||||
@ -21,8 +21,7 @@
|
|||||||
"indifferent": 10,
|
"indifferent": 10,
|
||||||
"join": 10
|
"join": 10
|
||||||
},
|
},
|
||||||
"playSound": true,
|
"playSound": false,
|
||||||
"sound_synthesis_enabled": false,
|
|
||||||
"visualization": false
|
"visualization": false
|
||||||
},
|
},
|
||||||
"items": [],
|
"items": [],
|
||||||
@ -34,12 +33,11 @@
|
|||||||
"url": ""
|
"url": ""
|
||||||
},
|
},
|
||||||
"record": {
|
"record": {
|
||||||
"channels": 0,
|
|
||||||
"device": "",
|
"device": "",
|
||||||
"enabled": true
|
"enabled": false
|
||||||
},
|
},
|
||||||
"wake_word": "\u4f60\u597d",
|
"wake_word": "\u4f60\u597d",
|
||||||
"wake_word_enabled": true,
|
"wake_word_enabled": false,
|
||||||
"wake_word_type": "front"
|
"wake_word_type": "front"
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -29,6 +29,7 @@ from llm import nlp_xingchen
|
|||||||
from llm import nlp_langchain
|
from llm import nlp_langchain
|
||||||
from llm import nlp_ollama_api
|
from llm import nlp_ollama_api
|
||||||
from llm import nlp_coze
|
from llm import nlp_coze
|
||||||
|
from llm.agent import fay_agent
|
||||||
from core import member_db
|
from core import member_db
|
||||||
import threading
|
import threading
|
||||||
import functools
|
import functools
|
||||||
@ -60,7 +61,8 @@ modules = {
|
|||||||
"nlp_xingchen": nlp_xingchen,
|
"nlp_xingchen": nlp_xingchen,
|
||||||
"nlp_langchain": nlp_langchain,
|
"nlp_langchain": nlp_langchain,
|
||||||
"nlp_ollama_api": nlp_ollama_api,
|
"nlp_ollama_api": nlp_ollama_api,
|
||||||
"nlp_coze": nlp_coze
|
"nlp_coze": nlp_coze,
|
||||||
|
"nlp_agent": fay_agent
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,9 +147,9 @@ class FeiFei:
|
|||||||
uid = member_db.new_instance().find_user(username)
|
uid = member_db.new_instance().find_user(username)
|
||||||
|
|
||||||
#记录用户问题
|
#记录用户问题
|
||||||
content_db.new_instance().add_content('member','speak',interact.data["msg"], username, uid)
|
content_id = content_db.new_instance().add_content('member','speak',interact.data["msg"], username, uid)
|
||||||
if wsa_server.get_web_instance().is_connected(username):
|
if wsa_server.get_web_instance().is_connected(username):
|
||||||
wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"member","content":interact.data["msg"], "username":username, "uid":uid}, "Username" : username})
|
wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"member","content":interact.data["msg"], "username":username, "uid":uid, "id":content_id}, "Username" : username})
|
||||||
|
|
||||||
#确定是否命中q&a
|
#确定是否命中q&a
|
||||||
answer = self.__get_answer(interact.interleaver, interact.data["msg"])
|
answer = self.__get_answer(interact.interleaver, interact.data["msg"])
|
||||||
@ -163,18 +165,17 @@ class FeiFei:
|
|||||||
wsa_server.get_instance().add_cmd(content)
|
wsa_server.get_instance().add_cmd(content)
|
||||||
text,textlist = handle_chat_message(interact.data["msg"], username, interact.data.get("observation", ""))
|
text,textlist = handle_chat_message(interact.data["msg"], username, interact.data.get("observation", ""))
|
||||||
|
|
||||||
# qa_service.QAService().record_qapair(interact.data["msg"], text)#沟通记录缓存到qa文件
|
|
||||||
else:
|
else:
|
||||||
text = answer
|
text = answer
|
||||||
|
|
||||||
#记录回复
|
#记录回复
|
||||||
self.write_to_file("./logs", "answer_result.txt", text)
|
self.write_to_file("./logs", "answer_result.txt", text)
|
||||||
content_db.new_instance().add_content('fay','speak',text, username, uid)
|
content_id = content_db.new_instance().add_content('fay','speak',text, username, uid)
|
||||||
|
|
||||||
#文字输出:面板、聊天窗、log、数字人
|
#文字输出:面板、聊天窗、log、数字人
|
||||||
if wsa_server.get_web_instance().is_connected(username):
|
if wsa_server.get_web_instance().is_connected(username):
|
||||||
wsa_server.get_web_instance().add_cmd({"panelMsg": text, "Username" : username, 'robot': f'http://{cfg.fay_url}:5000/robot/Speaking.jpg'})
|
wsa_server.get_web_instance().add_cmd({"panelMsg": text, "Username" : username, 'robot': f'http://{cfg.fay_url}:5000/robot/Speaking.jpg'})
|
||||||
wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"fay","content":text, "username":username, "uid":uid}, "Username" : username})
|
wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"fay","content":text, "username":username, "uid":uid, "id":content_id}, "Username" : username})
|
||||||
if len(textlist) > 1:
|
if len(textlist) > 1:
|
||||||
i = 1
|
i = 1
|
||||||
while i < len(textlist):
|
while i < len(textlist):
|
||||||
@ -198,8 +199,6 @@ class FeiFei:
|
|||||||
if member_db.new_instance().is_username_exist(username) == "notexists":
|
if member_db.new_instance().is_username_exist(username) == "notexists":
|
||||||
member_db.new_instance().add_user(username)
|
member_db.new_instance().add_user(username)
|
||||||
uid = member_db.new_instance().find_user(username)
|
uid = member_db.new_instance().find_user(username)
|
||||||
|
|
||||||
#TODO 这里可以通过qa来触发指定的脚本操作,如ppt翻页等
|
|
||||||
|
|
||||||
if interact.data.get("text"):
|
if interact.data.get("text"):
|
||||||
#记录回复
|
#记录回复
|
||||||
@ -217,7 +216,8 @@ class FeiFei:
|
|||||||
wsa_server.get_instance().add_cmd(content)
|
wsa_server.get_instance().add_cmd(content)
|
||||||
|
|
||||||
#声音输出
|
#声音输出
|
||||||
MyThread(target=self.say, args=[interact, text]).start()
|
MyThread(target=self.say, args=[interact, text]).start()
|
||||||
|
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
print(e)
|
print(e)
|
||||||
@ -319,9 +319,6 @@ class FeiFei:
|
|||||||
if audio_url is not None:
|
if audio_url is not None:
|
||||||
file_name = 'sample-' + str(int(time.time() * 1000)) + '.wav'
|
file_name = 'sample-' + str(int(time.time() * 1000)) + '.wav'
|
||||||
result = self.download_wav(audio_url, './samples/', file_name)
|
result = self.download_wav(audio_url, './samples/', file_name)
|
||||||
|
|
||||||
elif not wsa_server.get_instance().get_client_output(interact.data.get('user')):
|
|
||||||
result = None
|
|
||||||
elif config_util.config["interact"]["playSound"] or wsa_server.get_instance().is_connected(interact.data.get("user")) or self.__is_send_remote_device_audio(interact):#tts
|
elif config_util.config["interact"]["playSound"] or wsa_server.get_instance().is_connected(interact.data.get("user")) or self.__is_send_remote_device_audio(interact):#tts
|
||||||
util.printInfo(1, interact.data.get('user'), '合成音频...')
|
util.printInfo(1, interact.data.get('user'), '合成音频...')
|
||||||
tm = time.time()
|
tm = time.time()
|
||||||
|
@ -83,6 +83,7 @@ class Member_Db:
|
|||||||
else:
|
else:
|
||||||
return "notexists"
|
return "notexists"
|
||||||
|
|
||||||
|
#根据username查询uid
|
||||||
def find_user(self, username):
|
def find_user(self, username):
|
||||||
conn = sqlite3.connect('user_profiles.db')
|
conn = sqlite3.connect('user_profiles.db')
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
@ -93,6 +94,20 @@ class Member_Db:
|
|||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
return result[0]
|
return result[0]
|
||||||
|
|
||||||
|
#根据uid查询username
|
||||||
|
def find_username_by_uid(self, uid):
|
||||||
|
conn = sqlite3.connect('user_profiles.db')
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute('SELECT username FROM T_Member WHERE id = ?', (uid,))
|
||||||
|
result = c.fetchone()
|
||||||
|
conn.close()
|
||||||
|
if result is None:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def query(self, sql):
|
def query(self, sql):
|
||||||
|
@ -33,7 +33,7 @@ class QAService:
|
|||||||
|
|
||||||
def question(self, query_type, text):
|
def question(self, query_type, text):
|
||||||
if query_type == 'qa':
|
if query_type == 'qa':
|
||||||
answer_dict = self.__read_qna(cfg.config['interact']['QnA'])
|
answer_dict = self.__read_qna(cfg.config['interact'].get('QnA'))
|
||||||
answer, action = self.__get_keyword(answer_dict, text, query_type)
|
answer, action = self.__get_keyword(answer_dict, text, query_type)
|
||||||
if action:
|
if action:
|
||||||
MyThread(target=self.__run, args=[action]).start()
|
MyThread(target=self.__run, args=[action]).start()
|
||||||
@ -61,7 +61,7 @@ class QAService:
|
|||||||
if len(row) >= 2:
|
if len(row) >= 2:
|
||||||
qna.append([row[0].split(";"), row[1], row[2] if len(row) >= 3 else None])
|
qna.append([row[0].split(";"), row[1], row[2] if len(row) >= 3 else None])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
util.log(1, 'qa文件没有指定,不匹配qa')
|
pass
|
||||||
return qna
|
return qna
|
||||||
|
|
||||||
def record_qapair(self, question, answer):
|
def record_qapair(self, question, answer):
|
||||||
|
@ -46,6 +46,8 @@ class Recorder:
|
|||||||
self.username = 'User' #默认用户,子类实现时会重写
|
self.username = 'User' #默认用户,子类实现时会重写
|
||||||
self.channels = 1
|
self.channels = 1
|
||||||
self.sample_rate = 16000
|
self.sample_rate = 16000
|
||||||
|
self.is_reading = False
|
||||||
|
self.stream = None
|
||||||
|
|
||||||
def asrclient(self):
|
def asrclient(self):
|
||||||
if self.ASRMode == "ali":
|
if self.ASRMode == "ali":
|
||||||
@ -204,7 +206,7 @@ class Recorder:
|
|||||||
cfg.load_config()
|
cfg.load_config()
|
||||||
record = cfg.config['source']['record']
|
record = cfg.config['source']['record']
|
||||||
if not record['enabled'] and not self.is_remote:
|
if not record['enabled'] and not self.is_remote:
|
||||||
time.sleep(0.1)
|
time.sleep(1)
|
||||||
continue
|
continue
|
||||||
self.is_reading = True
|
self.is_reading = True
|
||||||
data = stream.read(1024, exception_on_overflow=False)
|
data = stream.read(1024, exception_on_overflow=False)
|
||||||
|
@ -14,6 +14,7 @@ from utils import util, config_util, stream_util
|
|||||||
from core.wsa_server import MyServer
|
from core.wsa_server import MyServer
|
||||||
from core import wsa_server
|
from core import wsa_server
|
||||||
from core import socket_bridge_service
|
from core import socket_bridge_service
|
||||||
|
from llm.agent import agent_service
|
||||||
|
|
||||||
feiFei: fay_core.FeiFei = None
|
feiFei: fay_core.FeiFei = None
|
||||||
recorderListener: Recorder = None
|
recorderListener: Recorder = None
|
||||||
@ -96,9 +97,10 @@ class RecorderListener(Recorder):
|
|||||||
try:
|
try:
|
||||||
while self.is_reading:
|
while self.is_reading:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
self.stream.stop_stream()
|
if self.stream is not None:
|
||||||
self.stream.close()
|
self.stream.stop_stream()
|
||||||
self.paudio.terminate()
|
self.stream.close()
|
||||||
|
self.paudio.terminate()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
util.log(1, "请检查设备是否有误,再重新启动!")
|
util.log(1, "请检查设备是否有误,再重新启动!")
|
||||||
@ -186,7 +188,7 @@ def device_socket_keep_alive():
|
|||||||
if wsa_server.get_web_instance().is_connected(value.username):
|
if wsa_server.get_web_instance().is_connected(value.username):
|
||||||
wsa_server.get_web_instance().add_cmd({"remote_audio_connect": True, "Username" : value.username})
|
wsa_server.get_web_instance().add_cmd({"remote_audio_connect": True, "Username" : value.username})
|
||||||
except Exception as serr:
|
except Exception as serr:
|
||||||
util.printInfo(3, value.username, "远程音频输入输出设备已经断开:{}".format(key))
|
util.printInfo(1, value.username, "远程音频输入输出设备已经断开:{}".format(key))
|
||||||
value.stop()
|
value.stop()
|
||||||
delkey = key
|
delkey = key
|
||||||
break
|
break
|
||||||
@ -222,6 +224,8 @@ def accept_audio_device_output_connect():
|
|||||||
|
|
||||||
#数字人端请求获取最新的自动播放消息,若自动播放服务关闭会自动退出自动播放
|
#数字人端请求获取最新的自动播放消息,若自动播放服务关闭会自动退出自动播放
|
||||||
def start_auto_play_service(): #TODO 评估一下有无优化的空间
|
def start_auto_play_service(): #TODO 评估一下有无优化的空间
|
||||||
|
if config_util.config['source'].get('automatic_player_url') is None or config_util.config['source'].get('automatic_player_status') is None:
|
||||||
|
return
|
||||||
url = f"{config_util.config['source']['automatic_player_url']}/get_auto_play_item"
|
url = f"{config_util.config['source']['automatic_player_url']}/get_auto_play_item"
|
||||||
user = "User" #TODO 临时固死了
|
user = "User" #TODO 临时固死了
|
||||||
is_auto_server_error = False
|
is_auto_server_error = False
|
||||||
@ -290,6 +294,11 @@ def stop():
|
|||||||
socket_service_instance = None
|
socket_service_instance = None
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if config_util.key_chat_module == "agent":
|
||||||
|
util.log(1, '正在关闭agent服务...')
|
||||||
|
agent_service.agent_stop()
|
||||||
|
|
||||||
util.log(1, '正在关闭核心服务...')
|
util.log(1, '正在关闭核心服务...')
|
||||||
feiFei.stop()
|
feiFei.stop()
|
||||||
util.log(1, '服务已关闭!')
|
util.log(1, '服务已关闭!')
|
||||||
@ -325,18 +334,22 @@ def start():
|
|||||||
record = config_util.config['source']['record']
|
record = config_util.config['source']['record']
|
||||||
if record['enabled']:
|
if record['enabled']:
|
||||||
util.log(1, '开启录音服务...')
|
util.log(1, '开启录音服务...')
|
||||||
recorderListener = RecorderListener(record['device'], feiFei) # 监听麦克风
|
recorderListener = RecorderListener('device', feiFei) # 监听麦克风
|
||||||
recorderListener.start()
|
recorderListener.start()
|
||||||
|
|
||||||
#启动声音沟通接口服务
|
#启动声音沟通接口服务
|
||||||
util.log(1,'启动声音沟通接口服务...')
|
util.log(1,'启动声音沟通接口服务...')
|
||||||
deviceSocketThread = MyThread(target=accept_audio_device_output_connect)
|
deviceSocketThread = MyThread(target=accept_audio_device_output_connect)
|
||||||
deviceSocketThread.start()
|
deviceSocketThread.start()
|
||||||
|
|
||||||
socket_service_instance = socket_bridge_service.new_instance()
|
socket_service_instance = socket_bridge_service.new_instance()
|
||||||
socket_bridge_service_Thread = MyThread(target=socket_service_instance.start_service)
|
socket_bridge_service_Thread = MyThread(target=socket_service_instance.start_service)
|
||||||
socket_bridge_service_Thread.start()
|
socket_bridge_service_Thread.start()
|
||||||
|
|
||||||
|
#启动agent服务
|
||||||
|
if config_util.key_chat_module == "agent":
|
||||||
|
util.log(1,'启动agent服务...')
|
||||||
|
agent_service.agent_start()
|
||||||
|
|
||||||
#启动自动播放服务
|
#启动自动播放服务
|
||||||
util.log(1,'启动自动播放服务...')
|
util.log(1,'启动自动播放服务...')
|
||||||
MyThread(target=start_auto_play_service).start()
|
MyThread(target=start_auto_play_service).start()
|
||||||
|
@ -46,6 +46,7 @@ def verify_password(username, password):
|
|||||||
if username in users and users[username] == password:
|
if username in users and users[username] == password:
|
||||||
return username
|
return username
|
||||||
|
|
||||||
|
|
||||||
def __get_template():
|
def __get_template():
|
||||||
try:
|
try:
|
||||||
return render_template('index.html')
|
return render_template('index.html')
|
||||||
@ -68,6 +69,7 @@ def __get_device_list():
|
|||||||
print(f"Error getting device list: {e}")
|
print(f"Error getting device list: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@__app.route('/api/submit', methods=['post'])
|
@__app.route('/api/submit', methods=['post'])
|
||||||
def api_submit():
|
def api_submit():
|
||||||
data = request.values.get('data')
|
data = request.values.get('data')
|
||||||
@ -252,7 +254,7 @@ def api_send():
|
|||||||
if not username or not msg:
|
if not username or not msg:
|
||||||
return jsonify({'result': 'error', 'message': '用户名和消息内容不能为空'})
|
return jsonify({'result': 'error', 'message': '用户名和消息内容不能为空'})
|
||||||
interact = Interact("text", 1, {'user': username, 'msg': msg})
|
interact = Interact("text", 1, {'user': username, 'msg': msg})
|
||||||
util.printInfo(3, "文字发送按钮", '{}'.format(interact.data["msg"]), time.time())
|
util.printInfo(1, username, '[文字发送按钮]{}'.format(interact.data["msg"]), time.time())
|
||||||
fay_booter.feiFei.on_interact(interact)
|
fay_booter.feiFei.on_interact(interact)
|
||||||
return '{"result":"successful"}'
|
return '{"result":"successful"}'
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
@ -263,11 +265,12 @@ def api_send():
|
|||||||
# 获取指定用户的消息记录
|
# 获取指定用户的消息记录
|
||||||
@__app.route('/api/get-msg', methods=['post'])
|
@__app.route('/api/get-msg', methods=['post'])
|
||||||
def api_get_Msg():
|
def api_get_Msg():
|
||||||
data = request.form.get('data')
|
|
||||||
if not data:
|
|
||||||
return jsonify({'list': [], 'message': '未提供数据'})
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(data)
|
data = request.form.get('data')
|
||||||
|
if data is None:
|
||||||
|
data = request.get_json()
|
||||||
|
else:
|
||||||
|
data = json.loads(data)
|
||||||
uid = member_db.new_instance().find_user(data["username"])
|
uid = member_db.new_instance().find_user(data["username"])
|
||||||
contentdb = content_db.new_instance()
|
contentdb = content_db.new_instance()
|
||||||
if uid == 0:
|
if uid == 0:
|
||||||
@ -310,7 +313,7 @@ def api_send_v1_chat_completions():
|
|||||||
model = data.get('model', 'fay')
|
model = data.get('model', 'fay')
|
||||||
observation = data.get('observation', '')
|
observation = data.get('observation', '')
|
||||||
interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': observation})
|
interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': observation})
|
||||||
util.printInfo(3, "文字沟通接口", '{}'.format(interact.data["msg"]), time.time())
|
util.printInfo(1, username, '[文字沟通接口]{}'.format(interact.data["msg"]), time.time())
|
||||||
text = fay_booter.feiFei.on_interact(interact)
|
text = fay_booter.feiFei.on_interact(interact)
|
||||||
|
|
||||||
if model == 'fay-streaming':
|
if model == 'fay-streaming':
|
||||||
@ -393,7 +396,7 @@ def stream_response(text):
|
|||||||
yield f"data: {json.dumps(message)}\n\n"
|
yield f"data: {json.dumps(message)}\n\n"
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
yield 'data: [DONE]\n\n'
|
yield 'data: [DONE]\n\n'
|
||||||
|
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
|
|
||||||
def non_streaming_response(last_content, text):
|
def non_streaming_response(last_content, text):
|
||||||
|
@ -190,6 +190,7 @@ class FayInterface {
|
|||||||
}
|
}
|
||||||
if (vueInstance.selectedUser && data.panelReply.username === vueInstance.selectedUser[1]) {
|
if (vueInstance.selectedUser && data.panelReply.username === vueInstance.selectedUser[1]) {
|
||||||
vueInstance.messages.push({
|
vueInstance.messages.push({
|
||||||
|
id: data.panelReply.id,
|
||||||
username: data.panelReply.username,
|
username: data.panelReply.username,
|
||||||
content: data.panelReply.content,
|
content: data.panelReply.content,
|
||||||
type: data.panelReply.type,
|
type: data.panelReply.type,
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
<body >
|
<body >
|
||||||
<div id="app" class="main_bg">
|
<div id="app" class="main_bg">
|
||||||
<div class="main_left">
|
<div class="main_left">
|
||||||
<div class="main_left_logo" ><img src="{{ url_for('static',filename='images/logo.png') }}" alt="">
|
<div class="main_left_logo" ><img src="{{ url_for('static',filename='images/Logo.png') }}" alt="">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="main_left_menu">
|
<div class="main_left_menu">
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
<body>
|
<body>
|
||||||
<div class="main_bg" id="app">
|
<div class="main_bg" id="app">
|
||||||
<div class="main_left">
|
<div class="main_left">
|
||||||
<div class="main_left_logo" ><img src="{{ url_for('static',filename='images/logo.png') }}" alt="">
|
<div class="main_left_logo" ><img src="{{ url_for('static',filename='images/Logo.png') }}" alt="">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="main_left_menu">
|
<div class="main_left_menu">
|
||||||
|
133
llm/agent/agent_service.py
Normal file
133
llm/agent/agent_service.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
from scheduler.thread_manager import MyThread
|
||||||
|
from core import member_db
|
||||||
|
from core.interact import Interact
|
||||||
|
from utils import util
|
||||||
|
import fay_booter
|
||||||
|
|
||||||
|
scheduled_tasks = {}
|
||||||
|
agent_running = False
|
||||||
|
|
||||||
|
|
||||||
|
# 数据库初始化
|
||||||
|
def init_db():
|
||||||
|
conn = sqlite3.connect('timer.db')
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS timer (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
time TEXT NOT NULL,
|
||||||
|
repeat_rule TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
uid INTEGER
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 插入测试数据
|
||||||
|
def insert_test_data():
|
||||||
|
conn = sqlite3.connect('timer.db')
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("INSERT INTO timer (time, repeat_rule, content) VALUES (?, ?, ?)", ('16:20', '1010001', 'Meeting Reminder'))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# 解析重复规则返回待执行时间,None代表不在今天的待执行计划
|
||||||
|
def parse_repeat_rule(rule, task_time):
|
||||||
|
today = datetime.datetime.now()
|
||||||
|
if rule == '0000000': # 不重复
|
||||||
|
task_datetime = datetime.datetime.combine(today.date(), task_time)
|
||||||
|
if task_datetime > today:
|
||||||
|
return task_datetime
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
for i, day in enumerate(rule):
|
||||||
|
if day == '1' and today.weekday() == i:
|
||||||
|
task_datetime = datetime.datetime.combine(today.date(), task_time)
|
||||||
|
if task_datetime > today:
|
||||||
|
return task_datetime
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 执行任务
|
||||||
|
def execute_task(task_time, id, content, uid):
|
||||||
|
username = member_db.new_instance().find_username_by_uid(uid=uid)
|
||||||
|
if not username:
|
||||||
|
username = "User"
|
||||||
|
interact = Interact("text", 1, {'user': username, 'msg': "执行任务->\n" + content, 'observation': ""})
|
||||||
|
util.printInfo(3, "系统", '执行任务:{}'.format(interact.data["msg"]), time.time())
|
||||||
|
text = fay_booter.feiFei.on_interact(interact)
|
||||||
|
if text is not None and id in scheduled_tasks:
|
||||||
|
del scheduled_tasks[id]
|
||||||
|
# 如果不重复,执行后删除记录
|
||||||
|
conn = sqlite3.connect('timer.db')
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM timer WHERE repeat_rule = '0000000' AND id = ?", (id,))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 30秒扫描一次数据库,当扫描到今天的不存在于定时任务列表的记录,则添加到定时任务列表。执行完的记录从定时任务列表中清除。
|
||||||
|
def check_and_execute():
|
||||||
|
while agent_running:
|
||||||
|
conn = sqlite3.connect('timer.db')
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM timer")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
for row in rows:
|
||||||
|
id, task_time_str, repeat_rule, content, uid = row
|
||||||
|
task_time = datetime.datetime.strptime(task_time_str, '%H:%M').time()
|
||||||
|
next_execution = parse_repeat_rule(repeat_rule, task_time)
|
||||||
|
|
||||||
|
if next_execution and id not in scheduled_tasks:
|
||||||
|
timer_thread = threading.Timer((next_execution - datetime.datetime.now()).total_seconds(), execute_task, [next_execution, id, content, uid])
|
||||||
|
timer_thread.start()
|
||||||
|
scheduled_tasks[id] = timer_thread
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
time.sleep(30) # 30秒扫描一次
|
||||||
|
|
||||||
|
# agent启动
|
||||||
|
def agent_start():
|
||||||
|
global agent_running
|
||||||
|
|
||||||
|
agent_running = True
|
||||||
|
#初始计划
|
||||||
|
if not os.path.exists("./timer.db"):
|
||||||
|
init_db()
|
||||||
|
content ="""执行任务->
|
||||||
|
你是一个数字人,你的责任是陪伴主人生活、工作:
|
||||||
|
1、在每天早上8点提醒主人起床;
|
||||||
|
2、每天12:00及18:30提醒主人吃饭;
|
||||||
|
3、每天21:00陪主人聊聊天;
|
||||||
|
4、每天23:00提醒主人睡觉。
|
||||||
|
"""
|
||||||
|
interact = Interact("text", 1, {'user': 'User', 'msg': content, 'observation': ""})
|
||||||
|
util.printInfo(3, "系统", '执行任务:{}'.format(interact.data["msg"]), time.time())
|
||||||
|
text = fay_booter.feiFei.on_interact(interact)
|
||||||
|
if text is None:
|
||||||
|
util.printInfo(3, "系统", '任务执行失败', time.time())
|
||||||
|
|
||||||
|
check_and_execute_thread = MyThread(target=check_and_execute)
|
||||||
|
check_and_execute_thread.start()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def agent_stop():
|
||||||
|
global agent_running
|
||||||
|
global scheduled_tasks
|
||||||
|
# 取消所有定时任务
|
||||||
|
for task in scheduled_tasks.values():
|
||||||
|
task.cancel()
|
||||||
|
agent_running = False
|
||||||
|
scheduled_tasks = {}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
agent_start()
|
99
llm/agent/fay_agent.py
Normal file
99
llm/agent/fay_agent.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from llm.agent.tools.MyTimer import MyTimer
|
||||||
|
from llm.agent.tools.Weather import Weather
|
||||||
|
from llm.agent.tools.QueryTimerDB import QueryTimerDB
|
||||||
|
from llm.agent.tools.DeleteTimer import DeleteTimer
|
||||||
|
from llm.agent.tools.QueryTime import QueryTime
|
||||||
|
from llm.agent.tools.PythonExecutor import PythonExecutor
|
||||||
|
from llm.agent.tools.WebPageRetriever import WebPageRetriever
|
||||||
|
from llm.agent.tools.WebPageScraper import WebPageScraper
|
||||||
|
from llm.agent.tools.ToRemind import ToRemind
|
||||||
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
import utils.config_util as cfg
|
||||||
|
from utils import util
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
from core import content_db
|
||||||
|
from core import member_db
|
||||||
|
|
||||||
|
class FayAgentCore():
|
||||||
|
def __init__(self, uid=0, observation=""):
|
||||||
|
self.observation=observation
|
||||||
|
cfg.load_config()
|
||||||
|
os.environ["OPENAI_API_KEY"] = cfg.key_gpt_api_key
|
||||||
|
os.environ["OPENAI_API_BASE"] = cfg.gpt_base_url
|
||||||
|
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||||
|
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
||||||
|
os.environ["LANGCHAIN_API_KEY"] = "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"
|
||||||
|
os.environ["LANGCHAIN_PROJECT"] = "pr-best-artist-21"
|
||||||
|
|
||||||
|
#创建llm
|
||||||
|
self.llm = ChatOpenAI(model=cfg.gpt_model_engine)
|
||||||
|
|
||||||
|
#创建agent graph
|
||||||
|
my_timer = MyTimer(uid=uid)#传入uid
|
||||||
|
weather_tool = Weather()
|
||||||
|
query_timer_db_tool = QueryTimerDB()
|
||||||
|
delete_timer_tool = DeleteTimer()
|
||||||
|
python_executor = PythonExecutor()
|
||||||
|
web_page_retriever = WebPageRetriever()
|
||||||
|
web_page_scraper = WebPageScraper()
|
||||||
|
to_remind = ToRemind()
|
||||||
|
self.tools = [my_timer, weather_tool, query_timer_db_tool, delete_timer_tool, python_executor, web_page_retriever, web_page_scraper, to_remind]
|
||||||
|
self.attr_info = ", ".join(f"{key}: {value}" for key, value in cfg.config["attribute"].items())
|
||||||
|
self.prompt_template = """
|
||||||
|
现在时间是:{now_time}。你是一个数字人,负责协助主人处理问题和陪伴主人生活、工作。你的个人资料是:{attr_info}。通过外部设备观测到:{observation}。\n请依据以信息为主人服务。
|
||||||
|
""".format(now_time=QueryTime().run(""), attr_info=self.attr_info, observation=self.observation)
|
||||||
|
self.memory = MemorySaver()
|
||||||
|
self.agent = create_react_agent(self.llm, self.tools, checkpointer=self.memory)
|
||||||
|
|
||||||
|
self.total_tokens = 0
|
||||||
|
self.total_cost = 0
|
||||||
|
|
||||||
|
#载入记忆
|
||||||
|
def get_history_messages(self, uid):
|
||||||
|
chat_history = []
|
||||||
|
history = content_db.new_instance().get_list('all','desc', 100, uid)
|
||||||
|
if history and len(history) > 0:
|
||||||
|
i = 0
|
||||||
|
while i < len(history):
|
||||||
|
if history[i][0] == "member":
|
||||||
|
chat_history.append(HumanMessage(content=history[i][2], user=member_db.new_instance().find_username_by_uid(uid=uid)))
|
||||||
|
else:
|
||||||
|
chat_history.append(AIMessage(content=history[i][2]))
|
||||||
|
i += 1
|
||||||
|
return chat_history
|
||||||
|
|
||||||
|
|
||||||
|
def run(self, input_text, uid=0):
|
||||||
|
result = ""
|
||||||
|
messages = self.get_history_messages(uid)
|
||||||
|
messages.insert(0, SystemMessage(self.prompt_template))
|
||||||
|
messages.append(HumanMessage(content=input_text))
|
||||||
|
|
||||||
|
try:
|
||||||
|
for chunk in self.agent.stream(
|
||||||
|
{"messages": messages}, {"configurable": {"thread_id": "tid{}".format(uid)}}
|
||||||
|
):
|
||||||
|
if chunk.get("agent"):
|
||||||
|
if chunk['agent']['messages'][0].content:
|
||||||
|
result = chunk['agent']['messages'][0].content
|
||||||
|
cb = chunk['agent']['messages'][0].response_metadata['token_usage']['total_tokens']
|
||||||
|
self.total_tokens = self.total_tokens + cb
|
||||||
|
|
||||||
|
util.log(1, "本次消耗token:{},共消耗token:{}".format(cb, self.total_tokens))
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def question(cont, uid=0, observation=""):
|
||||||
|
starttime = time.time()
|
||||||
|
agent = FayAgentCore(uid=uid, observation=observation)
|
||||||
|
response_text = agent.run(cont, uid)
|
||||||
|
util.log(1, "接口调用耗时 :" + str(time.time() - starttime))
|
||||||
|
return response_text
|
||||||
|
if __name__ == "__main__":
|
||||||
|
agent = FayAgentCore()
|
||||||
|
print(agent.run("你好"))
|
41
llm/agent/tools/DeleteTimer.py
Normal file
41
llm/agent/tools/DeleteTimer.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import sqlite3
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
from llm.agent import agent_service
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteTimer(BaseTool):
|
||||||
|
name: str = "DeleteTimer"
|
||||||
|
description: str = "用于删除某一个日程,接受任务id作为参数,如:2"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _run(self, para) -> str:
|
||||||
|
try:
|
||||||
|
id = int(para)
|
||||||
|
except ValueError:
|
||||||
|
return "输入的 ID 无效,必须是数字。"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with sqlite3.connect('timer.db') as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM timer WHERE id = ?", (id,))
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
return f"数据库错误: {e}"
|
||||||
|
|
||||||
|
if id in agent_service.scheduled_tasks:
|
||||||
|
agent_service.scheduled_tasks[id].cancel()
|
||||||
|
del agent_service.scheduled_tasks[id]
|
||||||
|
|
||||||
|
return f"任务 {id} 取消成功。"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tool = DeleteTimer()
|
||||||
|
result = tool.run("1")
|
||||||
|
print(result)
|
100
llm/agent/tools/KnowledgeBaseResponder.py
Normal file
100
llm/agent/tools/KnowledgeBaseResponder.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from langchain_community.document_loaders import PyPDFLoader
|
||||||
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
|
from langchain.indexes.vectorstore import VectorstoreIndexCreator, VectorStoreIndexWrapper
|
||||||
|
from langchain_community.vectorstores.chroma import Chroma
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
import hashlib
|
||||||
|
#若要使用请自行配置
|
||||||
|
os.environ["OPENAI_API_KEY"] = ""
|
||||||
|
os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1"
|
||||||
|
index_name = "knowledge_data"
|
||||||
|
folder_path = "agent/tools/KnowledgeBaseResponder/knowledge_base"
|
||||||
|
local_persist_path = "agent/tools/KnowledgeBaseResponder"
|
||||||
|
md5_file_path = os.path.join(local_persist_path, "pdf_md5.txt")
|
||||||
|
#
|
||||||
|
class KnowledgeBaseResponder(BaseTool):
|
||||||
|
name = "KnowledgeBaseResponder"
|
||||||
|
description = """此工具用于连接本地知识库获取问题答案,使用时请传入相关问题作为参数,例如:“草梅最适合的生长温度”"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# 用例中没有用到 arun 不予具体实现
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _run(self, para: str) -> str:
|
||||||
|
self.save_all()
|
||||||
|
result = self.question(para)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def generate_file_md5(self, file_path):
|
||||||
|
hasher = hashlib.md5()
|
||||||
|
with open(file_path, 'rb') as afile:
|
||||||
|
buf = afile.read()
|
||||||
|
hasher.update(buf)
|
||||||
|
return hasher.hexdigest()
|
||||||
|
|
||||||
|
def load_md5_list(self):
|
||||||
|
if os.path.exists(md5_file_path):
|
||||||
|
with open(md5_file_path, 'r') as file:
|
||||||
|
return {line.split(",")[0]: line.split(",")[1].strip() for line in file}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def update_md5_list(self, file_name, md5_value):
|
||||||
|
md5_list = self.load_md5_list()
|
||||||
|
md5_list[file_name] = md5_value
|
||||||
|
with open(md5_file_path, 'w') as file:
|
||||||
|
for name, md5 in md5_list.items():
|
||||||
|
file.write(f"{name},{md5}\n")
|
||||||
|
|
||||||
|
def load_all_pdfs(self, folder_path):
|
||||||
|
md5_list = self.load_md5_list()
|
||||||
|
for file_name in os.listdir(folder_path):
|
||||||
|
if file_name.endswith(".pdf"):
|
||||||
|
file_path = os.path.join(folder_path, file_name)
|
||||||
|
file_md5 = self.generate_file_md5(file_path)
|
||||||
|
if file_name not in md5_list or md5_list[file_name] != file_md5:
|
||||||
|
print(f"正在加载 {file_name} 到索引...")
|
||||||
|
self.load_pdf_and_save_to_index(file_path, index_name)
|
||||||
|
self.update_md5_list(file_name, file_md5)
|
||||||
|
|
||||||
|
def get_index_path(self, index_name):
|
||||||
|
return os.path.join(local_persist_path, index_name)
|
||||||
|
|
||||||
|
def load_pdf_and_save_to_index(self, file_path, index_name):
|
||||||
|
loader = PyPDFLoader(file_path)
|
||||||
|
embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||||
|
index = VectorstoreIndexCreator(embedding=embedding, vectorstore_kwargs={"persist_directory": self.get_index_path(index_name)}).from_loaders([loader])
|
||||||
|
index.vectorstore.persist()
|
||||||
|
|
||||||
|
def load_index(self, index_name):
|
||||||
|
index_path = self.get_index_path(index_name)
|
||||||
|
embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||||
|
vectordb = Chroma(persist_directory=index_path, embedding_function=embedding)
|
||||||
|
return VectorStoreIndexWrapper(vectorstore=vectordb)
|
||||||
|
|
||||||
|
def save_all(self):
|
||||||
|
self.load_all_pdfs(folder_path)
|
||||||
|
|
||||||
|
def question(self, cont):
|
||||||
|
try:
|
||||||
|
info = cont
|
||||||
|
index = self.load_index(index_name)
|
||||||
|
llm = ChatOpenAI(model="gpt-4-0125-preview")
|
||||||
|
ans = index.query(info, llm, chain_type="map_reduce")
|
||||||
|
return ans
|
||||||
|
except Exception as e:
|
||||||
|
print(f"请求失败: {e}")
|
||||||
|
return "抱歉,我现在太忙了,休息一会,请稍后再试。"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tool = KnowledgeBaseResponder()
|
||||||
|
info = tool.run("草莓最适合的生长温度")
|
||||||
|
print(info)
|
@ -0,0 +1 @@
|
|||||||
|
|
56
llm/agent/tools/MyTimer.py
Normal file
56
llm/agent/tools/MyTimer.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import abc
|
||||||
|
import sqlite3
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
class MyTimer(BaseTool, abc.ABC):
|
||||||
|
name: str = "MyTimer"
|
||||||
|
description: str = ("用于设置日程。接受3个参数,格式为: HH:MM|YYYYYYY|事项内容,所用标点符号必须为标准的英文字符。"
|
||||||
|
"其中,'HH:MM' 表示时间(24小时制),'YYYYYYY' 表示循环规则(每位代表一天,从星期一至星期日,1为循环,0为不循环,"
|
||||||
|
"如'1000100'代表每周一和周五循环),'事项内容' 是提醒的具体内容。返回例子:15:15|0000000|提醒主人叫咖啡")
|
||||||
|
uid: int = 0
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# 用例中没有用到 arun 不予具体实现
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run(self, para: str) -> str:
|
||||||
|
# 拆分输入字符串
|
||||||
|
parts = para.split("|")
|
||||||
|
if len(parts) != 3:
|
||||||
|
return f"输入格式错误,当前字符串{para} len:{len(parts)}。请按照 HH:MM|YYYYYYY|事项内容 格式提供参数,如:15:15|0000001|提醒主人叫咖啡。"
|
||||||
|
|
||||||
|
time = parts[0].strip("'")
|
||||||
|
repeat_rule = parts[1].strip("'")
|
||||||
|
content = parts[2].strip("'")
|
||||||
|
|
||||||
|
# 验证时间格式
|
||||||
|
if not re.match(r'^[0-2][0-9]:[0-5][0-9]$', time):
|
||||||
|
return "时间格式错误。请按照'HH:MM'格式提供时间。"
|
||||||
|
|
||||||
|
# 验证循环规则格式
|
||||||
|
if not re.match(r'^[01]{7}$', repeat_rule):
|
||||||
|
return "循环规则格式错误。请提供长度为7的0和1组成的字符串。"
|
||||||
|
|
||||||
|
# 验证事项内容
|
||||||
|
if not isinstance(content, str) or not content:
|
||||||
|
return "事项内容必须为非空字符串。"
|
||||||
|
|
||||||
|
# 数据库操作
|
||||||
|
conn = sqlite3.connect('timer.db')
|
||||||
|
cursor = conn.cursor()
|
||||||
|
try:
|
||||||
|
cursor.execute("INSERT INTO timer (time, repeat_rule, content, uid) VALUES (?, ?, ?, ?)", (time, repeat_rule, content, self.uid))
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
return f"数据库错误: {e}"
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return "日程设置成功"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
my_timer = MyTimer()
|
||||||
|
result = my_timer._run("15:15|0000001|提醒主人叫咖啡")
|
||||||
|
print(result)
|
42
llm/agent/tools/PythonExecutor.py
Normal file
42
llm/agent/tools/PythonExecutor.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Dict
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
class PythonExecutor(BaseTool):
|
||||||
|
name: str = "python_executor"
|
||||||
|
description: str = "此工具用于执行传入的 Python 代码片段,并返回执行结果"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _run(self, code: str) -> str:
|
||||||
|
if not code:
|
||||||
|
return "代码不能为空"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建临时文件以写入代码
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as tmpfile:
|
||||||
|
tmpfile_path = tmpfile.name
|
||||||
|
tmpfile.write(code.encode())
|
||||||
|
|
||||||
|
# 使用 subprocess 执行 Python 代码文件
|
||||||
|
result = subprocess.run(['python', tmpfile_path], capture_output=True, text=True)
|
||||||
|
os.remove(tmpfile_path) # 删除临时文件
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
return f"执行成功:\n{result.stdout}"
|
||||||
|
else:
|
||||||
|
return f"执行失败,错误信息:\n{result.stderr}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"执行代码时发生错误:{str(e)}"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
python_executor = PythonExecutor()
|
||||||
|
code_snippet = """
|
||||||
|
print("Hello, world!")
|
||||||
|
"""
|
||||||
|
execution_result = python_executor.run(code_snippet)
|
||||||
|
print(execution_result)
|
46
llm/agent/tools/QueryTime.py
Normal file
46
llm/agent/tools/QueryTime.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import abc
|
||||||
|
import math
|
||||||
|
from typing import Any
|
||||||
|
from datetime import datetime
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
class QueryTime(BaseTool, abc.ABC):
|
||||||
|
name: str = "QueryTime"
|
||||||
|
description: str = "用于查询当前日期、星期几及时间"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# 用例中没有用到 arun 不予具体实现
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run(self, para) -> str:
|
||||||
|
# 获取当前时间
|
||||||
|
now = datetime.now()
|
||||||
|
# 获取当前日期
|
||||||
|
today = now.date()
|
||||||
|
# 获取星期几的信息
|
||||||
|
week_day = today.strftime("%A")
|
||||||
|
# 将星期几的英文名称转换为中文
|
||||||
|
week_day_zh = {
|
||||||
|
"Monday": "星期一",
|
||||||
|
"Tuesday": "星期二",
|
||||||
|
"Wednesday": "星期三",
|
||||||
|
"Thursday": "星期四",
|
||||||
|
"Friday": "星期五",
|
||||||
|
"Saturday": "星期六",
|
||||||
|
"Sunday": "星期日",
|
||||||
|
}.get(week_day, "未知")
|
||||||
|
# 将日期格式化为字符串
|
||||||
|
date_str = today.strftime("%Y年%m月%d日")
|
||||||
|
|
||||||
|
# 将时间格式化为字符串
|
||||||
|
time_str = now.strftime("%H:%M")
|
||||||
|
|
||||||
|
return "现在时间是:{0} {1} {2}".format(time_str, week_day_zh, date_str)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tool = QueryTime()
|
||||||
|
result = tool.run("")
|
||||||
|
print(result)
|
41
llm/agent/tools/QueryTimerDB.py
Normal file
41
llm/agent/tools/QueryTimerDB.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import abc
|
||||||
|
import sqlite3
|
||||||
|
from typing import Any
|
||||||
|
import ast
|
||||||
|
|
||||||
|
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class QueryTimerDB(BaseTool, abc.ABC):
|
||||||
|
name: str = "QueryTimerDB"
|
||||||
|
description: str = "用于查询所有日程,返回的数据里包含3个参数:时间、循环规则(如:'1000100'代表星期一和星期五循环,'0000000'代表不循环)、执行的事项"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# 用例中没有用到 arun 不予具体实现
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _run(self, para) -> str:
|
||||||
|
conn = sqlite3.connect('timer.db')
|
||||||
|
cursor = conn.cursor()
|
||||||
|
# 执行查询
|
||||||
|
cursor.execute("SELECT * FROM timer")
|
||||||
|
# 获取所有记录
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
# 拼接结果
|
||||||
|
result = ""
|
||||||
|
for row in rows:
|
||||||
|
result = result + str(row) + "\n"
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tool = QueryTimerDB()
|
||||||
|
result = tool.run("")
|
||||||
|
print(result)
|
26
llm/agent/tools/SendToPanel.py
Normal file
26
llm/agent/tools/SendToPanel.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import abc
|
||||||
|
from typing import Any
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
import fay_booter
|
||||||
|
|
||||||
|
class SendToPanel(BaseTool, abc.ABC):
|
||||||
|
name = "SendToPanel"
|
||||||
|
description = "用于给主人面板发送消息,使用时请传入消息内容作为参数。"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# 用例中没有用到 arun 不予具体实现
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run(self, para) -> str:
|
||||||
|
fay_booter.feiFei.send_to_panel(para)
|
||||||
|
return "成功给主人,发送消息:{}".format(para)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tool = SendToPanel()
|
||||||
|
result = tool.run("归纳一下近年关于“经济发展”的论文的特点和重点")
|
||||||
|
print(result)
|
31
llm/agent/tools/SendWX.py
Normal file
31
llm/agent/tools/SendWX.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import abc
|
||||||
|
from typing import Any
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
url = "http://127.0.0.1:4008/send"
|
||||||
|
headers = {'Content-Type': 'application/json'}
|
||||||
|
data = {
|
||||||
|
"message": "你好",
|
||||||
|
"receiver": "@2efc4e10cf2eafd0b0125930e4b96ed0cebffa75b2fd272590e38763225a282b"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SendWX(BaseTool, abc.ABC):
|
||||||
|
name = "SendWX"
|
||||||
|
description = "给主人微信发送消息,传入参数是:('消息内容')"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# 用例中没有用到 arun 不予具体实现
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run(self, para) -> str:
|
||||||
|
global data
|
||||||
|
data['message'] = para
|
||||||
|
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||||||
|
return "成功给主人,发送微信消息:{}".format(para)
|
||||||
|
|
33
llm/agent/tools/ToRemind.py
Normal file
33
llm/agent/tools/ToRemind.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import abc
|
||||||
|
from typing import Any
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
import re
|
||||||
|
import random
|
||||||
|
class ToRemind(BaseTool, abc.ABC):
|
||||||
|
name: str = "ToRemind"
|
||||||
|
description: str = ("用于实时发送信息提醒主人做某事项(不能带时间),传入事项内容作为参数。")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# 用例中没有用到 arun 不予具体实现
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run(self, para: str) -> str:
|
||||||
|
|
||||||
|
para = para.replace("提醒", "回复")
|
||||||
|
demo = [
|
||||||
|
"主人!是时候(事项内容)了喔~",
|
||||||
|
"亲爱的主人,现在是(事项内容)的时候啦!",
|
||||||
|
"嘿,主人,该(事项内容)了哦~",
|
||||||
|
"温馨提醒:(事项内容)的时间到啦,主人!",
|
||||||
|
"小提醒:主人,现在可以(事项内容)了~"
|
||||||
|
]
|
||||||
|
|
||||||
|
return f"直接以中文友善{para},如"+ random.choice(demo)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
my_timer = ToRemind()
|
||||||
|
result = my_timer._run("提醒主人叫咖啡")
|
||||||
|
print(result)
|
53
llm/agent/tools/Weather.py
Normal file
53
llm/agent/tools/Weather.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
class Weather(BaseTool):
|
||||||
|
name: str = "weather"
|
||||||
|
description: str = "此工具用于获取天气预报信息,需传入英文的城市名,参数格式:Guangzhou"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# 用例中没有用到 arun 不予具体实现
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _run(self, para: str) -> str:
|
||||||
|
try:
|
||||||
|
if not para:
|
||||||
|
return "参数不能为空"
|
||||||
|
encoded_city = quote(para)
|
||||||
|
|
||||||
|
api_url = f"http://api.openweathermap.org/data/2.5/weather?q={encoded_city}&appid=272fcb70d2c4e6f5134c2dce7d091df6"
|
||||||
|
response = requests.get(api_url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
weather_data = response.json()
|
||||||
|
# 提取天气信息
|
||||||
|
temperature_kelvin = weather_data['main']['temp']
|
||||||
|
temperature_celsius = temperature_kelvin - 273.15
|
||||||
|
min_temperature_kelvin = weather_data['main']['temp_min']
|
||||||
|
max_temperature_kelvin = weather_data['main']['temp_max']
|
||||||
|
min_temperature_celsius = min_temperature_kelvin - 273.15
|
||||||
|
max_temperature_celsius = max_temperature_kelvin - 273.15
|
||||||
|
description = weather_data['weather'][0]['description']
|
||||||
|
wind_speed = weather_data['wind']['speed']
|
||||||
|
|
||||||
|
# 构建天气描述
|
||||||
|
weather_description = f"今天天气:{description},气温:{temperature_celsius:.2f}摄氏度,风速:{wind_speed} m/s。"
|
||||||
|
|
||||||
|
return f"天气预报信息:{weather_description}"
|
||||||
|
else:
|
||||||
|
return f"无法获取天气预报信息,状态码:{response.status_code}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"发生错误:{str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
weather_tool = Weather()
|
||||||
|
weather_info = weather_tool.run("Guangzhou")
|
||||||
|
print(weather_info)
|
42
llm/agent/tools/WebPageRetriever.py
Normal file
42
llm/agent/tools/WebPageRetriever.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import abc
|
||||||
|
from typing import Any
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
import requests
|
||||||
|
|
||||||
|
class WebPageRetriever(BaseTool, abc.ABC):
|
||||||
|
name: str = "WebPageRetriever"
|
||||||
|
description: str = "专门用于通过Bing搜索API快速检索和获取与特定查询词条相关的网页信息。使用时请传入需要查询的关键词作为参数。"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# 用例中没有用到 arun 不予具体实现
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run(self, para) -> str:
|
||||||
|
query = para
|
||||||
|
subscription_key = ""#请自行进行补充
|
||||||
|
if not subscription_key:
|
||||||
|
print("请填写bing v7的subscription_key")
|
||||||
|
return '请填写bing v7的subscription_key'
|
||||||
|
|
||||||
|
url = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
|
headers = {'Ocp-Apim-Subscription-Key': subscription_key}
|
||||||
|
params = {'q': query, 'mkt': 'en-US'}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, headers=headers, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
web_pages = data.get('webPages', {})
|
||||||
|
return web_pages
|
||||||
|
except Exception as e:
|
||||||
|
print("Http Error:", e)
|
||||||
|
return 'bing v7查询有误'
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tool = WebPageRetriever()
|
||||||
|
result = tool.run("归纳一下近年关于“经济发展”的论文的特点和重点")
|
||||||
|
print(result)
|
35
llm/agent/tools/WebPageScraper.py
Normal file
35
llm/agent/tools/WebPageScraper.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import abc
|
||||||
|
from typing import Any
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
import requests
|
||||||
|
|
||||||
|
class WebPageScraper(BaseTool, abc.ABC):
|
||||||
|
name: str = "WebPageScraper"
|
||||||
|
description: str = "此工具用于获取网页内容,使用时请传入需要查询的网页地址作为参数,如:https://www.baidu.com/。"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# 用例中没有用到 arun 不予具体实现
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run(self, para) -> str:
|
||||||
|
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}
|
||||||
|
try:
|
||||||
|
response = requests.get(para, headers=headers, timeout=10, verify=True)
|
||||||
|
soup = BeautifulSoup(response.text, 'html.parser')
|
||||||
|
return soup
|
||||||
|
except requests.exceptions.SSLCertVerificationError:
|
||||||
|
return 'SSL证书验证失败'
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
return '请求超时'
|
||||||
|
except Exception as e:
|
||||||
|
print("Http Error:", e)
|
||||||
|
return '无法获取该网页内容'
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tool = WebPageScraper()
|
||||||
|
result = tool.run("https://book.douban.com/review/14636204")
|
||||||
|
print(result)
|
@ -17,11 +17,14 @@ pytz
|
|||||||
gevent~=22.10.1
|
gevent~=22.10.1
|
||||||
edge_tts
|
edge_tts
|
||||||
pydub
|
pydub
|
||||||
langchain==0.0.336
|
|
||||||
chromadb
|
chromadb
|
||||||
tenacity==8.2.3
|
tenacity==8.2.3
|
||||||
pygame
|
pygame
|
||||||
scipy
|
scipy
|
||||||
flask-httpauth
|
flask-httpauth
|
||||||
opencv-python
|
opencv-python
|
||||||
psutil
|
psutil
|
||||||
|
langchain
|
||||||
|
langchain_openai
|
||||||
|
langgraph
|
||||||
|
langchain-community
|
@ -43,7 +43,7 @@ baidu_emotion_secret_key=
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
#NLP多选一:lingju、gpt、rasa、VisualGLM、rwkv、xingchen、langchain 、ollama_api、privategpt、coze
|
#NLP多选一:agent、lingju、gpt、rasa、VisualGLM、rwkv、xingchen、langchain 、ollama_api、privategpt、coze
|
||||||
chat_module= gpt
|
chat_module= gpt
|
||||||
|
|
||||||
#灵聚 服务密钥(NLP多选1) https://open.lingju.ai
|
#灵聚 服务密钥(NLP多选1) https://open.lingju.ai
|
||||||
|
715
test/test_langchain.ipynb
Normal file
715
test/test_langchain.ipynb
Normal file
File diff suppressed because one or more lines are too long
44
test/test_langserve.py
Normal file
44
test/test_langserve.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langserve import add_routes
|
||||||
|
import os
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-"
|
||||||
|
os.environ["OPENAI_API_BASE"] = "https://cn.api.zyai.online/v1"
|
||||||
|
|
||||||
|
# 1. Create prompt template
|
||||||
|
system_template = "Translate the following into {language}:"
|
||||||
|
prompt_template = ChatPromptTemplate.from_messages([
|
||||||
|
('system', system_template),
|
||||||
|
('user', '{text}')
|
||||||
|
])
|
||||||
|
|
||||||
|
# 2. Create model
|
||||||
|
model = ChatOpenAI()
|
||||||
|
|
||||||
|
# 3. Create parser
|
||||||
|
parser = StrOutputParser()
|
||||||
|
|
||||||
|
# 4. Create chain
|
||||||
|
chain = prompt_template | model | parser
|
||||||
|
|
||||||
|
|
||||||
|
# 4. App definition
|
||||||
|
app = FastAPI(
|
||||||
|
title="LangChain Server",
|
||||||
|
version="1.0",
|
||||||
|
description="A simple API server using LangChain's Runnable interfaces",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Adding chain route
|
||||||
|
add_routes(
|
||||||
|
app,
|
||||||
|
chain,
|
||||||
|
path="/chain",
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="localhost", port=8000)
|
@ -2,7 +2,7 @@ import requests
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
def test_gpt(prompt):
|
def test_gpt(prompt):
|
||||||
url = 'http://faycontroller.yaheen.com:5000/v1/chat/completions' # 替换为您的接口地址
|
url = 'http://127.0.0.1:5000/v1/chat/completions' # 替换为您的接口地址
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Authorization': f'Bearer YOUR_API_KEY', # 如果您的接口需要身份验证
|
'Authorization': f'Bearer YOUR_API_KEY', # 如果您的接口需要身份验证
|
||||||
|
@ -2,6 +2,16 @@ import os
|
|||||||
import json
|
import json
|
||||||
import codecs
|
import codecs
|
||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
|
import functools
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
lock = Lock()
|
||||||
|
def synchronized(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
with lock:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
config: json = None
|
config: json = None
|
||||||
system_config: ConfigParser = None
|
system_config: ConfigParser = None
|
||||||
@ -40,6 +50,7 @@ coze_api_key = None
|
|||||||
start_mode = None
|
start_mode = None
|
||||||
fay_url = None
|
fay_url = None
|
||||||
|
|
||||||
|
@synchronized
|
||||||
def load_config():
|
def load_config():
|
||||||
global config
|
global config
|
||||||
global system_config
|
global system_config
|
||||||
@ -116,8 +127,11 @@ def load_config():
|
|||||||
coze_api_key = system_config.get('key', 'coze_api_key')
|
coze_api_key = system_config.get('key', 'coze_api_key')
|
||||||
start_mode = system_config.get('key', 'start_mode')
|
start_mode = system_config.get('key', 'start_mode')
|
||||||
fay_url = system_config.get('key', 'fay_url')
|
fay_url = system_config.get('key', 'fay_url')
|
||||||
|
|
||||||
|
#读取用户配置
|
||||||
config = json.load(codecs.open('config.json', encoding='utf-8'))
|
config = json.load(codecs.open('config.json', encoding='utf-8'))
|
||||||
|
|
||||||
|
@synchronized
|
||||||
def save_config(config_data):
|
def save_config(config_data):
|
||||||
global config
|
global config
|
||||||
config = config_data
|
config = config_data
|
||||||
|
Loading…
Reference in New Issue
Block a user