Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 59 additions & 53 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
Conversation,
ConversationReference,
ConversationType,
EvaluationIdentifier,
MessagePiece,
PromptDataType,
ScenarioIdentifier,
Expand All @@ -62,49 +63,28 @@
# Default pyrit_version for database records created before version tracking was added
LEGACY_PYRIT_VERSION = "<0.10.0"

# Maximum length for string values in ComponentIdentifier.model_dump() when storing to the database.
# Longer values are truncated with a "..." suffix.
MAX_IDENTIFIER_VALUE_LENGTH: int = 80


def _dump_identifier(identifier: ComponentIdentifier | None) -> dict[str, Any] | None:
"""
Serialize a ``ComponentIdentifier`` to a dict for JSON storage, truncating long values.

Args:
identifier (ComponentIdentifier | None): The identifier to serialize, or None.

Returns:
dict[str, Any] | None: The serialized identifier, or None if ``identifier`` is falsy.
"""
if not identifier:
return None
return identifier.model_dump(context={"max_value_length": MAX_IDENTIFIER_VALUE_LENGTH})


def _dump_identifiers(identifiers: list[ComponentIdentifier]) -> list[dict[str, Any]]:
"""
Serialize a list of ``ComponentIdentifier`` objects for JSON storage.

Args:
identifiers (list[ComponentIdentifier]): The identifiers to serialize.

Returns:
list[dict[str, Any]]: The serialized identifiers in order.
"""
return [
identifier.model_dump(context={"max_value_length": MAX_IDENTIFIER_VALUE_LENGTH}) for identifier in identifiers
]


def _load_identifier(stored: dict[str, Any] | None, *, pyrit_version: str | None = None) -> ComponentIdentifier | None:
def _load_identifier(
stored: dict[str, Any] | None,
*,
pyrit_version: str | None = None,
eval_identifier_cls: type[EvaluationIdentifier] | None = None,
) -> ComponentIdentifier | None:
"""
Reconstruct a ``ComponentIdentifier`` from its stored dict representation.

The content hash is recomputed on validation (never trusted from storage).
When ``eval_identifier_cls`` is provided, the ``eval_hash`` is likewise
recomputed from the (full) stored params and re-stamped onto the identifier,
so the stored ``eval_hash`` value is never trusted on reload.

Args:
stored (dict[str, Any] | None): The stored identifier dict, or None.
pyrit_version (str | None): If provided, injected as the identifier's ``pyrit_version``
so the reconstructed object reflects the version that created the row.
eval_identifier_cls (type[EvaluationIdentifier] | None): If provided, the
``EvaluationIdentifier`` subclass used to recompute and re-stamp the
identifier's ``eval_hash`` on reload.

Returns:
ComponentIdentifier | None: The reconstructed identifier, or None if ``stored`` is falsy.
Expand All @@ -113,7 +93,10 @@ def _load_identifier(stored: dict[str, Any] | None, *, pyrit_version: str | None
return None
if pyrit_version is not None:
stored = {**stored, "pyrit_version": pyrit_version}
return ComponentIdentifier.model_validate(stored)
identifier = ComponentIdentifier.model_validate(stored)
if eval_identifier_cls is not None:
identifier = identifier.with_eval_hash(eval_identifier_cls(identifier).eval_hash)
return identifier


def _load_identifiers(
Expand Down Expand Up @@ -310,7 +293,7 @@ def __init__(self, *, entry: MessagePiece) -> None:
self.timestamp = entry.timestamp
self.labels = entry.labels
self.prompt_metadata = entry.prompt_metadata
self.converter_identifiers = _dump_identifiers(entry.converter_identifiers)
self.converter_identifiers = [identifier.model_dump() for identifier in entry.converter_identifiers]

self.original_value = entry.original_value
self.original_value_data_type = entry.original_value_data_type
Expand Down Expand Up @@ -399,7 +382,7 @@ def __init__(self, *, conversation: Conversation) -> None:
conversation (Conversation): The conversation metadata to persist.
"""
self.conversation_id = conversation.conversation_id
self.target_identifier = _dump_identifier(conversation.target_identifier)
self.target_identifier = conversation.target_identifier.model_dump() if conversation.target_identifier else None
self.pyrit_version = pyrit.__version__

def get_conversation(self) -> Conversation:
Expand Down Expand Up @@ -484,12 +467,13 @@ def __init__(self, *, entry: Score) -> None:
self.score_rationale = entry.score_rationale
self.score_metadata = entry.score_metadata or {}
normalized_scorer = entry.scorer_class_identifier
# Ensure eval_hash is set before truncation so it survives the DB round-trip
if normalized_scorer is not None and normalized_scorer.eval_hash is None:
# Always recompute eval_hash before dumping so the stored JSON carries the
# freshly computed value for DB-level filtering (never a value from storage).
if normalized_scorer is not None:
normalized_scorer = normalized_scorer.with_eval_hash(
ScorerEvaluationIdentifier(normalized_scorer).eval_hash
)
self.scorer_class_identifier = _dump_identifier(normalized_scorer) or {}
self.scorer_class_identifier = normalized_scorer.model_dump() if normalized_scorer else {}
self.prompt_request_response_id = entry.message_piece_id if entry.message_piece_id else None
self.timestamp = entry.timestamp
# Store in both columns for backward compatibility
Expand All @@ -505,9 +489,14 @@ def get_score(self) -> Score:
Returns:
Score: The reconstructed score object with all its data.
"""
# Convert dict back to ComponentIdentifier with the stored pyrit_version
# Convert dict back to ComponentIdentifier with the stored pyrit_version;
# eval_hash is recomputed on reload via ScorerEvaluationIdentifier.
stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION
scorer_identifier = _load_identifier(self.scorer_class_identifier, pyrit_version=stored_version)
scorer_identifier = _load_identifier(
self.scorer_class_identifier,
pyrit_version=stored_version,
eval_identifier_cls=ScorerEvaluationIdentifier,
)
return Score(
id=self.id,
score_value=self.score_value,
Expand Down Expand Up @@ -933,12 +922,15 @@ def __init__(self, *, entry: AttackResult) -> None:
self.id = uuid.UUID(entry.attack_result_id)
self.conversation_id = entry.conversation_id
self.objective = entry.objective
# Ensure eval_hash is set before truncation so it survives the DB round-trip
if entry.atomic_attack_identifier and entry.atomic_attack_identifier.eval_hash is None:
# Always recompute eval_hash before dumping so the stored JSON carries the
# freshly computed value for DB-level filtering (never a value from storage).
if entry.atomic_attack_identifier:
entry.atomic_attack_identifier = entry.atomic_attack_identifier.with_eval_hash(
AtomicAttackEvaluationIdentifier(entry.atomic_attack_identifier).eval_hash
)
self.atomic_attack_identifier = _dump_identifier(entry.atomic_attack_identifier)
self.atomic_attack_identifier = (
entry.atomic_attack_identifier.model_dump() if entry.atomic_attack_identifier else None
)
self.objective_sha256 = to_sha256(entry.objective)

# Use helper method for UUID conversions
Expand Down Expand Up @@ -1055,7 +1047,11 @@ def get_attack_result(self) -> AttackResult:
)
)

atomic_id = _load_identifier(self.atomic_attack_identifier)
# eval_hash is recomputed on reload via AtomicAttackEvaluationIdentifier.
atomic_id = _load_identifier(
self.atomic_attack_identifier,
eval_identifier_cls=AtomicAttackEvaluationIdentifier,
)

# Deserialize retry events from JSON
retry_events = []
Expand Down Expand Up @@ -1172,13 +1168,18 @@ def __init__(self, *, entry: ScenarioResult) -> None:
self.pyrit_version = entry.scenario_identifier.pyrit_version
self.scenario_init_data = entry.scenario_identifier.init_data
# Convert ComponentIdentifier to dict for JSON storage
self.objective_target_identifier = _dump_identifier(entry.objective_target_identifier) # type: ignore[ty:invalid-assignment]
# Ensure eval_hash is set before truncation so it survives the DB round-trip.
if entry.objective_scorer_identifier and entry.objective_scorer_identifier.eval_hash is None:
self.objective_target_identifier = ( # type: ignore[ty:invalid-assignment]
entry.objective_target_identifier.model_dump() if entry.objective_target_identifier else None
)
# Always recompute eval_hash before dumping so the stored JSON carries the
# freshly computed value for DB-level filtering (never a value from storage).
if entry.objective_scorer_identifier:
entry.objective_scorer_identifier = entry.objective_scorer_identifier.with_eval_hash(
ScorerEvaluationIdentifier(entry.objective_scorer_identifier).eval_hash
)
self.objective_scorer_identifier = _dump_identifier(entry.objective_scorer_identifier)
self.objective_scorer_identifier = (
entry.objective_scorer_identifier.model_dump() if entry.objective_scorer_identifier else None
)
self.scenario_run_state = entry.scenario_run_state.value
self.labels = entry.labels
self.number_tries = entry.number_tries
Expand Down Expand Up @@ -1224,8 +1225,13 @@ def get_scenario_result(self) -> ScenarioResult:
# Return empty attack_results - will be populated by memory_interface
attack_results: dict[str, list[AttackResult]] = {}

# Convert dict back to ComponentIdentifier with the stored pyrit_version
scorer_identifier = _load_identifier(self.objective_scorer_identifier, pyrit_version=stored_version)
# Convert dict back to ComponentIdentifier with the stored pyrit_version;
# eval_hash is recomputed on reload via ScorerEvaluationIdentifier.
scorer_identifier = _load_identifier(
self.objective_scorer_identifier,
pyrit_version=stored_version,
eval_identifier_cls=ScorerEvaluationIdentifier,
)

# Convert dict back to ComponentIdentifier for reconstruction
target_identifier = _load_identifier(self.objective_target_identifier)
Expand Down
Loading
Loading