diff --git a/doc/bibliography.md b/doc/bibliography.md index d5a55f4428..ccbab89d8c 100644 --- a/doc/bibliography.md +++ b/doc/bibliography.md @@ -5,6 +5,6 @@ All academic papers, research blogs, and technical reports referenced throughout :::{dropdown} Citation Keys :class: hidden-citations -[@aakanksha2024multilingual; @adversaai2023universal; @andriushchenko2024tense; @anthropic2024manyshot; @aqrawi2024singleturncrescendo; @atr2026; @bethany2024mathprompt; @bhardwaj2023harmfulqa; @bhardwaj2024homer; @brahman2024coconot; @bryan2025agentictaxonomy; @bullwinkel2025airtlessons; @bullwinkel2025repeng; @bullwinkel2026trigger; @chao2023pair; @chao2024jailbreakbench; @cui2024orbench; @darkbench2025; @derczynski2024garak; @ding2023wolf; @embracethered2024unicode; @embracethered2025sneakybits; @gehman2020realtoxicityprompts; @ghosh2025aegis; @ghosh2025ailuminate; @gong2025figstep; @gupta2024walledeval; @haider2024phi3safety; @han2024medsafetybench; @hines2024spotlighting; @ji2023beavertails; @ji2024pkusaferlhf; @jiang2025sosbench; @jones2025computeruse; @kingma2014adam; @li2024mossbench; @li2024saladbench; @li2024wmdp; @lin2023toxicchat; @liu2024flipattack; @liu2024mmsafetybench; @lopez2024pyrit; @luo2024jailbreakv; @lv2024codechameleon; @mazeika2023tdc; @mazeika2024harmbench; @mckee2024transparency; @mehrotra2023tap; @microsoft2024skeletonkey; @palaskar2025vlsu; @pfohl2024equitymedqa; @promptfoo2025ccp; @robustintelligence2024bypass; @roccia2024promptintel; @rottger2023xstest; @rottger2025msts; @russinovich2024crescendo; @russinovich2025price; @scheuerman2025transphobia; @shaikh2022second; @shayegani2025computeruse; @shen2023donotanything; @sheshadri2024lat; @souly2024strongreject; @stok2023ansi; @tan2026comicjailbreak; @tang2025multilingual; @tedeschi2024alert; @vantaylor2024socialbias; @vidgen2023simplesafetytests; @wang2023decodingtrust; @wang2023donotanswer; @wang2025siuo; @wei2023jailbroken; @xie2024sorrybench; @yu2023gptfuzzer; @yuan2023cipherchat; @zeng2024persuasion; @zhang2024cbtbench; @ziems2022mic; @zou2023gcg] +[@aakanksha2024multilingual; @adversaai2023universal; @andriushchenko2024tense; @anthropic2024manyshot; @aqrawi2024singleturncrescendo; @atr2026; @bethany2024mathprompt; @bhardwaj2023harmfulqa; @bhardwaj2024homer; @brahman2024coconot; @bryan2025agentictaxonomy; @bullwinkel2025airtlessons; @bullwinkel2025repeng; @bullwinkel2026trigger; @chao2023pair; @chao2024jailbreakbench; @cui2024orbench; @darkbench2025; @derczynski2024garak; @ding2023wolf; @embracethered2024unicode; @embracethered2025sneakybits; @gehman2020realtoxicityprompts; @ghosh2025aegis; @ghosh2025ailuminate; @gong2025figstep; @gupta2024walledeval; @haider2024phi3safety; @han2024medsafetybench; @hines2024spotlighting; @inie2025summon; @ji2023beavertails; @ji2024pkusaferlhf; @jiang2025sosbench; @jones2025computeruse; @kingma2014adam; @li2024mossbench; @li2024saladbench; @li2024wmdp; @lin2023toxicchat; @liu2024flipattack; @liu2024mmsafetybench; @lopez2024pyrit; @luo2024jailbreakv; @lv2024codechameleon; @mazeika2023tdc; @mazeika2024harmbench; @mckee2024transparency; @mehrotra2023tap; @microsoft2024skeletonkey; @odin2024; @palaskar2025vlsu; @pfohl2024equitymedqa; @promptfoo2025ccp; @robustintelligence2024bypass; @roccia2024promptintel; @rottger2023xstest; @rottger2025msts; @russinovich2024crescendo; @russinovich2025price; @scheuerman2025transphobia; @shaikh2022second; @shayegani2025computeruse; @shen2023donotanything; @sheshadri2024lat; @souly2024strongreject; @stok2023ansi; @tan2026comicjailbreak; @tang2025multilingual; @tedeschi2024alert; @vantaylor2024socialbias; @vidgen2023simplesafetytests; @wang2023decodingtrust; @wang2023donotanswer; @wang2025siuo; @wei2023jailbroken; @xie2024sorrybench; @yu2023gptfuzzer; @yuan2023cipherchat; @zeng2024persuasion; @zhang2024cbtbench; @ziems2022mic; @zou2023gcg] ::: diff --git a/doc/code/datasets/1_loading_datasets.ipynb b/doc/code/datasets/1_loading_datasets.ipynb index 5fec62ccf1..0dd5fc8389 100644 --- a/doc/code/datasets/1_loading_datasets.ipynb +++ b/doc/code/datasets/1_loading_datasets.ipynb @@ -14,6 +14,7 @@ "The following command lists all built-in datasets available in PyRIT. Some datasets are stored locally, while others are fetched remotely from sources like HuggingFace.\n", "\n", "Many of these datasets come from published research, including\n", + "0DIN [@odin2024],\n", "Aegis [@ghosh2025aegis],\n", "Agent Threat Rules [@atr2026],\n", "ALERT [@tedeschi2024alert],\n", @@ -74,6 +75,7 @@ " '0din_incremental_table_completion',\n", " '0din_placeholder_injection',\n", " '0din_technical_field_guide',\n", + " '0din_threatfeed',\n", " 'adv_bench',\n", " 'aegis_content_safety',\n", " 'agent_threat_rules',\n", diff --git a/doc/code/datasets/1_loading_datasets.py b/doc/code/datasets/1_loading_datasets.py index 164a2b53d8..9523b0dd80 100644 --- a/doc/code/datasets/1_loading_datasets.py +++ b/doc/code/datasets/1_loading_datasets.py @@ -18,6 +18,7 @@ # The following command lists all built-in datasets available in PyRIT. Some datasets are stored locally, while others are fetched remotely from sources like HuggingFace. # # Many of these datasets come from published research, including +# 0DIN [@odin2024], # Aegis [@ghosh2025aegis], # Agent Threat Rules [@atr2026], # ALERT [@tedeschi2024alert], diff --git a/doc/references.bib b/doc/references.bib index 3f987b823a..ae212f599f 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -88,6 +88,22 @@ @misc{roccia2024promptintel url = {https://promptintel.novahunting.ai/feed}, } +@misc{odin2024, + title = {{0DIN}: {GenAI} Bug Bounty and Threat Feed}, + author = {{Mozilla 0DIN}}, + year = {2024}, + url = {https://0din.ai/}, + note = {0DIN Jailbreak / Threat Feed}, +} + +@article{inie2025summon, + title = {Summon a Demon and Bind it: A Grounded Theory of {LLM} Red Teaming}, + author = {Nanna Inie and Jonathan Stray and Leon Derczynski}, + journal = {PLoS ONE}, + year = {2025}, + url = {https://arxiv.org/abs/2311.06237}, +} + @misc{vantaylor2024socialbias, title = {A Red-Teaming Repository of Existing Social Bias Prompts}, author = {Simone Van Taylor}, diff --git a/pyrit/datasets/seed_datasets/remote/__init__.py b/pyrit/datasets/seed_datasets/remote/__init__.py index e605659b1d..a32e9861d4 100644 --- a/pyrit/datasets/seed_datasets/remote/__init__.py +++ b/pyrit/datasets/seed_datasets/remote/__init__.py @@ -119,6 +119,12 @@ from pyrit.datasets.seed_datasets.remote.multilingual_vulnerability_dataset import ( _MultilingualVulnerabilityDataset, ) +from pyrit.datasets.seed_datasets.remote.odin_dataset import ( + ODINSecurityBoundary, + ODINSeverity, + ODINTaxonomyCategory, + _ODINDataset, +) from pyrit.datasets.seed_datasets.remote.or_bench_dataset import ( _ORBench80KDataset, _ORBenchHardDataset, @@ -199,6 +205,9 @@ "MMSafetyBenchCategory", "MMSafetyBenchVariant", "MossBenchOversensitivityType", + "ODINSecurityBoundary", + "ODINSeverity", + "ODINTaxonomyCategory", "PromptIntelCategory", "PromptIntelSeverity", "SGXSTestLabel", @@ -244,6 +253,7 @@ "_MossBenchDataset", "_MSTSDataset", "_MultilingualVulnerabilityDataset", + "_ODINDataset", "_ORBench80KDataset", "_ORBenchHardDataset", "_ORBenchToxicDataset", diff --git a/pyrit/datasets/seed_datasets/remote/odin_dataset.py b/pyrit/datasets/seed_datasets/remote/odin_dataset.py new file mode 100644 index 0000000000..d9db63ecf3 --- /dev/null +++ b/pyrit/datasets/seed_datasets/remote/odin_dataset.py @@ -0,0 +1,534 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +import json +import logging +import os +import time +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any + +import requests +from typing_extensions import override + +from pyrit.common.path import DB_DATA_PATH +from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( + _RemoteDatasetLoader, +) +from pyrit.models import SeedDataset, SeedPrompt, SeedUnion + +logger = logging.getLogger(__name__) + + +class ODINSeverity(Enum): + """Severity ratings assigned to 0DIN threat-feed reports.""" + + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + SEVERE = "severe" + + +class ODINSecurityBoundary(Enum): + """Security boundary categories for 0DIN threat-feed reports.""" + + GUARDRAIL_JAILBREAK = "guardrail_jailbreak" + PROMPT_INJECTION = "prompt_injection" + PROMPT_EXTRACTION = "prompt_extraction" + CONTENT_MANIPULATION = "content_manipulation" + INTERPRETER_JAILBREAK = "interpreter_jailbreak" + OTHER = "other" + + +class ODINTaxonomyCategory(Enum): + """Top-level categories from the 0DIN jailbreak taxonomy.""" + + STRATAGEMS = "stratagems" + FICTIONALIZING = "fictionalizing" + LANGUAGE = "language" + RHETORIC = "rhetoric" + POSSIBLE_WORLDS = "possible_worlds" + + +class _ODINDataset(_RemoteDatasetLoader): + """ + Loader for the 0DIN (0din.ai) Jailbreak / Threat Feed dataset. + + 0DIN is Mozilla's GenAI bug-bounty and threat-intelligence program. The Threat Feed + publishes verified jailbreak disclosures against production models, each annotated with + a taxonomy (category/strategy/technique), severity, affected models, reproducibility test + results, and impact scores. The taxonomy axis is drawn from 0DIN's published taxonomy, + which is grounded in the "Summon a Demon and Bind it" grounded theory of LLM red teaming + [@inie2025summon] (public taxonomy: https://0din.ai/research/taxonomy). Note this taxonomy + describes *how* an attack is structured, not the *harm* it targets. + + Each report exposes one or more sample exploit prompts (``messages``), and—optionally—a + large set of industry-specific ``variant_prompts``. Every prompt is mapped to a SeedPrompt + containing the literal attack text; the report title is stored in the SeedPrompt's ``name`` + field. Sample prompts that repeat across multiple tested models are de-duplicated. + + Note: 0DIN does not expose separate objective data, so no SeedObjective objects are created. + + The 0DIN feed is live and grows over time. The raw feed is cached on disk (under + ``DB_DATA_PATH``); because reports are returned newest-first, subsequent fetches sync + incrementally — fetching only newly disclosed reports and merging them onto the cache. + Pass ``cache=False`` to ``fetch_dataset_async`` to force a full refresh. + + Reference: [@odin2024], [@inie2025summon] + API Docs: https://0din.ai/docs/jailbreak-feed/api + + This dataset is gated: programmatic access requires a 0DIN Team or Enterprise subscription + and an API key. Provide the key via the ``api_key`` parameter or the ``0DIN_API_KEY`` + environment variable. See https://0din.ai/products for subscription details. + + Warning: This dataset contains adversarial prompts designed to exploit LLMs. Use responsibly + and consult your legal department before using for testing. + """ + + # Metadata + modalities: list[str] = ["text"] + size: str = "large" # ~1,346 unique sample prompts; far larger with variant prompts enabled + tags: set[str] = {"safety", "jailbreak", "cybersecurity"} + harm_categories: list[str] = sorted(c.value for c in ODINTaxonomyCategory) + + API_BASE_URL = "https://0din.ai/api/v1/threatfeed/" + REPORT_WEB_URL = "https://0din.ai/threatfeed" + PAGE_SIZE = 100 + # On-disk cache of the raw (unfiltered) feed, shared across filter configurations. + CACHE_FILENAME = "0din_threatfeed.json" + # 0DIN enforces a 25 req/min rate limit and returns transient 5xx (or 429/406 from its + # anti-abuse layer) under load; retry those with backoff. + MAX_RETRIES = 4 + RETRY_BACKOFF_SECONDS = 5.0 + _RETRYABLE_STATUS_CODES = frozenset({406, 429, 500, 502, 503, 504}) + + def __init__( + self, + *, + api_key: str | None = None, + severity: ODINSeverity | None = None, + security_boundaries: list[ODINSecurityBoundary] | None = None, + categories: list[ODINTaxonomyCategory] | None = None, + include_variant_prompts: bool = False, + ) -> None: + """ + Initialize the 0DIN dataset loader. + + The 0DIN API does not support server-side filtering, so all filters are applied + client-side after the full feed is fetched. + + Args: + api_key: 0DIN API key. Falls back to the ``0DIN_API_KEY`` environment variable + if not provided. + severity: Keep only reports with this severity. Defaults to None (all severities). + security_boundaries: Keep only reports whose security boundary is in this list. + Defaults to None (all boundaries). + categories: Keep only reports tagged with at least one of these taxonomy categories. + Defaults to None (all categories). + include_variant_prompts: Whether to additionally emit the industry-specific variant + prompts attached to each report. Defaults to False (sample prompts only), since + variants greatly increase the dataset size. + + Raises: + ValueError: If an invalid severity, security boundary, or category is provided, or + if a filter list is provided but empty (pass None to include all). + """ + self._api_key = api_key + + if severity is not None: + self._validate_enum(severity, ODINSeverity, "severity") + + if security_boundaries is not None: + if not security_boundaries: + raise ValueError( + "`security_boundaries` must be a non-empty list (pass None to include all security boundaries)" + ) + self._validate_enums(security_boundaries, ODINSecurityBoundary, "security_boundary") + + if categories is not None: + if not categories: + raise ValueError("`categories` must be a non-empty list (pass None to include all categories)") + self._validate_enums(categories, ODINTaxonomyCategory, "category") + + self._severity = severity + self._security_boundaries = security_boundaries + self._categories = categories + self._include_variant_prompts = include_variant_prompts + self.source = "https://0din.ai" + + @property + @override + def dataset_name(self) -> str: + """Return the dataset name.""" + return "0din_threatfeed" + + def _resolve_api_key(self) -> str: + """ + Resolve the 0DIN API key from the constructor argument or environment. + + Returns: + str: The resolved API key. + + Raises: + ValueError: If no API key is provided and ``0DIN_API_KEY`` is not set. + """ + api_key = self._api_key or os.environ.get("0DIN_API_KEY") + if not api_key: + raise ValueError( + "0DIN API key is required. Provide it via the 'api_key' parameter " + "or set the 0DIN_API_KEY environment variable." + ) + return api_key + + def _fetch_page(self, *, page: int, headers: dict[str, str]) -> dict[str, Any]: + """ + Fetch a single page of the threat feed, retrying transient errors with backoff. + + Args: + page: The 1-based page number to fetch. + headers: Request headers including the Authorization key. + + Returns: + dict[str, Any]: The parsed JSON body for the page. + + Raises: + ConnectionError: If the request fails with a non-retryable status, or if all + retries are exhausted on transient errors. + """ + last_status: int | None = None + last_text = "" + for attempt in range(self.MAX_RETRIES): + response = requests.get( + self.API_BASE_URL, + headers=headers, + params={"page": page, "per_page": self.PAGE_SIZE}, + timeout=60, + ) + + if response.status_code == 200: + return response.json() + + last_status = response.status_code + last_text = response.text + if response.status_code not in self._RETRYABLE_STATUS_CODES: + break + + if attempt < self.MAX_RETRIES - 1: + backoff = self.RETRY_BACKOFF_SECONDS * (attempt + 1) + logger.warning( + f"0DIN API page {page} returned status {response.status_code}; " + f"retrying in {backoff:.0f}s (attempt {attempt + 1}/{self.MAX_RETRIES})." + ) + time.sleep(backoff) + + raise ConnectionError(f"0DIN API request failed with status {last_status}: {last_text}") + + def _cache_path(self) -> Path: + """ + Return the on-disk path of the cached raw threat feed. + + Returns: + Path: The JSON cache file path under ``DB_DATA_PATH``. + """ + return DB_DATA_PATH / "seed-prompt-entries" / self.CACHE_FILENAME + + def _load_cached_reports(self) -> list[dict[str, Any]]: + """ + Load previously cached threat-feed reports from disk. + + Returns: + list[dict[str, Any]]: The cached reports, or an empty list if no usable cache exists. + """ + path = self._cache_path() + if not path.exists(): + return [] + try: + with path.open("r", encoding="utf-8") as file: + data = json.load(file) + except (OSError, json.JSONDecodeError) as exc: + logger.warning(f"Ignoring unreadable 0DIN cache at {path}: {exc}") + return [] + return data if isinstance(data, list) else [] + + def _write_cached_reports(self, reports: list[dict[str, Any]]) -> None: + """ + Persist the full set of threat-feed reports to disk. + + Args: + reports: The complete (unfiltered) list of reports to cache. + """ + path = self._cache_path() + try: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as file: + json.dump(reports, file, ensure_ascii=False) + except OSError as exc: + logger.warning(f"Failed to write 0DIN cache at {path}: {exc}") + + def _fetch_all_reports(self, *, cache: bool = True) -> list[dict[str, Any]]: + """ + Fetch all threat-feed reports, incrementally syncing against the on-disk cache. + + The feed is returned newest-first, so when a cache exists this paginates from the + first page and stops as soon as it encounters a report UUID already present in the + cache — fetching only newly disclosed reports and merging them on top. When no cache + exists (or ``cache`` is False) the full feed is fetched. + + Note: edits to already-cached reports (``updated_at`` changes) are not picked up by + the incremental sync; pass ``cache=False`` to force a full refresh. + + Args: + cache: Whether to read from and write to the on-disk cache. Defaults to True. + + Returns: + list[dict[str, Any]]: All report records (newest-first). + + Raises: + ValueError: If no API key is provided and ``0DIN_API_KEY`` is not set. + ConnectionError: If an API request fails. + """ + api_key = self._resolve_api_key() + headers = {"Authorization": api_key} + + cached_reports = self._load_cached_reports() if cache else [] + cached_uuids = {r.get("uuid") for r in cached_reports if r.get("uuid")} + + new_reports: list[dict[str, Any]] = [] + reached_cache = False + page = 1 + + while not reached_cache: + body = self._fetch_page(page=page, headers=headers) + for report in body.get("threat_feeds", []): + if report.get("uuid") in cached_uuids: + reached_cache = True + break + new_reports.append(report) + + total_pages = body.get("total_pages", 1) + if page >= total_pages: + break + page += 1 + + if not cache: + return new_reports + + if not new_reports: + return cached_reports + + # Merge newest-first, de-duplicating by UUID (newly fetched reports win). + merged: list[dict[str, Any]] = [] + seen: set[Any] = set() + for report in (*new_reports, *cached_reports): + uuid = report.get("uuid") + if uuid and uuid in seen: + continue + if uuid: + seen.add(uuid) + merged.append(report) + + self._write_cached_reports(merged) + return merged + + def _matches_filters(self, report: dict[str, Any]) -> bool: + """ + Determine whether a report satisfies the configured client-side filters. + + Args: + report: A single threat-feed report record. + + Returns: + bool: True if the report should be included. + """ + if self._severity is not None and report.get("severity") != self._severity.value: + return False + + if self._security_boundaries is not None: + allowed = {b.value for b in self._security_boundaries} + if report.get("security_boundary") not in allowed: + return False + + if self._categories is not None: + allowed_categories = {c.value for c in self._categories} + report_categories = {t.get("category") for t in report.get("taxonomies") or []} + if not (report_categories & allowed_categories): + return False + + return True + + def _parse_datetime(self, date_str: str | None) -> datetime | None: + """ + Parse an ISO 8601 datetime string from the API. + + Args: + date_str: ISO format datetime string, or None. + + Returns: + datetime or None if parsing fails. + """ + if not date_str: + return None + try: + return datetime.fromisoformat(date_str.replace("Z", "+00:00")) + except (ValueError, AttributeError): + return None + + def _build_metadata( + self, report: dict[str, Any], *, extra: dict[str, str | int] | None = None + ) -> dict[str, str | int]: + """ + Build the metadata dict from a 0DIN report. + + Args: + report: A single threat-feed report record. + extra: Optional additional key/value pairs to merge in (e.g. variant info). + + Returns: + dict[str, str | int]: Metadata dictionary with string or integer values. + """ + metadata: dict[str, str | int] = {} + + if report.get("uuid"): + metadata["uuid"] = report["uuid"] + if report.get("severity"): + metadata["severity"] = report["severity"] + if report.get("security_boundary"): + metadata["security_boundary"] = report["security_boundary"] + if report.get("source"): + metadata["report_source"] = report["source"] + + taxonomies = report.get("taxonomies") or [] + categories = sorted({t["category"] for t in taxonomies if t.get("category")}) + strategies = sorted({t["strategy"] for t in taxonomies if t.get("strategy")}) + techniques = sorted({t["technique"] for t in taxonomies if t.get("technique")}) + if categories: + metadata["taxonomy_categories"] = ", ".join(categories) + if strategies: + metadata["taxonomy_strategies"] = ", ".join(strategies) + if techniques: + metadata["taxonomy_techniques"] = ", ".join(techniques) + + model_names = [] + for model in report.get("models") or []: + name = model.get("name") + if not name: + continue + vendor = (model.get("vendor") or {}).get("name") + model_names.append(f"{vendor}: {name}" if vendor else name) + if model_names: + metadata["affected_models"] = ", ".join(model_names) + + for entry in report.get("metadata") or []: + if entry.get("type") == "SocialImpact" and entry.get("result") is not None: + metadata["social_impact"] = int(entry["result"]) + + signatures = report.get("detection_signatures") or [] + if signatures and signatures[0].get("signature"): + metadata["detection_signature"] = signatures[0]["signature"] + + if report.get("disclosed_at"): + metadata["disclosed_at"] = report["disclosed_at"] + + if extra: + metadata.update(extra) + + return metadata + + def _convert_report_to_seed_prompts(self, report: dict[str, Any]) -> list[SeedPrompt]: + """ + Convert a single 0DIN report into one or more SeedPrompts. + + Sample prompts from ``messages`` are emitted first (de-duplicated by text). When + ``include_variant_prompts`` is set, industry-specific variant prompts are appended. + + Args: + report: A single threat-feed report record. + + Returns: + list[SeedPrompt]: The seed prompts derived from this report. + """ + title = report.get("title") or None + uuid = report.get("uuid", "") + summary = report.get("summary") or None + taxonomies = report.get("taxonomies") or [] + harm_categories = sorted({t["category"] for t in taxonomies if t.get("category")}) or None + date_added = self._parse_datetime(report.get("disclosed_at")) + source_url = f"{self.REPORT_WEB_URL}/{uuid}" if uuid else self.source + + seeds: list[SeedPrompt] = [] + seen_prompts: set[str] = set() + + def _add_prompt(text: str, *, extra: dict[str, str | int] | None = None) -> None: + if not text or text in seen_prompts: + return + seen_prompts.add(text) + seeds.append( + SeedPrompt( + value=text, + data_type="text", + name=title, + dataset_name=self.dataset_name, + harm_categories=harm_categories, + description=summary, + groups=["0DIN", "Mozilla"], + source=source_url, + date_added=date_added, + metadata=self._build_metadata(report, extra=extra), + ) + ) + + for message in report.get("messages") or []: + _add_prompt(message.get("prompt", "")) + + if self._include_variant_prompts: + for variant in report.get("variant_prompts") or []: + industry = variant.get("industry") + for subindustry in variant.get("subindustries") or []: + sub_name = subindustry.get("subindustry") + for prompt in subindustry.get("prompts") or []: + extra: dict[str, str | int] = {} + if industry: + extra["variant_industry"] = industry + if sub_name: + extra["variant_subindustry"] = sub_name + _add_prompt(prompt.get("prompt", ""), extra=extra) + + return seeds + + @override + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: + """ + Fetch reports from the 0DIN API and return them as a SeedDataset. + + Args: + cache: Whether to use the on-disk cache. Defaults to True. When True, the raw feed + is cached and subsequent calls only fetch newly disclosed reports (see + ``_fetch_all_reports``). When False, the full feed is fetched fresh and the + cache is neither read nor written. + + Returns: + SeedDataset: A SeedDataset containing the fetched prompts. + + Raises: + ValueError: If no API key is available or if the filters produce no seeds. + ConnectionError: If an API request fails. + """ + logger.info("Fetching reports from 0DIN threat feed API") + + reports = await asyncio.to_thread(self._fetch_all_reports, cache=cache) + + all_seeds: list[SeedUnion] = [] + for report in reports: + if not self._matches_filters(report): + continue + all_seeds.extend(self._convert_report_to_seed_prompts(report)) + + if not all_seeds: + raise ValueError("SeedDataset cannot be empty. Check your filter criteria.") + + logger.info(f"Successfully loaded {len(all_seeds)} prompts from 0DIN") + + return SeedDataset(seeds=all_seeds, dataset_name=self.dataset_name) diff --git a/tests/unit/datasets/test_odin_dataset.py b/tests/unit/datasets/test_odin_dataset.py new file mode 100644 index 0000000000..bde3cec697 --- /dev/null +++ b/tests/unit/datasets/test_odin_dataset.py @@ -0,0 +1,508 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from pyrit.datasets.seed_datasets.remote.odin_dataset import ( + ODINSecurityBoundary, + ODINSeverity, + ODINTaxonomyCategory, + _ODINDataset, +) +from pyrit.models import SeedDataset, SeedPrompt + + +@pytest.fixture +def api_key(): + """A fake API key for testing.""" + return "odin_test_key_000000000000000000000000000000000000000000000000" + + +@pytest.fixture(autouse=True) +def isolate_cache(tmp_path): + """Point the loader's on-disk cache at a per-test temp file so tests never touch real cache.""" + cache_file = tmp_path / "0din_threatfeed.json" + with patch.object(_ODINDataset, "_cache_path", return_value=cache_file): + yield cache_file + + +def _report( + *, + uuid, + title="Sample Jailbreak", + severity="low", + security_boundary="guardrail_jailbreak", + source="internal", + prompts=("attack one", "attack two"), + categories=("stratagems", "language"), + variant_prompts=None, +): + """Build a single threat-feed report record matching the 0DIN API schema.""" + # Each prompt is repeated across two "models" to mimic the real de-duplication scenario. + messages = [] + for idx, prompt in enumerate(prompts): + messages.append({"prompt": prompt, "response": "...", "model_id": idx, "interface": "api"}) + messages.append({"prompt": prompt, "response": "...", "model_id": idx + 100, "interface": "api"}) + + taxonomies = [ + {"category": cat, "strategy": f"{cat}_strategy", "technique": f"{cat}_technique"} for cat in categories + ] + + return { + "uuid": uuid, + "title": title, + "summary": "A short summary.", + "detail": "A long detail.", + "severity": severity, + "security_boundary": security_boundary, + "source": source, + "disclosed_at": "2026-06-15T14:54:11.981Z", + "published_at": None, + "updated_at": "2026-06-15T14:54:12.029Z", + "detection_signatures": [{"version": "v1", "signature": f"sig-{uuid}"}], + "models": [ + {"id": 1, "name": "Gemini 3 Flash", "vendor": {"name": "Google"}}, + {"id": 2, "name": "Command R", "vendor": {"name": "Cohere"}}, + ], + "messages": messages, + "taxonomies": taxonomies, + "test_results": [{"result": 85.0, "temperature": 0.7, "model_id": 1, "test_type": {"id": 4, "name": "x"}}], + "metadata": [{"type": "SocialImpact", "result": 4}], + "reference_urls": [], + "variant_prompts": variant_prompts or [], + } + + +def _page(reports, *, page=1, total_pages=1, total_count=None): + """Build a paginated list response.""" + return { + "page": page, + "total_pages": total_pages, + "total_count": total_count if total_count is not None else len(reports), + "threat_feeds": reports, + } + + +def _make_mock_response(*, json_data, status_code=200): + """Create a mock requests.Response.""" + mock_resp = MagicMock() + mock_resp.status_code = status_code + mock_resp.json.return_value = json_data + mock_resp.text = str(json_data) + return mock_resp + + +@pytest.fixture +def single_page_response(): + """A one-page feed with two reports of differing severity/boundary/category.""" + return _page( + [ + _report( + uuid="11111111-1111-1111-1111-111111111111", + title="Report A", + severity="low", + security_boundary="guardrail_jailbreak", + prompts=("attack one", "attack two"), + categories=("stratagems", "language"), + ), + _report( + uuid="22222222-2222-2222-2222-222222222222", + title="Report B", + severity="high", + security_boundary="prompt_injection", + prompts=("attack three",), + categories=("rhetoric",), + ), + ] + ) + + +class TestODINDatasetInit: + """Test initialization and validation of _ODINDataset.""" + + def test_init_with_api_key(self, api_key): + loader = _ODINDataset(api_key=api_key) + assert loader.dataset_name == "0din_threatfeed" + assert loader._api_key == api_key + + def test_init_with_env_var(self, api_key): + with patch.dict("os.environ", {"0DIN_API_KEY": api_key}): + loader = _ODINDataset() + assert loader._api_key is None # env var resolved at fetch time + + def test_init_no_api_key_succeeds(self): + with patch.dict("os.environ", {}, clear=True): + loader = _ODINDataset() + assert loader._api_key is None + + def test_init_invalid_severity_raises(self, api_key): + with pytest.raises(ValueError, match="Expected ODINSeverity"): + _ODINDataset(api_key=api_key, severity="low") + + def test_init_invalid_security_boundary_raises(self, api_key): + with pytest.raises(ValueError, match="Expected ODINSecurityBoundary"): + _ODINDataset(api_key=api_key, security_boundaries=["guardrail_jailbreak"]) + + def test_init_invalid_category_raises(self, api_key): + with pytest.raises(ValueError, match="Expected ODINTaxonomyCategory"): + _ODINDataset(api_key=api_key, categories=["stratagems"]) + + def test_init_empty_security_boundaries_raises(self, api_key): + with pytest.raises(ValueError, match="`security_boundaries` must be a non-empty list"): + _ODINDataset(api_key=api_key, security_boundaries=[]) + + def test_init_empty_categories_raises(self, api_key): + with pytest.raises(ValueError, match="`categories` must be a non-empty list"): + _ODINDataset(api_key=api_key, categories=[]) + + def test_init_accepts_valid_enums(self, api_key): + loader = _ODINDataset( + api_key=api_key, + severity=ODINSeverity.HIGH, + security_boundaries=[ODINSecurityBoundary.GUARDRAIL_JAILBREAK], + categories=[ODINTaxonomyCategory.STRATAGEMS, ODINTaxonomyCategory.LANGUAGE], + ) + assert loader._severity == ODINSeverity.HIGH + assert loader._categories == [ODINTaxonomyCategory.STRATAGEMS, ODINTaxonomyCategory.LANGUAGE] + + def test_dataset_name(self, api_key): + assert _ODINDataset(api_key=api_key).dataset_name == "0din_threatfeed" + + +class TestODINDatasetFetch: + """Test fetch_dataset_async and data transformation.""" + + async def test_fetch_no_api_key_raises(self): + with patch.dict("os.environ", {}, clear=True): + loader = _ODINDataset() + with pytest.raises(ValueError, match="API key is required"): + await loader.fetch_dataset_async() + + async def test_fetch_returns_seed_dataset(self, api_key, single_page_response): + loader = _ODINDataset(api_key=api_key) + mock_resp = _make_mock_response(json_data=single_page_response) + + with patch("requests.get", return_value=mock_resp): + dataset = await loader.fetch_dataset_async() + + assert isinstance(dataset, SeedDataset) + # Report A: 2 unique prompts (de-duplicated from 4 messages); Report B: 1 prompt -> 3 total + assert len(dataset.seeds) == 3 + + async def test_deduplicates_message_prompts(self, api_key): + report = _report( + uuid="33333333-3333-3333-3333-333333333333", + prompts=("only attack",), # repeated across two models -> 2 messages + ) + loader = _ODINDataset(api_key=api_key) + mock_resp = _make_mock_response(json_data=_page([report])) + + with patch("requests.get", return_value=mock_resp): + dataset = await loader.fetch_dataset_async() + + assert len(dataset.seeds) == 1 + assert dataset.seeds[0].value == "only attack" + + async def test_seed_prompt_fields(self, api_key, single_page_response): + loader = _ODINDataset(api_key=api_key) + mock_resp = _make_mock_response(json_data=single_page_response) + + with patch("requests.get", return_value=mock_resp): + dataset = await loader.fetch_dataset_async() + + prompts = [s for s in dataset.seeds if isinstance(s, SeedPrompt)] + first = prompts[0] + assert first.data_type == "text" + assert first.dataset_name == "0din_threatfeed" + assert first.name == "Report A" + assert first.description == "A short summary." + assert first.harm_categories == ["language", "stratagems"] + assert first.groups == ["0DIN", "Mozilla"] + assert first.source == "https://0din.ai/threatfeed/11111111-1111-1111-1111-111111111111" + + async def test_seed_prompt_metadata(self, api_key, single_page_response): + loader = _ODINDataset(api_key=api_key) + mock_resp = _make_mock_response(json_data=single_page_response) + + with patch("requests.get", return_value=mock_resp): + dataset = await loader.fetch_dataset_async() + + first = next(s for s in dataset.seeds if s.name == "Report A") + assert first.metadata["uuid"] == "11111111-1111-1111-1111-111111111111" + assert first.metadata["severity"] == "low" + assert first.metadata["security_boundary"] == "guardrail_jailbreak" + assert first.metadata["report_source"] == "internal" + assert first.metadata["taxonomy_categories"] == "language, stratagems" + assert "Google: Gemini 3 Flash" in first.metadata["affected_models"] + assert first.metadata["social_impact"] == 4 + assert first.metadata["detection_signature"] == "sig-11111111-1111-1111-1111-111111111111" + + async def test_value_preserved_verbatim(self, api_key): + # Jinja-like syntax must not be rendered for untrusted remote text. + report = _report(uuid="44444444-4444-4444-4444-444444444444", prompts=("{{ 7 * 7 }} literal",)) + loader = _ODINDataset(api_key=api_key) + mock_resp = _make_mock_response(json_data=_page([report])) + + with patch("requests.get", return_value=mock_resp): + dataset = await loader.fetch_dataset_async() + + assert dataset.seeds[0].value == "{{ 7 * 7 }} literal" + + +class TestODINDatasetFilters: + """Test client-side filtering.""" + + async def test_severity_filter(self, api_key, single_page_response): + loader = _ODINDataset(api_key=api_key, severity=ODINSeverity.HIGH) + mock_resp = _make_mock_response(json_data=single_page_response) + + with patch("requests.get", return_value=mock_resp): + dataset = await loader.fetch_dataset_async() + + # Only Report B (high) survives -> its single prompt + assert len(dataset.seeds) == 1 + assert dataset.seeds[0].name == "Report B" + + async def test_security_boundary_filter(self, api_key, single_page_response): + loader = _ODINDataset( + api_key=api_key, + security_boundaries=[ODINSecurityBoundary.PROMPT_INJECTION], + ) + mock_resp = _make_mock_response(json_data=single_page_response) + + with patch("requests.get", return_value=mock_resp): + dataset = await loader.fetch_dataset_async() + + assert {s.name for s in dataset.seeds} == {"Report B"} + + async def test_category_filter(self, api_key, single_page_response): + loader = _ODINDataset(api_key=api_key, categories=[ODINTaxonomyCategory.RHETORIC]) + mock_resp = _make_mock_response(json_data=single_page_response) + + with patch("requests.get", return_value=mock_resp): + dataset = await loader.fetch_dataset_async() + + assert {s.name for s in dataset.seeds} == {"Report B"} + + async def test_filter_empty_result_raises(self, api_key, single_page_response): + loader = _ODINDataset(api_key=api_key, severity=ODINSeverity.SEVERE) + mock_resp = _make_mock_response(json_data=single_page_response) + + with patch("requests.get", return_value=mock_resp): + with pytest.raises(ValueError, match="SeedDataset cannot be empty"): + await loader.fetch_dataset_async() + + +class TestODINDatasetVariants: + """Test variant-prompt inclusion.""" + + @staticmethod + def _report_with_variants(uuid): + return _report( + uuid=uuid, + prompts=("primary attack",), + variant_prompts=[ + { + "industry": "automotive", + "subindustries": [ + { + "subindustry": "autonomous_driving", + "industry_id": 2, + "prompts": [ + {"prompt": "variant a", "key_changes": "...", "rationale": "..."}, + {"prompt": "variant b", "key_changes": "...", "rationale": "..."}, + ], + } + ], + } + ], + ) + + async def test_variants_excluded_by_default(self, api_key): + report = self._report_with_variants("55555555-5555-5555-5555-555555555555") + loader = _ODINDataset(api_key=api_key) + mock_resp = _make_mock_response(json_data=_page([report])) + + with patch("requests.get", return_value=mock_resp): + dataset = await loader.fetch_dataset_async() + + assert {s.value for s in dataset.seeds} == {"primary attack"} + + async def test_variants_included_when_requested(self, api_key): + report = self._report_with_variants("66666666-6666-6666-6666-666666666666") + loader = _ODINDataset(api_key=api_key, include_variant_prompts=True) + mock_resp = _make_mock_response(json_data=_page([report])) + + with patch("requests.get", return_value=mock_resp): + dataset = await loader.fetch_dataset_async() + + values = {s.value for s in dataset.seeds} + assert values == {"primary attack", "variant a", "variant b"} + + variant = next(s for s in dataset.seeds if s.value == "variant a") + assert variant.metadata["variant_industry"] == "automotive" + assert variant.metadata["variant_subindustry"] == "autonomous_driving" + + +class TestODINDatasetPagination: + """Test pagination handling.""" + + async def test_fetches_all_pages(self, api_key): + page1 = _page( + [_report(uuid="aaaaaaaa-0000-0000-0000-000000000001", prompts=("p1",))], + page=1, + total_pages=2, + total_count=2, + ) + page2 = _page( + [_report(uuid="aaaaaaaa-0000-0000-0000-000000000002", prompts=("p2",))], + page=2, + total_pages=2, + total_count=2, + ) + loader = _ODINDataset(api_key=api_key) + responses = [_make_mock_response(json_data=page1), _make_mock_response(json_data=page2)] + + with patch("requests.get", side_effect=responses) as mock_get: + dataset = await loader.fetch_dataset_async() + + assert mock_get.call_count == 2 + assert {s.value for s in dataset.seeds} == {"p1", "p2"} + + async def test_auth_header_has_no_bearer_prefix(self, api_key, single_page_response): + loader = _ODINDataset(api_key=api_key) + mock_resp = _make_mock_response(json_data=single_page_response) + + with patch("requests.get", return_value=mock_resp) as mock_get: + await loader.fetch_dataset_async() + + assert mock_get.call_args.kwargs["headers"]["Authorization"] == api_key + + +class TestODINDatasetAPIErrors: + """Test error handling for API failures.""" + + async def test_api_401_raises_immediately(self, api_key): + loader = _ODINDataset(api_key=api_key) + mock_resp = _make_mock_response(json_data={"error": "unauthorized"}, status_code=401) + + with patch("requests.get", return_value=mock_resp) as mock_get: + with pytest.raises(ConnectionError, match="status 401"): + await loader.fetch_dataset_async() + + # 401 is not retryable -> exactly one request + assert mock_get.call_count == 1 + + async def test_api_500_retries_then_raises(self, api_key): + loader = _ODINDataset(api_key=api_key) + mock_resp = _make_mock_response(json_data={"error": "server"}, status_code=500) + + with patch("time.sleep"): + with patch("requests.get", return_value=mock_resp) as mock_get: + with pytest.raises(ConnectionError, match="status 500"): + await loader.fetch_dataset_async() + + assert mock_get.call_count == _ODINDataset.MAX_RETRIES + + async def test_transient_error_then_success(self, api_key, single_page_response): + loader = _ODINDataset(api_key=api_key) + responses = [ + _make_mock_response(json_data={"error": "bad gateway"}, status_code=502), + _make_mock_response(json_data=single_page_response), + ] + + with patch("time.sleep"): + with patch("requests.get", side_effect=responses) as mock_get: + dataset = await loader.fetch_dataset_async() + + assert mock_get.call_count == 2 + assert len(dataset.seeds) == 3 + + +class TestODINDatasetCaching: + """Test the incremental on-disk cache.""" + + async def test_first_fetch_writes_cache(self, api_key, isolate_cache): + loader = _ODINDataset(api_key=api_key) + report = _report(uuid="cache-0000-0000-0000-000000000001", prompts=("p1",)) + mock_resp = _make_mock_response(json_data=_page([report])) + + assert not isolate_cache.exists() + with patch("requests.get", return_value=mock_resp): + await loader.fetch_dataset_async() + + assert isolate_cache.exists() + cached = json.loads(isolate_cache.read_text(encoding="utf-8")) + assert [r["uuid"] for r in cached] == ["cache-0000-0000-0000-000000000001"] + + async def test_no_new_reports_single_request(self, api_key): + loader = _ODINDataset(api_key=api_key) + report = _report(uuid="cache-0000-0000-0000-000000000001", prompts=("p1",)) + + # First fetch populates the cache. + with patch("requests.get", return_value=_make_mock_response(json_data=_page([report]))): + await loader.fetch_dataset_async() + + # Second fetch: page 1's first UUID is already cached -> stop after one request. + with patch("requests.get", return_value=_make_mock_response(json_data=_page([report]))) as mock_get: + dataset = await loader.fetch_dataset_async() + + assert mock_get.call_count == 1 + assert {s.value for s in dataset.seeds} == {"p1"} + + async def test_incremental_fetch_only_pulls_new_reports(self, api_key, isolate_cache): + loader = _ODINDataset(api_key=api_key) + old = _report(uuid="cache-0000-0000-0000-00000000000A", prompts=("old",)) + with patch("requests.get", return_value=_make_mock_response(json_data=_page([old]))): + await loader.fetch_dataset_async() + + # Feed now returns a new report on top of the known one (newest-first). + new = _report(uuid="cache-0000-0000-0000-00000000000B", prompts=("new",)) + feed = _page([new, old]) + with patch("requests.get", return_value=_make_mock_response(json_data=feed)) as mock_get: + dataset = await loader.fetch_dataset_async() + + assert mock_get.call_count == 1 + # Both old and new prompts are present, new merged on top. + assert {s.value for s in dataset.seeds} == {"new", "old"} + cached = json.loads(isolate_cache.read_text(encoding="utf-8")) + assert [r["uuid"] for r in cached] == [ + "cache-0000-0000-0000-00000000000B", + "cache-0000-0000-0000-00000000000A", + ] + + async def test_cache_false_bypasses_cache(self, api_key, isolate_cache): + loader = _ODINDataset(api_key=api_key) + report = _report(uuid="cache-0000-0000-0000-000000000001", prompts=("p1",)) + + with patch("requests.get", return_value=_make_mock_response(json_data=_page([report]))): + await loader.fetch_dataset_async(cache=True) + cached_mtime = isolate_cache.stat().st_mtime_ns + + # cache=False must not read or write the cache, and must fully paginate. + page1 = _page([report], page=1, total_pages=2, total_count=2) + other = _report(uuid="cache-0000-0000-0000-000000000002", prompts=("p2",)) + page2 = _page([other], page=2, total_pages=2, total_count=2) + responses = [_make_mock_response(json_data=page1), _make_mock_response(json_data=page2)] + with patch("requests.get", side_effect=responses) as mock_get: + dataset = await loader.fetch_dataset_async(cache=False) + + assert mock_get.call_count == 2 # full pagination, cache ignored + assert {s.value for s in dataset.seeds} == {"p1", "p2"} + assert isolate_cache.stat().st_mtime_ns == cached_mtime # cache untouched + + async def test_corrupt_cache_is_ignored(self, api_key, isolate_cache): + isolate_cache.parent.mkdir(parents=True, exist_ok=True) + isolate_cache.write_text("{ not json", encoding="utf-8") + + loader = _ODINDataset(api_key=api_key) + report = _report(uuid="cache-0000-0000-0000-000000000001", prompts=("p1",)) + with patch("requests.get", return_value=_make_mock_response(json_data=_page([report]))): + dataset = await loader.fetch_dataset_async() + + assert {s.value for s in dataset.seeds} == {"p1"} + cached = json.loads(isolate_cache.read_text(encoding="utf-8")) + assert [r["uuid"] for r in cached] == ["cache-0000-0000-0000-000000000001"]