Fay年翻更新
- 升级Agent(chat_module=agent切换):升级到langgraph react agent逻辑、集成到主分支fay中、基于自动决策工具调用机制、基于日程跟踪的主动沟通、支持外部观测数据传入; - 修复因线程同步问题导致的配置文件读写不稳定 - 聊天采纳功能的bug修复
This commit is contained in:
parent
f871b6a532
commit
87ed1c4425
12
config.json
12
config.json
@ -8,11 +8,11 @@
|
||||
"hobby": "\u53d1\u5446",
|
||||
"job": "\u52a9\u7406",
|
||||
"name": "\u83f2\u83f2",
|
||||
"voice": "\u6653\u6653(azure)",
|
||||
"voice": "abin",
|
||||
"zodiac": "\u86c7"
|
||||
},
|
||||
"interact": {
|
||||
"QnA": "",
|
||||
"QnA": "qa.csv",
|
||||
"maxInteractTime": 15,
|
||||
"perception": {
|
||||
"chat": 10,
|
||||
@ -21,8 +21,7 @@
|
||||
"indifferent": 10,
|
||||
"join": 10
|
||||
},
|
||||
"playSound": true,
|
||||
"sound_synthesis_enabled": false,
|
||||
"playSound": false,
|
||||
"visualization": false
|
||||
},
|
||||
"items": [],
|
||||
@ -34,12 +33,11 @@
|
||||
"url": ""
|
||||
},
|
||||
"record": {
|
||||
"channels": 0,
|
||||
"device": "",
|
||||
"enabled": true
|
||||
"enabled": false
|
||||
},
|
||||
"wake_word": "\u4f60\u597d",
|
||||
"wake_word_enabled": true,
|
||||
"wake_word_enabled": false,
|
||||
"wake_word_type": "front"
|
||||
}
|
||||
}
|
@ -29,6 +29,7 @@ from llm import nlp_xingchen
|
||||
from llm import nlp_langchain
|
||||
from llm import nlp_ollama_api
|
||||
from llm import nlp_coze
|
||||
from llm.agent import fay_agent
|
||||
from core import member_db
|
||||
import threading
|
||||
import functools
|
||||
@ -60,7 +61,8 @@ modules = {
|
||||
"nlp_xingchen": nlp_xingchen,
|
||||
"nlp_langchain": nlp_langchain,
|
||||
"nlp_ollama_api": nlp_ollama_api,
|
||||
"nlp_coze": nlp_coze
|
||||
"nlp_coze": nlp_coze,
|
||||
"nlp_agent": fay_agent
|
||||
|
||||
}
|
||||
|
||||
@ -145,9 +147,9 @@ class FeiFei:
|
||||
uid = member_db.new_instance().find_user(username)
|
||||
|
||||
#记录用户问题
|
||||
content_db.new_instance().add_content('member','speak',interact.data["msg"], username, uid)
|
||||
content_id = content_db.new_instance().add_content('member','speak',interact.data["msg"], username, uid)
|
||||
if wsa_server.get_web_instance().is_connected(username):
|
||||
wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"member","content":interact.data["msg"], "username":username, "uid":uid}, "Username" : username})
|
||||
wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"member","content":interact.data["msg"], "username":username, "uid":uid, "id":content_id}, "Username" : username})
|
||||
|
||||
#确定是否命中q&a
|
||||
answer = self.__get_answer(interact.interleaver, interact.data["msg"])
|
||||
@ -163,18 +165,17 @@ class FeiFei:
|
||||
wsa_server.get_instance().add_cmd(content)
|
||||
text,textlist = handle_chat_message(interact.data["msg"], username, interact.data.get("observation", ""))
|
||||
|
||||
# qa_service.QAService().record_qapair(interact.data["msg"], text)#沟通记录缓存到qa文件
|
||||
else:
|
||||
text = answer
|
||||
|
||||
#记录回复
|
||||
self.write_to_file("./logs", "answer_result.txt", text)
|
||||
content_db.new_instance().add_content('fay','speak',text, username, uid)
|
||||
content_id = content_db.new_instance().add_content('fay','speak',text, username, uid)
|
||||
|
||||
#文字输出:面板、聊天窗、log、数字人
|
||||
if wsa_server.get_web_instance().is_connected(username):
|
||||
wsa_server.get_web_instance().add_cmd({"panelMsg": text, "Username" : username, 'robot': f'http://{cfg.fay_url}:5000/robot/Speaking.jpg'})
|
||||
wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"fay","content":text, "username":username, "uid":uid}, "Username" : username})
|
||||
wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"fay","content":text, "username":username, "uid":uid, "id":content_id}, "Username" : username})
|
||||
if len(textlist) > 1:
|
||||
i = 1
|
||||
while i < len(textlist):
|
||||
@ -198,8 +199,6 @@ class FeiFei:
|
||||
if member_db.new_instance().is_username_exist(username) == "notexists":
|
||||
member_db.new_instance().add_user(username)
|
||||
uid = member_db.new_instance().find_user(username)
|
||||
|
||||
#TODO 这里可以通过qa来触发指定的脚本操作,如ppt翻页等
|
||||
|
||||
if interact.data.get("text"):
|
||||
#记录回复
|
||||
@ -217,7 +216,8 @@ class FeiFei:
|
||||
wsa_server.get_instance().add_cmd(content)
|
||||
|
||||
#声音输出
|
||||
MyThread(target=self.say, args=[interact, text]).start()
|
||||
MyThread(target=self.say, args=[interact, text]).start()
|
||||
|
||||
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
@ -319,9 +319,6 @@ class FeiFei:
|
||||
if audio_url is not None:
|
||||
file_name = 'sample-' + str(int(time.time() * 1000)) + '.wav'
|
||||
result = self.download_wav(audio_url, './samples/', file_name)
|
||||
|
||||
elif not wsa_server.get_instance().get_client_output(interact.data.get('user')):
|
||||
result = None
|
||||
elif config_util.config["interact"]["playSound"] or wsa_server.get_instance().is_connected(interact.data.get("user")) or self.__is_send_remote_device_audio(interact):#tts
|
||||
util.printInfo(1, interact.data.get('user'), '合成音频...')
|
||||
tm = time.time()
|
||||
|
@ -83,6 +83,7 @@ class Member_Db:
|
||||
else:
|
||||
return "notexists"
|
||||
|
||||
#根据username查询uid
|
||||
def find_user(self, username):
|
||||
conn = sqlite3.connect('user_profiles.db')
|
||||
c = conn.cursor()
|
||||
@ -93,6 +94,20 @@ class Member_Db:
|
||||
return 0
|
||||
else:
|
||||
return result[0]
|
||||
|
||||
#根据uid查询username
|
||||
def find_username_by_uid(self, uid):
|
||||
conn = sqlite3.connect('user_profiles.db')
|
||||
c = conn.cursor()
|
||||
c.execute('SELECT username FROM T_Member WHERE id = ?', (uid,))
|
||||
result = c.fetchone()
|
||||
conn.close()
|
||||
if result is None:
|
||||
return 0
|
||||
else:
|
||||
return result[0]
|
||||
|
||||
|
||||
|
||||
@synchronized
|
||||
def query(self, sql):
|
||||
|
@ -33,7 +33,7 @@ class QAService:
|
||||
|
||||
def question(self, query_type, text):
|
||||
if query_type == 'qa':
|
||||
answer_dict = self.__read_qna(cfg.config['interact']['QnA'])
|
||||
answer_dict = self.__read_qna(cfg.config['interact'].get('QnA'))
|
||||
answer, action = self.__get_keyword(answer_dict, text, query_type)
|
||||
if action:
|
||||
MyThread(target=self.__run, args=[action]).start()
|
||||
@ -61,7 +61,7 @@ class QAService:
|
||||
if len(row) >= 2:
|
||||
qna.append([row[0].split(";"), row[1], row[2] if len(row) >= 3 else None])
|
||||
except Exception as e:
|
||||
util.log(1, 'qa文件没有指定,不匹配qa')
|
||||
pass
|
||||
return qna
|
||||
|
||||
def record_qapair(self, question, answer):
|
||||
|
@ -46,6 +46,8 @@ class Recorder:
|
||||
self.username = 'User' #默认用户,子类实现时会重写
|
||||
self.channels = 1
|
||||
self.sample_rate = 16000
|
||||
self.is_reading = False
|
||||
self.stream = None
|
||||
|
||||
def asrclient(self):
|
||||
if self.ASRMode == "ali":
|
||||
@ -204,7 +206,7 @@ class Recorder:
|
||||
cfg.load_config()
|
||||
record = cfg.config['source']['record']
|
||||
if not record['enabled'] and not self.is_remote:
|
||||
time.sleep(0.1)
|
||||
time.sleep(1)
|
||||
continue
|
||||
self.is_reading = True
|
||||
data = stream.read(1024, exception_on_overflow=False)
|
||||
|
@ -14,6 +14,7 @@ from utils import util, config_util, stream_util
|
||||
from core.wsa_server import MyServer
|
||||
from core import wsa_server
|
||||
from core import socket_bridge_service
|
||||
from llm.agent import agent_service
|
||||
|
||||
feiFei: fay_core.FeiFei = None
|
||||
recorderListener: Recorder = None
|
||||
@ -96,9 +97,10 @@ class RecorderListener(Recorder):
|
||||
try:
|
||||
while self.is_reading:
|
||||
time.sleep(0.1)
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
self.paudio.terminate()
|
||||
if self.stream is not None:
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
self.paudio.terminate()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
util.log(1, "请检查设备是否有误,再重新启动!")
|
||||
@ -186,7 +188,7 @@ def device_socket_keep_alive():
|
||||
if wsa_server.get_web_instance().is_connected(value.username):
|
||||
wsa_server.get_web_instance().add_cmd({"remote_audio_connect": True, "Username" : value.username})
|
||||
except Exception as serr:
|
||||
util.printInfo(3, value.username, "远程音频输入输出设备已经断开:{}".format(key))
|
||||
util.printInfo(1, value.username, "远程音频输入输出设备已经断开:{}".format(key))
|
||||
value.stop()
|
||||
delkey = key
|
||||
break
|
||||
@ -222,6 +224,8 @@ def accept_audio_device_output_connect():
|
||||
|
||||
#数字人端请求获取最新的自动播放消息,若自动播放服务关闭会自动退出自动播放
|
||||
def start_auto_play_service(): #TODO 评估一下有无优化的空间
|
||||
if config_util.config['source'].get('automatic_player_url') is None or config_util.config['source'].get('automatic_player_status') is None:
|
||||
return
|
||||
url = f"{config_util.config['source']['automatic_player_url']}/get_auto_play_item"
|
||||
user = "User" #TODO 临时固死了
|
||||
is_auto_server_error = False
|
||||
@ -290,6 +294,11 @@ def stop():
|
||||
socket_service_instance = None
|
||||
except:
|
||||
pass
|
||||
|
||||
if config_util.key_chat_module == "agent":
|
||||
util.log(1, '正在关闭agent服务...')
|
||||
agent_service.agent_stop()
|
||||
|
||||
util.log(1, '正在关闭核心服务...')
|
||||
feiFei.stop()
|
||||
util.log(1, '服务已关闭!')
|
||||
@ -325,18 +334,22 @@ def start():
|
||||
record = config_util.config['source']['record']
|
||||
if record['enabled']:
|
||||
util.log(1, '开启录音服务...')
|
||||
recorderListener = RecorderListener(record['device'], feiFei) # 监听麦克风
|
||||
recorderListener = RecorderListener('device', feiFei) # 监听麦克风
|
||||
recorderListener.start()
|
||||
|
||||
#启动声音沟通接口服务
|
||||
util.log(1,'启动声音沟通接口服务...')
|
||||
deviceSocketThread = MyThread(target=accept_audio_device_output_connect)
|
||||
deviceSocketThread.start()
|
||||
|
||||
socket_service_instance = socket_bridge_service.new_instance()
|
||||
socket_bridge_service_Thread = MyThread(target=socket_service_instance.start_service)
|
||||
socket_bridge_service_Thread.start()
|
||||
|
||||
#启动agent服务
|
||||
if config_util.key_chat_module == "agent":
|
||||
util.log(1,'启动agent服务...')
|
||||
agent_service.agent_start()
|
||||
|
||||
#启动自动播放服务
|
||||
util.log(1,'启动自动播放服务...')
|
||||
MyThread(target=start_auto_play_service).start()
|
||||
|
@ -46,6 +46,7 @@ def verify_password(username, password):
|
||||
if username in users and users[username] == password:
|
||||
return username
|
||||
|
||||
|
||||
def __get_template():
|
||||
try:
|
||||
return render_template('index.html')
|
||||
@ -68,6 +69,7 @@ def __get_device_list():
|
||||
print(f"Error getting device list: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@__app.route('/api/submit', methods=['post'])
|
||||
def api_submit():
|
||||
data = request.values.get('data')
|
||||
@ -252,7 +254,7 @@ def api_send():
|
||||
if not username or not msg:
|
||||
return jsonify({'result': 'error', 'message': '用户名和消息内容不能为空'})
|
||||
interact = Interact("text", 1, {'user': username, 'msg': msg})
|
||||
util.printInfo(3, "文字发送按钮", '{}'.format(interact.data["msg"]), time.time())
|
||||
util.printInfo(1, username, '[文字发送按钮]{}'.format(interact.data["msg"]), time.time())
|
||||
fay_booter.feiFei.on_interact(interact)
|
||||
return '{"result":"successful"}'
|
||||
except json.JSONDecodeError:
|
||||
@ -263,11 +265,12 @@ def api_send():
|
||||
# 获取指定用户的消息记录
|
||||
@__app.route('/api/get-msg', methods=['post'])
|
||||
def api_get_Msg():
|
||||
data = request.form.get('data')
|
||||
if not data:
|
||||
return jsonify({'list': [], 'message': '未提供数据'})
|
||||
try:
|
||||
data = json.loads(data)
|
||||
data = request.form.get('data')
|
||||
if data is None:
|
||||
data = request.get_json()
|
||||
else:
|
||||
data = json.loads(data)
|
||||
uid = member_db.new_instance().find_user(data["username"])
|
||||
contentdb = content_db.new_instance()
|
||||
if uid == 0:
|
||||
@ -310,7 +313,7 @@ def api_send_v1_chat_completions():
|
||||
model = data.get('model', 'fay')
|
||||
observation = data.get('observation', '')
|
||||
interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': observation})
|
||||
util.printInfo(3, "文字沟通接口", '{}'.format(interact.data["msg"]), time.time())
|
||||
util.printInfo(1, username, '[文字沟通接口]{}'.format(interact.data["msg"]), time.time())
|
||||
text = fay_booter.feiFei.on_interact(interact)
|
||||
|
||||
if model == 'fay-streaming':
|
||||
@ -393,7 +396,7 @@ def stream_response(text):
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
time.sleep(0.1)
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
|
||||
def non_streaming_response(last_content, text):
|
||||
|
@ -190,6 +190,7 @@ class FayInterface {
|
||||
}
|
||||
if (vueInstance.selectedUser && data.panelReply.username === vueInstance.selectedUser[1]) {
|
||||
vueInstance.messages.push({
|
||||
id: data.panelReply.id,
|
||||
username: data.panelReply.username,
|
||||
content: data.panelReply.content,
|
||||
type: data.panelReply.type,
|
||||
|
@ -14,7 +14,7 @@
|
||||
<body >
|
||||
<div id="app" class="main_bg">
|
||||
<div class="main_left">
|
||||
<div class="main_left_logo" ><img src="{{ url_for('static',filename='images/logo.png') }}" alt="">
|
||||
<div class="main_left_logo" ><img src="{{ url_for('static',filename='images/Logo.png') }}" alt="">
|
||||
</div>
|
||||
|
||||
<div class="main_left_menu">
|
||||
|
@ -14,7 +14,7 @@
|
||||
<body>
|
||||
<div class="main_bg" id="app">
|
||||
<div class="main_left">
|
||||
<div class="main_left_logo" ><img src="{{ url_for('static',filename='images/logo.png') }}" alt="">
|
||||
<div class="main_left_logo" ><img src="{{ url_for('static',filename='images/Logo.png') }}" alt="">
|
||||
</div>
|
||||
|
||||
<div class="main_left_menu">
|
||||
|
133
llm/agent/agent_service.py
Normal file
133
llm/agent/agent_service.py
Normal file
@ -0,0 +1,133 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
import datetime
|
||||
import time
|
||||
import os
|
||||
from scheduler.thread_manager import MyThread
|
||||
from core import member_db
|
||||
from core.interact import Interact
|
||||
from utils import util
|
||||
import fay_booter
|
||||
|
||||
scheduled_tasks = {}
|
||||
agent_running = False
|
||||
|
||||
|
||||
# 数据库初始化
|
||||
def init_db():
|
||||
conn = sqlite3.connect('timer.db')
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS timer (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
time TEXT NOT NULL,
|
||||
repeat_rule TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
uid INTEGER
|
||||
)
|
||||
''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
|
||||
# 插入测试数据
|
||||
def insert_test_data():
|
||||
conn = sqlite3.connect('timer.db')
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("INSERT INTO timer (time, repeat_rule, content) VALUES (?, ?, ?)", ('16:20', '1010001', 'Meeting Reminder'))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# 解析重复规则返回待执行时间,None代表不在今天的待执行计划
|
||||
def parse_repeat_rule(rule, task_time):
|
||||
today = datetime.datetime.now()
|
||||
if rule == '0000000': # 不重复
|
||||
task_datetime = datetime.datetime.combine(today.date(), task_time)
|
||||
if task_datetime > today:
|
||||
return task_datetime
|
||||
else:
|
||||
return None
|
||||
for i, day in enumerate(rule):
|
||||
if day == '1' and today.weekday() == i:
|
||||
task_datetime = datetime.datetime.combine(today.date(), task_time)
|
||||
if task_datetime > today:
|
||||
return task_datetime
|
||||
return None
|
||||
|
||||
# 执行任务
|
||||
def execute_task(task_time, id, content, uid):
|
||||
username = member_db.new_instance().find_username_by_uid(uid=uid)
|
||||
if not username:
|
||||
username = "User"
|
||||
interact = Interact("text", 1, {'user': username, 'msg': "执行任务->\n" + content, 'observation': ""})
|
||||
util.printInfo(3, "系统", '执行任务:{}'.format(interact.data["msg"]), time.time())
|
||||
text = fay_booter.feiFei.on_interact(interact)
|
||||
if text is not None and id in scheduled_tasks:
|
||||
del scheduled_tasks[id]
|
||||
# 如果不重复,执行后删除记录
|
||||
conn = sqlite3.connect('timer.db')
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM timer WHERE repeat_rule = '0000000' AND id = ?", (id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
# 30秒扫描一次数据库,当扫描到今天的不存在于定时任务列表的记录,则添加到定时任务列表。执行完的记录从定时任务列表中清除。
|
||||
def check_and_execute():
|
||||
while agent_running:
|
||||
conn = sqlite3.connect('timer.db')
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM timer")
|
||||
rows = cursor.fetchall()
|
||||
for row in rows:
|
||||
id, task_time_str, repeat_rule, content, uid = row
|
||||
task_time = datetime.datetime.strptime(task_time_str, '%H:%M').time()
|
||||
next_execution = parse_repeat_rule(repeat_rule, task_time)
|
||||
|
||||
if next_execution and id not in scheduled_tasks:
|
||||
timer_thread = threading.Timer((next_execution - datetime.datetime.now()).total_seconds(), execute_task, [next_execution, id, content, uid])
|
||||
timer_thread.start()
|
||||
scheduled_tasks[id] = timer_thread
|
||||
|
||||
conn.close()
|
||||
time.sleep(30) # 30秒扫描一次
|
||||
|
||||
# agent启动
|
||||
def agent_start():
|
||||
global agent_running
|
||||
|
||||
agent_running = True
|
||||
#初始计划
|
||||
if not os.path.exists("./timer.db"):
|
||||
init_db()
|
||||
content ="""执行任务->
|
||||
你是一个数字人,你的责任是陪伴主人生活、工作:
|
||||
1、在每天早上8点提醒主人起床;
|
||||
2、每天12:00及18:30提醒主人吃饭;
|
||||
3、每天21:00陪主人聊聊天;
|
||||
4、每天23:00提醒主人睡觉。
|
||||
"""
|
||||
interact = Interact("text", 1, {'user': 'User', 'msg': content, 'observation': ""})
|
||||
util.printInfo(3, "系统", '执行任务:{}'.format(interact.data["msg"]), time.time())
|
||||
text = fay_booter.feiFei.on_interact(interact)
|
||||
if text is None:
|
||||
util.printInfo(3, "系统", '任务执行失败', time.time())
|
||||
|
||||
check_and_execute_thread = MyThread(target=check_and_execute)
|
||||
check_and_execute_thread.start()
|
||||
|
||||
|
||||
|
||||
def agent_stop():
|
||||
global agent_running
|
||||
global scheduled_tasks
|
||||
# 取消所有定时任务
|
||||
for task in scheduled_tasks.values():
|
||||
task.cancel()
|
||||
agent_running = False
|
||||
scheduled_tasks = {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
agent_start()
|
99
llm/agent/fay_agent.py
Normal file
99
llm/agent/fay_agent.py
Normal file
@ -0,0 +1,99 @@
|
||||
import os
|
||||
import time
|
||||
from llm.agent.tools.MyTimer import MyTimer
|
||||
from llm.agent.tools.Weather import Weather
|
||||
from llm.agent.tools.QueryTimerDB import QueryTimerDB
|
||||
from llm.agent.tools.DeleteTimer import DeleteTimer
|
||||
from llm.agent.tools.QueryTime import QueryTime
|
||||
from llm.agent.tools.PythonExecutor import PythonExecutor
|
||||
from llm.agent.tools.WebPageRetriever import WebPageRetriever
|
||||
from llm.agent.tools.WebPageScraper import WebPageScraper
|
||||
from llm.agent.tools.ToRemind import ToRemind
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
import utils.config_util as cfg
|
||||
from utils import util
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from core import content_db
|
||||
from core import member_db
|
||||
|
||||
class FayAgentCore():
|
||||
def __init__(self, uid=0, observation=""):
|
||||
self.observation=observation
|
||||
cfg.load_config()
|
||||
os.environ["OPENAI_API_KEY"] = cfg.key_gpt_api_key
|
||||
os.environ["OPENAI_API_BASE"] = cfg.gpt_base_url
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
||||
os.environ["LANGCHAIN_API_KEY"] = "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"
|
||||
os.environ["LANGCHAIN_PROJECT"] = "pr-best-artist-21"
|
||||
|
||||
#创建llm
|
||||
self.llm = ChatOpenAI(model=cfg.gpt_model_engine)
|
||||
|
||||
#创建agent graph
|
||||
my_timer = MyTimer(uid=uid)#传入uid
|
||||
weather_tool = Weather()
|
||||
query_timer_db_tool = QueryTimerDB()
|
||||
delete_timer_tool = DeleteTimer()
|
||||
python_executor = PythonExecutor()
|
||||
web_page_retriever = WebPageRetriever()
|
||||
web_page_scraper = WebPageScraper()
|
||||
to_remind = ToRemind()
|
||||
self.tools = [my_timer, weather_tool, query_timer_db_tool, delete_timer_tool, python_executor, web_page_retriever, web_page_scraper, to_remind]
|
||||
self.attr_info = ", ".join(f"{key}: {value}" for key, value in cfg.config["attribute"].items())
|
||||
self.prompt_template = """
|
||||
现在时间是:{now_time}。你是一个数字人,负责协助主人处理问题和陪伴主人生活、工作。你的个人资料是:{attr_info}。通过外部设备观测到:{observation}。\n请依据以信息为主人服务。
|
||||
""".format(now_time=QueryTime().run(""), attr_info=self.attr_info, observation=self.observation)
|
||||
self.memory = MemorySaver()
|
||||
self.agent = create_react_agent(self.llm, self.tools, checkpointer=self.memory)
|
||||
|
||||
self.total_tokens = 0
|
||||
self.total_cost = 0
|
||||
|
||||
#载入记忆
|
||||
def get_history_messages(self, uid):
|
||||
chat_history = []
|
||||
history = content_db.new_instance().get_list('all','desc', 100, uid)
|
||||
if history and len(history) > 0:
|
||||
i = 0
|
||||
while i < len(history):
|
||||
if history[i][0] == "member":
|
||||
chat_history.append(HumanMessage(content=history[i][2], user=member_db.new_instance().find_username_by_uid(uid=uid)))
|
||||
else:
|
||||
chat_history.append(AIMessage(content=history[i][2]))
|
||||
i += 1
|
||||
return chat_history
|
||||
|
||||
|
||||
def run(self, input_text, uid=0):
|
||||
result = ""
|
||||
messages = self.get_history_messages(uid)
|
||||
messages.insert(0, SystemMessage(self.prompt_template))
|
||||
messages.append(HumanMessage(content=input_text))
|
||||
|
||||
try:
|
||||
for chunk in self.agent.stream(
|
||||
{"messages": messages}, {"configurable": {"thread_id": "tid{}".format(uid)}}
|
||||
):
|
||||
if chunk.get("agent"):
|
||||
if chunk['agent']['messages'][0].content:
|
||||
result = chunk['agent']['messages'][0].content
|
||||
cb = chunk['agent']['messages'][0].response_metadata['token_usage']['total_tokens']
|
||||
self.total_tokens = self.total_tokens + cb
|
||||
|
||||
util.log(1, "本次消耗token:{},共消耗token:{}".format(cb, self.total_tokens))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return result
|
||||
|
||||
def question(cont, uid=0, observation=""):
|
||||
starttime = time.time()
|
||||
agent = FayAgentCore(uid=uid, observation=observation)
|
||||
response_text = agent.run(cont, uid)
|
||||
util.log(1, "接口调用耗时 :" + str(time.time() - starttime))
|
||||
return response_text
|
||||
if __name__ == "__main__":
|
||||
agent = FayAgentCore()
|
||||
print(agent.run("你好"))
|
41
llm/agent/tools/DeleteTimer.py
Normal file
41
llm/agent/tools/DeleteTimer.py
Normal file
@ -0,0 +1,41 @@
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from llm.agent import agent_service
|
||||
|
||||
|
||||
class DeleteTimer(BaseTool):
|
||||
name: str = "DeleteTimer"
|
||||
description: str = "用于删除某一个日程,接受任务id作为参数,如:2"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def _run(self, para) -> str:
|
||||
try:
|
||||
id = int(para)
|
||||
except ValueError:
|
||||
return "输入的 ID 无效,必须是数字。"
|
||||
|
||||
try:
|
||||
with sqlite3.connect('timer.db') as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM timer WHERE id = ?", (id,))
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
return f"数据库错误: {e}"
|
||||
|
||||
if id in agent_service.scheduled_tasks:
|
||||
agent_service.scheduled_tasks[id].cancel()
|
||||
del agent_service.scheduled_tasks[id]
|
||||
|
||||
return f"任务 {id} 取消成功。"
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tool = DeleteTimer()
|
||||
result = tool.run("1")
|
||||
print(result)
|
100
llm/agent/tools/KnowledgeBaseResponder.py
Normal file
100
llm/agent/tools/KnowledgeBaseResponder.py
Normal file
@ -0,0 +1,100 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.indexes.vectorstore import VectorstoreIndexCreator, VectorStoreIndexWrapper
|
||||
from langchain_community.vectorstores.chroma import Chroma
|
||||
from langchain_openai import ChatOpenAI
|
||||
import hashlib
|
||||
#若要使用请自行配置
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1"
|
||||
index_name = "knowledge_data"
|
||||
folder_path = "agent/tools/KnowledgeBaseResponder/knowledge_base"
|
||||
local_persist_path = "agent/tools/KnowledgeBaseResponder"
|
||||
md5_file_path = os.path.join(local_persist_path, "pdf_md5.txt")
|
||||
#
|
||||
class KnowledgeBaseResponder(BaseTool):
|
||||
name = "KnowledgeBaseResponder"
|
||||
description = """此工具用于连接本地知识库获取问题答案,使用时请传入相关问题作为参数,例如:“草梅最适合的生长温度”"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# 用例中没有用到 arun 不予具体实现
|
||||
pass
|
||||
|
||||
|
||||
def _run(self, para: str) -> str:
|
||||
self.save_all()
|
||||
result = self.question(para)
|
||||
return result
|
||||
|
||||
def generate_file_md5(self, file_path):
|
||||
hasher = hashlib.md5()
|
||||
with open(file_path, 'rb') as afile:
|
||||
buf = afile.read()
|
||||
hasher.update(buf)
|
||||
return hasher.hexdigest()
|
||||
|
||||
def load_md5_list(self):
|
||||
if os.path.exists(md5_file_path):
|
||||
with open(md5_file_path, 'r') as file:
|
||||
return {line.split(",")[0]: line.split(",")[1].strip() for line in file}
|
||||
return {}
|
||||
|
||||
def update_md5_list(self, file_name, md5_value):
|
||||
md5_list = self.load_md5_list()
|
||||
md5_list[file_name] = md5_value
|
||||
with open(md5_file_path, 'w') as file:
|
||||
for name, md5 in md5_list.items():
|
||||
file.write(f"{name},{md5}\n")
|
||||
|
||||
def load_all_pdfs(self, folder_path):
|
||||
md5_list = self.load_md5_list()
|
||||
for file_name in os.listdir(folder_path):
|
||||
if file_name.endswith(".pdf"):
|
||||
file_path = os.path.join(folder_path, file_name)
|
||||
file_md5 = self.generate_file_md5(file_path)
|
||||
if file_name not in md5_list or md5_list[file_name] != file_md5:
|
||||
print(f"正在加载 {file_name} 到索引...")
|
||||
self.load_pdf_and_save_to_index(file_path, index_name)
|
||||
self.update_md5_list(file_name, file_md5)
|
||||
|
||||
def get_index_path(self, index_name):
|
||||
return os.path.join(local_persist_path, index_name)
|
||||
|
||||
def load_pdf_and_save_to_index(self, file_path, index_name):
|
||||
loader = PyPDFLoader(file_path)
|
||||
embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||
index = VectorstoreIndexCreator(embedding=embedding, vectorstore_kwargs={"persist_directory": self.get_index_path(index_name)}).from_loaders([loader])
|
||||
index.vectorstore.persist()
|
||||
|
||||
def load_index(self, index_name):
|
||||
index_path = self.get_index_path(index_name)
|
||||
embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||
vectordb = Chroma(persist_directory=index_path, embedding_function=embedding)
|
||||
return VectorStoreIndexWrapper(vectorstore=vectordb)
|
||||
|
||||
def save_all(self):
|
||||
self.load_all_pdfs(folder_path)
|
||||
|
||||
def question(self, cont):
|
||||
try:
|
||||
info = cont
|
||||
index = self.load_index(index_name)
|
||||
llm = ChatOpenAI(model="gpt-4-0125-preview")
|
||||
ans = index.query(info, llm, chain_type="map_reduce")
|
||||
return ans
|
||||
except Exception as e:
|
||||
print(f"请求失败: {e}")
|
||||
return "抱歉,我现在太忙了,休息一会,请稍后再试。"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tool = KnowledgeBaseResponder()
|
||||
info = tool.run("草莓最适合的生长温度")
|
||||
print(info)
|
@ -0,0 +1 @@
|
||||
|
56
llm/agent/tools/MyTimer.py
Normal file
56
llm/agent/tools/MyTimer.py
Normal file
@ -0,0 +1,56 @@
|
||||
import abc
|
||||
import sqlite3
|
||||
import re
|
||||
from typing import Any
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
class MyTimer(BaseTool, abc.ABC):
|
||||
name: str = "MyTimer"
|
||||
description: str = ("用于设置日程。接受3个参数,格式为: HH:MM|YYYYYYY|事项内容,所用标点符号必须为标准的英文字符。"
|
||||
"其中,'HH:MM' 表示时间(24小时制),'YYYYYYY' 表示循环规则(每位代表一天,从星期一至星期日,1为循环,0为不循环,"
|
||||
"如'1000100'代表每周一和周五循环),'事项内容' 是提醒的具体内容。返回例子:15:15|0000000|提醒主人叫咖啡")
|
||||
uid: int = 0
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# 用例中没有用到 arun 不予具体实现
|
||||
pass
|
||||
|
||||
def _run(self, para: str) -> str:
|
||||
# 拆分输入字符串
|
||||
parts = para.split("|")
|
||||
if len(parts) != 3:
|
||||
return f"输入格式错误,当前字符串{para} len:{len(parts)}。请按照 HH:MM|YYYYYYY|事项内容 格式提供参数,如:15:15|0000001|提醒主人叫咖啡。"
|
||||
|
||||
time = parts[0].strip("'")
|
||||
repeat_rule = parts[1].strip("'")
|
||||
content = parts[2].strip("'")
|
||||
|
||||
# 验证时间格式
|
||||
if not re.match(r'^[0-2][0-9]:[0-5][0-9]$', time):
|
||||
return "时间格式错误。请按照'HH:MM'格式提供时间。"
|
||||
|
||||
# 验证循环规则格式
|
||||
if not re.match(r'^[01]{7}$', repeat_rule):
|
||||
return "循环规则格式错误。请提供长度为7的0和1组成的字符串。"
|
||||
|
||||
# 验证事项内容
|
||||
if not isinstance(content, str) or not content:
|
||||
return "事项内容必须为非空字符串。"
|
||||
|
||||
# 数据库操作
|
||||
conn = sqlite3.connect('timer.db')
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute("INSERT INTO timer (time, repeat_rule, content, uid) VALUES (?, ?, ?, ?)", (time, repeat_rule, content, self.uid))
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
return f"数据库错误: {e}"
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return "日程设置成功"
|
||||
|
||||
if __name__ == "__main__":
|
||||
my_timer = MyTimer()
|
||||
result = my_timer._run("15:15|0000001|提醒主人叫咖啡")
|
||||
print(result)
|
42
llm/agent/tools/PythonExecutor.py
Normal file
42
llm/agent/tools/PythonExecutor.py
Normal file
@ -0,0 +1,42 @@
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
import subprocess
|
||||
import tempfile
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
class PythonExecutor(BaseTool):
|
||||
name: str = "python_executor"
|
||||
description: str = "此工具用于执行传入的 Python 代码片段,并返回执行结果"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def _run(self, code: str) -> str:
|
||||
if not code:
|
||||
return "代码不能为空"
|
||||
|
||||
try:
|
||||
# 创建临时文件以写入代码
|
||||
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as tmpfile:
|
||||
tmpfile_path = tmpfile.name
|
||||
tmpfile.write(code.encode())
|
||||
|
||||
# 使用 subprocess 执行 Python 代码文件
|
||||
result = subprocess.run(['python', tmpfile_path], capture_output=True, text=True)
|
||||
os.remove(tmpfile_path) # 删除临时文件
|
||||
|
||||
if result.returncode == 0:
|
||||
return f"执行成功:\n{result.stdout}"
|
||||
else:
|
||||
return f"执行失败,错误信息:\n{result.stderr}"
|
||||
|
||||
except Exception as e:
|
||||
return f"执行代码时发生错误:{str(e)}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
python_executor = PythonExecutor()
|
||||
code_snippet = """
|
||||
print("Hello, world!")
|
||||
"""
|
||||
execution_result = python_executor.run(code_snippet)
|
||||
print(execution_result)
|
46
llm/agent/tools/QueryTime.py
Normal file
46
llm/agent/tools/QueryTime.py
Normal file
@ -0,0 +1,46 @@
|
||||
import abc
|
||||
import math
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
class QueryTime(BaseTool, abc.ABC):
|
||||
name: str = "QueryTime"
|
||||
description: str = "用于查询当前日期、星期几及时间"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# 用例中没有用到 arun 不予具体实现
|
||||
pass
|
||||
|
||||
def _run(self, para) -> str:
|
||||
# 获取当前时间
|
||||
now = datetime.now()
|
||||
# 获取当前日期
|
||||
today = now.date()
|
||||
# 获取星期几的信息
|
||||
week_day = today.strftime("%A")
|
||||
# 将星期几的英文名称转换为中文
|
||||
week_day_zh = {
|
||||
"Monday": "星期一",
|
||||
"Tuesday": "星期二",
|
||||
"Wednesday": "星期三",
|
||||
"Thursday": "星期四",
|
||||
"Friday": "星期五",
|
||||
"Saturday": "星期六",
|
||||
"Sunday": "星期日",
|
||||
}.get(week_day, "未知")
|
||||
# 将日期格式化为字符串
|
||||
date_str = today.strftime("%Y年%m月%d日")
|
||||
|
||||
# 将时间格式化为字符串
|
||||
time_str = now.strftime("%H:%M")
|
||||
|
||||
return "现在时间是:{0} {1} {2}".format(time_str, week_day_zh, date_str)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tool = QueryTime()
|
||||
result = tool.run("")
|
||||
print(result)
|
41
llm/agent/tools/QueryTimerDB.py
Normal file
41
llm/agent/tools/QueryTimerDB.py
Normal file
@ -0,0 +1,41 @@
|
||||
import abc
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
import ast
|
||||
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class QueryTimerDB(BaseTool, abc.ABC):
|
||||
name: str = "QueryTimerDB"
|
||||
description: str = "用于查询所有日程,返回的数据里包含3个参数:时间、循环规则(如:'1000100'代表星期一和星期五循环,'0000000'代表不循环)、执行的事项"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# 用例中没有用到 arun 不予具体实现
|
||||
pass
|
||||
|
||||
|
||||
def _run(self, para) -> str:
|
||||
conn = sqlite3.connect('timer.db')
|
||||
cursor = conn.cursor()
|
||||
# 执行查询
|
||||
cursor.execute("SELECT * FROM timer")
|
||||
# 获取所有记录
|
||||
rows = cursor.fetchall()
|
||||
# 拼接结果
|
||||
result = ""
|
||||
for row in rows:
|
||||
result = result + str(row) + "\n"
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tool = QueryTimerDB()
|
||||
result = tool.run("")
|
||||
print(result)
|
26
llm/agent/tools/SendToPanel.py
Normal file
26
llm/agent/tools/SendToPanel.py
Normal file
@ -0,0 +1,26 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
from langchain.tools import BaseTool
|
||||
import fay_booter
|
||||
|
||||
class SendToPanel(BaseTool, abc.ABC):
|
||||
name = "SendToPanel"
|
||||
description = "用于给主人面板发送消息,使用时请传入消息内容作为参数。"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# 用例中没有用到 arun 不予具体实现
|
||||
pass
|
||||
|
||||
def _run(self, para) -> str:
|
||||
fay_booter.feiFei.send_to_panel(para)
|
||||
return "成功给主人,发送消息:{}".format(para)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tool = SendToPanel()
|
||||
result = tool.run("归纳一下近年关于“经济发展”的论文的特点和重点")
|
||||
print(result)
|
31
llm/agent/tools/SendWX.py
Normal file
31
llm/agent/tools/SendWX.py
Normal file
@ -0,0 +1,31 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
from langchain.tools import BaseTool
|
||||
import requests
|
||||
import json
|
||||
|
||||
url = "http://127.0.0.1:4008/send"
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
data = {
|
||||
"message": "你好",
|
||||
"receiver": "@2efc4e10cf2eafd0b0125930e4b96ed0cebffa75b2fd272590e38763225a282b"
|
||||
}
|
||||
|
||||
|
||||
class SendWX(BaseTool, abc.ABC):
|
||||
name = "SendWX"
|
||||
description = "给主人微信发送消息,传入参数是:('消息内容')"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# 用例中没有用到 arun 不予具体实现
|
||||
pass
|
||||
|
||||
def _run(self, para) -> str:
|
||||
global data
|
||||
data['message'] = para
|
||||
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||||
return "成功给主人,发送微信消息:{}".format(para)
|
||||
|
33
llm/agent/tools/ToRemind.py
Normal file
33
llm/agent/tools/ToRemind.py
Normal file
@ -0,0 +1,33 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
from langchain.tools import BaseTool
|
||||
import re
|
||||
import random
|
||||
class ToRemind(BaseTool, abc.ABC):
|
||||
name: str = "ToRemind"
|
||||
description: str = ("用于实时发送信息提醒主人做某事项(不能带时间),传入事项内容作为参数。")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# 用例中没有用到 arun 不予具体实现
|
||||
pass
|
||||
|
||||
def _run(self, para: str) -> str:
|
||||
|
||||
para = para.replace("提醒", "回复")
|
||||
demo = [
|
||||
"主人!是时候(事项内容)了喔~",
|
||||
"亲爱的主人,现在是(事项内容)的时候啦!",
|
||||
"嘿,主人,该(事项内容)了哦~",
|
||||
"温馨提醒:(事项内容)的时间到啦,主人!",
|
||||
"小提醒:主人,现在可以(事项内容)了~"
|
||||
]
|
||||
|
||||
return f"直接以中文友善{para},如"+ random.choice(demo)
|
||||
|
||||
if __name__ == "__main__":
|
||||
my_timer = ToRemind()
|
||||
result = my_timer._run("提醒主人叫咖啡")
|
||||
print(result)
|
53
llm/agent/tools/Weather.py
Normal file
53
llm/agent/tools/Weather.py
Normal file
@ -0,0 +1,53 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from langchain.tools import BaseTool
|
||||
from urllib.parse import quote
|
||||
|
||||
class Weather(BaseTool):
|
||||
name: str = "weather"
|
||||
description: str = "此工具用于获取天气预报信息,需传入英文的城市名,参数格式:Guangzhou"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# 用例中没有用到 arun 不予具体实现
|
||||
pass
|
||||
|
||||
|
||||
def _run(self, para: str) -> str:
|
||||
try:
|
||||
if not para:
|
||||
return "参数不能为空"
|
||||
encoded_city = quote(para)
|
||||
|
||||
api_url = f"http://api.openweathermap.org/data/2.5/weather?q={encoded_city}&appid=272fcb70d2c4e6f5134c2dce7d091df6"
|
||||
response = requests.get(api_url)
|
||||
if response.status_code == 200:
|
||||
weather_data = response.json()
|
||||
# 提取天气信息
|
||||
temperature_kelvin = weather_data['main']['temp']
|
||||
temperature_celsius = temperature_kelvin - 273.15
|
||||
min_temperature_kelvin = weather_data['main']['temp_min']
|
||||
max_temperature_kelvin = weather_data['main']['temp_max']
|
||||
min_temperature_celsius = min_temperature_kelvin - 273.15
|
||||
max_temperature_celsius = max_temperature_kelvin - 273.15
|
||||
description = weather_data['weather'][0]['description']
|
||||
wind_speed = weather_data['wind']['speed']
|
||||
|
||||
# 构建天气描述
|
||||
weather_description = f"今天天气:{description},气温:{temperature_celsius:.2f}摄氏度,风速:{wind_speed} m/s。"
|
||||
|
||||
return f"天气预报信息:{weather_description}"
|
||||
else:
|
||||
return f"无法获取天气预报信息,状态码:{response.status_code}"
|
||||
except Exception as e:
|
||||
return f"发生错误:{str(e)}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
weather_tool = Weather()
|
||||
weather_info = weather_tool.run("Guangzhou")
|
||||
print(weather_info)
|
42
llm/agent/tools/WebPageRetriever.py
Normal file
42
llm/agent/tools/WebPageRetriever.py
Normal file
@ -0,0 +1,42 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
from langchain.tools import BaseTool
|
||||
import requests
|
||||
|
||||
class WebPageRetriever(BaseTool, abc.ABC):
|
||||
name: str = "WebPageRetriever"
|
||||
description: str = "专门用于通过Bing搜索API快速检索和获取与特定查询词条相关的网页信息。使用时请传入需要查询的关键词作为参数。"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# 用例中没有用到 arun 不予具体实现
|
||||
pass
|
||||
|
||||
def _run(self, para) -> str:
|
||||
query = para
|
||||
subscription_key = ""#请自行进行补充
|
||||
if not subscription_key:
|
||||
print("请填写bing v7的subscription_key")
|
||||
return '请填写bing v7的subscription_key'
|
||||
|
||||
url = "https://api.bing.microsoft.com/v7.0/search"
|
||||
headers = {'Ocp-Apim-Subscription-Key': subscription_key}
|
||||
params = {'q': query, 'mkt': 'en-US'}
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
web_pages = data.get('webPages', {})
|
||||
return web_pages
|
||||
except Exception as e:
|
||||
print("Http Error:", e)
|
||||
return 'bing v7查询有误'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tool = WebPageRetriever()
|
||||
result = tool.run("归纳一下近年关于“经济发展”的论文的特点和重点")
|
||||
print(result)
|
35
llm/agent/tools/WebPageScraper.py
Normal file
35
llm/agent/tools/WebPageScraper.py
Normal file
@ -0,0 +1,35 @@
|
||||
from bs4 import BeautifulSoup
|
||||
import abc
|
||||
from typing import Any
|
||||
from langchain.tools import BaseTool
|
||||
import requests
|
||||
|
||||
class WebPageScraper(BaseTool, abc.ABC):
|
||||
name: str = "WebPageScraper"
|
||||
description: str = "此工具用于获取网页内容,使用时请传入需要查询的网页地址作为参数,如:https://www.baidu.com/。"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# 用例中没有用到 arun 不予具体实现
|
||||
pass
|
||||
|
||||
def _run(self, para) -> str:
|
||||
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}
|
||||
try:
|
||||
response = requests.get(para, headers=headers, timeout=10, verify=True)
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
return soup
|
||||
except requests.exceptions.SSLCertVerificationError:
|
||||
return 'SSL证书验证失败'
|
||||
except requests.exceptions.Timeout:
|
||||
return '请求超时'
|
||||
except Exception as e:
|
||||
print("Http Error:", e)
|
||||
return '无法获取该网页内容'
|
||||
|
||||
if __name__ == "__main__":
|
||||
tool = WebPageScraper()
|
||||
result = tool.run("https://book.douban.com/review/14636204")
|
||||
print(result)
|
@ -17,11 +17,14 @@ pytz
|
||||
gevent~=22.10.1
|
||||
edge_tts
|
||||
pydub
|
||||
langchain==0.0.336
|
||||
chromadb
|
||||
tenacity==8.2.3
|
||||
pygame
|
||||
scipy
|
||||
flask-httpauth
|
||||
opencv-python
|
||||
psutil
|
||||
psutil
|
||||
langchain
|
||||
langchain_openai
|
||||
langgraph
|
||||
langchain-community
|
@ -43,7 +43,7 @@ baidu_emotion_secret_key=
|
||||
|
||||
|
||||
|
||||
#NLP多选一:lingju、gpt、rasa、VisualGLM、rwkv、xingchen、langchain 、ollama_api、privategpt、coze
|
||||
#NLP多选一:agent、lingju、gpt、rasa、VisualGLM、rwkv、xingchen、langchain 、ollama_api、privategpt、coze
|
||||
chat_module= gpt
|
||||
|
||||
#灵聚 服务密钥(NLP多选1) https://open.lingju.ai
|
||||
|
715
test/test_langchain.ipynb
Normal file
715
test/test_langchain.ipynb
Normal file
File diff suppressed because one or more lines are too long
44
test/test_langserve.py
Normal file
44
test/test_langserve.py
Normal file
@ -0,0 +1,44 @@
|
||||
from fastapi import FastAPI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langserve import add_routes
|
||||
import os
|
||||
os.environ["OPENAI_API_KEY"] = "sk-"
|
||||
os.environ["OPENAI_API_BASE"] = "https://cn.api.zyai.online/v1"
|
||||
|
||||
# 1. Create prompt template
|
||||
system_template = "Translate the following into {language}:"
|
||||
prompt_template = ChatPromptTemplate.from_messages([
|
||||
('system', system_template),
|
||||
('user', '{text}')
|
||||
])
|
||||
|
||||
# 2. Create model
|
||||
model = ChatOpenAI()
|
||||
|
||||
# 3. Create parser
|
||||
parser = StrOutputParser()
|
||||
|
||||
# 4. Create chain
|
||||
chain = prompt_template | model | parser
|
||||
|
||||
|
||||
# 4. App definition
|
||||
app = FastAPI(
|
||||
title="LangChain Server",
|
||||
version="1.0",
|
||||
description="A simple API server using LangChain's Runnable interfaces",
|
||||
)
|
||||
|
||||
# 5. Adding chain route
|
||||
add_routes(
|
||||
app,
|
||||
chain,
|
||||
path="/chain",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="localhost", port=8000)
|
@ -2,7 +2,7 @@ import requests
|
||||
import json
|
||||
|
||||
def test_gpt(prompt):
|
||||
url = 'http://faycontroller.yaheen.com:5000/v1/chat/completions' # 替换为您的接口地址
|
||||
url = 'http://127.0.0.1:5000/v1/chat/completions' # 替换为您的接口地址
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer YOUR_API_KEY', # 如果您的接口需要身份验证
|
||||
|
@ -2,6 +2,16 @@ import os
|
||||
import json
|
||||
import codecs
|
||||
from configparser import ConfigParser
|
||||
import functools
|
||||
from threading import Lock
|
||||
|
||||
lock = Lock()
|
||||
def synchronized(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with lock:
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
config: json = None
|
||||
system_config: ConfigParser = None
|
||||
@ -40,6 +50,7 @@ coze_api_key = None
|
||||
start_mode = None
|
||||
fay_url = None
|
||||
|
||||
@synchronized
|
||||
def load_config():
|
||||
global config
|
||||
global system_config
|
||||
@ -116,8 +127,11 @@ def load_config():
|
||||
coze_api_key = system_config.get('key', 'coze_api_key')
|
||||
start_mode = system_config.get('key', 'start_mode')
|
||||
fay_url = system_config.get('key', 'fay_url')
|
||||
|
||||
#读取用户配置
|
||||
config = json.load(codecs.open('config.json', encoding='utf-8'))
|
||||
|
||||
@synchronized
|
||||
def save_config(config_data):
|
||||
global config
|
||||
config = config_data
|
||||
|
Loading…
Reference in New Issue
Block a user