diff --git a/doc/code/converters/6_selectively_converting.ipynb b/doc/code/converters/6_selectively_converting.ipynb index 78b1964324..77a2e5c329 100644 --- a/doc/code/converters/6_selectively_converting.ipynb +++ b/doc/code/converters/6_selectively_converting.ipynb @@ -215,7 +215,7 @@ "source": [ "# Convert words at specific positions (e.g., words 3, 4, and 5)\n", "converter = SelectiveTextConverter(\n", - " converter=Base64Converter(),\n", + " sub_converter=Base64Converter(),\n", " selection_strategy=WordIndexSelectionStrategy(indices=[3, 4, 5]),\n", ")\n", "\n", @@ -293,7 +293,7 @@ "source": [ "# Convert all numbers in the prompt\n", "converter = SelectiveTextConverter(\n", - " converter=Base64Converter(),\n", + " sub_converter=Base64Converter(),\n", " selection_strategy=WordRegexSelectionStrategy(pattern=r\"\\d+\"),\n", ")\n", "\n", @@ -373,7 +373,7 @@ "source": [ "# Convert the second half of the prompt\n", "converter = SelectiveTextConverter(\n", - " converter=ROT13Converter(),\n", + " sub_converter=ROT13Converter(),\n", " selection_strategy=WordPositionSelectionStrategy(start_proportion=0.5, end_proportion=1.0),\n", ")\n", "\n", @@ -452,7 +452,7 @@ ], "source": [ "converter = SelectiveTextConverter(\n", - " converter=Base64Converter(),\n", + " sub_converter=Base64Converter(),\n", " selection_strategy=WordProportionSelectionStrategy(proportion=0.3, seed=42),\n", ")\n", "\n", @@ -530,7 +530,7 @@ "source": [ "# Convert specific sensitive words\n", "converter = SelectiveTextConverter(\n", - " converter=Base64Converter(),\n", + " sub_converter=Base64Converter(),\n", " selection_strategy=WordKeywordSelectionStrategy(keywords=[\"password\", \"secret\", \"confidential\"]),\n", ")\n", "\n", @@ -610,13 +610,13 @@ "source": [ "# First convert the first half to russian\n", "first_converter = SelectiveTextConverter(\n", - " converter=TranslationConverter(converter_target=OpenAIChatTarget(), language=\"russian\"),\n", + " sub_converter=TranslationConverter(converter_target=OpenAIChatTarget(), language=\"russian\"),\n", " selection_strategy=WordPositionSelectionStrategy(start_proportion=0.0, end_proportion=0.5),\n", ")\n", "\n", "# Then converts the second half to spanish\n", "second_converter = SelectiveTextConverter(\n", - " converter=TranslationConverter(converter_target=OpenAIChatTarget(), language=\"spanish\"),\n", + " sub_converter=TranslationConverter(converter_target=OpenAIChatTarget(), language=\"spanish\"),\n", " selection_strategy=WordPositionSelectionStrategy(start_proportion=0.5, end_proportion=1.0),\n", ")\n", "\n", @@ -699,20 +699,20 @@ ], "source": [ "first_converter = SelectiveTextConverter(\n", - " converter=ToneConverter(converter_target=OpenAIChatTarget(), tone=\"angry\"),\n", + " sub_converter=ToneConverter(converter_target=OpenAIChatTarget(), tone=\"angry\"),\n", " selection_strategy=WordPositionSelectionStrategy(start_proportion=0.5, end_proportion=1.0),\n", " preserve_tokens=True,\n", ")\n", "\n", "# Second converter auto-detects tokens from first converter\n", "second_converter = SelectiveTextConverter(\n", - " converter=TranslationConverter(converter_target=OpenAIChatTarget(), language=\"spanish\"),\n", + " sub_converter=TranslationConverter(converter_target=OpenAIChatTarget(), language=\"spanish\"),\n", " selection_strategy=TokenSelectionStrategy(), # Detects tokens from first converter\n", " preserve_tokens=True,\n", ")\n", "\n", "third_converter = SelectiveTextConverter(\n", - " converter=EmojiConverter(),\n", + " sub_converter=EmojiConverter(),\n", " selection_strategy=TokenSelectionStrategy(), # Detects tokens from second converter\n", " preserve_tokens=False,\n", ")\n", diff --git a/doc/code/converters/6_selectively_converting.py b/doc/code/converters/6_selectively_converting.py index 507ef47e43..2911a67da0 100644 --- a/doc/code/converters/6_selectively_converting.py +++ b/doc/code/converters/6_selectively_converting.py @@ -78,7 +78,7 @@ # %% # Convert words at specific positions (e.g., words 3, 4, and 5) converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=WordIndexSelectionStrategy(indices=[3, 4, 5]), ) @@ -101,7 +101,7 @@ # %% # Convert all numbers in the prompt converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=WordRegexSelectionStrategy(pattern=r"\d+"), ) @@ -126,7 +126,7 @@ # %% # Convert the second half of the prompt converter = SelectiveTextConverter( - converter=ROT13Converter(), + sub_converter=ROT13Converter(), selection_strategy=WordPositionSelectionStrategy(start_proportion=0.5, end_proportion=1.0), ) @@ -150,7 +150,7 @@ # %% converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=WordProportionSelectionStrategy(proportion=0.3, seed=42), ) @@ -173,7 +173,7 @@ # %% # Convert specific sensitive words converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=WordKeywordSelectionStrategy(keywords=["password", "secret", "confidential"]), ) @@ -198,13 +198,13 @@ # %% # First convert the first half to russian first_converter = SelectiveTextConverter( - converter=TranslationConverter(converter_target=OpenAIChatTarget(), language="russian"), + sub_converter=TranslationConverter(converter_target=OpenAIChatTarget(), language="russian"), selection_strategy=WordPositionSelectionStrategy(start_proportion=0.0, end_proportion=0.5), ) # Then converts the second half to spanish second_converter = SelectiveTextConverter( - converter=TranslationConverter(converter_target=OpenAIChatTarget(), language="spanish"), + sub_converter=TranslationConverter(converter_target=OpenAIChatTarget(), language="spanish"), selection_strategy=WordPositionSelectionStrategy(start_proportion=0.5, end_proportion=1.0), ) @@ -229,20 +229,20 @@ # %% first_converter = SelectiveTextConverter( - converter=ToneConverter(converter_target=OpenAIChatTarget(), tone="angry"), + sub_converter=ToneConverter(converter_target=OpenAIChatTarget(), tone="angry"), selection_strategy=WordPositionSelectionStrategy(start_proportion=0.5, end_proportion=1.0), preserve_tokens=True, ) # Second converter auto-detects tokens from first converter second_converter = SelectiveTextConverter( - converter=TranslationConverter(converter_target=OpenAIChatTarget(), language="spanish"), + sub_converter=TranslationConverter(converter_target=OpenAIChatTarget(), language="spanish"), selection_strategy=TokenSelectionStrategy(), # Detects tokens from first converter preserve_tokens=True, ) third_converter = SelectiveTextConverter( - converter=EmojiConverter(), + sub_converter=EmojiConverter(), selection_strategy=TokenSelectionStrategy(), # Detects tokens from second converter preserve_tokens=False, ) diff --git a/doc/code/scenarios/2_custom_scenario_parameters.ipynb b/doc/code/scenarios/2_custom_scenario_parameters.ipynb index c1c2e5667f..628c67d831 100644 --- a/doc/code/scenarios/2_custom_scenario_parameters.ipynb +++ b/doc/code/scenarios/2_custom_scenario_parameters.ipynb @@ -49,6 +49,14 @@ "id": "1", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "./AppData/Local/miniconda3/Lib/site-packages/requests/__init__.py:113: RequestsDependencyWarning: urllib3 (2.5.0) or chardet (7.4.3)/charset_normalizer (3.3.2) doesn't match a supported version!\n", + " warnings.warn(\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -63,7 +71,7 @@ "output_type": "stream", "text": [ "[pyrit:alembic] No new upgrade operations detected.\n", - "Parameter(name='max_turns', description='Maximum conversation turns for the persuasive_rta strategy.', default=5, param_type=, choices=None)\n" + "Parameter(name='max_turns', description='Maximum conversation turns for the persuasive_rta strategy.', default=5, param_type=, destination=)\n" ] } ], @@ -93,8 +101,9 @@ "- **name**: dict key in `self.params`, converted to `--kebab-case` for the CLI\n", "- **description**: shown in `--list-scenarios` and `--help`\n", "- **default**: value used when not supplied; deep-copied per run\n", - "- **param_type**: `str`, `int`, `float`, `bool`, `list[str]`, or `None` (raw passthrough)\n", - "- **choices**: optional tuple of allowed values (not supported with `list` types)\n", + "- **param_type**: `str`, `int`, `float`, `bool`, a `Literal[...]`/`Enum` (a\n", + " constrained scalar that carries its own allowed set), a `list[...]` of any of\n", + " those, or `None` (raw passthrough)\n", "\n", "A more complete declaration list might look like:" ] @@ -109,15 +118,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "Parameter(name='objective', description='Goal the attack pursues', default=None, param_type=, choices=None)\n", - "Parameter(name='max_turns', description='Conversation cap', default=5, param_type=, choices=None)\n", - "Parameter(name='mode', description='Speed mode', default='fast', param_type=, choices=('fast', 'slow'))\n", - "Parameter(name='tags', description='Tag list', default=['default'], param_type=list[str], choices=None)\n" + "Parameter(name='objective', description='Goal the attack pursues', default=None, param_type=, destination=)\n", + "Parameter(name='max_turns', description='Conversation cap', default=5, param_type=, destination=)\n", + "Parameter(name='mode', description='Speed mode', default='fast', param_type=typing.Literal['fast', 'slow'], destination=)\n", + "Parameter(name='tags', description='Tag list', default=['default'], param_type=list[str], destination=)\n" ] } ], "source": [ - "from pyrit.common import Parameter\n", + "from typing import Literal\n", + "\n", + "from pyrit.models import Parameter\n", "\n", "# What a scenario author would return from supported_parameters():\n", "example_declarations = [\n", @@ -125,13 +136,12 @@ " Parameter(name=\"objective\", description=\"Goal the attack pursues\", param_type=str),\n", " # Scalar with default\n", " Parameter(name=\"max_turns\", description=\"Conversation cap\", default=5, param_type=int),\n", - " # Choices: behaves like an enum\n", + " # Constrained scalar: a Literal behaves like an enum (the type *is* the allowed set)\n", " Parameter(\n", " name=\"mode\",\n", " description=\"Speed mode\",\n", " default=\"fast\",\n", - " param_type=str,\n", - " choices=(\"fast\", \"slow\"),\n", + " param_type=Literal[\"fast\", \"slow\"],\n", " ),\n", " # List parameter\n", " Parameter(name=\"tags\", description=\"Tag list\", default=[\"default\"], param_type=list[str]),\n", @@ -257,6 +267,27 @@ "id": "8", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "TargetRegistry entry 'objective_scorer_chat' not found. Falling back to default OpenAIChatTarget.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using fallback default objective scorer: TrueFalseInverterScorer with chat target: OpenAIChatTarget\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "TextAdaptive: _EXCLUDED_TECHNIQUES entries ['prompt_sending'] are not in the current scenario-techniques catalog ['context_compliance', 'crescendo_history_lecture', 'crescendo_journalist_interview', 'crescendo_movie_director', 'crescendo_simulated', 'many_shot', 'pair', 'red_teaming', 'role_play', 'tap', 'violent_durian']; the exclusion is a no-op for those entries. Remove stale entries or update the catalog.\n" + ] + }, { "name": "stderr", "output_type": "stream", @@ -496,7 +527,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.13" + "version": "3.13.5" } }, "nbformat": 4, diff --git a/doc/code/scenarios/2_custom_scenario_parameters.py b/doc/code/scenarios/2_custom_scenario_parameters.py index 946ca2d5f5..1a23302297 100644 --- a/doc/code/scenarios/2_custom_scenario_parameters.py +++ b/doc/code/scenarios/2_custom_scenario_parameters.py @@ -67,13 +67,16 @@ # - **name**: dict key in `self.params`, converted to `--kebab-case` for the CLI # - **description**: shown in `--list-scenarios` and `--help` # - **default**: value used when not supplied; deep-copied per run -# - **param_type**: `str`, `int`, `float`, `bool`, `list[str]`, or `None` (raw passthrough) -# - **choices**: optional tuple of allowed values (not supported with `list` types) +# - **param_type**: `str`, `int`, `float`, `bool`, a `Literal[...]`/`Enum` (a +# constrained scalar that carries its own allowed set), a `list[...]` of any of +# those, or `None` (raw passthrough) # # A more complete declaration list might look like: # %% -from pyrit.common import Parameter +from typing import Literal + +from pyrit.models import Parameter # What a scenario author would return from supported_parameters(): example_declarations = [ @@ -81,13 +84,12 @@ Parameter(name="objective", description="Goal the attack pursues", param_type=str), # Scalar with default Parameter(name="max_turns", description="Conversation cap", default=5, param_type=int), - # Choices: behaves like an enum + # Constrained scalar: a Literal behaves like an enum (the type *is* the allowed set) Parameter( name="mode", description="Speed mode", default="fast", - param_type=str, - choices=("fast", "slow"), + param_type=Literal["fast", "slow"], ), # List parameter Parameter(name="tags", description="Tag list", default=["default"], param_type=list[str]), diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 6b366ec537..981ba271da 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -15,6 +15,7 @@ import base64 import inspect import mimetypes +import types import uuid from functools import lru_cache from pathlib import Path @@ -34,15 +35,12 @@ CreateConverterResponse, PreviewStep, ) +from pyrit.common import REQUIRED_VALUE from pyrit.memory import data_serializer_factory from pyrit.models import PromptDataType - -# ``get_union_non_none_args`` is a general type-introspection utility used here to -# render parameter types for the catalog (a presentation concern owned by this -# service). +from pyrit.models.parameter import Parameter from pyrit.registry.components import ConverterRegistry -from pyrit.registry.components.converter_registry import _ConverterParameterMetadata -from pyrit.registry.resolution import get_union_non_none_args +from pyrit.registry.resolution import display_choices def _serialize_type(annotation: Any) -> str: @@ -60,11 +58,14 @@ def _serialize_type(annotation: Any) -> str: if get_origin(annotation) is Literal: args = get_args(annotation) return f"Literal[{', '.join(repr(a) for a in args)}]" - non_none = get_union_non_none_args(annotation) - if non_none is not None and len(non_none) == 1: - inner = _serialize_type(non_none[0]) - has_none = type(None) in get_args(annotation) - return f"Optional[{inner}]" if has_none else inner + origin = get_origin(annotation) + if origin is Union or origin is types.UnionType: + args = get_args(annotation) + non_none = [a for a in args if a is not type(None)] + if len(non_none) == 1: + inner = _serialize_type(non_none[0]) + has_none = type(None) in args + return f"Optional[{inner}]" if has_none else inner if hasattr(annotation, "__name__"): return str(annotation.__name__) return str(annotation) @@ -133,33 +134,37 @@ async def list_converter_catalog_async(self) -> ConverterCatalogResponse: converter_type=metadata.class_name, supported_input_types=list(metadata.supported_input_types), supported_output_types=list(metadata.supported_output_types), - parameters=[self._build_parameter_schema(p) for p in metadata.parameters if p.coercible_from_string], + parameters=[self._build_parameter_schema(p) for p in metadata.parameters if p.is_string_coercible], is_llm_based=metadata.is_llm_based, description=metadata.class_description or None, ) - for metadata in self._registry.list_class_metadata() + for metadata in self._registry.get_all_registered_class_metadata() ] return ConverterCatalogResponse(items=items) @staticmethod - def _build_parameter_schema(parameter: _ConverterParameterMetadata) -> ConverterParameterSchema: + def _build_parameter_schema(parameter: Parameter) -> ConverterParameterSchema: """ - Map registry parameter metadata to the catalog DTO. + Map a derived ``Parameter`` to the catalog DTO. - Renders the raw annotation to a human-readable ``type_name`` for the - frontend (presentation concern owned by this service). + Renders the parameter's ``param_type`` to a human-readable ``type_name`` and + projects its allowed values (presentation concerns owned by this service). + Required-ness is read from the ``REQUIRED_VALUE`` sentinel default. Returns: ConverterParameterSchema: The parameter schema for the catalog entry. """ + required = parameter.default is REQUIRED_VALUE + default_value = None if required or parameter.default is None else str(parameter.default) + choices = display_choices(parameter.param_type) return ConverterParameterSchema( name=parameter.name, - type_name=_serialize_type(parameter.annotation), - required=parameter.required, - default_value=parameter.default_value, - choices=list(parameter.choices) if parameter.choices is not None else None, - description=parameter.description, + type_name=_serialize_type(parameter.param_type), + required=required, + default_value=default_value, + choices=[str(c) for c in choices] if choices is not None else None, + description=parameter.description or None, ) async def get_converter_async(self, *, converter_id: str) -> ConverterInstance | None: diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index eddad0f0f3..fcc2143928 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -21,18 +21,18 @@ import logging import shlex from pathlib import Path -from typing import TYPE_CHECKING, Any, get_origin +from typing import TYPE_CHECKING, Any, Literal, get_args, get_origin from pyrit.common.cli_helpers import ( CONFIG_FILE_HELP, validate_log_level, validate_log_level_argparse, ) -from pyrit.common.parameter import Parameter, coerce_value if TYPE_CHECKING: from collections.abc import Callable + from pyrit.models.parameter import Parameter from pyrit.setup.configuration_loader import ScenarioConfig # --------------------------------------------------------------------------- @@ -569,20 +569,32 @@ def _arg_spec_from_parameter(*, param: Parameter) -> _ArgSpec: Returns: _ArgSpec: Spec with ``scenario__`` result key and a parser - that routes through ``pyrit.common.parameter.coerce_value``. + that routes through ``pyrit.models.parameter.coerce_value``. """ + from pyrit.models.parameter import Parameter + multi = get_origin(param.param_type) is list parser: Callable[[str], Any] | None if multi: - # Per-element coercion; v1 only ships list[str]. - parser = str - elif param.param_type is None or (param.param_type is str and param.choices is None): - # No coercion needed and no choices to enforce. + # Per-element coercion via a temporary scalar-typed Parameter. + type_args = get_args(param.param_type) + element_type = type_args[0] if type_args else str + element_param = Parameter( + name=param.name, + description=param.description, + param_type=element_type, + ) + + def parser(raw: str) -> Any: + return element_param.coerce_value(raw) + + elif param.param_type is None or param.param_type is str: + # No coercion needed (plain str / untyped passthrough). parser = None else: - # Coerce + validate (handles ints/floats/bools AND str-with-choices). + # Coerce + validate (handles ints/floats/bools AND Literal/Enum membership). def parser(raw: str) -> Any: - return coerce_value(param=param, raw_value=raw) + return param.coerce_value(raw) return _ArgSpec( flags=[_normalize_scenario_flag(name=param.name)], @@ -649,7 +661,9 @@ def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> list[Param Maps the display ``param_type`` string ("int", "float", "bool", "str", "list[...]", "any") back to a concrete ``param_type`` so the shell parser - can apply per-element coercion and treat list params as ``multi_value``. + can apply per-element coercion and treat list params as ``multi_value``. A + ``choices`` list is reconstructed into a ``Literal[...]`` (the single source + of truth for an allowed set) — typed by the same base scalar. Args: api_params: List of parameter dicts from ``GET /api/scenarios/catalog/{name}``. @@ -657,26 +671,33 @@ def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> list[Param Returns: list[Parameter] | None: Parameter list when ``api_params`` is non-empty, else ``None``. """ + from pyrit.models.parameter import Parameter + if not api_params: return None type_map: dict[str, Any] = {"int": int, "float": float, "bool": bool, "str": str} parameters: list[Parameter] = [] for p in api_params: type_display = p.get("param_type", "") - if p.get("is_list"): - element_type = type_map.get(type_display.removeprefix("list[").rstrip("]"), str) - resolved_type: Any = list[element_type] # type: ignore[valid-type] - else: - resolved_type = type_map.get(type_display) + is_list = bool(p.get("is_list")) + base_name = type_display.removeprefix("list[").rstrip("]") if is_list else type_display + base_type = type_map.get(base_name, str) + raw_choices = p.get("choices") - choices: tuple[Any, ...] | None = tuple(raw_choices) if raw_choices else None + if raw_choices: + choice_param = Parameter(name=p["name"], description="", param_type=base_type) + members = tuple(choice_param.coerce_value(c) for c in raw_choices) + element_type: Any = Literal[members] # ty: ignore[invalid-type-form] + else: + element_type = type_map.get(base_name) if is_list else type_map.get(type_display) + + resolved_type: Any = list[element_type] if is_list else element_type # type: ignore[valid-type] parameters.append( Parameter( name=p["name"], description=p.get("description", ""), param_type=resolved_type, default=p.get("default"), - choices=choices, ) ) return parameters diff --git a/pyrit/common/__init__.py b/pyrit/common/__init__.py index e3891460a0..16c46384a2 100644 --- a/pyrit/common/__init__.py +++ b/pyrit/common/__init__.py @@ -9,6 +9,11 @@ directly, e.g.:: from pyrit.common.net_utility import get_httpx_client + +``Parameter`` is no longer part of ``pyrit.common``; it lives in ``pyrit.models``. +Accessing ``pyrit.common.Parameter`` (or ``from pyrit.common import Parameter``) +still resolves for one release but emits a ``DeprecationWarning``. Import from +``pyrit.models`` instead. This alias will be removed in 0.16.0. """ from pyrit.common.apply_defaults import ( @@ -22,9 +27,8 @@ ) from pyrit.common.brick_contract import enforce_keyword_only_init from pyrit.common.default_values import get_non_required_value, get_required_value -from pyrit.common.deprecation import print_deprecation_message +from pyrit.common.deprecation import module_deprecation_getattr, print_deprecation_message from pyrit.common.notebook_utils import is_in_ipython_session -from pyrit.common.parameter import Parameter from pyrit.common.singleton import Singleton from pyrit.common.utils import ( combine_dict, @@ -36,6 +40,16 @@ ) from pyrit.common.yaml_loadable import YamlLoadable +# ``Parameter`` moved to ``pyrit.models``. Resolve it lazily so that (a) ``pyrit.common`` +# stays free of the heavy ``pyrit.models`` import on the fast CLI path, and (b) the +# deprecated ``from pyrit.common import Parameter`` access emits a one-time warning. +__getattr__ = module_deprecation_getattr( + old_module="pyrit.common", + target_module="pyrit.models", + names=["Parameter"], + removed_in="0.16.0", +) + __all__ = [ "apply_defaults", "apply_defaults_to_method", @@ -49,7 +63,6 @@ "get_random_indices", "get_required_value", "is_in_ipython_session", - "Parameter", "print_deprecation_message", "REQUIRED_VALUE", "reset_default_values", diff --git a/pyrit/common/parameter.py b/pyrit/common/parameter.py index 24da1ec56a..d4f130256f 100644 --- a/pyrit/common/parameter.py +++ b/pyrit/common/parameter.py @@ -1,233 +1,37 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Unified parameter declaration and coercion helpers shared by initializers, scenarios, and CLI parsers.""" +""" +Deprecation shim — the parameter contract now lives in ``pyrit.models.parameter``. -import copy -from dataclasses import dataclass -from types import GenericAlias -from typing import Any, get_args, get_origin +Importing names from ``pyrit.common.parameter`` still works for one release but +emits a one-time ``DeprecationWarning`` per name. Import from +``pyrit.models.parameter`` (or ``pyrit.models``) instead. This shim will be +removed in 0.16.0. -_SUPPORTED_SCALAR_TYPES: tuple[type, ...] = (str, int, float, bool) +NOTE: When this shim is removed, also drop the ``pyrit.common.parameter`` entry +from ``KNOWN_COMMON_VIOLATIONS`` in ``tests/unit/models/test_import_boundary.py`` +if it has not already been removed. +""" +from __future__ import annotations -@dataclass(frozen=True) -class Parameter: - """ - Describes a parameter that a PyRIT component (initializer or scenario) accepts. +from pyrit.common.deprecation import module_deprecation_getattr - Args: - name (str): Parameter name; becomes the key in ``params`` and the - ``--kebab-case`` CLI flag. - description (str): Human-readable description shown in ``--help`` and - ``--list-*`` output. - default (Any): Default value when not supplied. Defaults to None. Must not - contain secrets; defaults are rendered verbatim by ``--list-scenarios``. - param_type (type | GenericAlias | None): Type for scenario-side coercion. - Supported: ``str``, ``int``, ``float``, ``bool``, ``list[str]``. None - means no coercion (the initializer convention). Defaults to None. - choices (tuple[Any, ...] | None): Optional allowed values. Coerced to - ``param_type`` and tuple-normalized so argparse, YAML, and runtime - membership checks see the same Python type. Defaults to None. - """ +__all__ = [ + "ComponentType", + "Parameter", + "ParameterDestination", + "RegistryReference", +] - name: str - description: str - default: Any = None - param_type: type | GenericAlias | None = None - choices: tuple[Any, ...] | None = None +__getattr__ = module_deprecation_getattr( + old_module="pyrit.common.parameter", + target_module="pyrit.models.parameter", + names=__all__, + removed_in="0.16.0", +) - def __post_init__(self) -> None: - """Tuple-ify ``choices`` and coerce them to ``param_type`` for scalar types.""" - if self.choices is not None and not isinstance(self.choices, tuple): - object.__setattr__(self, "choices", tuple(self.choices)) - # Lists with choices are rejected at declaration time, so list[T] is skipped here. - if self.choices is not None and self.param_type in (bool, int, float, str): - try: - coerced = tuple( - _coerce_choice_value(name=self.name, param_type=self.param_type, raw_value=c) for c in self.choices - ) - except ValueError: - # Leave choices alone; _validate_declarations surfaces the error. - return - object.__setattr__(self, "choices", coerced) - -def _coerce_choice_value(*, name: str, param_type: Any, raw_value: Any) -> Any: - """ - Coerce one declared choice to ``param_type``. - - Helper for ``Parameter.__post_init__``. ``param_type`` is typed ``Any`` - because the dataclass field is ``type | GenericAlias | None``; the caller - gates on scalar types before invoking this helper. - - Args: - name (str): Parameter name (used only in error messages). - param_type (Any): One of ``bool``, ``int``, ``float``, ``str``. - raw_value (Any): The choice value as declared by the author. - - Returns: - Any: The coerced choice value. - """ - if param_type is bool: - return coerce_bool(param_name=name, raw_value=raw_value) - if param_type is int: - return coerce_scalar(param_name=name, scalar_type=int, raw_value=raw_value) - if param_type is float: - return coerce_scalar(param_name=name, scalar_type=float, raw_value=raw_value) - return str(raw_value) - - -def validate_param_type(*, param: Parameter) -> None: - """ - Reject parameter declarations with an unsupported ``param_type``. - - Args: - param (Parameter): The parameter declaration. - - Raises: - ValueError: If ``param_type`` is not ``None``, ``str``, ``int``, - ``float``, ``bool``, or ``list[str]``. - """ - param_type = param.param_type - if param_type is None or param_type in _SUPPORTED_SCALAR_TYPES: - return - if get_origin(param_type) is list: - type_args = get_args(param_type) - element_type = type_args[0] if type_args else str - if element_type is str: - return - - raise ValueError( - f"Parameter '{param.name}' has unsupported param_type {param_type!r}. " - f"Supported types: str, int, float, bool, list[str], or None." - ) - - -def coerce_value(*, param: Parameter, raw_value: Any) -> Any: - """ - Coerce a raw value to ``param.param_type`` and validate against ``param.choices``. - - Args: - param (Parameter): The parameter declaration. - raw_value (Any): Value as supplied by CLI, YAML, or declared default. - - Returns: - Any: The coerced value. - - Raises: - ValueError: If coercion fails or the result is not in ``choices``. - """ - param_type = param.param_type - if param_type is None: - # Deep-copy so mutable raw values don't share identity with self.params. - value: Any = copy.deepcopy(raw_value) - elif param_type is bool: - value = coerce_bool(param_name=param.name, raw_value=raw_value) - elif param_type is int: - value = coerce_scalar(param_name=param.name, scalar_type=int, raw_value=raw_value) - elif param_type is float: - value = coerce_scalar(param_name=param.name, scalar_type=float, raw_value=raw_value) - elif param_type is str: - value = str(raw_value) - elif get_origin(param_type) is list: - value = coerce_list(param=param, raw_value=raw_value) - else: - raise ValueError( - f"Parameter '{param.name}' has unsupported param_type {param_type!r}. " - f"Supported types: str, int, float, bool, list[str]." - ) - - if param.choices is not None and value not in param.choices: - raise ValueError(f"Parameter '{param.name}' value {value!r} is not in declared choices {param.choices!r}.") - - return value - - -def coerce_scalar(*, param_name: str, scalar_type: type, raw_value: Any) -> Any: - """ - Coerce ``raw_value`` to ``int`` or ``float``, rejecting native ``bool`` inputs. - - Avoids ``int(True) == 1`` / ``float(False) == 0.0`` silent surprises. - - Args: - param_name (str): Parameter name for error messages. - scalar_type (type): ``int`` or ``float``. - raw_value (Any): Value to coerce. - - Returns: - Any: The coerced numeric value. - - Raises: - ValueError: If ``raw_value`` is a ``bool`` or cannot be coerced. - """ - if isinstance(raw_value, bool): - raise ValueError( - f"Parameter '{param_name}' expects {scalar_type.__name__} but received a bool ({raw_value!r})." - ) - try: - return scalar_type(raw_value) - except (TypeError, ValueError) as exc: - raise ValueError( - f"Parameter '{param_name}' could not be coerced to {scalar_type.__name__}: {raw_value!r} ({exc})." - ) from exc - - -def coerce_bool(*, param_name: str, raw_value: Any) -> bool: - """ - Parse ``raw_value`` as a boolean, avoiding the ``bool("false") is True`` argparse footgun. - - Accepts native ``bool`` and case-insensitive ``true``/``1``/``yes`` / - ``false``/``0``/``no`` strings. - - Args: - param_name (str): Parameter name for error messages. - raw_value (Any): Value to coerce. - - Returns: - bool: The coerced boolean. - - Raises: - ValueError: If ``raw_value`` is not a recognized boolean form. - """ - if isinstance(raw_value, bool): - return raw_value - if isinstance(raw_value, str): - normalized = raw_value.strip().lower() - if normalized in ("true", "1", "yes"): - return True - if normalized in ("false", "0", "no"): - return False - raise ValueError( - f"Parameter '{param_name}' expects bool but received {raw_value!r}. " - f"Accepted values: true/false, 1/0, yes/no (case-insensitive), or a native bool." - ) - - -def coerce_list(*, param: Parameter, raw_value: Any) -> list[Any]: - """ - Coerce a ``list[T]`` parameter (v1: only ``list[str]``). - - Args: - param (Parameter): Declaration with ``param_type`` like ``list[str]``. - raw_value (Any): Must be a list. - - Returns: - list[Any]: The coerced list. - - Raises: - ValueError: If ``raw_value`` is not a list or the element type isn't ``str``. - """ - if not isinstance(raw_value, list): - raise ValueError( - f"Parameter '{param.name}' expects a list but received {type(raw_value).__name__} ({raw_value!r})." - ) - - type_args = get_args(param.param_type) - element_type = type_args[0] if type_args else str - - if element_type is str: - return [str(item) for item in raw_value] - raise ValueError( - f"Parameter '{param.name}' has unsupported list element type {element_type!r}. Supported list types: list[str]." - ) +def __dir__() -> list[str]: + return sorted(__all__) diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 43b5168232..6cd3b84b1f 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -85,6 +85,12 @@ group_message_pieces_into_conversations, sort_message_pieces, ) +from pyrit.models.parameter import ( + ComponentType, + Parameter, + ParameterDestination, + RegistryReference, +) from pyrit.models.question_answering import QuestionAnsweringDataset, QuestionAnsweringEntry, QuestionChoice from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT from pyrit.models.results.strategy_result import StrategyResult, StrategyResultT @@ -129,6 +135,7 @@ "class_name_to_snake_case", "CapabilityName", "ComponentIdentifier", + "ComponentType", "compute_eval_hash", "config_hash", "ConverterIdentifier", @@ -170,10 +177,13 @@ "Modality", "NextMessageSystemPromptPaths", "ObjectiveTargetEvaluationIdentifier", + "Parameter", + "ParameterDestination", "PromptDataType", "PromptResponseError", "QuestionAnsweringDataset", "QuestionAnsweringEntry", + "RegistryReference", "QuestionChoice", "REGISTRY_NAME_PATTERN", "ScaleDescription", diff --git a/pyrit/models/identifiers/__init__.py b/pyrit/models/identifiers/__init__.py index b1c4934495..c6dd9278fc 100644 --- a/pyrit/models/identifiers/__init__.py +++ b/pyrit/models/identifiers/__init__.py @@ -30,6 +30,7 @@ ) from pyrit.models.identifiers.evaluation_markers import EvalMarker, Evaluate, Exclude, Include, Unwrap from pyrit.models.identifiers.identifier_filters import IdentifierFilter, IdentifierType +from pyrit.models.identifiers.param_markers import Param, ParamMarker from pyrit.models.identifiers.scorer_identifier import ScorerIdentifier from pyrit.models.identifiers.seed_identifier import SeedIdentifier from pyrit.models.identifiers.target_identifier import TargetIdentifier @@ -54,6 +55,8 @@ "Include", "ObjectiveTargetEvaluationIdentifier", "REGISTRY_NAME_PATTERN", + "Param", + "ParamMarker", "ScorerEvaluationIdentifier", "ScorerIdentifier", "SeedIdentifier", diff --git a/pyrit/models/identifiers/component_identifier.py b/pyrit/models/identifiers/component_identifier.py index 19834e3450..efedf26c7d 100644 --- a/pyrit/models/identifiers/component_identifier.py +++ b/pyrit/models/identifiers/component_identifier.py @@ -12,6 +12,8 @@ 2. Hash is content-addressed from behavioral params only. 3. Children carry their own hashes. 4. Adding optional params with None default is backward-compatible (None values excluded). + 5. Attributes are identity-bearing state: hashed like params, but excluded from + the eval hash and never passed to a constructor. """ from __future__ import annotations @@ -20,7 +22,7 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, ClassVar, get_args, get_origin +from typing import TYPE_CHECKING, Any, ClassVar, get_args, get_origin from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, model_serializer, model_validator from typing_extensions import Self, TypeAliasType @@ -28,6 +30,9 @@ import pyrit from pyrit.common.deprecation import print_deprecation_message +if TYPE_CHECKING: + from pyrit.models.parameter import ComponentType + #: The set of value types allowed inside ``ComponentIdentifier.params``. Params #: must be JSON-serializable scalars (``str`` / ``int`` / ``float`` / ``bool`` / #: ``None``) or arbitrarily nested ``list`` / ``dict`` containers of those. This @@ -51,6 +56,7 @@ "eval_hash", "children", "params", + "attributes", "__type__", "__module__", } @@ -86,6 +92,7 @@ def _build_hash_dict( class_module: str, params: dict[str, JSONValue], children: dict[str, ComponentIdentifier | list[ComponentIdentifier]], + attributes: dict[str, JSONValue] | None = None, ) -> dict[str, Any]: """ Build the canonical dictionary used for hash computation. @@ -100,6 +107,8 @@ def _build_hash_dict( params (dict[str, JSONValue]): Behavioral parameters (non-None values only). children (dict[str, ComponentIdentifier | list[ComponentIdentifier]]): Child name to ComponentIdentifier or list of ComponentIdentifier. + attributes (dict[str, JSONValue] | None): Identity-bearing state (non-None values + only). Hashed like params but excluded from the eval hash. Returns: dict[str, Any]: The canonical dictionary for hashing. @@ -113,6 +122,14 @@ def _build_hash_dict( # won't change existing hashes, making the schema backward-compatible. hash_dict.update({key: value for key, value in sorted(params.items()) if value is not None}) + # Attributes sit under their own key (never inlined alongside params) and, + # like params, only contribute non-None values so an optional attribute with + # a None default stays hash-compatible. + if attributes: + attr_dict = {key: value for key, value in sorted(attributes.items()) if value is not None} + if attr_dict: + hash_dict[ComponentIdentifier.KEY_ATTRIBUTES] = attr_dict + # Children contribute their hashes, not their full structure. if children: children_hashes: dict[str, Any] = {} @@ -174,12 +191,20 @@ class ComponentIdentifier(BaseModel): serializes and hashes identically to a plain ``ComponentIdentifier`` built with the same params/children. Non-promoted members simply stay in ``params`` / ``children``. + Attributes: a third bucket alongside ``params`` and ``children`` for + identity-bearing **state** that is neither behavioral nor a constructor input — + e.g. a deployment / model version observed at runtime. Like params it feeds the + content hash, but it is excluded from the eval hash and is never used to build + the component. No identifier promotes attributes to a typed field today; they are + populated explicitly through the ``attributes`` dict. + Serialization: ``model_dump()`` returns a flat dict where reserved keys (``class_name``, ``class_module``, ``hash``, ``pyrit_version``, ``eval_hash``, - ``children``) sit at the top level alongside the inlined param values. This shape is - also the storage / REST format. Pass ``context={"max_value_length": N}`` to truncate - long string param values. ``model_validate()`` accepts the same flat shape (plus a - structured form with an explicit ``params`` dict). + ``children``, ``attributes``) sit at the top level alongside the inlined param + values. This shape is also the storage / REST format. Pass + ``context={"max_value_length": N}`` to truncate long string param values. + ``model_validate()`` accepts the same flat shape (plus a structured form with an + explicit ``params`` dict). Mutability: the model is frozen, but ``params`` and ``children`` are dicts whose contents are not deep-frozen — mutating them after construction creates an @@ -195,6 +220,7 @@ class ComponentIdentifier(BaseModel): KEY_EVAL_HASH: ClassVar[str] = "eval_hash" KEY_PYRIT_VERSION: ClassVar[str] = "pyrit_version" KEY_CHILDREN: ClassVar[str] = "children" + KEY_ATTRIBUTES: ClassVar[str] = "attributes" LEGACY_KEY_TYPE: ClassVar[str] = "__type__" LEGACY_KEY_MODULE: ClassVar[str] = "__module__" @@ -208,6 +234,10 @@ class ComponentIdentifier(BaseModel): params: dict[str, JSONValue] = Field(default_factory=dict) #: Named child identifiers for compositional identity (e.g., a scorer's target). children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = Field(default_factory=dict) + #: Identity-bearing state that is hashed (like ``params``) but excluded from the + #: eval hash and never passed to a constructor — e.g. a runtime-resolved model + #: version. Same value rules as ``params`` (see ``JSONValue``). + attributes: dict[str, JSONValue] = Field(default_factory=dict) #: Content-addressed SHA256 hash. Computed automatically when ``None``; #: pass an explicit value to preserve a hash from DB storage where params #: may have been truncated. @@ -218,10 +248,43 @@ class ComponentIdentifier(BaseModel): #: to the identifier so it survives DB round-trips with truncated params. eval_hash: str | None = None + #: The registry family this identifier type builds. ``None`` on the base means + #: a plain ``ComponentIdentifier`` is never a buildable/resolvable reference; + #: each concrete leaf identifier (``TargetIdentifier`` / ``ConverterIdentifier`` + #: / ``ScorerIdentifier``) overrides it with its own ``ComponentType`` so a + #: child-identifier-typed field self-reports which registry resolves it. + component_type: ClassVar[ComponentType | None] = None + # ------------------------------------------------------------------ # Promotion (typed projection — derived from the subclass's own fields) # ------------------------------------------------------------------ + @staticmethod + def _child_identifier_type(annotation: Any) -> type[ComponentIdentifier] | None: + """ + Return the ``ComponentIdentifier`` subclass a field annotation denotes, if any. + + Handles a direct subclass, an ``Optional`` wrapper, and a ``list[...]`` of + subclasses (the two shapes promoted fields use for children). + + Args: + annotation (Any): The resolved field annotation (from + ``model_fields[name].annotation``). + + Returns: + type[ComponentIdentifier] | None: The child identifier subclass, or + ``None`` for a scalar (param) field. + """ + if get_origin(annotation) is list: + args = get_args(annotation) + inner = args[0] if args else None + return inner if isinstance(inner, type) and issubclass(inner, ComponentIdentifier) else None + + for candidate in get_args(annotation) or (annotation,): + if isinstance(candidate, type) and issubclass(candidate, ComponentIdentifier): + return candidate + return None + @staticmethod def _is_child_field(annotation: Any) -> bool: """ @@ -236,13 +299,7 @@ def _is_child_field(annotation: Any) -> bool: or a ``list`` thereof (optionally wrapped in ``| None``); ``False`` for scalar (param) fields. """ - if get_origin(annotation) is list: - args = get_args(annotation) - inner = args[0] if args else None - return isinstance(inner, type) and issubclass(inner, ComponentIdentifier) - - candidates: tuple[Any, ...] = get_args(annotation) or (annotation,) - return any(isinstance(c, type) and issubclass(c, ComponentIdentifier) for c in candidates) + return ComponentIdentifier._child_identifier_type(annotation) is not None @classmethod def _promoted_fields(cls) -> tuple[str, ...]: @@ -276,6 +333,71 @@ def _promoted_child_fields(cls) -> tuple[str, ...]: """ return tuple(n for n in cls._promoted_fields() if cls._is_child_field(cls.model_fields[n].annotation)) + @classmethod + def get_reference_component_types(cls) -> dict[str, ComponentType]: + """ + Map constructor-arg names to the component family each reference resolves to. + + A promoted field that is an included constructor parameter (explicit + ``Param.Include`` or unmarked) and is typed as a child identifier + contributes ``{arg_name: component_type}``, where ``arg_name`` is the marker + alias or the field name and ``component_type`` is the child identifier + type's own ``component_type``. ``Param.Exclude()`` fields, plain-value + fields, and any field typed as a base ``ComponentIdentifier`` (whose + ``component_type`` is ``None`` and is therefore not buildable) contribute + nothing. + + Returns: + dict[str, ComponentType]: Constructor-arg-name → referenced component type. + """ + from pyrit.models.identifiers.param_markers import ClassAttrMarker, ExcludeMarker, IncludeMarker, ParamMarker + + references: dict[str, ComponentType] = {} + for field_name in cls._promoted_fields(): + field = cls.model_fields[field_name] + marker = next((m for m in field.metadata if isinstance(m, ParamMarker)), None) + if isinstance(marker, (ExcludeMarker, ClassAttrMarker)): + continue + + child_type = cls._child_identifier_type(field.annotation) + if child_type is None or child_type.component_type is None: + continue + + arg_name = marker.alias if isinstance(marker, IncludeMarker) and marker.alias else field_name + references[arg_name] = child_type.component_type + + return references + + @classmethod + def get_class_attribute_values(cls, target_cls: type) -> dict[str, Any]: + """ + Read each ``Param.ClassAttr``-marked field's value off a target class. + + Lets a registry describe a *class* (with no configured instance) by + sourcing the marked fields directly from class attributes. For every + promoted field carrying a ``ClassAttrMarker``, reads the named class + attribute (defaulting to the field name upper-cased) from ``target_cls``. + + Args: + target_cls (type): The component class to read class attributes from. + + Returns: + dict[str, Any]: Field-name → class-attribute value, for each + ``Param.ClassAttr`` field. Missing attributes map to ``None``. + """ + from pyrit.models.identifiers.param_markers import ClassAttrMarker + + values: dict[str, Any] = {} + for field_name in cls._promoted_fields(): + field = cls.model_fields[field_name] + marker = next((m for m in field.metadata if isinstance(m, ClassAttrMarker)), None) + if marker is None: + continue + attr_name = marker.attr_name or field_name.upper() + values[field_name] = getattr(target_cls, attr_name, None) + + return values + # ------------------------------------------------------------------ # Validators # ------------------------------------------------------------------ @@ -339,6 +461,7 @@ def _normalize_input(cls, data: Any) -> Any: cls.KEY_PYRIT_VERSION, cls.KEY_EVAL_HASH, cls.KEY_CHILDREN, + cls.KEY_ATTRIBUTES, *promoted_fields, } @@ -428,6 +551,7 @@ def _promote_and_compute_hash(self) -> ComponentIdentifier: class_module=self.class_module, params=self.params, children=self.children, + attributes=self.attributes, ) object.__setattr__(self, "hash", config_hash(hash_dict)) return self @@ -473,6 +597,11 @@ def _serialize_flat(self, info: SerializationInfo) -> dict[str, Any]: serialized_children[name] = [c.model_dump(mode=mode, context=context) for c in child] result[self.KEY_CHILDREN] = serialized_children + if self.attributes: + result[self.KEY_ATTRIBUTES] = { + key: self._truncate_value(value=value, max_length=max_len) for key, value in self.attributes.items() + } + return result # ------------------------------------------------------------------ @@ -525,6 +654,7 @@ def with_eval_hash(self, eval_hash: str) -> ComponentIdentifier: class_module=self.class_module, params=self.params, children=self.children, + attributes=self.attributes, hash=self.hash, pyrit_version=self.pyrit_version, eval_hash=eval_hash, @@ -569,11 +699,14 @@ def __repr__(self) -> str: """ params_str = ", ".join(f"{k}={v!r}" for k, v in sorted(self.params.items())) children_str = ", ".join(f"{k}={v}" for k, v in sorted(self.children.items())) + attributes_str = ", ".join(f"{k}={v!r}" for k, v in sorted(self.attributes.items())) parts = [f"class={self.class_name}"] if params_str: parts.append(f"params=({params_str})") if children_str: parts.append(f"children=({children_str})") + if attributes_str: + parts.append(f"attributes=({attributes_str})") parts.append(f"hash={self.short_hash}") return f"ComponentIdentifier({', '.join(parts)})" @@ -588,6 +721,7 @@ def of( *, params: dict[str, Any] | None = None, children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, + attributes: dict[str, Any] | None = None, **promoted: Any, ) -> Self: """ @@ -602,6 +736,8 @@ def of( identifier. params: Optional behavioral params. children: Optional child identifiers. + attributes: Optional identity-bearing state (hashed, but excluded from + the eval hash and not a constructor input). ``None`` values dropped. **promoted: Optional promoted typed fields (for subclasses). Passed by name; ``None`` values are dropped. These are mirrored back into ``params`` / ``children`` automatically. @@ -611,6 +747,7 @@ def of( """ clean_params = {k: v for k, v in (params or {}).items() if v is not None} clean_children = {k: v for k, v in (children or {}).items() if v is not None} + clean_attributes = {k: v for k, v in (attributes or {}).items() if v is not None} clean_promoted = {k: v for k, v in promoted.items() if v is not None} return cls( @@ -618,6 +755,7 @@ def of( class_module=obj.__class__.__module__, params=clean_params, children=clean_children, + attributes=clean_attributes, **clean_promoted, ) @@ -681,7 +819,7 @@ def get_child_list(self, key: str) -> list[ComponentIdentifier]: def _collect_child_eval_hashes(self) -> set[str]: """ - Recursively collect all eval_hash values from child identifiers. + Recursively collect all eval_hash values from descendant identifiers. Returns: The set of non-empty eval_hash strings found in descendants. diff --git a/pyrit/models/identifiers/converter_identifier.py b/pyrit/models/identifiers/converter_identifier.py index 9371f52cd0..bdcf81ed41 100644 --- a/pyrit/models/identifiers/converter_identifier.py +++ b/pyrit/models/identifiers/converter_identifier.py @@ -5,16 +5,16 @@ from __future__ import annotations -from typing import Annotated - -from pydantic import Field +from typing import Annotated, ClassVar from pyrit.models.identifiers.component_identifier import ComponentIdentifier from pyrit.models.identifiers.evaluation_markers import Evaluate +from pyrit.models.identifiers.param_markers import Param from pyrit.models.identifiers.target_identifier import ( # noqa: TC001 TargetIdentifier, # runtime-required by Pydantic field annotations ) from pyrit.models.literals import PromptDataType # noqa: TC001 (runtime-required by Pydantic field annotations) +from pyrit.models.parameter import ComponentType class ConverterIdentifier(ComponentIdentifier): @@ -23,16 +23,27 @@ class ConverterIdentifier(ComponentIdentifier): Promotes the supported input/output data types; any converter-specific params stay in ``params``. The converter's own child slots — ``converter_target`` - (an LLM target) and ``sub_converters`` (nested converters) — are promoted to + (an LLM target) and ``sub_converter`` (a wrapped converter) — are promoted to typed fields. + + Build markers (``Param.*``) declare how these fields map to the converter's + constructor: the supported-type lists are class attributes sourced from the + converter class (``Param.ClassAttr``), while ``converter_target`` and + ``sub_converter`` are included constructor parameters whose identifier types + make them references resolved from the target and converter registries. """ - #: Input data types supported by this converter. - supported_input_types: Annotated[list[PromptDataType] | None, Evaluate.Include()] = None - #: Output data types produced by this converter. - supported_output_types: Annotated[list[PromptDataType] | None, Evaluate.Include()] = None + component_type: ClassVar[ComponentType] = ComponentType.CONVERTER + + #: Input data types supported by this converter (sourced from the + #: ``SUPPORTED_INPUT_TYPES`` class attribute, not a ctor arg). + supported_input_types: Annotated[list[PromptDataType] | None, Evaluate.Include(), Param.ClassAttr()] = None + #: Output data types produced by this converter (sourced from the + #: ``SUPPORTED_OUTPUT_TYPES`` class attribute, not a ctor arg). + supported_output_types: Annotated[list[PromptDataType] | None, Evaluate.Include(), Param.ClassAttr()] = None #: Target an LLM-backed converter calls (e.g., ``LLMGenericTextConverter``). - converter_target: Annotated[TargetIdentifier | None, Evaluate.Include()] = None - #: Nested converters a composite wraps (e.g., ``SelectiveTextConverter``), - #: typed recursively. - sub_converters: Annotated[list[ConverterIdentifier], Evaluate.Include()] = Field(default_factory=list) + converter_target: Annotated[TargetIdentifier | None, Evaluate.Include(), Param.Include()] = None + #: A nested converter a composite wraps (e.g., ``SelectiveTextConverter``), + #: typed recursively. An included constructor parameter resolved by name from + #: the converter registry. + sub_converter: Annotated[ConverterIdentifier | None, Evaluate.Include(), Param.Include()] = None diff --git a/pyrit/models/identifiers/param_markers.py b/pyrit/models/identifiers/param_markers.py new file mode 100644 index 0000000000..4d9ee447d5 --- /dev/null +++ b/pyrit/models/identifiers/param_markers.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Field-level *build* markers for strongly-typed identifiers. + +These markers are attached to identifier fields via ``typing.Annotated`` +metadata and declare — on the identifier itself — how each field maps to the +component's constructor. They are the build-time counterpart to the +``Evaluate.*`` markers (see ``evaluation_markers``): where ``Evaluate.*`` +governs the eval hash, ``Param.*`` governs how the registry derives the +``Parameter`` list and resolves constructor arguments. + +Usage:: + + class ConverterIdentifier(ComponentIdentifier): + supported_input_types: Annotated[ + list[PromptDataType] | None, Evaluate.Include(), Param.ClassAttr() + ] = None + converter_target: Annotated[ + TargetIdentifier | None, Evaluate.Include(), Param.Include() + ] = None + +Semantics (a field is, by default, an included constructor parameter named after +the field): + +* ``Param.Exclude()`` — the field is part of identity/eval but is **not** a + constructor input (e.g. a composite child slot with no 1:1 constructor arg). +* ``Param.ClassAttr(attr_name=...)`` — like ``Exclude`` (not a constructor input), + but additionally declares that the field's value, when describing the *class*, + is sourced from a class attribute. ``attr_name`` names that attribute; when + omitted it defaults to the field name upper-cased (e.g. the + ``supported_input_types`` field reads ``SUPPORTED_INPUT_TYPES``). A registry can + read these off the class without constructing an instance. +* ``Param.Include(alias=...)`` — the field **is** a constructor parameter. Whether + it is a coerced value or a registry **reference** is inferred from the field's + type: a child-identifier type (e.g. ``TargetIdentifier``) is resolved by name + from that kind's registry, while any other type is coerced from its raw value. + ``alias`` names the constructor arg when it differs from the identifier field + name. + +An unmarked field behaves like ``Param.Include()`` with no alias. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ParamMarker: + """Base class for all ``Param.*`` field markers.""" + + +@dataclass(frozen=True) +class ExcludeMarker(ParamMarker): + """Mark an identity/eval field that is **not** a constructor parameter.""" + + +@dataclass(frozen=True) +class ClassAttrMarker(ParamMarker): + """ + Mark an identity/eval field whose value is sourced from a class attribute. + + Like ``ExcludeMarker``, the field is **not** a constructor parameter. In + addition, it declares that the value (when describing the *class*, with no + configured instance) can be read off a class attribute, so a registry can + populate it without constructing an instance. + + Args: + attr_name (str | None): The class attribute name to read. ``None`` means + use the field name upper-cased (e.g. ``supported_input_types`` → + ``SUPPORTED_INPUT_TYPES``). + """ + + attr_name: str | None = None + + +@dataclass(frozen=True) +class IncludeMarker(ParamMarker): + """ + Mark an identity/eval field that **is** a constructor parameter. + + Whether the parameter is a plain coerced value or a registry **reference** is + inferred from the field's type: a child-identifier type (e.g. + ``TargetIdentifier``) is resolved by name from that kind's registry, while any + other type is coerced from its raw value. + + Args: + alias (str | None): The constructor arg name, when it differs from the + identifier field name. ``None`` means use the field name. + """ + + alias: str | None = None + + +class Param: + """Namespace for the field-level build markers (see module docstring).""" + + Exclude = ExcludeMarker + ClassAttr = ClassAttrMarker + Include = IncludeMarker diff --git a/pyrit/models/identifiers/scorer_identifier.py b/pyrit/models/identifiers/scorer_identifier.py index ceec10af49..8912230a27 100644 --- a/pyrit/models/identifiers/scorer_identifier.py +++ b/pyrit/models/identifiers/scorer_identifier.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Annotated +from typing import Annotated, ClassVar from pydantic import Field @@ -14,6 +14,7 @@ from pyrit.models.identifiers.target_identifier import ( # noqa: TC001 TargetIdentifier, # runtime-required by Pydantic field annotations ) +from pyrit.models.parameter import ComponentType class ScorerIdentifier(ComponentIdentifier): @@ -25,6 +26,8 @@ class ScorerIdentifier(ComponentIdentifier): ``sub_scorers`` (nested scorers). """ + component_type: ClassVar[ComponentType] = ComponentType.SCORER + #: The scorer category (e.g., ``"true_false"`` or ``"float_scale"``). scorer_type: Annotated[str | None, Evaluate.Include()] = None #: Name of the aggregator function combining sub-scores (e.g., ``"AND_"``). diff --git a/pyrit/models/identifiers/target_identifier.py b/pyrit/models/identifiers/target_identifier.py index 4f146cce47..c5a910def2 100644 --- a/pyrit/models/identifiers/target_identifier.py +++ b/pyrit/models/identifiers/target_identifier.py @@ -5,12 +5,13 @@ from __future__ import annotations -from typing import Annotated +from typing import Annotated, ClassVar from pydantic import Field from pyrit.models.identifiers.component_identifier import ComponentIdentifier from pyrit.models.identifiers.evaluation_markers import Evaluate +from pyrit.models.parameter import ComponentType class TargetIdentifier(ComponentIdentifier): @@ -31,6 +32,8 @@ class TargetIdentifier(ComponentIdentifier): unwrapped so a multi-target hashes the same as its inner target. """ + component_type: ClassVar[ComponentType] = ComponentType.TARGET + #: Target endpoint URL. endpoint: Annotated[str | None, Evaluate.Exclude()] = None #: Model or deployment name used in API calls. diff --git a/pyrit/models/parameter.py b/pyrit/models/parameter.py new file mode 100644 index 0000000000..f2a8436cf8 --- /dev/null +++ b/pyrit/models/parameter.py @@ -0,0 +1,350 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Declarative parameter model for registry and scenario construction.""" + +from __future__ import annotations + +import copy +import types +from dataclasses import dataclass +from enum import Enum +from types import GenericAlias +from typing import Any, Literal, Union, get_args, get_origin + +_SUPPORTED_SCALAR_TYPES: tuple[type, ...] = (str, int, float, bool) + + +class ComponentType(str, Enum): + """ + The component family a registry reference resolves to. + + Each member maps one-to-one to a registry singleton that resolves references + of that family by name (``TARGET`` → ``TargetRegistry``, ``CONVERTER`` → + ``ConverterRegistry``, ``SCORER`` → ``ScorerRegistry``). + """ + + TARGET = "target" + CONVERTER = "converter" + SCORER = "scorer" + + +class ParameterDestination(str, Enum): + """Where a declarative parameter is consumed at build time.""" + + CONSTRUCTOR = "constructor" + REGISTERED = "registered" + + +@dataclass(frozen=True) +class RegistryReference: + """Self-describing reference to another registry-backed component.""" + + component_type: ComponentType + name: str | None = None + annotation: Any | None = None + + +@dataclass(frozen=True) +class Parameter: + """ + Describes a parameter that a PyRIT component accepts. + + ``param_type`` carries the value's type and its allowed set (a ``Literal[...]`` + or ``Enum`` *is* the allowed set). ``reference``, when set, marks the parameter + as a registry reference: its value is supplied *by name* and resolved to a + registered instance by the registry layer (``Parameter`` itself never resolves + references). + + ``coerce_value`` and ``validate`` are the only public behaviors; all coercion + branching lives behind them so callers never touch a free function. + """ + + name: str + description: str + default: Any = None + param_type: type | GenericAlias | None = None + reference: RegistryReference | None = None + destination: ParameterDestination = ParameterDestination.CONSTRUCTOR + + @property + def is_string_coercible(self) -> bool: + """ + Whether a single string token can be coerced to this parameter's value. + + True for a non-reference plain scalar (``str`` / ``int`` / ``float`` / + ``bool``) or ``Literal[...]`` parameter — exactly the forms a text field or + CLI token can supply. References and structured types (lists, enums, + arbitrary objects) are False and are surfaced/handled elsewhere. + + Returns: + bool: True when a string can be coerced to this parameter's value. + """ + if self.reference is not None: + return False + if self.param_type in _SUPPORTED_SCALAR_TYPES: + return True + return get_origin(self.param_type) is Literal + + def is_reference_to(self, component_type: ComponentType) -> bool: + """ + Whether this parameter is a registry reference to the given component family. + + A reference parameter is supplied by name and resolved to a registered + instance by the registry layer. This is the single source of truth for + "does this parameter point at a ``TARGET`` / ``CONVERTER`` / ``SCORER``", + so callers never re-derive it from ``reference`` internals. + + Args: + component_type (ComponentType): The component family to test against. + + Returns: + bool: True when this parameter is a reference to ``component_type``. + """ + return self.reference is not None and self.reference.component_type is component_type + + def coerce_value(self, raw_value: Any) -> Any: + """ + Coerce ``raw_value`` to this parameter's declared type. + + A reference parameter passes its value through unchanged (the registry + layer resolves it by name). Otherwise it branches by shape: ``None`` + passes through (deep-copied), a ``list`` coerces per element, and a scalar + form (including ``Literal``/``Enum``) coerces and validates membership. + Arbitrary defaulted types pass through unchanged. + + Args: + raw_value (Any): The raw value to coerce. + + Returns: + Any: The coerced value (a deep copy for the ``None`` passthrough, a + coerced list for list types, a coerced scalar for scalar types, or + the raw value unchanged for reference/arbitrary types). + + Raises: + ValueError: If the value cannot be coerced to a constrained scalar or + list element type. + """ + if self.reference is not None: + return raw_value + param_type = self.param_type + if param_type is None: + return copy.deepcopy(raw_value) + if get_origin(param_type) is list: + return _coerce_list(param_name=self.name, param_type=param_type, raw_value=raw_value) + if _is_scalar_param_type(param_type): + return _coerce_simple_value(param_name=self.name, annotation=param_type, raw_value=raw_value) + return raw_value + + def validate(self) -> None: + """ + Reject a declaration with an unsupported ``param_type``. + + Supported forms are a plain scalar, a constrained scalar + (``Literal``/``Enum``), a ``list`` of any of those, a registry reference, + or ``None``. An otherwise-unsupported type is tolerated only when the + parameter declares a default (the builder simply does not supply it, and + the value passes through unchanged). + + Raises: + ValueError: If ``param_type`` is unsupported and no default is declared. + """ + if self.reference is not None: + return + param_type = self.param_type + if param_type is None or _is_scalar_param_type(param_type): + return + if get_origin(param_type) is list: + type_args = get_args(param_type) + element_type = type_args[0] if type_args else str + if _is_scalar_param_type(element_type): + return + if self.default is not None: + return + + raise ValueError( + f"Parameter '{self.name}' has unsupported param_type {param_type!r}. " + f"Supported types: str, int, float, bool, Literal[...], Enum, a list of those, " + f"or None (or provide a default)." + ) + + +def _unwrap_optional(annotation: Any) -> Any: + """ + Reduce ``Optional[X]`` / ``X | None`` to ``X`` (only for single-member unions). + + Returns: + Any: ``X`` when ``annotation`` is a single-member optional union, otherwise the + annotation unchanged. + """ + origin = get_origin(annotation) + if origin is Union or origin is types.UnionType: + non_none = [a for a in get_args(annotation) if a is not type(None)] + if len(non_none) == 1: + return non_none[0] + return annotation + + +def _is_enum_type(annotation: Any) -> bool: + """Return True when ``annotation`` is an ``Enum`` subclass.""" + return isinstance(annotation, type) and issubclass(annotation, Enum) + + +def _is_scalar_param_type(annotation: Any) -> bool: + """ + Return True when ``annotation`` is a coercible scalar form. + + A scalar form is a plain scalar (``str`` / ``int`` / ``float`` / ``bool``) or a + constrained scalar (``Literal[...]`` or an ``Enum`` subclass) that carries its + own allowed set. + + Returns: + bool: True when the annotation is a single coercible scalar form. + """ + if annotation in _SUPPORTED_SCALAR_TYPES: + return True + if get_origin(annotation) is Literal: + return True + return _is_enum_type(annotation) + + +def _coerce_simple_value(*, param_name: str, annotation: Any, raw_value: Any) -> Any: + """ + Coerce ``raw_value`` to a scalar ``annotation`` — the shared coercion core. + + Handles ``Optional[X]`` unwrap, ``Literal``/``Enum`` membership, and + int/float/bool/str. Anything else passes through unchanged. Both the + ``Parameter`` path (``coerce_value``) and the resolver's annotation path route + through this function so they cannot diverge on coerced values. + + Returns: + Any: The coerced value (a ``Literal``/``Enum`` member, an int/float/bool/str, or + the raw value unchanged for unsupported annotations). + + Raises: + ValueError: If the value is not a valid member of a ``Literal``/``Enum`` or + cannot be coerced to the annotated scalar type. + """ + annotation = _unwrap_optional(annotation) + if get_origin(annotation) is Literal: + return _coerce_literal(param_name=param_name, annotation=annotation, raw_value=raw_value) + if _is_enum_type(annotation): + return _coerce_enum(param_name=param_name, enum_type=annotation, raw_value=raw_value) + if annotation is bool: + return _coerce_bool(param_name=param_name, raw_value=raw_value) + if annotation is int: + return _coerce_scalar(param_name=param_name, scalar_type=int, raw_value=raw_value) + if annotation is float: + return _coerce_scalar(param_name=param_name, scalar_type=float, raw_value=raw_value) + if annotation is str: + return str(raw_value) + return raw_value + + +def _coerce_literal(*, param_name: str, annotation: Any, raw_value: Any) -> Any: + """ + Validate ``raw_value`` against a ``Literal`` and return the matching member. + + Returns: + Any: The matching ``Literal`` member. + + Raises: + ValueError: If ``raw_value`` does not match any allowed member. + """ + allowed = get_args(annotation) + for member in allowed: + if str(raw_value) == str(member): + return member + raise ValueError(f"Parameter '{param_name}' expected one of {[str(a) for a in allowed]}, got {raw_value!r}.") + + +def _coerce_enum(*, param_name: str, enum_type: type[Enum], raw_value: Any) -> Any: + """ + Validate ``raw_value`` against an ``Enum`` and return the matching member. + + Returns: + Any: The matching ``Enum`` member. + + Raises: + ValueError: If ``raw_value`` does not match any enum member by identity, value, or name. + """ + for member in enum_type: + if raw_value is member or str(raw_value) == str(member.value) or str(raw_value) == member.name: + return member + raise ValueError( + f"Parameter '{param_name}' expected one of {[member.name for member in enum_type]}, got {raw_value!r}." + ) + + +def _coerce_scalar(*, param_name: str, scalar_type: type, raw_value: Any) -> Any: + """ + Coerce ``raw_value`` to ``int`` or ``float`` while rejecting native ``bool`` inputs. + + Returns: + Any: The value coerced to ``scalar_type``. + + Raises: + ValueError: If ``raw_value`` is a native ``bool`` or cannot be coerced to ``scalar_type``. + """ + if isinstance(raw_value, bool): + raise ValueError( + f"Parameter '{param_name}' expects {scalar_type.__name__} but received a bool ({raw_value!r})." + ) + try: + return scalar_type(raw_value) + except (TypeError, ValueError) as exc: + raise ValueError( + f"Parameter '{param_name}' could not be coerced to {scalar_type.__name__}: {raw_value!r} ({exc})." + ) from exc + + +def _coerce_bool(*, param_name: str, raw_value: Any) -> bool: + """ + Parse ``raw_value`` as a boolean, accepting the usual textual forms. + + Returns: + bool: The parsed boolean value. + + Raises: + ValueError: If ``raw_value`` cannot be interpreted as a boolean. + """ + if isinstance(raw_value, bool): + return raw_value + if isinstance(raw_value, str): + normalized = raw_value.strip().lower() + if normalized in ("true", "1", "yes"): + return True + if normalized in ("false", "0", "no"): + return False + raise ValueError( + f"Parameter '{param_name}' expects bool but received {raw_value!r}; could not interpret as a boolean. " + f"Accepted values: true/false, 1/0, yes/no (case-insensitive), or a native bool." + ) + + +def _coerce_list(*, param_name: str, param_type: Any, raw_value: Any) -> list[Any]: + """ + Coerce a ``list[T]`` parameter by coercing each element to ``T``. + + Returns: + list[Any]: The list with each element coerced to the declared element type. + + Raises: + ValueError: If ``raw_value`` is not a list or the element type is unsupported. + """ + if not isinstance(raw_value, list): + raise ValueError( + f"Parameter '{param_name}' expects a list but received {type(raw_value).__name__} ({raw_value!r})." + ) + + type_args = get_args(param_type) + element_type = type_args[0] if type_args else str + + if _is_scalar_param_type(element_type): + return [ + _coerce_simple_value(param_name=param_name, annotation=element_type, raw_value=item) for item in raw_value + ] + raise ValueError( + f"Parameter '{param_name}' has unsupported list element type {element_type!r}. " + f"Supported list element types: str, int, float, bool, or Literal[...]." + ) diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 4dc10901ad..e53f75b6dd 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -203,7 +203,7 @@ def _create_identifier( *, params: dict[str, Any] | None = None, converter_target: ComponentIdentifier | None = None, - sub_converters: list[ComponentIdentifier] | None = None, + sub_converter: ComponentIdentifier | None = None, ) -> ComponentIdentifier: """ Construct and return the converter identifier. @@ -222,8 +222,8 @@ def _create_identifier( the subclass (e.g., font, encoding_func). Merged into the base params. converter_target (ComponentIdentifier | None): The target an LLM-backed converter calls, promoted to ``ConverterIdentifier.converter_target``. - sub_converters (list[ComponentIdentifier] | None): Nested converters a - composite wraps, promoted to ``ConverterIdentifier.sub_converters``. + sub_converter (ComponentIdentifier | None): A nested converter a + composite wraps, promoted to ``ConverterIdentifier.sub_converter``. Returns: ComponentIdentifier: The identifier for this converter. @@ -234,7 +234,7 @@ def _create_identifier( supported_input_types=self.SUPPORTED_INPUT_TYPES, supported_output_types=self.SUPPORTED_OUTPUT_TYPES, converter_target=converter_target, - sub_converters=sub_converters, + sub_converter=sub_converter, ) @property diff --git a/pyrit/prompt_converter/selective_text_converter.py b/pyrit/prompt_converter/selective_text_converter.py index e056fb9663..7f05eee641 100644 --- a/pyrit/prompt_converter/selective_text_converter.py +++ b/pyrit/prompt_converter/selective_text_converter.py @@ -31,7 +31,7 @@ class SelectiveTextConverter(PromptConverter): >>> # Convert only words matching a pattern >>> strategy = WordRegexSelectionStrategy(pattern=r"\\d+") >>> converter = SelectiveTextConverter( - ... converter=Base64Converter(), + ... sub_converter=Base64Converter(), ... selection_strategy=strategy, ... preserve_tokens=True ... ) @@ -47,7 +47,7 @@ class SelectiveTextConverter(PromptConverter): def __init__( self, *, - converter: PromptConverter, + sub_converter: PromptConverter, selection_strategy: TextSelectionStrategy, preserve_tokens: bool = False, start_token: str = "⟪", @@ -58,7 +58,7 @@ def __init__( Initialize the selective text converter. Args: - converter (PromptConverter): The converter to apply to the selected text. + sub_converter (PromptConverter): The converter to apply to the selected text. selection_strategy (TextSelectionStrategy): The strategy for selecting which text to convert. Can be character-level or word-level strategy. preserve_tokens (bool): If True, wraps converted text with start/end tokens. @@ -78,9 +78,9 @@ def __init__( """ super().__init__() - self._validate_converter(converter=converter, selection_strategy=selection_strategy) + self._validate_converter(sub_converter=sub_converter, selection_strategy=selection_strategy) - self._converter = converter + self._sub_converter = sub_converter self._selection_strategy = selection_strategy self._preserve_tokens = preserve_tokens self._start_token = start_token @@ -103,20 +103,20 @@ def _build_identifier(self) -> ComponentIdentifier: "start_token": self._start_token, "end_token": self._end_token, }, - sub_converters=[self._converter.get_identifier()], + sub_converter=self._sub_converter.get_identifier(), ) def _validate_converter( self, *, - converter: PromptConverter, + sub_converter: PromptConverter, selection_strategy: TextSelectionStrategy, ) -> None: """ Validate the converter and selection strategy combination. Args: - converter (PromptConverter): The converter to validate. + sub_converter (PromptConverter): The converter to validate. selection_strategy (TextSelectionStrategy): The selection strategy to validate against. Raises: @@ -124,18 +124,18 @@ def _validate_converter( ValueError: If a word-level selection strategy is used with a WordLevelConverter that has a non-default word_selection_strategy. """ - if not converter.input_supported("text"): - raise ValueError(f"The converter {converter.__class__.__name__} does not support text input") - if not converter.output_supported("text"): - raise ValueError(f"The converter {converter.__class__.__name__} does not support text output") + if not sub_converter.input_supported("text"): + raise ValueError(f"The converter {sub_converter.__class__.__name__} does not support text input") + if not sub_converter.output_supported("text"): + raise ValueError(f"The converter {sub_converter.__class__.__name__} does not support text output") # Check for conflicting word selection strategies is_word_level_selection = isinstance(selection_strategy, WordSelectionStrategy) - if is_word_level_selection and isinstance(converter, WordLevelConverter): - has_non_default_strategy = not isinstance(converter._word_selection_strategy, AllWordsSelectionStrategy) + if is_word_level_selection and isinstance(sub_converter, WordLevelConverter): + has_non_default_strategy = not isinstance(sub_converter._word_selection_strategy, AllWordsSelectionStrategy) if has_non_default_strategy: raise ValueError( - f"Cannot use a WordSelectionStrategy with a {converter.__class__.__name__} that has a " + f"Cannot use a WordSelectionStrategy with a {sub_converter.__class__.__name__} that has a " f"non-default word_selection_strategy. When SelectiveTextConverter uses a word-level " f"strategy, it passes individual words to the wrapped converter, making the wrapped " f"converter's word selection strategy meaningless. Either use a character-level " @@ -161,7 +161,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text # If using TokenSelectionStrategy, delegate to convert_tokens_async if self._is_token_based: - result = await self._converter.convert_tokens_async( + result = await self._sub_converter.convert_tokens_async( prompt=prompt, input_type="text", start_token=self._start_token, @@ -201,7 +201,7 @@ async def _convert_word_level_async(self, *, prompt: str) -> ConverterResult: # Convert selected words for idx in selected_indices: - conversion_result = await self._converter.convert_async(prompt=words[idx], input_type="text") + conversion_result = await self._sub_converter.convert_async(prompt=words[idx], input_type="text") converted_word = conversion_result.output_text if self._preserve_tokens: @@ -234,7 +234,7 @@ async def _convert_char_level_async(self, *, prompt: str) -> ConverterResult: after_text = prompt[end_idx:] # Convert the selected region - conversion_result = await self._converter.convert_async(prompt=selected_text, input_type="text") + conversion_result = await self._sub_converter.convert_async(prompt=selected_text, input_type="text") converted_text = conversion_result.output_text if self._preserve_tokens: diff --git a/pyrit/prompt_converter/text_selection_strategy.py b/pyrit/prompt_converter/text_selection_strategy.py index 438f5f73ec..cdfd6abedb 100644 --- a/pyrit/prompt_converter/text_selection_strategy.py +++ b/pyrit/prompt_converter/text_selection_strategy.py @@ -38,14 +38,14 @@ class TokenSelectionStrategy(TextSelectionStrategy): Example: >>> first_converter = SelectiveTextConverter( - ... converter=Base64Converter(), + ... sub_converter=Base64Converter(), ... selection_strategy=WordPositionSelectionStrategy(start_proportion=0.5, end_proportion=1.0), ... preserve_tokens=True ... ) >>> # Text after first converter: "hello world ⟪Y29udmVydGVk⟫" >>> >>> second_converter = SelectiveTextConverter( - ... converter=ROT13Converter(), + ... sub_converter=ROT13Converter(), ... selection_strategy=TokenSelectionStrategy(), # Auto-detect tokens ... preserve_tokens=True ... ) diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 6fa61edfd1..b6cb751387 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -4,7 +4,6 @@ """Registry module for PyRIT class and object registries.""" from pyrit.registry.base import RegistryProtocol -from pyrit.registry.buildable_registry import BuildableRegistry from pyrit.registry.class_registries import ( BaseClassRegistry, ClassEntry, @@ -36,17 +35,18 @@ ScorerRegistry, TargetRegistry, ) +from pyrit.registry.registry import Registry from pyrit.registry.tag_query import TagQuery __all__ = [ "AttackTechniqueRegistry", "BaseClassRegistry", "BaseInstanceRegistry", - "BuildableRegistry", - "ConverterMetadata", "ConverterRegistry", + "ConverterMetadata", "DefaultInstanceRegistry", "InstanceRegistry", + "Registry", "RetrievableInstanceRegistry", "SupportsInstances", "ClassEntry", diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py index a1bc88ab81..f39e30e33c 100644 --- a/pyrit/registry/base.py +++ b/pyrit/registry/base.py @@ -10,14 +10,17 @@ from __future__ import annotations -from dataclasses import dataclass +import inspect +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterator, Mapping from typing_extensions import Self + from pyrit.models.parameter import Parameter + # Type variable for metadata (invariant for Protocol compatibility) MetadataT = TypeVar("MetadataT") @@ -36,12 +39,22 @@ class ClassRegistryEntry: class_description (str): Human-readable description, typically from the class docstring. registry_name (str): The suffix-stripped snake_case key used in the registry (e.g., "content_harms" for ContentHarmsScenario). + parameters (tuple[Parameter, ...]): The derived build contract for the class. + Buildable registries (e.g. converters) populate this from the constructor + signature; scenarios/initializers use their own ``supported_parameters`` + today and will migrate to this unified shape. + class_attributes (Mapping[str, Any]): Values sourced from class attributes + (declared on the identifier via ``Param.ClassAttr``), letting the entry + describe class-level facts — e.g. a converter's supported input/output + types — without constructing an instance. Empty for entries with none. """ class_name: str class_module: str class_description: str = "" registry_name: str = "" + parameters: tuple[Parameter, ...] = field(kw_only=True, default=()) + class_attributes: Mapping[str, Any] = field(kw_only=True, default_factory=dict) @staticmethod def description_from_docstring(cls: type, *, fallback: str = "") -> str: @@ -58,6 +71,27 @@ def description_from_docstring(cls: type, *, fallback: str = "") -> str: cleaned = " ".join(doc.split()) return cleaned or fallback + @staticmethod + def summary_from_docstring(cls: type) -> str: + """ + Extract a short summary from the first paragraph of a class docstring. + + Uses the class's own docstring only (never an inherited one), normalizes + indentation, and collapses the first paragraph's whitespace onto one line. + Empty when the class has no docstring. This is the catalog-display + counterpart to ``description_from_docstring`` (which collapses the whole + docstring); buildable registries populate ``class_description`` from this + first-paragraph form. + + Returns: + str: The first-paragraph summary, or "" when there is no docstring. + """ + raw = cls.__doc__ + if not raw: + return "" + first_paragraph = inspect.cleandoc(raw).split("\n\n", 1)[0] + return " ".join(first_paragraph.split()) + @runtime_checkable class RegistryProtocol(Protocol[MetadataT]): diff --git a/pyrit/registry/buildable_registry.py b/pyrit/registry/buildable_registry.py deleted file mode 100644 index ebe1bb9e1e..0000000000 --- a/pyrit/registry/buildable_registry.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Buildable registry base for PyRIT. - -``BuildableRegistry`` is the universal registry capability: discover classes, -introspect them into metadata, and **build** configured instances from a type -name plus a flat argument dict. Construction routes through the shared -``resolve_constructor_args`` primitive, so simple values are coerced and -registry-reference parameters (e.g. a ``PromptTarget``) are resolved by name — -the same mechanism for every domain. - -Every PyRIT registry is buildable. Registries that additionally hold named -instances expose an ``instances`` property (an ``InstanceRegistry``); the -buildable layer itself only concerns the class catalog. -""" - -from __future__ import annotations - -from typing import TypeVar - -from pyrit.registry.class_registries.base_class_registry import BaseClassRegistry -from pyrit.registry.resolution import resolve_constructor_args - -T = TypeVar("T") -MetadataT = TypeVar("MetadataT") - - -class BuildableRegistry(BaseClassRegistry[T, MetadataT]): - """ - Registry base that can build instances from a type name and arguments. - - Extends the class-table infrastructure of ``BaseClassRegistry`` with a - construction path that routes through ``resolve_constructor_args``: string - values are coerced to their annotated scalar types and registry-reference - parameters are resolved by name from the owning domain's registry. A - registered factory, when present, is used as-is (its arguments are not - resolved, since a factory owns its own construction semantics). - - Type Parameters: - T: The type of classes being registered (e.g. ``PromptConverter``). - MetadataT: The metadata dataclass type (e.g. ``ConverterMetadata``). - """ - - def get_class_names(self) -> list[str]: - """ - Get a sorted list of all registered class names. - - Always reflects the class catalog, even on registries that also hold - instances (where the protocol surface ``get_names`` refers to instances on - the ``instances`` property, not here). - - Returns: - list[str]: The sorted class-catalog names. - """ - self._ensure_discovered() - return sorted(self._class_entries.keys()) - - def list_class_metadata( - self, - *, - include_filters: dict[str, object] | None = None, - exclude_filters: dict[str, object] | None = None, - ) -> list[MetadataT]: - """ - List metadata for all registered classes, optionally filtered. - - This is the class-catalog metadata (one entry per registered class), - distinct from any instance-level metadata a container registry exposes. - It always reflects the class catalog, even on container registries where - ``list_metadata`` refers to instances. - - Args: - include_filters (dict[str, object] | None): Filters items must match. - exclude_filters (dict[str, object] | None): Filters items must not match. - - Returns: - list[MetadataT]: Metadata describing each registered class. - """ - return BaseClassRegistry.list_metadata(self, include_filters=include_filters, exclude_filters=exclude_filters) - - def create_instance(self, name: str, **kwargs: object) -> T: - """ - Build a configured instance by class name. - - Arguments are resolved via ``resolve_constructor_args`` (coerce simple - strings, resolve registry references by name, raise on unknown params). - When the class is registered with a factory, the factory is invoked - directly with the given arguments instead. - - Args: - name (str): The class-catalog name to build. - **kwargs (object): Constructor arguments (simple values or registry - names for reference parameters). - - Returns: - T: The constructed instance. - - Raises: - KeyError: If the name is not registered. - ValueError: If an argument is not a valid constructor parameter, a - registry reference cannot be resolved, or a value cannot be coerced. - """ - entry = self._require_entry(name) - - if entry.factory is not None: - return entry.create_instance(**kwargs) - - raw_args = {**entry.default_kwargs, **kwargs} - resolved = resolve_constructor_args(cls=entry.registered_class, raw_args=raw_args) - return entry.registered_class(**resolved) diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index 0679ab748b..b1668f6177 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -195,8 +195,8 @@ def _require_entry(self, name: str) -> ClassEntry[T]: """ Resolve a registered ``ClassEntry`` by name or raise. - Shared lookup used by ``get_class`` and by ``BuildableRegistry.create_instance`` - so the "not found" behavior (and its error message listing the class catalog) + Shared lookup used by ``get_class`` and ``create_instance`` so the + "not found" behavior (and its error message listing the class catalog) lives in one place. Args: diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index 0300b00a06..56b270d943 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -13,7 +13,7 @@ import logging from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, NamedTuple, get_origin +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, get_args, get_origin from pyrit.models import class_name_to_snake_case from pyrit.registry.base import ClassRegistryEntry @@ -25,6 +25,7 @@ discover_in_package, discover_subclasses_in_loaded_modules, ) +from pyrit.registry.resolution import display_choices if TYPE_CHECKING: from pyrit.scenario.core import Scenario @@ -222,7 +223,7 @@ def _build_metadata(self, name: str, entry: ClassEntry[Scenario]) -> ScenarioMet description=p.description, default=p.default, param_type=_param_type_display(p.param_type), - choices=[str(c) for c in p.choices] if p.choices else None, + choices=[str(c) for c in choices] if (choices := display_choices(p.param_type)) else None, is_list=get_origin(p.param_type) is list, ) for p in scenario_class.supported_parameters() @@ -272,6 +273,18 @@ def _param_type_display(param_type: Any) -> str: """ if param_type is None: return "any" + # A constrained scalar (Literal[...]) renders as its base scalar name so the + # display + API round-trip works; the allowed members travel via `choices`. + if get_origin(param_type) is Literal: + args = get_args(param_type) + return type(args[0]).__name__ if args else "str" + if get_origin(param_type) is list: + type_args = get_args(param_type) + element_type = type_args[0] if type_args else str + if get_origin(element_type) is Literal: + element_args = get_args(element_type) + element_name = type(element_args[0]).__name__ if element_args else "str" + return f"list[{element_name}]" # Detect parameterized generics (list[str], dict[str, int], ...) reliably across Python # versions: get_origin returns the unparameterized type for GenericAlias, None otherwise. # On some 3.10 builds GenericAlias passes isinstance(_, type), so we can't rely on that. diff --git a/pyrit/registry/components/__init__.py b/pyrit/registry/components/__init__.py index b12a22a9c4..38faacaeda 100644 --- a/pyrit/registry/components/__init__.py +++ b/pyrit/registry/components/__init__.py @@ -6,11 +6,11 @@ This package contains registries for PyRIT components (objects identified by a ``ComponentIdentifier``, such as converters, scorers, and targets). A component -registry is a ``BuildableRegistry`` class catalog that can build instances from +registry is a ``Registry`` class catalog that can build instances from classes and, when it retains pre-configured instances, also exposes them via an ``.instances`` property. -Shared capabilities and base classes (``BuildableRegistry``, ``InstanceRegistry``, +Shared capabilities and base classes (``Registry``, ``InstanceRegistry``, ``DefaultInstanceRegistry``) live at the top level of ``pyrit.registry``. """ diff --git a/pyrit/registry/components/converter_registry.py b/pyrit/registry/components/converter_registry.py index 7a0bc8b237..0e224ea17d 100644 --- a/pyrit/registry/components/converter_registry.py +++ b/pyrit/registry/components/converter_registry.py @@ -7,32 +7,31 @@ A single registry for ``PromptConverter`` that both: - **builds** converters from a type name plus arguments — discovering converter - classes, introspecting their constructor parameters, and constructing instances - via the shared resolver (so LLM converters can be built by passing a - ``converter_target`` registry name), and + classes, deriving their ``Parameter`` contract from the constructor enriched by + ``ConverterIdentifier``'s build markers, and constructing instances via the + shared resolver (so LLM converters can be built by passing a ``converter_target`` + registry name), and - **holds** pre-configured converter instances registered via initializers or the backend. -It is a ``BuildableRegistry``: the registry's own surface (``get_class``, -``get_class_names``, ``list_class_metadata``, ``create_instance``) is the buildable -class catalog. Pre-configured instances live under the ``instances`` property -(``register``, ``get``, ``get_all_instances``, ``get_names``), a -``DefaultInstanceRegistry``. +It is a ``Registry``: the registry's own surface (``get_class``, +``get_class_names``, ``get_all_registered_class_metadata``, ``create_instance``) +is the buildable class catalog. Pre-configured instances live under the +``instances`` property (``register``, ``get``, ``get_all_instances``, +``get_names``), a ``DefaultInstanceRegistry``. """ from __future__ import annotations -import inspect import logging -import re -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Literal, NamedTuple, get_args, get_origin +from dataclasses import dataclass +from typing import TYPE_CHECKING +from pyrit.models.identifiers import ConverterIdentifier +from pyrit.models.parameter import ComponentType from pyrit.registry.base import ClassRegistryEntry -from pyrit.registry.buildable_registry import BuildableRegistry -from pyrit.registry.class_registries.base_class_registry import ClassEntry from pyrit.registry.instance_registry import DefaultInstanceRegistry, InstanceRegistry -from pyrit.registry.resolution import get_union_non_none_args, is_coercible_from_string +from pyrit.registry.registry import Registry if TYPE_CHECKING: from pyrit.prompt_converter import PromptConverter @@ -56,170 +55,39 @@ def _prompt_converter_type() -> type[PromptConverter]: return PromptConverter -class _ConverterParameterMetadata(NamedTuple): - """ - A converter constructor parameter described for dynamic construction. - - .. note:: - Transitional / internal. This bespoke shape is replaced by the unified - ``pyrit.common.parameter.Parameter`` contract on ``ClassRegistryEntry`` in - Phase 3 of the registry refactor and will be deleted then. Read the values - if you must, but do not build new APIs around this type or import it as a - stable public symbol. - - Carries raw introspection data so callers can build converters on the fly. - ``annotation`` is the parameter's raw type annotation; rendering it to a - human-readable string is a presentation concern left to the caller. - ``coercible_from_string`` is True when a string value can be coerced to the - annotated type. ``requires_llm`` is True when the parameter expects a - ``PromptTarget`` (i.e. the converter performs an LLM-based transformation). - - NamedTuple so consumers can read fields by name while the value stays - immutable (safe to cache inside a frozen ``ConverterMetadata``). - """ - - name: str - annotation: Any - required: bool - default_value: str | None - choices: tuple[str, ...] | None - description: str | None - coercible_from_string: bool - requires_llm: bool - - @dataclass(frozen=True) class ConverterMetadata(ClassRegistryEntry): """ Metadata describing a registered ``PromptConverter`` class. + Carries the derived ``parameters`` build contract (the same list the resolver + consumes to build an instance) and, via ``class_attributes`` on the base, the + converter's class-level supported input/output types. Presentation facts — the + supported types and whether the converter is LLM-based — are projected from + those rather than stored, so the entry can never drift from the class or the + contract. + Use ``ConverterRegistry.get_class()`` to get the actual class or ``create_instance()`` to build a configured instance. """ - # Input data types the converter accepts (stringified PromptDataType values). - supported_input_types: tuple[str, ...] = field(kw_only=True, default=()) - - # Output data types the converter produces (stringified PromptDataType values). - supported_output_types: tuple[str, ...] = field(kw_only=True, default=()) - - # Simple constructor parameters suitable for dynamic form generation. - # Transitional element type — replaced by ``Parameter`` in Phase 3. - parameters: tuple[_ConverterParameterMetadata, ...] = field(kw_only=True, default=()) - - # Whether the converter requires an LLM target. - is_llm_based: bool = field(kw_only=True, default=False) - - -def _requires_llm_target(annotation: Any) -> bool: - """ - Return True if the annotation expects a ``PromptTarget`` (or subclass). - - Handles unioned forms such as ``PromptTarget | None``. A converter parameter - with such an annotation indicates the converter performs an LLM-based - transformation. - - Returns: - bool: True if the annotation expects a ``PromptTarget``, False otherwise. - """ - if annotation is inspect.Parameter.empty: - return False - - from pyrit.prompt_target import PromptTarget - - candidates = get_union_non_none_args(annotation) - if candidates is None: - candidates = [annotation] - for candidate in candidates: - try: - if isinstance(candidate, type) and issubclass(candidate, PromptTarget): - return True - except TypeError: - continue - return False - - -def _parse_arg_descriptions(converter_class: type) -> dict[str, str]: - """ - Parse parameter descriptions from a Google-style docstring Args section. - - Returns: - dict[str, str]: Mapping of parameter names to their descriptions. - """ - doc = (converter_class.__init__.__doc__ or converter_class.__doc__ or "").strip() - match = re.search(r"Args:\s*\n(.*?)(?:\n\s*\n|\n\s*Returns:|\n\s*Raises:|\Z)", doc, re.DOTALL) - if not match: - return {} - args_block = match.group(1) - # Detect indentation of first parameter line - indent_match = re.match(r"^(\s+)", args_block) - indent = indent_match.group(1) if indent_match else r"\s+" - pattern = rf"^{indent}(\w+)\s*(?:\([^)]*\))?\s*:\s*(.+?)(?=\n{indent}\w|\Z)" - descriptions: dict[str, str] = {} - for m in re.finditer(pattern, args_block, re.DOTALL | re.MULTILINE): - descriptions[m.group(1)] = " ".join(m.group(2).split()) - return descriptions - - -def _extract_parameters(converter_class: type) -> tuple[_ConverterParameterMetadata, ...]: - """ - Extract constructor parameters from a converter class. - - Surfaces every settable constructor parameter (excluding ``self`` and - var-args) so a caller has the full picture for dynamic construction. Each - parameter records its raw ``annotation`` and a ``coercible_from_string`` flag - indicating whether a string value can be coerced to its type. + @property + def supported_input_types(self) -> tuple[str, ...]: + """Input data types the converter accepts (stringified ``PromptDataType`` values).""" + return tuple(str(dt) for dt in (self.class_attributes.get("supported_input_types") or ())) - Returns: - tuple[_ConverterParameterMetadata, ...]: The constructor parameters. - """ - try: - sig = inspect.signature(converter_class.__init__) - except (ValueError, TypeError): - return () - - arg_descriptions = _parse_arg_descriptions(converter_class) - - params: list[_ConverterParameterMetadata] = [] - for name, p in sig.parameters.items(): - if name in ("self", "args", "kwargs"): - continue - if p.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): - continue - - no_default = p.default is inspect.Parameter.empty - is_sentinel = hasattr(p.default, "__class__") and "Sentinel" in type(p.default).__name__ - required = no_default or is_sentinel - - default_value: str | None = None - if not required and p.default is not None: - default_value = str(p.default) - - choices: tuple[str, ...] | None = None - choice_annotation = p.annotation - non_none_choice = get_union_non_none_args(choice_annotation) - if non_none_choice is not None and len(non_none_choice) == 1: - choice_annotation = non_none_choice[0] - if get_origin(choice_annotation) is Literal: - choices = tuple(str(a) for a in get_args(choice_annotation)) - - params.append( - _ConverterParameterMetadata( - name=name, - annotation=p.annotation, - required=required, - default_value=default_value, - choices=choices, - description=arg_descriptions.get(name), - coercible_from_string=is_coercible_from_string(p.annotation), - requires_llm=_requires_llm_target(p.annotation), - ) - ) + @property + def supported_output_types(self) -> tuple[str, ...]: + """Output data types the converter produces (stringified ``PromptDataType`` values).""" + return tuple(str(dt) for dt in (self.class_attributes.get("supported_output_types") or ())) - return tuple(params) + @property + def is_llm_based(self) -> bool: + """Whether the converter requires an LLM target (a TARGET reference parameter).""" + return any(p.is_reference_to(ComponentType.TARGET) for p in self.parameters) -class ConverterRegistry(BuildableRegistry["PromptConverter", ConverterMetadata]): +class ConverterRegistry(Registry["PromptConverter", ConverterMetadata]): """ Registry that discovers, builds, and holds ``PromptConverter`` instances. @@ -247,7 +115,11 @@ def __init__(self, *, lazy_discovery: bool = True) -> None: instance_type=_prompt_converter_type ) - def _get_registry_name(self, cls: type) -> str: + def _identifier_type(self) -> type[ConverterIdentifier]: + """Return ``ConverterIdentifier`` so its ``Param.*`` markers drive derivation.""" + return ConverterIdentifier + + def _get_registry_name(self, cls: type[PromptConverter]) -> str: """ Use the exact class name as the catalog key. @@ -270,38 +142,12 @@ def _discover(self) -> None: continue if not issubclass(cls, PromptConverter) or cls is PromptConverter: continue - self._class_entries[name] = ClassEntry(registered_class=cls) - logger.debug(f"Registered converter class: {name}") - - def _build_metadata(self, name: str, entry: ClassEntry[PromptConverter]) -> ConverterMetadata: - """ - Build catalog metadata for a ``PromptConverter`` class. - - Args: - name (str): The catalog name (exact class name) of the converter. - entry (ClassEntry[PromptConverter]): The class entry being described. - - Returns: - ConverterMetadata: Metadata describing the converter class. - """ - converter_class = entry.registered_class - - # First paragraph of the docstring as a short description. - raw_doc = (converter_class.__doc__ or "").strip() - description = raw_doc.split("\n\n")[0].replace("\n", " ").strip() - - supported_input_types = tuple(str(dt) for dt in getattr(converter_class, "SUPPORTED_INPUT_TYPES", ())) - supported_output_types = tuple(str(dt) for dt in getattr(converter_class, "SUPPORTED_OUTPUT_TYPES", ())) - - parameters = _extract_parameters(converter_class) - - return ConverterMetadata( - class_name=converter_class.__name__, - class_module=converter_class.__module__, - class_description=description, - registry_name=name, - supported_input_types=supported_input_types, - supported_output_types=supported_output_types, - parameters=parameters, - is_llm_based=any(p.requires_llm for p in parameters), - ) + # Key off the class itself (via _get_registry_name) rather than the + # __all__ export name so the catalog key always matches class_name, + # even if an export is ever aliased. + self.register_class(cls) + logger.debug(f"Registered converter class: {cls.__name__}") + + def _metadata_class(self) -> type[ConverterMetadata]: + """Return ``ConverterMetadata``; the base populates it from the common fields.""" + return ConverterMetadata diff --git a/pyrit/registry/object_registries/base_instance_registry.py b/pyrit/registry/object_registries/base_instance_registry.py index 4982aac08c..58a7aa2354 100644 --- a/pyrit/registry/object_registries/base_instance_registry.py +++ b/pyrit/registry/object_registries/base_instance_registry.py @@ -7,7 +7,7 @@ .. note:: **Legacy stack — do not build new registries on this.** New component - registries should subclass ``BuildableRegistry`` (a class catalog that can + registries should subclass ``Registry`` (a class catalog that can build instances by name) and hold pre-configured instances via the ``.instances`` property (a ``DefaultInstanceRegistry``). See ``ConverterRegistry`` for the target shape. This class and @@ -55,7 +55,7 @@ class BaseInstanceRegistry(ABC, RegistryProtocol[ComponentIdentifier], Generic[T .. note:: **Legacy — do not subclass for new registries.** New component - registries subclass ``BuildableRegistry`` and expose retained instances + registries subclass ``Registry`` and expose retained instances via the ``.instances`` property (``DefaultInstanceRegistry``), which carries this same surface (``register``/``get``/``get_by_tag``/ ``add_tags``/``find_dependents_of_tag``/``list_metadata``). This class diff --git a/pyrit/registry/object_registries/retrievable_instance_registry.py b/pyrit/registry/object_registries/retrievable_instance_registry.py index 69e07fb0f3..b462c22012 100644 --- a/pyrit/registry/object_registries/retrievable_instance_registry.py +++ b/pyrit/registry/object_registries/retrievable_instance_registry.py @@ -7,7 +7,7 @@ .. note:: **Legacy stack — do not build new registries on this.** New component - registries subclass ``BuildableRegistry`` and retain instances via the + registries subclass ``Registry`` and retain instances via the ``.instances`` property (``DefaultInstanceRegistry``), which already provides ``get``/``get_entry``/``get_all_instances``. See ``ConverterRegistry`` for the target shape. This class remains only for the @@ -42,7 +42,7 @@ class RetrievableInstanceRegistry(BaseInstanceRegistry[T]): .. note:: **Legacy — do not subclass for new registries.** Use - ``BuildableRegistry`` + the ``.instances`` property + ``Registry`` + the ``.instances`` property (``DefaultInstanceRegistry``), which already exposes ``get``/``get_entry``/``get_all_instances``. Retained only for the not-yet-migrated ``ScorerRegistry`` and ``TargetRegistry``. diff --git a/pyrit/registry/registry.py b/pyrit/registry/registry.py new file mode 100644 index 0000000000..c2a9ff3c60 --- /dev/null +++ b/pyrit/registry/registry.py @@ -0,0 +1,516 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Registry base for PyRIT. + +``Registry`` is the universal registry capability: discover classes, introspect +them into metadata, and construct configured instances from a type name plus a +flat argument dict. Construction routes through the shared +``resolve_constructor_args`` primitive, so simple values are coerced and +registry-reference parameters (e.g. a ``PromptTarget``) are resolved by name — +the same mechanism for every domain. + +It owns a single add path: ``_discover()`` populates the catalog by calling +``register_class()``, which validates the class (its build contract must be +derivable and every reference parameter must map to a wired registry) before it +is stored. Validation therefore happens once, at registration time; there is no +separate post-hoc sweep. + +Every PyRIT registry is a ``Registry``. Registries that additionally hold named, +pre-built component objects expose an ``instances`` property (an +``InstanceRegistry``); the class catalog itself only concerns classes. After this +layering, "instance" only ever means a built component object — never the +registry singleton. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from pyrit.models import class_name_to_snake_case +from pyrit.registry.base import ClassRegistryEntry +from pyrit.registry.resolution import ( + derive_parameters, + is_component_type_resolvable, + resolve_constructor_args, +) + +if TYPE_CHECKING: + from collections.abc import Iterator, Mapping + + from typing_extensions import Self + + from pyrit.models.identifiers.component_identifier import ComponentIdentifier + from pyrit.models.parameter import Parameter + +T = TypeVar("T") +MetadataT = TypeVar("MetadataT", bound=ClassRegistryEntry) + + +def _get_metadata_value(metadata: Any, key: str) -> tuple[bool, Any]: + """ + Get a value from a metadata object by key. + + Checks direct attributes first, then falls back to the ``params`` dict + (used by ComponentIdentifier). Returns a (found, value) tuple. + + Args: + metadata: The metadata object to look up. + key (str): The attribute or params key to find. + + Returns: + tuple: (True, value) if found, (False, None) otherwise. + """ + if hasattr(metadata, key): + return True, getattr(metadata, key) + + params = getattr(metadata, "params", None) + if isinstance(params, dict) and key in params: + return True, params[key] + + return False, None + + +def _matches_filters( + metadata: Any, + *, + include_filters: dict[str, Any] | None = None, + exclude_filters: dict[str, Any] | None = None, +) -> bool: + """ + Check if a metadata object matches all provided filters. + + Supports filtering on any property of the metadata dataclass or on keys + inside the ``params`` dict (for ComponentIdentifier metadata): + + - For simple types (str, int, bool): exact match comparison. + - For sequence types (list, tuple): checks if the filter value is contained. + + Items must match ALL include_filters (AND logic) and must NOT match ANY + exclude_filters. + + Args: + metadata: The metadata dataclass instance to check. + include_filters: Optional dict of filters that must ALL match. + exclude_filters: Optional dict of filters that must ALL NOT match. + + Returns: + bool: True if all include_filters match and no exclude_filters match. + """ + if include_filters: + for key, filter_value in include_filters.items(): + found, actual_value = _get_metadata_value(metadata, key) + if not found: + return False + if isinstance(actual_value, (list, tuple)): + if filter_value not in actual_value: + return False + elif actual_value != filter_value: + return False + + if exclude_filters: + for key, filter_value in exclude_filters.items(): + found, actual_value = _get_metadata_value(metadata, key) + if not found: + continue + if isinstance(actual_value, (list, tuple)): + if filter_value in actual_value: + return False + elif actual_value == filter_value: + return False + + return True + + +class Registry(ABC, Generic[T, MetadataT]): + """ + Standalone base for PyRIT registries: a validated class catalog that builds instances. + + Provides the common infrastructure every registry needs: + + - Lazy discovery of classes (deferred until first access). + - A single add path (``register_class``) that validates a class before storing it. + - Metadata caching keyed by registry name. + - Construction from a type name plus arguments (``create_instance``), routed + through ``resolve_constructor_args`` so string values are coerced and + registry-reference parameters are resolved by name from the owning domain. + - Singleton support via ``get_registry_singleton()``. + + Subclasses must implement: + + - ``_discover()`` — populate the catalog by calling ``register_class`` for each class. + - ``_metadata_class()`` — return the concrete metadata dataclass the base builds. + + Type Parameters: + T: The type of classes being registered (e.g. ``PromptConverter``). + MetadataT: The metadata dataclass type (e.g. ``ConverterMetadata``). + """ + + # Class-level singleton instances, keyed by registry class. + _singletons: dict[type, Registry[Any, Any]] = {} + + def __init__(self, *, lazy_discovery: bool = True) -> None: + """ + Initialize the registry. + + Args: + lazy_discovery (bool): If True, discovery is deferred until first access. + If False, discovery runs immediately in the constructor. + """ + self._classes: dict[str, type[T]] = {} + self._metadata_cache: dict[str, MetadataT] | None = None + self._discovered = False + self._lazy_discovery = lazy_discovery + + if not lazy_discovery: + self._discover() + self._discovered = True + + @classmethod + def get_registry_singleton(cls) -> Self: + """ + Get the singleton instance of this registry. + + Creates the instance on first call with default parameters. + + Returns: + The singleton instance of this registry class. + """ + if cls not in cls._singletons: + cls._singletons[cls] = cls() # type: ignore[ty:invalid-assignment] + return cls._singletons[cls] # type: ignore[ty:invalid-return-type] + + @classmethod + def reset_registry_singleton(cls) -> None: + """ + Reset the singleton instance. + + Useful for testing or when re-discovery is needed. + """ + if cls in cls._singletons: + del cls._singletons[cls] + + def _ensure_discovered(self) -> None: + """Ensure discovery has been performed. Runs discovery on first access.""" + if not self._discovered: + self._discover() + self._discovered = True + + @abstractmethod + def _discover(self) -> None: + """ + Perform discovery of registry classes. + + Subclasses implement this to populate the catalog by calling + ``self.register_class(cls)`` for each discovered class. + """ + + @abstractmethod + def _metadata_class(self) -> type[MetadataT]: + """ + Return the concrete metadata dataclass this registry builds. + + The base ``_build_metadata`` constructs this type from the common + ``ClassRegistryEntry`` fields. Subclasses whose metadata carries extra + fields beyond the common shape override ``_build_metadata`` instead. + + Returns: + type[MetadataT]: The metadata dataclass (e.g. ``ConverterMetadata``). + """ + + def _build_metadata(self, name: str, cls: type[T]) -> MetadataT: + """ + Build the metadata descriptor for a registered class. + + Populates the common ``ClassRegistryEntry`` fields — name/module, a + first-paragraph description, the derived ``Parameter`` build contract, and + any ``Param.ClassAttr`` class attributes — into the registry's + ``_metadata_class``. Subclasses needing extra fields override this. + + Args: + name (str): The catalog name (the registry key) for the class. + cls (type[T]): The registered class to describe. + + Returns: + MetadataT: A metadata descriptor for the registered class. + """ + metadata_class = self._metadata_class() + return metadata_class( + class_name=cls.__name__, + class_module=cls.__module__, + class_description=metadata_class.summary_from_docstring(cls), + registry_name=name, + parameters=self._derive_parameters(cls), + class_attributes=self._class_attributes(cls), + ) + + def _derive_parameters(self, cls: type[T]) -> tuple[Parameter, ...]: + """ + Derive the class's ``Parameter`` build contract under this registry's identifier. + + Args: + cls (type[T]): The class to introspect. + + Returns: + tuple[Parameter, ...]: The derived build contract. + """ + return tuple(derive_parameters(cls=cls, identifier_type=self._identifier_type())) + + def _class_attributes(self, cls: type[T]) -> Mapping[str, Any]: + """ + Read this registry's ``Param.ClassAttr`` class attributes off a class. + + Args: + cls (type[T]): The class to read class-level attributes from. + + Returns: + Mapping[str, Any]: Field-name → class-attribute value, empty when the + registry has no domain identifier. + """ + identifier_type = self._identifier_type() + if identifier_type is None: + return {} + return identifier_type.get_class_attribute_values(cls) + + def _identifier_type(self) -> type[ComponentIdentifier] | None: + """ + Return the domain identifier whose ``Param.*`` markers drive derivation. + + The base registry has no domain identifier, so no constructor parameter is + treated as a registry reference. Domain registries (e.g. + ``ConverterRegistry``) override this to return their identifier type so that + ``Param.Exclude`` / ``Param.Include`` markers are honored. + + Returns: + type[ComponentIdentifier] | None: The domain identifier type, or None. + """ + return None + + def _get_registry_name(self, cls: type[T]) -> str: + """ + Get the catalog name for a class. + + Subclasses can override this to customize name derivation. The default + converts CamelCase to snake_case. + + Args: + cls (type[T]): The class to get a name for. + + Returns: + str: The catalog name (snake_case identifier by default). + """ + return class_name_to_snake_case(cls.__name__) + + def _validate_class(self, cls: type[T]) -> None: + """ + Verify the registry can describe and build a class. + + Derives the class's ``Parameter`` contract (raising if its constructor + cannot be introspected) and checks that every reference parameter maps to a + registry the resolver knows how to query. This is the registration gate: a + class whose build contract does not line up with a resolvable reference + fails fast at ``register_class`` time instead of erroring only at build time. + + Args: + cls (type[T]): The class to validate. + + Raises: + ValueError: If the constructor cannot be introspected or a reference + parameter has no registry wired for its component type. + """ + # Derived here only to validate references; the metadata cache derives the + # contract again lazily in _build_metadata. The two happen at different + # lifecycle stages (register vs. first metadata access), and derivation is + # cheap, so the small duplication is deliberate rather than worth caching. + parameters = self._derive_parameters(cls) + for param in parameters: + if param.reference is not None and not is_component_type_resolvable(param.reference.component_type): + raise ValueError( + f"{cls.__name__}: reference parameter '{param.name}' has no registry wired for component type " + f"'{param.reference.component_type}'." + ) + + def register_class(self, cls: type[T], *, name: str | None = None) -> None: + """ + Add a class to the catalog after validating it. + + Registers a class *type* (not an instance) so the registry knows it exists + and can later build instances of it via ``create_instance``. The class is + validated by ``_validate_class`` before being stored, so the catalog never + holds a class whose build contract cannot be resolved. + + Args: + cls (type[T]): The class to register. + name (str | None): Optional custom catalog name. If not provided, it is + derived via ``_get_registry_name``. + + Raises: + ValueError: If the class fails validation. + """ + if name is None: + name = self._get_registry_name(cls) + self._validate_class(cls) + self._classes[name] = cls + self._metadata_cache = None + + def get_class(self, name: str) -> type[T]: + """ + Get a registered class by name. + + Args: + name (str): The catalog name. + + Returns: + type[T]: The registered class (the class itself, not an instance). + + Raises: + KeyError: If the name is not registered. + """ + self._ensure_discovered() + cls = self._classes.get(name) + if cls is None: + available = ", ".join(self.get_class_names()) + raise KeyError(f"'{name}' not found in registry. Available: {available}") + return cls + + def get_class_names(self) -> list[str]: + """ + Get a sorted list of all registered catalog names. + + Returns: + list[str]: Sorted catalog names. + """ + self._ensure_discovered() + return sorted(self._classes.keys()) + + def _ensure_metadata(self) -> dict[str, MetadataT]: + """ + Build (once) and return the metadata cache keyed by catalog name. + + Returns: + dict[str, MetadataT]: Metadata for every registered class, keyed by name. + """ + self._ensure_discovered() + if self._metadata_cache is None: + self._metadata_cache = { + name: self._build_metadata(name, cls) for name, cls in sorted(self._classes.items()) + } + return self._metadata_cache + + def get_all_registered_class_metadata( + self, + *, + include_filters: dict[str, object] | None = None, + exclude_filters: dict[str, object] | None = None, + ) -> list[MetadataT]: + """ + List metadata for all registered classes, optionally filtered. + + Supports filtering on any metadata property: + + - Simple types (str, int, bool): exact match. + - Sequence types (list, tuple): checks if the filter value is contained. + + Args: + include_filters (dict[str, object] | None): Filters that items must match + (AND logic). Keys are metadata property names. + exclude_filters (dict[str, object] | None): Filters that exclude an item + when matched. Keys are metadata property names. + + Returns: + list[MetadataT]: Metadata describing each registered class (filtered). + """ + metadata = list(self._ensure_metadata().values()) + if not include_filters and not exclude_filters: + return metadata + + return [ + m for m in metadata if _matches_filters(m, include_filters=include_filters, exclude_filters=exclude_filters) + ] + + def get_registered_class_metadata(self, name: str) -> MetadataT | None: + """ + Get the metadata for a single registered class by name. + + Args: + name (str): The catalog name. + + Returns: + MetadataT | None: The metadata, or None if the name is not registered. + """ + return self._ensure_metadata().get(name) + + def get_class_metadata(self, cls: type[T]) -> MetadataT: + """ + Build metadata for any class (registered or not). + + Derives the catalog name via ``_get_registry_name`` and builds a fresh + descriptor. Useful for describing a class without registering it. + + Args: + cls (type[T]): The class to describe. + + Returns: + MetadataT: The metadata descriptor for the class. + """ + return self._build_metadata(self._get_registry_name(cls), cls) + + def create_instance(self, name: str, **kwargs: object) -> T: + """ + Build a configured instance by class name. + + Looks up the catalogued class, resolves the given arguments via + ``resolve_constructor_args`` (coerce simple strings, resolve registry + references by name, raise on unknown params), and constructs the object. + + Args: + name (str): The catalog name to build. + **kwargs (object): Constructor arguments (simple values or registry + names for reference parameters). + + Returns: + T: The constructed instance. + + Raises: + KeyError: If the name is not registered. + ValueError: If an argument is not a valid constructor parameter, a + registry reference cannot be resolved, or a value cannot be coerced. + """ + cls = self.get_class(name) + resolved = resolve_constructor_args( + cls=cls, + raw_args=dict(kwargs), + identifier_type=self._identifier_type(), + ) + return cls(**resolved) + + def __contains__(self, name: str) -> bool: + """ + Check if a name is registered. + + Returns: + bool: True if the name is registered, False otherwise. + """ + self._ensure_discovered() + return name in self._classes + + def __len__(self) -> int: + """ + Get the count of registered classes. + + Returns: + int: The number of registered classes. + """ + self._ensure_discovered() + return len(self._classes) + + def __iter__(self) -> Iterator[str]: + """ + Iterate over registered names. + + Returns: + Iterator[str]: An iterator over sorted registered names. + """ + return iter(self.get_class_names()) diff --git a/pyrit/registry/resolution.py b/pyrit/registry/resolution.py index 1b6a75ad75..99b5c61b53 100644 --- a/pyrit/registry/resolution.py +++ b/pyrit/registry/resolution.py @@ -2,209 +2,241 @@ # Licensed under the MIT license. """ -Constructor-argument resolution for PyRIT registries. - -This is the shared mechanism that lets any registry build an instance from a -type name plus a flat dict of arguments. Build inputs are exactly two kinds: - -- **Simple values** — strings/ints/floats/bools (and ``Literal`` choices) that - can be coerced to the constructor's annotated type. -- **Registry references** — a parameter whose annotation is a domain base type - (``PromptTarget``, ``PromptConverter``, ``Scorer``) is supplied *by name* and - resolved from that domain's registry. An already-constructed instance passes - through unchanged. - -Unknown parameters raise, so a caller (form, agent, attack strategy) gets a -clear error instead of having values silently dropped. - -This module performs no eager heavy imports and never imports ``pyrit.backend``: -the resolvable-registry lookups are done lazily so it can be reused anywhere. +The constructor <-> ``Parameter`` contract bridge for PyRIT registries. + +This module is the single place that translates between a component class's +``__init__`` and the declarative ``Parameter`` contract carried by its domain +identifier. It has three responsibilities: + +- **Derive** (``derive_parameters``): read the constructor signature, enriched + by the identifier's ``Param.*`` build markers, into a ``list[Parameter]``. A + parameter the identifier promotes as a reference to another registry (an + included field typed as a child identifier, e.g. ``TargetIdentifier``) becomes + a registry **reference**; every other parameter becomes a plain value parameter + whose ``param_type`` is the annotation with ``Optional[X]`` reduced to ``X``. +- **Resolve** (``resolve_constructor_args``): derive the contract for a class + and turn a flat dict of raw arguments into constructor-ready keyword arguments — + coercing simple string values via ``Parameter.coerce_value`` and resolving + registry-reference parameters by name from the owning domain's registry. +- **Present** (``display_choices``): project a constrained-scalar ``param_type`` + into its allowed-value display tuple. + +The identifier is the declarative blueprint; this module is where the registry +reads and applies it. It performs no eager heavy imports and never imports +``pyrit.backend``: registry lookups are done lazily so it can be reused anywhere. """ from __future__ import annotations import inspect +import re import types -from typing import TYPE_CHECKING, Any, Literal, Protocol, Union, get_args, get_origin +from enum import Enum +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, Union, get_args, get_origin + +from pyrit.common.apply_defaults import REQUIRED_VALUE, _RequiredValueSentinel +from pyrit.models.parameter import ComponentType, Parameter, RegistryReference if TYPE_CHECKING: from collections.abc import Callable -# Scalar Python types whose string values can be coerced to the real type. -_SIMPLE_TYPES: set[type] = {str, int, float, bool} + from pyrit.models.identifiers.component_identifier import ComponentIdentifier +# Constructor parameters that never describe a settable build input. +_SKIPPED_PARAM_NAMES: frozenset[str] = frozenset({"self", "args", "kwargs"}) -class _NamedInstanceRegistry(Protocol): - """Structural type for a registry that resolves stored instances by name.""" +#: A runtime type-annotation object as seen on a constructor parameter or a +#: ``Parameter.param_type``: a concrete ``type``, a typing special form +#: (``X | None`` / ``Optional`` / ``Union`` / ``Literal``), or +#: ``inspect.Parameter.empty`` for an unannotated parameter. Aliased to ``Any`` +#: because no single static type captures all of these; the name documents intent. +TypeAnnotation: TypeAlias = Any - def get(self, name: str) -> Any | None: - """Return the instance registered under ``name``, or None.""" - ... - def get_names(self) -> list[str]: - """Return the sorted names of registered instances.""" - ... +# --------------------------------------------------------------------------- +# Derive: component class -> list[Parameter] +# --------------------------------------------------------------------------- -def get_union_non_none_args(annotation: Any) -> list[Any] | None: +def _unwrap_optional(annotation: TypeAnnotation) -> TypeAnnotation: """ - Return the non-``None`` members of a union annotation, or None if not a union. - - Handles both ``typing.Union[X, None]`` and PEP 604 ``X | None``. This is a - general type-introspection utility (not presentation), reused by coercion, - registry-reference detection, and callers that need to render a type. - - Args: - annotation (Any): The type annotation to inspect. + Reduce ``Optional[X]`` / ``X | None`` to ``X`` (only for single-member unions). Returns: - list[Any] | None: The non-None union members, or None when the annotation - is not a union. + TypeAnnotation: ``X`` when ``annotation`` is a single-member optional union, + otherwise the annotation unchanged. """ origin = get_origin(annotation) if origin is Union or origin is types.UnionType: - return [a for a in get_args(annotation) if a is not type(None)] - return None + non_none = [a for a in get_args(annotation) if a is not type(None)] + if len(non_none) == 1: + return non_none[0] + return annotation -def is_coercible_from_string(annotation: Any) -> bool: +def _parse_arg_descriptions(cls: type) -> dict[str, str]: """ - Return True if a string value can be coerced to the annotated type. - - Covers the scalar types in ``_SIMPLE_TYPES`` (str/int/float/bool), - ``Literal`` annotations, and an ``Optional`` wrapping one of those. + Parse parameter descriptions from a Google-style docstring ``Args`` section. Returns: - bool: True if the annotation is coercible from a string, False otherwise. + dict[str, str]: Mapping of parameter names to their descriptions. """ - if annotation in _SIMPLE_TYPES: - return True - if get_origin(annotation) is Literal: - return True - non_none = get_union_non_none_args(annotation) - if non_none is not None: - return len(non_none) == 1 and is_coercible_from_string(non_none[0]) - return False - - -def _resolvable_registries() -> list[tuple[type, Callable[[], _NamedInstanceRegistry]]]: + doc = (cls.__init__.__doc__ or cls.__doc__ or "").strip() + match = re.search(r"Args:\s*\n(.*?)(?:\n\s*\n|\n\s*Returns:|\n\s*Raises:|\Z)", doc, re.DOTALL) + if not match: + return {} + args_block = match.group(1) + indent_match = re.match(r"^(\s+)", args_block) + indent = indent_match.group(1) if indent_match else r"\s+" + pattern = rf"^{indent}(\w+)\s*(?:\([^)]*\))?\s*:\s*(.+?)(?=\n{indent}\w|\Z)" + descriptions: dict[str, str] = {} + for m in re.finditer(pattern, args_block, re.DOTALL | re.MULTILINE): + descriptions[m.group(1)] = " ".join(m.group(2).split()) + return descriptions + + +def _default_for(param: inspect.Parameter) -> Any: """ - Return the (base type -> registry singleton getter) pairs that can be resolved by name. + Return the ``Parameter.default`` for a constructor parameter. - A constructor parameter whose annotation is (a subclass of) one of these base - types is supplied by name and looked up in the paired registry. Imports are - deferred so this core module stays import-light and free of cycles. + A parameter with no default or the ``REQUIRED_VALUE`` sentinel is required, and + is represented with ``REQUIRED_VALUE`` so consumers can detect it uniformly. Returns: - list[tuple[type, Callable[[], _NamedInstanceRegistry]]]: The resolvable - domain base types paired with a callable returning their registry singleton. + Any: The parameter's default value, or ``REQUIRED_VALUE`` when it is required. """ - from pyrit.prompt_converter import PromptConverter - from pyrit.prompt_target import PromptTarget - from pyrit.registry.components import ConverterRegistry - from pyrit.registry.object_registries import ( - ScorerRegistry, - TargetRegistry, - ) - from pyrit.score.scorer import Scorer + if param.default is inspect.Parameter.empty or isinstance(param.default, _RequiredValueSentinel): + return REQUIRED_VALUE + return param.default - return [ - (PromptTarget, TargetRegistry.get_registry_singleton), - (PromptConverter, lambda: ConverterRegistry.get_registry_singleton().instances), - (Scorer, ScorerRegistry.get_registry_singleton), - ] - -def get_resolvable_registry_getter(annotation: Any) -> Callable[[], _NamedInstanceRegistry] | None: +def derive_parameters(*, cls: type, identifier_type: type[ComponentIdentifier] | None = None) -> list[Parameter]: """ - Return the registry-singleton getter for a registry-reference annotation. + Derive the declarative ``Parameter`` list for ``cls`` from its constructor. - The annotation matches when it is (or unions, e.g. ``X | None``, to) a subclass - of a resolvable domain base type. A parameter with such an annotation is - supplied by name and resolved from the returned registry. + Performs the single ``inspect.signature`` call of the build pipeline and maps + each settable constructor parameter to a ``Parameter``: parameters the + identifier promotes as references carry a ``RegistryReference``; plain + parameters carry an ``Optional``-unwrapped ``param_type``. Parameter order + follows the constructor signature. Args: - annotation (Any): The parameter's type annotation. + cls (type): The component class whose ``__init__`` drives derivation. + identifier_type (type[ComponentIdentifier] | None): The domain identifier + whose ``Param.*`` markers declare which parameters are registry + references. When None, no parameter is treated as a reference. Returns: - Callable[[], _NamedInstanceRegistry] | None: A callable returning the - registry singleton, or None when the annotation is not a registry reference. + list[Parameter]: One ``Parameter`` per settable constructor parameter. + + Raises: + ValueError: If the constructor signature cannot be inspected. """ - if annotation is inspect.Parameter.empty: - return None + try: + sig = inspect.signature(cls.__init__) + except (ValueError, TypeError) as e: + raise ValueError(f"Failed to inspect __init__ signature for '{cls.__name__}': {e}") from e - candidates = get_union_non_none_args(annotation) - if candidates is None: - candidates = [annotation] + reference_overrides = identifier_type.get_reference_component_types() if identifier_type is not None else {} + descriptions = _parse_arg_descriptions(cls) + + parameters: list[Parameter] = [] + for name, param in sig.parameters.items(): + if name in _SKIPPED_PARAM_NAMES: + continue + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + continue + + annotation = param.annotation + component_type = reference_overrides.get(name) + description = descriptions.get(name, "") + default = _default_for(param) + + if component_type is not None: + parameters.append( + Parameter( + name=name, + description=description, + default=default, + reference=RegistryReference(component_type=component_type, annotation=annotation), + ) + ) + else: + param_type = None if annotation is inspect.Parameter.empty else _unwrap_optional(annotation) + parameters.append(Parameter(name=name, description=description, default=default, param_type=param_type)) - for base_type, getter in _resolvable_registries(): - for candidate in candidates: - try: - if isinstance(candidate, type) and issubclass(candidate, base_type): - return getter - except TypeError: - continue - return None + return parameters + + +# --------------------------------------------------------------------------- +# Resolve: derived Parameters + raw args -> constructor keyword arguments +# --------------------------------------------------------------------------- + + +class _NamedInstanceRegistry(Protocol): + """Structural type for a registry that resolves stored instances by name.""" + + def get(self, name: str) -> Any | None: + """Return the instance registered under ``name``, or None.""" + ... + + def get_names(self) -> list[str]: + """Return the sorted names of registered instances.""" + ... -def is_registry_reference(annotation: Any) -> bool: +# TODO (Phase 4 — Target/Scorer migration): this function is deliberately left +# in its current, slightly awkward shape until Target/Scorer become unified +# ``Registry`` instances. It wants to be a flat ``ComponentType -> Registry class`` +# mapping, but it can't be one yet because the three families don't share a uniform +# name->instance surface: ``ConverterRegistry`` is a ``Registry`` whose instances +# live under ``.instances``, while ``TargetRegistry``/``ScorerRegistry`` are still +# legacy object registries whose singleton *is* the instance registry (hence the +# ``.instances`` hop for converters but not the others). Once Target/Scorer migrate +# onto ``Registry`` + ``.instances`` (Phase 4), collapse this into a single mapping +# to the registry classes and fold ``is_component_type_resolvable`` into the base +# ``Registry`` as a private method. +def _registry_getter_for_component_type(component_type: ComponentType) -> Callable[[], _NamedInstanceRegistry] | None: """ - Return True if the annotation is a registry reference (resolved by name). + Return the getter for the registry singleton that resolves a component family. + + This is the one place that must import the concrete registries, so it stays in + the resolve layer (the derive layer never imports them). It is the inverse of + the identifier's self-reported ``component_type``: given that family, return the + registry that resolves its references by name. Returns: - bool: True if a value for this parameter is supplied by name and resolved - from a registry, False otherwise. + Callable[[], _NamedInstanceRegistry] | None: The registry getter, or None + when no registry is wired for ``component_type``. """ - return get_resolvable_registry_getter(annotation) is not None + from pyrit.registry.components import ConverterRegistry + from pyrit.registry.object_registries import ScorerRegistry, TargetRegistry + + if component_type is ComponentType.TARGET: + return TargetRegistry.get_registry_singleton + if component_type is ComponentType.CONVERTER: + return lambda: ConverterRegistry.get_registry_singleton().instances + if component_type is ComponentType.SCORER: + return ScorerRegistry.get_registry_singleton + return None -def coerce_string_to_annotation(*, value: str, annotation: Any) -> Any: +def is_component_type_resolvable(component_type: ComponentType) -> bool: """ - Coerce a string value to the annotated scalar type (int/float/bool/Literal). + Return whether a registry is wired to resolve references of ``component_type``. - ``Optional[X]`` / ``X | None`` is unwrapped to ``X`` first. A ``Literal`` value - is validated against the allowed members and returned as the matching member - (so an int literal comes back as an ``int``); other ``str`` values pass through - unchanged. + This is the registration-time gate used by buildable registries: a reference + parameter whose component type has no paired registry can never be resolved by + name and should fail fast instead of erroring only at build time. - Args: - value (str): The raw string value. - annotation (Any): The parameter's type annotation. + NOTE: This belongs on the ``Registry`` base as a private method; it lives here + for now only because it wraps ``_registry_getter_for_component_type``. Both move + together in Phase 4 (see that function's note). Returns: - Any: The value coerced to the annotated type, or the original string when - no numeric/boolean/Literal coercion applies. - - Raises: - ValueError: If the value cannot be interpreted as the annotated type, or is - not one of the allowed members of an annotated ``Literal``. + bool: True when references of ``component_type`` can be resolved by name. """ - if annotation is inspect.Parameter.empty: - return value - - non_none = get_union_non_none_args(annotation) - if non_none is not None and len(non_none) == 1: - annotation = non_none[0] - - if get_origin(annotation) is Literal: - allowed = get_args(annotation) - for member in allowed: - if value == str(member): - return member - raise ValueError(f"expected one of {[str(a) for a in allowed]}, got {value!r}") - - if annotation is int: - return int(value) - if annotation is float: - return float(value) - if annotation is bool: - lowered = value.strip().lower() - if lowered in ("true", "1", "yes"): - return True - if lowered in ("false", "0", "no"): - return False - raise ValueError(f"cannot interpret {value!r} as a boolean") - return value + return _registry_getter_for_component_type(component_type) is not None def _resolve_registry_reference( @@ -249,60 +281,85 @@ def _resolve_registry_reference( ) -def resolve_constructor_args(*, cls: type, raw_args: dict[str, Any]) -> dict[str, Any]: +def resolve_constructor_args( + *, cls: type, raw_args: dict[str, Any], identifier_type: type[ComponentIdentifier] | None = None +) -> dict[str, Any]: """ Resolve a flat argument dict into constructor-ready keyword arguments. - For each argument: validate it is a real constructor parameter (unless the - constructor accepts ``**kwargs``); resolve registry-reference parameters by - name; coerce simple string values to their annotated scalar type; pass - everything else through unchanged. + Derives the ``Parameter`` contract for ``cls`` (the single + ``inspect.signature`` call) and applies it to ``raw_args``. For each raw + argument: validate it is a declared parameter; resolve registry-reference + parameters by name; coerce simple string values via + ``Parameter.coerce_value``; pass everything else through unchanged. Args: - cls (type): The class whose ``__init__`` signature drives resolution. + cls (type): The class being built. raw_args (dict[str, Any]): The raw argument values (e.g. from a form or agent). + identifier_type (type[ComponentIdentifier] | None): The domain identifier + whose ``Param.*`` markers declare which parameters are registry + references. When None, no parameter is treated as a reference. Returns: dict[str, Any]: Arguments ready to pass to ``cls(**resolved)``. Raises: - ValueError: If the signature cannot be inspected, an argument is not a - valid constructor parameter, a registry reference cannot be resolved, - or a simple value cannot be coerced. + ValueError: If an argument is not a declared parameter, a registry + reference cannot be resolved, or a simple value cannot be coerced. """ - try: - sig = inspect.signature(cls.__init__) - except (ValueError, TypeError) as e: - raise ValueError(f"Failed to inspect __init__ signature for '{cls.__name__}': {e}") from e - - accepts_var_kwargs = any(p.kind is inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) - valid_params = { - param_name: p - for param_name, p in sig.parameters.items() - if param_name != "self" and p.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) - } + by_name = {param.name: param for param in derive_parameters(cls=cls, identifier_type=identifier_type)} resolved: dict[str, Any] = {} for name, value in raw_args.items(): - param = valid_params.get(name) - if param is None and not accepts_var_kwargs: + param = by_name.get(name) + if param is None: raise ValueError( - f"Unknown parameter '{name}' for '{cls.__name__}'. Valid parameters: {sorted(valid_params.keys())}" + f"Unknown parameter '{name}' for '{cls.__name__}'. Valid parameters: {sorted(by_name.keys())}" ) - annotation = param.annotation if param is not None else inspect.Parameter.empty - - registry_getter = get_resolvable_registry_getter(annotation) - if registry_getter is not None: - resolved[name] = _resolve_registry_reference( - value=value, getter=registry_getter, owner=cls.__name__, name=name - ) - elif isinstance(value, str) and is_coercible_from_string(annotation): + if param.reference is not None: + getter = _registry_getter_for_component_type(param.reference.component_type) + if getter is None: + raise ValueError( + f"{cls.__name__}.{name}: no registry is wired for component type " + f"'{param.reference.component_type}'." + ) + resolved[name] = _resolve_registry_reference(value=value, getter=getter, owner=cls.__name__, name=name) + elif isinstance(value, str) and param.is_string_coercible: try: - resolved[name] = coerce_string_to_annotation(value=value, annotation=annotation) + resolved[name] = param.coerce_value(value) except (ValueError, TypeError) as e: raise ValueError(f"Parameter '{name}' of '{cls.__name__}': {e}") from e else: resolved[name] = value return resolved + + +# --------------------------------------------------------------------------- +# Present: param_type -> allowed-value display tuple +# --------------------------------------------------------------------------- + + +def display_choices(param_type: TypeAnnotation) -> tuple[Any, ...] | None: + """ + Derive the allowed-value display list from a constrained-scalar ``param_type``. + + This is the presentation projection of an allowed set: a ``Parameter`` stores + the constraint as a ``Literal[...]`` / ``Enum`` type, and serializers render the + members on demand instead of reading a separate field. ``Optional[X]`` / + ``X | None`` is unwrapped first. + + Args: + param_type (TypeAnnotation): The parameter's type annotation. + + Returns: + tuple[Any, ...] | None: The allowed members for a constrained scalar + (``Literal`` args or ``Enum`` member values), or None when unconstrained. + """ + unwrapped = _unwrap_optional(param_type) + if get_origin(unwrapped) is Literal: + return get_args(unwrapped) + if isinstance(unwrapped, type) and issubclass(unwrapped, Enum): + return tuple(member.value for member in unwrapped) + return None diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index a93f3098c1..7994d82d5b 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -18,7 +18,7 @@ import sys from types import ModuleType -from pyrit.common.parameter import Parameter +from pyrit.models.parameter import Parameter from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult from pyrit.scenario.core import ( AtomicAttack, diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 7b50cef237..1a014778dc 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -3,7 +3,7 @@ """Core scenario classes for running attack configurations.""" -from pyrit.common.parameter import Parameter +from pyrit.models.parameter import Parameter from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory, ScorerOverridePolicy diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 4611128eda..29aea1290e 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -18,7 +18,7 @@ from collections.abc import Sequence from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, cast, get_origin +from typing import TYPE_CHECKING, Any, ClassVar, cast try: # Built-in on Python 3.11+. Fall back to the ``exceptiongroup`` backport on 3.10 @@ -29,15 +29,15 @@ from tqdm.auto import tqdm -from pyrit.common import REQUIRED_VALUE, Parameter, apply_defaults +from pyrit.common import REQUIRED_VALUE, apply_defaults from pyrit.common.deprecation import print_deprecation_message -from pyrit.common.parameter import coerce_value, validate_param_type from pyrit.common.utils import to_sha256 from pyrit.executor.attack import AttackExecutor from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.memory import CentralMemory from pyrit.memory.memory_models import ScenarioResultEntry from pyrit.models import AttackOutcome, AttackResult, SeedAttackGroup +from pyrit.models.parameter import Parameter from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult, ScenarioRunState from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.target_requirements import TargetRequirements @@ -451,7 +451,7 @@ def set_params_from_args(self, *, args: dict[str, Any]) -> None: # Stash unknowns so _validate_params can list them all at once. coerced[name] = raw_value continue - coerced[name] = coerce_value(param=param, raw_value=raw_value) + coerced[name] = param.coerce_value(raw_value) self._validate_params(params=coerced, declared=declared) @@ -463,7 +463,7 @@ def set_params_from_args(self, *, args: dict[str, Any]) -> None: # without an explicit default land as None, and the scenario raises # a domain-specific error at run time if it cannot proceed. coerced[param.name] = ( - copy.deepcopy(coerce_value(param=param, raw_value=param.default)) if param.default is not None else None + copy.deepcopy(param.coerce_value(param.default)) if param.default is not None else None ) self.params = coerced @@ -477,9 +477,8 @@ def _validate_declarations(self, *, declared: list[Parameter]) -> None: Raises: ValueError: If declarations contain duplicate names, an - unsupported ``param_type``, ``choices`` not coercible to - ``param_type``, or a default that fails coercion / is not - in ``choices``. + unsupported ``param_type``, or a default that fails coercion + (including membership for a constrained scalar). """ seen: set[str] = set() for param in declared: @@ -488,44 +487,18 @@ def _validate_declarations(self, *, declared: list[Parameter]) -> None: seen.add(param.name) try: - validate_param_type(param=param) + param.validate() except ValueError as exc: raise ValueError(f"Scenario '{type(self).__name__}' {exc}") from exc - if param.choices is not None and get_origin(param.param_type) is list: - # argparse `nargs='+'` applies choices per-item; core checks the whole list. - # Reject the combination until we reconcile the semantics. - raise ValueError( - f"Scenario '{type(self).__name__}' parameter '{param.name}' declares choices on a list " - f"param_type ({param.param_type!r}); this combination is not supported. " - f"Use a scalar param_type with choices, or omit choices on list params." - ) - - if param.choices is not None and param.param_type is not None: - # Each choice must be coercible — fail at declaration time, not user time. - for choice in param.choices: - try: - coerce_value(param=param, raw_value=choice) - except ValueError as exc: - raise ValueError( - f"Scenario '{type(self).__name__}' parameter '{param.name}' choice " - f"{choice!r} is not coercible to {param.param_type!r}: {exc}" - ) from exc - if param.default is not None: try: - coerced_default = coerce_value(param=param, raw_value=param.default) + param.coerce_value(param.default) except ValueError as exc: raise ValueError( f"Scenario '{type(self).__name__}' parameter '{param.name}' has an invalid default: {exc}" ) from exc - if param.choices is not None and coerced_default not in param.choices: - raise ValueError( - f"Scenario '{type(self).__name__}' parameter '{param.name}' default " - f"{param.default!r} is not in declared choices {param.choices!r}." - ) - def _validate_params(self, *, params: dict[str, Any], declared: list[Parameter]) -> None: """ Validate supplied params against the scenario's declarations. diff --git a/pyrit/scenario/scenarios/adaptive/text_adaptive.py b/pyrit/scenario/scenarios/adaptive/text_adaptive.py index e942bd4b1e..bedd275f01 100644 --- a/pyrit/scenario/scenarios/adaptive/text_adaptive.py +++ b/pyrit/scenario/scenarios/adaptive/text_adaptive.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, ClassVar from pyrit.common import apply_defaults -from pyrit.common.parameter import Parameter +from pyrit.models.parameter import Parameter from pyrit.registry.object_registries.attack_technique_registry import ( AttackTechniqueRegistry, ) diff --git a/pyrit/scenario/scenarios/airt/scam.py b/pyrit/scenario/scenarios/airt/scam.py index e591f580c8..03c6f79698 100644 --- a/pyrit/scenario/scenarios/airt/scam.py +++ b/pyrit/scenario/scenarios/airt/scam.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from pyrit.common import Parameter, apply_defaults +from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message # Deprecated. Will be removed in 0.16.0. from pyrit.common.path import ( EXECUTOR_RED_TEAM_PATH, @@ -21,7 +21,7 @@ AttackAdversarialConfig, AttackScoringConfig, ) -from pyrit.models import SeedAttackGroup +from pyrit.models import Parameter, SeedAttackGroup from pyrit.prompt_target import PromptTarget from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique diff --git a/pyrit/scenario/scenarios/benchmark/adversarial.py b/pyrit/scenario/scenarios/benchmark/adversarial.py index 0c0b4f6fb2..b0e0a58070 100644 --- a/pyrit/scenario/scenarios/benchmark/adversarial.py +++ b/pyrit/scenario/scenarios/benchmark/adversarial.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, ClassVar from pyrit.analytics import get_cached_results_for_technique -from pyrit.common import Parameter, apply_defaults +from pyrit.common import apply_defaults from pyrit.executor.attack import AttackScoringConfig from pyrit.models import ( AttackOutcome, @@ -19,6 +19,7 @@ ScenarioResult, SeedAttackGroup, ) +from pyrit.models.parameter import Parameter from pyrit.registry import AttackTechniqueRegistry, TargetRegistry from pyrit.registry.tag_query import TagQuery from pyrit.scenario.core.atomic_attack import AtomicAttack diff --git a/pyrit/setup/initializers/__init__.py b/pyrit/setup/initializers/__init__.py index 0cf9d1afcf..d2951a7c0c 100644 --- a/pyrit/setup/initializers/__init__.py +++ b/pyrit/setup/initializers/__init__.py @@ -4,7 +4,7 @@ """PyRIT initializers package.""" from pyrit.common.deprecation import print_deprecation_message -from pyrit.common.parameter import Parameter +from pyrit.models.parameter import Parameter from pyrit.setup.initializers.airt import AIRTInitializer from pyrit.setup.initializers.components.scenario_techniques import ScenarioTechniqueInitializer from pyrit.setup.initializers.components.scorers import ScorerInitializer diff --git a/pyrit/setup/initializers/components/scorers.py b/pyrit/setup/initializers/components/scorers.py index 29d1ac98a5..af0b060a9e 100644 --- a/pyrit/setup/initializers/components/scorers.py +++ b/pyrit/setup/initializers/components/scorers.py @@ -23,7 +23,7 @@ from azure.ai.contentsafety.models import TextCategory -from pyrit.common.parameter import Parameter +from pyrit.models.parameter import Parameter from pyrit.registry import ScorerRegistry, TargetRegistry from pyrit.score import ( AzureContentFilterScorer, diff --git a/pyrit/setup/initializers/components/targets.py b/pyrit/setup/initializers/components/targets.py index d81c15220a..022f63efff 100644 --- a/pyrit/setup/initializers/components/targets.py +++ b/pyrit/setup/initializers/components/targets.py @@ -20,8 +20,8 @@ from typing import Any from pyrit.auth import get_azure_openai_auth, get_azure_token_provider -from pyrit.common.parameter import Parameter from pyrit.models.identifiers import TARGET_EVAL_PARAM_FALLBACKS, TARGET_EVAL_PARAMS +from pyrit.models.parameter import Parameter from pyrit.prompt_target import ( AzureMLChatTarget, OpenAIChatTarget, diff --git a/pyrit/setup/initializers/pyrit_initializer.py b/pyrit/setup/initializers/pyrit_initializer.py index d92707c92a..f9802647c0 100644 --- a/pyrit/setup/initializers/pyrit_initializer.py +++ b/pyrit/setup/initializers/pyrit_initializer.py @@ -16,7 +16,7 @@ from pyrit.common.apply_defaults import get_global_default_values from pyrit.common.deprecation import print_deprecation_message -from pyrit.common.parameter import Parameter +from pyrit.models.parameter import Parameter def __getattr__(name: str) -> type: diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 2760433b44..4cc21e52fd 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -33,9 +33,9 @@ @pytest.fixture(autouse=True) def reset_registry(): """Reset the converter registry before each test.""" - ConverterRegistry.reset_instance() + ConverterRegistry.reset_registry_singleton() yield - ConverterRegistry.reset_instance() + ConverterRegistry.reset_registry_singleton() class TestListConverters: diff --git a/tests/unit/common/test_parameter.py b/tests/unit/common/test_parameter.py index 92341aa457..5950044960 100644 --- a/tests/unit/common/test_parameter.py +++ b/tests/unit/common/test_parameter.py @@ -1,115 +1,49 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Unit tests for the unified Parameter dataclass.""" +"""The parameter contract moved to ``pyrit.models.parameter``. -import pytest - -from pyrit.common import Parameter - - -class TestParameter: - """Tests for pyrit.common.Parameter.""" - - def test_minimal_construction(self) -> None: - """Parameter requires only name and description.""" - p = Parameter(name="x", description="some param") - - assert p.name == "x" - assert p.description == "some param" - assert p.default is None - assert p.param_type is None - assert p.choices is None - - def test_full_construction(self) -> None: - """All fields can be supplied.""" - p = Parameter( - name="max_turns", - description="turn cap", - default=5, - param_type=int, - choices=(1, 5, 10), - ) - - assert p.default == 5 - assert p.param_type is int - assert p.choices == (1, 5, 10) - - def test_parameter_is_hashable(self) -> None: - """Frozen dataclass means Parameters can live in sets and dict keys.""" - p = Parameter(name="x", description="d") +These tests pin the deprecation shims: importing ``Parameter`` (or the coercion +helpers) from ``pyrit.common`` / ``pyrit.common.parameter`` must still resolve to +the canonical object but emit a ``DeprecationWarning``. +""" - # If hash() raised, this set construction would fail. - assert {p} == {p} +import importlib - def test_choices_list_is_normalized_to_tuple(self) -> None: - """A list passed for choices is coerced to a tuple to keep the dataclass hashable.""" - p = Parameter(name="x", description="d", choices=["a", "b", "c"]) - - assert p.choices == ("a", "b", "c") - assert isinstance(p.choices, tuple) - - # And the resulting Parameter is still hashable. - _ = hash(p) - - def test_choices_none_stays_none(self) -> None: - """Default None choices is preserved (no spurious tuple coercion).""" - p = Parameter(name="x", description="d") - - assert p.choices is None - - def test_choices_coerced_to_int_param_type(self) -> None: - """Stringy int choices are coerced so argparse and runtime both see ints.""" - p = Parameter(name="x", description="d", param_type=int, choices=("1", "5", "10")) - - assert p.choices == (1, 5, 10) - assert all(isinstance(c, int) for c in p.choices) - - def test_choices_coerced_to_bool_param_type(self) -> None: - p = Parameter(name="x", description="d", param_type=bool, choices=("true", "false")) - - assert p.choices == (True, False) +import pytest - def test_choices_uncoercible_left_unchanged(self) -> None: - """Uncoercible choices are left as-is so _validate_declarations can surface a clear error.""" - p = Parameter(name="x", description="d", param_type=int, choices=("not-a-number", "5")) +import pyrit.common +import pyrit.common.parameter as common_parameter +from pyrit.models.parameter import Parameter as CanonicalParameter - # Original tuple preserved. The downstream validator emits the friendly - # "scenario X parameter Y choice Z is not coercible" error. - assert p.choices == ("not-a-number", "5") - def test_choices_skipped_for_none_param_type(self) -> None: - """When param_type is None (raw passthrough) choices stay as-declared.""" - p = Parameter(name="x", description="d", choices=("a", "b")) +def test_parameter_from_common_parameter_warns_and_resolves(): + # Reload to reset the shim's one-time "already warned" state so the warning + # fires deterministically regardless of earlier imports in the session. + importlib.reload(common_parameter) - assert p.choices == ("a", "b") + with pytest.warns(DeprecationWarning, match=r"pyrit\.models\.parameter\.Parameter"): + resolved = common_parameter.Parameter - def test_list_param_type_accepted(self) -> None: - """``param_type=list[str]`` is accepted (GenericAlias, not type).""" - p = Parameter(name="datasets", description="d", param_type=list[str]) + assert resolved is CanonicalParameter - assert p.param_type == list[str] - def test_parameter_is_immutable(self) -> None: - """Frozen dataclass rejects field assignment after construction.""" - p = Parameter(name="x", description="d") +def test_parameter_from_common_package_warns_and_resolves(): + importlib.reload(pyrit.common) - with pytest.raises((AttributeError, TypeError)): - p.name = "y" # type: ignore[misc] + with pytest.warns(DeprecationWarning, match=r"pyrit\.models\.Parameter"): + resolved = pyrit.common.Parameter + assert resolved is CanonicalParameter -class TestCoerceValuePassthroughDeepcopy: - """``coerce_value`` deep-copies raw passthrough values for ``param_type=None``.""" - def test_param_type_none_returns_distinct_object(self) -> None: - """A mutable raw value must not share identity with the coerced result.""" - from pyrit.common.parameter import coerce_value +def test_common_parameter_unknown_name_raises_attribute_error(): + importlib.reload(common_parameter) - raw = ["a", "b"] - coerced = coerce_value(param=Parameter(name="opts", description="d"), raw_value=raw) + missing_attr = "does_not_exist" + with pytest.raises(AttributeError): + getattr(common_parameter, missing_attr) - assert coerced == raw - assert coerced is not raw - raw.append("c") - assert coerced == ["a", "b"] +def test_parameter_no_longer_in_common_all(): + assert "Parameter" not in pyrit.common.__all__ diff --git a/tests/unit/models/identifiers/test_component_identifier.py b/tests/unit/models/identifiers/test_component_identifier.py index ec51120619..292ac94e5b 100644 --- a/tests/unit/models/identifiers/test_component_identifier.py +++ b/tests/unit/models/identifiers/test_component_identifier.py @@ -1414,7 +1414,7 @@ def test_with_eval_hash_returns_new_instance(self): class TestComponentIdentifierReservedKeyCollision: @pytest.mark.parametrize( "reserved", - ["class_name", "class_module", "hash", "pyrit_version", "eval_hash", "children", "params"], + ["class_name", "class_module", "hash", "pyrit_version", "eval_hash", "children", "params", "attributes"], ) def test_reserved_param_name_rejected_in_normalized_shape(self, reserved): with pytest.raises(ValidationError, match="reserved names"): @@ -1506,3 +1506,79 @@ def test_usable_in_set(self): b = ComponentIdentifier(class_name="Foo", class_module="m", params={"a": 1}) s = {a, b} assert len(s) == 1 + + +class TestComponentIdentifierAttributes: + """The ``attributes`` bucket: hashed identity state, excluded from the eval hash, never a constructor input.""" + + def test_attribute_is_part_of_identity_hash(self): + """Adding an attribute changes the content hash (it is part of identity).""" + base = ComponentIdentifier(class_name="Foo", class_module="m", params={"x": 1}) + with_attr = ComponentIdentifier( + class_name="Foo", class_module="m", params={"x": 1}, attributes={"model_version": "v2"} + ) + assert base.hash != with_attr.hash + + def test_different_attributes_produce_different_hashes(self): + a = ComponentIdentifier(class_name="Foo", class_module="m", attributes={"model_version": "v1"}) + b = ComponentIdentifier(class_name="Foo", class_module="m", attributes={"model_version": "v2"}) + assert a.hash != b.hash + + def test_empty_attributes_hash_matches_no_attributes(self): + base = ComponentIdentifier(class_name="Foo", class_module="m", params={"x": 1}) + empty = ComponentIdentifier(class_name="Foo", class_module="m", params={"x": 1}, attributes={}) + assert base.hash == empty.hash + + def test_none_valued_attribute_excluded_from_hash(self): + """A None-valued attribute does not change the hash (backward-compatible additions).""" + base = ComponentIdentifier(class_name="Foo", class_module="m", params={"x": 1}) + with_none = ComponentIdentifier(class_name="Foo", class_module="m", params={"x": 1}, attributes={"opt": None}) + assert base.hash == with_none.hash + + def test_attribute_excluded_from_eval_hash(self): + """Attributes feed the identity hash but not the eval hash.""" + no_attr = ComponentIdentifier(class_name="Foo", class_module="m", params={"x": 1}) + with_attr = ComponentIdentifier( + class_name="Foo", class_module="m", params={"x": 1}, attributes={"model_version": "v2"} + ) + assert _build_eval_dict(no_attr, child_eval_rules={}) == _build_eval_dict(with_attr, child_eval_rules={}) + + def test_attribute_distinct_from_same_named_param(self): + """An ``attributes`` entry and a same-named ``params`` entry are not interchangeable.""" + as_param = ComponentIdentifier(class_name="Foo", class_module="m", params={"version": "v2"}) + as_attr = ComponentIdentifier(class_name="Foo", class_module="m", attributes={"version": "v2"}) + assert as_param.hash != as_attr.hash + + def test_serialize_nests_attributes_under_key(self): + ident = ComponentIdentifier(class_name="Foo", class_module="m", attributes={"region": "eastus"}) + dumped = ident.model_dump() + assert dumped["attributes"] == {"region": "eastus"} + + def test_serialize_omits_attributes_key_when_empty(self): + ident = ComponentIdentifier(class_name="Foo", class_module="m", params={"x": 1}) + assert "attributes" not in ident.model_dump() + + def test_roundtrip_preserves_attributes_and_hash(self): + ident = ComponentIdentifier( + class_name="Foo", class_module="m", params={"x": 1}, attributes={"region": "eastus"} + ) + rebuilt = ComponentIdentifier.model_validate(ident.model_dump()) + assert rebuilt.attributes == {"region": "eastus"} + assert rebuilt.hash == ident.hash + + def test_of_factory_drops_none_attributes(self): + class _Dummy: + pass + + ident = ComponentIdentifier.of(_Dummy(), attributes={"region": "eastus", "drop": None}) + assert ident.attributes == {"region": "eastus"} + + def test_with_eval_hash_preserves_attributes(self): + ident = ComponentIdentifier(class_name="Foo", class_module="m", attributes={"region": "eastus"}) + updated = ident.with_eval_hash("abc123") + assert updated.attributes == {"region": "eastus"} + assert updated.hash == ident.hash + + def test_repr_includes_attributes(self): + ident = ComponentIdentifier(class_name="Foo", class_module="m", attributes={"region": "eastus"}) + assert "attributes=(region='eastus')" in repr(ident) diff --git a/tests/unit/models/identifiers/test_typed_identifier.py b/tests/unit/models/identifiers/test_typed_identifier.py index 4922e4d4d4..3f7ad4e005 100644 --- a/tests/unit/models/identifiers/test_typed_identifier.py +++ b/tests/unit/models/identifiers/test_typed_identifier.py @@ -15,6 +15,7 @@ SeedIdentifier, TargetIdentifier, ) +from pyrit.models.parameter import ComponentType def _target_identifier() -> ComponentIdentifier: @@ -199,14 +200,13 @@ def test_promoted_children_typed_per_field(self): params={}, children={ "converter_target": target_child, - "sub_converters": [sub_converter_child], + "sub_converter": sub_converter_child, }, ) cd = ConverterIdentifier.from_component_identifier(ci) assert isinstance(cd.converter_target, TargetIdentifier) assert cd.converter_target.endpoint == "https://obj" - assert isinstance(cd.sub_converters, list) - assert all(isinstance(c, ConverterIdentifier) for c in cd.sub_converters) + assert isinstance(cd.sub_converter, ConverterIdentifier) assert cd.hash == ci.hash @@ -220,6 +220,61 @@ def test_promoted_fields(self): assert sd.prompt_target.endpoint == "https://c" +class TestComponentType: + """Each leaf identifier self-reports its registry family; the base reports none.""" + + def test_base_is_not_buildable(self): + assert ComponentIdentifier.component_type is None + + @pytest.mark.parametrize( + "identifier_type, expected", + [ + (TargetIdentifier, ComponentType.TARGET), + (ConverterIdentifier, ComponentType.CONVERTER), + (ScorerIdentifier, ComponentType.SCORER), + ], + ) + def test_leaf_component_type(self, identifier_type, expected): + assert identifier_type.component_type is expected + + def test_converter_reference_args_map_to_target(self): + # converter_target (typed TargetIdentifier) and sub_converter (typed + # ConverterIdentifier) are both Param.Include buildable references; the + # Param.ClassAttr type lists are not. + assert ConverterIdentifier.get_reference_component_types() == { + "converter_target": ComponentType.TARGET, + "sub_converter": ComponentType.CONVERTER, + } + + def test_base_identifier_has_no_reference_args(self): + assert ComponentIdentifier.get_reference_component_types() == {} + + +class TestClassAttributeValues: + """``get_class_attribute_values`` reads Param.ClassAttr fields off a target class.""" + + def test_reads_converter_supported_types(self): + class _FakeConverter: + SUPPORTED_INPUT_TYPES = ["text"] + SUPPORTED_OUTPUT_TYPES = ["text", "image_path"] + + values = ConverterIdentifier.get_class_attribute_values(_FakeConverter) + assert values == { + "supported_input_types": ["text"], + "supported_output_types": ["text", "image_path"], + } + + def test_missing_attribute_maps_to_none(self): + class _NoTypes: + pass + + values = ConverterIdentifier.get_class_attribute_values(_NoTypes) + assert values == {"supported_input_types": None, "supported_output_types": None} + + def test_base_identifier_has_no_class_attributes(self): + assert ComponentIdentifier.get_class_attribute_values(object) == {} + + class TestDirectConstruction: """Building a typed identifier by hand yields a valid ComponentIdentifier.""" diff --git a/tests/unit/models/test_parameter.py b/tests/unit/models/test_parameter.py new file mode 100644 index 0000000000..4e522b74a7 --- /dev/null +++ b/tests/unit/models/test_parameter.py @@ -0,0 +1,310 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Unit tests for the unified Parameter model and its coercion methods.""" + +from enum import Enum +from typing import Literal + +import pytest + +from pyrit.models import Parameter +from pyrit.models.parameter import ComponentType, RegistryReference, _is_scalar_param_type +from pyrit.registry.resolution import display_choices + + +class _Speed(Enum): + FAST = "fast" + SLOW = "slow" + + +class _Unsupported: + """Stand-in for an arbitrary (non-scalar, non-registry) constructor type.""" + + +class TestParameter: + """Tests for the five-field pyrit.common.Parameter.""" + + def test_minimal_construction(self) -> None: + p = Parameter(name="x", description="some param") + + assert p.name == "x" + assert p.description == "some param" + assert p.default is None + assert p.param_type is None + assert p.reference is None + + def test_full_construction(self) -> None: + p = Parameter( + name="max_turns", + description="turn cap", + default=5, + param_type=Literal[1, 5, 10], + ) + + assert p.default == 5 + assert p.param_type == Literal[1, 5, 10] + + def test_parameter_is_hashable(self) -> None: + """Frozen dataclass means Parameters can live in sets and dict keys.""" + p = Parameter(name="x", description="d") + + assert {p} == {p} + + def test_list_param_type_accepted(self) -> None: + """``param_type=list[str]`` is accepted (GenericAlias, not type).""" + p = Parameter(name="datasets", description="d", param_type=list[str]) + + assert p.param_type == list[str] + + def test_parameter_is_immutable(self) -> None: + p = Parameter(name="x", description="d") + + with pytest.raises((AttributeError, TypeError)): + p.name = "y" # type: ignore[misc] + + +class TestIsScalarParamType: + """``_is_scalar_param_type`` recognizes plain and constrained scalars.""" + + @pytest.mark.parametrize("annotation", [str, int, float, bool, Literal["a", "b"], _Speed]) + def test_scalar_forms(self, annotation: object) -> None: + assert _is_scalar_param_type(annotation) is True + + @pytest.mark.parametrize("annotation", [None, list[str], list[int], _Unsupported]) + def test_non_scalar_forms(self, annotation: object) -> None: + assert _is_scalar_param_type(annotation) is False + + +class TestDisplayChoices: + """``display_choices`` (now in the registry layer) derives the allowed set from the type.""" + + def test_literal_returns_args(self) -> None: + assert display_choices(Literal["fast", "slow"]) == ("fast", "slow") + + def test_optional_literal_unwrapped(self) -> None: + assert display_choices(Literal["a", "b"] | None) == ("a", "b") + + def test_enum_returns_member_values(self) -> None: + assert display_choices(_Speed) == ("fast", "slow") + + @pytest.mark.parametrize("annotation", [None, str, int, list[str]]) + def test_unconstrained_returns_none(self, annotation: object) -> None: + assert display_choices(annotation) is None + + +class TestIsStringCoercible: + """``Parameter.is_string_coercible`` reflects whether a string token can supply the value.""" + + @pytest.mark.parametrize("param_type", [str, int, float, bool, Literal["a", "b"]]) + def test_coercible_value_types(self, param_type: object) -> None: + p = Parameter(name="x", description="d", param_type=param_type) + assert p.is_string_coercible is True + + @pytest.mark.parametrize("param_type", [None, list[str], _Speed, _Unsupported]) + def test_non_coercible_value_types(self, param_type: object) -> None: + p = Parameter(name="x", description="d", param_type=param_type) + assert p.is_string_coercible is False + + def test_reference_is_never_coercible(self) -> None: + p = Parameter( + name="target", + description="d", + reference=RegistryReference(component_type=ComponentType.TARGET), + ) + assert p.is_string_coercible is False + + +class TestIsReferenceTo: + """``Parameter.is_reference_to`` is the single predicate for "points at this component family".""" + + def test_matching_component_type_is_true(self) -> None: + p = Parameter( + name="converter_target", + description="d", + reference=RegistryReference(component_type=ComponentType.TARGET), + ) + assert p.is_reference_to(ComponentType.TARGET) is True + + def test_other_component_type_is_false(self) -> None: + p = Parameter( + name="converter_target", + description="d", + reference=RegistryReference(component_type=ComponentType.TARGET), + ) + assert p.is_reference_to(ComponentType.SCORER) is False + + def test_non_reference_is_false(self) -> None: + p = Parameter(name="x", description="d", param_type=int) + assert p.is_reference_to(ComponentType.TARGET) is False + + +class TestCoerceValueScalars: + """``Parameter.coerce_value`` coerces plain scalars.""" + + def test_int(self) -> None: + p = Parameter(name="n", description="d", param_type=int) + assert p.coerce_value("5") == 5 + + def test_float(self) -> None: + p = Parameter(name="r", description="d", param_type=float) + assert p.coerce_value("0.25") == 0.25 + + def test_bool(self) -> None: + p = Parameter(name="flag", description="d", param_type=bool) + assert p.coerce_value("yes") is True + + def test_str_passthrough(self) -> None: + p = Parameter(name="s", description="d", param_type=str) + assert p.coerce_value("hello") == "hello" + + def test_int_invalid_raises(self) -> None: + p = Parameter(name="n", description="d", param_type=int) + with pytest.raises(ValueError, match="could not be coerced to int"): + p.coerce_value("not-a-number") + + def test_bool_invalid_raises(self) -> None: + p = Parameter(name="flag", description="d", param_type=bool) + with pytest.raises(ValueError, match="boolean"): + p.coerce_value("maybe") + + +class TestCoerceValueConstrainedScalars: + """``Parameter.coerce_value`` validates membership for Literal / Enum.""" + + def test_literal_member(self) -> None: + p = Parameter(name="speed", description="d", param_type=Literal["fast", "slow"]) + assert p.coerce_value("fast") == "fast" + + def test_literal_coerces_to_member_type(self) -> None: + p = Parameter(name="n", description="d", param_type=Literal[1, 5, 10]) + result = p.coerce_value("5") + assert result == 5 + assert isinstance(result, int) + + def test_literal_invalid_raises(self) -> None: + p = Parameter(name="speed", description="d", param_type=Literal["fast", "slow"]) + with pytest.raises(ValueError, match="one of"): + p.coerce_value("medium") + + def test_enum_by_value(self) -> None: + p = Parameter(name="speed", description="d", param_type=_Speed) + assert p.coerce_value("fast") is _Speed.FAST + + def test_enum_by_member(self) -> None: + p = Parameter(name="speed", description="d", param_type=_Speed) + assert p.coerce_value(_Speed.SLOW) is _Speed.SLOW + + def test_enum_invalid_raises(self) -> None: + p = Parameter(name="speed", description="d", param_type=_Speed) + with pytest.raises(ValueError, match="one of"): + p.coerce_value("medium") + + +class TestCoerceValueLists: + """``Parameter.coerce_value`` coerces each element of a ``list[...]`` param.""" + + def test_list_str(self) -> None: + p = Parameter(name="ds", description="d", param_type=list[str]) + assert p.coerce_value(["a", "b"]) == ["a", "b"] + + def test_list_int(self) -> None: + p = Parameter(name="ns", description="d", param_type=list[int]) + assert p.coerce_value(["1", "2", "3"]) == [1, 2, 3] + + def test_list_literal_membership(self) -> None: + p = Parameter(name="modes", description="d", param_type=list[Literal["a", "b"]]) + assert p.coerce_value(["a", "b", "a"]) == ["a", "b", "a"] + + def test_list_literal_invalid_raises(self) -> None: + p = Parameter(name="modes", description="d", param_type=list[Literal["a", "b"]]) + with pytest.raises(ValueError, match="one of"): + p.coerce_value(["a", "z"]) + + def test_non_list_value_raises(self) -> None: + p = Parameter(name="ds", description="d", param_type=list[str]) + with pytest.raises(ValueError, match="expects a list"): + p.coerce_value("not-a-list") + + +class TestCoerceValuePassthrough: + """Reference / arbitrary / None param_types pass through unchanged.""" + + def test_param_type_none_returns_distinct_object(self) -> None: + raw = ["a", "b"] + coerced = Parameter(name="opts", description="d").coerce_value(raw) + + assert coerced == raw + assert coerced is not raw + + raw.append("c") + assert coerced == ["a", "b"] + + def test_unsupported_type_with_value_passes_through(self) -> None: + sentinel = _Unsupported() + p = Parameter(name="obj", description="d", param_type=_Unsupported) + assert p.coerce_value(sentinel) is sentinel + + def test_reference_param_passes_value_through(self) -> None: + """A reference parameter never coerces — the registry resolves it by name.""" + p = Parameter( + name="converter_target", + description="d", + reference=RegistryReference(component_type=ComponentType.TARGET), + ) + assert p.coerce_value("my_target") == "my_target" + + +class TestValidate: + """``Parameter.validate`` accepts supported forms and tolerates defaulted others.""" + + @pytest.mark.parametrize( + "param_type", + [None, str, int, float, bool, Literal["a", "b"], _Speed, list[str], list[int], list[Literal["a", "b"]]], + ) + def test_supported_forms_ok(self, param_type: object) -> None: + Parameter(name="x", description="d", param_type=param_type).validate() + + def test_unsupported_without_default_raises(self) -> None: + p = Parameter(name="x", description="d", param_type=_Unsupported) + with pytest.raises(ValueError, match="unsupported param_type"): + p.validate() + + def test_unsupported_with_default_tolerated(self) -> None: + p = Parameter(name="x", description="d", param_type=_Unsupported, default=_Unsupported()) + p.validate() + + def test_reference_param_is_valid(self) -> None: + p = Parameter( + name="target", + description="d", + reference=RegistryReference(component_type=ComponentType.TARGET), + ) + p.validate() + + +class TestCoercionParity: + """Derivation feeds ``coerce_value`` the unwrapped type, so coercion round-trips.""" + + @pytest.mark.parametrize( + "annotation, raw, expected", + [ + (int, "7", 7), + (float, "1.5", 1.5), + (bool, "true", True), + (Literal["a", "b"], "b", "b"), + (Literal[1, 2], "2", 2), + (int | None, "9", 9), + ], + ) + def test_derived_param_coerces(self, annotation: object, raw: str, expected: object) -> None: + from pyrit.registry.resolution import derive_parameters + + class _Holder: + def __init__(self, *, value=None) -> None: + self.value = value + + _Holder.__init__.__annotations__["value"] = annotation + param = next(p for p in derive_parameters(cls=_Holder) if p.name == "value") + + assert param.coerce_value(raw) == expected diff --git a/tests/unit/prompt_converter/test_selective_text_converter.py b/tests/unit/prompt_converter/test_selective_text_converter.py index 9f28a8b9a4..fa3b3e505f 100644 --- a/tests/unit/prompt_converter/test_selective_text_converter.py +++ b/tests/unit/prompt_converter/test_selective_text_converter.py @@ -24,14 +24,14 @@ class TestSelectiveTextConverter: async def test_initialization_valid(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=IndexSelectionStrategy(start=0, end=5), ) assert converter is not None async def test_initialization_with_preserve_tokens(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=IndexSelectionStrategy(start=0, end=5), preserve_tokens=True, start_token="<<", @@ -47,7 +47,7 @@ def input_supported(self, input_type): with pytest.raises(ValueError, match="does not support text input"): SelectiveTextConverter( - converter=NonTextConverter(), + sub_converter=NonTextConverter(), selection_strategy=IndexSelectionStrategy(start=0, end=5), ) @@ -59,13 +59,13 @@ def output_supported(self, output_type): with pytest.raises(ValueError, match="does not support text output"): SelectiveTextConverter( - converter=NonTextConverter(), + sub_converter=NonTextConverter(), selection_strategy=IndexSelectionStrategy(start=0, end=5), ) async def test_convert_async_with_index_strategy(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=IndexSelectionStrategy(start=0, end=5), ) result = await converter.convert_async(prompt="Hello World", input_type="text") @@ -75,7 +75,7 @@ async def test_convert_async_with_index_strategy(self): async def test_convert_async_with_regex_strategy(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=RegexSelectionStrategy(pattern=r"\d+"), ) result = await converter.convert_async(prompt="The code is 12345 here", input_type="text") @@ -85,7 +85,7 @@ async def test_convert_async_with_regex_strategy(self): async def test_convert_async_with_keyword_strategy(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=KeywordSelectionStrategy(keyword="secret"), ) result = await converter.convert_async(prompt="The secret is here", input_type="text") @@ -95,7 +95,7 @@ async def test_convert_async_with_keyword_strategy(self): async def test_convert_async_with_position_strategy(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=PositionSelectionStrategy(start_proportion=0.0, end_proportion=0.5), ) result = await converter.convert_async(prompt="0123456789", input_type="text") @@ -105,7 +105,7 @@ async def test_convert_async_with_position_strategy(self): async def test_convert_async_with_proportion_strategy(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=ProportionSelectionStrategy(proportion=0.5, anchor="start"), ) result = await converter.convert_async(prompt="0123456789", input_type="text") @@ -115,7 +115,7 @@ async def test_convert_async_with_proportion_strategy(self): async def test_convert_async_with_range_strategy(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=RangeSelectionStrategy(start_proportion=0.0, end_proportion=0.5), ) result = await converter.convert_async(prompt="0123456789", input_type="text") @@ -125,7 +125,7 @@ async def test_convert_async_with_range_strategy(self): async def test_convert_async_with_preserve_tokens(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=IndexSelectionStrategy(start=0, end=5), preserve_tokens=True, ) @@ -136,7 +136,7 @@ async def test_convert_async_with_preserve_tokens(self): async def test_convert_async_with_custom_tokens(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=IndexSelectionStrategy(start=0, end=5), preserve_tokens=True, start_token="<<", @@ -149,7 +149,7 @@ async def test_convert_async_with_custom_tokens(self): async def test_convert_async_no_match_returns_original(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=RegexSelectionStrategy(pattern=r"\d+"), ) result = await converter.convert_async(prompt="No numbers here", input_type="text") @@ -158,7 +158,7 @@ async def test_convert_async_no_match_returns_original(self): async def test_convert_async_invalid_input_type(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=IndexSelectionStrategy(start=0, end=5), ) with pytest.raises(ValueError, match="only supports text input"): @@ -167,7 +167,7 @@ async def test_convert_async_invalid_input_type(self): async def test_convert_async_chaining_with_preserved_tokens(self): # First converter: convert first half with preserve_tokens converter1 = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=PositionSelectionStrategy(start_proportion=0.0, end_proportion=0.5), preserve_tokens=True, ) @@ -175,7 +175,7 @@ async def test_convert_async_chaining_with_preserved_tokens(self): # Second converter: convert second half with preserve_tokens converter2 = SelectiveTextConverter( - converter=ROT13Converter(), + sub_converter=ROT13Converter(), selection_strategy=PositionSelectionStrategy(start_proportion=0.5, end_proportion=1.0), preserve_tokens=True, ) @@ -189,7 +189,7 @@ async def test_convert_async_chaining_with_preserved_tokens(self): async def test_convert_async_middle_section(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=IndexSelectionStrategy(start=4, end=10), ) result = await converter.convert_async(prompt="The secret code", input_type="text") @@ -199,7 +199,7 @@ async def test_convert_async_middle_section(self): async def test_convert_async_end_section(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=IndexSelectionStrategy(start=11, end=None), ) result = await converter.convert_async(prompt="Hello World", input_type="text") @@ -209,7 +209,7 @@ async def test_convert_async_end_section(self): async def test_input_supported(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=IndexSelectionStrategy(start=0, end=5), ) assert converter.input_supported("text") is True @@ -217,7 +217,7 @@ async def test_input_supported(self): async def test_output_supported(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=IndexSelectionStrategy(start=0, end=5), ) assert converter.output_supported("text") is True @@ -225,7 +225,7 @@ async def test_output_supported(self): async def test_convert_async_with_keyword_and_context(self): converter = SelectiveTextConverter( - converter=Base64Converter(), + sub_converter=Base64Converter(), selection_strategy=KeywordSelectionStrategy(keyword="secret", context_before=4, context_after=3), ) result = await converter.convert_async(prompt="The secret is here", input_type="text") @@ -235,7 +235,7 @@ async def test_convert_async_with_keyword_and_context(self): async def test_convert_async_entire_text_with_range(self): converter = SelectiveTextConverter( - converter=ROT13Converter(), + sub_converter=ROT13Converter(), selection_strategy=RangeSelectionStrategy(start_proportion=0.0, end_proportion=1.0), ) result = await converter.convert_async(prompt="Hello", input_type="text") @@ -247,7 +247,9 @@ async def test_initialization_word_level_strategy_with_word_level_converter_rais that has a non-default word_selection_strategy raises ValueError.""" with pytest.raises(ValueError, match="Cannot use a WordSelectionStrategy"): SelectiveTextConverter( - converter=LeetspeakConverter(word_selection_strategy=WordProportionSelectionStrategy(proportion=0.5)), + sub_converter=LeetspeakConverter( + word_selection_strategy=WordProportionSelectionStrategy(proportion=0.5) + ), selection_strategy=WordIndexSelectionStrategy(indices=[0, 1]), ) @@ -256,7 +258,7 @@ async def test_initialization_word_level_strategy_with_default_word_level_conver that has the default (AllWordsSelectionStrategy) is allowed.""" # This should NOT raise - LeetspeakConverter with no explicit strategy uses AllWordsSelectionStrategy converter = SelectiveTextConverter( - converter=LeetspeakConverter(), + sub_converter=LeetspeakConverter(), selection_strategy=WordIndexSelectionStrategy(indices=[0]), ) assert converter is not None @@ -267,7 +269,7 @@ async def test_initialization_char_level_strategy_with_word_level_converter_allo # This should NOT raise - character-level strategy passes a substring to the converter, # so the converter's word selection strategy can meaningfully operate on it converter = SelectiveTextConverter( - converter=LeetspeakConverter(word_selection_strategy=WordProportionSelectionStrategy(proportion=0.5)), + sub_converter=LeetspeakConverter(word_selection_strategy=WordProportionSelectionStrategy(proportion=0.5)), selection_strategy=IndexSelectionStrategy(start=0, end=20), ) assert converter is not None diff --git a/tests/unit/registry/test_converter_registry.py b/tests/unit/registry/test_converter_registry.py index 926de6835e..62b6ab0762 100644 --- a/tests/unit/registry/test_converter_registry.py +++ b/tests/unit/registry/test_converter_registry.py @@ -10,7 +10,9 @@ import pytest +from pyrit.common import REQUIRED_VALUE from pyrit.models import ComponentIdentifier, Message, MessagePiece, PromptDataType +from pyrit.models.parameter import ComponentType from pyrit.prompt_converter import ( Base64Converter, CaesarConverter, @@ -29,13 +31,10 @@ ConverterMetadata, ConverterRegistry, ) -from pyrit.registry.components.converter_registry import ( - _extract_parameters, - _requires_llm_target, -) from pyrit.registry.object_registries import ( TargetRegistry, ) +from pyrit.registry.resolution import derive_parameters class MockPromptTarget(PromptTarget): @@ -108,10 +107,10 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text @pytest.fixture def registry(): """Provide a fresh ``ConverterRegistry`` singleton, reset around each test.""" - ConverterRegistry.reset_instance() + ConverterRegistry.reset_registry_singleton() instance = ConverterRegistry.get_registry_singleton() yield instance - ConverterRegistry.reset_instance() + ConverterRegistry.reset_registry_singleton() # --------------------------------------------------------------------------- @@ -123,10 +122,10 @@ class TestConverterRegistrySingleton: """Tests for the singleton pattern in ConverterRegistry.""" def setup_method(self): - ConverterRegistry.reset_instance() + ConverterRegistry.reset_registry_singleton() def teardown_method(self): - ConverterRegistry.reset_instance() + ConverterRegistry.reset_registry_singleton() def test_get_registry_singleton_returns_same_instance(self): assert ConverterRegistry.get_registry_singleton() is ConverterRegistry.get_registry_singleton() @@ -134,9 +133,9 @@ def test_get_registry_singleton_returns_same_instance(self): def test_get_registry_singleton_returns_converter_registry_type(self): assert isinstance(ConverterRegistry.get_registry_singleton(), ConverterRegistry) - def test_reset_instance_clears_singleton(self): + def test_reset_registry_singleton_clears_singleton(self): instance1 = ConverterRegistry.get_registry_singleton() - ConverterRegistry.reset_instance() + ConverterRegistry.reset_registry_singleton() assert ConverterRegistry.get_registry_singleton() is not instance1 @@ -308,16 +307,6 @@ def test_build_does_not_register_instance(self, registry: ConverterRegistry): registry.create_instance("Base64Converter") assert len(registry.instances) == 0 - def test_honors_registered_default_kwargs(self, registry: ConverterRegistry): - registry.register(CaesarConverter, name="CaesarDefault", default_kwargs={"caesar_offset": 5}) - converter = registry.create_instance("CaesarDefault") - assert converter.get_identifier().params.get("caesar_offset") == 5 - - def test_uses_registered_factory(self, registry: ConverterRegistry): - sentinel = Base64Converter() - registry.register(Base64Converter, name="B64Factory", factory=lambda **kwargs: sentinel) - assert registry.create_instance("B64Factory") is sentinel - @pytest.mark.usefixtures("patch_central_database") class TestCreateLLMConverter: @@ -347,13 +336,20 @@ class TestClassMetadata: """Tests for converter class-catalog metadata building.""" def _metadata_for(self, registry: ConverterRegistry, name: str) -> ConverterMetadata: - return next(m for m in registry.list_class_metadata() if m.class_name == name) + return next(m for m in registry.get_all_registered_class_metadata() if m.class_name == name) def test_metadata_includes_supported_types(self, registry: ConverterRegistry): meta = self._metadata_for(registry, "Base64Converter") assert "text" in meta.supported_input_types assert "text" in meta.supported_output_types + def test_metadata_carries_class_attributes(self, registry: ConverterRegistry): + meta = self._metadata_for(registry, "Base64Converter") + # Supported types are sourced from class attributes via Param.ClassAttr, + # not from a fabricated instance identifier. + assert "supported_input_types" in meta.class_attributes + assert "text" in [str(dt) for dt in meta.class_attributes["supported_input_types"]] + def test_metadata_has_no_catalog_visible_field(self, registry: ConverterRegistry): # catalog_visible is a presentation concern owned by the backend/frontend. assert not hasattr(self._metadata_for(registry, "Base64Converter"), "catalog_visible") @@ -377,16 +373,18 @@ def test_is_llm_based_flag(self, registry: ConverterRegistry): def test_parameters_extracted(self, registry: ConverterRegistry): meta = self._metadata_for(registry, "CaesarConverter") caesar_param = next(p for p in meta.parameters if p.name == "caesar_offset") - assert caesar_param.required is True - assert caesar_param.annotation is int - assert caesar_param.coercible_from_string is True + assert caesar_param.default is REQUIRED_VALUE + assert caesar_param.param_type is int + assert caesar_param.reference is None + assert caesar_param.is_string_coercible is True def test_surfaces_non_coercible_params(self, registry: ConverterRegistry): - # An LLM-based converter exposes its target parameter for dynamic - # construction even though it cannot be coerced from a string. + # An LLM-based converter exposes its target parameter (a registry reference) + # for dynamic construction even though it cannot be coerced from a string. meta = self._metadata_for(registry, "PersuasionConverter") - non_coercible = [p for p in meta.parameters if not p.coercible_from_string] - assert non_coercible, "expected at least one non-coercible parameter (the LLM target)" + references = [p for p in meta.parameters if p.reference is not None] + assert references, "expected at least one reference parameter (the LLM target)" + assert any(p.is_reference_to(ComponentType.TARGET) for p in meta.parameters) # --------------------------------------------------------------------------- @@ -397,8 +395,8 @@ def test_surfaces_non_coercible_params(self, registry: ConverterRegistry): class _UnionTargetConverter: """Helper with a PEP 604 unioned target parameter for introspection tests.""" - def __init__(self, *, target: PromptTarget | None = None, offset: int | None = None) -> None: - self.target = target + def __init__(self, *, converter_target: PromptTarget | None = None, offset: int | None = None) -> None: + self.converter_target = converter_target self.offset = offset @@ -409,42 +407,32 @@ def __init__(self, *, fmt: Literal["A", "B"] | None = None) -> None: self.fmt = fmt -class TestExtractParameters: - """Tests for the converter-parameter introspection helper.""" +class TestDeriveParameters: + """Tests for the converter-parameter derivation into the ``Parameter`` contract.""" - def test_exposes_raw_annotation(self) -> None: - offset_param = next(p for p in _extract_parameters(_UnionTargetConverter) if p.name == "offset") - assert offset_param.annotation == (int | None) - assert offset_param.coercible_from_string is True + def test_unwraps_optional_into_param_type(self) -> None: + from pyrit.models.identifiers import ConverterIdentifier - def test_includes_non_coercible(self) -> None: - target_param = next(p for p in _extract_parameters(_UnionTargetConverter) if p.name == "target") - assert target_param.coercible_from_string is False - - def test_optional_literal_choices(self) -> None: - fmt_param = next(p for p in _extract_parameters(_OptionalLiteralConverter) if p.name == "fmt") - assert fmt_param.choices == ("A", "B") - - def test_sets_requires_llm(self) -> None: - params = _extract_parameters(_UnionTargetConverter) - target_param = next(p for p in params if p.name == "target") + params = derive_parameters(cls=_UnionTargetConverter, identifier_type=ConverterIdentifier) offset_param = next(p for p in params if p.name == "offset") - assert target_param.requires_llm is True - assert offset_param.requires_llm is False - + assert offset_param.param_type is int + assert offset_param.reference is None + assert offset_param.is_string_coercible is True -class TestRequiresLlmTarget: - """Tests for the _requires_llm_target helper.""" + def test_target_becomes_reference(self) -> None: + from pyrit.models.identifiers import ConverterIdentifier - def test_plain_target(self) -> None: - assert _requires_llm_target(PromptTarget) is True + params = derive_parameters(cls=_UnionTargetConverter, identifier_type=ConverterIdentifier) + target_param = next(p for p in params if p.name == "converter_target") + assert target_param.reference is not None + assert target_param.reference.component_type is ComponentType.TARGET + assert target_param.param_type is None - def test_optional_target(self) -> None: - assert _requires_llm_target(PromptTarget | None) is True + def test_optional_literal_choices(self) -> None: + from pyrit.registry.resolution import display_choices - def test_non_target(self) -> None: - assert _requires_llm_target(int) is False - assert _requires_llm_target(str | None) is False + fmt_param = next(p for p in derive_parameters(cls=_OptionalLiteralConverter) if p.name == "fmt") + assert display_choices(fmt_param.param_type) == ("A", "B") class TestNoBackendDependency: @@ -464,3 +452,58 @@ def test_module_has_no_backend_dependency(self) -> None: elif isinstance(node, ast.ImportFrom) and node.module: imported_modules.append(node.module) assert not any(name.startswith("pyrit.backend") for name in imported_modules) + + +class TestRegistrationGate: + """The identifier blueprint must line up with a resolvable contract for every converter.""" + + def test_discovery_validates_all_converters(self, registry: ConverterRegistry) -> None: + # Discovery registers every converter through ``register_class``, which + # validates each class. Accessing the catalog therefore proves every + # discovered converter is describable and buildable (all reference params + # map to a wired registry); otherwise discovery would have raised. + names = registry.get_class_names() + assert names + assert "Base64Converter" in names + + def test_every_converter_derives_a_contract(self, registry: ConverterRegistry) -> None: + from pyrit.models.identifiers import ConverterIdentifier + + for name in registry.get_class_names(): + cls = registry.get_class(name) + parameters = derive_parameters(cls=cls, identifier_type=ConverterIdentifier) + # Reference params only ever carry a component type the resolver can map. + for param in parameters: + if param.reference is not None: + assert param.reference.component_type in ( + ComponentType.TARGET, + ComponentType.CONVERTER, + ComponentType.SCORER, + ) + + def test_is_llm_based_matches_target_reference(self, registry: ConverterRegistry) -> None: + from pyrit.models.identifiers import ConverterIdentifier + + for meta in registry.get_all_registered_class_metadata(): + parameters = derive_parameters(cls=registry.get_class(meta.class_name), identifier_type=ConverterIdentifier) + has_target = any(p.is_reference_to(ComponentType.TARGET) for p in parameters) + assert meta.is_llm_based is has_target, f"is_llm_based mismatch for {meta.class_name}" + + def test_register_class_raises_for_unresolvable_reference(self, registry: ConverterRegistry) -> None: + from unittest.mock import patch + + from pyrit.models.parameter import Parameter, RegistryReference + + target_ref = Parameter( + name="converter_target", + description="", + reference=RegistryReference(component_type=ComponentType.TARGET, annotation=object), + ) + # A class whose reference parameter has no wired registry must fail the + # registration gate (validation runs at register_class time). + with ( + patch("pyrit.registry.registry.derive_parameters", return_value=[target_ref]), + patch("pyrit.registry.resolution._registry_getter_for_component_type", return_value=None), + ): + with pytest.raises(ValueError, match="no registry wired"): + registry.register_class(Base64Converter) diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py new file mode 100644 index 0000000000..177177d6f2 --- /dev/null +++ b/tests/unit/registry/test_registry.py @@ -0,0 +1,212 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for the standalone ``Registry`` base. + +``ConverterRegistry`` overrides ``_get_registry_name`` and ``_identifier_type``, so +exercising the base only through it leaves the base's own defaults uncovered: +snake_case naming, the no-identifier path, eager vs. lazy discovery, the metadata +accessors, and the filter wiring. These tests drive a minimal subclass that keeps +every base default. +""" + +from dataclasses import dataclass, field + +import pytest + +from pyrit.registry.base import ClassRegistryEntry +from pyrit.registry.registry import Registry, _get_metadata_value, _matches_filters + + +class SampleWidget: + """A sample widget. + + A second paragraph that must not leak into the one-line summary. + """ + + def __init__(self, *, size: int = 1) -> None: + self.size = size + + +class UndocumentedWidget: + def __init__(self, *, size: int = 1) -> None: + self.size = size + + +class UnregisteredWidget: + """An unregistered widget.""" + + def __init__(self, *, size: int = 1) -> None: + self.size = size + + +class WidgetRegistry(Registry[object, ClassRegistryEntry]): + """Minimal Registry subclass that keeps every base default.""" + + def __init__(self, *, lazy_discovery: bool = True) -> None: + self.discover_calls = 0 + super().__init__(lazy_discovery=lazy_discovery) + + def _discover(self) -> None: + self.discover_calls += 1 + self.register_class(SampleWidget) + self.register_class(UndocumentedWidget) + + def _metadata_class(self) -> type[ClassRegistryEntry]: + return ClassRegistryEntry + + +@dataclass(frozen=True) +class _TaggedMetadata(ClassRegistryEntry): + tags: tuple[str, ...] = field(kw_only=True, default=()) + + +def test_get_registry_name_defaults_to_snake_case(): + registry = WidgetRegistry() + + assert registry.get_class_names() == ["sample_widget", "undocumented_widget"] + + +def test_build_metadata_uses_first_paragraph_summary(): + registry = WidgetRegistry() + + meta = registry.get_registered_class_metadata("sample_widget") + + assert meta is not None + assert meta.class_description == "A sample widget." + assert meta.class_name == "SampleWidget" + assert meta.class_module == SampleWidget.__module__ + + +def test_build_metadata_empty_description_without_docstring(): + registry = WidgetRegistry() + + meta = registry.get_registered_class_metadata("undocumented_widget") + + assert meta is not None + assert meta.class_description == "" + + +def test_class_attributes_empty_without_identifier_type(): + registry = WidgetRegistry() + + meta = registry.get_registered_class_metadata("sample_widget") + + assert meta is not None + assert meta.class_attributes == {} + + +def test_parameters_have_no_references_without_identifier_type(): + registry = WidgetRegistry() + + meta = registry.get_registered_class_metadata("sample_widget") + + assert meta is not None + assert all(p.reference is None for p in meta.parameters) + + +def test_create_instance_builds_object(): + registry = WidgetRegistry() + + widget = registry.create_instance("sample_widget", size=3) + + assert isinstance(widget, SampleWidget) + assert widget.size == 3 + + +def test_lazy_discovery_defers_until_access(): + registry = WidgetRegistry(lazy_discovery=True) + + assert registry.discover_calls == 0 + registry.get_class_names() + assert registry.discover_calls == 1 + + +def test_eager_discovery_runs_in_constructor(): + registry = WidgetRegistry(lazy_discovery=False) + + assert registry.discover_calls == 1 + + +def test_get_registered_class_metadata_unknown_name_returns_none(): + registry = WidgetRegistry() + + assert registry.get_registered_class_metadata("does_not_exist") is None + + +def test_get_class_metadata_builds_for_unregistered_class(): + registry = WidgetRegistry() + + meta = registry.get_class_metadata(UnregisteredWidget) + + assert meta.class_name == "UnregisteredWidget" + assert meta.registry_name == "unregistered_widget" + assert "unregistered_widget" not in registry.get_class_names() + + +def test_get_class_unknown_name_raises(): + registry = WidgetRegistry() + + with pytest.raises(KeyError, match="not found in registry"): + registry.get_class("nope") + + +def test_iter_and_contains_and_len(): + registry = WidgetRegistry() + + assert len(registry) == 2 + assert "sample_widget" in registry + assert list(registry) == ["sample_widget", "undocumented_widget"] + + +def test_get_all_metadata_no_filters_returns_all(): + registry = WidgetRegistry() + + all_meta = registry.get_all_registered_class_metadata() + + assert {m.registry_name for m in all_meta} == {"sample_widget", "undocumented_widget"} + + +def test_get_all_metadata_include_filter_matches_subset(): + registry = WidgetRegistry() + + result = registry.get_all_registered_class_metadata(include_filters={"registry_name": "sample_widget"}) + + assert [m.registry_name for m in result] == ["sample_widget"] + + +def test_get_all_metadata_exclude_filter_removes_match(): + registry = WidgetRegistry() + + result = registry.get_all_registered_class_metadata(exclude_filters={"registry_name": "sample_widget"}) + + assert [m.registry_name for m in result] == ["undocumented_widget"] + + +def test_matches_filters_list_containment(): + meta = _TaggedMetadata(class_name="X", class_module="m", tags=("a", "b")) + + assert _matches_filters(meta, include_filters={"tags": "a"}) + assert not _matches_filters(meta, include_filters={"tags": "z"}) + assert not _matches_filters(meta, exclude_filters={"tags": "a"}) + + +def test_matches_filters_unknown_include_key_fails(): + meta = ClassRegistryEntry(class_name="X", class_module="m") + + assert not _matches_filters(meta, include_filters={"nope": "x"}) + + +def test_get_metadata_value_falls_back_to_params(): + class HasParams: + def __init__(self) -> None: + self.params = {"k": "v"} + + found, value = _get_metadata_value(HasParams(), "k") + assert found is True + assert value == "v" + + missing_found, missing_value = _get_metadata_value(HasParams(), "missing") + assert missing_found is False + assert missing_value is None diff --git a/tests/unit/registry/test_resolution.py b/tests/unit/registry/test_resolution.py index 8a7889bd78..72400d692f 100644 --- a/tests/unit/registry/test_resolution.py +++ b/tests/unit/registry/test_resolution.py @@ -9,15 +9,16 @@ import pytest +from pyrit.common import REQUIRED_VALUE +from pyrit.common.apply_defaults import _RequiredValueSentinel from pyrit.models import Message, MessagePiece +from pyrit.models.identifiers import ConverterIdentifier +from pyrit.models.parameter import ComponentType from pyrit.prompt_target import PromptTarget from pyrit.registry.object_registries import TargetRegistry from pyrit.registry.resolution import ( - coerce_string_to_annotation, - get_resolvable_registry_getter, - get_union_non_none_args, - is_coercible_from_string, - is_registry_reference, + derive_parameters, + display_choices, resolve_constructor_args, ) @@ -56,12 +57,44 @@ def __init__( self.mode = mode -class _AcceptsKwargs: - """Helper whose constructor accepts arbitrary keyword arguments.""" +class _Plain: + def __init__( + self, *, count: int, ratio: float = 0.5, mode: Literal["a", "b"] = "a", note: str | None = None + ) -> None: + """Plain converter-like helper. + + Args: + count (int): A required count. + ratio (float): A ratio with a default. + mode (Literal): A constrained mode. + note (str): An optional note. + """ + self.count = count + self.ratio = ratio + self.mode = mode + self.note = note - def __init__(self, *, name: str = "n", **kwargs: object) -> None: + +class _SentinelDefault: + def __init__(self, *, value: int = REQUIRED_VALUE) -> None: # type: ignore[assignment] + self.value = value + + +class _VarArgs: + def __init__(self, *args: object, name: str = "n", **kwargs: object) -> None: self.name = name - self.kwargs = kwargs + + +class _StrTargetArg: + """A constructor arg named like the identifier reference but annotated as a plain type.""" + + def __init__(self, *, converter_target: str = "x") -> None: + self.converter_target = converter_target + + +def _resolve(cls: type, raw_args: dict[str, object], *, identifier_type: type | None = None) -> dict[str, object]: + """Resolve ``raw_args`` against the derived parameter contract for ``cls``.""" + return resolve_constructor_args(cls=cls, raw_args=raw_args, identifier_type=identifier_type) @pytest.fixture @@ -83,122 +116,118 @@ def empty_target_registry(): TargetRegistry.reset_instance() -class TestTypeHelpers: - """Tests for the type-introspection helpers.""" - - def test_get_union_non_none_args_pep604(self) -> None: - assert get_union_non_none_args(int | None) == [int] - - def test_get_union_non_none_args_not_a_union(self) -> None: - assert get_union_non_none_args(int) is None - - def test_is_coercible_from_string(self) -> None: - assert is_coercible_from_string(str) is True - assert is_coercible_from_string(int | None) is True - assert is_coercible_from_string(Literal["a", "b"]) is True - assert is_coercible_from_string(PromptTarget) is False - - def test_is_registry_reference(self) -> None: - assert is_registry_reference(PromptTarget) is True - assert is_registry_reference(PromptTarget | None) is True - assert is_registry_reference(int) is False - - def test_get_resolvable_registry_getter_returns_target_registry(self) -> None: - getter = get_resolvable_registry_getter(PromptTarget) - assert getter is not None - assert isinstance(getter(), TargetRegistry) +class TestDisplayChoices: + """Tests for the allowed-value presentation projection.""" - def test_get_resolvable_registry_getter_none_for_simple(self) -> None: - assert get_resolvable_registry_getter(int) is None + def test_literal(self) -> None: + assert display_choices(Literal["a", "b"]) == ("a", "b") + def test_optional_literal_unwrapped(self) -> None: + assert display_choices(Literal["a", "b"] | None) == ("a", "b") -class TestCoerceStringToAnnotation: - """Tests for scalar string coercion.""" - - def test_int(self) -> None: - assert coerce_string_to_annotation(value="42", annotation=int) == 42 - - def test_float(self) -> None: - assert coerce_string_to_annotation(value="0.25", annotation=float) == 0.25 - - def test_bool_true(self) -> None: - assert coerce_string_to_annotation(value="yes", annotation=bool) is True - - def test_bool_false(self) -> None: - assert coerce_string_to_annotation(value="0", annotation=bool) is False - - def test_bool_invalid_raises(self) -> None: - with pytest.raises(ValueError, match="boolean"): - coerce_string_to_annotation(value="maybe", annotation=bool) - - def test_optional_unwrapped(self) -> None: - assert coerce_string_to_annotation(value="7", annotation=int | None) == 7 - - def test_str_passthrough(self) -> None: - assert coerce_string_to_annotation(value="hello", annotation=str) == "hello" - - def test_literal_valid(self) -> None: - assert coerce_string_to_annotation(value="b", annotation=Literal["a", "b"]) == "b" - - def test_literal_invalid_raises(self) -> None: - with pytest.raises(ValueError, match="one of"): - coerce_string_to_annotation(value="c", annotation=Literal["a", "b"]) - - def test_literal_coerces_to_member_type(self) -> None: - result = coerce_string_to_annotation(value="2", annotation=Literal[1, 2]) - assert result == 2 - assert isinstance(result, int) + def test_unconstrained_returns_none(self) -> None: + assert display_choices(int) is None @pytest.mark.usefixtures("patch_central_database") class TestResolveConstructorArgs: - """Tests for the end-to-end resolve_constructor_args.""" + """Tests for the end-to-end resolve_constructor_args over a derived contract.""" def test_coerces_simple_params(self) -> None: - resolved = resolve_constructor_args(cls=_SimpleOnly, raw_args={"count": "3", "ratio": "0.75", "flag": "true"}) + resolved = _resolve(_SimpleOnly, {"count": "3", "ratio": "0.75", "flag": "true"}) assert resolved == {"count": 3, "ratio": 0.75, "flag": True} def test_literal_passthrough(self) -> None: - resolved = resolve_constructor_args(cls=_SimpleOnly, raw_args={"mode": "b"}) + resolved = _resolve(_SimpleOnly, {"mode": "b"}) assert resolved == {"mode": "b"} def test_literal_invalid_raises(self) -> None: with pytest.raises(ValueError, match="mode"): - resolve_constructor_args(cls=_SimpleOnly, raw_args={"mode": "z"}) + _resolve(_SimpleOnly, {"mode": "z"}) def test_unknown_param_raises(self) -> None: with pytest.raises(ValueError, match="Unknown parameter 'nope'"): - resolve_constructor_args(cls=_SimpleOnly, raw_args={"nope": "1"}) + _resolve(_SimpleOnly, {"nope": "1"}) def test_unknown_param_lists_valid_params(self) -> None: with pytest.raises(ValueError, match="count"): - resolve_constructor_args(cls=_SimpleOnly, raw_args={"nope": "1"}) - - def test_var_kwargs_accepts_unknown(self) -> None: - resolved = resolve_constructor_args(cls=_AcceptsKwargs, raw_args={"anything": "value"}) - assert resolved == {"anything": "value"} + _resolve(_SimpleOnly, {"nope": "1"}) def test_invalid_coercion_raises(self) -> None: with pytest.raises(ValueError, match="count"): - resolve_constructor_args(cls=_SimpleOnly, raw_args={"count": "not-an-int"}) + _resolve(_SimpleOnly, {"count": "not-an-int"}) def test_resolves_registry_reference_by_name(self, target_registry: TargetRegistry) -> None: - resolved = resolve_constructor_args(cls=_NeedsTarget, raw_args={"converter_target": "my_target", "offset": "5"}) + resolved = _resolve( + _NeedsTarget, {"converter_target": "my_target", "offset": "5"}, identifier_type=ConverterIdentifier + ) assert resolved["converter_target"] is target_registry.get_instance_by_name("my_target") assert resolved["offset"] == 5 def test_registry_reference_instance_passthrough(self, target_registry: TargetRegistry) -> None: instance = MockPromptTarget() - resolved = resolve_constructor_args(cls=_NeedsTarget, raw_args={"converter_target": instance}) + resolved = _resolve(_NeedsTarget, {"converter_target": instance}, identifier_type=ConverterIdentifier) assert resolved["converter_target"] is instance def test_unknown_registry_reference_raises_with_names(self, target_registry: TargetRegistry) -> None: with pytest.raises(ValueError, match="my_target"): - resolve_constructor_args(cls=_NeedsTarget, raw_args={"converter_target": "missing"}) + _resolve(_NeedsTarget, {"converter_target": "missing"}, identifier_type=ConverterIdentifier) def test_unknown_registry_reference_empty_registry_hint(self, empty_target_registry: TargetRegistry) -> None: with pytest.raises(ValueError, match="is empty"): - resolve_constructor_args(cls=_NeedsTarget, raw_args={"converter_target": "missing"}) + _resolve(_NeedsTarget, {"converter_target": "missing"}, identifier_type=ConverterIdentifier) + + +class TestDeriveParameters: + """Tests for deriving the Parameter contract from a constructor signature.""" + + def test_required_and_defaults(self) -> None: + params = {p.name: p for p in derive_parameters(cls=_Plain)} + assert params["count"].default is REQUIRED_VALUE + assert params["ratio"].default == 0.5 + assert params["count"].param_type is int + + def test_optional_unwrapped(self) -> None: + params = {p.name: p for p in derive_parameters(cls=_Plain)} + assert params["note"].param_type is str + + def test_descriptions_parsed(self) -> None: + params = {p.name: p for p in derive_parameters(cls=_Plain)} + assert params["count"].description == "A required count." + + def test_order_follows_signature(self) -> None: + names = [p.name for p in derive_parameters(cls=_Plain)] + assert names == ["count", "ratio", "mode", "note"] + + def test_sentinel_default_is_required(self) -> None: + param = derive_parameters(cls=_SentinelDefault)[0] + assert param.default is REQUIRED_VALUE + assert isinstance(REQUIRED_VALUE, _RequiredValueSentinel) + + def test_var_args_skipped(self) -> None: + names = [p.name for p in derive_parameters(cls=_VarArgs)] + assert names == ["name"] + + def test_identifier_marker_overrides_plain_annotation(self) -> None: + # The identifier marks ``converter_target`` as a TARGET reference, so even a + # plainly-annotated arg of that name becomes a reference (the marker wins). + param = derive_parameters(cls=_StrTargetArg, identifier_type=ConverterIdentifier)[0] + assert param.reference is not None + assert param.reference.component_type is ComponentType.TARGET + + def test_no_identifier_yields_no_references(self) -> None: + # Without an identifier, no parameter is treated as a reference. + param = derive_parameters(cls=_StrTargetArg)[0] + assert param.reference is None + assert param.param_type is str + + +def test_signature_inspection_failure_raises() -> None: + class _NoInit: + __init__ = None # type: ignore[assignment] + + with pytest.raises(ValueError, match="Failed to inspect"): + derive_parameters(cls=_NoInit) def test_module_has_no_backend_dependency() -> None: diff --git a/tests/unit/scenario/core/test_scenario_parameters.py b/tests/unit/scenario/core/test_scenario_parameters.py index 4a013f4365..a8bd9aa3b3 100644 --- a/tests/unit/scenario/core/test_scenario_parameters.py +++ b/tests/unit/scenario/core/test_scenario_parameters.py @@ -3,13 +3,12 @@ """Tests for Scenario custom parameter declaration, coercion, and validation (Stage 1b).""" -from typing import ClassVar +from typing import ClassVar, Literal from unittest.mock import MagicMock import pytest -from pyrit.common import Parameter -from pyrit.models import ComponentIdentifier +from pyrit.models import ComponentIdentifier, Parameter from pyrit.scenario import DatasetConfiguration from pyrit.scenario.core import BaselineAttackPolicy, Scenario, ScenarioStrategy from pyrit.score import Scorer @@ -161,46 +160,52 @@ def test_list_param_rejects_non_list_value(self) -> None: with pytest.raises(ValueError, match="expects a list"): scenario.set_params_from_args(args={"datasets": "single"}) - def test_unsupported_list_element_type_raises(self) -> None: - """list[int] is rejected at declaration time (only list[str] is supported).""" + def test_list_int_coerces_each_element(self) -> None: + """list[int] is supported and coerces each element.""" scenario = _make_scenario(declared_params=[Parameter(name="counts", description="d", param_type=list[int])]) + scenario.set_params_from_args(args={"counts": ["1", "2"]}) + assert scenario.params == {"counts": [1, 2]} + + def test_unsupported_list_element_type_raises(self) -> None: + """A list of a non-scalar element type is rejected at declaration time.""" + scenario = _make_scenario(declared_params=[Parameter(name="tags", description="d", param_type=list[set[str]])]) with pytest.raises(ValueError, match="unsupported.*param_type"): - scenario.set_params_from_args(args={"counts": [1, 2]}) + scenario.set_params_from_args(args={"tags": [{"a"}]}) @pytest.mark.usefixtures("patch_central_database") -class TestSetParamsFromArgsChoices: - """choices validation.""" +class TestSetParamsFromArgsConstrainedScalars: + """Constrained-scalar (Literal) membership validation.""" def test_valid_choice_is_accepted(self) -> None: scenario = _make_scenario( - declared_params=[Parameter(name="mode", description="d", param_type=str, choices=("fast", "slow"))] + declared_params=[Parameter(name="mode", description="d", param_type=Literal["fast", "slow"])] ) scenario.set_params_from_args(args={"mode": "fast"}) assert scenario.params == {"mode": "fast"} def test_invalid_choice_raises(self) -> None: scenario = _make_scenario( - declared_params=[Parameter(name="mode", description="d", param_type=str, choices=("fast", "slow"))] + declared_params=[Parameter(name="mode", description="d", param_type=Literal["fast", "slow"])] ) - with pytest.raises(ValueError, match="not in declared choices"): + with pytest.raises(ValueError, match="one of"): scenario.set_params_from_args(args={"mode": "medium"}) def test_choices_validated_after_coercion(self) -> None: - """A string '5' coerces to int 5, then is checked against int choices.""" + """A string '5' coerces to int 5, then is checked against the int Literal.""" scenario = _make_scenario( - declared_params=[Parameter(name="count", description="d", param_type=int, choices=(1, 5, 10))] + declared_params=[Parameter(name="count", description="d", param_type=Literal[1, 5, 10])] ) scenario.set_params_from_args(args={"count": "5"}) assert scenario.params == {"count": 5} - def test_stringy_choices_accept_typed_user_input(self) -> None: - """Author declares choices as strings; user input is coerced and accepted.""" + def test_list_literal_membership(self) -> None: + """A list of a constrained scalar validates membership per element.""" scenario = _make_scenario( - declared_params=[Parameter(name="count", description="d", param_type=int, choices=("1", "5", "10"))] + declared_params=[Parameter(name="modes", description="d", param_type=list[Literal["a", "b"]])] ) - scenario.set_params_from_args(args={"count": "5"}) - assert scenario.params == {"count": 5} + scenario.set_params_from_args(args={"modes": ["a", "b", "a"]}) + assert scenario.params == {"modes": ["a", "b", "a"]} @pytest.mark.usefixtures("patch_central_database") @@ -306,31 +311,18 @@ def test_invalid_default_type_raises(self) -> None: with pytest.raises(ValueError, match="invalid default"): scenario.set_params_from_args(args={}) - def test_default_not_in_choices_raises(self) -> None: + def test_default_not_in_literal_raises(self) -> None: scenario = _make_scenario( declared_params=[ Parameter( name="mode", description="d", - param_type=str, + param_type=Literal["fast", "slow"], default="medium", - choices=("fast", "slow"), ) ] ) - with pytest.raises(ValueError, match="not in declared choices"): - scenario.set_params_from_args(args={}) - - def test_choices_on_list_param_rejected_at_declaration(self) -> None: - """Combining `choices` with a list param_type is rejected pending semantic resolution. - - argparse's per-item choices for nargs='+' diverges from core's whole-list - post-coercion check, so we forbid the combination at declaration time. - """ - scenario = _make_scenario( - declared_params=[Parameter(name="datasets", description="d", param_type=list[str], choices=("a", "b"))] - ) - with pytest.raises(ValueError, match="choices on a list param_type"): + with pytest.raises(ValueError, match="invalid default"): scenario.set_params_from_args(args={}) def test_unsupported_param_type_rejected_at_declaration(self) -> None: @@ -339,14 +331,6 @@ def test_unsupported_param_type_rejected_at_declaration(self) -> None: with pytest.raises(ValueError, match="unsupported.*param_type"): scenario.set_params_from_args(args={}) - def test_choices_not_coercible_to_param_type_raises(self) -> None: - """A choices tuple with values that cannot be coerced to param_type fails fast.""" - scenario = _make_scenario( - declared_params=[Parameter(name="count", description="d", param_type=int, choices=("a", "b"))] - ) - with pytest.raises(ValueError, match="not coercible to"): - scenario.set_params_from_args(args={}) - def test_repeat_call_does_not_revalidate_declarations(self) -> None: """Once validated, a successful set_params_from_args should not re-run declaration checks. diff --git a/tests/unit/setup/test_pyrit_initializer.py b/tests/unit/setup/test_pyrit_initializer.py index eba9cb303e..7c5a3cd971 100644 --- a/tests/unit/setup/test_pyrit_initializer.py +++ b/tests/unit/setup/test_pyrit_initializer.py @@ -5,12 +5,12 @@ import pytest -from pyrit.common import Parameter from pyrit.common.apply_defaults import ( reset_default_values, set_default_value, set_global_variable, ) +from pyrit.models import Parameter from pyrit.setup.initializers import PyRITInitializer @@ -632,7 +632,7 @@ def test_package_level_alias_warning_points_to_replacement(self) -> None: """The deprecation warning tells users which class to use instead.""" import pyrit.setup.initializers as initializers_module - with pytest.warns(DeprecationWarning, match=r"pyrit\.common\.parameter\.Parameter"): + with pytest.warns(DeprecationWarning, match=r"pyrit\.models\.parameter\.Parameter"): _ = initializers_module.InitializerParameter def test_canonical_module_alias_emits_deprecation_warning(self) -> None: