feat: Claude tool-use agent loop with graceful tool-error handling
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from auto_reverse.tools import Registry, tool_schemas
|
||||
|
||||
MAX_ITERATIONS = 25
|
||||
|
||||
|
||||
class Agent:
|
||||
"""Conversational Claude tool-use loop driving browser/flows/doc tools."""
|
||||
|
||||
def __init__(self, client: Any, registry: Registry, model: str, system: str) -> None:
|
||||
self._client = client
|
||||
self._registry = registry
|
||||
self._model = model
|
||||
self._system = system
|
||||
self._messages: list[dict[str, Any]] = []
|
||||
|
||||
def run_turn(self, user_message: str) -> str:
|
||||
self._messages.append({"role": "user", "content": user_message})
|
||||
for _ in range(MAX_ITERATIONS):
|
||||
response = self._client.messages.create(
|
||||
model=self._model,
|
||||
max_tokens=4096,
|
||||
system=self._system,
|
||||
tools=tool_schemas(self._registry),
|
||||
messages=self._messages,
|
||||
)
|
||||
self._messages.append(
|
||||
{"role": "assistant", "content": self._serialize(response.content)}
|
||||
)
|
||||
tool_uses = [b for b in response.content if b.type == "tool_use"]
|
||||
if not tool_uses:
|
||||
return self._text_of(response.content)
|
||||
results: list[dict[str, Any]] = []
|
||||
for block in tool_uses:
|
||||
results.append(self._run_tool(block))
|
||||
self._messages.append({"role": "user", "content": results})
|
||||
return "(stopped: reached max tool iterations)"
|
||||
|
||||
def _run_tool(self, block: Any) -> dict[str, Any]:
|
||||
entry = self._registry.get(block.name)
|
||||
if entry is None:
|
||||
output: Any = {"error": f"unknown tool {block.name}"}
|
||||
else:
|
||||
_, handler = entry
|
||||
try:
|
||||
output = handler(block.input)
|
||||
except Exception as exc: # tool failure -> structured error so the agent can re-plan
|
||||
output = {"error": str(exc)}
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.id,
|
||||
"content": json.dumps(output),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _serialize(content: list[Any]) -> list[dict[str, Any]]:
|
||||
out: list[dict[str, Any]] = []
|
||||
for b in content:
|
||||
if b.type == "text":
|
||||
out.append({"type": "text", "text": b.text})
|
||||
elif b.type == "tool_use":
|
||||
out.append(
|
||||
{"type": "tool_use", "id": b.id, "name": b.name, "input": b.input}
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _text_of(content: list[Any]) -> str:
|
||||
return "".join(b.text for b in content if b.type == "text").strip()
|
||||
@@ -0,0 +1,53 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from auto_reverse.agent import Agent
|
||||
|
||||
|
||||
def _text_block(text):
|
||||
return SimpleNamespace(type="text", text=text)
|
||||
|
||||
|
||||
def _tool_use(tool_id, name, inp):
|
||||
return SimpleNamespace(type="tool_use", id=tool_id, name=name, input=inp)
|
||||
|
||||
|
||||
class FakeMessages:
|
||||
def __init__(self, scripted):
|
||||
self._scripted = list(scripted)
|
||||
self.calls = []
|
||||
|
||||
def create(self, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
content = self._scripted.pop(0)
|
||||
stop = "tool_use" if any(b.type == "tool_use" for b in content) else "end_turn"
|
||||
return SimpleNamespace(content=content, stop_reason=stop, role="assistant")
|
||||
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, scripted):
|
||||
self.messages = FakeMessages(scripted)
|
||||
|
||||
|
||||
def test_agent_executes_tool_then_returns_text():
|
||||
scripted = [
|
||||
[_tool_use("t1", "flows_search", {"query": "users"})],
|
||||
[_text_block("Found the users endpoint.")],
|
||||
]
|
||||
client = FakeClient(scripted)
|
||||
registry = {
|
||||
"flows_search": (
|
||||
{"name": "flows_search", "input_schema": {"type": "object"}},
|
||||
lambda inp: {"endpoints": [{"path": "/api/users"}]},
|
||||
)
|
||||
}
|
||||
agent = Agent(client, registry, model="m", system="s")
|
||||
reply = agent.run_turn("map users")
|
||||
assert "users endpoint" in reply
|
||||
# the tool result was fed back: second create call has >= 3 messages
|
||||
assert len(client.messages.calls[1]["messages"]) >= 3
|
||||
|
||||
|
||||
def test_agent_plain_text_no_tools():
|
||||
client = FakeClient([[_text_block("Hello!")]])
|
||||
agent = Agent(client, {}, model="m", system="s")
|
||||
assert agent.run_turn("hi") == "Hello!"
|
||||
Reference in New Issue
Block a user