diff --git a/src/auto_reverse/agent.py b/src/auto_reverse/agent.py new file mode 100644 index 0000000..16c94de --- /dev/null +++ b/src/auto_reverse/agent.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import json +from typing import Any + +from auto_reverse.tools import Registry, tool_schemas + +MAX_ITERATIONS = 25 + + +class Agent: + """Conversational Claude tool-use loop driving browser/flows/doc tools.""" + + def __init__(self, client: Any, registry: Registry, model: str, system: str) -> None: + self._client = client + self._registry = registry + self._model = model + self._system = system + self._messages: list[dict[str, Any]] = [] + + def run_turn(self, user_message: str) -> str: + self._messages.append({"role": "user", "content": user_message}) + for _ in range(MAX_ITERATIONS): + response = self._client.messages.create( + model=self._model, + max_tokens=4096, + system=self._system, + tools=tool_schemas(self._registry), + messages=self._messages, + ) + self._messages.append( + {"role": "assistant", "content": self._serialize(response.content)} + ) + tool_uses = [b for b in response.content if b.type == "tool_use"] + if not tool_uses: + return self._text_of(response.content) + results: list[dict[str, Any]] = [] + for block in tool_uses: + results.append(self._run_tool(block)) + self._messages.append({"role": "user", "content": results}) + return "(stopped: reached max tool iterations)" + + def _run_tool(self, block: Any) -> dict[str, Any]: + entry = self._registry.get(block.name) + if entry is None: + output: Any = {"error": f"unknown tool {block.name}"} + else: + _, handler = entry + try: + output = handler(block.input) + except Exception as exc: # tool failure -> structured error so the agent can re-plan + output = {"error": str(exc)} + return { + "type": "tool_result", + "tool_use_id": block.id, + "content": json.dumps(output), + } + + @staticmethod + def _serialize(content: list[Any]) -> list[dict[str, Any]]: + out: list[dict[str, Any]] = [] + for b in content: + if b.type == "text": + out.append({"type": "text", "text": b.text}) + elif b.type == "tool_use": + out.append( + {"type": "tool_use", "id": b.id, "name": b.name, "input": b.input} + ) + return out + + @staticmethod + def _text_of(content: list[Any]) -> str: + return "".join(b.text for b in content if b.type == "text").strip() diff --git a/tests/test_agent.py b/tests/test_agent.py new file mode 100644 index 0000000..4ad5da9 --- /dev/null +++ b/tests/test_agent.py @@ -0,0 +1,53 @@ +from types import SimpleNamespace + +from auto_reverse.agent import Agent + + +def _text_block(text): + return SimpleNamespace(type="text", text=text) + + +def _tool_use(tool_id, name, inp): + return SimpleNamespace(type="tool_use", id=tool_id, name=name, input=inp) + + +class FakeMessages: + def __init__(self, scripted): + self._scripted = list(scripted) + self.calls = [] + + def create(self, **kwargs): + self.calls.append(kwargs) + content = self._scripted.pop(0) + stop = "tool_use" if any(b.type == "tool_use" for b in content) else "end_turn" + return SimpleNamespace(content=content, stop_reason=stop, role="assistant") + + +class FakeClient: + def __init__(self, scripted): + self.messages = FakeMessages(scripted) + + +def test_agent_executes_tool_then_returns_text(): + scripted = [ + [_tool_use("t1", "flows_search", {"query": "users"})], + [_text_block("Found the users endpoint.")], + ] + client = FakeClient(scripted) + registry = { + "flows_search": ( + {"name": "flows_search", "input_schema": {"type": "object"}}, + lambda inp: {"endpoints": [{"path": "/api/users"}]}, + ) + } + agent = Agent(client, registry, model="m", system="s") + reply = agent.run_turn("map users") + assert "users endpoint" in reply + # the tool result was fed back: second create call has >= 3 messages + assert len(client.messages.calls[1]["messages"]) >= 3 + + +def test_agent_plain_text_no_tools(): + client = FakeClient([[_text_block("Hello!")]]) + agent = Agent(client, {}, model="m", system="s") + assert agent.run_turn("hi") == "Hello!"