Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions build_scripts/memory_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ def _cmd_check() -> None:
tmp_path.unlink(missing_ok=True)


def _cmd_head() -> None:
"""Print the current Alembic head revision ID."""
from pathlib import Path

from alembic.config import Config
from alembic.script import ScriptDirectory

script_location = Path(__file__).parent.parent / "pyrit" / "memory" / "alembic"
config = Config()
config.set_main_option("script_location", str(script_location))
head = ScriptDirectory.from_config(config).get_current_head()
print(head)


def _build_parser() -> argparse.ArgumentParser:
"""Build the CLI argument parser."""
parser = argparse.ArgumentParser(
Expand All @@ -71,6 +85,8 @@ def _build_parser() -> argparse.ArgumentParser:

sub.add_parser("check", help="Verify all migrations apply cleanly and add up to the current memory models.")

sub.add_parser("head", help="Print the current Alembic head revision ID.")

return parser


Expand All @@ -82,6 +98,8 @@ def main() -> int:
_cmd_generate(message=args.message, force=args.force)
elif args.command == "check":
_cmd_check()
elif args.command == "head":
_cmd_head()

return 0

Expand Down
317 changes: 317 additions & 0 deletions build_scripts/migrate_prod_memory_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Deliberate schema migration tool for production databases.

This script is the ONLY sanctioned way to apply Alembic migrations to a production
database. It is intended to be run during the release process (see
doc/contributing/10_release_process.md) or by a CD pipeline — never by normal
application startup.

Safety rails:
- Requires an explicit --target-revision argument (no blind "upgrade head").
- Validates the environment (release branch, clean working tree, no .dev version).
- Validates the target revision exists in the local migration graph.
- Reports current DB revision before and after migration.
- Runs schema check after upgrade to confirm models match.
- Interactive confirmation when running in a terminal.
- Exits non-zero on any failure.

Usage:
python build_scripts/migrate_prod_memory_schema.py \
--target-revision c3d5e7f9a1b2

The script reads the production connection string from the
AZURE_SQL_DB_CONNECTION_STRING environment variable (same as AzureSQLMemory).
"""

import argparse
import os
import subprocess
import sys

import dotenv
from alembic import command
from alembic.script import ScriptDirectory
from alembic.util.exc import AutogenerateDiffsDetected
from sqlalchemy import create_engine, text

from pyrit.common.path import CONFIGURATION_DIRECTORY_PATH

# Load .env files from ~/.pyrit/ (same files that initialize_pyrit_async loads)
# Use override=False so explicitly-set env vars take precedence over .env values
for _env_file in [CONFIGURATION_DIRECTORY_PATH / ".env", CONFIGURATION_DIRECTORY_PATH / ".env.local"]:
if _env_file.exists():
dotenv.load_dotenv(_env_file, override=False, interpolate=True)

# ANSI color codes
_GREEN = "\033[92m"
_RED = "\033[91m"
_YELLOW = "\033[93m"
_RESET = "\033[0m"

_CONNECTION_STRING_ENV = "AZURE_SQL_DB_CONNECTION_STRING_PROD"


def _print_error(message: str) -> None:
"""Print an error message in red to stderr."""
print(f"{_RED}ERROR: {message}{_RESET}", file=sys.stderr)


def _print_warning(message: str) -> None:
"""Print a warning message in yellow."""
print(f"{_YELLOW}WARNING: {message}{_RESET}")


def _print_success(message: str) -> None:
"""Print a success message in green."""
print(f"{_GREEN}{message}{_RESET}")


def _get_current_revision(*, engine) -> str | None:
"""
Read the current Alembic revision from the database.

Returns None if no version table exists (fresh database).
"""
from sqlalchemy import inspect as sa_inspect

from pyrit.memory.migration import PYRIT_MEMORY_ALEMBIC_VERSION_TABLE

inspector = sa_inspect(engine)
if PYRIT_MEMORY_ALEMBIC_VERSION_TABLE not in inspector.get_table_names():
return None

with engine.connect() as conn:
result = conn.execute(text(f"SELECT version_num FROM {PYRIT_MEMORY_ALEMBIC_VERSION_TABLE}"))
row = result.fetchone()
return row[0] if row else None


def _validate_revision_exists(*, target_revision: str) -> bool:
"""Check that the target revision exists in the local migration script directory."""

from pathlib import Path

script_location = Path(__file__).parent.parent / "pyrit" / "memory" / "alembic"
# Build a minimal config to get the ScriptDirectory
from alembic.config import Config

config = Config()
config.set_main_option("script_location", str(script_location))
script_dir = ScriptDirectory.from_config(config)

try:
script_dir.get_revision(target_revision)
return True
except Exception:
return False


def _get_all_revision_ids() -> list[str]:
"""Get all revision IDs from the local migration graph."""
from pathlib import Path

from alembic.config import Config

script_location = Path(__file__).parent.parent / "pyrit" / "memory" / "alembic"
config = Config()
config.set_main_option("script_location", str(script_location))
script_dir = ScriptDirectory.from_config(config)
return [rev.revision for rev in script_dir.walk_revisions()]


def _check_release_environment() -> list[str]:
"""
Validate that the script is running in a proper release environment.

Returns a list of warning/error messages. Empty list means all checks pass.
"""
issues: list[str] = []

# Check 1: Running from a release branch (releases/vX.Y.Z)
try:
branch = subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
text=True,
stderr=subprocess.DEVNULL,
).strip()
if not branch.startswith("releases/"):
issues.append(
f"Not on a release branch (current: '{branch}'). "
"Production migrations should run from 'releases/vX.Y.Z'."
)
except (subprocess.CalledProcessError, FileNotFoundError):
issues.append("Could not determine current Git branch.")

# Check 2: Clean working tree (no uncommitted changes to memory files)
try:
dirty_files = subprocess.check_output(
["git", "status", "--porcelain", "--", "pyrit/memory/"],
text=True,
stderr=subprocess.DEVNULL,
).strip()
if dirty_files:
issues.append(
"Uncommitted changes detected in pyrit/memory/:\n"
f" {dirty_files}\n"
" Commit or stash changes before migrating production."
)
except (subprocess.CalledProcessError, FileNotFoundError):
issues.append("Could not check Git working tree status.")

# Check 3: Not a .dev version (should be a release version)
try:
from pyrit import __version__

if ".dev" in __version__:
issues.append(
f"PyRIT version is '{__version__}' (development). "
"Production migrations should use a release version (no .dev suffix)."
)
except ImportError:
issues.append("Could not determine PyRIT version.")

return issues


def _run_migration(*, connection_string: str, target_revision: str) -> int:
"""
Execute the migration against the target database.

Returns 0 on success, 1 on failure.
"""
from pathlib import Path

from alembic.config import Config

from pyrit.memory.migration import check_schema_migrations

engine = create_engine(connection_string)

# Step 1: Report current state
current_rev = _get_current_revision(engine=engine)
print(f"Current database revision: {current_rev or '(none — fresh database)'}")
print(f"Target revision: {target_revision}")
print()

if current_rev == target_revision:
_print_success("Database is already at the target revision. Nothing to do.")
engine.dispose()
return 0

# Step 2: Run upgrade to specific revision (not head)
print("Applying migrations...")
try:
script_location = Path(__file__).parent.parent / "pyrit" / "memory" / "alembic"
with engine.begin() as connection:
config = Config()
config.set_main_option("script_location", str(script_location))
config.attributes["connection"] = connection

# Stamp unversioned schemas if needed
from pyrit.memory.migration import _validate_and_stamp_unversioned_memory_schema

_validate_and_stamp_unversioned_memory_schema(config=config, connection=connection)
command.upgrade(config, target_revision)
except Exception as e:
_print_error(f"Migration failed: {e}")
engine.dispose()
return 1

# Step 3: Verify new revision
new_rev = _get_current_revision(engine=engine)
print(f"Database revision after migration: {new_rev}")

if new_rev != target_revision:
_print_error(f"Expected revision {target_revision}, but database is at {new_rev}.")
engine.dispose()
return 1

# Step 4: Schema check — verify models match
print("Verifying schema matches models...")
try:
check_schema_migrations(engine=engine)
except AutogenerateDiffsDetected as e:
_print_error(f"Schema check failed after migration: {e}")
_print_error("The models in this codebase do not match the database schema.")
engine.dispose()
return 1

_print_success("Migration completed and verified successfully.")
engine.dispose()
return 0


def _build_parser() -> argparse.ArgumentParser:
"""Build the CLI argument parser."""
parser = argparse.ArgumentParser(
description=(
"Apply Alembic schema migrations to a production database. "
"Validates release environment and requires an explicit target revision."
),
)
parser.add_argument(
"--target-revision",
required=True,
help="The exact Alembic revision ID to upgrade to (e.g., 'c3d5e7f9a1b2').",
)
parser.add_argument(
"--connection-string-env",
default=_CONNECTION_STRING_ENV,
help=f"Environment variable containing the connection string. Default: {_CONNECTION_STRING_ENV}",
)
parser.add_argument(
"--skip-environment-check",
action="store_true",
help="Skip release environment checks (branch, clean tree, version). Use only in CI with caution.",
)
return parser


def main() -> int:
"""Entry point for production schema migration."""
args = _build_parser().parse_args()

# Safety rail 1: Require connection string
connection_string = os.environ.get(args.connection_string_env)
if not connection_string:
_print_error(f"Environment variable '{args.connection_string_env}' is not set.")
return 1

# Safety rail 2: Validate target revision exists in local migration graph
if not _validate_revision_exists(target_revision=args.target_revision):
available = _get_all_revision_ids()
_print_error(
f"Target revision '{args.target_revision}' not found in the local migration graph.\n"
f"Available revisions: {', '.join(available)}"
)
return 1

# Safety rail 3: Verify release environment (branch, clean tree, version)
if not args.skip_environment_check:
issues = _check_release_environment()
if issues:
_print_error("Release environment checks failed:")
for issue in issues:
_print_error(f" - {issue}")
_print_error("Fix the above issues or pass --skip-environment-check (CI only).")
return 1
else:
_print_warning("Skipping release environment checks (--skip-environment-check).")

# Confirmation prompt (unless running in CI where stdin is not interactive)
if sys.stdin.isatty():
print(f"About to migrate production database to revision: {args.target_revision}")
print(f"Connection string source: ${args.connection_string_env}")
response = input("Type 'yes' to proceed: ")
if response.strip().lower() != "yes":
print("Aborted.")
return 1

return _run_migration(connection_string=connection_string, target_revision=args.target_revision)


if __name__ == "__main__":
raise SystemExit(main())
Loading
Loading