diff --git a/build_scripts/memory_migrations.py b/build_scripts/memory_migrations.py index bb37584c11..59ea7b92f3 100644 --- a/build_scripts/memory_migrations.py +++ b/build_scripts/memory_migrations.py @@ -58,6 +58,20 @@ def _cmd_check() -> None: tmp_path.unlink(missing_ok=True) +def _cmd_head() -> None: + """Print the current Alembic head revision ID.""" + from pathlib import Path + + from alembic.config import Config + from alembic.script import ScriptDirectory + + script_location = Path(__file__).parent.parent / "pyrit" / "memory" / "alembic" + config = Config() + config.set_main_option("script_location", str(script_location)) + head = ScriptDirectory.from_config(config).get_current_head() + print(head) + + def _build_parser() -> argparse.ArgumentParser: """Build the CLI argument parser.""" parser = argparse.ArgumentParser( @@ -71,6 +85,8 @@ def _build_parser() -> argparse.ArgumentParser: sub.add_parser("check", help="Verify all migrations apply cleanly and add up to the current memory models.") + sub.add_parser("head", help="Print the current Alembic head revision ID.") + return parser @@ -82,6 +98,8 @@ def main() -> int: _cmd_generate(message=args.message, force=args.force) elif args.command == "check": _cmd_check() + elif args.command == "head": + _cmd_head() return 0 diff --git a/build_scripts/migrate_prod_memory_schema.py b/build_scripts/migrate_prod_memory_schema.py new file mode 100644 index 0000000000..d9e2139628 --- /dev/null +++ b/build_scripts/migrate_prod_memory_schema.py @@ -0,0 +1,317 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Deliberate schema migration tool for production databases. + +This script is the ONLY sanctioned way to apply Alembic migrations to a production +database. It is intended to be run during the release process (see +doc/contributing/10_release_process.md) or by a CD pipeline — never by normal +application startup. + +Safety rails: +- Requires an explicit --target-revision argument (no blind "upgrade head"). +- Validates the environment (release branch, clean working tree, no .dev version). +- Validates the target revision exists in the local migration graph. +- Reports current DB revision before and after migration. +- Runs schema check after upgrade to confirm models match. +- Interactive confirmation when running in a terminal. +- Exits non-zero on any failure. + +Usage: + python build_scripts/migrate_prod_memory_schema.py \ + --target-revision c3d5e7f9a1b2 + +The script reads the production connection string from the +AZURE_SQL_DB_CONNECTION_STRING environment variable (same as AzureSQLMemory). +""" + +import argparse +import os +import subprocess +import sys + +import dotenv +from alembic import command +from alembic.script import ScriptDirectory +from alembic.util.exc import AutogenerateDiffsDetected +from sqlalchemy import create_engine, text + +from pyrit.common.path import CONFIGURATION_DIRECTORY_PATH + +# Load .env files from ~/.pyrit/ (same files that initialize_pyrit_async loads) +# Use override=False so explicitly-set env vars take precedence over .env values +for _env_file in [CONFIGURATION_DIRECTORY_PATH / ".env", CONFIGURATION_DIRECTORY_PATH / ".env.local"]: + if _env_file.exists(): + dotenv.load_dotenv(_env_file, override=False, interpolate=True) + +# ANSI color codes +_GREEN = "\033[92m" +_RED = "\033[91m" +_YELLOW = "\033[93m" +_RESET = "\033[0m" + +_CONNECTION_STRING_ENV = "AZURE_SQL_DB_CONNECTION_STRING_PROD" + + +def _print_error(message: str) -> None: + """Print an error message in red to stderr.""" + print(f"{_RED}ERROR: {message}{_RESET}", file=sys.stderr) + + +def _print_warning(message: str) -> None: + """Print a warning message in yellow.""" + print(f"{_YELLOW}WARNING: {message}{_RESET}") + + +def _print_success(message: str) -> None: + """Print a success message in green.""" + print(f"{_GREEN}{message}{_RESET}") + + +def _get_current_revision(*, engine) -> str | None: + """ + Read the current Alembic revision from the database. + + Returns None if no version table exists (fresh database). + """ + from sqlalchemy import inspect as sa_inspect + + from pyrit.memory.migration import PYRIT_MEMORY_ALEMBIC_VERSION_TABLE + + inspector = sa_inspect(engine) + if PYRIT_MEMORY_ALEMBIC_VERSION_TABLE not in inspector.get_table_names(): + return None + + with engine.connect() as conn: + result = conn.execute(text(f"SELECT version_num FROM {PYRIT_MEMORY_ALEMBIC_VERSION_TABLE}")) + row = result.fetchone() + return row[0] if row else None + + +def _validate_revision_exists(*, target_revision: str) -> bool: + """Check that the target revision exists in the local migration script directory.""" + + from pathlib import Path + + script_location = Path(__file__).parent.parent / "pyrit" / "memory" / "alembic" + # Build a minimal config to get the ScriptDirectory + from alembic.config import Config + + config = Config() + config.set_main_option("script_location", str(script_location)) + script_dir = ScriptDirectory.from_config(config) + + try: + script_dir.get_revision(target_revision) + return True + except Exception: + return False + + +def _get_all_revision_ids() -> list[str]: + """Get all revision IDs from the local migration graph.""" + from pathlib import Path + + from alembic.config import Config + + script_location = Path(__file__).parent.parent / "pyrit" / "memory" / "alembic" + config = Config() + config.set_main_option("script_location", str(script_location)) + script_dir = ScriptDirectory.from_config(config) + return [rev.revision for rev in script_dir.walk_revisions()] + + +def _check_release_environment() -> list[str]: + """ + Validate that the script is running in a proper release environment. + + Returns a list of warning/error messages. Empty list means all checks pass. + """ + issues: list[str] = [] + + # Check 1: Running from a release branch (releases/vX.Y.Z) + try: + branch = subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + if not branch.startswith("releases/"): + issues.append( + f"Not on a release branch (current: '{branch}'). " + "Production migrations should run from 'releases/vX.Y.Z'." + ) + except (subprocess.CalledProcessError, FileNotFoundError): + issues.append("Could not determine current Git branch.") + + # Check 2: Clean working tree (no uncommitted changes to memory files) + try: + dirty_files = subprocess.check_output( + ["git", "status", "--porcelain", "--", "pyrit/memory/"], + text=True, + stderr=subprocess.DEVNULL, + ).strip() + if dirty_files: + issues.append( + "Uncommitted changes detected in pyrit/memory/:\n" + f" {dirty_files}\n" + " Commit or stash changes before migrating production." + ) + except (subprocess.CalledProcessError, FileNotFoundError): + issues.append("Could not check Git working tree status.") + + # Check 3: Not a .dev version (should be a release version) + try: + from pyrit import __version__ + + if ".dev" in __version__: + issues.append( + f"PyRIT version is '{__version__}' (development). " + "Production migrations should use a release version (no .dev suffix)." + ) + except ImportError: + issues.append("Could not determine PyRIT version.") + + return issues + + +def _run_migration(*, connection_string: str, target_revision: str) -> int: + """ + Execute the migration against the target database. + + Returns 0 on success, 1 on failure. + """ + from pathlib import Path + + from alembic.config import Config + + from pyrit.memory.migration import check_schema_migrations + + engine = create_engine(connection_string) + + # Step 1: Report current state + current_rev = _get_current_revision(engine=engine) + print(f"Current database revision: {current_rev or '(none — fresh database)'}") + print(f"Target revision: {target_revision}") + print() + + if current_rev == target_revision: + _print_success("Database is already at the target revision. Nothing to do.") + engine.dispose() + return 0 + + # Step 2: Run upgrade to specific revision (not head) + print("Applying migrations...") + try: + script_location = Path(__file__).parent.parent / "pyrit" / "memory" / "alembic" + with engine.begin() as connection: + config = Config() + config.set_main_option("script_location", str(script_location)) + config.attributes["connection"] = connection + + # Stamp unversioned schemas if needed + from pyrit.memory.migration import _validate_and_stamp_unversioned_memory_schema + + _validate_and_stamp_unversioned_memory_schema(config=config, connection=connection) + command.upgrade(config, target_revision) + except Exception as e: + _print_error(f"Migration failed: {e}") + engine.dispose() + return 1 + + # Step 3: Verify new revision + new_rev = _get_current_revision(engine=engine) + print(f"Database revision after migration: {new_rev}") + + if new_rev != target_revision: + _print_error(f"Expected revision {target_revision}, but database is at {new_rev}.") + engine.dispose() + return 1 + + # Step 4: Schema check — verify models match + print("Verifying schema matches models...") + try: + check_schema_migrations(engine=engine) + except AutogenerateDiffsDetected as e: + _print_error(f"Schema check failed after migration: {e}") + _print_error("The models in this codebase do not match the database schema.") + engine.dispose() + return 1 + + _print_success("Migration completed and verified successfully.") + engine.dispose() + return 0 + + +def _build_parser() -> argparse.ArgumentParser: + """Build the CLI argument parser.""" + parser = argparse.ArgumentParser( + description=( + "Apply Alembic schema migrations to a production database. " + "Validates release environment and requires an explicit target revision." + ), + ) + parser.add_argument( + "--target-revision", + required=True, + help="The exact Alembic revision ID to upgrade to (e.g., 'c3d5e7f9a1b2').", + ) + parser.add_argument( + "--connection-string-env", + default=_CONNECTION_STRING_ENV, + help=f"Environment variable containing the connection string. Default: {_CONNECTION_STRING_ENV}", + ) + parser.add_argument( + "--skip-environment-check", + action="store_true", + help="Skip release environment checks (branch, clean tree, version). Use only in CI with caution.", + ) + return parser + + +def main() -> int: + """Entry point for production schema migration.""" + args = _build_parser().parse_args() + + # Safety rail 1: Require connection string + connection_string = os.environ.get(args.connection_string_env) + if not connection_string: + _print_error(f"Environment variable '{args.connection_string_env}' is not set.") + return 1 + + # Safety rail 2: Validate target revision exists in local migration graph + if not _validate_revision_exists(target_revision=args.target_revision): + available = _get_all_revision_ids() + _print_error( + f"Target revision '{args.target_revision}' not found in the local migration graph.\n" + f"Available revisions: {', '.join(available)}" + ) + return 1 + + # Safety rail 3: Verify release environment (branch, clean tree, version) + if not args.skip_environment_check: + issues = _check_release_environment() + if issues: + _print_error("Release environment checks failed:") + for issue in issues: + _print_error(f" - {issue}") + _print_error("Fix the above issues or pass --skip-environment-check (CI only).") + return 1 + else: + _print_warning("Skipping release environment checks (--skip-environment-check).") + + # Confirmation prompt (unless running in CI where stdin is not interactive) + if sys.stdin.isatty(): + print(f"About to migrate production database to revision: {args.target_revision}") + print(f"Connection string source: ${args.connection_string_env}") + response = input("Type 'yes' to proceed: ") + if response.strip().lower() != "yes": + print("Aborted.") + return 1 + + return _run_migration(connection_string=connection_string, target_revision=args.target_revision) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/doc/contributing/10_release_process.md b/doc/contributing/10_release_process.md index 7e5af48b65..5eb4fd527a 100644 --- a/doc/contributing/10_release_process.md +++ b/doc/contributing/10_release_process.md @@ -204,7 +204,58 @@ Note: You may need to build the package again if those changes modify any depend Lastly, **Verify pyrit-internal is up to date.** Follow the instructions at [aka.ms/internal-release](https://aka.ms/internal-release) to ensure the internal package is current. -## 9. Publish to PyPI +## 9. Migrate Production Database Schema + +Apply any pending Alembic migrations to the production database. This is the **only** +sanctioned path for modifying the production schema — normal startup only validates, +never upgrades. + +**Run from the release branch with release dependencies.** This ensures the migration +files and model definitions match exactly what will be shipped to users. Running from +`main` or a dev environment could apply unreleased migrations that break prod. + +```bash +git checkout releases/vx.y.z +uv sync --frozen +python -c "import pyrit; print(pyrit.__version__)" # verify: x.y.z (no .dev0) +``` + +**Identify the target revision** — the Alembic head on this branch. We use an explicit +revision (not `head`) so the migration is deterministic and tied to this exact release. + +```bash +python build_scripts/memory_migrations.py head +``` + +This prints the revision ID (e.g., `c3d5e7f9a1b2`) to use as `` below. + +**Run the migration** (reads `AZURE_SQL_DB_CONNECTION_STRING_PROD` from `~/.pyrit/.env`): + +```bash +python build_scripts/migrate_prod_memory_schema.py --target-revision +``` + +The script validates the environment (release branch, clean tree, no `.dev` version), +confirms the target revision exists, applies migrations, and verifies the schema matches +models. It exits non-zero on any failure, and migrations roll back automatically. + +**Verify prod is usable after migration.** This connects to prod using the check-only +path (no schema modification) and confirms compatibility: + +```bash +python -c "from pyrit.memory import AzureSQLMemory; AzureSQLMemory()" +``` + +If it exits without error, prod is ready. + +If no schema changes landed in this release, the script reports "already at target revision" +and exits cleanly. Still run it as confirmation. + +**Rollback policy:** forward-fix only. Ship a new corrective migration rather than downgrading, +since `downgrade()` risks data loss. + + +## 10. Publish to PyPI Create an account on pypi.org if you don't have one yet. Ask one of the other maintainers to add you to the `pyrit` project on PyPI. @@ -221,7 +272,7 @@ If successful, it will print > View at: > https://pypi.org/project/pyrit/x.y.z/ -## 10. Update main +## 11. Update main After the release is on PyPI, make sure to create a PR for the `main` branch where the only changes are: @@ -233,7 +284,7 @@ where the only changes are: The PR should be made from your fork and should be a different branch than the releases branch you created earlier. This should be something like `x.y.z+1.dev0`. -## 11. Create GitHub Release +## 12. Create GitHub Release Finally, go to the [releases page](https://github.com/microsoft/PyRIT/releases), select "Draft a new release" and the "tag" for which you want to create the release notes. It should match the version that you just released diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 47770db028..55d62778cb 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -58,6 +58,9 @@ class AzureSQLMemory(MemoryInterface, metaclass=Singleton): AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: str = "AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL" AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: str = "AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN" + # Optional environment variable for production connection string to prevent accidental schema migrations on prod + AZURE_SQL_DB_CONNECTION_STRING_PROD: str = "AZURE_SQL_DB_CONNECTION_STRING_PROD" + def __init__( self, *, @@ -81,6 +84,9 @@ def __init__( verbose (bool): Whether to enable verbose logging for the database engine. Defaults to False. skip_schema_migration (bool): Whether to skip schema migration. Defaults to False. silent (bool): If True, suppresses schema migration console output. Defaults to False. + + Raises: + AutogenerateDiffsDetected: If connected to a production database and schema does not match models. """ self._connection_string = default_values.get_required_value( env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING, passed_value=connection_string @@ -107,8 +113,19 @@ def __init__( self._enable_azure_authorization() self.SessionFactory = sessionmaker(bind=self.engine) + + prod_connection_string = default_values.get_non_required_value( + env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING_PROD + ) if not skip_schema_migration: - self._run_schema_migration(silent=silent) + if self._connection_string == prod_connection_string: + # Production guard: verify schema compatibility without modifying the database. + # Logs a warning on mismatch but does not block startup, so developers on + # newer code can still query prod data. + self._check_schema_migration(silent=silent) + else: + # For non-production databases, run normal schema migration which will create/update tables as needed. + self._run_schema_migration(silent=silent) super().__init__() diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index e03af2461d..4c4a2f99cd 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1434,6 +1434,35 @@ def _run_schema_migration(self, *, silent: bool = False) -> None: run_schema_migrations(engine=self.engine, silent=silent) check_schema_migrations(engine=self.engine, silent=silent) + def _check_schema_migration(self, *, silent: bool = False) -> None: + """ + Verify that the current database schema matches the models without modifying the database. + + Logs a warning if the schema does not match, but does not raise or block startup. + + Args: + silent (bool): If True, suppresses Alembic console output. Defaults to False. + + Raises: + RuntimeError: If the engine is not initialized. + """ + from alembic.util.exc import AutogenerateDiffsDetected + + from pyrit.memory.migration import check_schema_migrations + + logger.info("Checking schema migration compatibility.") + if self.engine is None: + raise RuntimeError("Engine must be initialized to check schema migrations.") + try: + check_schema_migrations(engine=self.engine, silent=silent) + except AutogenerateDiffsDetected: + logger.warning( + "Schema mismatch detected on production database. " + "Your code models differ from the database schema. " + "This may cause errors if your code references columns or tables that don't exist. " + "Schema was NOT modified." + ) + def reset_database(self) -> None: """ Drop and recreate all tables in the database. diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 34e9671461..b7e536be69 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -10,6 +10,7 @@ import pytest from sqlalchemy import inspect, text +from pyrit.common.singleton import Singleton from pyrit.memory import AzureSQLMemory, EmbeddingDataEntry, PromptMemoryEntry from pyrit.models import Conversation, MessagePiece from pyrit.prompt_converter.base64_converter import Base64Converter @@ -617,3 +618,162 @@ def test_reset_database_raises_when_engine_none(): obj.engine = None with pytest.raises(RuntimeError, match="Engine is not initialized"): obj.reset_database() + + +def test_init_prod_connection_runs_check_only_not_migration(): + """When connection matches prod, only check_schema_migrations runs — not run_schema_migrations.""" + prod_conn = "Server=tcp:prod.database.windows.net;Database=prod_db;" + saved = Singleton._instances.copy() + Singleton._instances.clear() + try: + with ( + patch("pyrit.memory.AzureSQLMemory._create_engine"), + patch("pyrit.memory.AzureSQLMemory._create_auth_token"), + patch("pyrit.memory.AzureSQLMemory._enable_azure_authorization"), + patch.object(AzureSQLMemory, "_check_schema_migration") as mock_check, + patch.object(AzureSQLMemory, "_run_schema_migration") as mock_migration, + patch.dict( + "os.environ", + { + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: "https://test.blob.core.windows.net/test", + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: "valid_sas_token", + AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD: prod_conn, + }, + ), + ): + AzureSQLMemory( + connection_string=prod_conn, + results_container_url="https://test.blob.core.windows.net/test", + results_sas_token="valid_sas_token", + ) + mock_check.assert_called_once() + mock_migration.assert_not_called() + finally: + Singleton._instances.clear() + Singleton._instances.update(saved) + + +def test_init_prod_connection_warns_on_schema_mismatch(): + """When connection matches prod and schema doesn't match, startup succeeds with a warning (no raise).""" + prod_conn = "Server=tcp:prod.database.windows.net;Database=prod_db;" + saved = Singleton._instances.copy() + Singleton._instances.clear() + try: + with ( + patch("pyrit.memory.AzureSQLMemory._create_engine"), + patch("pyrit.memory.AzureSQLMemory._create_auth_token"), + patch("pyrit.memory.AzureSQLMemory._enable_azure_authorization"), + patch.object(AzureSQLMemory, "_check_schema_migration") as mock_check, + patch.dict( + "os.environ", + { + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: "https://test.blob.core.windows.net/test", + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: "valid_sas_token", + AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD: prod_conn, + }, + ), + ): + # Should not raise — _check_schema_migration warns internally on mismatch + AzureSQLMemory( + connection_string=prod_conn, + results_container_url="https://test.blob.core.windows.net/test", + results_sas_token="valid_sas_token", + ) + mock_check.assert_called_once() + finally: + Singleton._instances.clear() + Singleton._instances.update(saved) + + +def test_init_allows_migration_when_connection_does_not_match_prod(): + """Migration proceeds normally when the connection string does not match the prod env var.""" + saved = Singleton._instances.copy() + Singleton._instances.clear() + try: + with ( + patch("pyrit.memory.AzureSQLMemory._create_engine"), + patch("pyrit.memory.AzureSQLMemory._create_auth_token"), + patch("pyrit.memory.AzureSQLMemory._enable_azure_authorization"), + patch.object(AzureSQLMemory, "_run_schema_migration") as mock_migration, + patch.dict( + "os.environ", + { + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: "https://test.blob.core.windows.net/test", + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: "valid_sas_token", + AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD: "Server=tcp:prod.database.windows.net;", + }, + ), + ): + AzureSQLMemory( + connection_string="Server=tcp:dev.database.windows.net;", + results_container_url="https://test.blob.core.windows.net/test", + results_sas_token="valid_sas_token", + ) + mock_migration.assert_called_once() + finally: + Singleton._instances.clear() + Singleton._instances.update(saved) + + +def test_init_allows_migration_when_prod_env_var_not_set(): + """Migration proceeds normally when AZURE_SQL_DB_CONNECTION_STRING_PROD is not set.""" + saved = Singleton._instances.copy() + Singleton._instances.clear() + try: + with ( + patch("pyrit.memory.AzureSQLMemory._create_engine"), + patch("pyrit.memory.AzureSQLMemory._create_auth_token"), + patch("pyrit.memory.AzureSQLMemory._enable_azure_authorization"), + patch.object(AzureSQLMemory, "_run_schema_migration") as mock_migration, + patch.dict( + "os.environ", + { + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: "https://test.blob.core.windows.net/test", + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: "valid_sas_token", + }, + clear=False, + ), + ): + os.environ.pop(AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD, None) + AzureSQLMemory( + connection_string="Server=tcp:dev.database.windows.net;", + results_container_url="https://test.blob.core.windows.net/test", + results_sas_token="valid_sas_token", + ) + mock_migration.assert_called_once() + finally: + Singleton._instances.clear() + Singleton._instances.update(saved) + + +def test_init_skips_prod_guard_when_skip_schema_migration_true(): + """When skip_schema_migration=True, the prod guard is bypassed entirely — no error, no migration.""" + prod_conn = "Server=tcp:prod.database.windows.net;Database=prod_db;" + saved = Singleton._instances.copy() + Singleton._instances.clear() + try: + with ( + patch("pyrit.memory.AzureSQLMemory._create_engine"), + patch("pyrit.memory.AzureSQLMemory._create_auth_token"), + patch("pyrit.memory.AzureSQLMemory._enable_azure_authorization"), + patch.object(AzureSQLMemory, "_run_schema_migration") as mock_migration, + patch.dict( + "os.environ", + { + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: "https://test.blob.core.windows.net/test", + AzureSQLMemory.AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: "valid_sas_token", + AzureSQLMemory.AZURE_SQL_DB_CONNECTION_STRING_PROD: prod_conn, + }, + ), + ): + # Should not raise even though connection matches prod + AzureSQLMemory( + connection_string=prod_conn, + results_container_url="https://test.blob.core.windows.net/test", + results_sas_token="valid_sas_token", + skip_schema_migration=True, + ) + mock_migration.assert_not_called() + finally: + Singleton._instances.clear() + Singleton._instances.update(saved) diff --git a/tests/unit/memory/test_migration.py b/tests/unit/memory/test_migration.py index 7497acdf1d..4205e2291b 100644 --- a/tests/unit/memory/test_migration.py +++ b/tests/unit/memory/test_migration.py @@ -604,6 +604,78 @@ def test_check_schema_migrations_not_silent_prints_output(capsys): engine.dispose() +def test_memory_interface_check_schema_migration_calls_check(): + """_check_schema_migration on MemoryInterface calls check_schema_migrations without running upgrade.""" + from unittest.mock import MagicMock, patch + + from pyrit.memory.memory_interface import MemoryInterface + + obj = MagicMock(spec=MemoryInterface) + obj.engine = MagicMock() + + with patch("pyrit.memory.migration.check_schema_migrations") as mock_check: + MemoryInterface._check_schema_migration(obj, silent=True) + mock_check.assert_called_once_with(engine=obj.engine, silent=True) + + +def test_memory_interface_check_schema_migration_warns_on_mismatch(caplog): + """_check_schema_migration logs a warning instead of raising when schema mismatches.""" + import logging + from unittest.mock import MagicMock, patch + + from alembic.util.exc import AutogenerateDiffsDetected + + from pyrit.memory.memory_interface import MemoryInterface + + obj = MagicMock(spec=MemoryInterface) + obj.engine = MagicMock() + + with ( + patch( + "pyrit.memory.migration.check_schema_migrations", + side_effect=AutogenerateDiffsDetected( + "diffs detected", + revision_context=MagicMock(), + diffs=[], + ), + ), + caplog.at_level(logging.WARNING), + ): + # Should NOT raise + MemoryInterface._check_schema_migration(obj, silent=True) + + assert "Schema mismatch detected on production database" in caplog.text + + +def test_memory_interface_check_schema_migration_raises_without_engine(): + """_check_schema_migration raises RuntimeError when engine is None.""" + from unittest.mock import MagicMock + + from pyrit.memory.memory_interface import MemoryInterface + + obj = MagicMock(spec=MemoryInterface) + obj.engine = None + + with pytest.raises(RuntimeError, match="Engine must be initialized"): + MemoryInterface._check_schema_migration(obj, silent=False) + + +def test_memory_migrations_head_command(capsys): + """The 'head' subcommand of memory_migrations.py prints the current Alembic head revision.""" + import sys + + # Import the module's main function + sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "build_scripts")) + from memory_migrations import _cmd_head + + _cmd_head() + captured = capsys.readouterr() + revision = captured.out.strip() + # Should be a non-empty hex-ish string + assert len(revision) > 0 + assert all(c in "0123456789abcdef" for c in revision) + + # ============================================================================= # Backfill tests for the Conversations table migration (b2f4c6a8d1e3) # =============================================================================