diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 649fdfdf5e..d80d6a7501 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -42,6 +42,7 @@ Conversation, ConversationReference, ConversationType, + EvaluationIdentifier, MessagePiece, PromptDataType, ScenarioIdentifier, @@ -62,49 +63,28 @@ # Default pyrit_version for database records created before version tracking was added LEGACY_PYRIT_VERSION = "<0.10.0" -# Maximum length for string values in ComponentIdentifier.model_dump() when storing to the database. -# Longer values are truncated with a "..." suffix. -MAX_IDENTIFIER_VALUE_LENGTH: int = 80 - -def _dump_identifier(identifier: ComponentIdentifier | None) -> dict[str, Any] | None: - """ - Serialize a ``ComponentIdentifier`` to a dict for JSON storage, truncating long values. - - Args: - identifier (ComponentIdentifier | None): The identifier to serialize, or None. - - Returns: - dict[str, Any] | None: The serialized identifier, or None if ``identifier`` is falsy. - """ - if not identifier: - return None - return identifier.model_dump(context={"max_value_length": MAX_IDENTIFIER_VALUE_LENGTH}) - - -def _dump_identifiers(identifiers: list[ComponentIdentifier]) -> list[dict[str, Any]]: - """ - Serialize a list of ``ComponentIdentifier`` objects for JSON storage. - - Args: - identifiers (list[ComponentIdentifier]): The identifiers to serialize. - - Returns: - list[dict[str, Any]]: The serialized identifiers in order. - """ - return [ - identifier.model_dump(context={"max_value_length": MAX_IDENTIFIER_VALUE_LENGTH}) for identifier in identifiers - ] - - -def _load_identifier(stored: dict[str, Any] | None, *, pyrit_version: str | None = None) -> ComponentIdentifier | None: +def _load_identifier( + stored: dict[str, Any] | None, + *, + pyrit_version: str | None = None, + eval_identifier_cls: type[EvaluationIdentifier] | None = None, +) -> ComponentIdentifier | None: """ Reconstruct a ``ComponentIdentifier`` from its stored dict representation. + The content hash is recomputed on validation (never trusted from storage). + When ``eval_identifier_cls`` is provided, the ``eval_hash`` is likewise + recomputed from the (full) stored params and re-stamped onto the identifier, + so the stored ``eval_hash`` value is never trusted on reload. + Args: stored (dict[str, Any] | None): The stored identifier dict, or None. pyrit_version (str | None): If provided, injected as the identifier's ``pyrit_version`` so the reconstructed object reflects the version that created the row. + eval_identifier_cls (type[EvaluationIdentifier] | None): If provided, the + ``EvaluationIdentifier`` subclass used to recompute and re-stamp the + identifier's ``eval_hash`` on reload. Returns: ComponentIdentifier | None: The reconstructed identifier, or None if ``stored`` is falsy. @@ -113,7 +93,10 @@ def _load_identifier(stored: dict[str, Any] | None, *, pyrit_version: str | None return None if pyrit_version is not None: stored = {**stored, "pyrit_version": pyrit_version} - return ComponentIdentifier.model_validate(stored) + identifier = ComponentIdentifier.model_validate(stored) + if eval_identifier_cls is not None: + identifier = identifier.with_eval_hash(eval_identifier_cls(identifier).eval_hash) + return identifier def _load_identifiers( @@ -310,7 +293,7 @@ def __init__(self, *, entry: MessagePiece) -> None: self.timestamp = entry.timestamp self.labels = entry.labels self.prompt_metadata = entry.prompt_metadata - self.converter_identifiers = _dump_identifiers(entry.converter_identifiers) + self.converter_identifiers = [identifier.model_dump() for identifier in entry.converter_identifiers] self.original_value = entry.original_value self.original_value_data_type = entry.original_value_data_type @@ -399,7 +382,7 @@ def __init__(self, *, conversation: Conversation) -> None: conversation (Conversation): The conversation metadata to persist. """ self.conversation_id = conversation.conversation_id - self.target_identifier = _dump_identifier(conversation.target_identifier) + self.target_identifier = conversation.target_identifier.model_dump() if conversation.target_identifier else None self.pyrit_version = pyrit.__version__ def get_conversation(self) -> Conversation: @@ -484,12 +467,13 @@ def __init__(self, *, entry: Score) -> None: self.score_rationale = entry.score_rationale self.score_metadata = entry.score_metadata or {} normalized_scorer = entry.scorer_class_identifier - # Ensure eval_hash is set before truncation so it survives the DB round-trip - if normalized_scorer is not None and normalized_scorer.eval_hash is None: + # Always recompute eval_hash before dumping so the stored JSON carries the + # freshly computed value for DB-level filtering (never a value from storage). + if normalized_scorer is not None: normalized_scorer = normalized_scorer.with_eval_hash( ScorerEvaluationIdentifier(normalized_scorer).eval_hash ) - self.scorer_class_identifier = _dump_identifier(normalized_scorer) or {} + self.scorer_class_identifier = normalized_scorer.model_dump() if normalized_scorer else {} self.prompt_request_response_id = entry.message_piece_id if entry.message_piece_id else None self.timestamp = entry.timestamp # Store in both columns for backward compatibility @@ -505,9 +489,14 @@ def get_score(self) -> Score: Returns: Score: The reconstructed score object with all its data. """ - # Convert dict back to ComponentIdentifier with the stored pyrit_version + # Convert dict back to ComponentIdentifier with the stored pyrit_version; + # eval_hash is recomputed on reload via ScorerEvaluationIdentifier. stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION - scorer_identifier = _load_identifier(self.scorer_class_identifier, pyrit_version=stored_version) + scorer_identifier = _load_identifier( + self.scorer_class_identifier, + pyrit_version=stored_version, + eval_identifier_cls=ScorerEvaluationIdentifier, + ) return Score( id=self.id, score_value=self.score_value, @@ -933,12 +922,15 @@ def __init__(self, *, entry: AttackResult) -> None: self.id = uuid.UUID(entry.attack_result_id) self.conversation_id = entry.conversation_id self.objective = entry.objective - # Ensure eval_hash is set before truncation so it survives the DB round-trip - if entry.atomic_attack_identifier and entry.atomic_attack_identifier.eval_hash is None: + # Always recompute eval_hash before dumping so the stored JSON carries the + # freshly computed value for DB-level filtering (never a value from storage). + if entry.atomic_attack_identifier: entry.atomic_attack_identifier = entry.atomic_attack_identifier.with_eval_hash( AtomicAttackEvaluationIdentifier(entry.atomic_attack_identifier).eval_hash ) - self.atomic_attack_identifier = _dump_identifier(entry.atomic_attack_identifier) + self.atomic_attack_identifier = ( + entry.atomic_attack_identifier.model_dump() if entry.atomic_attack_identifier else None + ) self.objective_sha256 = to_sha256(entry.objective) # Use helper method for UUID conversions @@ -1055,7 +1047,11 @@ def get_attack_result(self) -> AttackResult: ) ) - atomic_id = _load_identifier(self.atomic_attack_identifier) + # eval_hash is recomputed on reload via AtomicAttackEvaluationIdentifier. + atomic_id = _load_identifier( + self.atomic_attack_identifier, + eval_identifier_cls=AtomicAttackEvaluationIdentifier, + ) # Deserialize retry events from JSON retry_events = [] @@ -1172,13 +1168,18 @@ def __init__(self, *, entry: ScenarioResult) -> None: self.pyrit_version = entry.scenario_identifier.pyrit_version self.scenario_init_data = entry.scenario_identifier.init_data # Convert ComponentIdentifier to dict for JSON storage - self.objective_target_identifier = _dump_identifier(entry.objective_target_identifier) # type: ignore[ty:invalid-assignment] - # Ensure eval_hash is set before truncation so it survives the DB round-trip. - if entry.objective_scorer_identifier and entry.objective_scorer_identifier.eval_hash is None: + self.objective_target_identifier = ( # type: ignore[ty:invalid-assignment] + entry.objective_target_identifier.model_dump() if entry.objective_target_identifier else None + ) + # Always recompute eval_hash before dumping so the stored JSON carries the + # freshly computed value for DB-level filtering (never a value from storage). + if entry.objective_scorer_identifier: entry.objective_scorer_identifier = entry.objective_scorer_identifier.with_eval_hash( ScorerEvaluationIdentifier(entry.objective_scorer_identifier).eval_hash ) - self.objective_scorer_identifier = _dump_identifier(entry.objective_scorer_identifier) + self.objective_scorer_identifier = ( + entry.objective_scorer_identifier.model_dump() if entry.objective_scorer_identifier else None + ) self.scenario_run_state = entry.scenario_run_state.value self.labels = entry.labels self.number_tries = entry.number_tries @@ -1224,8 +1225,13 @@ def get_scenario_result(self) -> ScenarioResult: # Return empty attack_results - will be populated by memory_interface attack_results: dict[str, list[AttackResult]] = {} - # Convert dict back to ComponentIdentifier with the stored pyrit_version - scorer_identifier = _load_identifier(self.objective_scorer_identifier, pyrit_version=stored_version) + # Convert dict back to ComponentIdentifier with the stored pyrit_version; + # eval_hash is recomputed on reload via ScorerEvaluationIdentifier. + scorer_identifier = _load_identifier( + self.objective_scorer_identifier, + pyrit_version=stored_version, + eval_identifier_cls=ScorerEvaluationIdentifier, + ) # Convert dict back to ComponentIdentifier for reconstruction target_identifier = _load_identifier(self.objective_target_identifier) diff --git a/pyrit/models/identifiers/component_identifier.py b/pyrit/models/identifiers/component_identifier.py index 19834e3450..284197f7da 100644 --- a/pyrit/models/identifiers/component_identifier.py +++ b/pyrit/models/identifiers/component_identifier.py @@ -22,7 +22,16 @@ from abc import ABC, abstractmethod from typing import Any, ClassVar, get_args, get_origin -from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, model_serializer, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + SerializationInfo, + computed_field, + model_serializer, + model_validator, +) from typing_extensions import Self, TypeAliasType import pyrit @@ -136,7 +145,7 @@ def _dump_child_identifiers_to_dict(value: Any) -> Any: base ``ComponentIdentifier`` (or a different subclass) for that slot, which Pydantic's strict model validation would reject. Dumping such instances to their flat ``model_dump()`` dict lets validation re-parse them into the - declared subclass; the stored ``hash`` rides along, so identity is preserved. + declared subclass; the content hash is recomputed identically on revalidation. Args: value (Any): The raw child value (an identifier instance, a dict, a list @@ -177,9 +186,10 @@ class ComponentIdentifier(BaseModel): Serialization: ``model_dump()`` returns a flat dict where reserved keys (``class_name``, ``class_module``, ``hash``, ``pyrit_version``, ``eval_hash``, ``children``) sit at the top level alongside the inlined param values. This shape is - also the storage / REST format. Pass ``context={"max_value_length": N}`` to truncate - long string param values. ``model_validate()`` accepts the same flat shape (plus a - structured form with an explicit ``params`` dict). + also the storage / REST format. Param values are stored in full (no truncation). + ``model_validate()`` accepts the same flat shape (plus a structured form with an + explicit ``params`` dict); the content ``hash`` is always recomputed on validation, + so any stored ``hash`` is ignored. Mutability: the model is frozen, but ``params`` and ``children`` are dicts whose contents are not deep-frozen — mutating them after construction creates an @@ -208,16 +218,18 @@ class ComponentIdentifier(BaseModel): params: dict[str, JSONValue] = Field(default_factory=dict) #: Named child identifiers for compositional identity (e.g., a scorer's target). children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = Field(default_factory=dict) - #: Content-addressed SHA256 hash. Computed automatically when ``None``; - #: pass an explicit value to preserve a hash from DB storage where params - #: may have been truncated. - hash: str | None = None #: Version tag for storage. Not included in the content hash. pyrit_version: str = Field(default=pyrit.__version__) - #: Evaluation hash. Computed by EvaluationIdentifier subclasses and attached - #: to the identifier so it survives DB round-trips with truncated params. + #: Evaluation hash. The base identifier cannot compute it (the eval rules live + #: in EvaluationIdentifier subclasses), so it is attached only through + #: ``with_eval_hash``, which is the single supported way to set it. Stamped on + #: so it lands in the stored JSON for DB-level filtering. eval_hash: str | None = None + #: Cache backing the read-only ``hash`` computed field. Populated once by the + #: after-validator from the identifier's content; never set externally. + _hash: str = PrivateAttr(default="") + # ------------------------------------------------------------------ # Promotion (typed projection — derived from the subclass's own fields) # ------------------------------------------------------------------ @@ -317,6 +329,11 @@ def _normalize_input(cls, data: Any) -> Any: data = dict(data) + # hash is a read-only computed field, never supplied. Drop any incoming + # value (a constructor kwarg or one read back from the flat storage form) + # so extra="forbid" does not reject it; the computed field derives it. + data.pop(cls.KEY_HASH, None) + # Map legacy keys onto canonical keys when canonical is absent. if cls.KEY_CLASS_NAME not in data and cls.LEGACY_KEY_TYPE in data: data[cls.KEY_CLASS_NAME] = data.pop(cls.LEGACY_KEY_TYPE) @@ -383,7 +400,7 @@ def _normalize_input(cls, data: Any) -> Any: # (possibly a base ComponentIdentifier or a different subclass than # the typed field declares). Dump them to their flat dict form so # Pydantic re-parses them into the declared identifier subclass. - # Round-tripping through model_dump preserves the stored hash. + # The content hash is recomputed identically on revalidation. for name in cls._promoted_child_fields(): if name in data: data[name] = _dump_child_identifiers_to_dict(data[name]) @@ -391,16 +408,16 @@ def _normalize_input(cls, data: Any) -> Any: return data @model_validator(mode="after") - def _promote_and_compute_hash(self) -> ComponentIdentifier: + def _promote_typed_fields(self) -> ComponentIdentifier: """ - Mirror promoted typed fields into ``params`` / ``children`` and hash. + Mirror promoted typed fields into ``params`` / ``children``, then hash. Promoted scalar fields are written into ``params`` and promoted identifier fields into ``children`` (``None`` / empty list dropped), so a typed subclass serializes and hashes identically to a plain ``ComponentIdentifier`` with the same values. The content-addressed hash - is then computed if it was not provided — a pre-set hash (e.g. one - reconstructed from a truncated DB row) is preserved. + is then computed once from the populated content and cached in ``_hash``, + backing the read-only ``hash`` computed field. Returns: ``self`` (mutated in-place). @@ -422,16 +439,32 @@ def _promote_and_compute_hash(self) -> ComponentIdentifier: else: self.children[name] = value - if self.hash is None: - hash_dict = _build_hash_dict( - class_name=self.class_name, - class_module=self.class_module, - params=self.params, - children=self.children, - ) - object.__setattr__(self, "hash", config_hash(hash_dict)) + hash_dict = _build_hash_dict( + class_name=self.class_name, + class_module=self.class_module, + params=self.params, + children=self.children, + ) + object.__setattr__(self, "_hash", config_hash(hash_dict)) return self + @computed_field # type: ignore[prop-decorator] + @property + def hash(self) -> str: + """ + Content-addressed SHA256 hash, derived from this identifier's content. + + Computed once by the after-validator from ``class_name`` / + ``class_module`` / ``params`` / ``children`` and cached in ``_hash``. It + is a read-only computed field: nothing can set it, and any ``hash`` value + supplied at construction (a kwarg, or one read back from the flat storage + form) is dropped before validation. + + Returns: + The SHA256 content hash. + """ + return self._hash + # ------------------------------------------------------------------ # Serializer # ------------------------------------------------------------------ @@ -441,15 +474,13 @@ def _serialize_flat(self, info: SerializationInfo) -> dict[str, Any]: """ Emit the flat storage shape. - Honors ``context={"max_value_length": N}`` to truncate long string - param values, propagating both context and mode (``"python"`` vs - ``"json"``) into recursive child dumps. + Propagates the serialization mode (``"python"`` vs ``"json"``) into + recursive child dumps. Values are stored in full — identifiers are no + longer truncated. Returns: The flat dict representation of this identifier. """ - context = info.context if isinstance(info.context, dict) else {} - max_len = context.get("max_value_length") mode = info.mode result: dict[str, Any] = { @@ -461,16 +492,15 @@ def _serialize_flat(self, info: SerializationInfo) -> dict[str, Any]: if self.eval_hash is not None: result[self.KEY_EVAL_HASH] = self.eval_hash - for key, value in self.params.items(): - result[key] = self._truncate_value(value=value, max_length=max_len) + result.update(self.params) if self.children: serialized_children: dict[str, Any] = {} for name, child in self.children.items(): if isinstance(child, ComponentIdentifier): - serialized_children[name] = child.model_dump(mode=mode, context=context) + serialized_children[name] = child.model_dump(mode=mode) elif isinstance(child, list): - serialized_children[name] = [c.model_dump(mode=mode, context=context) for c in child] + serialized_children[name] = [c.model_dump(mode=mode) for c in child] result[self.KEY_CHILDREN] = serialized_children return result @@ -508,10 +538,10 @@ def with_eval_hash(self, eval_hash: str) -> ComponentIdentifier: """ Return a new identifier with ``eval_hash`` set. - Builds a fresh instance, passing the existing ``hash`` through - explicitly so it is preserved rather than recomputed. This matters - for identifiers reconstructed from truncated DB data, where - recomputing from the truncated params would produce a wrong hash. + This is the single supported way to set ``eval_hash``: it is not + computed by the base model, so callers attach it here rather than via + the constructor. The content hash is recomputed from the (unchanged) + params and children, so it is identical to this identifier's hash. Args: eval_hash: The evaluation hash to attach. @@ -525,7 +555,6 @@ def with_eval_hash(self, eval_hash: str) -> ComponentIdentifier: class_module=self.class_module, params=self.params, children=self.children, - hash=self.hash, pyrit_version=self.pyrit_version, eval_hash=eval_hash, ) @@ -538,12 +567,7 @@ def with_eval_hash(self, eval_hash: str) -> ComponentIdentifier: def short_hash(self) -> str: """ Return the first 8 characters of the hash for display and logging. - - Raises: - RuntimeError: If the hash has not been set by the validator. """ - if self.hash is None: - raise RuntimeError("hash should be set by validator") return self.hash[:8] @property @@ -628,8 +652,8 @@ def from_component_identifier(cls, identifier: ComponentIdentifier) -> Self: Pass-through when ``identifier`` is already an instance of ``cls``; otherwise revalidate its flat dump into ``cls`` (e.g. a base identifier - loaded from the DB), rehydrating promoted typed fields. The hash is - preserved across the round-trip. + loaded from the DB), rehydrating promoted typed fields. The content hash + is recomputed identically across the round-trip. Args: identifier: A ``ComponentIdentifier`` (possibly the base type). @@ -695,33 +719,14 @@ def _collect_child_eval_hashes(self) -> set[str]: hashes.update(child._collect_child_eval_hashes()) return hashes - @staticmethod - def _truncate_value(*, value: Any, max_length: int | None) -> Any: - """ - Truncate string values longer than ``max_length`` with a ``...`` suffix. - - Args: - value: The value to potentially truncate. - max_length: Maximum length, or ``None`` to disable. - - Returns: - The (possibly truncated) value. - """ - if max_length is not None and isinstance(value, str) and len(value) > max_length: - return value[:max_length] + "..." - return value - # ------------------------------------------------------------------ # Deprecated shims — kept for one release cycle # ------------------------------------------------------------------ - def to_dict(self, *, max_value_length: int | None = None) -> dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """ Return the flat storage dict (deprecated; use ``model_dump`` instead). - Args: - max_value_length: Optional truncation length for string params. - Returns: The flat dict representation. """ @@ -730,8 +735,7 @@ def to_dict(self, *, max_value_length: int | None = None) -> dict[str, Any]: new_item="ComponentIdentifier.model_dump", removed_in="0.16.0", ) - context = {"max_value_length": max_value_length} if max_value_length is not None else None - return self.model_dump(context=context) + return self.model_dump() @classmethod def from_dict(cls, data: dict[str, Any]) -> ComponentIdentifier: diff --git a/pyrit/models/identifiers/evaluation_identifier.py b/pyrit/models/identifiers/evaluation_identifier.py index f3cc7a05bc..73823fe6cd 100644 --- a/pyrit/models/identifiers/evaluation_identifier.py +++ b/pyrit/models/identifiers/evaluation_identifier.py @@ -241,8 +241,6 @@ def compute_eval_hash( identifier = inner[0] if not child_eval_rules and own_rule is None: - if identifier.hash is None: - raise RuntimeError("hash should be set by __post_init__") return identifier.hash eval_dict = _build_eval_dict( @@ -465,24 +463,20 @@ def __init_subclass__(cls, **kwargs: Any) -> None: def __init__(self, identifier: ComponentIdentifier) -> None: """ - Wrap a ComponentIdentifier and resolve its eval hash. + Wrap a ComponentIdentifier and compute its eval hash. - If the identifier carries an ``eval_hash`` (preserved from a prior - DB round-trip or set by the scorer), that value is used directly. - Otherwise the eval hash is computed from the identifier's params - and children using the subclass's ``CHILD_EVAL_RULES``, ``OWN_RULE``, - and ``ROOT_UNWRAP_CHILD``. + The eval hash is always computed fresh from the identifier's params and + children using the subclass's ``CHILD_EVAL_RULES``, ``OWN_RULE``, and + ``ROOT_UNWRAP_CHILD`` — any ``eval_hash`` already carried on the + identifier (e.g. a value read back from storage) is never trusted. """ self._identifier = identifier - if identifier.eval_hash is not None: - self._eval_hash = identifier.eval_hash - else: - self._eval_hash = compute_eval_hash( - identifier, - child_eval_rules=self.CHILD_EVAL_RULES, - own_rule=self.OWN_RULE, - root_unwrap_child=self.ROOT_UNWRAP_CHILD, - ) + self._eval_hash = compute_eval_hash( + identifier, + child_eval_rules=self.CHILD_EVAL_RULES, + own_rule=self.OWN_RULE, + root_unwrap_child=self.ROOT_UNWRAP_CHILD, + ) @property def identifier(self) -> ComponentIdentifier: diff --git a/pyrit/scenario/core/atomic_attack.py b/pyrit/scenario/core/atomic_attack.py index c35c521da6..0a8b9ba17e 100644 --- a/pyrit/scenario/core/atomic_attack.py +++ b/pyrit/scenario/core/atomic_attack.py @@ -22,7 +22,6 @@ from pyrit.executor.attack.core.attack_executor import AttackExecutorResult from pyrit.executor.attack.core.attack_result_attribution import AttackResultAttribution from pyrit.memory import CentralMemory -from pyrit.memory.memory_models import MAX_IDENTIFIER_VALUE_LENGTH from pyrit.models import AtomicAttackEvaluationIdentifier, AtomicAttackIdentifier, AttackResult, SeedAttackGroup from pyrit.scenario.core.attack_technique import AttackTechnique @@ -439,9 +438,8 @@ def _enrich_atomic_attack_identifiers(self, *, results: AttackExecutorResult[Att ) # Persist the enriched identifier back to the database. - # Set eval_hash before truncation so it survives the DB round-trip. - if identifier.eval_hash is None: - identifier = identifier.with_eval_hash(AtomicAttackEvaluationIdentifier(identifier).eval_hash) + # Stamp eval_hash so it lands in the stored JSON for DB-level filtering. + identifier = identifier.with_eval_hash(AtomicAttackEvaluationIdentifier(identifier).eval_hash) result.atomic_attack_identifier = identifier @@ -449,8 +447,6 @@ def _enrich_atomic_attack_identifiers(self, *, results: AttackExecutorResult[Att memory.update_attack_result_by_id( attack_result_id=result.attack_result_id, update_fields={ - "atomic_attack_identifier": identifier.model_dump( - context={"max_value_length": MAX_IDENTIFIER_VALUE_LENGTH}, - ), + "atomic_attack_identifier": identifier.model_dump(), }, ) diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index ee0dc4a9fc..031c7c1553 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -134,9 +134,8 @@ def get_identifier(self) -> ComponentIdentifier: ComponentIdentifier: The identity with ``eval_hash`` set. """ identifier = super().get_identifier() - if identifier.eval_hash is None: - identifier = identifier.with_eval_hash(ScorerEvaluationIdentifier(identifier).eval_hash) - self._identifier = identifier + identifier = identifier.with_eval_hash(ScorerEvaluationIdentifier(identifier).eval_hash) + self._identifier = identifier return identifier @property diff --git a/tests/unit/memory/test_memory_models.py b/tests/unit/memory/test_memory_models.py index 852b6de9d8..a2b244fc1f 100644 --- a/tests/unit/memory/test_memory_models.py +++ b/tests/unit/memory/test_memory_models.py @@ -18,7 +18,6 @@ ScoreEntry, SeedEntry, UTCDateTime, - _dump_identifier, _load_identifier, ) from pyrit.models import ( @@ -140,10 +139,6 @@ def test_utcdatetime_passes_through_none(): # --------------------------------------------------------------------------- -def test_dump_identifier_returns_none_for_none(): - assert _dump_identifier(None) is None - - def test_load_identifier_returns_none_for_falsy(): assert _load_identifier(None) is None assert _load_identifier({}) is None @@ -151,7 +146,7 @@ def test_load_identifier_returns_none_for_falsy(): def test_dump_then_load_identifier_round_trips(): identifier = ComponentIdentifier(class_name="MyConverter", class_module="pyrit.converters", pyrit_version="0.1.0") - stored = _dump_identifier(identifier) + stored = identifier.model_dump() assert stored is not None loaded = _load_identifier(stored) assert loaded is not None @@ -161,7 +156,7 @@ def test_dump_then_load_identifier_round_trips(): def test_load_identifier_injects_pyrit_version(): identifier = ComponentIdentifier(class_name="MyConverter", class_module="pyrit.converters", pyrit_version="0.1.0") - stored = _dump_identifier(identifier) + stored = identifier.model_dump() loaded = _load_identifier(stored, pyrit_version="9.9.9") assert loaded is not None assert loaded.pyrit_version == "9.9.9" diff --git a/tests/unit/models/identifiers/test_component_identifier.py b/tests/unit/models/identifiers/test_component_identifier.py index ec51120619..a1cf622544 100644 --- a/tests/unit/models/identifiers/test_component_identifier.py +++ b/tests/unit/models/identifiers/test_component_identifier.py @@ -92,6 +92,25 @@ def test_pyrit_version_set(self): class TestComponentIdentifierHash: """Tests for hash computation.""" + def test_hash_cannot_be_set_via_constructor(self): + """Test that a hash supplied at construction is dropped and recomputed.""" + computed = ComponentIdentifier(class_name="C", class_module="m", params={"key": "value"}).hash + with_bogus = ComponentIdentifier( + class_name="C", + class_module="m", + params={"key": "value"}, + hash="bogus-not-used", + ) + assert with_bogus.hash == computed + + def test_hash_dropped_from_flat_storage_on_load(self): + """Test that a stored hash is dropped and recomputed on model_validate.""" + ident = ComponentIdentifier(class_name="C", class_module="m", params={"key": "value"}) + stored = ident.model_dump() + stored["hash"] = "tampered-value" + reloaded = ComponentIdentifier.model_validate(stored) + assert reloaded.hash == ident.hash + def test_hash_deterministic(self): """Test that identical configs produce the same hash.""" id1 = ComponentIdentifier( @@ -228,7 +247,7 @@ def test_to_dict_no_children_key_when_empty(self): assert "children" not in result def test_to_dict_no_truncation_by_default(self): - """Test that values are not truncated when max_value_length is not set.""" + """Test that values are stored in full (truncation removed).""" long_value = "x" * 200 identifier = ComponentIdentifier( class_name="Target", @@ -238,55 +257,32 @@ def test_to_dict_no_truncation_by_default(self): result = identifier.to_dict() assert result["system_prompt"] == long_value - def test_to_dict_truncates_long_string_params(self): - """Test that string params exceeding max_value_length are truncated.""" - long_value = "x" * 200 - identifier = ComponentIdentifier( - class_name="Target", - class_module="mod", - params={"system_prompt": long_value}, - ) - result = identifier.to_dict(max_value_length=100) - assert result["system_prompt"] == "x" * 100 + "..." - assert len(result["system_prompt"]) == 103 - - def test_to_dict_does_not_truncate_short_string_params(self): - """Test that string params within max_value_length are not truncated.""" - short_value = "short" - identifier = ComponentIdentifier( - class_name="Target", - class_module="mod", - params={"system_prompt": short_value}, - ) - result = identifier.to_dict(max_value_length=100) - assert result["system_prompt"] == short_value - def test_to_dict_does_not_truncate_non_string_params(self): - """Test that non-string params are not affected by max_value_length.""" + """Test that non-string params are stored unchanged.""" identifier = ComponentIdentifier( class_name="Target", class_module="mod", params={"count": 999999, "flag": True}, ) - result = identifier.to_dict(max_value_length=5) + result = identifier.to_dict() assert result["count"] == 999999 assert result["flag"] is True - def test_to_dict_does_not_truncate_structural_keys(self): - """Test that class_name, class_module, hash, pyrit_version are never truncated.""" + def test_to_dict_preserves_structural_keys(self): + """Test that class_name, class_module, hash, pyrit_version are stored unchanged.""" long_module = "pyrit.module." + "sub." * 50 identifier = ComponentIdentifier( class_name="VeryLongClassNameForTesting", class_module=long_module, ) - result = identifier.to_dict(max_value_length=10) + result = identifier.to_dict() assert result["class_name"] == "VeryLongClassNameForTesting" assert result["class_module"] == long_module assert result["hash"] == identifier.hash assert result["pyrit_version"] == identifier.pyrit_version - def test_to_dict_truncation_propagates_to_children(self): - """Test that max_value_length is propagated to children.""" + def test_to_dict_stores_full_child_values(self): + """Test that child values are stored in full (no truncation).""" long_value = "y" * 200 child = ComponentIdentifier( class_name="Child", @@ -298,12 +294,12 @@ def test_to_dict_truncation_propagates_to_children(self): class_module="mod.parent", children={"target": child}, ) - result = parent.to_dict(max_value_length=50) + result = parent.to_dict() child_result = result["children"]["target"] - assert child_result["endpoint"] == "y" * 50 + "..." + assert child_result["endpoint"] == long_value - def test_to_dict_truncation_propagates_to_list_children(self): - """Test that max_value_length is propagated to list children.""" + def test_to_dict_stores_full_list_child_values(self): + """Test that list-child values are stored in full (no truncation).""" long_value = "z" * 200 c1 = ComponentIdentifier(class_name="Conv1", class_module="m", params={"data": long_value}) c2 = ComponentIdentifier(class_name="Conv2", class_module="m", params={"data": "short"}) @@ -312,8 +308,8 @@ def test_to_dict_truncation_propagates_to_list_children(self): class_module="m", children={"converters": [c1, c2]}, ) - result = parent.to_dict(max_value_length=80) - assert result["children"]["converters"][0]["data"] == "z" * 80 + "..." + result = parent.to_dict() + assert result["children"]["converters"][0]["data"] == long_value assert result["children"]["converters"][1]["data"] == "short" @@ -333,8 +329,10 @@ def test_from_dict_basic(self): identifier = ComponentIdentifier.from_dict(data) assert identifier.class_name == "TestClass" assert identifier.class_module == "test.module" - # Stored hash is preserved as-is - assert identifier.hash == stored_hash + # The stored hash is ignored; the content hash is always recomputed. + fresh = ComponentIdentifier(class_name="TestClass", class_module="test.module") + assert identifier.hash == fresh.hash + assert identifier.hash != stored_hash def test_from_dict_with_params(self): """Test from_dict with inlined params.""" @@ -422,12 +420,8 @@ def test_from_dict_does_not_mutate_input(self): ComponentIdentifier.from_dict(data) assert data == original - def test_from_dict_preserves_stored_hash(self): - """Test that from_dict preserves the stored hash rather than recomputing it. - - The stored hash was computed from untruncated data and is the correct identity. - Recomputing from potentially truncated DB values would produce a wrong hash. - """ + def test_from_dict_recomputes_hash_from_full_params(self): + """Test that from_dict recomputes the content hash from the (full) stored params.""" original = ComponentIdentifier( class_name="Target", class_module="mod", @@ -435,17 +429,15 @@ def test_from_dict_preserves_stored_hash(self): ) original_hash = original.hash - # Serialize with truncation (simulates DB storage with column limits) - truncated_dict = original.to_dict(max_value_length=50) - # The stored hash in truncated_dict is the original (correct) hash - assert truncated_dict["hash"] == original_hash + # Full values are stored (no truncation), so the recomputed hash matches. + stored_dict = original.to_dict() + assert stored_dict["hash"] == original_hash - # Deserialize — from_dict should preserve the stored hash - reconstructed = ComponentIdentifier.from_dict(truncated_dict) + reconstructed = ComponentIdentifier.from_dict(stored_dict) assert reconstructed.hash == original_hash - def test_from_dict_preserves_stored_hash_with_children(self): - """Test that from_dict preserves stored hash when children have truncated params.""" + def test_from_dict_recomputes_hash_with_children(self): + """Test that from_dict recomputes hashes from full stored params for parent and children.""" child = ComponentIdentifier( class_name="Child", class_module="mod.child", @@ -459,17 +451,16 @@ def test_from_dict_preserves_stored_hash_with_children(self): original_parent_hash = parent.hash original_child_hash = child.hash - truncated_dict = parent.to_dict(max_value_length=50) - reconstructed = ComponentIdentifier.from_dict(truncated_dict) + stored_dict = parent.to_dict() + reconstructed = ComponentIdentifier.from_dict(stored_dict) - # Both parent and child should preserve their stored hashes assert reconstructed.hash == original_parent_hash child_recon = reconstructed.children["target"] assert isinstance(child_recon, ComponentIdentifier) assert child_recon.hash == original_child_hash - def test_from_dict_preserves_explicit_stored_hash(self): - """Test that from_dict uses the stored hash value exactly as provided.""" + def test_from_dict_ignores_explicit_stored_hash(self): + """Test that from_dict recomputes the hash, ignoring any stored hash value.""" known_hash = "abc123def456" * 5 + "abcd" # 64 chars data = { "class_name": "Test", @@ -478,7 +469,9 @@ def test_from_dict_preserves_explicit_stored_hash(self): "param": "value", } identifier = ComponentIdentifier.from_dict(data) - assert identifier.hash == known_hash + fresh = ComponentIdentifier(class_name="Test", class_module="mod", params={"param": "value"}) + assert identifier.hash == fresh.hash + assert identifier.hash != known_hash def test_from_dict_computes_hash_when_no_stored_hash(self): """Test that from_dict computes a hash when none is stored.""" @@ -559,33 +552,24 @@ def test_roundtrip_preserves_eval_hash(self): reconstructed = ComponentIdentifier.from_dict(d) assert reconstructed.eval_hash == expected_eval_hash - def test_roundtrip_eval_hash_survives_truncation(self): - """Regression test: eval_hash computed before truncation is preserved after round-trip. - - This is the core bug fix — long params get truncated in to_dict(), which would - cause eval_hash recomputation to produce a wrong hash. By storing eval_hash in - the dict, it survives truncation. - """ - long_prompt = "You are a scorer that evaluates responses. " * 20 # >80 chars - eval_hash_before_truncation = "correct_eval_hash_" + "0" * 46 # 64 chars + def test_roundtrip_eval_hash_survives_full_value_roundtrip(self): + """Test that a stored eval_hash survives a to_dict -> from_dict round-trip.""" + long_prompt = "You are a scorer that evaluates responses. " * 20 + stored_eval_hash = "correct_eval_hash_" + "0" * 46 # 64 chars original = ComponentIdentifier( class_name="SelfAskTrueFalseScorer", class_module="pyrit.score", params={"system_prompt_template": long_prompt}, - ).with_eval_hash(eval_hash_before_truncation) - - # Serialize with truncation (simulates DB storage) - truncated_dict = original.to_dict(max_value_length=80) - # Params are truncated - assert truncated_dict["system_prompt_template"].endswith("...") - # But eval_hash is preserved - assert truncated_dict["eval_hash"] == eval_hash_before_truncation - - # Deserialize - reconstructed = ComponentIdentifier.from_dict(truncated_dict) - # eval_hash is available on the reconstructed identifier - assert reconstructed.eval_hash == eval_hash_before_truncation - # And it's NOT in params (from_dict pops it as a reserved key) + ).with_eval_hash(stored_eval_hash) + + stored_dict = original.to_dict() + # Full params are stored (no truncation). + assert stored_dict["system_prompt_template"] == long_prompt + assert stored_dict["eval_hash"] == stored_eval_hash + + reconstructed = ComponentIdentifier.from_dict(stored_dict) + assert reconstructed.eval_hash == stored_eval_hash + # eval_hash is not part of params (popped as a reserved key). assert "eval_hash" not in reconstructed.params def test_roundtrip_no_eval_hash_when_not_set(self): @@ -627,14 +611,14 @@ def test_double_roundtrip_preserves_eval_hash_and_identity_hash(self): eval_hash = "eval_" + "a1b2c3d4" * 7 + "a1b2c3" # 64 chars original = original.with_eval_hash(eval_hash) - # First round-trip: store with truncation - d1 = original.to_dict(max_value_length=80) + # First round-trip + d1 = original.to_dict() r1 = ComponentIdentifier.from_dict(d1) assert r1.hash == original_hash assert r1.eval_hash == eval_hash - # Second round-trip: re-store (simulating retrieve → use → re-store) - d2 = r1.to_dict(max_value_length=80) + # Second round-trip (simulating retrieve → use → re-store) + d2 = r1.to_dict() r2 = ComponentIdentifier.from_dict(d2) assert r2.hash == original_hash assert r2.eval_hash == eval_hash @@ -1320,13 +1304,9 @@ def test_mixed_children_with_and_without_eval_hash(self): assert parent._collect_child_eval_hashes() == {"has_hash"} -def test_short_hash_raises_when_hash_none(): - obj = ComponentIdentifier.__new__(ComponentIdentifier) - object.__setattr__(obj, "hash", None) - object.__setattr__(obj, "class_name", "Test") - object.__setattr__(obj, "class_module", "test.module") - with pytest.raises(RuntimeError, match="hash should be set"): - _ = obj.short_hash +def test_short_hash_returns_hash_prefix(): + identifier = ComponentIdentifier(class_name="Test", class_module="test.module") + assert identifier.short_hash == identifier.hash[:8] class TestComponentIdentifierPydanticMethods: @@ -1355,17 +1335,17 @@ def test_model_dump_matches_to_dict_nested(self): warnings.simplefilter("ignore", DeprecationWarning) assert ident.model_dump() == ident.to_dict() - def test_model_dump_context_truncates(self): + def test_model_dump_stores_full_value(self): ident = ComponentIdentifier(class_name="Foo", class_module="m", params={"v": "x" * 200}) - dumped = ident.model_dump(context={"max_value_length": 50}) - assert isinstance(dumped["v"], str) and len(dumped["v"]) < 200 + dumped = ident.model_dump() + assert dumped["v"] == "x" * 200 - def test_model_dump_context_propagates_to_children(self): + def test_model_dump_stores_full_nested_values(self): child = ComponentIdentifier(class_name="C", class_module="m", params={"v": "y" * 200}) parent = ComponentIdentifier(class_name="P", class_module="m", params={"v": "x" * 200}, children={"c": child}) - dumped = parent.model_dump(context={"max_value_length": 50}) - assert len(dumped["v"]) < 200 - assert len(dumped["children"]["c"]["v"]) < 200 + dumped = parent.model_dump() + assert dumped["v"] == "x" * 200 + assert dumped["children"]["c"]["v"] == "y" * 200 def test_model_validate_roundtrip(self): ident = self._nested() @@ -1374,14 +1354,14 @@ def test_model_validate_roundtrip(self): assert rebuilt.hash == ident.hash assert rebuilt.children["c"].hash == ident.children["c"].hash - def test_model_validate_preserves_stored_hash(self): - # Simulates DB round-trip where params were truncated but hash was preserved. + def test_model_validate_recomputes_hash(self): + # The content hash is always recomputed from params, never trusted from storage. ident = self._simple() stored_hash = ident.hash flat = ident.model_dump() - flat["a"] = "TRUNCATED" + flat["a"] = "MUTATED" rebuilt = ComponentIdentifier.model_validate(flat) - assert rebuilt.hash == stored_hash + assert rebuilt.hash != stored_hash def test_model_validate_omits_eval_hash_when_none(self): ident = self._simple() @@ -1397,11 +1377,12 @@ def test_with_eval_hash_preserves_stored_hash(self): assert new.hash == stored_hash assert new.eval_hash == "abc123" - def test_with_eval_hash_preserves_truncated_hash(self): - # A hash reconstructed from truncated params must survive unchanged. + def test_with_eval_hash_recomputes_hash(self): + # hash cannot be set; a passed-in value is dropped and recomputed from content. ident = ComponentIdentifier(class_name="Foo", class_module="m", params={"a": 1}, hash="deadbeef") + fresh = ComponentIdentifier(class_name="Foo", class_module="m", params={"a": 1}) new = ident.with_eval_hash("abc123") - assert new.hash == "deadbeef" + assert new.hash == fresh.hash assert new.eval_hash == "abc123" def test_with_eval_hash_returns_new_instance(self): diff --git a/tests/unit/models/identifiers/test_evaluation_identifier.py b/tests/unit/models/identifiers/test_evaluation_identifier.py index d7230fc613..4798559c92 100644 --- a/tests/unit/models/identifiers/test_evaluation_identifier.py +++ b/tests/unit/models/identifiers/test_evaluation_identifier.py @@ -224,17 +224,19 @@ class CustomIdentity(EvaluationIdentifier): ) assert identity.eval_hash == expected - def test_uses_eval_hash_when_available(self): - """Test that EvaluationIdentifier uses eval_hash instead of recomputing.""" + def test_always_recomputes_eval_hash_ignoring_stored(self): + """Test that EvaluationIdentifier always recomputes, ignoring any stored eval_hash.""" stored_hash = "stored_eval_hash_value_" + "0" * 42 # 64 chars cid = ComponentIdentifier( class_name="Scorer", class_module="pyrit.score", - params={"system_prompt": "truncated..."}, + params={"system_prompt": "full value"}, ).with_eval_hash(stored_hash) identity = _StubEvaluationIdentifier(cid) - assert identity.eval_hash == stored_hash + expected = compute_eval_hash(cid, child_eval_rules=_StubEvaluationIdentifier.CHILD_EVAL_RULES) + assert identity.eval_hash == expected + assert identity.eval_hash != stored_hash def test_computes_eval_hash_when_not_set(self): """Test that eval_hash is computed normally when eval_hash is None.""" @@ -249,55 +251,37 @@ def test_computes_eval_hash_when_not_set(self): expected = compute_eval_hash(cid, child_eval_rules=_StubEvaluationIdentifier.CHILD_EVAL_RULES) assert identity.eval_hash == expected - def test_truncation_roundtrip_preserves_eval_hash(self): - """Regression test: eval_hash survives DB round-trip with param truncation. + def test_full_value_roundtrip_recomputes_matching_eval_hash(self): + """Test that eval_hash recomputes consistently after a full-value DB round-trip. - This is the core scenario for the bug fix. A scorer with a long system_prompt - gets stored to the DB with truncation. The eval_hash computed from the untruncated - identifier is included in to_dict(). After from_dict() reconstruction, the - EvaluationIdentifier should use the stored eval_hash (not recompute from truncated params). + With truncation removed, full params survive the round-trip, so the always-recompute + EvaluationIdentifier produces the same eval_hash before and after the round-trip. """ - # Build a scorer identifier with a long system_prompt and a target child - long_prompt = "Evaluate whether the response achieves the objective. " * 10 target_child = ComponentIdentifier( class_name="OpenAIChatTarget", class_module="pyrit.prompt_target", params={"model_name": "gpt-4o", "endpoint": "https://api.openai.com", "temperature": 0.0}, ) + long_prompt = "Evaluate whether the response achieves the objective. " * 10 scorer_id = ComponentIdentifier( class_name="SelfAskTrueFalseScorer", class_module="pyrit.score", params={"system_prompt_template": long_prompt}, - children={"prompt_target": target_child}, + children={"my_target": target_child}, ) - # Compute eval_hash from the untruncated identifier (the correct hash) - correct_eval_hash = compute_eval_hash(scorer_id, child_eval_rules=_CHILD_EVAL_RULES) - scorer_id = scorer_id.with_eval_hash(correct_eval_hash) - - # Simulate DB storage: serialize with truncation - truncated_dict = scorer_id.to_dict(max_value_length=80) - - # Verify params are actually truncated - assert truncated_dict["system_prompt_template"].endswith("...") - - # Reconstruct from truncated dict (simulates DB read) - reconstructed = ComponentIdentifier.from_dict(truncated_dict) + original_eval_hash = _StubEvaluationIdentifier(scorer_id).eval_hash - # The reconstructed identifier has truncated params, so recomputing would give wrong hash - recomputed = compute_eval_hash(reconstructed, child_eval_rules=_CHILD_EVAL_RULES) - assert recomputed != correct_eval_hash, "Truncated params should produce different eval_hash" + # Simulate DB storage: full values are retained (no truncation). + stored_dict = scorer_id.to_dict() + assert stored_dict["system_prompt_template"] == long_prompt - # But EvaluationIdentifier uses the preserved eval_hash, giving the correct result - identity = _StubEvaluationIdentifier(reconstructed) - assert identity.eval_hash == correct_eval_hash + # Reconstruct from the stored dict (simulates DB read) and recompute. + reconstructed = ComponentIdentifier.from_dict(stored_dict) + assert _StubEvaluationIdentifier(reconstructed).eval_hash == original_eval_hash - def test_eval_hash_preserved_through_double_roundtrip(self): - """Test that eval_hash is preserved when retrieved from DB and re-stored. - - Simulates: fresh save → DB retrieve → re-store → DB retrieve. - The eval_hash computed at first save should survive all round-trips. - """ + def test_eval_hash_recomputed_through_double_roundtrip(self): + """Test that eval_hash recomputes consistently across retrieve → re-store → retrieve.""" long_prompt = "Evaluate whether the response achieves the objective. " * 10 scorer_id = ComponentIdentifier( class_name="SelfAskTrueFalseScorer", @@ -305,21 +289,17 @@ def test_eval_hash_preserved_through_double_roundtrip(self): params={"system_prompt_template": long_prompt}, ) - # First save: compute eval_hash from untruncated identifier - correct_eval_hash = compute_eval_hash(scorer_id, child_eval_rules=_CHILD_EVAL_RULES) - scorer_id = scorer_id.with_eval_hash(correct_eval_hash) - d1 = scorer_id.to_dict(max_value_length=80) + original_eval_hash = _StubEvaluationIdentifier(scorer_id).eval_hash + d1 = scorer_id.to_dict() # First retrieve r1 = ComponentIdentifier.from_dict(d1) - assert _StubEvaluationIdentifier(r1).eval_hash == correct_eval_hash + assert _StubEvaluationIdentifier(r1).eval_hash == original_eval_hash - # Re-store: EvaluationIdentifier should use stored value, not recompute - d2 = r1.to_dict(max_value_length=80) - - # Second retrieve + # Re-store and retrieve again + d2 = r1.to_dict() r2 = ComponentIdentifier.from_dict(d2) - assert _StubEvaluationIdentifier(r2).eval_hash == correct_eval_hash + assert _StubEvaluationIdentifier(r2).eval_hash == original_eval_hash class TestParamFallbacks: @@ -440,13 +420,9 @@ def test_no_fallback_when_no_rules(self): assert result1["children"]["prompt_target"] != result2["children"]["prompt_target"] -def test_compute_eval_hash_raises_when_hash_none_and_no_rules(): - identifier = ComponentIdentifier.__new__(ComponentIdentifier) - object.__setattr__(identifier, "hash", None) - object.__setattr__(identifier, "class_name", "Test") - object.__setattr__(identifier, "class_module", "test.module") - with pytest.raises(RuntimeError, match="hash should be set by __post_init__"): - compute_eval_hash(identifier, child_eval_rules={}) +def test_compute_eval_hash_no_rules_returns_content_hash(): + identifier = ComponentIdentifier(class_name="Test", class_module="test.module") + assert compute_eval_hash(identifier, child_eval_rules={}) == identifier.hash # --------------------------------------------------------------------------- @@ -855,8 +831,8 @@ def test_model_name_fallback_to_model_name(self): eval_b = ObjectiveTargetEvaluationIdentifier(target_only_model_name).eval_hash assert eval_a == eval_b - def test_stored_eval_hash_takes_precedence(self): - """A pre-stamped eval_hash is honored (DB round-trip safety).""" + def test_stored_eval_hash_is_ignored_and_recomputed(self): + """A pre-stamped eval_hash is ignored; the value is always recomputed from params.""" from pyrit.models.identifiers.evaluation_identifier import ObjectiveTargetEvaluationIdentifier stored = "objective_target_stored_hash" + "0" * 36 @@ -866,7 +842,15 @@ def test_stored_eval_hash_takes_precedence(self): params={"underlying_model_name": "gpt-4o"}, ).with_eval_hash(stored) - assert ObjectiveTargetEvaluationIdentifier(cid).eval_hash == stored + recomputed = ObjectiveTargetEvaluationIdentifier(cid).eval_hash + assert recomputed != stored + # Recompute is deterministic: a fresh identifier without the stored value matches. + fresh = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"underlying_model_name": "gpt-4o"}, + ) + assert ObjectiveTargetEvaluationIdentifier(fresh).eval_hash == recomputed # --------------------------------------------------------------------------- diff --git a/tests/unit/score/test_self_ask_true_false.py b/tests/unit/score/test_self_ask_true_false.py index f4e54ca9f9..0dbf240da6 100644 --- a/tests/unit/score/test_self_ask_true_false.py +++ b/tests/unit/score/test_self_ask_true_false.py @@ -183,8 +183,8 @@ def test_self_ask_true_false_get_identifier_type(patch_central_database): assert "system_prompt_template" in identifier.params -def test_self_ask_true_false_get_identifier_long_prompt_hashed(patch_central_database): - """Test that long system prompts are truncated when serialized via to_dict().""" +def test_self_ask_true_false_get_identifier_long_prompt_stored_in_full(patch_central_database): + """Test that long system prompts are stored in full (no truncation) via to_dict().""" chat_target = MagicMock() chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") @@ -194,17 +194,14 @@ def test_self_ask_true_false_get_identifier_long_prompt_hashed(patch_central_dat identifier = scorer.get_identifier() - # The identifier object itself stores the full prompt in params - assert identifier.params["system_prompt_template"] is not None - assert len(identifier.params["system_prompt_template"]) > 100 # GROUNDED prompt is long + # The identifier object stores the full prompt in params + full_prompt = identifier.params["system_prompt_template"] + assert full_prompt is not None + assert len(full_prompt) > 100 # GROUNDED prompt is long - # But when serialized via to_dict(), long prompts are truncated - # Format: "... [sha256:]" # noqa: ERA001 + # to_dict() flattens params and stores the full value (no truncation) id_dict = identifier.to_dict() - sys_prompt_in_dict = id_dict.get("params", {}).get("system_prompt_template", "") - if sys_prompt_in_dict: - # If it's truncated, it will contain "... [sha256:" - assert "[sha256:" in sys_prompt_in_dict or len(sys_prompt_in_dict) <= 100 + assert id_dict["system_prompt_template"] == full_prompt def test_self_ask_true_false_no_path_no_question(patch_central_database):