2024-01-01 22:53:06 +08:00
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
import math
|
|
|
|
|
|
2023-12-12 00:03:36 +08:00
|
|
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
|
|
|
from langchain.chat_models import ChatOpenAI
|
|
|
|
|
from langchain.memory import VectorStoreRetrieverMemory
|
|
|
|
|
import faiss
|
|
|
|
|
from langchain.docstore import InMemoryDocstore
|
|
|
|
|
from langchain.vectorstores import FAISS
|
2023-12-14 10:38:08 +08:00
|
|
|
|
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent, initialize_agent, agent_types
|
2023-12-12 00:03:36 +08:00
|
|
|
|
from langchain.chains import LLMChain
|
2023-12-25 22:13:09 +08:00
|
|
|
|
from langchain.prompts import PromptTemplate
|
2023-12-12 00:03:36 +08:00
|
|
|
|
|
|
|
|
|
from agent.tools.MyTimer import MyTimer
|
|
|
|
|
from agent.tools.QueryTime import QueryTime
|
|
|
|
|
from agent.tools.Weather import Weather
|
|
|
|
|
from agent.tools.Calculator import Calculator
|
|
|
|
|
from agent.tools.CheckSensor import CheckSensor
|
|
|
|
|
from agent.tools.Switch import Switch
|
|
|
|
|
from agent.tools.Knowledge import Knowledge
|
|
|
|
|
from agent.tools.Say import Say
|
|
|
|
|
from agent.tools.QueryTimerDB import QueryTimerDB
|
2023-12-14 10:38:08 +08:00
|
|
|
|
from agent.tools.DeleteTimer import DeleteTimer
|
2023-12-18 16:52:53 +08:00
|
|
|
|
from agent.tools.GetSwitchLog import GetSwitchLog
|
|
|
|
|
from agent.tools.getOnRunLinkage import getOnRunLinkage
|
2023-12-25 22:13:09 +08:00
|
|
|
|
from agent.tools.SetChatStatus import SetChatStatus
|
2024-01-01 22:53:06 +08:00
|
|
|
|
from langchain.callbacks import get_openai_callback
|
|
|
|
|
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
|
|
|
|
from langchain.memory import ConversationBufferWindowMemory
|
2023-12-12 00:03:36 +08:00
|
|
|
|
|
2023-12-12 18:23:43 +08:00
|
|
|
|
import utils.config_util as utils
|
|
|
|
|
from core import wsa_server
|
2024-01-01 22:53:06 +08:00
|
|
|
|
import fay_booter
|
|
|
|
|
from utils import util
|
2023-12-12 00:03:36 +08:00
|
|
|
|
|
2023-12-25 22:13:09 +08:00
|
|
|
|
|
2023-12-12 00:03:36 +08:00
|
|
|
|
class FayAgentCore():
|
|
|
|
|
def __init__(self):
|
2023-12-12 18:23:43 +08:00
|
|
|
|
utils.load_config()
|
|
|
|
|
os.environ['OPENAI_API_KEY'] = utils.key_gpt_api_key
|
2023-12-12 00:03:36 +08:00
|
|
|
|
#使用open ai embedding
|
|
|
|
|
embedding_size = 1536 # OpenAIEmbeddings 的维度
|
|
|
|
|
index = faiss.IndexFlatL2(embedding_size)
|
|
|
|
|
embedding_fn = OpenAIEmbeddings()
|
|
|
|
|
|
|
|
|
|
#创建llm
|
2024-01-01 22:53:06 +08:00
|
|
|
|
self.llm = ChatOpenAI(model="gpt-4-1106-preview", verbose=True)
|
2023-12-12 00:03:36 +08:00
|
|
|
|
|
|
|
|
|
#创建向量数据库
|
2024-01-01 22:53:06 +08:00
|
|
|
|
def relevance_score_fn(self, score: float) -> float:
|
|
|
|
|
return 1.0 - score / math.sqrt(2)
|
|
|
|
|
vectorstore = FAISS(embedding_fn, index, InMemoryDocstore({}), {}, relevance_score_fn=relevance_score_fn)
|
2023-12-12 00:03:36 +08:00
|
|
|
|
|
2024-01-01 22:53:06 +08:00
|
|
|
|
# 创建记忆(斯坦福小镇同款记忆检索机制:时间、相关性、重要性三个维度)
|
|
|
|
|
retriever = TimeWeightedVectorStoreRetriever(vectorstore=vectorstore, other_score_keys=["importance"], k=3)
|
|
|
|
|
self.agent_memory = VectorStoreRetrieverMemory(memory_key="history", retriever=retriever)
|
2023-12-12 00:03:36 +08:00
|
|
|
|
|
2023-12-12 18:23:43 +08:00
|
|
|
|
# 保存基本信息到记忆
|
|
|
|
|
utils.load_config()
|
|
|
|
|
attr_info = ", ".join(f"{key}: {value}" for key, value in utils.config["attribute"].items())
|
2024-01-01 22:53:06 +08:00
|
|
|
|
self.agent_memory.save_context({"input": "我的基本信息是?"}, {"output": attr_info})
|
|
|
|
|
|
|
|
|
|
#内存保存聊天历史
|
|
|
|
|
self.chat_history = []
|
2023-12-12 00:03:36 +08:00
|
|
|
|
|
|
|
|
|
#创建agent chain
|
|
|
|
|
my_timer = MyTimer()
|
|
|
|
|
query_time_tool = QueryTime()
|
|
|
|
|
weather_tool = Weather()
|
|
|
|
|
calculator_tool = Calculator()
|
|
|
|
|
check_sensor_tool = CheckSensor()
|
|
|
|
|
switch_tool = Switch()
|
|
|
|
|
knowledge_tool = Knowledge()
|
|
|
|
|
say_tool = Say()
|
|
|
|
|
query_timer_db_tool = QueryTimerDB()
|
2023-12-14 10:38:08 +08:00
|
|
|
|
delete_timer_tool = DeleteTimer()
|
2023-12-18 16:52:53 +08:00
|
|
|
|
get_switch_log = GetSwitchLog()
|
|
|
|
|
get_on_run_linkage = getOnRunLinkage()
|
2023-12-25 22:13:09 +08:00
|
|
|
|
set_chat_status_tool = SetChatStatus()
|
2023-12-18 16:52:53 +08:00
|
|
|
|
|
2024-01-01 22:53:06 +08:00
|
|
|
|
self.tools = [
|
2023-12-12 00:03:36 +08:00
|
|
|
|
Tool(
|
|
|
|
|
name=my_timer.name,
|
|
|
|
|
func=my_timer.run,
|
|
|
|
|
description=my_timer.description
|
|
|
|
|
),
|
|
|
|
|
Tool(
|
|
|
|
|
name=query_time_tool.name,
|
|
|
|
|
func=query_time_tool.run,
|
|
|
|
|
description=query_time_tool.description
|
|
|
|
|
),
|
|
|
|
|
Tool(
|
|
|
|
|
name=weather_tool.name,
|
|
|
|
|
func=weather_tool.run,
|
|
|
|
|
description=weather_tool.description
|
|
|
|
|
),
|
|
|
|
|
Tool(
|
|
|
|
|
name=calculator_tool.name,
|
|
|
|
|
func=calculator_tool.run,
|
|
|
|
|
description=calculator_tool.description
|
|
|
|
|
),
|
|
|
|
|
Tool(
|
|
|
|
|
name=check_sensor_tool.name,
|
|
|
|
|
func=check_sensor_tool.run,
|
|
|
|
|
description=check_sensor_tool.description
|
|
|
|
|
),
|
|
|
|
|
Tool(
|
|
|
|
|
name=switch_tool.name,
|
|
|
|
|
func=switch_tool.run,
|
|
|
|
|
description=switch_tool.description
|
|
|
|
|
),
|
|
|
|
|
Tool(
|
|
|
|
|
name=knowledge_tool.name,
|
|
|
|
|
func=knowledge_tool.run,
|
|
|
|
|
description=knowledge_tool.description
|
|
|
|
|
),
|
|
|
|
|
Tool(
|
|
|
|
|
name=say_tool.name,
|
|
|
|
|
func=say_tool.run,
|
|
|
|
|
description=say_tool.description
|
|
|
|
|
),
|
|
|
|
|
Tool(
|
|
|
|
|
name=query_timer_db_tool.name,
|
|
|
|
|
func=query_timer_db_tool.run,
|
|
|
|
|
description=query_timer_db_tool.description
|
|
|
|
|
),
|
2023-12-14 10:38:08 +08:00
|
|
|
|
Tool(
|
|
|
|
|
name=delete_timer_tool.name,
|
|
|
|
|
func=delete_timer_tool.run,
|
|
|
|
|
description=delete_timer_tool.description
|
|
|
|
|
),
|
2023-12-18 16:52:53 +08:00
|
|
|
|
Tool(
|
|
|
|
|
name=get_switch_log.name,
|
|
|
|
|
func=get_switch_log.run,
|
|
|
|
|
description=get_switch_log.description
|
|
|
|
|
),
|
|
|
|
|
Tool(
|
|
|
|
|
name=get_on_run_linkage.name,
|
|
|
|
|
func=get_on_run_linkage.run,
|
|
|
|
|
description=get_on_run_linkage.description
|
|
|
|
|
),
|
2023-12-25 22:13:09 +08:00
|
|
|
|
Tool(
|
|
|
|
|
name=set_chat_status_tool.name,
|
|
|
|
|
func=set_chat_status_tool.run,
|
|
|
|
|
description=set_chat_status_tool.description
|
|
|
|
|
),
|
2023-12-12 00:03:36 +08:00
|
|
|
|
]
|
|
|
|
|
|
2023-12-25 22:13:09 +08:00
|
|
|
|
#agent用于执行任务
|
2023-12-14 10:38:08 +08:00
|
|
|
|
self.agent = initialize_agent(agent_types=agent_types.AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
|
2024-01-01 22:53:06 +08:00
|
|
|
|
tools=self.tools, llm=self.llm, verbose=True,
|
|
|
|
|
max_history=5, handle_parsing_errors=True)
|
|
|
|
|
|
|
|
|
|
#llm chain 用于聊天
|
|
|
|
|
self.is_chat = False#聊天状态
|
|
|
|
|
|
|
|
|
|
#记录一轮执行有无调用过say tool
|
|
|
|
|
self.is_use_say_tool = False
|
|
|
|
|
self.say_tool_text = ""
|
|
|
|
|
|
|
|
|
|
self.total_tokens = 0
|
|
|
|
|
self.total_cost = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_history_str(self, str):
|
|
|
|
|
result = ""
|
|
|
|
|
history_string = str['history']
|
|
|
|
|
|
|
|
|
|
# Split the string into lines
|
|
|
|
|
lines = history_string.split('input:')
|
|
|
|
|
|
|
|
|
|
# Initialize an empty list to store the formatted history
|
|
|
|
|
formatted_history = []
|
|
|
|
|
|
|
|
|
|
#处理记忆流格式
|
|
|
|
|
for line in lines:
|
|
|
|
|
if "output" in line:
|
|
|
|
|
input_line = line.split("output:")[0].strip()
|
|
|
|
|
output_line = line.split("output:")[1].strip()
|
|
|
|
|
formatted_history.append({"input": input_line, "output": output_line})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 记忆流转换成字符串
|
|
|
|
|
result += "-以下是与用户说话关连度最高的记忆:\n"
|
|
|
|
|
for i in range(len(formatted_history)):
|
|
|
|
|
if i >= 3:
|
|
|
|
|
break
|
|
|
|
|
line = formatted_history[i]
|
|
|
|
|
result += f"--input:{line['input']}\n--output:{line['output']}\n"
|
|
|
|
|
if len(formatted_history) == 0:
|
|
|
|
|
result += "--没有记录\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#添加内存记忆
|
|
|
|
|
formatted_history = []
|
|
|
|
|
for line in self.chat_history:
|
|
|
|
|
formatted_history.append({"input": line[0], "output": line[1]})
|
2023-12-25 22:13:09 +08:00
|
|
|
|
|
2024-01-01 22:53:06 +08:00
|
|
|
|
#格式化内存记忆字符串
|
|
|
|
|
result += "\n-以下刚刚的对话:\n"
|
|
|
|
|
for i in range(len(formatted_history)):
|
|
|
|
|
line = formatted_history[i]
|
|
|
|
|
result += f"--input:{line['input']}\n--output:{line['output']}\n"
|
|
|
|
|
if len(formatted_history) == 0:
|
|
|
|
|
result += "--没有记录\n"
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_llm_chain(self, history):
|
2023-12-25 22:13:09 +08:00
|
|
|
|
tools_prompt = "["
|
2024-01-01 22:53:06 +08:00
|
|
|
|
tool_names = [tool.name for tool in self.tools if tool.name != SetChatStatus().name and tool.name != Say().name]
|
2023-12-25 22:13:09 +08:00
|
|
|
|
tools_prompt += "、".join(tool_names) + "]"
|
2024-01-01 22:53:06 +08:00
|
|
|
|
template = """
|
|
|
|
|
你是一个智能家居系统中的AI,负责协助主人处理日常事务和智能设备的操作。当主人提出要求时,如果需要使用特定的工具或执行特定的操作,请严格回复“agent: {human_input}”字符串。如果主人只是进行普通对话或询问信息,直接以文本内容回答即可。你可以使用的工具或执行的任务包括:。
|
|
|
|
|
""" + tools_prompt + "等。" +"""
|
|
|
|
|
现在时间是:now_time
|
|
|
|
|
请依据以下信息回复主人:
|
|
|
|
|
chat_history
|
|
|
|
|
|
|
|
|
|
input:
|
|
|
|
|
{human_input}
|
|
|
|
|
output:""".replace("chat_history", history).replace("now_time", QueryTime().run(""))
|
2023-12-25 22:13:09 +08:00
|
|
|
|
prompt = PromptTemplate(
|
2024-01-01 22:53:06 +08:00
|
|
|
|
input_variables=["human_input"], template=template
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
llm_chain = LLMChain(
|
|
|
|
|
llm=self.llm,
|
2023-12-25 22:13:09 +08:00
|
|
|
|
prompt=prompt,
|
2024-01-01 22:53:06 +08:00
|
|
|
|
verbose=True
|
2023-12-25 22:13:09 +08:00
|
|
|
|
)
|
2024-01-01 22:53:06 +08:00
|
|
|
|
return llm_chain
|
|
|
|
|
|
2023-12-12 00:03:36 +08:00
|
|
|
|
def run(self, input_text):
|
2024-01-01 22:53:06 +08:00
|
|
|
|
self.is_use_say_tool = False
|
|
|
|
|
self.say_tool_text = ""
|
|
|
|
|
|
2023-12-25 22:13:09 +08:00
|
|
|
|
result = ""
|
2024-01-01 22:53:06 +08:00
|
|
|
|
history = self.agent_memory.load_memory_variables({"input":input_text.replace('主人语音说了:', '').replace('主人文字说了:', '')})
|
|
|
|
|
history = self.format_history_str(history)
|
2023-12-14 10:38:08 +08:00
|
|
|
|
try:
|
2023-12-25 22:13:09 +08:00
|
|
|
|
#判断执行聊天模式还是agent模式,双模式在运行过程中会主动切换
|
|
|
|
|
if self.is_chat:
|
2024-01-01 22:53:06 +08:00
|
|
|
|
llm_chain = self.get_llm_chain(history)
|
|
|
|
|
with get_openai_callback() as cb:
|
|
|
|
|
result = llm_chain.predict(human_input=input_text.replace('主人语音说了:', '').replace('主人文字说了:', ''))
|
|
|
|
|
self.total_tokens = self.total_tokens + cb.total_tokens
|
|
|
|
|
self.total_cost = self.total_cost + cb.total_cost
|
|
|
|
|
util.log(1, "本次消耗token:{}, Cost (USD):{},共消耗token:{}, Cost (USD):{}".format(cb.total_tokens, cb.total_cost, self.total_tokens, self.total_cost))
|
|
|
|
|
|
2023-12-25 22:13:09 +08:00
|
|
|
|
if "agent:" in result.lower() or not self.is_chat:
|
|
|
|
|
self.is_chat = False
|
2024-01-01 22:53:06 +08:00
|
|
|
|
input_text = result.lower().replace("agent:", "") if "agent:" in result.lower() else input_text.replace('主人语音说了:', '').replace('主人文字说了:', '')
|
|
|
|
|
agent_prompt = """
|
|
|
|
|
现在时间是:{now_time}。请依据以下信息为主人服务 :
|
|
|
|
|
{history}
|
|
|
|
|
input:{input_text}
|
|
|
|
|
output:
|
|
|
|
|
""".format(history=history, input_text=input_text, now_time=QueryTime().run(""))
|
|
|
|
|
print(agent_prompt)
|
|
|
|
|
with get_openai_callback() as cb:
|
|
|
|
|
result = self.agent.run(agent_prompt)
|
|
|
|
|
self.total_tokens = self.total_tokens + cb.total_tokens
|
|
|
|
|
self.total_cost = self.total_cost + cb.total_cost
|
|
|
|
|
util.log(1, "本次消耗token:{}, Cost (USD):{},共消耗token:{}, Cost (USD):{}".format(cb.total_tokens, cb.total_cost, self.total_tokens, self.total_cost))
|
|
|
|
|
|
2023-12-14 10:38:08 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
print(e)
|
|
|
|
|
|
2024-01-01 22:53:06 +08:00
|
|
|
|
result = "执行完毕" if result is None or result == "N/A" else result
|
|
|
|
|
chat_text = self.say_tool_text if self.is_use_say_tool else result
|
2023-12-12 18:23:43 +08:00
|
|
|
|
|
2024-01-01 22:53:06 +08:00
|
|
|
|
#保存到记忆流和聊天对话
|
|
|
|
|
self.agent_memory.save_context({"input": input_text.replace('主人语音说了:', '').replace('主人文字说了:', '')},{"output": result})
|
|
|
|
|
self.chat_history.append((input_text.replace('主人语音说了:', '').replace('主人文字说了:', ''), chat_text))
|
|
|
|
|
if len(self.chat_history) > 5:
|
|
|
|
|
self.chat_history.pop(0)
|
|
|
|
|
|
|
|
|
|
return self.is_use_say_tool, chat_text
|
2023-12-12 00:03:36 +08:00
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
agent = FayAgentCore()
|
|
|
|
|
print(agent.run("你好"))
|