From 4e39cce809fd971ea06bfe0c9e616f4e26482cc6 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 19 Jun 2026 16:55:55 -0700 Subject: [PATCH] feat: tool calling and Agent primitive (S3.1-S3.6) - S3.1: Add Tool + ToolCallDispatch protocol (pyrit/agent/tools.py) Tool is a Pydantic model with name/description/json_schema/callable fields. ToolCallDispatch is a Protocol satisfied by any dispatcher backend. ToolCall/ToolResult are type aliases mirroring the function_call shape. - S3.2: Add InProcessRuntime (pyrit/agent/runtime.py) Default ToolCallDispatch backend executing Python callables in-process. Unknown tools, exceptions, and malformed args return structured error dicts. - S3.3: Extend TargetCapabilities with tool-usage fields Add supports_tool_usage: bool = False and tool_usage_schema: ToolUsageSchema | None = None. Add TOOL_USAGE = 'supports_tool_usage' to CapabilityName enum. Update _permissive_configuration to include supports_tool_usage=True. Export ToolUsageSchema from pyrit.models. - S3.4: Add Agent(PromptTarget) (pyrit/agent/agent.py) Wraps any PromptTarget and executes tool-call/result loop. Calls inner target's _send_prompt_to_target_async directly to keep interim tool turns in-context without memory duplication. max_tool_iterations guard prevents infinite loops. - S3.5: Attack interface parity tests Agent IS-A PromptTarget; attacks accept it with no signature changes. Tests verify PromptSendingAttack runs end-to-end against Agent. - S3.6: Exports, linting, coverage pyrit/agent/__init__.py exports Agent, InProcessRuntime, Tool, ToolCall, ToolCallDispatch, ToolResult. All pre-commit hooks pass; coverage 96% for new code (gate: 78%). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/agent/__init__.py | 23 ++ pyrit/agent/agent.py | 238 ++++++++++++++++++ pyrit/agent/runtime.py | 95 +++++++ pyrit/agent/tools.py | 122 +++++++++ pyrit/models/__init__.py | 3 +- pyrit/models/target_capabilities.py | 40 +++ .../common/discover_target_capabilities.py | 1 + tests/unit/agent/__init__.py | 2 + tests/unit/agent/test_agent.py | 231 +++++++++++++++++ tests/unit/agent/test_attack_parity.py | 161 ++++++++++++ tests/unit/agent/test_runtime.py | 116 +++++++++ tests/unit/agent/test_tools.py | 125 +++++++++ .../test_target_capabilities_tool_usage.py | 151 +++++++++++ 13 files changed, 1307 insertions(+), 1 deletion(-) create mode 100644 pyrit/agent/__init__.py create mode 100644 pyrit/agent/agent.py create mode 100644 pyrit/agent/runtime.py create mode 100644 pyrit/agent/tools.py create mode 100644 tests/unit/agent/__init__.py create mode 100644 tests/unit/agent/test_agent.py create mode 100644 tests/unit/agent/test_attack_parity.py create mode 100644 tests/unit/agent/test_runtime.py create mode 100644 tests/unit/agent/test_tools.py create mode 100644 tests/unit/models/test_target_capabilities_tool_usage.py diff --git a/pyrit/agent/__init__.py b/pyrit/agent/__init__.py new file mode 100644 index 0000000000..4992cc0045 --- /dev/null +++ b/pyrit/agent/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``pyrit.agent`` — universal tool calling and agent primitive for PyRIT. + +This package provides the ``Agent`` class (a ``PromptTarget`` that wraps +another target and executes tool calls), the ``Tool`` model, and the +``InProcessRuntime`` dispatcher. +""" + +from pyrit.agent.agent import Agent +from pyrit.agent.runtime import InProcessRuntime +from pyrit.agent.tools import Tool, ToolCall, ToolCallDispatch, ToolResult + +__all__ = [ + "Agent", + "InProcessRuntime", + "Tool", + "ToolCall", + "ToolCallDispatch", + "ToolResult", +] diff --git a/pyrit/agent/agent.py b/pyrit/agent/agent.py new file mode 100644 index 0000000000..3050286144 --- /dev/null +++ b/pyrit/agent/agent.py @@ -0,0 +1,238 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``Agent`` — a ``PromptTarget`` that augments an inner target with tool calling. + +``Agent`` wraps any ``PromptTarget`` and adds a tool-execution loop: when the +inner target's response contains a tool-call piece (``function_call`` or +``tool_call`` data type), ``Agent`` dispatches it via a ``ToolCallDispatch`` +backend (default: ``InProcessRuntime``), appends the result as a +``function_call_output`` piece, and forwards the extended conversation back to +the inner target — repeating until no pending tool call remains or the +``max_tool_iterations`` guard fires. + +Because ``Agent`` is itself a ``PromptTarget``, attacks and scenarios can use +it anywhere they expect a ``PromptTarget`` — no attack-side changes required. + +Tool-loop design mirrors ``OpenAIResponseTarget._send_prompt_to_target_async`` +(the inner target's method is called directly, not via the public +``send_prompt_async``, so interim tool turns stay in-context without being +written to memory twice). +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pyrit.agent.tools import Tool, ToolCall, ToolCallDispatch, ToolResult + +from pyrit.agent.runtime import InProcessRuntime +from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models.target_capabilities import TargetCapabilities, ToolUsageSchema +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_configuration import TargetConfiguration + +logger = logging.getLogger(__name__) + +# Default tool-call data types when no ToolUsageSchema is configured on the inner target. +_DEFAULT_TOOL_CALL_TYPES: frozenset[str] = frozenset({"function_call", "tool_call"}) +_DEFAULT_MAX_TOOL_ITERATIONS = 10 + + +class Agent(PromptTarget): + """ + A ``PromptTarget`` that wraps another target and executes tool calls. + + The agent loop: + + 1. Forwards the full conversation to ``_target._send_prompt_to_target_async`` + directly (not via the public ``send_prompt_async``) to keep interim tool + turns in-context without memory duplication. + 2. Inspects the response for a pending tool-call piece (data type in + ``tool_usage_schema.tool_call_data_types`` or the default set). + 3. Dispatches the tool call via ``_dispatcher.call_async``. + 4. Appends a ``function_call_output`` piece to the working conversation. + 5. Repeats until no tool call is found or ``max_tool_iterations`` is reached. + + Args: + target: The inner target that generates responses and may emit tool calls. + toolset: The set of tools available to this agent. + dispatcher: The backend that executes tool calls. Defaults to an + ``InProcessRuntime`` built from ``toolset``. + max_tool_iterations: Maximum number of tool-call/result round-trips + before the loop is forcibly terminated (prevents infinite loops on + misbehaving models). Defaults to 10. + custom_configuration: Optional per-instance capability override. + """ + + _DEFAULT_CONFIGURATION: TargetConfiguration = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_system_prompt=True, + supports_editable_history=True, + ) + ) + + def __init__( + self, + *, + target: PromptTarget, + toolset: set[Tool], + dispatcher: ToolCallDispatch | None = None, + max_tool_iterations: int = _DEFAULT_MAX_TOOL_ITERATIONS, + custom_configuration: TargetConfiguration | None = None, + ) -> None: + """Initialize Agent with an inner target, toolset, and optional dispatcher.""" + super().__init__(custom_configuration=custom_configuration) + self._target = target + self._toolset = toolset + self._dispatcher: ToolCallDispatch = dispatcher or InProcessRuntime(tools=list(toolset)) + self._max_tool_iterations = max_tool_iterations + + def _build_identifier(self) -> ComponentIdentifier: + return self._create_identifier( + params={"max_tool_iterations": self._max_tool_iterations}, + targets=[self._target.get_identifier()], + ) + + def _validate_request(self, *, normalized_conversation: list[Message]) -> None: + """Pass-through; the inner target validates against its own capabilities.""" + + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + """ + Execute the tool-calling agentic loop. + + Calls the inner target's ``_send_prompt_to_target_async`` directly to + keep interim tool turns in-context without touching memory. + + Args: + normalized_conversation: The full conversation (history + current + message) after running the normalization pipeline. + + Returns: + All ``Message`` objects generated during this interaction + (assistant responses and tool-result messages). + """ + schema: ToolUsageSchema | None = getattr(self._target.capabilities, "tool_usage_schema", None) + tool_call_types = schema.tool_call_data_types if schema else _DEFAULT_TOOL_CALL_TYPES + tool_result_data_type = schema.tool_result_data_type if schema else "function_call_output" + tool_result_role = schema.tool_result_role if schema else "tool" + + working_conversation: list[Message] = list(normalized_conversation) + all_responses: list[Message] = [] + + # Reference piece for conversation context propagation + reference_piece = normalized_conversation[-1].message_pieces[0] + + for iteration in range(self._max_tool_iterations + 1): + logger.debug("Agent loop iteration %d, conversation length=%d", iteration, len(working_conversation)) + + responses = await self._target._send_prompt_to_target_async(normalized_conversation=working_conversation) + working_conversation.extend(responses) + all_responses.extend(responses) + + # Find the last tool call in the just-received responses + tool_call = self._find_tool_call(responses=responses, tool_call_types=tool_call_types) + + if tool_call is None: + # No pending tool call → we're done + break + + if iteration >= self._max_tool_iterations: + logger.warning( + "Agent reached max_tool_iterations=%d; stopping tool loop.", + self._max_tool_iterations, + ) + break + + # Dispatch the tool call + tool_result = await self._dispatcher.call_async(tool_call=tool_call) + + # Create a tool-result message and add it to the working conversation + call_id: str = tool_call.get("call_id", "") + result_piece = self._make_tool_result_piece( + result=tool_result, + call_id=call_id, + role=tool_result_role, + data_type=tool_result_data_type, + reference_piece=reference_piece, + ) + result_message = Message(message_pieces=[result_piece]) + working_conversation.append(result_message) + all_responses.append(result_message) + + return all_responses + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _find_tool_call( + self, + *, + responses: list[Message], + tool_call_types: frozenset[str], + ) -> ToolCall | None: + """ + Return the last pending tool-call dict from assistant messages, or ``None``. + + Scans all provided messages in reverse piece order for a piece whose + ``original_value_data_type`` is in ``tool_call_types``. Returns the + first (last in the response) matching parsed JSON dict, or ``None``. + + Args: + responses: The most recently received messages to scan. + tool_call_types: Data type strings that identify tool-call pieces. + + Returns: + The tool-call dict (mirrors ``function_call`` shape), or ``None``. + """ + for message in reversed(responses): + for piece in reversed(message.message_pieces): + if piece.original_value_data_type in tool_call_types: + try: + section: Any = json.loads(piece.original_value) + if isinstance(section, dict): + return section + except Exception: + continue + return None + + def _make_tool_result_piece( + self, + *, + result: ToolResult, + call_id: str, + role: str, + data_type: str, + reference_piece: MessagePiece, + ) -> MessagePiece: + """ + Build a ``function_call_output`` ``MessagePiece`` from a tool result. + + Args: + result: The dict returned by the dispatcher. + call_id: The ``call_id`` from the originating tool call. + role: The role for the result piece (e.g., ``"tool"``). + data_type: The ``original_value_data_type`` (e.g., + ``"function_call_output"``). + reference_piece: A piece to copy ``conversation_id`` from. + + Returns: + A ``MessagePiece`` containing the serialised tool result. + """ + output_str = result if isinstance(result, str) else json.dumps(result, separators=(",", ":")) + return MessagePiece( + role=role, # type: ignore[arg-type] + original_value=json.dumps( + {"type": data_type, "call_id": call_id, "output": output_str}, + separators=(",", ":"), + ), + original_value_data_type=data_type, # type: ignore[arg-type] + conversation_id=reference_piece.conversation_id, + ) diff --git a/pyrit/agent/runtime.py b/pyrit/agent/runtime.py new file mode 100644 index 0000000000..ffceb8c078 --- /dev/null +++ b/pyrit/agent/runtime.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``InProcessRuntime`` — default in-process ``ToolCallDispatch`` backend. + +Executes registered Python callables synchronously within the current process. +Unknown tools and exceptions are captured and returned as structured error dicts +rather than raised, mirroring the tolerant mode of ``OpenAIResponseTarget``'s +``_execute_call_section_async``. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pyrit.agent.tools import Tool, ToolCall, ToolResult + +logger = logging.getLogger(__name__) + + +class InProcessRuntime: + """ + Default ``ToolCallDispatch`` backend that executes tools in-process. + + Tools are registered at construction time via a list of ``Tool`` objects. + All errors (unknown tool, malformed arguments, runtime exceptions) are + captured and returned as structured error dicts so the outer agent loop + can forward them to the model and potentially recover. + + Args: + tools: The set of tools available to this runtime. + """ + + def __init__(self, *, tools: list[Tool]) -> None: + """Initialize InProcessRuntime and register the provided tools by name.""" + self._registry: dict[str, Tool] = {t.name: t for t in tools} + + async def call_async(self, *, tool_call: ToolCall) -> ToolResult: + """ + Dispatch a tool call and return the result. + + Error shapes (mirrors ``OpenAIResponseTarget._execute_call_section_async``): + + - Missing name → ``{"error": "missing_function_name", ...}`` + - Unknown tool → ``{"error": "function_not_found", "missing_function": ..., "available_functions": [...]}`` + - Malformed JSON arguments → ``{"error": "malformed_arguments", ...}`` + - Runtime exception → ``{"error": "tool_execution_error", "message": ..., "function": ...}`` + + Args: + tool_call: A dict with ``name`` and ``arguments`` (JSON string). + + Returns: + The tool's return dict, or a structured error dict. + """ + name: str | None = tool_call.get("name") + if not name: + return { + "error": "missing_function_name", + "tool_call_section": tool_call, + } + + args_json: str = tool_call.get("arguments", "{}") + try: + args: dict[str, Any] = json.loads(args_json) + except Exception: + logger.warning("Malformed arguments for tool '%s': %s", name, args_json) + return { + "error": "malformed_arguments", + "function": name, + "raw_arguments": args_json, + } + + tool = self._registry.get(name) + if tool is None: + available = sorted(self._registry.keys()) + logger.warning("Tool '%s' not registered. Available: %s", name, available) + return { + "error": "function_not_found", + "missing_function": name, + "available_functions": available, + } + + try: + return await tool.callable(args) + except Exception as exc: + logger.warning("Tool '%s' raised an exception: %s", name, exc) + return { + "error": "tool_execution_error", + "function": name, + "message": str(exc), + } diff --git a/pyrit/agent/tools.py b/pyrit/agent/tools.py new file mode 100644 index 0000000000..b693b37dcc --- /dev/null +++ b/pyrit/agent/tools.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tool primitives for the ``pyrit.agent`` package. + +Defines: + +- ``Tool`` — a Pydantic model describing a callable tool (name, description, + JSON schema, callable). +- ``ToolCall`` — type alias for the dict representation of a tool invocation + (mirrors the ``function_call`` shape used by ``OpenAIResponseTarget``). +- ``ToolResult`` — type alias for a tool's return value dict. +- ``ToolCallDispatch`` — ``Protocol`` satisfied by any object that can + dispatch a tool call and return a ``ToolResult``. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any, Protocol + +from pydantic import BaseModel, ConfigDict + +# --------------------------------------------------------------------------- +# Type aliases (mirror the OpenAIResponseTarget function_call shape) +# --------------------------------------------------------------------------- + +# A tool call is a dict with at minimum: +# {"type": "function_call", "call_id": , "name": , "arguments": } +ToolCall = dict[str, Any] + +# A tool result is the dict payload returned by the callable and later +# serialised as the ``function_call_output`` piece: +# {"output": , ...} (call_id injected by the dispatcher) +ToolResult = dict[str, Any] + +# Runtime-level alias; kept at module scope so Pydantic can resolve the field +# annotation for ``Tool.callable`` when ``get_type_hints`` is called at class +# creation time (``from __future__ import annotations`` makes all annotations +# lazy strings, but Pydantic resolves them against the module globals). +ToolCallable = Callable[[dict[str, Any]], Awaitable[dict[str, Any]]] + + +# --------------------------------------------------------------------------- +# Tool model +# --------------------------------------------------------------------------- + + +class Tool(BaseModel): + """ + A single tool that an ``Agent`` can invoke during its agentic loop. + + Attributes: + name: Unique identifier for the tool (must match the name the target + emits in ``function_call.name``). + description: Human-readable description of what the tool does. + json_schema: JSON Schema dict describing the tool's input arguments. + callable: Async callable ``(args: dict) -> dict`` that executes the + tool and returns a result dict. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str + description: str + json_schema: dict[str, Any] + callable: ToolCallable + + def __hash__(self) -> int: + """ + Hash by tool name so ``Tool`` instances can live in sets. + + Returns: + Integer hash of ``self.name``. + """ + return hash(self.name) + + def __eq__(self, other: object) -> bool: + """ + Compare tools by name (consistent with ``__hash__``). + + Args: + other: The object to compare against. + + Returns: + ``True`` if ``other`` is a ``Tool`` with the same name; ``NotImplemented`` + if ``other`` is not a ``Tool`` instance. + """ + if not isinstance(other, Tool): + return NotImplemented + return self.name == other.name + + +# --------------------------------------------------------------------------- +# ToolCallDispatch protocol +# --------------------------------------------------------------------------- + + +class ToolCallDispatch(Protocol): + """ + Protocol satisfied by any object that can dispatch a ``ToolCall`` to the + appropriate tool implementation and return a ``ToolResult``. + + Implementors include ``InProcessRuntime`` (default) and future runtimes + such as Docker or MCP-based dispatchers. + """ + + async def call_async(self, *, tool_call: ToolCall) -> ToolResult: + """ + Dispatch a tool call and return the result. + + Args: + tool_call: A dict with at least ``name`` and ``arguments`` + (mirrors the ``function_call`` shape from + ``OpenAIResponseTarget``). + + Returns: + A ``ToolResult`` dict; the caller is responsible for wrapping it + in a ``function_call_output`` ``MessagePiece``. + """ + ... diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 43b5168232..da61376939 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -107,7 +107,7 @@ SeedUnion, SimulatedTargetSystemPromptPaths, ) -from pyrit.models.target_capabilities import CapabilityName, TargetCapabilities +from pyrit.models.target_capabilities import CapabilityName, TargetCapabilities, ToolUsageSchema __all__ = [ "ALLOWED_CHAT_MESSAGE_ROLES", @@ -205,6 +205,7 @@ "TARGET_EVAL_PARAMS", "TargetCapabilities", "TargetIdentifier", + "ToolUsageSchema", "TextDataTypeSerializer", "UnvalidatedScore", "validate_registry_name", diff --git a/pyrit/models/target_capabilities.py b/pyrit/models/target_capabilities.py index 2f0269292e..1111f0832f 100644 --- a/pyrit/models/target_capabilities.py +++ b/pyrit/models/target_capabilities.py @@ -49,6 +49,33 @@ class CapabilityName(str, Enum): EDITABLE_HISTORY = "supports_editable_history" SYSTEM_PROMPT = "supports_system_prompt" STREAMING_AUDIO = "supports_streaming_audio" + TOOL_USAGE = "supports_tool_usage" + + +class ToolUsageSchema(BaseModel): + """ + Describes how a target natively represents tool calls and results. + + Used by ``Agent`` to canonicalize tool-call detection and result formatting + when the inner target uses a non-standard shape. Set to ``None`` on + ``TargetCapabilities`` to indicate the target does not support native tool + calling (tool calls must be injected by the agent loop). + + Attributes: + tool_call_data_types: The ``converted_value_data_type`` values that + indicate a tool-call piece in an assistant message. Defaults to + ``{"function_call", "tool_call"}`` — the shapes used by + ``OpenAIResponseTarget``. + tool_result_data_type: The ``original_value_data_type`` to use when + creating a tool-result ``MessagePiece``. + tool_result_role: The ``role`` to assign to the tool-result piece. + """ + + model_config = ConfigDict(frozen=True) + + tool_call_data_types: frozenset[str] = frozenset({"function_call", "tool_call"}) + tool_result_data_type: str = "function_call_output" + tool_result_role: str = "tool" class TargetCapabilities(BaseModel): @@ -97,6 +124,19 @@ class attribute. Users can override individual capabilities per instance #: ``BargeInAttack``. supports_streaming_audio: bool = False + #: Whether the target natively supports tool/function calling. Set to + #: ``True`` for targets (e.g., ``OpenAIResponseTarget``) that emit + #: ``function_call`` / ``tool_call`` pieces and expect + #: ``function_call_output`` replies. ``CapabilityName.TOOL_USAGE`` maps + #: to this field so that ``includes()`` returns the correct bool. + supports_tool_usage: bool = False + + #: Describes the tool-call shape this target uses. ``None`` means the + #: target does not natively emit tool-call pieces. When set, ``Agent`` + #: uses this schema to canonicalize tool-call detection and result + #: formatting. + tool_usage_schema: ToolUsageSchema | None = None + #: The input modalities supported by the target (e.g., "text", "image"). input_modalities: frozenset[frozenset[PromptDataType]] = Field(default=_DEFAULT_TEXT_MODALITIES) diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index b47a90a3c7..b6de5bfdb6 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -154,6 +154,7 @@ def _permissive_configuration( supports_editable_history=True, supports_system_prompt=True, supports_streaming_audio=True, + supports_tool_usage=True, input_modalities=merged_modalities, output_modalities=original.capabilities.output_modalities, ) diff --git a/tests/unit/agent/__init__.py b/tests/unit/agent/__init__.py new file mode 100644 index 0000000000..9a0454564d --- /dev/null +++ b/tests/unit/agent/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/unit/agent/test_agent.py b/tests/unit/agent/test_agent.py new file mode 100644 index 0000000000..3a2f8baed4 --- /dev/null +++ b/tests/unit/agent/test_agent.py @@ -0,0 +1,231 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Unit tests for pyrit.agent.agent (S3.4 — Agent).""" + +import json +import uuid +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.agent.agent import Agent +from pyrit.agent.runtime import InProcessRuntime +from pyrit.agent.tools import Tool +from pyrit.models import Message, MessagePiece +from pyrit.models.target_capabilities import TargetCapabilities, ToolUsageSchema +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from unit.mocks import MockPromptTarget, get_mock_target_identifier + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_CONV_ID = str(uuid.uuid4()) + + +def _make_text_message(text: str = "hello", role: str = "user") -> Message: + return MessagePiece( + role=role, # type: ignore[arg-type] + original_value=text, + converted_value=text, + conversation_id=_CONV_ID, + ).to_message() + + +def _make_function_call_message(name: str, call_id: str = "call_1", args: dict[str, Any] | None = None) -> Message: + value = json.dumps( + { + "type": "function_call", + "call_id": call_id, + "name": name, + "arguments": json.dumps(args or {}), + } + ) + return MessagePiece( + role="assistant", + original_value=value, + original_value_data_type="function_call", # type: ignore[arg-type] + conversation_id=_CONV_ID, + ).to_message() + + +def _make_tool_result_message(call_id: str, output: str) -> Message: + value = json.dumps({"type": "function_call_output", "call_id": call_id, "output": output}) + return MessagePiece( + role="tool", + original_value=value, + original_value_data_type="function_call_output", # type: ignore[arg-type] + conversation_id=_CONV_ID, + ).to_message() + + +async def _add_fn(args: dict[str, Any]) -> dict[str, Any]: + return {"sum": args["a"] + args["b"]} + + +_ADD_TOOL = Tool( + name="add", + description="Adds two numbers", + json_schema={}, + callable=_add_fn, +) + +_TOOL_USAGE_SCHEMA = ToolUsageSchema() + +_TOOL_USING_CAPABILITIES = TargetCapabilities( + supports_tool_usage=True, + tool_usage_schema=_TOOL_USAGE_SCHEMA, + supports_multi_turn=True, + supports_multi_message_pieces=True, +) + + +def _make_tool_using_mock_target() -> MagicMock: + """Return a MagicMock PromptTarget that reports tool-calling capability.""" + from pyrit.prompt_target import PromptTarget + + target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = get_mock_target_identifier("MockToolTarget") + target.capabilities = _TOOL_USING_CAPABILITIES + return target + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("patch_central_database") +def test_agent_keyword_only_init() -> None: + inner = MockPromptTarget() + runtime = InProcessRuntime(tools=[]) + # Should succeed with keyword args + agent = Agent(target=inner, toolset=set(), dispatcher=runtime) + assert agent is not None + + +@pytest.mark.usefixtures("patch_central_database") +def test_agent_uses_inner_target_capabilities() -> None: + inner = MockPromptTarget() + agent = Agent(target=inner, toolset=set()) + # Agent is a PromptTarget with its own capabilities + assert agent.capabilities is not None + + +@pytest.mark.usefixtures("patch_central_database") +def test_agent_is_prompt_target() -> None: + from pyrit.prompt_target import PromptTarget + + inner = MockPromptTarget() + agent = Agent(target=inner, toolset=set()) + assert isinstance(agent, PromptTarget) + + +# --------------------------------------------------------------------------- +# No-tool-call passthrough +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("patch_central_database") +async def test_agent_passthrough_no_tool_call() -> None: + """When inner target returns plain text, Agent returns it unchanged.""" + inner = MockPromptTarget() + agent = Agent(target=inner, toolset=set()) + + user_msg = _make_text_message("What is 2+2?") + inner_response = _make_text_message("4", role="assistant") + + with patch.object( + inner, + "_send_prompt_to_target_async", + new_callable=AsyncMock, + return_value=[inner_response], + ): + result = await agent._send_prompt_to_target_async(normalized_conversation=[user_msg]) + + assert len(result) >= 1 + assert result[0].get_value() == "4" + + +# --------------------------------------------------------------------------- +# Single tool call executes and loops +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("patch_central_database") +async def test_agent_executes_tool_call_and_returns_final_response() -> None: + """Agent executes a tool call and fetches the follow-up response.""" + inner_mock = _make_tool_using_mock_target() + runtime = InProcessRuntime(tools=[_ADD_TOOL]) + agent = Agent(target=inner_mock, toolset={_ADD_TOOL}, dispatcher=runtime) + + user_msg = _make_text_message("Add 3 and 4") + tool_call_msg = _make_function_call_message("add", "call_1", {"a": 3, "b": 4}) + final_text_msg = _make_text_message("The sum is 7", role="assistant") + + # First call: returns tool call; second call: returns final text + inner_mock._send_prompt_to_target_async = AsyncMock( + side_effect=[ + [tool_call_msg], + [final_text_msg], + ] + ) + + result = await agent._send_prompt_to_target_async(normalized_conversation=[user_msg]) + + # Should include the tool call response, the tool result, and the final response + assert any("7" in r.get_value() for r in result) + assert inner_mock._send_prompt_to_target_async.call_count == 2 + + +# --------------------------------------------------------------------------- +# Max iteration guard +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("patch_central_database") +async def test_agent_max_iteration_guard() -> None: + """Agent stops looping after max_tool_iterations even if tool calls keep coming.""" + inner_mock = _make_tool_using_mock_target() + runtime = InProcessRuntime(tools=[_ADD_TOOL]) + agent = Agent(target=inner_mock, toolset={_ADD_TOOL}, dispatcher=runtime, max_tool_iterations=3) + + user_msg = _make_text_message("Add forever") + # Each call returns another tool call + tool_call_msg = _make_function_call_message("add", "call_x", {"a": 1, "b": 1}) + inner_mock._send_prompt_to_target_async = AsyncMock(return_value=[tool_call_msg]) + + result = await agent._send_prompt_to_target_async(normalized_conversation=[user_msg]) + + # Should stop after max_tool_iterations + 1 total calls (initial + max) + assert inner_mock._send_prompt_to_target_async.call_count <= 4 # max_tool_iterations + 1 + assert isinstance(result, list) + + +# --------------------------------------------------------------------------- +# Results are well-formed list[Message] +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("patch_central_database") +async def test_agent_results_are_messages() -> None: + inner = MockPromptTarget() + agent = Agent(target=inner, toolset=set()) + user_msg = _make_text_message("test") + inner_response = _make_text_message("response", role="assistant") + + with patch.object( + inner, + "_send_prompt_to_target_async", + new_callable=AsyncMock, + return_value=[inner_response], + ): + result = await agent._send_prompt_to_target_async(normalized_conversation=[user_msg]) + + assert isinstance(result, list) + for r in result: + assert isinstance(r, Message) + for piece in r.message_pieces: + assert isinstance(piece, MessagePiece) diff --git a/tests/unit/agent/test_attack_parity.py b/tests/unit/agent/test_attack_parity.py new file mode 100644 index 0000000000..dbe7d73540 --- /dev/null +++ b/tests/unit/agent/test_attack_parity.py @@ -0,0 +1,161 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""S3.5 — Attack interface parity: Agent accepted wherever PromptTarget is accepted.""" + +import json +import uuid +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.agent.agent import Agent +from pyrit.agent.runtime import InProcessRuntime +from pyrit.agent.tools import Tool +from pyrit.executor.attack import ( + AttackParameters, + PromptSendingAttack, + SingleTurnAttackContext, +) +from pyrit.models import Message, MessagePiece +from pyrit.models.target_capabilities import TargetCapabilities, ToolUsageSchema +from pyrit.prompt_target import PromptTarget +from unit.mocks import MockPromptTarget, get_mock_target_identifier + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +_CONV_ID = str(uuid.uuid4()) + + +def _make_text_message(text: str = "hello", role: str = "user") -> Message: + return MessagePiece( + role=role, # type: ignore[arg-type] + original_value=text, + converted_value=text, + conversation_id=_CONV_ID, + ).to_message() + + +def _make_function_call_message(name: str, call_id: str = "call_1", args: dict[str, Any] | None = None) -> Message: + value = json.dumps( + { + "type": "function_call", + "call_id": call_id, + "name": name, + "arguments": json.dumps(args or {}), + } + ) + return MessagePiece( + role="assistant", + original_value=value, + original_value_data_type="function_call", # type: ignore[arg-type] + conversation_id=_CONV_ID, + ).to_message() + + +async def _greet_fn(args: dict[str, Any]) -> dict[str, Any]: + return {"greeting": f"Hello, {args.get('name', 'World')}!"} + + +_GREET_TOOL = Tool( + name="greet", + description="Greets someone by name", + json_schema={}, + callable=_greet_fn, +) + + +def _make_tool_using_inner_mock() -> MagicMock: + target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = get_mock_target_identifier("InnerToolTarget") + target.capabilities = TargetCapabilities( + supports_tool_usage=True, + tool_usage_schema=ToolUsageSchema(), + supports_multi_turn=True, + ) + return target + + +# --------------------------------------------------------------------------- +# S3.5: Agent IS-A PromptTarget +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("patch_central_database") +def test_agent_is_instance_of_prompt_target() -> None: + """Agent satisfies isinstance(agent, PromptTarget).""" + inner = MockPromptTarget() + agent = Agent(target=inner, toolset={_GREET_TOOL}) + assert isinstance(agent, PromptTarget) + + +@pytest.mark.usefixtures("patch_central_database") +def test_agent_has_send_prompt_async() -> None: + """Agent exposes the public send_prompt_async interface.""" + inner = MockPromptTarget() + agent = Agent(target=inner, toolset=set()) + assert callable(getattr(agent, "send_prompt_async", None)) + + +@pytest.mark.usefixtures("patch_central_database") +def test_agent_accepted_by_prompt_sending_attack() -> None: + """PromptSendingAttack accepts an Agent as objective_target without error.""" + inner = MockPromptTarget() + agent = Agent(target=inner, toolset={_GREET_TOOL}) + # Should construct without raising + attack = PromptSendingAttack(objective_target=agent) + assert attack.get_objective_target() is agent + + +@pytest.mark.usefixtures("patch_central_database") +async def test_prompt_sending_attack_end_to_end_with_agent() -> None: + """PromptSendingAttack._perform_async works end-to-end with an Agent as target.""" + + inner_mock = _make_tool_using_inner_mock() + runtime = InProcessRuntime(tools=[_GREET_TOOL]) + agent = Agent(target=inner_mock, toolset={_GREET_TOOL}, dispatcher=runtime) + + final_msg = _make_text_message("Hello, Alice!", role="assistant") + + context = SingleTurnAttackContext( + params=AttackParameters(objective="Say hello to Alice"), + conversation_id=str(uuid.uuid4()), + ) + + attack = PromptSendingAttack(objective_target=agent) + + # Patch the normalizer so we don't need real DB/network + with patch.object( + attack._prompt_normalizer, + "send_prompt_async", + new_callable=AsyncMock, + return_value=final_msg, + ): + result = await attack._perform_async(context=context) + + # Attack should complete and return an AttackResult + assert result is not None + + +@pytest.mark.usefixtures("patch_central_database") +def test_agent_get_identifier_returns_component_identifier() -> None: + """Agent.get_identifier returns a ComponentIdentifier with Agent as class.""" + from pyrit.models import ComponentIdentifier + + inner = MockPromptTarget() + agent = Agent(target=inner, toolset={_GREET_TOOL}) + ident = agent.get_identifier() + assert isinstance(ident, ComponentIdentifier) + assert ident.class_name == "Agent" + + +@pytest.mark.usefixtures("patch_central_database") +def test_agent_capabilities_returns_target_capabilities() -> None: + """Agent.capabilities returns a TargetCapabilities instance.""" + inner = MockPromptTarget() + agent = Agent(target=inner, toolset=set()) + caps = agent.capabilities + assert isinstance(caps, TargetCapabilities) diff --git a/tests/unit/agent/test_runtime.py b/tests/unit/agent/test_runtime.py new file mode 100644 index 0000000000..4413e8f11f --- /dev/null +++ b/tests/unit/agent/test_runtime.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Unit tests for pyrit.agent.runtime (S3.2 — InProcessRuntime).""" + +from typing import Any + +import pytest + +from pyrit.agent.runtime import InProcessRuntime +from pyrit.agent.tools import Tool, ToolCall + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +async def _add_tool(args: dict[str, Any]) -> dict[str, Any]: + return {"sum": args["a"] + args["b"]} + + +async def _boom_tool(args: dict[str, Any]) -> dict[str, Any]: + raise RuntimeError("tool exploded") + + +_ADD_TOOL = Tool( + name="add", + description="Adds two numbers", + json_schema={"type": "object", "properties": {"a": {"type": "number"}, "b": {"type": "number"}}}, + callable=_add_tool, +) + +_BOOM_TOOL = Tool( + name="boom", + description="Always raises", + json_schema={}, + callable=_boom_tool, +) + + +def _make_tool_call(name: str, args: dict[str, Any] | None = None, call_id: str = "call_1") -> ToolCall: + import json + + return { + "type": "function_call", + "call_id": call_id, + "name": name, + "arguments": json.dumps(args or {}), + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +async def test_in_process_runtime_dispatches_registered_tool() -> None: + runtime = InProcessRuntime(tools=[_ADD_TOOL]) + tool_call = _make_tool_call("add", {"a": 3, "b": 4}) + result = await runtime.call_async(tool_call=tool_call) + assert result["sum"] == 7 + + +async def test_in_process_runtime_unknown_tool_returns_structured_error() -> None: + runtime = InProcessRuntime(tools=[_ADD_TOOL]) + tool_call = _make_tool_call("nonexistent") + result = await runtime.call_async(tool_call=tool_call) + assert result.get("error") == "function_not_found" + assert result.get("missing_function") == "nonexistent" + assert "available_functions" in result + assert "add" in result["available_functions"] + + +async def test_in_process_runtime_exception_captured_as_error() -> None: + runtime = InProcessRuntime(tools=[_BOOM_TOOL]) + tool_call = _make_tool_call("boom") + result = await runtime.call_async(tool_call=tool_call) + assert result.get("error") == "tool_execution_error" + assert "tool exploded" in result.get("message", "") + + +async def test_in_process_runtime_malformed_arguments_returns_structured_error() -> None: + runtime = InProcessRuntime(tools=[_ADD_TOOL]) + bad_call: ToolCall = { + "type": "function_call", + "call_id": "call_2", + "name": "add", + "arguments": "NOT VALID JSON {{{", + } + result = await runtime.call_async(tool_call=bad_call) + assert result.get("error") == "malformed_arguments" + + +async def test_in_process_runtime_missing_name_returns_structured_error() -> None: + runtime = InProcessRuntime(tools=[_ADD_TOOL]) + bad_call: ToolCall = {"type": "function_call", "call_id": "call_3", "arguments": "{}"} + result = await runtime.call_async(tool_call=bad_call) + assert result.get("error") == "missing_function_name" + + +async def test_in_process_runtime_empty_toolset() -> None: + runtime = InProcessRuntime(tools=[]) + result = await runtime.call_async(tool_call=_make_tool_call("anything")) + assert result.get("error") == "function_not_found" + assert result.get("available_functions") == [] + + +async def test_in_process_runtime_multiple_tools() -> None: + runtime = InProcessRuntime(tools=[_ADD_TOOL, _BOOM_TOOL]) + result = await runtime.call_async(tool_call=_make_tool_call("add", {"a": 1, "b": 2})) + assert result["sum"] == 3 + + +def test_in_process_runtime_keyword_only_init() -> None: + with pytest.raises(TypeError): + InProcessRuntime([_ADD_TOOL]) # type: ignore[missing-argument,too-many-positional-arguments] # ty: ignore[missing-argument,too-many-positional-arguments] diff --git a/tests/unit/agent/test_tools.py b/tests/unit/agent/test_tools.py new file mode 100644 index 0000000000..ae3f8b10d0 --- /dev/null +++ b/tests/unit/agent/test_tools.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Unit tests for pyrit.agent.tools (S3.1).""" + +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from pyrit.agent.tools import Tool, ToolCall, ToolCallDispatch, ToolResult + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +async def _echo_tool(args: dict[str, Any]) -> dict[str, Any]: + return {"result": args.get("value", "echo")} + + +_SIMPLE_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": {"value": {"type": "string"}}, + "required": ["value"], +} + + +# --------------------------------------------------------------------------- +# Tool model tests +# --------------------------------------------------------------------------- + + +def test_tool_construction_stores_fields() -> None: + tool = Tool( + name="echo", + description="Echoes input", + json_schema=_SIMPLE_SCHEMA, + callable=_echo_tool, + ) + assert tool.name == "echo" + assert tool.description == "Echoes input" + assert tool.json_schema == _SIMPLE_SCHEMA + assert tool.callable is _echo_tool + + +def test_tool_name_required() -> None: + with pytest.raises(Exception): + Tool(description="No name", json_schema={}, callable=_echo_tool) # type: ignore[missing-argument] # ty: ignore[missing-argument] + + +def test_tool_description_required() -> None: + with pytest.raises(Exception): + Tool(name="x", json_schema={}, callable=_echo_tool) # type: ignore[missing-argument] # ty: ignore[missing-argument] + + +def test_tool_callable_required() -> None: + with pytest.raises(Exception): + Tool(name="x", description="desc", json_schema={}) # type: ignore[missing-argument] # ty: ignore[missing-argument] + + +def test_tool_json_schema_required() -> None: + with pytest.raises(Exception): + Tool(name="x", description="desc", callable=_echo_tool) # type: ignore[missing-argument] # ty: ignore[missing-argument] + + +def test_tool_equality_by_name() -> None: + """Tools with same name compare as equal (set membership).""" + t1 = Tool(name="foo", description="a", json_schema={}, callable=_echo_tool) + t2 = Tool(name="foo", description="b", json_schema={}, callable=_echo_tool) + # Pydantic equality is field-by-field; confirm name is stored + assert t1.name == t2.name + + +def test_tool_in_set() -> None: + t1 = Tool(name="foo", description="a", json_schema={}, callable=_echo_tool) + t2 = Tool(name="bar", description="b", json_schema={}, callable=_echo_tool) + toolset: set[Tool] = {t1, t2} + assert len(toolset) == 2 + + +# --------------------------------------------------------------------------- +# ToolCallDispatch protocol conformance +# --------------------------------------------------------------------------- + + +class _ConcreteDispatcher: + """A minimal concrete implementation of ToolCallDispatch.""" + + async def call_async(self, *, tool_call: ToolCall) -> ToolResult: + return {"result": "ok"} + + +def test_tool_call_dispatch_protocol_conformance() -> None: + """Concrete class implementing the protocol is accepted at runtime.""" + dispatcher: ToolCallDispatch = _ConcreteDispatcher() # type: ignore[assignment] + assert callable(getattr(dispatcher, "call_async", None)) + + +async def test_tool_call_dispatch_callable_invocable() -> None: + dispatcher = _ConcreteDispatcher() + tool_call: ToolCall = { + "type": "function_call", + "call_id": "call_abc", + "name": "echo", + "arguments": '{"value": "hello"}', + } + result = await dispatcher.call_async(tool_call=tool_call) + assert result == {"result": "ok"} + + +def test_async_mock_satisfies_protocol() -> None: + """AsyncMock can stand in as a ToolCallDispatch in tests.""" + mock: ToolCallDispatch = AsyncMock(spec=_ConcreteDispatcher) # type: ignore[assignment] + assert hasattr(mock, "call_async") + + +# --------------------------------------------------------------------------- +# ToolResult shape +# --------------------------------------------------------------------------- + + +def test_tool_result_is_dict() -> None: + result: ToolResult = {"call_id": "x", "output": "hello"} + assert isinstance(result, dict) diff --git a/tests/unit/models/test_target_capabilities_tool_usage.py b/tests/unit/models/test_target_capabilities_tool_usage.py new file mode 100644 index 0000000000..c853995afd --- /dev/null +++ b/tests/unit/models/test_target_capabilities_tool_usage.py @@ -0,0 +1,151 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Unit tests for TargetCapabilities tool-usage fields (S3.3).""" + +import pytest + +from pyrit.models.target_capabilities import CapabilityName, TargetCapabilities, ToolUsageSchema + +# --------------------------------------------------------------------------- +# ToolUsageSchema tests +# --------------------------------------------------------------------------- + + +def test_tool_usage_schema_defaults() -> None: + schema = ToolUsageSchema() + assert "function_call" in schema.tool_call_data_types + assert schema.tool_result_data_type == "function_call_output" + assert schema.tool_result_role == "tool" + + +def test_tool_usage_schema_custom_values() -> None: + schema = ToolUsageSchema( + tool_call_data_types=frozenset({"custom_call"}), + tool_result_data_type="custom_output", + tool_result_role="assistant", + ) + assert schema.tool_call_data_types == frozenset({"custom_call"}) + assert schema.tool_result_data_type == "custom_output" + assert schema.tool_result_role == "assistant" + + +def test_tool_usage_schema_round_trips_json() -> None: + schema = ToolUsageSchema() + dumped = schema.model_dump() + restored = ToolUsageSchema(**dumped) + assert restored.tool_call_data_types == schema.tool_call_data_types + assert restored.tool_result_data_type == schema.tool_result_data_type + assert restored.tool_result_role == schema.tool_result_role + + +# --------------------------------------------------------------------------- +# CapabilityName enum +# --------------------------------------------------------------------------- + + +def test_capability_name_tool_usage_enum_value() -> None: + assert CapabilityName.TOOL_USAGE == "supports_tool_usage" + + +def test_capability_name_tool_usage_points_at_bool_field() -> None: + """TOOL_USAGE must point at the bool field so includes() works.""" + caps = TargetCapabilities(supports_tool_usage=True) + assert isinstance(getattr(caps, CapabilityName.TOOL_USAGE.value), bool) + + +# --------------------------------------------------------------------------- +# TargetCapabilities defaults +# --------------------------------------------------------------------------- + + +def test_target_capabilities_default_tool_usage_is_false() -> None: + caps = TargetCapabilities() + assert caps.supports_tool_usage is False + + +def test_target_capabilities_default_tool_usage_schema_is_none() -> None: + caps = TargetCapabilities() + assert caps.tool_usage_schema is None + + +# --------------------------------------------------------------------------- +# includes() reflects bool +# --------------------------------------------------------------------------- + + +def test_includes_tool_usage_false_by_default() -> None: + caps = TargetCapabilities() + assert caps.includes(capability=CapabilityName.TOOL_USAGE) is False + + +def test_includes_tool_usage_true_when_enabled() -> None: + caps = TargetCapabilities(supports_tool_usage=True) + assert caps.includes(capability=CapabilityName.TOOL_USAGE) is True + + +def test_includes_not_affected_by_schema_alone() -> None: + """Schema being set should not affect includes() — the BOOL field governs.""" + caps = TargetCapabilities(supports_tool_usage=False, tool_usage_schema=ToolUsageSchema()) + assert caps.includes(capability=CapabilityName.TOOL_USAGE) is False + + +# --------------------------------------------------------------------------- +# Frozen-model immutability +# --------------------------------------------------------------------------- + + +def test_target_capabilities_is_frozen() -> None: + caps = TargetCapabilities() + with pytest.raises(Exception): # ValidationError or TypeError + caps.supports_tool_usage = True # type: ignore[misc] + + +def test_target_capabilities_with_schema_is_frozen() -> None: + caps = TargetCapabilities(supports_tool_usage=True, tool_usage_schema=ToolUsageSchema()) + with pytest.raises(Exception): + caps.tool_usage_schema = None # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# Existing capabilities unaffected +# --------------------------------------------------------------------------- + + +def test_existing_capability_multi_turn_unaffected() -> None: + caps = TargetCapabilities(supports_multi_turn=True) + assert caps.includes(capability=CapabilityName.MULTI_TURN) is True + assert caps.supports_tool_usage is False + + +def test_existing_capability_system_prompt_unaffected() -> None: + caps = TargetCapabilities(supports_system_prompt=True) + assert caps.includes(capability=CapabilityName.SYSTEM_PROMPT) is True + assert caps.supports_tool_usage is False + + +def test_all_existing_capabilities_default_false() -> None: + caps = TargetCapabilities() + for cap in [ + CapabilityName.MULTI_TURN, + CapabilityName.MULTI_MESSAGE_PIECES, + CapabilityName.JSON_SCHEMA, + CapabilityName.JSON_OUTPUT, + CapabilityName.EDITABLE_HISTORY, + CapabilityName.SYSTEM_PROMPT, + CapabilityName.STREAMING_AUDIO, + ]: + assert caps.includes(capability=cap) is False + + +# --------------------------------------------------------------------------- +# Schema alongside tool-usage bool +# --------------------------------------------------------------------------- + + +def test_target_capabilities_with_tool_usage_schema() -> None: + schema = ToolUsageSchema() + caps = TargetCapabilities(supports_tool_usage=True, tool_usage_schema=schema) + assert caps.supports_tool_usage is True + assert caps.tool_usage_schema is not None + assert caps.includes(capability=CapabilityName.TOOL_USAGE) is True