diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index e429526..31c265e 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -8,6 +8,7 @@ uv sync --extra dev # create .venv and install dev d .scripts/run.sh integration-test # requires ADLA + az login + .env .scripts/run.sh lint # ruff check + format --check .scripts/run.sh fix # ruff auto-fix + format +.scripts/run.sh upload # build and upload the adapter to a storage account # Single test file or test uv run pytest tests/unit/test_script_builder.py -v diff --git a/.scripts/run.sh b/.scripts/run.sh index cc29bbd..a87ae5d 100755 --- a/.scripts/run.sh +++ b/.scripts/run.sh @@ -163,7 +163,7 @@ run_build() { } run_upload() { - write_step "upload: Building wheel and uploading to static site" + write_step "upload: Building wheel and uploading to static storage" run_build local dist_dir="${PROJECT_DIR}/dist" local whl diff --git a/README.md b/README.md index e899978..48d366e 100644 --- a/README.md +++ b/README.md @@ -175,6 +175,9 @@ my_project: | `job_timeout_seconds` | `36000` | Max seconds to wait for a SCOPE job before timing out | | `max_files_per_trigger` | `50` | Default max files per SCOPE job (overridable per-model) | | `max_bytes_per_trigger` | `10737418240000` (10 TB) | Default max estimated bytes per batch (overridable per-model) | +| `max_file_count_per_output_file_set` | `5000` | SCOPE `@@MaxFileCountPerOutputFileSet` (overridable per-model) | +| `cancel_jobs_on_shutdown` | `true` | Cancel in-flight ADLA jobs on SIGINT/SIGTERM | +| `wait_on_cancel_seconds` | `30` | Per-job wait for ADLA terminal state when cancelling on shutdown | | `http_timeout_seconds` | `30` | HTTP request timeout for ADLA REST API calls | | `http_retries` | `3` | Number of HTTP retries for transient errors (429, 5xx) | | `scope_feature_previews` | `"EnableDeltaTableDynamicInsert:on"` | SCOPE feature preview flags (overridable per-model) | @@ -386,6 +389,39 @@ Send `SIGTERM` or `SIGINT` (Ctrl+C) to the dbt process. The adapter finishes the > **Note:** `processing_time` is only supported for `incremental` materializations. +## Graceful shutdown of ADLA jobs + +When the dbt process receives `SIGINT` (Ctrl+C) or `SIGTERM`, the adapter: + +1. Sets a shared shutdown flag so every in-flight `submit_and_wait` loop self-cancels its own SCOPE job. +2. Snapshots the process-wide registry of in-flight ADLA jobs and fans out parallel `CancelJob` REST calls — one worker per job, bounded at 32 threads. +3. **Waits for each cancelled job to reach a terminal `Ended` state** (typically a few seconds in ADLA) up to `wait_on_cancel_seconds` per job. Since cancels run in parallel, total wall-clock is `~wait_on_cancel_seconds` regardless of job count. + +This is on by default. To opt out (e.g. in a CI environment where you'd rather let jobs run to completion): + +```yaml +# profiles.yml +outputs: + dev: + type: scope + cancel_jobs_on_shutdown: false +``` + +To tune how long the adapter blocks waiting for ADLA to confirm each cancel: + +```yaml +outputs: + dev: + type: scope + wait_on_cancel_seconds: 60 # default 30 +``` + +**Caveats:** + +- `SIGKILL` is uncatchable at the OS level — no Python handler can run. The existing `cancel_orphaned_jobs` cleanup (runs at the start of every new `dbt run` per model) is the safety net for that case: orphaned jobs from previous runs are cancelled before submitting a new one. +- On Windows, `SIGTERM` is not delivered the same way as on POSIX; `SIGINT` (Ctrl+C) works. +- Only jobs submitted by the **current** Python process are tracked — orphans from earlier `dbt run` invocations are handled by `cancel_orphaned_jobs`, not by this shutdown hook. + ## Contributing See [CONTRIBUTING.md](CONTRIBUTING.md). diff --git a/dbt/adapters/scope/_file_lock.py b/dbt/adapters/scope/_file_lock.py index d1baa8b..511dfb8 100644 --- a/dbt/adapters/scope/_file_lock.py +++ b/dbt/adapters/scope/_file_lock.py @@ -18,6 +18,8 @@ # Well-known lock file for Azure CLI token serialization AZ_CLI_TOKEN_LOCK = str(Path(tempfile.gettempdir()) / "dbt-scope-az-cli-token") +# Well-known lock file for custom (e.g. Fabric notebook / SNI) token credentials +FABRIC_TOKEN_LOCK = str(Path(tempfile.gettempdir()) / "dbt-scope-fabric-token") # Default timeout for acquiring the lock (seconds). With several xdist workers # all racing for the Azure CLI token at startup, contention is high. diff --git a/dbt/adapters/scope/adls_gen1_client.py b/dbt/adapters/scope/adls_gen1_client.py index 6e50efe..cbd7517 100644 --- a/dbt/adapters/scope/adls_gen1_client.py +++ b/dbt/adapters/scope/adls_gen1_client.py @@ -7,22 +7,84 @@ from __future__ import annotations +import inspect import logging import re +import threading import time from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from dataclasses import dataclass, field, replace from datetime import datetime, timezone +import requests +from azure.core.credentials import TokenCredential from azure.datalake.store import core as adls_core -from azure.identity import AzureCliCredential +from azure.identity import CredentialUnavailableError from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.scope._file_lock import AZ_CLI_TOKEN_LOCK, FileLock +from dbt.adapters.scope.delta_lake import RetryPolicy +from dbt.adapters.scope.message_retry import MessageRetryPolicy, retry_on_message log = AdapterLogger("scope") +_LEGACY_GEN1_SCOPE = "https://datalake.azure.net//.default" + + +class _LegacyDataLakeCredentialAdapter: + """Bridge a modern ``azure.core.credentials.TokenCredential`` to the legacy + ``azure.datalake.store.lib.DataLakeCredential`` ``signed_session()`` API. + + Fabric notebook runtimes ship ``azure-datalake-store`` 0.0.5x preinstalled. + That version's ``DatalakeRESTInterface.__init__`` silently drops the + modern ``token_credential=`` kwarg and falls back to MSAL device-code + interactive auth — a hard-failure on any headless surface. This adapter + keeps the bundled wheel honest by exposing ``signed_session()`` on top + of our non-interactive credential, refreshing the bearer token a few + minutes before expiry. The 5-minute skew matches the legacy SDK's own + 100-second slop window with extra headroom for long-running directory + walks. + """ + + _REFRESH_LEAD_SECONDS = 300 + + def __init__(self, credential: TokenCredential, *, scope: str = _LEGACY_GEN1_SCOPE) -> None: + self._credential = credential + self._scope = scope + self._lock = threading.Lock() + self._access_token: str | None = None + self._expires_on: int = 0 + + def _refresh(self) -> None: + token = self._credential.get_token(self._scope) + self._access_token = token.token + self._expires_on = int(token.expires_on) + + def signed_session(self) -> requests.Session: + with self._lock: + now = int(time.time()) + if not self._access_token or now > self._expires_on - self._REFRESH_LEAD_SECONDS: + self._refresh() + bearer = self._access_token + session = requests.Session() + session.headers["Authorization"] = f"Bearer {bearer}" + return session + + def refresh_token(self, authority: str | None = None) -> None: + with self._lock: + self._refresh() + + +def _legacy_gen1_sdk_in_use() -> bool: + """Return True when the running ``AzureDLFileSystem.__init__`` predates the + 1.x ``token_credential=`` kwarg and therefore needs the legacy adapter.""" + try: + params = inspect.signature(adls_core.AzureDLFileSystem.__init__).parameters + except (TypeError, ValueError): + return False + return "token_credential" not in params + + class _SuppressFileNotFound(logging.Filter): """Reject Azure SDK log records that report a 404 / FileNotFoundError. @@ -80,14 +142,24 @@ def _list_one_dir( fs: adls_core.AzureDLFileSystem, dir_path: str, depth: int, + *, + message_retry_policy: MessageRetryPolicy | None = None, ) -> tuple[list[dict], list[dict], str, int, float]: """List a single directory. Returns (files, subdirs, path, depth, elapsed_ms).""" t0 = time.monotonic() + policy = message_retry_policy or MessageRetryPolicy.disabled() try: - entries = fs.ls(dir_path, detail=True) + entries = retry_on_message( + lambda: fs.ls(dir_path, detail=True), + policy=policy, + label=f"gen1.ls {dir_path}", + ) except FileNotFoundError: log.debug(f"Path not found (skipping): {dir_path}") return [], [], dir_path, depth, (time.monotonic() - t0) * 1000 + except CredentialUnavailableError: + log.error(f"_list_directory: credential acquisition exhausted for {dir_path}") + raise except Exception: log.warning(f"Failed to list {dir_path} (skipping)") return [], [], dir_path, depth, (time.monotonic() - t0) * 1000 @@ -105,23 +177,50 @@ def __init__( self, account: str, *, - lock_file: str = AZ_CLI_TOKEN_LOCK, + credential: TokenCredential | None = None, + retry_policy: RetryPolicy | None = None, + message_retry_policy: MessageRetryPolicy | None = None, ) -> None: self._account = account - self._lock_file = lock_file + self._credential = credential + self._retry_policy = retry_policy + self._message_retry_policy = message_retry_policy or MessageRetryPolicy.disabled() self._fs: adls_core.AzureDLFileSystem | None = None self._file_cache: dict[tuple[str, str | None], list[FileInfo]] = {} self._enrichment_cache: dict[str, tuple[int, tuple[str, ...]]] = {} + def _retry(self, op, *, label: str): + return retry_on_message(op, policy=self._message_retry_policy, label=label) + def _get_fs(self) -> adls_core.AzureDLFileSystem: - """Lazily initialize the ADLS Gen1 filesystem client.""" + """Lazily initialize the ADLS Gen1 filesystem client. + + On Fabric notebook runtimes the preinstalled ``azure-datalake-store`` + is the 0.0.5x line, whose ``DatalakeRESTInterface.__init__`` silently + ignores ``token_credential=`` and falls back to MSAL device-code + interactive auth. We detect that signature mismatch and route + through :class:`_LegacyDataLakeCredentialAdapter` instead. + """ if self._fs is None: - with FileLock(self._lock_file): - credential = AzureCliCredential() - self._fs = adls_core.AzureDLFileSystem( - token_credential=credential, - store_name=self._account, - ) + if self._credential is None: + raise RuntimeError( + "AdlsGen1Client requires an explicit ``credential``; " + "callers should pass ``credential=build_credential(creds)``." + ) + if _legacy_gen1_sdk_in_use(): + log.debug( + "AdlsGen1Client: legacy azure-datalake-store detected — " + "wrapping credential in _LegacyDataLakeCredentialAdapter" + ) + self._fs = adls_core.AzureDLFileSystem( + token=_LegacyDataLakeCredentialAdapter(self._credential), + store_name=self._account, + ) + else: + self._fs = adls_core.AzureDLFileSystem( + token_credential=self._credential, + store_name=self._account, + ) return self._fs def list_files( @@ -162,14 +261,20 @@ def list_files( walk_start = time.monotonic() if recursive: - raw_entries = self._walk(fs, root, max_workers) + raw_entries = self._walk(fs, root, max_workers, self._message_retry_policy) else: t0 = time.monotonic() try: - raw_entries = fs.ls(root, detail=True) + raw_entries = self._retry( + lambda: fs.ls(root, detail=True), + label=f"gen1.list_files {root}", + ) except FileNotFoundError: log.debug(f"Path not found: {root}") return [] + except CredentialUnavailableError: + log.error(f"list_files: credential acquisition exhausted for {root}") + raise except Exception: log.warning(f"Failed to list {root}") return [] @@ -221,6 +326,7 @@ def _walk( fs: adls_core.AzureDLFileSystem, root: str, max_workers: int, + message_retry_policy: MessageRetryPolicy | None = None, ) -> list[dict]: """Walk directories in parallel, logging per-directory progress.""" all_files: list[dict] = [] @@ -229,7 +335,9 @@ def _walk( with ThreadPoolExecutor(max_workers=max_workers) as executor: futures: dict[Future, tuple[str, int]] = {} - f = executor.submit(_list_one_dir, fs, root, 0) + f = executor.submit( + _list_one_dir, fs, root, 0, message_retry_policy=message_retry_policy + ) futures[f] = (root, 0) while futures: @@ -239,6 +347,10 @@ def _walk( futures.pop(completed) try: files, dirs, dir_path, depth, elapsed_ms = completed.result() + except CredentialUnavailableError: + for pending in futures: + pending.cancel() + raise except Exception: dirs_done += 1 continue @@ -255,7 +367,13 @@ def _walk( all_files.extend(files) for d in sorted(dirs, key=lambda e: e.get("name", "")): - new_f = executor.submit(_list_one_dir, fs, d["name"], depth + 1) + new_f = executor.submit( + _list_one_dir, + fs, + d["name"], + depth + 1, + message_retry_policy=message_retry_policy, + ) futures[new_f] = (d["name"], depth + 1) if futures: @@ -344,6 +462,9 @@ def enrich_with_estimates(self, files: list[FileInfo]) -> list[FileInfo]: contributing_files=contrib_tuple, ) ) + except CredentialUnavailableError: + log.error(f"enrich_with_estimates: credential acquisition exhausted for {f.path}") + raise except Exception: log.warning(f"Failed to estimate bytes for {f.path} — using file length") self._enrichment_cache[f.path] = (f.length, ()) @@ -356,20 +477,22 @@ def enrich_with_estimates(self, files: list[FileInfo]) -> list[FileInfo]: ) return enriched - @staticmethod - def _directory_exists(path: str, fs: adls_core.AzureDLFileSystem) -> bool: + def _directory_exists(self, path: str, fs: adls_core.AzureDLFileSystem) -> bool: """Check if a directory exists on ADLS Gen1.""" try: - info = fs.info(path) + info = self._retry(lambda: fs.info(path), label=f"gen1.info {path}") return info.get("type") == "DIRECTORY" except FileNotFoundError: return False + except CredentialUnavailableError: + log.error(f"_directory_exists: credential acquisition exhausted for {path}") + raise except Exception: log.debug(f"_directory_exists: error checking {path} — assuming not exists") return False - @staticmethod def _list_directory_files( + self, dir_path: str, fs: adls_core.AzureDLFileSystem, ) -> list[dict]: @@ -380,9 +503,15 @@ def _list_directory_files( while dirs_to_visit: current = dirs_to_visit.pop() try: - entries = fs.ls(current, detail=True) + entries = self._retry( + lambda c=current: fs.ls(c, detail=True), + label=f"gen1.ls {current}", + ) except FileNotFoundError: continue + except CredentialUnavailableError: + log.error(f"_list_directory_files: credential acquisition exhausted for {current}") + raise except Exception: log.debug(f"_list_directory_files: failed to list {current} — skipping") continue diff --git a/dbt/adapters/scope/checkpoint.py b/dbt/adapters/scope/checkpoint.py index b9c64b2..232528e 100644 --- a/dbt/adapters/scope/checkpoint.py +++ b/dbt/adapters/scope/checkpoint.py @@ -23,12 +23,13 @@ from dataclasses import dataclass from datetime import datetime, timezone -from azure.identity import AzureCliCredential +from azure.core.credentials import TokenCredential +from azure.identity import CredentialUnavailableError from azure.storage.filedatalake import DataLakeServiceClient from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.scope._file_lock import AZ_CLI_TOKEN_LOCK -from dbt.adapters.scope.delta_lake import AbfssLocation, LockedTokenCredential +from dbt.adapters.scope.delta_lake import AbfssLocation, RetryPolicy +from dbt.adapters.scope.message_retry import MessageRetryPolicy, retry_on_message log = AdapterLogger("scope") @@ -36,6 +37,28 @@ _WATERMARK_FILE = "watermark.json" _SOURCES_DIR = "sources" + +def _json_default(o: object) -> str: + """``json.dumps`` ``default=`` hook for source records. + + Source records can carry ``datetime`` values when they originate from + a previously written parquet snapshot (DuckDB returns ``TIMESTAMP`` + columns as Python ``datetime``). Convert them back to ISO 8601 strings + so the records round-trip through NDJSON cleanly. + + DuckDB's ``TIMESTAMP`` is naive — it drops the timezone offset on the + cast that produces the snapshot — but every value produced inside + this module is created from ``datetime.now(timezone.utc)``. So if + the datetime comes back naive we re-attach UTC, preserving the + timezone-aware ISO 8601 contract that JSONL diffs use. + """ + if isinstance(o, datetime): + if o.tzinfo is None: + o = o.replace(tzinfo=timezone.utc) + return o.isoformat() + raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable") + + # Virtual column names that map to SCOPE FILE.* functions VIRTUAL_COLUMNS: dict[str, str] = { "source_file_uri": "FILE.URI()", @@ -80,15 +103,31 @@ def from_json(cls, raw: str) -> Watermark: ) -def _get_service(parsed: AbfssLocation, credential: LockedTokenCredential) -> DataLakeServiceClient: +def _get_service(parsed: AbfssLocation, credential: TokenCredential) -> DataLakeServiceClient: return DataLakeServiceClient(account_url=parsed.account_url, credential=credential) class CheckpointManager: """Manage ``_checkpoint/`` on ADLS Gen2 Delta table roots.""" - def __init__(self, *, lock_file: str = AZ_CLI_TOKEN_LOCK) -> None: - self._credential = LockedTokenCredential(AzureCliCredential(), lock_file=lock_file) + def __init__( + self, + *, + credential: TokenCredential | None = None, + retry_policy: RetryPolicy | None = None, + message_retry_policy: MessageRetryPolicy | None = None, + ) -> None: + if credential is None: + raise RuntimeError( + "CheckpointManager requires an explicit ``credential``; " + "callers should pass ``credential=build_credential(creds)``." + ) + self._credential = credential + self._retry_policy = retry_policy + self._message_retry_policy = message_retry_policy or MessageRetryPolicy.disabled() + + def _retry(self, op, *, label: str): + return retry_on_message(op, policy=self._message_retry_policy, label=label) # -- Watermark --------------------------------------------------------- @@ -99,7 +138,7 @@ def read_watermark(self, delta_location: str) -> Watermark | None: log.warning(f"read_watermark: invalid delta_location: {delta_location}") return None - try: + def _read() -> Watermark: service = _get_service(parsed, self._credential) fs = service.get_file_system_client(parsed.container) file_path = f"{parsed.path.rstrip('/')}/{_CHECKPOINT_DIR}/{_WATERMARK_FILE}" @@ -113,6 +152,14 @@ def read_watermark(self, delta_location: str) -> Watermark | None: f"batch_id={watermark.batch_id}" ) return watermark + + try: + return self._retry(_read, label=f"checkpoint.read_watermark {delta_location}") + except CredentialUnavailableError: + # Don't mask auth failures as "no checkpoint" — that would + # silently flip an incremental run into a full refresh. + log.error(f"read_watermark: credential acquisition exhausted for {delta_location}") + raise except Exception: log.debug(f"No checkpoint found for {delta_location} (first run or full refresh)") return None @@ -124,7 +171,7 @@ def write_watermark(self, delta_location: str, watermark: Watermark) -> None: log.warning(f"write_watermark: invalid delta_location: {delta_location}") return - try: + def _write() -> None: service = _get_service(parsed, self._credential) fs = service.get_file_system_client(parsed.container) @@ -142,6 +189,9 @@ def write_watermark(self, delta_location: str, watermark: Watermark) -> None: f"modified_time={watermark.modified_time}, " f"batch_id={watermark.batch_id} → {delta_location}" ) + + try: + self._retry(_write, label=f"checkpoint.write_watermark {delta_location}") except Exception: log.error(f"write_watermark failed for {delta_location}") raise @@ -153,13 +203,16 @@ def delete_watermark(self, delta_location: str) -> None: log.warning(f"delete_watermark: invalid delta_location: {delta_location}") return - try: + def _delete() -> None: service = _get_service(parsed, self._credential) fs = service.get_file_system_client(parsed.container) file_path = f"{parsed.path.rstrip('/')}/{_CHECKPOINT_DIR}/{_WATERMARK_FILE}" file_client = fs.get_file_client(file_path) file_client.delete_file() log.debug(f"Deleted watermark for {delta_location}") + + try: + self._retry(_delete, label=f"checkpoint.delete_watermark {delta_location}") except Exception: log.debug(f"No watermark to delete for {delta_location} (already clean)") @@ -209,7 +262,7 @@ def write_batch_sources( is_compaction = batch_id > 0 and batch_id % compaction_interval == 0 - try: + def _write() -> None: service = _get_service(parsed, self._credential) fs = service.get_file_system_client(parsed.container) sources_dir = f"{parsed.path.rstrip('/')}/{_CHECKPOINT_DIR}/{_SOURCES_DIR}" @@ -226,6 +279,9 @@ def write_batch_sources( f"{'parquet snapshot' if is_compaction else 'jsonl diff'}) → " f"{delta_location}" ) + + try: + self._retry(_write, label=f"checkpoint.write_batch_sources batch={batch_id}") except Exception: log.error(f"write_batch_sources failed for batch {batch_id}") raise @@ -251,7 +307,7 @@ def _build_source_records( @staticmethod def _write_jsonl(fs, sources_dir: str, batch_id: int, records: list[dict]) -> None: - lines = [json.dumps(r, separators=(",", ":")) for r in records] + lines = [json.dumps(r, default=_json_default, separators=(",", ":")) for r in records] content = "\n".join(lines) file_path = f"{sources_dir}/{batch_id}" file_client = fs.get_file_client(file_path) @@ -345,18 +401,36 @@ def _write_snapshot_parquet( # Add current batch records all_records.extend(current_batch_records) - # Write consolidated parquet via DuckDB (NDJSON → read_json_auto → COPY) + # Write consolidated parquet via DuckDB (NDJSON → read_json_auto → COPY). + # + # ``batchProcessingTime`` is written here as an ISO 8601 string, but + # DuckDB's ``read_json_auto`` will infer it as ``TIMESTAMP`` once the + # NDJSON has enough rows of consistent ISO text. We force the cast + # explicitly so the parquet schema is deterministic regardless of + # sample size — and so subsequent reads of this snapshot always come + # back as ``datetime`` (handled by ``_json_default`` on the next + # round-trip). parquet_local = f"/tmp/dbt_scope_{batch_id}.parquet" ndjson_local = f"/tmp/dbt_scope_{batch_id}.ndjson" try: with open(ndjson_local, "w") as nf: for r in all_records: - nf.write(json.dumps(r) + "\n") + nf.write(json.dumps(r, default=_json_default) + "\n") conn = duckdb.connect() try: + # Cast every column explicitly so the snapshot schema is + # fully deterministic regardless of DuckDB's + # ``read_json_auto`` heuristics (which vary with sample + # size and version). conn.execute( - f"CREATE TABLE sources AS SELECT * FROM read_json_auto('{ndjson_local}')" + "CREATE TABLE sources AS " + "SELECT " + "CAST(path AS VARCHAR) AS path, " + 'CAST("modificationTime" AS BIGINT) AS "modificationTime", ' + 'CAST("batchId" AS BIGINT) AS "batchId", ' + 'CAST("batchProcessingTime" AS TIMESTAMP) AS "batchProcessingTime" ' + f"FROM read_json_auto('{ndjson_local}')" ) conn.execute(f"COPY sources TO '{parquet_local}' (FORMAT PARQUET)") finally: @@ -388,12 +462,11 @@ def cleanup_sources( if parsed is None: return 0 - try: + def _cleanup() -> int: service = _get_service(parsed, self._credential) fs = service.get_file_system_client(parsed.container) sources_dir = f"{parsed.path.rstrip('/')}/{_CHECKPOINT_DIR}/{_SOURCES_DIR}" - # List all files files: list[tuple[str, str]] = [] # (name, full_path) for path_info in fs.get_paths(path=sources_dir, recursive=False): if getattr(path_info, "is_directory", False): @@ -404,7 +477,6 @@ def cleanup_sources( if len(files) <= max_files: return 0 - # Sort: JSONL files (numeric names) first by batch_id, then parquet by name def sort_key(item: tuple[str, str]) -> tuple[int, str]: name = item[0] try: @@ -414,7 +486,6 @@ def sort_key(item: tuple[str, str]) -> tuple[int, str]: files.sort(key=sort_key) - # Delete oldest files until we're at the limit to_delete = len(files) - max_files deleted = 0 for _name, full_path in files[:to_delete]: @@ -429,6 +500,9 @@ def sort_key(item: tuple[str, str]) -> tuple[int, str]: f"cleanup_sources: deleted {deleted} files (was {len(files)}, limit {max_files})" ) return deleted + + try: + return self._retry(_cleanup, label=f"checkpoint.cleanup_sources {delta_location}") except Exception: log.warning(f"cleanup_sources failed for {delta_location}") return 0 @@ -439,7 +513,7 @@ def delete_all_sources(self, delta_location: str) -> None: if parsed is None: return - try: + def _delete_all() -> None: service = _get_service(parsed, self._credential) fs = service.get_file_system_client(parsed.container) sources_dir = f"{parsed.path.rstrip('/')}/{_CHECKPOINT_DIR}/{_SOURCES_DIR}" @@ -455,6 +529,9 @@ def delete_all_sources(self, delta_location: str) -> None: except Exception: pass log.debug(f"delete_all_sources: deleted {deleted} files for {delta_location}") + + try: + self._retry(_delete_all, label=f"checkpoint.delete_all_sources {delta_location}") except Exception: log.debug(f"No sources to delete for {delta_location} (already clean)") @@ -464,7 +541,7 @@ def list_source_files(self, delta_location: str) -> list[str]: if parsed is None: return [] - try: + def _list() -> list[str]: service = _get_service(parsed, self._credential) fs = service.get_file_system_client(parsed.container) sources_dir = f"{parsed.path.rstrip('/')}/{_CHECKPOINT_DIR}/{_SOURCES_DIR}" @@ -475,6 +552,9 @@ def list_source_files(self, delta_location: str) -> list[str]: continue names.append(path_info.name.rsplit("/", 1)[-1]) return sorted(names) + + try: + return self._retry(_list, label=f"checkpoint.list_source_files {delta_location}") except Exception: return [] @@ -489,17 +569,21 @@ def read_batch_source(self, delta_location: str, batch_id: int) -> list[dict]: fs = service.get_file_system_client(parsed.container) sources_dir = f"{parsed.path.rstrip('/')}/{_CHECKPOINT_DIR}/{_SOURCES_DIR}" - # Try JSONL first - try: + def _read_jsonl() -> list[dict]: jsonl_path = f"{sources_dir}/{batch_id}" file_client = fs.get_file_client(jsonl_path) raw = file_client.download_file().readall().decode("utf-8") return [json.loads(line) for line in raw.strip().split("\n") if line.strip()] + + try: + return self._retry( + _read_jsonl, + label=f"checkpoint.read_batch_source jsonl batch={batch_id}", + ) except Exception: pass - # Try parquet snapshot (compaction batches) - try: + def _read_parquet() -> list[dict]: import os import duckdb @@ -519,12 +603,17 @@ def read_batch_source(self, delta_location: str, batch_id: int) -> list[dict]: f"DESCRIBE SELECT * FROM read_parquet('{tmp_path}')" ).fetchall() ] - # Filter to only records for this batch_id all_records = [dict(zip(cols, row, strict=False)) for row in rows] return [r for r in all_records if r.get("batchId") == batch_id] finally: conn.close() os.remove(tmp_path) + + try: + return self._retry( + _read_parquet, + label=f"checkpoint.read_batch_source parquet batch={batch_id}", + ) except Exception: pass diff --git a/dbt/adapters/scope/connections.py b/dbt/adapters/scope/connections.py index ee04a1f..298e12e 100644 --- a/dbt/adapters/scope/connections.py +++ b/dbt/adapters/scope/connections.py @@ -2,8 +2,10 @@ from __future__ import annotations +import threading import time import uuid +from concurrent.futures import ThreadPoolExecutor, wait from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any, ClassVar @@ -11,7 +13,6 @@ import agate import requests -from azure.identity import AzureCliCredential from dbt.adapters.base import BaseConnectionManager from dbt.adapters.contracts.connection import ( AdapterResponse, @@ -23,8 +24,10 @@ from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry -from dbt.adapters.scope._file_lock import AZ_CLI_TOKEN_LOCK, FileLock from dbt.adapters.scope.credentials import ScopeCredentials +from dbt.adapters.scope.delta_lake import build_credential +from dbt.adapters.scope.message_retry import MessageRetryPolicy, retry_on_message +from dbt.adapters.scope.quota_eviction import QuotaEvictionPolicy, retry_with_quota_eviction log = AdapterLogger("scope") @@ -39,6 +42,98 @@ _UUID_NAMESPACE = uuid.NAMESPACE_DNS +# --------------------------------------------------------------------------- +# Process-wide active-jobs registry (for cancel-all-on-shutdown) +# --------------------------------------------------------------------------- + + +@dataclass +class _ActiveJobEntry: + """Reference to an in-flight ADLA job, used by ``cancel_all_active_jobs``.""" + + job_id: str + name: str + handle: ScopeConnectionHandle + submitted_at: float + model_name: str | None = None + + +_active_jobs: dict[str, _ActiveJobEntry] = {} +_active_jobs_lock = threading.Lock() +_cancelled_job_ids: set[str] = set() + +# Shared shutdown event. Set by the SIGINT/SIGTERM handler in ``impl.py``; observed +# by ``submit_and_wait``'s poll loop and ``wait_for_next_cycle`` so that in-flight +# work aborts promptly when the operator hits Ctrl+C. +_shutdown_event = threading.Event() + + +def _register_active_job(entry: _ActiveJobEntry) -> None: + with _active_jobs_lock: + _active_jobs[entry.job_id] = entry + + +def _deregister_active_job(job_id: str) -> None: + with _active_jobs_lock: + _active_jobs.pop(job_id, None) + + +def _snapshot_active_jobs() -> list[_ActiveJobEntry]: + with _active_jobs_lock: + return list(_active_jobs.values()) + + +def cancel_all_active_jobs(reason: str, wait_seconds: int) -> tuple[int, int]: + """Cancel every in-flight ADLA job and wait for each to reach a terminal state. + + Each per-job cancel runs on a worker thread that POSTs ``/CancelJob`` and then + polls until ``Ended`` (Cancelled) or ``wait_seconds`` elapses. Workers run in + parallel so the total wall-clock is ``~wait_seconds`` regardless of job count. + + Returns ``(attempted, confirmed_terminal)``. + """ + entries = _snapshot_active_jobs() + if not entries: + return (0, 0) + + log.info( + f"Shutdown ({reason}) — cancelling {len(entries)} active ADLA job(s), " + f"waiting up to {wait_seconds}s for terminal state" + ) + + max_workers = min(len(entries), 32) + + def _cancel_one(entry: _ActiveJobEntry) -> bool: + if entry.job_id in _cancelled_job_ids: + return True + try: + entry.handle.cancel_job( + entry.job_id, + poll_interval=2, + max_wait=wait_seconds, + ) + _cancelled_job_ids.add(entry.job_id) + return True + except Exception as exc: + log.warning(f"Failed to cancel ADLA job '{entry.name}' ({entry.job_id}): {exc}") + _cancelled_job_ids.add(entry.job_id) + return False + + executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="scope-cancel") + try: + futures = [executor.submit(_cancel_one, e) for e in entries] + # Bound the overall wait — cancel_job has its own max_wait per job, but we + # add a small grace for thread scheduling + the synchronous POST itself. + grace_seconds = 5 + wait(futures, timeout=wait_seconds + grace_seconds) + confirmed = sum(1 for f in futures if f.done() and not f.cancelled() and f.result()) + finally: + executor.shutdown(wait=False, cancel_futures=True) + + log.info(f"Shutdown cancel complete: {confirmed}/{len(entries)} ADLA job(s) confirmed terminal") + return (len(entries), confirmed) + + @dataclass class ADLAJob: """Lightweight job tracker returned by ``submit_job``.""" @@ -93,8 +188,10 @@ def __init__(self, credentials: ScopeCredentials) -> None: self._account = credentials.adla_account self._base_url = f"https://{self._account}.azuredatalakeanalytics.net" self._timeout = credentials.http_timeout_seconds - self._credential = AzureCliCredential() + self._credential = build_credential(credentials) self._session = self._build_session(credentials.http_retries) + self._message_retry_policy = MessageRetryPolicy.from_credentials(credentials) + self._quota_eviction_policy = QuotaEvictionPolicy.from_credentials(credentials) self._cached_token: str | None = None self._token_expires_at: float = 0 self._next_job_name: str | None = None @@ -137,7 +234,16 @@ def submit_job( } log.debug(f"Submitting SCOPE job '{name}' (AU={au}) → {job_id}") log.debug(f"SCOPE script for '{name}':\n{script}") - resp = self._request("PUT", url, json=body) + + def _put() -> dict: + return self._request("PUT", url, json=body) + + resp = retry_with_quota_eviction( + _put, + eviction_ctx=self, + policy=self._quota_eviction_policy, + label=f"submit_job '{name}' ({job_id})", + ) job = ADLAJob(job_id=job_id, name=name) job.update_from_response(resp) return job @@ -185,6 +291,17 @@ def cancel_job( ) return + def cancel_job_async(self, job_id: str) -> None: + """Fire-and-forget cancel — POST CancelJob without polling for terminal state. + + Used by the quota-eviction layer where we cancel multiple victims + per attempt and only need the queue slot to free up, not strict + confirmation. + """ + url = f"{self._base_url}/jobs/{job_id}/CancelJob?api-version={API_VERSION}" + log.debug(f"Cancelling ADLA job {job_id} (fire-and-forget)") + self._request("POST", url) + def list_jobs(self, filter_expr: str | None = None, top: int = 100) -> list[dict[str, Any]]: """List ADLA jobs, optionally filtered by an OData ``$filter`` expression.""" url = f"{self._base_url}/Jobs?api-version={API_VERSION}&$top={top}" @@ -257,72 +374,125 @@ def submit_and_wait( poll_interval: int = 5, max_wait: int = 3600, model_name: str | None = None, + wait_on_cancel_seconds: int = 30, ) -> ADLAJob: - """Submit a SCOPE job and poll until terminal.""" + """Submit a SCOPE job and poll until terminal. + + Registers the job in the process-wide active-jobs registry so that + ``cancel_all_active_jobs`` can reach it on SIGINT/SIGTERM, and checks + ``_shutdown_event`` between polls — if a shutdown is in progress, this + method calls ``cancel_job`` for its own job (blocking up to + ``wait_on_cancel_seconds`` for terminal state) and raises + ``DbtRuntimeError``. + """ job = self.submit_job(name, script, au, priority, model_name=model_name) - start = time.monotonic() - last_state = job.state - consecutive_failures = 0 - - while not job.is_terminal: - elapsed = time.monotonic() - start - if elapsed >= max_wait: - raise DbtRuntimeError( - f"SCOPE job '{name}' ({job.job_id}) timed out after " - f"{elapsed:.0f}s in state {job.state}" - ) - time.sleep(poll_interval) - try: - self.poll_job(job) - consecutive_failures = 0 - except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as exc: - consecutive_failures += 1 - if consecutive_failures >= self._MAX_CONSECUTIVE_POLL_FAILURES: - raise DbtDatabaseError( - f"SCOPE job '{name}' ({job.job_id}) poll failed " - f"{consecutive_failures} consecutive times: {exc}" - ) from exc - log.warning( - f"Transient poll error for '{name}' ({job.job_id}), " - f"attempt {consecutive_failures}/{self._MAX_CONSECUTIVE_POLL_FAILURES}: {exc}" + _register_active_job( + _ActiveJobEntry( + job_id=job.job_id, + name=name, + handle=self, + submitted_at=time.monotonic(), + model_name=model_name, + ) + ) + try: + start = time.monotonic() + last_state = job.state + consecutive_failures = 0 + + while not job.is_terminal: + if _shutdown_event.is_set(): + if job.job_id not in _cancelled_job_ids: + log.info( + f"[{name}] Shutdown signalled — cancelling job {job.job_id} " + f"(waiting up to {wait_on_cancel_seconds}s for terminal state)" + ) + try: + self.cancel_job( + job.job_id, + poll_interval=2, + max_wait=wait_on_cancel_seconds, + ) + except Exception as exc: + log.warning(f"[{name}] Self-cancel failed for {job.job_id}: {exc}") + finally: + _cancelled_job_ids.add(job.job_id) + raise DbtRuntimeError( + f"SCOPE job '{name}' ({job.job_id}) cancelled by shutdown signal" + ) + + elapsed = time.monotonic() - start + if elapsed >= max_wait: + raise DbtRuntimeError( + f"SCOPE job '{name}' ({job.job_id}) timed out after " + f"{elapsed:.0f}s in state {job.state}" + ) + time.sleep(poll_interval) + try: + self.poll_job(job) + consecutive_failures = 0 + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as exc: + consecutive_failures += 1 + if consecutive_failures >= self._MAX_CONSECUTIVE_POLL_FAILURES: + raise DbtDatabaseError( + f"SCOPE job '{name}' ({job.job_id}) poll failed " + f"{consecutive_failures} consecutive times: {exc}" + ) from exc + log.warning( + f"Transient poll error for '{name}' ({job.job_id}), " + f"attempt {consecutive_failures}/{self._MAX_CONSECUTIVE_POLL_FAILURES}: {exc}" + ) + continue + if job.state != last_state: + log.debug(f"[{name}] {last_state} → {job.state}") + last_state = job.state + + if not job.succeeded: + raise DbtDatabaseError( + f"SCOPE job '{name}' ({job.job_id}) failed: {job.error_message}" ) - continue - if job.state != last_state: - log.debug(f"[{name}] {last_state} → {job.state}") - last_state = job.state - if not job.succeeded: - raise DbtDatabaseError(f"SCOPE job '{name}' ({job.job_id}) failed: {job.error_message}") - - log.debug(f"[{name}] Completed successfully ({job.result})") - return job + log.debug(f"[{name}] Completed successfully ({job.result})") + return job + finally: + _deregister_active_job(job.job_id) # -- Internal ----------------------------------------------------- def _get_token(self) -> str: if self._cached_token and time.time() < self._token_expires_at - 300: return self._cached_token - with FileLock(AZ_CLI_TOKEN_LOCK): - token = self._credential.get_token(ADLA_TOKEN_SCOPE) + # ``LockedTokenCredential`` handles both the FileLock and retry on + # transient ``CredentialUnavailableError`` failures. + token = self._credential.get_token(ADLA_TOKEN_SCOPE) self._cached_token = token.token self._token_expires_at = token.expires_on return self._cached_token def _request(self, method: str, url: str, **kwargs: Any) -> dict: - headers = { - "Authorization": f"Bearer {self._get_token()}", - "Content-Type": "application/json", - "Accept": "application/json", - } - resp = self._session.request(method, url, headers=headers, timeout=self._timeout, **kwargs) - if resp.status_code >= 400: - raise DbtDatabaseError( - f"ADLA API {method} {url} returned {resp.status_code}: {resp.text[:500]}" + def _send() -> dict: + headers = { + "Authorization": f"Bearer {self._get_token()}", + "Content-Type": "application/json", + "Accept": "application/json", + } + resp = self._session.request( + method, url, headers=headers, timeout=self._timeout, **kwargs ) - # Some endpoints (e.g. CancelJob) return 200 with an empty body - if not resp.content: - return {} - return resp.json() + if resp.status_code >= 400: + raise DbtDatabaseError( + f"ADLA API {method} {url} returned {resp.status_code}: {resp.text[:500]}" + ) + # Some endpoints (e.g. CancelJob) return 200 with an empty body + if not resp.content: + return {} + return resp.json() + + return retry_on_message( + _send, + policy=self._message_retry_policy, + label=f"ADLA {method} {url}", + ) @staticmethod def _build_session(retries: int) -> requests.Session: @@ -344,6 +514,11 @@ class ScopeConnectionManager(BaseConnectionManager): TYPE = "scope" + # Lazy-bound hook so impl.py can install signal handlers + capture + # credentials when adapters are opened. We can't import impl.py here + # (circular), so impl.py sets this on import. + _on_open: ClassVar[Any] = None + @classmethod def open(cls, connection: Connection) -> Connection: if connection.state == ConnectionState.OPEN: @@ -353,6 +528,11 @@ def open(cls, connection: Connection) -> Connection: handle = ScopeConnectionHandle(credentials) connection.handle = handle connection.state = ConnectionState.OPEN + if cls._on_open is not None: + try: + cls._on_open(credentials) + except Exception as exc: + log.warning(f"ScopeConnectionManager open hook failed: {exc}") return connection @classmethod @@ -360,11 +540,14 @@ def get_response(cls, _cursor: Any) -> AdapterResponse: return AdapterResponse(_message="OK") def cancel(self, connection: Connection) -> None: - pass + creds = getattr(connection, "credentials", None) + wait_seconds = getattr(creds, "wait_on_cancel_seconds", 30) if creds else 30 + if getattr(creds, "cancel_jobs_on_shutdown", True): + cancel_all_active_jobs("dbt-native:cancel", wait_seconds=wait_seconds) @classmethod def cancel_open(cls) -> None: - pass + cancel_all_active_jobs("dbt-native:cancel_open", wait_seconds=30) @contextmanager def exception_handler(self, sql: str): # type: ignore[override] @@ -435,6 +618,7 @@ def execute( poll_interval=credentials.poll_interval_seconds, max_wait=effective_max_wait, model_name=effective_model_name, + wait_on_cancel_seconds=credentials.wait_on_cancel_seconds, ) response = AdapterResponse( diff --git a/dbt/adapters/scope/constants.py b/dbt/adapters/scope/constants.py index 84bc06d..3538cf1 100644 --- a/dbt/adapters/scope/constants.py +++ b/dbt/adapters/scope/constants.py @@ -10,6 +10,14 @@ DEFAULT_SOURCE_COMPACTION_INTERVAL: int = 10 DEFAULT_SOURCE_RETENTION_FILES: int = 100 +# SCOPE @@MaxFileCountPerOutputFileSet cap. Compiler upstream allows [1, 1_000_000] +# with a default of 100_000, but Fabric/OneLake clusters often enforce a stricter +# 5_000 ceiling at runtime, so the adapter mirrors that as its safe default and +# always emits the SET explicitly to make the value deterministic. +DEFAULT_MAX_FILE_COUNT_PER_OUTPUT_FILE_SET: int = 5000 +MAX_FILE_COUNT_PER_OUTPUT_FILE_SET_MIN: int = 1 +MAX_FILE_COUNT_PER_OUTPUT_FILE_SET_MAX: int = 1_000_000 + # Valid values for @@DeltaLakeCommitCondition VALID_DELTA_LAKE_COMMIT_CONDITIONS: frozenset[str] = frozenset( { @@ -23,3 +31,9 @@ # Trigger mode constants DEFAULT_TRIGGER_TYPE: str = "available_now" DEFAULT_PROCESSING_TIME_TIMEOUT_SECONDS: int = 2_592_000 # 30 days + +# Graceful shutdown: on SIGINT/SIGTERM, POST CancelJob for every in-flight ADLA job +# and block until each reaches a terminal state (or wait_on_cancel_seconds elapses +# per job, with cancels running in parallel so the total wall-clock is bounded). +DEFAULT_CANCEL_JOBS_ON_SHUTDOWN: bool = True +DEFAULT_WAIT_ON_CANCEL_SECONDS: int = 30 diff --git a/dbt/adapters/scope/credentials.py b/dbt/adapters/scope/credentials.py index b9f9556..2f7bacd 100644 --- a/dbt/adapters/scope/credentials.py +++ b/dbt/adapters/scope/credentials.py @@ -2,9 +2,17 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any from dbt.adapters.contracts.connection import Credentials +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.scope.constants import ( + DEFAULT_CANCEL_JOBS_ON_SHUTDOWN, + DEFAULT_MAX_FILE_COUNT_PER_OUTPUT_FILE_SET, + DEFAULT_WAIT_ON_CANCEL_SECONDS, +) @dataclass @@ -26,6 +34,24 @@ class ScopeCredentials(Credentials): priority: 1 max_files_per_trigger: 50 max_bytes_per_trigger: 10737418240000 # ~10 TB + max_file_count_per_output_file_set: 5000 # SCOPE @@MaxFileCountPerOutputFileSet + cancel_jobs_on_shutdown: true # cancel in-flight ADLA jobs on SIGINT/SIGTERM + wait_on_cancel_seconds: 30 # wait per job for ADLA terminal state + + # Optional: plug a custom azure.core.credentials.TokenCredential. + # Defaults to authentication='cli' which uses az login. + authentication: token_credential + credential_class: "fabric_entra_auth.EntraTokenCredential" + credential_kwargs: + auth: + authentication_method: SNI + sni: + client_id: + tenant_id: + vault_url: 'https://.vault.azure.net/' + vault_certificate_name: + vault_pull_config: + authentication_method: azCli """ adla_account: str = "" @@ -39,11 +65,32 @@ class ScopeCredentials(Credentials): job_timeout_seconds: int = 36_000 max_files_per_trigger: int = 50 max_bytes_per_trigger: int = 10_737_418_240_000 # ~10 TB + max_file_count_per_output_file_set: int = DEFAULT_MAX_FILE_COUNT_PER_OUTPUT_FILE_SET + cancel_jobs_on_shutdown: bool = DEFAULT_CANCEL_JOBS_ON_SHUTDOWN + wait_on_cancel_seconds: int = DEFAULT_WAIT_ON_CANCEL_SECONDS http_timeout_seconds: int = 120 - http_retries: int = 3 + http_retries: int = 10 scope_feature_previews: str | None = "EnableDeltaTableDynamicInsert:on" delta_lake_commit_condition: str = "FailIfFileConflict" + retry_on_error_messages: list[str] = field(default_factory=list) + max_retries_on_error: int = 25 + initial_wait_on_error_seconds: float = 1.0 + max_wait_on_error_seconds: float = 30.0 + + enable_quota_eviction: bool = True + quota_eviction_max_attempts: int = 25 + quota_eviction_cancel_num: int = 5 + quota_eviction_wait_seconds: float = 30.0 + quota_eviction_jitter_seconds: float = 5.0 + + # "cli" (default — AzureCliCredential) or "token_credential" (dotted-path) + authentication: str = "cli" + # Dotted path to a TokenCredential implementation loaded via importlib + # when authentication='token_credential'. + credential_class: str | None = None + credential_kwargs: dict[str, Any] = field(default_factory=dict) + @property def type(self) -> str: return "scope" @@ -63,5 +110,77 @@ def _connection_keys(self) -> tuple[str, ...]: "priority", "max_files_per_trigger", "max_bytes_per_trigger", + "max_file_count_per_output_file_set", + "cancel_jobs_on_shutdown", + "wait_on_cancel_seconds", "delta_lake_commit_condition", + "retry_on_error_messages", + "max_retries_on_error", + "initial_wait_on_error_seconds", + "max_wait_on_error_seconds", + "enable_quota_eviction", + "quota_eviction_max_attempts", + "quota_eviction_cancel_num", + "quota_eviction_wait_seconds", + "quota_eviction_jitter_seconds", + "authentication", + "credential_class", ) + + def __post_init__(self) -> None: + is_token_credential_auth = ( + isinstance(self.authentication, str) + and self.authentication.lower() == "token_credential" + ) + if is_token_credential_auth and not self.credential_class: + raise DbtRuntimeError( + "authentication='token_credential' requires `credential_class` " + "(dotted path to an azure.core.credentials.TokenCredential)." + ) + if not is_token_credential_auth and (self.credential_class or self.credential_kwargs): + raise DbtRuntimeError( + "`credential_class` and `credential_kwargs` are only valid when " + "authentication='token_credential'." + ) + + if self.max_retries_on_error < 0: + raise DbtRuntimeError( + f"max_retries_on_error must be >= 0; got {self.max_retries_on_error}" + ) + if self.initial_wait_on_error_seconds <= 0: + raise DbtRuntimeError( + "initial_wait_on_error_seconds must be > 0; " + f"got {self.initial_wait_on_error_seconds}" + ) + if self.max_wait_on_error_seconds <= 0: + raise DbtRuntimeError( + f"max_wait_on_error_seconds must be > 0; got {self.max_wait_on_error_seconds}" + ) + if self.initial_wait_on_error_seconds > self.max_wait_on_error_seconds: + raise DbtRuntimeError( + "initial_wait_on_error_seconds must be <= max_wait_on_error_seconds; " + f"got {self.initial_wait_on_error_seconds} > {self.max_wait_on_error_seconds}" + ) + for entry in self.retry_on_error_messages: + if not isinstance(entry, str) or not entry: + raise DbtRuntimeError( + f"retry_on_error_messages entries must be non-empty strings; got {entry!r}" + ) + + if self.quota_eviction_max_attempts < 0: + raise DbtRuntimeError( + f"quota_eviction_max_attempts must be >= 0; got {self.quota_eviction_max_attempts}" + ) + if self.quota_eviction_cancel_num < 1: + raise DbtRuntimeError( + f"quota_eviction_cancel_num must be >= 1; got {self.quota_eviction_cancel_num}" + ) + if self.quota_eviction_wait_seconds <= 0: + raise DbtRuntimeError( + f"quota_eviction_wait_seconds must be > 0; got {self.quota_eviction_wait_seconds}" + ) + if self.quota_eviction_jitter_seconds < 0: + raise DbtRuntimeError( + "quota_eviction_jitter_seconds must be >= 0; " + f"got {self.quota_eviction_jitter_seconds}" + ) diff --git a/dbt/adapters/scope/custom_credential.py b/dbt/adapters/scope/custom_credential.py new file mode 100644 index 0000000..0cd8bde --- /dev/null +++ b/dbt/adapters/scope/custom_credential.py @@ -0,0 +1,92 @@ +"""Lazy dotted-path loader for user-supplied ``TokenCredential`` implementations. + +Mirrors the pattern from dbt-fabricspark PR #177 (``livysession._load_custom_credential``): +the user's profile carries a dotted path under ``credential_class`` and an +arbitrary ``credential_kwargs`` mapping. We: + +1. Validate the dotted path against an identifier regex (defence-in-depth — ``importlib`` + wouldn't shell out, but rejecting non-identifier chars gives clearer errors). +2. ``importlib.import_module`` the module portion. +3. ``getattr`` the class portion and ``cls(**kwargs)`` it. +4. Enforce ``isinstance(instance, azure.core.credentials.TokenCredential)`` — the protocol + is ``@runtime_checkable`` so this checks for a callable ``get_token``. + +Instances are cached process-wide keyed by ``(dotted_path, repr-of-sorted-kwargs)`` so +that refreshes reuse the same object (matching how ``azure-identity`` credentials +are typically held) and so that the YAML round-trip of nested dicts/lists doesn't +defeat caching. +""" + +from __future__ import annotations + +import importlib +import re +import threading +from typing import Any + +from azure.core.credentials import TokenCredential +from dbt_common.exceptions import DbtRuntimeError + +_DOTTED_PATH_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)+$") + +_custom_credential_cache: dict[tuple[str, str], TokenCredential] = {} +_custom_credential_lock = threading.Lock() + + +def _cache_key(dotted: str, kwargs: dict[str, Any]) -> tuple[str, str]: + # repr() of sorted items keeps the key hashable even when kwargs contain + # nested dicts/lists from YAML. + return (dotted, repr(sorted(kwargs.items(), key=lambda kv: kv[0]))) + + +def load_custom_credential(dotted: str | None, kwargs: dict[str, Any] | None) -> TokenCredential: + """Import and instantiate the user-supplied ``TokenCredential``.""" + if not dotted: + raise DbtRuntimeError( + "authentication='token_credential' requires `credential_class` " + "(dotted path to an azure.core.credentials.TokenCredential)." + ) + if not _DOTTED_PATH_PATTERN.match(dotted): + raise DbtRuntimeError( + f"credential_class must be a dotted path like 'pkg.module.ClassName', got: {dotted!r}" + ) + kwargs = kwargs or {} + key = _cache_key(dotted, kwargs) + with _custom_credential_lock: + cached = _custom_credential_cache.get(key) + if cached is not None: + return cached + module_path, _, class_name = dotted.rpartition(".") + try: + module = importlib.import_module(module_path) + except ImportError as exc: + raise DbtRuntimeError( + f"Could not import module for credential_class={dotted!r}: {exc}" + ) from exc + try: + cls = getattr(module, class_name) + except AttributeError as exc: + raise DbtRuntimeError( + f"Module {module_path!r} has no attribute {class_name!r} " + f"(from credential_class={dotted!r})" + ) from exc + try: + instance = cls(**kwargs) + except TypeError as exc: + raise DbtRuntimeError( + f"Failed to instantiate {dotted!r} with credential_kwargs: {exc}" + ) from exc + # TokenCredential is @runtime_checkable — this checks for callable get_token. + if not isinstance(instance, TokenCredential): + raise DbtRuntimeError( + f"{dotted!r} must implement azure.core.credentials.TokenCredential " + f"(missing callable get_token)." + ) + _custom_credential_cache[key] = instance + return instance + + +def clear_cache() -> None: + """Clear the process-wide cache. Intended for tests.""" + with _custom_credential_lock: + _custom_credential_cache.clear() diff --git a/dbt/adapters/scope/delta_lake.py b/dbt/adapters/scope/delta_lake.py index 409c957..b3862d5 100644 --- a/dbt/adapters/scope/delta_lake.py +++ b/dbt/adapters/scope/delta_lake.py @@ -12,21 +12,22 @@ from __future__ import annotations import re +import time from abc import ABC, abstractmethod from collections.abc import Callable, Iterator from contextlib import contextmanager from dataclasses import dataclass -from functools import lru_cache from typing import Any import duckdb from azure.core.credentials import AccessToken, TokenCredential -from azure.identity import AzureCliCredential +from azure.identity import AzureCliCredential, CredentialUnavailableError from azure.storage.filedatalake import DataLakeServiceClient from dbt.adapters.events.logging import AdapterLogger from dbt_common.exceptions import DbtRuntimeError -from dbt.adapters.scope._file_lock import AZ_CLI_TOKEN_LOCK, FileLock +from dbt.adapters.scope._file_lock import AZ_CLI_TOKEN_LOCK, FABRIC_TOKEN_LOCK, FileLock +from dbt.adapters.scope.message_retry import MessageRetryPolicy, retry_on_message log = AdapterLogger("scope") @@ -76,18 +77,128 @@ def account_url(self) -> str: return f"https://{self.account}.dfs.core.windows.net" +@dataclass(frozen=True) +class RetryPolicy: + """Linear-backoff retry policy for transient credential failures. + + ``max_retries`` is the number of additional attempts AFTER the first + try — matching the semantics of urllib3's ``Retry(total=...)``. + Total attempts == ``max_retries + 1``. + + Delay between attempts is ``min(attempt * initial_delay_seconds, + max_delay_seconds)`` (linear, capped). No jitter — keep it + deterministic for testing. + """ + + max_retries: int = 10 + initial_delay_seconds: float = 1.0 + max_delay_seconds: float = 10.0 + + @classmethod + def from_http_retries(cls, http_retries: int | None) -> RetryPolicy: + """Build a policy from the ``http_retries`` profile field. + + Reuses the same field as the urllib3 HTTP retry count for + consistency. ``None`` (or any value below 0) returns the + defaults: 10 retries, 1s linear, 10s cap. + """ + if http_retries is None or http_retries < 0: + return cls() + return cls( + max_retries=http_retries, + initial_delay_seconds=1.0, + max_delay_seconds=10.0, + ) + + class LockedTokenCredential(TokenCredential): """Serialize token acquisition for credentials that share a cache on disk.""" - def __init__(self, credential: TokenCredential, lock_file: str = AZ_CLI_TOKEN_LOCK) -> None: + def __init__( + self, + credential: TokenCredential, + lock_file: str = AZ_CLI_TOKEN_LOCK, + retry_policy: RetryPolicy | None = None, + sleep: Callable[[float], None] = time.sleep, + ) -> None: self._credential = credential self._lock_file = lock_file + self._retry_policy = retry_policy or RetryPolicy() + self._sleep = sleep def get_token(self, *scopes: str, claims: str | None = None, **kwargs: Any) -> AccessToken: - with FileLock(self._lock_file): - if claims is None: - return self._credential.get_token(*scopes, **kwargs) - return self._credential.get_token(*scopes, claims=claims, **kwargs) + # ``CredentialUnavailableError`` is what ``AzureCliCredential`` raises + # when the underlying ``az`` subprocess times out or otherwise fails + # transiently (it wraps ``subprocess.TimeoutExpired`` and friends). + # Retry with linear backoff while releasing the file lock between + # attempts so other workers get a fair chance at the lock. + policy = self._retry_policy + last_exc: CredentialUnavailableError | None = None + for attempt in range(1, policy.max_retries + 2): # +1 for the initial try + try: + with FileLock(self._lock_file): + if claims is None: + return self._credential.get_token(*scopes, **kwargs) + return self._credential.get_token(*scopes, claims=claims, **kwargs) + except CredentialUnavailableError as exc: + last_exc = exc + if attempt > policy.max_retries: + log.error( + f"Azure credential acquisition failed after " + f"{policy.max_retries + 1} attempts: {exc.message}" + ) + raise + delay = min(policy.initial_delay_seconds * attempt, policy.max_delay_seconds) + log.warning( + f"Azure credential acquisition failed " + f"(attempt {attempt}/{policy.max_retries + 1}): " + f"{exc.message}. Retrying in {delay:.1f}s" + ) + self._sleep(delay) + # Unreachable: the loop either returns or raises. Keep mypy happy. + assert last_exc is not None + raise last_exc + + +def build_credential( + credentials: Any, *, retry_policy: RetryPolicy | None = None +) -> TokenCredential: + """Return the configured TokenCredential for a ScopeCredentials object, + always wrapped in ``LockedTokenCredential``. + + The file lock serializes concurrent dbt threads through a single token + acquisition. Without it, 4 parallel workers each independently walk the + inner credential's fallback chain — which on headless Fabric notebooks can + land on interactive device-code auth (one prompt per thread). + + - ``authentication='cli'``: wraps ``AzureCliCredential()``. File lock and + transient-error retry are tuned for the ``az`` subprocess token cache. + - ``authentication='token_credential'``: wraps the user-supplied credential + (e.g. ``EntraTokenCredential``). The first thread populates the cache; + subsequent threads reuse the cached token without re-entering the inner + credential's fallback chain. + """ + policy = retry_policy or RetryPolicy.from_http_retries( + getattr(credentials, "http_retries", None) + ) + auth = (getattr(credentials, "authentication", "cli") or "cli").lower() + if auth == "token_credential": + # Lazy import keeps `delta_lake.py` importable in places that don't + # need the custom-credential plumbing. + from dbt.adapters.scope.custom_credential import load_custom_credential + + inner: TokenCredential = load_custom_credential( + credentials.credential_class, credentials.credential_kwargs + ) + lock_file = FABRIC_TOKEN_LOCK + else: + inner = AzureCliCredential() + lock_file = AZ_CLI_TOKEN_LOCK + return LockedTokenCredential( + inner, + lock_file=lock_file, + retry_policy=policy, + ) class DeltaLakeClient(ABC): @@ -118,6 +229,9 @@ def table_exists(self, delta_location: str) -> bool: escaped_location = _sql_literal(delta_location) self.fetchone(f"SELECT 1 FROM delta_scan('{escaped_location}') LIMIT 0") return True + except CredentialUnavailableError: + log.error(f"table_exists: credential acquisition exhausted for {delta_location}") + raise except Exception: log.debug(f"table_exists({delta_location}) → False (not found or error)") return False @@ -135,6 +249,9 @@ def get_max_partition(self, delta_location: str, partition_col: str) -> str | No log.debug(f"get_max_partition({delta_location}, {partition_col}) → {result}") return result return None + except CredentialUnavailableError: + log.error(f"get_max_partition: credential acquisition exhausted for {delta_location}") + raise except Exception: log.debug(f"get_max_partition({delta_location}, {partition_col}) → None (error)") return None @@ -150,6 +267,9 @@ def get_columns(self, delta_location: str) -> list[str] | None: columns = [column[0] for column in column_description] log.debug(f"get_columns({delta_location}) → {columns!s}") return columns + except CredentialUnavailableError: + log.error(f"get_columns: credential acquisition exhausted for {delta_location}") + raise except Exception: log.debug(f"get_columns({delta_location}) → None (error)") return None @@ -233,10 +353,11 @@ def __init__( credential: TokenCredential, *, connection_factory: Callable[[], duckdb.DuckDBPyConnection] | None = None, - lock_file: str = AZ_CLI_TOKEN_LOCK, + message_retry_policy: MessageRetryPolicy | None = None, ) -> None: - self._credential = LockedTokenCredential(credential, lock_file=lock_file) + self._credential = credential self._connection_factory = connection_factory or duckdb.connect + self._message_retry_policy = message_retry_policy or MessageRetryPolicy.disabled() @contextmanager def connect(self) -> Iterator[duckdb.DuckDBPyConnection]: @@ -261,7 +382,7 @@ def list_table_paths(self, delta_location: str) -> list[str]: if parsed is None: return [] - try: + def _list() -> list[str]: service = DataLakeServiceClient( account_url=parsed.account_url, credential=self._credential, @@ -273,12 +394,16 @@ def list_table_paths(self, delta_location: str) -> list[str]: for path in file_system.get_paths(path=prefix, recursive=True) if not getattr(path, "is_directory", False) ] + + try: + return retry_on_message( + _list, + policy=self._message_retry_policy, + label=f"delta_lake.list_table_paths {delta_location}", + ) + except CredentialUnavailableError: + log.error(f"list_table_paths: credential acquisition exhausted for {delta_location}") + raise except Exception: log.warning(f"list_table_paths({delta_location}) failed") return [] - - -@lru_cache(maxsize=1) -def get_default_delta_client() -> DuckDbDeltaLakeClient: - """Return the default Delta client used by the adapter and test helpers.""" - return DuckDbDeltaLakeClient(credential=AzureCliCredential()) diff --git a/dbt/adapters/scope/impl.py b/dbt/adapters/scope/impl.py index f6c59f5..7205980 100644 --- a/dbt/adapters/scope/impl.py +++ b/dbt/adapters/scope/impl.py @@ -2,6 +2,7 @@ from __future__ import annotations +import atexit import signal import threading import time @@ -18,16 +19,24 @@ from dbt.adapters.scope.adls_gen1_client import AdlsGen1Client, FileInfo from dbt.adapters.scope.checkpoint import CheckpointManager, Watermark from dbt.adapters.scope.column import ScopeColumn -from dbt.adapters.scope.connections import ScopeConnectionHandle, ScopeConnectionManager +from dbt.adapters.scope.connections import ( + ScopeConnectionHandle, + ScopeConnectionManager, + _shutdown_event, + cancel_all_active_jobs, +) from dbt.adapters.scope.constants import ( DEFAULT_MAX_BYTES_PER_TRIGGER, DEFAULT_PROCESSING_TIME_TIMEOUT_SECONDS, DEFAULT_SAFETY_BUFFER_SECONDS, DEFAULT_SOURCE_COMPACTION_INTERVAL, DEFAULT_SOURCE_RETENTION_FILES, + DEFAULT_WAIT_ON_CANCEL_SECONDS, ) from dbt.adapters.scope.credentials import ScopeCredentials +from dbt.adapters.scope.delta_lake import RetryPolicy, build_credential from dbt.adapters.scope.file_tracker import FileTracker +from dbt.adapters.scope.message_retry import MessageRetryPolicy, retry_on_message from dbt.adapters.scope.relation import ScopeRelation from dbt.adapters.scope.script_builder import ColumnDef, ScriptConfig from dbt.adapters.scope.trigger_config import parse_trigger_config @@ -37,13 +46,45 @@ # --------------------------------------------------------------------------- # Graceful shutdown support # --------------------------------------------------------------------------- -_shutdown_event = threading.Event() _signal_handlers_installed = False _signal_lock = threading.Lock() +_atexit_registered = False + +# Credentials observed across all ScopeConnectionManager.open() calls in this +# process. Used by the signal handler to decide (a) whether to cancel +# in-flight jobs, and (b) how long to wait for ADLA to confirm terminal state. +_observed_credentials: list[ScopeCredentials] = [] +_observed_credentials_lock = threading.Lock() + + +def _observe_credentials(credentials: ScopeCredentials) -> None: + """Record a credentials object so the signal handler can read its preferences.""" + with _observed_credentials_lock: + for existing in _observed_credentials: + if existing is credentials: + return + _observed_credentials.append(credentials) + + +def _any_observed_cancel_on_shutdown_enabled() -> bool: + with _observed_credentials_lock: + if not _observed_credentials: + return True + return any(getattr(c, "cancel_jobs_on_shutdown", True) for c in _observed_credentials) + + +def _observed_max_wait_on_cancel_seconds() -> int: + with _observed_credentials_lock: + values = [ + getattr(c, "wait_on_cancel_seconds", DEFAULT_WAIT_ON_CANCEL_SECONDS) + for c in _observed_credentials + if getattr(c, "cancel_jobs_on_shutdown", True) + ] + return max(values) if values else DEFAULT_WAIT_ON_CANCEL_SECONDS def _install_signal_handlers() -> None: - """Install SIGTERM/SIGINT handlers that set the shutdown event. + """Install SIGTERM/SIGINT handlers that trigger graceful shutdown. Safe to call from any thread — only installs handlers when called from the main thread. Subsequent calls are no-ops. @@ -61,11 +102,16 @@ def _install_signal_handlers() -> None: def _handler(signum: int, frame: Any) -> None: sig_name = signal.Signals(signum).name - log.info( - f"Received {sig_name} — requesting graceful shutdown of processing_time models" - ) + log.info(f"Received {sig_name} — requesting graceful shutdown") _shutdown_event.set() - # Chain to previous handler (e.g. dbt's own handler) + if _any_observed_cancel_on_shutdown_enabled(): + try: + cancel_all_active_jobs( + f"signal:{sig_name}", + wait_seconds=_observed_max_wait_on_cancel_seconds(), + ) + except Exception as exc: + log.warning(f"cancel_all_active_jobs failed in signal handler: {exc}") prev = _prev_sigterm if signum == signal.SIGTERM else _prev_sigint if callable(prev) and prev not in (signal.SIG_DFL, signal.SIG_IGN): prev(signum, frame) @@ -75,6 +121,52 @@ def _handler(signum: int, frame: Any) -> None: _signal_handlers_installed = True +def _atexit_cancel_all() -> None: + """Fallback cancel-all invoked on interpreter shutdown. + + Covers paths where dbt unwinds via an unhandled exception that does not + pass through our signal handler. + """ + if not _any_observed_cancel_on_shutdown_enabled(): + return + try: + cancel_all_active_jobs( + "atexit", + wait_seconds=_observed_max_wait_on_cancel_seconds(), + ) + except Exception as exc: + log.warning(f"cancel_all_active_jobs failed in atexit hook: {exc}") + + +def _register_atexit() -> None: + global _atexit_registered + if _atexit_registered: + return + atexit.register(_atexit_cancel_all) + _atexit_registered = True + + +def _scope_open_hook(credentials: ScopeCredentials) -> None: + """Invoked by ``ScopeConnectionManager.open()`` for every connection.""" + _observe_credentials(credentials) + _install_signal_handlers() + _register_atexit() + + +ScopeConnectionManager._on_open = staticmethod(_scope_open_hook) + +# Install signal handlers eagerly at module-load time. ``dbt.adapters.scope.impl`` +# is imported during dbt's main-thread CLI bootstrap (before any worker threads +# are spawned for model execution), so this is the only reliable place to win +# the race against ``signal.signal()``'s main-thread-only requirement. +# ``ScopeConnectionManager.open()`` runs on per-model worker threads (via +# dbt's ``LazyHandle(self.open)``), where ``signal.signal()`` would raise +# ``ValueError: signal only works in main thread of the main interpreter`` +# and our ``_install_signal_handlers`` guard would early-return. +_install_signal_handlers() +_register_atexit() + + _TIMESTAMP_COLS = ("accessTime", "modificationTime", "msExpirationTime", "expiryTime") _SIZE_COLS = ("length", "blockSize") @@ -222,7 +314,7 @@ def date_function(cls) -> str: @classmethod def is_cancelable(cls) -> bool: - return False + return True def list_schemas(self, database: str) -> list[str]: """Return the single 'schema' — the container path.""" @@ -275,19 +367,19 @@ def list_relations_without_caching(self, schema_relation: ScopeRelation) -> list if not creds.storage_account or not creds.container: return [] + message_retry_policy = MessageRetryPolicy.from_credentials(creds) + try: - from azure.identity import AzureCliCredential + from azure.identity import CredentialUnavailableError from azure.storage.filedatalake import DataLakeServiceClient - from dbt.adapters.scope.delta_lake import LockedTokenCredential - t_start = time.monotonic() log.debug( f"list_relations: scanning {creds.storage_account}/{creds.container}/" f"{creds.delta_base_path} for Delta tables" ) - credential = LockedTokenCredential(AzureCliCredential()) + credential = build_credential(creds) service = DataLakeServiceClient( account_url=f"https://{creds.storage_account}.dfs.core.windows.net", credential=credential, @@ -295,11 +387,15 @@ def list_relations_without_caching(self, schema_relation: ScopeRelation) -> list fs = service.get_file_system_client(creds.container) t0 = time.monotonic() - dirs = [ - p - for p in fs.get_paths(path=creds.delta_base_path, recursive=False) - if p.is_directory - ] + dirs = retry_on_message( + lambda: [ + p + for p in fs.get_paths(path=creds.delta_base_path, recursive=False) + if p.is_directory + ], + policy=message_retry_policy, + label=f"list_relations.get_paths {creds.delta_base_path}", + ) elapsed_ms = (time.monotonic() - t0) * 1000 log.debug( f"list_relations: get_paths found {len(dirs)} directories in {elapsed_ms:.1f} ms" @@ -310,8 +406,16 @@ def list_relations_without_caching(self, schema_relation: ScopeRelation) -> list table_name = path_info.name.split("/")[-1] t0 = time.monotonic() try: - delta_log = fs.get_directory_client(f"{path_info.name}/_delta_log") - delta_log.get_directory_properties() + + def _probe(name=path_info.name): + delta_log = fs.get_directory_client(f"{name}/_delta_log") + delta_log.get_directory_properties() + + retry_on_message( + _probe, + policy=message_retry_policy, + label=f"list_relations.probe {table_name}", + ) relations.append( self.Relation.create( database=creds.storage_account, @@ -325,6 +429,8 @@ def list_relations_without_caching(self, schema_relation: ScopeRelation) -> list f"list_relations: [{i + 1}/{len(dirs)}] {table_name} — " f"Delta table found in {elapsed_ms:.1f} ms" ) + except CredentialUnavailableError: + raise except Exception: elapsed_ms = (time.monotonic() - t0) * 1000 log.debug( @@ -335,6 +441,11 @@ def list_relations_without_caching(self, schema_relation: ScopeRelation) -> list total_ms = (time.monotonic() - t_start) * 1000 log.debug(f"list_relations: found {len(relations)} Delta tables in {total_ms:.1f} ms") return relations + except CredentialUnavailableError: + log.error( + f"list_relations: credential acquisition exhausted for {creds.delta_base_path}" + ) + raise except Exception: log.debug(f"No Delta tables found at {creds.delta_base_path} (path may not exist yet)") return [] @@ -808,13 +919,23 @@ def _get_gen1_client(self) -> AdlsGen1Client: """Return an ADLS Gen1 client for the configured account.""" if not hasattr(self, "_gen1_client"): creds = self._credentials() - self._gen1_client = AdlsGen1Client(account=creds.adls_gen1_account) + self._gen1_client = AdlsGen1Client( + account=creds.adls_gen1_account, + credential=build_credential(creds), + retry_policy=RetryPolicy.from_http_retries(creds.http_retries), + message_retry_policy=MessageRetryPolicy.from_credentials(creds), + ) return self._gen1_client def _get_checkpoint_manager(self) -> CheckpointManager: """Return the checkpoint manager singleton.""" if not hasattr(self, "_checkpoint_manager"): - self._checkpoint_manager = CheckpointManager() + creds = self._credentials() + self._checkpoint_manager = CheckpointManager( + credential=build_credential(creds), + retry_policy=RetryPolicy.from_http_retries(creds.http_retries), + message_retry_policy=MessageRetryPolicy.from_credentials(creds), + ) return self._checkpoint_manager def _get_file_tracker(self) -> FileTracker: diff --git a/dbt/adapters/scope/message_retry.py b/dbt/adapters/scope/message_retry.py new file mode 100644 index 0000000..aa55c26 --- /dev/null +++ b/dbt/adapters/scope/message_retry.py @@ -0,0 +1,165 @@ +"""Message-pattern-based retry layer for Azure API calls. + +Sits above the urllib3 transport retry (which handles 429 / 5xx + connect/read +errors) and above the ``RetryPolicy`` token retry (which handles +``CredentialUnavailableError``). This layer inspects the **exception message** +of whatever leaks past those — typically structured error bodies returned by +ADLA / ADLS as 4xx responses — and retries with bounded exponential backoff +when the message matches one of the user-configured patterns. + +Configured via ``ScopeCredentials``:: + + retry_on_error_messages: + - "Cannot exceed" # plain substring (case-sensitive) + - " queued SCOPE jobs" # plain substring + - "re:Cannot exceed \\d+" # compiled as regex when prefixed with "re:" + max_retries_on_error: 25 + max_wait_on_error_seconds: 30 + initial_wait_on_error_seconds: 1 +""" + +from __future__ import annotations + +import re +import time +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, TypeVar + +from dbt.adapters.events.logging import AdapterLogger + +log = AdapterLogger("scope") + +T = TypeVar("T") + +_REGEX_PREFIX = "re:" + + +@dataclass(frozen=True) +class MessageRetryPolicy: + """Exponential-backoff retry triggered by exception message patterns. + + ``max_retries`` is the number of additional attempts AFTER the first try + (matching the semantics of ``urllib3.Retry(total=...)``). Total attempts + == ``max_retries + 1``. + + Delay between attempts is ``min(initial_wait_seconds * 2**(attempt-1), + max_wait_seconds)`` (capped exponential, no jitter — deterministic for + testing). Empty ``patterns`` disables the layer entirely. + """ + + patterns: tuple[Any, ...] = () + max_retries: int = 25 + initial_wait_seconds: float = 1.0 + max_wait_seconds: float = 30.0 + + @classmethod + def disabled(cls) -> MessageRetryPolicy: + return cls(patterns=()) + + @classmethod + def from_credentials(cls, credentials: Any) -> MessageRetryPolicy: + raw_patterns = getattr(credentials, "retry_on_error_messages", None) or [] + compiled: list[Any] = [] + for entry in raw_patterns: + if not isinstance(entry, str) or not entry: + raise ValueError( + f"retry_on_error_messages entries must be non-empty strings; got {entry!r}" + ) + if entry.startswith(_REGEX_PREFIX): + pattern_text = entry[len(_REGEX_PREFIX) :] + if not pattern_text: + raise ValueError( + f"retry_on_error_messages regex entry {entry!r} is empty after 're:'" + ) + compiled.append(re.compile(pattern_text)) + else: + compiled.append(entry) + + max_retries = int(getattr(credentials, "max_retries_on_error", 25)) + initial_wait = float(getattr(credentials, "initial_wait_on_error_seconds", 1.0)) + max_wait = float(getattr(credentials, "max_wait_on_error_seconds", 30.0)) + + if max_retries < 0: + raise ValueError(f"max_retries_on_error must be >= 0; got {max_retries}") + if initial_wait <= 0: + raise ValueError(f"initial_wait_on_error_seconds must be > 0; got {initial_wait}") + if max_wait <= 0: + raise ValueError(f"max_wait_on_error_seconds must be > 0; got {max_wait}") + if initial_wait > max_wait: + raise ValueError( + "initial_wait_on_error_seconds must be <= max_wait_on_error_seconds; " + f"got {initial_wait} > {max_wait}" + ) + + return cls( + patterns=tuple(compiled), + max_retries=max_retries, + initial_wait_seconds=initial_wait, + max_wait_seconds=max_wait, + ) + + @property + def enabled(self) -> bool: + return bool(self.patterns) and self.max_retries > 0 + + def matches(self, exc: BaseException) -> str | None: + if not self.patterns: + return None + text = str(exc) + for pattern in self.patterns: + if isinstance(pattern, re.Pattern): + if pattern.search(text): + return f"re:{pattern.pattern}" + elif pattern in text: + return pattern + return None + + def delay_for_attempt(self, attempt: int) -> float: + if attempt < 1: + attempt = 1 + raw = self.initial_wait_seconds * (2 ** (attempt - 1)) + return min(raw, self.max_wait_seconds) + + +def retry_on_message( + operation: Callable[[], T], + *, + policy: MessageRetryPolicy, + label: str, + sleep: Callable[[float], None] = time.sleep, +) -> T: + """Run ``operation``; on a pattern-matching exception, retry with backoff. + + Non-matching exceptions are re-raised immediately without delay. When the + retry budget is exhausted the **original** exception is re-raised so + downstream callers see the unmodified type (e.g. ``DbtDatabaseError`` + stays ``DbtDatabaseError``). + """ + if not policy.enabled: + return operation() + + last_exc: BaseException | None = None + total_attempts = policy.max_retries + 1 + for attempt in range(1, total_attempts + 1): + try: + return operation() + except BaseException as exc: + matched = policy.matches(exc) + if matched is None: + raise + last_exc = exc + if attempt >= total_attempts: + log.error( + f"[{label}] retry budget exhausted after {attempt} attempt(s) " + f"on pattern {matched!r}: {exc}" + ) + raise + delay = policy.delay_for_attempt(attempt) + log.warning( + f"[{label}] transient error matched pattern {matched!r} " + f"(attempt {attempt}/{total_attempts}); retrying in {delay:.1f}s: {exc}" + ) + sleep(delay) + assert last_exc is not None + raise last_exc diff --git a/dbt/adapters/scope/quota_eviction.py b/dbt/adapters/scope/quota_eviction.py new file mode 100644 index 0000000..e18951c --- /dev/null +++ b/dbt/adapters/scope/quota_eviction.py @@ -0,0 +1,212 @@ +"""Quota-eviction recovery layer for ADLA SCOPE job submission. + +When the ADLA workspace queue is saturated, ``submit_job`` fails with a +400 BadRequest carrying:: + + Cannot exceed 1000 queued SCOPE jobs in an ADLA workspace. + +This layer intercepts that specific error, lists every non-terminal job in +the workspace, picks the least-important (highest ``priority`` number), +oldest victims, cancels ``cancel_num`` of them, sleeps with jitter, and +retries the original submit. A per-account ``threading.Lock`` prevents +concurrent dbt threads from cascading evictions for the same quota event. + +Configured via ``ScopeCredentials``:: + + enable_quota_eviction: true + quota_eviction_max_attempts: 25 + quota_eviction_cancel_num: 5 + quota_eviction_wait_seconds: 30 + quota_eviction_jitter_seconds: 5 +""" + +from __future__ import annotations + +import random +import threading +import time +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Protocol, TypeVar + +from dbt.adapters.events.logging import AdapterLogger + +log = AdapterLogger("scope") + +T = TypeVar("T") + +_QUOTA_NEEDLES = ("Cannot exceed", "queued SCOPE jobs") +_NON_TERMINAL_FILTER = "state ne 'Ended'" +_LIST_TOP = 1000 + +_EVICTION_LOCKS: dict[str, threading.Lock] = {} +_EVICTION_LOCKS_GUARD = threading.Lock() + + +def _lock_for(account: str) -> threading.Lock: + with _EVICTION_LOCKS_GUARD: + lk = _EVICTION_LOCKS.get(account) + if lk is None: + lk = threading.Lock() + _EVICTION_LOCKS[account] = lk + return lk + + +@dataclass(frozen=True) +class QuotaEvictionPolicy: + """Eviction recovery policy for ADLA queue-saturation 400s.""" + + account: str + enabled: bool + max_attempts: int + cancel_num: int + wait_seconds: float + jitter_seconds: float + + @classmethod + def from_credentials(cls, credentials: Any) -> QuotaEvictionPolicy: + return cls( + account=getattr(credentials, "adla_account", "") or "", + enabled=bool(getattr(credentials, "enable_quota_eviction", True)), + max_attempts=int(getattr(credentials, "quota_eviction_max_attempts", 25)), + cancel_num=int(getattr(credentials, "quota_eviction_cancel_num", 5)), + wait_seconds=float(getattr(credentials, "quota_eviction_wait_seconds", 30.0)), + jitter_seconds=float(getattr(credentials, "quota_eviction_jitter_seconds", 5.0)), + ) + + @classmethod + def disabled(cls) -> QuotaEvictionPolicy: + return cls( + account="", + enabled=False, + max_attempts=0, + cancel_num=1, + wait_seconds=1.0, + jitter_seconds=0.0, + ) + + +class EvictionContext(Protocol): + """Operations the eviction layer needs from the connection handle. + + ``cancel_job_async`` MUST be fire-and-forget — it must NOT block waiting + for the job to reach a terminal state. Blocking would multiply the + recovery time by the number of victims. + """ + + def list_jobs(self, filter_expr: str | None = None, top: int = 100) -> list[dict[str, Any]]: ... + + def cancel_job_async(self, job_id: str) -> None: ... + + +def is_quota_error(exc: BaseException) -> bool: + """Return True when ``exc`` carries the ADLA queue-saturation message.""" + msg = str(exc) + return all(needle in msg for needle in _QUOTA_NEEDLES) + + +def select_victims(jobs: list[dict[str, Any]], k: int) -> list[dict[str, Any]]: + """Return the top-``k`` eviction victims. + + Sorts by highest ``priority`` number (least important) first, then by + oldest ``submitTime`` (lexical ISO-8601 sort). Missing fields are + treated as worst-case (priority=0 = most important; submitTime='' = + oldest, so an unknown timestamp gets evicted earlier within a tier). + """ + if k <= 0 or not jobs: + return [] + + def sort_key(job: dict[str, Any]) -> tuple[int, str]: + raw_priority = job.get("priority", 0) + try: + priority = int(raw_priority) + except (TypeError, ValueError): + priority = 0 + submit_time = job.get("submitTime") or "" + return (-priority, str(submit_time)) + + ordered = sorted(jobs, key=sort_key) + return ordered[:k] + + +def retry_with_quota_eviction( + op: Callable[[], T], + *, + eviction_ctx: EvictionContext, + policy: QuotaEvictionPolicy, + label: str, + sleep: Callable[[float], None] = time.sleep, + random_uniform: Callable[[float, float], float] = random.uniform, +) -> T: + """Execute ``op``; on ADLA quota 400, evict and retry up to ``max_attempts`` times.""" + if not policy.enabled or policy.max_attempts <= 0: + return op() + + lock = _lock_for(policy.account) + last_exc: BaseException | None = None + + for attempt in range(1, policy.max_attempts + 1): + try: + return op() + except Exception as exc: + if not is_quota_error(exc): + raise + last_exc = exc + log.warning( + f"{label}: ADLA quota hit ({exc}). " + f"Attempt {attempt}/{policy.max_attempts}: " + f"evicting up to {policy.cancel_num} job(s) and retrying." + ) + + with lock: + victims = _gather_and_cancel(eviction_ctx, policy) + + if not victims: + log.warning( + f"{label}: no eviction candidates found in workspace " + f"'{policy.account}'; aborting retry." + ) + raise + + jitter = random_uniform(-policy.jitter_seconds, policy.jitter_seconds) + delay = max(0.0, policy.wait_seconds + jitter) + log.info( + f"{label}: cancelled {len(victims)} victim(s); sleeping {delay:.1f}s before retry." + ) + sleep(delay) + + assert last_exc is not None + log.warning(f"{label}: exhausted {policy.max_attempts} quota-eviction attempts; raising.") + raise last_exc + + +def _gather_and_cancel(ctx: EvictionContext, policy: QuotaEvictionPolicy) -> list[dict[str, Any]]: + try: + jobs = ctx.list_jobs(filter_expr=_NON_TERMINAL_FILTER, top=_LIST_TOP) + except Exception as exc: + log.warning(f"Quota eviction: list_jobs failed ({exc}); cannot pick victims.") + return [] + + victims = select_victims(jobs, policy.cancel_num) + if not victims: + return [] + + cancelled: list[dict[str, Any]] = [] + for victim in victims: + job_id = str(victim.get("jobId") or "") + if not job_id: + continue + name = victim.get("name", "") + priority = victim.get("priority", "?") + submit_time = victim.get("submitTime", "?") + log.info( + f"Quota eviction: cancelling job {job_id} " + f"(name='{name}', priority={priority}, submitTime={submit_time})" + ) + try: + ctx.cancel_job_async(job_id) + cancelled.append(victim) + except Exception as exc: + log.warning(f"Quota eviction: cancel of {job_id} failed ({exc}); continuing.") + + return cancelled diff --git a/dbt/adapters/scope/script_builder.py b/dbt/adapters/scope/script_builder.py index f228a2c..e4f7284 100644 --- a/dbt/adapters/scope/script_builder.py +++ b/dbt/adapters/scope/script_builder.py @@ -20,8 +20,11 @@ from dbt.adapters.scope.constants import ( DEFAULT_DELTA_LAKE_COMMIT_CONDITION, DEFAULT_MAX_BYTES_PER_TRIGGER, + DEFAULT_MAX_FILE_COUNT_PER_OUTPUT_FILE_SET, DEFAULT_MAX_FILES_PER_TRIGGER, DEFAULT_SAFETY_BUFFER_SECONDS, + MAX_FILE_COUNT_PER_OUTPUT_FILE_SET_MAX, + MAX_FILE_COUNT_PER_OUTPUT_FILE_SET_MIN, VALID_DELTA_LAKE_COMMIT_CONDITIONS, ) @@ -80,6 +83,10 @@ class ScriptConfig: # Delta Lake commit condition delta_lake_commit_condition: str = DEFAULT_DELTA_LAKE_COMMIT_CONDITION + # SCOPE @@MaxFileCountPerOutputFileSet — caps the number of distinct output files + # a single OutputFileSet operator may emit. Range [1, 1_000_000]. + max_file_count_per_output_file_set: int = DEFAULT_MAX_FILE_COUNT_PER_OUTPUT_FILE_SET + # Delta table columns (CREATE TABLE schema) delta_columns: list[ColumnDef] = field(default_factory=list) @@ -126,6 +133,9 @@ def build_full_refresh( parts.append(_header_comment("full-refresh", config.table_name)) parts.append(_set_feature_previews(config.feature_previews)) parts.append(_set_commit_condition(config.delta_lake_commit_condition)) + parts.append( + _set_max_file_count_per_output_file_set(config.max_file_count_per_output_file_set) + ) parts.append(_declare_paths(delta_loc)) parts.append(_create_table(config.delta_columns, config.partition_by, "@deltaPath")) if config.scope_settings: @@ -168,6 +178,9 @@ def build_incremental( ) parts.append(_set_feature_previews(config.feature_previews)) parts.append(_set_commit_condition(config.delta_lake_commit_condition)) + parts.append( + _set_max_file_count_per_output_file_set(config.max_file_count_per_output_file_set) + ) parts.append(_declare_paths(delta_loc)) parts.append(_create_table(config.delta_columns, config.partition_by, "@deltaPath")) if config.scope_settings: @@ -230,6 +243,29 @@ def _set_commit_condition(condition: str) -> str: return f'SET @@DeltaLakeCommitCondition = "{condition}";\n' +def _set_max_file_count_per_output_file_set(value: int) -> str: + """Emit ``SET @@MaxFileCountPerOutputFileSet`` after validating the value. + + The SCOPE compiler rejects values outside [1, 1_000_000]; we surface a + ``DbtRuntimeError`` early so misconfiguration fails at compile time rather + than after a job has been submitted. + """ + if not isinstance(value, int) or isinstance(value, bool): + raise DbtRuntimeError( + f"Invalid max_file_count_per_output_file_set '{value!r}': must be an int." + ) + if ( + value < MAX_FILE_COUNT_PER_OUTPUT_FILE_SET_MIN + or value > MAX_FILE_COUNT_PER_OUTPUT_FILE_SET_MAX + ): + raise DbtRuntimeError( + f"Invalid max_file_count_per_output_file_set '{value}': must be in " + f"[{MAX_FILE_COUNT_PER_OUTPUT_FILE_SET_MIN}, " + f"{MAX_FILE_COUNT_PER_OUTPUT_FILE_SET_MAX}]." + ) + return f"SET @@MaxFileCountPerOutputFileSet = {value};\n" + + def _declare_paths(delta_loc: str) -> str: return f'#DECLARE @deltaPath string = "{delta_loc}";\n' diff --git a/dbt/include/scope/macros/materializations/defaults.sql b/dbt/include/scope/macros/materializations/defaults.sql index ebf3911..0117446 100644 --- a/dbt/include/scope/macros/materializations/defaults.sql +++ b/dbt/include/scope/macros/materializations/defaults.sql @@ -12,6 +12,7 @@ "safety_buffer_seconds": 30, "source_compaction_interval": 10, "source_retention_files": 100, + "max_file_count_per_output_file_set": 5000, "delta_lake_commit_condition": "FailIfFileConflict", }) %} {% endmacro %} diff --git a/dbt/include/scope/macros/materializations/incremental.sql b/dbt/include/scope/macros/materializations/incremental.sql index 8cbda71..532c993 100644 --- a/dbt/include/scope/macros/materializations/incremental.sql +++ b/dbt/include/scope/macros/materializations/incremental.sql @@ -42,6 +42,10 @@ {%- set source_patterns = config.get('source_patterns', ['.*\\.ss$']) -%} {%- set max_files_per_trigger = config.get('max_files_per_trigger', target.max_files_per_trigger) | int -%} {%- set max_bytes_per_trigger = config.get('max_bytes_per_trigger', target.max_bytes_per_trigger) | int -%} + {%- set max_file_count_per_output_file_set = config.get( + 'max_file_count_per_output_file_set', + target.max_file_count_per_output_file_set | default(defaults.max_file_count_per_output_file_set, true) + ) | int -%} {%- set safety_buffer_seconds = config.get('safety_buffer_seconds', defaults.safety_buffer_seconds) | int -%} {%- set source_compaction_interval = config.get('source_compaction_interval', defaults.source_compaction_interval) | int -%} {%- set source_retention_files = config.get('source_retention_files', defaults.source_retention_files) | int -%} @@ -141,7 +145,8 @@ ns.file_batch, is_full_refresh=is_first_full_refresh_batch, is_incremental=(not is_first_full_refresh_batch), - delta_lake_commit_condition=delta_lake_commit_condition + delta_lake_commit_condition=delta_lake_commit_condition, + max_file_count_per_output_file_set=max_file_count_per_output_file_set ) -%} {%- set mode_label = "full-refresh" if full_refresh_mode else "incremental" -%} diff --git a/dbt/include/scope/macros/materializations/table.sql b/dbt/include/scope/macros/materializations/table.sql index 2d5b1f6..c6691b8 100644 --- a/dbt/include/scope/macros/materializations/table.sql +++ b/dbt/include/scope/macros/materializations/table.sql @@ -22,6 +22,10 @@ {%- set source_patterns = config.get('source_patterns', ['.*\\.ss$']) -%} {%- set max_files_per_trigger = config.get('max_files_per_trigger', target.max_files_per_trigger) | int -%} {%- set max_bytes_per_trigger = config.get('max_bytes_per_trigger', target.max_bytes_per_trigger) | int -%} + {%- set max_file_count_per_output_file_set = config.get( + 'max_file_count_per_output_file_set', + target.max_file_count_per_output_file_set | default(defaults.max_file_count_per_output_file_set, true) + ) | int -%} {%- set safety_buffer_seconds = config.get('safety_buffer_seconds', defaults.safety_buffer_seconds) | int -%} {%- set source_compaction_interval = config.get('source_compaction_interval', defaults.source_compaction_interval) | int -%} {%- set source_retention_files = config.get('source_retention_files', defaults.source_retention_files) | int -%} @@ -74,7 +78,8 @@ sql, ns.file_batch, is_full_refresh=(ns.batch_num == 1), - delta_lake_commit_condition=delta_lake_commit_condition + delta_lake_commit_condition=delta_lake_commit_condition, + max_file_count_per_output_file_set=max_file_count_per_output_file_set ) -%} {{ log("SCOPE: full-refresh " ~ identifier ~ " batch " ~ ns.batch_num ~ " of " ~ total_batches ~ " (" ~ ns.file_batch | length ~ " files)", info=True) }} @@ -119,7 +124,8 @@ source_files, is_full_refresh=false, is_incremental=false, - delta_lake_commit_condition="FailIfFileConflict" + delta_lake_commit_condition="FailIfFileConflict", + max_file_count_per_output_file_set=5000 ) %} {# -- Normalize partition_by to a list -- #} {%- set partition_cols = partition_by if partition_by is iterable and partition_by is not string else ([partition_by] if partition_by else []) -%} @@ -133,6 +139,7 @@ SET @@FeaturePreviews = "{{ feature_previews }}"; SET @@DeltaLakeCommitCondition = "{{ delta_lake_commit_condition }}"; +SET @@MaxFileCountPerOutputFileSet = {{ max_file_count_per_output_file_set }}; #DECLARE @deltaPath string = "{{ delta_location }}"; diff --git a/dbt/include/scope/profile_template.yml b/dbt/include/scope/profile_template.yml index 5b4a91a..6a5d7ba 100644 --- a/dbt/include/scope/profile_template.yml +++ b/dbt/include/scope/profile_template.yml @@ -17,3 +17,12 @@ prompts: priority: hint: "Job priority (lower = higher priority)" default: 1 + authentication: + hint: "Auth mode: 'cli' (AzureCliCredential) or 'token_credential' (custom)" + default: "cli" + credential_class: + hint: "Dotted path to a TokenCredential implementation (e.g. 'fabric_entra_auth.EntraTokenCredential') — only used when authentication='token_credential'" + default: "" + credential_kwargs: + hint: "Mapping of kwargs forwarded to credential_class — only used when authentication='token_credential'" + default: {} diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index de34bb9..ddd1965 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -20,10 +20,11 @@ import duckdb import pytest +from azure.identity import AzureCliCredential from datagen import ScopeDataset, make_default_dataset, submit_datagen_job from dbt.cli.main import dbtRunner -from dbt.adapters.scope.delta_lake import get_default_delta_client +from dbt.adapters.scope.delta_lake import DuckDbDeltaLakeClient, LockedTokenCredential PROJECT_DIR = Path(__file__).parent / "dbt_project" REPO_ROOT = Path(__file__).parent.parent.parent @@ -52,7 +53,8 @@ def _delta_path(table_name: str) -> str: def _delta_client(): - return get_default_delta_client() + # Integration tests run on dev machines authenticated via ``az login``. + return DuckDbDeltaLakeClient(credential=LockedTokenCredential(AzureCliCredential())) # -- Datagen datasets -------------------------------------------------------- @@ -273,22 +275,22 @@ def verify_delta_with_duckdb( # -- Watermark + sources checkpoint verification ------------------------------ -def read_watermark(delta_path: str): - """Read the watermark checkpoint for a Delta table.""" +def _checkpoint_manager(): from dbt.adapters.scope.checkpoint import CheckpointManager - return CheckpointManager().read_watermark(delta_path) + return CheckpointManager(credential=LockedTokenCredential(AzureCliCredential())) + + +def read_watermark(delta_path: str): + """Read the watermark checkpoint for a Delta table.""" + return _checkpoint_manager().read_watermark(delta_path) def list_source_files(delta_path: str) -> list[str]: """List files in ``_checkpoint/sources/``.""" - from dbt.adapters.scope.checkpoint import CheckpointManager - - return CheckpointManager().list_source_files(delta_path) + return _checkpoint_manager().list_source_files(delta_path) def read_batch_source(delta_path: str, batch_id: int) -> list[dict]: """Read a batch JSONL file from ``_checkpoint/sources/{batch_id}``.""" - from dbt.adapters.scope.checkpoint import CheckpointManager - - return CheckpointManager().read_batch_source(delta_path, batch_id) + return _checkpoint_manager().read_batch_source(delta_path, batch_id) diff --git a/tests/unit/test_adls_gen1_client.py b/tests/unit/test_adls_gen1_client.py index b6f0216..52cc828 100644 --- a/tests/unit/test_adls_gen1_client.py +++ b/tests/unit/test_adls_gen1_client.py @@ -49,8 +49,7 @@ def test_from_adls_entry_no_mod_time_returns_none(self): class TestAdlsGen1Client: @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_list_files_basic(self, mock_cred, mock_adls): + def test_list_files_basic(self, mock_adls): mock_fs = MagicMock() mock_fs.ls.return_value = [ { @@ -68,14 +67,13 @@ def test_list_files_basic(self, mock_cred, mock_adls): ] mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) files = client.list_files("/shares/test", recursive=False) assert len(files) == 2 assert files[0].name == "a.ss" @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_list_files_with_pattern(self, mock_cred, mock_adls): + def test_list_files_with_pattern(self, mock_adls): mock_fs = MagicMock() mock_fs.ls.return_value = [ { @@ -93,14 +91,13 @@ def test_list_files_with_pattern(self, mock_cred, mock_adls): ] mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) files = client.list_files("/shares/test", pattern=r".*\.ss$", recursive=False) assert len(files) == 1 assert files[0].name == "file.ss" @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_list_files_sorted_by_modification_time(self, mock_cred, mock_adls): + def test_list_files_sorted_by_modification_time(self, mock_adls): mock_fs = MagicMock() mock_fs.ls.return_value = [ { @@ -118,25 +115,23 @@ def test_list_files_sorted_by_modification_time(self, mock_cred, mock_adls): ] mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) files = client.list_files("/shares/test", recursive=False) assert files[0].name == "older.ss" assert files[1].name == "newer.ss" @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_list_files_not_found_returns_empty(self, mock_cred, mock_adls): + def test_list_files_not_found_returns_empty(self, mock_adls): mock_fs = MagicMock() mock_fs.ls.side_effect = FileNotFoundError("not found") mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) files = client.list_files("/shares/nonexistent", recursive=False) assert files == [] @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_list_files_recursive(self, mock_cred, mock_adls): + def test_list_files_recursive(self, mock_adls): mock_fs = MagicMock() # Root has a directory and a file @@ -162,15 +157,14 @@ def test_list_files_recursive(self, mock_cred, mock_adls): mock_fs.ls.side_effect = [root_entries, subdir_entries] mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) files = client.list_files("/shares/test", recursive=True) assert len(files) == 2 names = {f.name for f in files} assert names == {"root.ss", "nested.ss"} @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_recursive_parallel_multiple_subdirs(self, mock_cred, mock_adls): + def test_recursive_parallel_multiple_subdirs(self, mock_adls): """Multiple sibling directories are walked in parallel.""" mock_fs = MagicMock() @@ -215,14 +209,13 @@ def mock_ls(path, detail=True): mock_fs.ls.side_effect = mock_ls mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) files = client.list_files("/shares/test", recursive=True) assert len(files) == 3 assert {f.name for f in files} == {"a.ss", "b.ss", "c.ss"} @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_recursive_subdir_error_skips_gracefully(self, mock_cred, mock_adls): + def test_recursive_subdir_error_skips_gracefully(self, mock_adls): """If one subdirectory fails, the others still succeed.""" mock_fs = MagicMock() @@ -250,14 +243,70 @@ def mock_ls(path, detail=True): mock_fs.ls.side_effect = mock_ls mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) files = client.list_files("/shares/test", recursive=True) assert len(files) == 1 assert files[0].name == "ok.ss" @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_recursive_deep_nesting(self, mock_cred, mock_adls): + def test_recursive_subdir_credential_exhaustion_propagates(self, mock_adls): + """If a subdirectory raises CredentialUnavailableError, ``list_files`` must + propagate it — otherwise the watermark advances past unseen files. + + Regression for the ``CredentialUnavailableError: Failed to invoke the + Azure CLI`` resiliency work (PR #32). + """ + from azure.identity import CredentialUnavailableError + + mock_fs = MagicMock() + + root_entries = [ + {"name": "/shares/test/good_dir", "type": "DIRECTORY"}, + {"name": "/shares/test/bad_dir", "type": "DIRECTORY"}, + ] + good_dir_entries = [ + { + "name": "/shares/test/good_dir/ok.ss", + "type": "FILE", + "length": 100, + "modificationTime": 1775018672000, + }, + ] + + def mock_ls(path, detail=True): + if path == "/shares/test/bad_dir": + # Simulate ``LockedTokenCredential`` retries being + # exhausted under the parallel walk. + raise CredentialUnavailableError(message="Failed to invoke the Azure CLI") + return { + "/shares/test": root_entries, + "/shares/test/good_dir": good_dir_entries, + }[path] + + mock_fs.ls.side_effect = mock_ls + mock_adls.AzureDLFileSystem.return_value = mock_fs + + client = AdlsGen1Client("test-account", credential=MagicMock()) + with pytest.raises(CredentialUnavailableError): + client.list_files("/shares/test", recursive=True) + + @patch("dbt.adapters.scope.adls_gen1_client.adls_core") + def test_non_recursive_credential_exhaustion_propagates(self, mock_adls): + """Same guarantee for the non-recursive top-level ``ls`` path.""" + from azure.identity import CredentialUnavailableError + + mock_fs = MagicMock() + mock_fs.ls.side_effect = CredentialUnavailableError( + message="Failed to invoke the Azure CLI" + ) + mock_adls.AzureDLFileSystem.return_value = mock_fs + + client = AdlsGen1Client("test-account", credential=MagicMock()) + with pytest.raises(CredentialUnavailableError): + client.list_files("/shares/test", recursive=False) + + @patch("dbt.adapters.scope.adls_gen1_client.adls_core") + def test_recursive_deep_nesting(self, mock_adls): """Parallel walk works with nested subdirectories (depth > 1).""" mock_fs = MagicMock() @@ -278,7 +327,7 @@ def mock_ls(path, detail=True): mock_fs.ls.side_effect = mock_ls mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) files = client.list_files("/root", recursive=True) assert len(files) == 1 assert files[0].name == "deep.ss" @@ -287,8 +336,7 @@ def mock_ls(path, detail=True): class TestWalkProgressLogging: @patch("dbt.adapters.scope.adls_gen1_client.log") @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_logs_per_directory_progress(self, mock_cred, mock_adls, mock_log): + def test_logs_per_directory_progress(self, mock_adls, mock_log): """Walk logs depth, dir/file counts, and in-flight per directory.""" mock_fs = MagicMock() @@ -319,7 +367,7 @@ def mock_ls(path, detail=True): mock_fs.ls.side_effect = mock_ls mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) client.list_files("/root", recursive=True) debug_msgs = [c.args[0] if c.args else "" for c in mock_log.debug.call_args_list] @@ -328,8 +376,7 @@ def mock_ls(path, detail=True): @patch("dbt.adapters.scope.adls_gen1_client.log") @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_logs_timing_for_non_recursive(self, mock_cred, mock_adls, mock_log): + def test_logs_timing_for_non_recursive(self, mock_adls, mock_log): mock_fs = MagicMock() mock_fs.ls.return_value = [ { @@ -341,7 +388,7 @@ def test_logs_timing_for_non_recursive(self, mock_cred, mock_adls, mock_log): ] mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) client.list_files("/shares/test", recursive=False) debug_msgs = [c.args[0] if c.args else "" for c in mock_log.debug.call_args_list] @@ -350,14 +397,13 @@ def test_logs_timing_for_non_recursive(self, mock_cred, mock_adls, mock_log): @patch("dbt.adapters.scope.adls_gen1_client.log") @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_logs_on_not_found(self, mock_cred, mock_adls, mock_log): + def test_logs_on_not_found(self, mock_adls, mock_log): """Walk complete is logged even when root path not found.""" mock_fs = MagicMock() mock_fs.ls.side_effect = FileNotFoundError("not found") mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) client.list_files("/shares/gone", recursive=True) debug_msgs = [c.args[0] if c.args else "" for c in mock_log.debug.call_args_list] @@ -368,22 +414,20 @@ class TestEstimateBytes: """Tests for estimate_bytes — SSv3/v4 vs SSv5/v6 detection.""" @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_ssv3_no_sibling_folder(self, mock_cred, mock_adls): + def test_ssv3_no_sibling_folder(self, mock_adls): """SSv3/v4: no sibling folder → returns (file_length, []).""" mock_fs = MagicMock() mock_fs.info.side_effect = FileNotFoundError("not found") mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) est_bytes, contrib = client.estimate_bytes("/shares/test/data.ss", 727393) assert est_bytes == 727393 assert contrib == [] @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_ssv5_with_du_files(self, mock_cred, mock_adls): + def test_ssv5_with_du_files(self, mock_adls): """SSv5/v6: sibling folder with .du files → returns sum.""" mock_fs = MagicMock() mock_fs.info.return_value = {"type": "DIRECTORY"} @@ -393,7 +437,7 @@ def test_ssv5_with_du_files(self, mock_cred, mock_adls): ] mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) est_bytes, contrib = client.estimate_bytes("/shares/test/data.ss", 4096) assert est_bytes == 4096 + 50_000_000 + 48_000_000 @@ -402,8 +446,7 @@ def test_ssv5_with_du_files(self, mock_cred, mock_adls): assert "/shares/test/data/part-00001.du" in contrib @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_ssv5_with_delta_subfolder(self, mock_cred, mock_adls): + def test_ssv5_with_delta_subfolder(self, mock_adls): """SSv5 with delta updates — recursive listing includes subdirs.""" mock_fs = MagicMock() mock_fs.info.return_value = {"type": "DIRECTORY"} @@ -430,22 +473,21 @@ def mock_ls(path, detail=True): mock_fs.ls.side_effect = mock_ls mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) est_bytes, contrib = client.estimate_bytes("/shares/test/data.ss", 4096) assert est_bytes == 4096 + 50_000_000 + 1_000_000 assert len(contrib) == 2 @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_empty_sibling_folder(self, mock_cred, mock_adls): + def test_empty_sibling_folder(self, mock_adls): """Empty sibling folder → returns (manifest_size, []).""" mock_fs = MagicMock() mock_fs.info.return_value = {"type": "DIRECTORY"} mock_fs.ls.return_value = [] mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) est_bytes, contrib = client.estimate_bytes("/shares/test/data.ss", 2048) assert est_bytes == 2048 @@ -461,14 +503,13 @@ class TestEnrichWithEstimates: """Tests for enrich_with_estimates — bulk enrichment of FileInfo lists.""" @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_enriches_all_files(self, mock_cred, mock_adls): + def test_enriches_all_files(self, mock_adls): mock_fs = MagicMock() # No sibling folders (SSv3/v4 for both) mock_fs.info.side_effect = FileNotFoundError("not found") mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) files = [ FileInfo( path="/shares/test/a.ss", @@ -492,8 +533,7 @@ def test_enriches_all_files(self, mock_cred, mock_adls): assert enriched[1].contributing_files == () @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_enriches_ssv5_files(self, mock_cred, mock_adls): + def test_enriches_ssv5_files(self, mock_adls): mock_fs = MagicMock() mock_fs.info.return_value = {"type": "DIRECTORY"} mock_fs.ls.return_value = [ @@ -501,7 +541,7 @@ def test_enriches_ssv5_files(self, mock_cred, mock_adls): ] mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) files = [ FileInfo( path="/shares/test/a.ss", @@ -520,14 +560,13 @@ def test_empty_list_returns_empty(self): assert client.enrich_with_estimates([]) == [] @patch("dbt.adapters.scope.adls_gen1_client.adls_core") - @patch("dbt.adapters.scope.adls_gen1_client.AzureCliCredential") - def test_fallback_on_error(self, mock_cred, mock_adls): + def test_fallback_on_error(self, mock_adls): """If estimate_bytes fails for a file, fall back to file length.""" mock_fs = MagicMock() mock_fs.info.side_effect = Exception("network error") mock_adls.AzureDLFileSystem.return_value = mock_fs - client = AdlsGen1Client("test-account") + client = AdlsGen1Client("test-account", credential=MagicMock()) files = [ FileInfo( path="/shares/test/bad.ss", @@ -540,3 +579,112 @@ def test_fallback_on_error(self, mock_cred, mock_adls): enriched = client.enrich_with_estimates(files) assert enriched[0].estimated_bytes == 999 assert enriched[0].contributing_files == () + + +class TestLegacyDataLakeCredentialAdapter: + """Cover the legacy ``azure-datalake-store`` 0.0.5x compatibility shim.""" + + def test_signed_session_caches_token_until_expiry(self): + from azure.core.credentials import AccessToken + + from dbt.adapters.scope.adls_gen1_client import _LegacyDataLakeCredentialAdapter + + far_future = int(__import__("time").time()) + 3600 + credential = MagicMock() + credential.get_token.return_value = AccessToken("tok-A", far_future) + + adapter = _LegacyDataLakeCredentialAdapter(credential) + s1 = adapter.signed_session() + s2 = adapter.signed_session() + + assert s1.headers["Authorization"] == "Bearer tok-A" + assert s2.headers["Authorization"] == "Bearer tok-A" + credential.get_token.assert_called_once_with("https://datalake.azure.net//.default") + + def test_signed_session_refreshes_when_near_expiry(self): + from azure.core.credentials import AccessToken + + from dbt.adapters.scope.adls_gen1_client import _LegacyDataLakeCredentialAdapter + + now = int(__import__("time").time()) + # First token expires in 60s — below the 300s refresh lead, so the + # next signed_session() call must refresh. + credential = MagicMock() + credential.get_token.side_effect = [ + AccessToken("tok-old", now + 60), + AccessToken("tok-new", now + 3600), + ] + + adapter = _LegacyDataLakeCredentialAdapter(credential) + s1 = adapter.signed_session() + s2 = adapter.signed_session() + + assert s1.headers["Authorization"] == "Bearer tok-old" + assert s2.headers["Authorization"] == "Bearer tok-new" + assert credential.get_token.call_count == 2 + + def test_refresh_token_forces_unconditional_reacquire(self): + from azure.core.credentials import AccessToken + + from dbt.adapters.scope.adls_gen1_client import _LegacyDataLakeCredentialAdapter + + now = int(__import__("time").time()) + credential = MagicMock() + credential.get_token.side_effect = [ + AccessToken("first", now + 3600), + AccessToken("second", now + 3600), + ] + adapter = _LegacyDataLakeCredentialAdapter(credential) + adapter.signed_session() + adapter.refresh_token() + s = adapter.signed_session() + + assert s.headers["Authorization"] == "Bearer second" + assert credential.get_token.call_count == 2 + + +class TestLegacyGen1SdkDetection: + """Ensure ``_get_fs`` routes the legacy SDK through the adapter.""" + + @patch("dbt.adapters.scope.adls_gen1_client.adls_core") + @patch("dbt.adapters.scope.adls_gen1_client._legacy_gen1_sdk_in_use") + def test_legacy_sdk_uses_token_kwarg_with_adapter(self, mock_is_legacy, mock_adls): + from dbt.adapters.scope.adls_gen1_client import ( + AdlsGen1Client, + _LegacyDataLakeCredentialAdapter, + ) + + mock_is_legacy.return_value = True + mock_adls.AzureDLFileSystem.return_value = MagicMock() + credential = MagicMock() + + client = AdlsGen1Client("acct", credential=credential) + client._get_fs() + + kwargs = mock_adls.AzureDLFileSystem.call_args.kwargs + assert "token_credential" not in kwargs + assert isinstance(kwargs["token"], _LegacyDataLakeCredentialAdapter) + assert kwargs["store_name"] == "acct" + + @patch("dbt.adapters.scope.adls_gen1_client.adls_core") + @patch("dbt.adapters.scope.adls_gen1_client._legacy_gen1_sdk_in_use") + def test_modern_sdk_uses_token_credential_kwarg(self, mock_is_legacy, mock_adls): + from dbt.adapters.scope.adls_gen1_client import AdlsGen1Client + + mock_is_legacy.return_value = False + mock_adls.AzureDLFileSystem.return_value = MagicMock() + credential = MagicMock() + + client = AdlsGen1Client("acct", credential=credential) + client._get_fs() + + kwargs = mock_adls.AzureDLFileSystem.call_args.kwargs + assert kwargs["token_credential"] is credential + assert "token" not in kwargs + assert kwargs["store_name"] == "acct" + + def test_legacy_sdk_in_use_inspects_constructor_signature(self): + from dbt.adapters.scope.adls_gen1_client import _legacy_gen1_sdk_in_use + + # The locally-installed SDK is 1.x and accepts ``token_credential``. + assert _legacy_gen1_sdk_in_use() is False diff --git a/tests/unit/test_checkpoint.py b/tests/unit/test_checkpoint.py index 32918f9..cba90f4 100644 --- a/tests/unit/test_checkpoint.py +++ b/tests/unit/test_checkpoint.py @@ -56,8 +56,7 @@ def test_backward_compat_old_format(self): class TestCheckpointManagerWatermark: @patch("dbt.adapters.scope.checkpoint.DataLakeServiceClient") - @patch("dbt.adapters.scope.checkpoint.AzureCliCredential") - def test_read_watermark(self, mock_cred, mock_service): + def test_read_watermark(self, mock_service): wm_json = Watermark( version=1, modified_time="2026-04-01T12:00:00+00:00", batch_id=3 ).to_json() @@ -69,24 +68,51 @@ def test_read_watermark(self, mock_cred, mock_service): mock_fs.get_file_client.return_value = mock_file mock_service.return_value.get_file_system_client.return_value = mock_fs - result = CheckpointManager().read_watermark("abfss://c@a.dfs.core.windows.net/d/t") + result = CheckpointManager(credential=MagicMock()).read_watermark( + "abfss://c@a.dfs.core.windows.net/d/t" + ) assert result is not None assert result.batch_id == 3 @patch("dbt.adapters.scope.checkpoint.DataLakeServiceClient") - @patch("dbt.adapters.scope.checkpoint.AzureCliCredential") - def test_read_watermark_none_on_missing(self, mock_cred, mock_service): + def test_read_watermark_none_on_missing(self, mock_service): mock_fs = MagicMock() mock_fs.get_file_client.side_effect = Exception("Not found") mock_service.return_value.get_file_system_client.return_value = mock_fs - assert CheckpointManager().read_watermark("abfss://c@a.dfs.core.windows.net/d/t") is None + assert ( + CheckpointManager(credential=MagicMock()).read_watermark( + "abfss://c@a.dfs.core.windows.net/d/t" + ) + is None + ) + + @patch("dbt.adapters.scope.checkpoint.DataLakeServiceClient") + def test_read_watermark_propagates_credential_exhaustion(self, mock_service): + """``read_watermark`` MUST NOT swallow ``CredentialUnavailableError``. + + Returning ``None`` on auth failure would silently flip an + incremental run into a full refresh and re-ingest the entire + source history. Regression for PR #32. + """ + import pytest + from azure.identity import CredentialUnavailableError + + mock_fs = MagicMock() + mock_fs.get_file_client.side_effect = CredentialUnavailableError( + message="Failed to invoke the Azure CLI" + ) + mock_service.return_value.get_file_system_client.return_value = mock_fs + + with pytest.raises(CredentialUnavailableError): + CheckpointManager(credential=MagicMock()).read_watermark( + "abfss://c@a.dfs.core.windows.net/d/t" + ) def test_read_watermark_none_for_bad_path(self): - assert CheckpointManager().read_watermark("https://bad") is None + assert CheckpointManager(credential=MagicMock()).read_watermark("https://bad") is None @patch("dbt.adapters.scope.checkpoint.DataLakeServiceClient") - @patch("dbt.adapters.scope.checkpoint.AzureCliCredential") - def test_write_watermark(self, mock_cred, mock_service): + def test_write_watermark(self, mock_service): mock_dir = MagicMock() mock_file = MagicMock() mock_fs = MagicMock() @@ -95,26 +121,28 @@ def test_write_watermark(self, mock_cred, mock_service): mock_service.return_value.get_file_system_client.return_value = mock_fs wm = Watermark(version=0, modified_time="2026-04-01T12:00:00+00:00", batch_id=0) - CheckpointManager().write_watermark("abfss://c@a.dfs.core.windows.net/d/t", wm) + CheckpointManager(credential=MagicMock()).write_watermark( + "abfss://c@a.dfs.core.windows.net/d/t", wm + ) mock_file.upload_data.assert_called_once() @patch("dbt.adapters.scope.checkpoint.DataLakeServiceClient") - @patch("dbt.adapters.scope.checkpoint.AzureCliCredential") - def test_delete_watermark(self, mock_cred, mock_service): + def test_delete_watermark(self, mock_service): mock_file = MagicMock() mock_fs = MagicMock() mock_fs.get_file_client.return_value = mock_file mock_fs.get_paths.return_value = [] mock_service.return_value.get_file_system_client.return_value = mock_fs - CheckpointManager().delete_watermark("abfss://c@a.dfs.core.windows.net/d/t") + CheckpointManager(credential=MagicMock()).delete_watermark( + "abfss://c@a.dfs.core.windows.net/d/t" + ) mock_file.delete_file.assert_called_once() class TestCheckpointManagerSources: @patch("dbt.adapters.scope.checkpoint.DataLakeServiceClient") - @patch("dbt.adapters.scope.checkpoint.AzureCliCredential") - def test_write_jsonl_on_normal_batch(self, mock_cred, mock_service): + def test_write_jsonl_on_normal_batch(self, mock_service): """Non-compaction batch writes a JSONL diff file.""" from datetime import datetime, timezone @@ -125,7 +153,7 @@ def test_write_jsonl_on_normal_batch(self, mock_cred, mock_service): mock_fs.get_file_client.return_value = mock_file mock_service.return_value.get_file_system_client.return_value = mock_fs - CheckpointManager().write_batch_sources( + CheckpointManager(credential=MagicMock()).write_batch_sources( "abfss://c@a.dfs.core.windows.net/d/t", batch_id=3, file_paths=["/shares/a.ss", "/shares/b.ss"], @@ -147,8 +175,7 @@ def test_write_jsonl_on_normal_batch(self, mock_cred, mock_service): assert "batchProcessingTime" in r0 @patch("dbt.adapters.scope.checkpoint.DataLakeServiceClient") - @patch("dbt.adapters.scope.checkpoint.AzureCliCredential") - def test_batch_zero_always_writes_jsonl(self, mock_cred, mock_service): + def test_batch_zero_always_writes_jsonl(self, mock_service): """Batch 0 is never a compaction boundary, even with interval=1.""" from datetime import datetime, timezone @@ -159,7 +186,7 @@ def test_batch_zero_always_writes_jsonl(self, mock_cred, mock_service): mock_fs.get_file_client.return_value = mock_file mock_service.return_value.get_file_system_client.return_value = mock_fs - CheckpointManager().write_batch_sources( + CheckpointManager(credential=MagicMock()).write_batch_sources( "abfss://c@a.dfs.core.windows.net/d/t", batch_id=0, file_paths=["/shares/a.ss"], @@ -175,8 +202,7 @@ def test_batch_zero_always_writes_jsonl(self, mock_cred, mock_service): json.loads(uploaded.decode("utf-8").split("\n")[0]) @patch("dbt.adapters.scope.checkpoint.DataLakeServiceClient") - @patch("dbt.adapters.scope.checkpoint.AzureCliCredential") - def test_cleanup_deletes_oldest(self, mock_cred, mock_service): + def test_cleanup_deletes_oldest(self, mock_service): mock_file = MagicMock() mock_fs = MagicMock() mock_fs.get_file_client.return_value = mock_file @@ -189,22 +215,21 @@ def test_cleanup_deletes_oldest(self, mock_cred, mock_service): ] mock_service.return_value.get_file_system_client.return_value = mock_fs - deleted = CheckpointManager().cleanup_sources( + deleted = CheckpointManager(credential=MagicMock()).cleanup_sources( "abfss://c@a.dfs.core.windows.net/d/t", max_files=3 ) assert deleted == 2 assert mock_file.delete_file.call_count == 2 @patch("dbt.adapters.scope.checkpoint.DataLakeServiceClient") - @patch("dbt.adapters.scope.checkpoint.AzureCliCredential") - def test_cleanup_noop_under_limit(self, mock_cred, mock_service): + def test_cleanup_noop_under_limit(self, mock_service): mock_fs = MagicMock() mock_fs.get_paths.return_value = [ SimpleNamespace(name="d/t/_checkpoint/sources/0", is_directory=False), ] mock_service.return_value.get_file_system_client.return_value = mock_fs - deleted = CheckpointManager().cleanup_sources( + deleted = CheckpointManager(credential=MagicMock()).cleanup_sources( "abfss://c@a.dfs.core.windows.net/d/t", max_files=100 ) assert deleted == 0 diff --git a/tests/unit/test_checkpoint_lifecycle.py b/tests/unit/test_checkpoint_lifecycle.py index a1b0438..d125c56 100644 --- a/tests/unit/test_checkpoint_lifecycle.py +++ b/tests/unit/test_checkpoint_lifecycle.py @@ -7,10 +7,11 @@ from __future__ import annotations +import json import os -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from types import SimpleNamespace -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -103,9 +104,8 @@ def checkpoint_mgr(adls_store): service = InMemoryServiceClient(adls_store) with ( patch("dbt.adapters.scope.checkpoint._get_service", return_value=service), - patch("dbt.adapters.scope.checkpoint.AzureCliCredential"), ): - yield CheckpointManager() + yield CheckpointManager(credential=MagicMock()) def _make_times(count: int, base_year: int = 2026, base_month: int = 4) -> list[datetime]: @@ -505,7 +505,6 @@ def test_dump_checkpoint_to_disk(self, checkpoint_mgr, adls_store): ls -la /tmp/dbt_scope_checkpoint_demo/_checkpoint/sources/ duckdb -c "SELECT * FROM read_parquet('/tmp/dbt_scope_checkpoint_demo/_checkpoint/sources/100.parquet') LIMIT 20" """ - import json import shutil import duckdb @@ -690,3 +689,162 @@ def sort_key(r): f"Validated {len(parquet_snapshots)} parquet snapshots: " f"each = previous parquet + intermediate JSONL diffs ✓" ) + + def test_compaction_handles_timestamp_typed_prior_snapshot(self, checkpoint_mgr, adls_store): + """Regression for: ``Object of type datetime is not JSON serializable``. + + In production, batches are spaced minutes apart so the + ``batchProcessingTime`` strings in a per-snapshot NDJSON have + enough variation for DuckDB's ``read_json_auto`` to infer + ``TIMESTAMP``. The resulting parquet snapshot then stores the + column as ``TIMESTAMP``, and the **next** compaction reads it + back as Python ``datetime`` — which used to crash + ``json.dumps(...)`` inside ``_write_snapshot_parquet``. + + This test reproduces that exact shape deterministically by + pre-seeding a parquet snapshot whose ``batchProcessingTime`` + column is ``TIMESTAMP``, then triggering a second compaction. + Pre-fix: raises ``TypeError: Object of type datetime is not JSON + serializable``. Post-fix: succeeds. + """ + import duckdb + + sources_prefix = "delta/lifecycle_test/_checkpoint/sources" # matches DELTA_LOC's path + compaction_interval = 10 + + # ── 1. Build a parquet snapshot whose ``batchProcessingTime`` is + # explicitly ``TIMESTAMP``-typed — mimicking what happens in + # production when DuckDB's ``read_json_auto`` infers TIMESTAMP + # from well-spaced batch processing times. (DuckDB's inference + # heuristic varies across versions/sample sizes, so we force the + # cast here to make the test deterministic across environments.) + seeded_records = 22 # 11 batches x 2 files + prior_snapshot_path = f"/tmp/test_prior_snapshot_{id(self)}.parquet" + prior_ndjson_path = f"/tmp/test_prior_snapshot_{id(self)}.ndjson" + try: + with open(prior_ndjson_path, "w") as f: + for i in range(seeded_records): + ts = datetime(2026, 6, 1, tzinfo=timezone.utc) + timedelta(minutes=i * 5) + f.write( + json.dumps( + { + "path": f"/shares/seed/file_{i:04d}.ss", + "modificationTime": 1700000000000 + i, + "batchId": i // 2, + "batchProcessingTime": ts.isoformat(), + } + ) + + "\n" + ) + conn = duckdb.connect() + try: + # Force TIMESTAMP via explicit CAST so the test is + # deterministic regardless of DuckDB's auto-inference + # heuristic — what we want to assert is the read-back + + # next-compaction behaviour, not DuckDB's inference. + conn.execute( + "CREATE TABLE t AS " + "SELECT path, " + '"modificationTime", ' + '"batchId", ' + 'CAST("batchProcessingTime" AS TIMESTAMP) AS "batchProcessingTime" ' + f"FROM read_json_auto('{prior_ndjson_path}')" + ) + schema = conn.execute("DESCRIBE t").fetchall() + assert any( + col[0] == "batchProcessingTime" and col[1] == "TIMESTAMP" for col in schema + ), f"Setup precondition failed: expected TIMESTAMP, got {schema}" + conn.execute(f"COPY t TO '{prior_snapshot_path}' (FORMAT PARQUET)") + finally: + conn.close() + + with open(prior_snapshot_path, "rb") as f: + prior_parquet_bytes = f.read() + finally: + for p in (prior_ndjson_path, prior_snapshot_path): + if os.path.exists(p): + os.remove(p) + + # ── 2. Inject the prior snapshot directly into in-memory ADLS as + # "10.parquet" so the next compaction will see it. + adls_store[f"{sources_prefix}/10.parquet"] = prior_parquet_bytes + + # ── 2b. Also seed a JSONL diff for batch 15 — _write_snapshot_parquet + # is supposed to merge "prior snapshot + JSONL diffs since snapshot + + # current batch", so we want to exercise all three legs. + intermediate_paths = _make_paths(3, prefix="/shares/b15") + intermediate_times = _make_times(3, base_month=6) + checkpoint_mgr.write_batch_sources( + DELTA_LOC, + batch_id=15, + file_paths=intermediate_paths, + modification_times=intermediate_times, + compaction_interval=compaction_interval, # 15 % 10 != 0 → JSONL + ) + assert "15" in checkpoint_mgr.list_source_files(DELTA_LOC), ( + "Intermediate JSONL diff for batch 15 should exist" + ) + + # ── 3. Trigger compaction at batch 20. Pre-fix: this raises + # ``TypeError: Object of type datetime is not JSON serializable`` + # from inside _write_snapshot_parquet. Post-fix: success. + new_batch_paths = _make_paths(2, prefix="/shares/b20") + new_batch_times = _make_times(2, base_month=7) + checkpoint_mgr.write_batch_sources( + DELTA_LOC, + batch_id=20, + file_paths=new_batch_paths, + modification_times=new_batch_times, + compaction_interval=compaction_interval, + ) + + # ── 4. Verify the new snapshot exists and contains records from + # every leg of the union (prior 22 + intermediate JSONL 3 + current 2 = 27). + sources = checkpoint_mgr.list_source_files(DELTA_LOC) + assert "20.parquet" in sources, f"Expected 20.parquet, got {sources}" + assert "10.parquet" in sources, "Prior snapshot should still exist" + assert "15" in sources, "Intermediate JSONL diff should still exist" + + new_snapshot_key = next(k for k in adls_store if k.endswith("20.parquet")) + new_local = f"/tmp/test_new_snapshot_{id(self)}.parquet" + with open(new_local, "wb") as f: + f.write(adls_store[new_snapshot_key]) + try: + conn = duckdb.connect() + try: + total = conn.execute( + f"SELECT count(*) FROM read_parquet('{new_local}')" + ).fetchone()[0] + expected = seeded_records + len(intermediate_paths) + len(new_batch_paths) + assert total == expected, ( + f"Expected {expected} records in new snapshot, got {total}" + ) + + seen_batch_ids = { + row[0] + for row in conn.execute( + f"SELECT DISTINCT \"batchId\" FROM read_parquet('{new_local}')" + ).fetchall() + } + assert 15 in seen_batch_ids, ( + f"JSONL diff records (batchId=15) missing from snapshot: {seen_batch_ids}" + ) + assert 20 in seen_batch_ids, ( + f"Current batch records (batchId=20) missing from snapshot: {seen_batch_ids}" + ) + assert any(bid <= 10 for bid in seen_batch_ids), ( + f"Prior snapshot records (batchId<=10) missing: {seen_batch_ids}" + ) + + # New snapshot's batchProcessingTime is deterministically + # TIMESTAMP (enforced by the explicit CAST in the fix). + new_schema = conn.execute( + f"DESCRIBE SELECT * FROM read_parquet('{new_local}')" + ).fetchall() + assert any( + col[0] == "batchProcessingTime" and col[1] == "TIMESTAMP" for col in new_schema + ), f"Expected TIMESTAMP schema in new snapshot, got {new_schema}" + finally: + conn.close() + finally: + os.remove(new_local) diff --git a/tests/unit/test_credentials.py b/tests/unit/test_credentials.py index 3d52bee..2375bbf 100644 --- a/tests/unit/test_credentials.py +++ b/tests/unit/test_credentials.py @@ -1,5 +1,8 @@ """Tests for ScopeCredentials.""" +import pytest +from dbt_common.exceptions import DbtRuntimeError + from dbt.adapters.scope.credentials import ScopeCredentials # Base Credentials requires database and schema @@ -29,6 +32,9 @@ def test_default_values(self): assert creds.poll_interval_seconds == 5 assert creds.job_timeout_seconds == 36_000 assert creds.delta_base_path == "delta" + assert creds.max_file_count_per_output_file_set == 5000 + assert creds.cancel_jobs_on_shutdown is True + assert creds.wait_on_cancel_seconds == 30 def test_custom_values(self): creds = ScopeCredentials( @@ -37,6 +43,9 @@ def test_custom_values(self): container="mycontainer", au=50, priority=2, + max_file_count_per_output_file_set=250000, + cancel_jobs_on_shutdown=False, + wait_on_cancel_seconds=120, **_BASE_KWARGS, ) assert creds.adla_account == "my-adla" @@ -44,3 +53,179 @@ def test_custom_values(self): assert creds.container == "mycontainer" assert creds.au == 50 assert creds.priority == 2 + assert creds.max_file_count_per_output_file_set == 250000 + assert creds.cancel_jobs_on_shutdown is False + assert creds.wait_on_cancel_seconds == 120 + + def test_max_file_count_in_connection_keys(self): + creds = ScopeCredentials(adla_account="test-account", **_BASE_KWARGS) + assert "max_file_count_per_output_file_set" in creds._connection_keys() + + def test_shutdown_settings_in_connection_keys(self): + creds = ScopeCredentials(adla_account="test-account", **_BASE_KWARGS) + keys = creds._connection_keys() + assert "cancel_jobs_on_shutdown" in keys + assert "wait_on_cancel_seconds" in keys + + +class TestAuthenticationFields: + def test_authentication_defaults_to_cli(self): + creds = ScopeCredentials(adla_account="x", **_BASE_KWARGS) + assert creds.authentication == "cli" + assert creds.credential_class is None + assert creds.credential_kwargs == {} + + def test_token_credential_requires_credential_class(self): + with pytest.raises(DbtRuntimeError, match="requires `credential_class`"): + ScopeCredentials( + adla_account="x", + authentication="token_credential", + **_BASE_KWARGS, + ) + + def test_token_credential_accepts_credential_class(self): + creds = ScopeCredentials( + adla_account="x", + authentication="token_credential", + credential_class="my_pkg.MyCredential", + credential_kwargs={"foo": "bar"}, + **_BASE_KWARGS, + ) + assert creds.credential_class == "my_pkg.MyCredential" + assert creds.credential_kwargs == {"foo": "bar"} + + def test_credential_class_rejected_under_cli_auth(self): + with pytest.raises(DbtRuntimeError, match="only valid when"): + ScopeCredentials( + adla_account="x", + credential_class="my_pkg.MyCredential", + **_BASE_KWARGS, + ) + + def test_credential_kwargs_rejected_under_cli_auth(self): + with pytest.raises(DbtRuntimeError, match="only valid when"): + ScopeCredentials( + adla_account="x", + credential_kwargs={"foo": "bar"}, + **_BASE_KWARGS, + ) + + def test_authentication_in_connection_keys(self): + creds = ScopeCredentials(adla_account="x", **_BASE_KWARGS) + keys = creds._connection_keys() + assert "authentication" in keys + assert "credential_class" in keys + + def test_authentication_case_insensitive_match(self): + creds = ScopeCredentials( + adla_account="x", + authentication="Token_Credential", + credential_class="my_pkg.MyCredential", + **_BASE_KWARGS, + ) + assert creds.credential_class == "my_pkg.MyCredential" + + +class TestMessageRetryFields: + def test_defaults(self): + creds = ScopeCredentials(adla_account="x", **_BASE_KWARGS) + assert creds.retry_on_error_messages == [] + assert creds.max_retries_on_error == 25 + assert creds.initial_wait_on_error_seconds == 1.0 + assert creds.max_wait_on_error_seconds == 30.0 + + def test_custom_values_accepted(self): + creds = ScopeCredentials( + adla_account="x", + retry_on_error_messages=["a", "re:b\\d+"], + max_retries_on_error=10, + initial_wait_on_error_seconds=2.0, + max_wait_on_error_seconds=60.0, + **_BASE_KWARGS, + ) + assert creds.retry_on_error_messages == ["a", "re:b\\d+"] + assert creds.max_retries_on_error == 10 + assert creds.initial_wait_on_error_seconds == 2.0 + assert creds.max_wait_on_error_seconds == 60.0 + + def test_negative_max_retries_rejected(self): + with pytest.raises(DbtRuntimeError, match="max_retries_on_error must be >= 0"): + ScopeCredentials(adla_account="x", max_retries_on_error=-1, **_BASE_KWARGS) + + def test_non_positive_initial_wait_rejected(self): + with pytest.raises(DbtRuntimeError, match="initial_wait_on_error_seconds must be > 0"): + ScopeCredentials(adla_account="x", initial_wait_on_error_seconds=0, **_BASE_KWARGS) + + def test_non_positive_max_wait_rejected(self): + with pytest.raises(DbtRuntimeError, match="max_wait_on_error_seconds must be > 0"): + ScopeCredentials(adla_account="x", max_wait_on_error_seconds=0, **_BASE_KWARGS) + + def test_initial_greater_than_max_rejected(self): + with pytest.raises(DbtRuntimeError, match="must be <= max_wait_on_error_seconds"): + ScopeCredentials( + adla_account="x", + initial_wait_on_error_seconds=60, + max_wait_on_error_seconds=30, + **_BASE_KWARGS, + ) + + def test_empty_pattern_string_rejected(self): + with pytest.raises(DbtRuntimeError, match="non-empty strings"): + ScopeCredentials(adla_account="x", retry_on_error_messages=[""], **_BASE_KWARGS) + + def test_retry_fields_in_connection_keys(self): + keys = ScopeCredentials(adla_account="x", **_BASE_KWARGS)._connection_keys() + assert "retry_on_error_messages" in keys + assert "max_retries_on_error" in keys + assert "initial_wait_on_error_seconds" in keys + assert "max_wait_on_error_seconds" in keys + + +class TestQuotaEvictionFields: + def test_defaults(self): + creds = ScopeCredentials(adla_account="x", **_BASE_KWARGS) + assert creds.enable_quota_eviction is True + assert creds.quota_eviction_max_attempts == 25 + assert creds.quota_eviction_cancel_num == 5 + assert creds.quota_eviction_wait_seconds == 30.0 + assert creds.quota_eviction_jitter_seconds == 5.0 + + def test_custom_values_accepted(self): + creds = ScopeCredentials( + adla_account="x", + enable_quota_eviction=False, + quota_eviction_max_attempts=3, + quota_eviction_cancel_num=10, + quota_eviction_wait_seconds=60.0, + quota_eviction_jitter_seconds=0.0, + **_BASE_KWARGS, + ) + assert creds.enable_quota_eviction is False + assert creds.quota_eviction_max_attempts == 3 + assert creds.quota_eviction_cancel_num == 10 + assert creds.quota_eviction_wait_seconds == 60.0 + assert creds.quota_eviction_jitter_seconds == 0.0 + + def test_negative_max_attempts_rejected(self): + with pytest.raises(DbtRuntimeError, match="quota_eviction_max_attempts must be >= 0"): + ScopeCredentials(adla_account="x", quota_eviction_max_attempts=-1, **_BASE_KWARGS) + + def test_zero_cancel_num_rejected(self): + with pytest.raises(DbtRuntimeError, match="quota_eviction_cancel_num must be >= 1"): + ScopeCredentials(adla_account="x", quota_eviction_cancel_num=0, **_BASE_KWARGS) + + def test_non_positive_wait_rejected(self): + with pytest.raises(DbtRuntimeError, match="quota_eviction_wait_seconds must be > 0"): + ScopeCredentials(adla_account="x", quota_eviction_wait_seconds=0, **_BASE_KWARGS) + + def test_negative_jitter_rejected(self): + with pytest.raises(DbtRuntimeError, match="quota_eviction_jitter_seconds must be >= 0"): + ScopeCredentials(adla_account="x", quota_eviction_jitter_seconds=-1, **_BASE_KWARGS) + + def test_fields_in_connection_keys(self): + keys = ScopeCredentials(adla_account="x", **_BASE_KWARGS)._connection_keys() + assert "enable_quota_eviction" in keys + assert "quota_eviction_max_attempts" in keys + assert "quota_eviction_cancel_num" in keys + assert "quota_eviction_wait_seconds" in keys + assert "quota_eviction_jitter_seconds" in keys diff --git a/tests/unit/test_custom_credential.py b/tests/unit/test_custom_credential.py new file mode 100644 index 0000000..ca1fca4 --- /dev/null +++ b/tests/unit/test_custom_credential.py @@ -0,0 +1,126 @@ +"""Tests for custom_credential — dotted-path TokenCredential loader.""" + +from __future__ import annotations + +import sys +import types +from datetime import datetime, timezone + +import pytest +from azure.core.credentials import AccessToken, TokenCredential +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.scope.custom_credential import ( + _cache_key, + clear_cache, + load_custom_credential, +) + + +class _FakeCredential: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def get_token(self, *scopes, **kw): + return AccessToken( + token="fake-token", expires_on=int(datetime.now(tz=timezone.utc).timestamp()) + 3600 + ) + + +class _StrictCredential: + def __init__(self, foo: str): + self.foo = foo + + def get_token(self, *scopes, **kw): + return AccessToken( + token="strict", expires_on=int(datetime.now(tz=timezone.utc).timestamp()) + 3600 + ) + + +class _NotACredential: + def __init__(self, **kwargs): + pass + + +@pytest.fixture(autouse=True) +def _clear_cache(): + clear_cache() + yield + clear_cache() + + +@pytest.fixture +def fake_module(): + name = "_test_custom_cred_module" + mod = types.ModuleType(name) + mod.FakeCredential = _FakeCredential + mod.StrictCredential = _StrictCredential + mod.NotACredential = _NotACredential + sys.modules[name] = mod + yield name + sys.modules.pop(name, None) + + +class TestLoadCustomCredential: + def test_loads_and_isinstance(self, fake_module): + cred = load_custom_credential(f"{fake_module}.FakeCredential", {"foo": "bar"}) + assert isinstance(cred, TokenCredential) + assert cred.kwargs == {"foo": "bar"} + + def test_caches_instance(self, fake_module): + first = load_custom_credential(f"{fake_module}.FakeCredential", {"foo": "bar"}) + second = load_custom_credential(f"{fake_module}.FakeCredential", {"foo": "bar"}) + assert first is second + + def test_cache_distinguishes_kwargs(self, fake_module): + first = load_custom_credential(f"{fake_module}.FakeCredential", {"foo": "bar"}) + second = load_custom_credential(f"{fake_module}.FakeCredential", {"foo": "baz"}) + assert first is not second + + def test_cache_handles_nested_dicts(self, fake_module): + nested = {"auth": {"method": "SNI", "sni": {"client_id": "abc"}}} + first = load_custom_credential(f"{fake_module}.FakeCredential", nested) + second = load_custom_credential(f"{fake_module}.FakeCredential", dict(nested)) + assert first is second + + def test_rejects_missing_class(self): + with pytest.raises(DbtRuntimeError, match="requires `credential_class`"): + load_custom_credential(None, {}) + with pytest.raises(DbtRuntimeError, match="requires `credential_class`"): + load_custom_credential("", {}) + + def test_rejects_non_dotted_path(self): + with pytest.raises(DbtRuntimeError, match="dotted path"): + load_custom_credential("notdotted", {}) + + def test_rejects_invalid_identifier(self): + with pytest.raises(DbtRuntimeError, match="dotted path"): + load_custom_credential("pkg.123bad", {}) + + def test_import_error_surfaces(self): + with pytest.raises(DbtRuntimeError, match="Could not import module"): + load_custom_credential("nonexistent_module_xyz.SomeClass", {}) + + def test_attribute_error_surfaces(self, fake_module): + with pytest.raises(DbtRuntimeError, match="no attribute"): + load_custom_credential(f"{fake_module}.MissingClass", {}) + + def test_type_error_surfaces(self, fake_module): + with pytest.raises(DbtRuntimeError, match="Failed to instantiate"): + load_custom_credential(f"{fake_module}.StrictCredential", {"unknown_kwarg": "value"}) + + def test_isinstance_enforced(self, fake_module): + with pytest.raises(DbtRuntimeError, match="must implement"): + load_custom_credential(f"{fake_module}.NotACredential", {}) + + +class TestCacheKey: + def test_stable_for_reordered_kwargs(self): + a = _cache_key("pkg.Cls", {"b": 1, "a": 2}) + b = _cache_key("pkg.Cls", {"a": 2, "b": 1}) + assert a == b + + def test_distinguishes_class(self): + a = _cache_key("pkg.A", {"x": 1}) + b = _cache_key("pkg.B", {"x": 1}) + assert a != b diff --git a/tests/unit/test_file_lock.py b/tests/unit/test_file_lock.py index 975d5e5..e2ec3b6 100644 --- a/tests/unit/test_file_lock.py +++ b/tests/unit/test_file_lock.py @@ -6,7 +6,7 @@ import threading from pathlib import Path -from dbt.adapters.scope._file_lock import AZ_CLI_TOKEN_LOCK, FileLock +from dbt.adapters.scope._file_lock import AZ_CLI_TOKEN_LOCK, FABRIC_TOKEN_LOCK, FileLock class TestFileLock: @@ -70,3 +70,9 @@ def test_az_cli_token_lock_constant(self): """AZ_CLI_TOKEN_LOCK is a well-known path in the temp directory.""" assert "dbt-scope-az-cli-token" in AZ_CLI_TOKEN_LOCK assert tempfile.gettempdir() in AZ_CLI_TOKEN_LOCK + + def test_fabric_token_lock_constant(self): + """FABRIC_TOKEN_LOCK is a well-known path in the temp directory.""" + assert "dbt-scope-fabric-token" in FABRIC_TOKEN_LOCK + assert tempfile.gettempdir() in FABRIC_TOKEN_LOCK + assert FABRIC_TOKEN_LOCK != AZ_CLI_TOKEN_LOCK diff --git a/tests/unit/test_locked_token_credential.py b/tests/unit/test_locked_token_credential.py new file mode 100644 index 0000000..0210b1a --- /dev/null +++ b/tests/unit/test_locked_token_credential.py @@ -0,0 +1,295 @@ +"""Unit tests for ``LockedTokenCredential`` + ``RetryPolicy``. + +Verifies the credential-retry resilience added for the +``CredentialUnavailableError: Failed to invoke the Azure CLI`` failure +mode observed in production (PR #32). +""" + +from __future__ import annotations + +import tempfile +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from azure.identity import CredentialUnavailableError + +from dbt.adapters.scope.delta_lake import LockedTokenCredential, RetryPolicy + + +@pytest.fixture +def lock_path() -> str: + """Per-test lock file under tmp to avoid cross-test contention.""" + with tempfile.NamedTemporaryFile(suffix=".lock", delete=False) as f: + return f.name + + +class _RecordingSleep: + """Record every sleep duration without actually sleeping.""" + + def __init__(self) -> None: + self.calls: list[float] = [] + + def __call__(self, seconds: float) -> None: + self.calls.append(seconds) + + +# -- RetryPolicy --------------------------------------------------------- + + +class TestRetryPolicy: + def test_defaults(self) -> None: + policy = RetryPolicy() + assert policy.max_retries == 10 + assert policy.initial_delay_seconds == 1.0 + assert policy.max_delay_seconds == 10.0 + + def test_from_http_retries_none_uses_defaults(self) -> None: + policy = RetryPolicy.from_http_retries(None) + assert policy.max_retries == 10 + + def test_from_http_retries_negative_uses_defaults(self) -> None: + policy = RetryPolicy.from_http_retries(-1) + assert policy.max_retries == 10 + + def test_from_http_retries_zero_disables_retries(self) -> None: + # ``0`` means "do not retry" — only the initial attempt runs. + policy = RetryPolicy.from_http_retries(0) + assert policy.max_retries == 0 + + def test_from_http_retries_passthrough(self) -> None: + policy = RetryPolicy.from_http_retries(25) + assert policy.max_retries == 25 + + +# -- LockedTokenCredential.get_token ------------------------------------- + + +class TestLockedTokenCredential: + def test_succeeds_on_first_attempt_no_sleep(self, lock_path: str) -> None: + inner = MagicMock() + inner.get_token.return_value = SimpleNamespace(token="t", expires_on=0) + sleep = _RecordingSleep() + + cred = LockedTokenCredential( + inner, + lock_file=lock_path, + retry_policy=RetryPolicy(max_retries=3), + sleep=sleep, + ) + + token = cred.get_token("https://example.com/.default") + + assert token.token == "t" + assert inner.get_token.call_count == 1 + assert sleep.calls == [] + + def test_succeeds_after_transient_failures(self, lock_path: str) -> None: + inner = MagicMock() + inner.get_token.side_effect = [ + CredentialUnavailableError(message="cli timeout 1"), + CredentialUnavailableError(message="cli timeout 2"), + SimpleNamespace(token="t", expires_on=0), + ] + sleep = _RecordingSleep() + + cred = LockedTokenCredential( + inner, + lock_file=lock_path, + retry_policy=RetryPolicy( + max_retries=5, initial_delay_seconds=1.0, max_delay_seconds=10.0 + ), + sleep=sleep, + ) + + token = cred.get_token("scope") + + assert token.token == "t" + assert inner.get_token.call_count == 3 + # Slept twice (after attempts 1 and 2), not after the success. + assert sleep.calls == [1.0, 2.0] + + def test_exhausts_retries_and_reraises(self, lock_path: str) -> None: + inner = MagicMock() + inner.get_token.side_effect = CredentialUnavailableError(message="permanent") + sleep = _RecordingSleep() + + cred = LockedTokenCredential( + inner, + lock_file=lock_path, + retry_policy=RetryPolicy( + max_retries=3, initial_delay_seconds=1.0, max_delay_seconds=10.0 + ), + sleep=sleep, + ) + + with pytest.raises(CredentialUnavailableError): + cred.get_token("scope") + + # 3 retries + 1 initial attempt == 4 total calls + assert inner.get_token.call_count == 4 + # Sleeps happen between attempts: 1s, 2s, 3s (3 sleeps == max_retries) + # No sleep after the final failed attempt. + assert sleep.calls == [1.0, 2.0, 3.0] + + def test_linear_backoff_caps_at_max_delay(self, lock_path: str) -> None: + inner = MagicMock() + inner.get_token.side_effect = CredentialUnavailableError(message="x") + sleep = _RecordingSleep() + + cred = LockedTokenCredential( + inner, + lock_file=lock_path, + retry_policy=RetryPolicy( + max_retries=15, initial_delay_seconds=1.0, max_delay_seconds=10.0 + ), + sleep=sleep, + ) + + with pytest.raises(CredentialUnavailableError): + cred.get_token("scope") + + # Linear ramp 1..10s then capped at 10s. + expected = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] + assert sleep.calls == expected + # 15 retries + 1 initial == 16 total calls; 15 sleeps (one per retry) + assert inner.get_token.call_count == 16 + assert len(sleep.calls) == 15 + + def test_zero_retries_makes_single_attempt(self, lock_path: str) -> None: + inner = MagicMock() + inner.get_token.side_effect = CredentialUnavailableError(message="x") + sleep = _RecordingSleep() + + cred = LockedTokenCredential( + inner, + lock_file=lock_path, + retry_policy=RetryPolicy(max_retries=0), + sleep=sleep, + ) + + with pytest.raises(CredentialUnavailableError): + cred.get_token("scope") + + assert inner.get_token.call_count == 1 + assert sleep.calls == [] # never sleep when no retries configured + + def test_does_not_retry_unrelated_exceptions(self, lock_path: str) -> None: + inner = MagicMock() + inner.get_token.side_effect = RuntimeError("boom") + sleep = _RecordingSleep() + + cred = LockedTokenCredential( + inner, + lock_file=lock_path, + retry_policy=RetryPolicy(max_retries=5), + sleep=sleep, + ) + + with pytest.raises(RuntimeError, match="boom"): + cred.get_token("scope") + + # Did NOT retry — only ``CredentialUnavailableError`` is retried. + assert inner.get_token.call_count == 1 + assert sleep.calls == [] + + def test_passes_claims_kwarg_through(self, lock_path: str) -> None: + inner = MagicMock() + inner.get_token.return_value = SimpleNamespace(token="t", expires_on=0) + + cred = LockedTokenCredential(inner, lock_file=lock_path) + + cred.get_token("scope", claims="my-claims") + + inner.get_token.assert_called_once_with("scope", claims="my-claims") + + def test_lock_is_released_between_attempts(self, lock_path: str) -> None: + # The lock must be released between attempts so concurrent + # workers can make progress. We verify by acquiring it from a + # parallel "thread" (via the lock_file path on disk) in between + # the failing attempts. + from dbt.adapters.scope._file_lock import FileLock + + attempts_observed_unlocked: list[bool] = [] + call_counter = {"n": 0} + + def fake_get_token(*args, **kwargs): + call_counter["n"] += 1 + if call_counter["n"] < 3: + raise CredentialUnavailableError(message="transient") + return SimpleNamespace(token="t", expires_on=0) + + inner = MagicMock() + inner.get_token.side_effect = fake_get_token + + def sleep_and_probe(_seconds: float) -> None: + # Between attempts, attempt to acquire the lock — should + # succeed instantly because LockedTokenCredential released + # it after the failed attempt. + try: + with FileLock(lock_path): + attempts_observed_unlocked.append(True) + except Exception: + attempts_observed_unlocked.append(False) + + cred = LockedTokenCredential( + inner, + lock_file=lock_path, + retry_policy=RetryPolicy(max_retries=5), + sleep=sleep_and_probe, + ) + + cred.get_token("scope") + + # We made 2 failed attempts → 2 sleeps → 2 probes. + assert attempts_observed_unlocked == [True, True] + + def test_uses_default_policy_when_none(self, lock_path: str) -> None: + # No explicit policy → default 10 retries. + inner = MagicMock() + inner.get_token.side_effect = CredentialUnavailableError(message="x") + sleep = _RecordingSleep() + + cred = LockedTokenCredential(inner, lock_file=lock_path, sleep=sleep) + + with pytest.raises(CredentialUnavailableError): + cred.get_token("scope") + + # Default RetryPolicy → 10 retries + 1 initial == 11 calls + assert inner.get_token.call_count == 11 + assert len(sleep.calls) == 10 + + +# -- build_credential lock-file dispatch --------------------------------- + + +class TestBuildCredentialLockFile: + """``build_credential`` picks the lock file based on ``authentication``.""" + + def test_cli_auth_uses_az_cli_lock(self) -> None: + from dbt.adapters.scope._file_lock import AZ_CLI_TOKEN_LOCK + from dbt.adapters.scope.delta_lake import build_credential + + creds = SimpleNamespace(authentication="cli", http_retries=0) + cred = build_credential(creds) + assert isinstance(cred, LockedTokenCredential) + assert cred._lock_file == AZ_CLI_TOKEN_LOCK + + def test_token_credential_auth_uses_fabric_lock(self, monkeypatch: pytest.MonkeyPatch) -> None: + from dbt.adapters.scope._file_lock import FABRIC_TOKEN_LOCK + from dbt.adapters.scope.delta_lake import build_credential + + fake_inner = MagicMock(name="custom_inner") + monkeypatch.setattr( + "dbt.adapters.scope.custom_credential.load_custom_credential", + lambda *_a, **_k: fake_inner, + ) + creds = SimpleNamespace( + authentication="token_credential", + credential_class="some.module.SomeCred", + credential_kwargs={}, + http_retries=0, + ) + cred = build_credential(creds) + assert isinstance(cred, LockedTokenCredential) + assert cred._lock_file == FABRIC_TOKEN_LOCK diff --git a/tests/unit/test_message_retry.py b/tests/unit/test_message_retry.py new file mode 100644 index 0000000..397a979 --- /dev/null +++ b/tests/unit/test_message_retry.py @@ -0,0 +1,260 @@ +"""Tests for MessageRetryPolicy + retry_on_message.""" + +from __future__ import annotations + +import re + +import pytest +from dbt_common.exceptions import DbtDatabaseError + +from dbt.adapters.scope.message_retry import ( + MessageRetryPolicy, + retry_on_message, +) + + +class _Creds: + """Tiny duck-typed stand-in for ScopeCredentials.""" + + def __init__( + self, + retry_on_error_messages=None, + max_retries_on_error=25, + initial_wait_on_error_seconds=1.0, + max_wait_on_error_seconds=30.0, + ): + self.retry_on_error_messages = retry_on_error_messages + self.max_retries_on_error = max_retries_on_error + self.initial_wait_on_error_seconds = initial_wait_on_error_seconds + self.max_wait_on_error_seconds = max_wait_on_error_seconds + + +class TestMessageRetryPolicyConstruction: + def test_disabled_factory_produces_empty_policy(self): + policy = MessageRetryPolicy.disabled() + assert policy.patterns == () + assert policy.enabled is False + + def test_from_credentials_compiles_substring_and_regex(self): + creds = _Creds( + retry_on_error_messages=[ + "Cannot exceed", + "re:queued \\d+ jobs", + ], + max_retries_on_error=3, + initial_wait_on_error_seconds=0.5, + max_wait_on_error_seconds=4.0, + ) + + policy = MessageRetryPolicy.from_credentials(creds) + + assert policy.max_retries == 3 + assert policy.initial_wait_seconds == 0.5 + assert policy.max_wait_seconds == 4.0 + assert len(policy.patterns) == 2 + assert policy.patterns[0] == "Cannot exceed" + assert isinstance(policy.patterns[1], re.Pattern) + assert policy.patterns[1].pattern == "queued \\d+ jobs" + + def test_from_credentials_with_no_patterns_disables_policy(self): + policy = MessageRetryPolicy.from_credentials(_Creds(retry_on_error_messages=[])) + assert policy.enabled is False + + def test_from_credentials_rejects_empty_string_entry(self): + with pytest.raises(ValueError, match="non-empty strings"): + MessageRetryPolicy.from_credentials(_Creds(retry_on_error_messages=[""])) + + def test_from_credentials_rejects_empty_regex_after_prefix(self): + with pytest.raises(ValueError, match="empty after"): + MessageRetryPolicy.from_credentials(_Creds(retry_on_error_messages=["re:"])) + + def test_from_credentials_rejects_negative_max_retries(self): + creds = _Creds(retry_on_error_messages=["x"], max_retries_on_error=-1) + with pytest.raises(ValueError, match="max_retries_on_error must be >= 0"): + MessageRetryPolicy.from_credentials(creds) + + def test_from_credentials_rejects_non_positive_initial_wait(self): + creds = _Creds(retry_on_error_messages=["x"], initial_wait_on_error_seconds=0) + with pytest.raises(ValueError, match="initial_wait_on_error_seconds must be > 0"): + MessageRetryPolicy.from_credentials(creds) + + def test_from_credentials_rejects_initial_greater_than_max(self): + creds = _Creds( + retry_on_error_messages=["x"], + initial_wait_on_error_seconds=10, + max_wait_on_error_seconds=5, + ) + with pytest.raises(ValueError, match="must be <= max_wait_on_error_seconds"): + MessageRetryPolicy.from_credentials(creds) + + +class TestMessageRetryPolicyMatching: + def test_substring_match(self): + policy = MessageRetryPolicy(patterns=("Cannot exceed",)) + assert policy.matches(RuntimeError("400: Cannot exceed 1000 queued jobs")) == ( + "Cannot exceed" + ) + + def test_regex_match_returns_pattern_label(self): + policy = MessageRetryPolicy(patterns=(re.compile(r"Cannot exceed \d+"),)) + assert policy.matches(RuntimeError("Cannot exceed 1000 queued")) == ( + "re:Cannot exceed \\d+" + ) + + def test_no_match_returns_none(self): + policy = MessageRetryPolicy(patterns=("Cannot exceed",)) + assert policy.matches(RuntimeError("Permission denied")) is None + + def test_empty_patterns_returns_none(self): + policy = MessageRetryPolicy(patterns=()) + assert policy.matches(RuntimeError("Cannot exceed")) is None + + +class TestDelayCurve: + def test_exponential_capped_curve(self): + policy = MessageRetryPolicy( + patterns=("x",), + max_retries=10, + initial_wait_seconds=1.0, + max_wait_seconds=30.0, + ) + delays = [policy.delay_for_attempt(n) for n in range(1, 11)] + assert delays == [1.0, 2.0, 4.0, 8.0, 16.0, 30.0, 30.0, 30.0, 30.0, 30.0] + + +class TestRetryOnMessage: + def _policy(self, max_retries: int = 3) -> MessageRetryPolicy: + return MessageRetryPolicy( + patterns=("Cannot exceed",), + max_retries=max_retries, + initial_wait_seconds=1.0, + max_wait_seconds=30.0, + ) + + def test_returns_immediately_on_success(self): + sleeps: list[float] = [] + result = retry_on_message( + lambda: 42, + policy=self._policy(), + label="test", + sleep=sleeps.append, + ) + assert result == 42 + assert sleeps == [] + + def test_disabled_policy_runs_operation_once_without_retry(self): + calls = {"n": 0} + + def op(): + calls["n"] += 1 + raise RuntimeError("Cannot exceed 1000 queued") + + sleeps: list[float] = [] + with pytest.raises(RuntimeError, match="Cannot exceed"): + retry_on_message( + op, + policy=MessageRetryPolicy.disabled(), + label="test", + sleep=sleeps.append, + ) + assert calls["n"] == 1 + assert sleeps == [] + + def test_retries_with_capped_exponential_backoff(self): + attempts: list[int] = [] + + def op(): + attempts.append(len(attempts) + 1) + if len(attempts) < 7: + raise DbtDatabaseError("ADLA 400: Cannot exceed 1000 queued SCOPE jobs") + return "ok" + + sleeps: list[float] = [] + result = retry_on_message( + op, + policy=MessageRetryPolicy( + patterns=("Cannot exceed",), + max_retries=10, + initial_wait_seconds=1.0, + max_wait_seconds=30.0, + ), + label="test", + sleep=sleeps.append, + ) + assert result == "ok" + assert len(attempts) == 7 + # 6 failures → 6 sleeps before the 7th attempt succeeds + assert sleeps == [1.0, 2.0, 4.0, 8.0, 16.0, 30.0] + + def test_exhausts_budget_and_reraises_original_exception_type(self): + def op(): + raise DbtDatabaseError("ADLA 400: Cannot exceed 1000 queued SCOPE jobs") + + sleeps: list[float] = [] + with pytest.raises(DbtDatabaseError, match="Cannot exceed"): + retry_on_message( + op, + policy=self._policy(max_retries=3), + label="test", + sleep=sleeps.append, + ) + # 4 total attempts (1 + 3 retries) → 3 sleeps + assert sleeps == [1.0, 2.0, 4.0] + + def test_non_matching_exception_propagates_immediately(self): + calls = {"n": 0} + + def op(): + calls["n"] += 1 + raise PermissionError("403 Forbidden — credential lacks role") + + sleeps: list[float] = [] + with pytest.raises(PermissionError): + retry_on_message( + op, + policy=self._policy(), + label="test", + sleep=sleeps.append, + ) + assert calls["n"] == 1 + assert sleeps == [] + + def test_regex_pattern_triggers_retry(self): + attempts = {"n": 0} + + def op(): + attempts["n"] += 1 + if attempts["n"] < 2: + raise RuntimeError("Cannot exceed 1234 queued SCOPE jobs") + return "ok" + + sleeps: list[float] = [] + policy = MessageRetryPolicy( + patterns=(re.compile(r"Cannot exceed \d+ queued"),), + max_retries=3, + initial_wait_seconds=1.0, + max_wait_seconds=30.0, + ) + result = retry_on_message(op, policy=policy, label="test", sleep=sleeps.append) + assert result == "ok" + assert attempts["n"] == 2 + assert sleeps == [1.0] + + def test_zero_max_retries_with_patterns_still_runs_once_and_raises(self): + calls = {"n": 0} + + def op(): + calls["n"] += 1 + raise RuntimeError("Cannot exceed 1000 queued") + + sleeps: list[float] = [] + policy = MessageRetryPolicy( + patterns=("Cannot exceed",), + max_retries=0, + initial_wait_seconds=1.0, + max_wait_seconds=30.0, + ) + with pytest.raises(RuntimeError): + retry_on_message(op, policy=policy, label="test", sleep=sleeps.append) + assert calls["n"] == 1 + assert sleeps == [] diff --git a/tests/unit/test_quota_eviction.py b/tests/unit/test_quota_eviction.py new file mode 100644 index 0000000..ffcd028 --- /dev/null +++ b/tests/unit/test_quota_eviction.py @@ -0,0 +1,362 @@ +"""Unit tests for the quota-eviction recovery layer.""" + +from __future__ import annotations + +import threading +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from dbt.adapters.scope.quota_eviction import ( + _EVICTION_LOCKS, + QuotaEvictionPolicy, + _lock_for, + is_quota_error, + retry_with_quota_eviction, + select_victims, +) + + +class _Creds: + def __init__( + self, + *, + adla_account: str = "acct", + enable_quota_eviction: bool = True, + quota_eviction_max_attempts: int = 25, + quota_eviction_cancel_num: int = 5, + quota_eviction_wait_seconds: float = 30.0, + quota_eviction_jitter_seconds: float = 5.0, + ) -> None: + self.adla_account = adla_account + self.enable_quota_eviction = enable_quota_eviction + self.quota_eviction_max_attempts = quota_eviction_max_attempts + self.quota_eviction_cancel_num = quota_eviction_cancel_num + self.quota_eviction_wait_seconds = quota_eviction_wait_seconds + self.quota_eviction_jitter_seconds = quota_eviction_jitter_seconds + + +class _FakeCtx: + def __init__(self, jobs: list[dict[str, Any]]) -> None: + self.jobs = jobs + self.list_calls: list[tuple[str | None, int]] = [] + self.cancel_calls: list[str] = [] + self.cancel_failures: dict[str, Exception] = {} + + def list_jobs(self, filter_expr: str | None = None, top: int = 100) -> list[dict[str, Any]]: + self.list_calls.append((filter_expr, top)) + return list(self.jobs) + + def cancel_job_async(self, job_id: str) -> None: + self.cancel_calls.append(job_id) + if job_id in self.cancel_failures: + raise self.cancel_failures[job_id] + + +class TestQuotaEvictionPolicyConstruction: + def test_from_credentials_uses_provided_values(self): + creds = _Creds( + adla_account="my-acct", + enable_quota_eviction=True, + quota_eviction_max_attempts=7, + quota_eviction_cancel_num=3, + quota_eviction_wait_seconds=12.5, + quota_eviction_jitter_seconds=1.5, + ) + policy = QuotaEvictionPolicy.from_credentials(creds) + assert policy.account == "my-acct" + assert policy.enabled is True + assert policy.max_attempts == 7 + assert policy.cancel_num == 3 + assert policy.wait_seconds == 12.5 + assert policy.jitter_seconds == 1.5 + + def test_from_credentials_respects_disabled_flag(self): + policy = QuotaEvictionPolicy.from_credentials(_Creds(enable_quota_eviction=False)) + assert policy.enabled is False + + def test_disabled_factory(self): + policy = QuotaEvictionPolicy.disabled() + assert policy.enabled is False + assert policy.max_attempts == 0 + + def test_policy_is_frozen(self): + policy = QuotaEvictionPolicy.disabled() + with pytest.raises(AttributeError): + policy.max_attempts = 99 # type: ignore[misc] + + +class TestIsQuotaError: + def test_matches_production_error(self): + msg = ( + "ADLA API PUT https://x.azuredatalakeanalytics.net/jobs/abc returned 400: " + '{"Message":"Cannot exceed 1000 queued SCOPE jobs in an ADLA workspace."}' + ) + assert is_quota_error(RuntimeError(msg)) is True + + def test_matches_with_different_quota_number(self): + assert is_quota_error(RuntimeError("Cannot exceed 250 queued SCOPE jobs")) is True + + def test_rejects_unrelated_error(self): + assert is_quota_error(RuntimeError("Internal server error")) is False + + def test_rejects_partial_match_first_needle_only(self): + assert is_quota_error(RuntimeError("Cannot exceed budget")) is False + + def test_rejects_partial_match_second_needle_only(self): + assert is_quota_error(RuntimeError("queued SCOPE jobs reported")) is False + + +class TestSelectVictims: + def test_sorts_highest_priority_number_first(self): + jobs = [ + {"jobId": "j1", "priority": 1, "submitTime": "2026-01-01T00:00:00Z"}, + {"jobId": "j2", "priority": 5, "submitTime": "2026-01-01T00:00:00Z"}, + {"jobId": "j3", "priority": 3, "submitTime": "2026-01-01T00:00:00Z"}, + ] + victims = select_victims(jobs, k=3) + assert [v["jobId"] for v in victims] == ["j2", "j3", "j1"] + + def test_within_priority_tier_oldest_first(self): + jobs = [ + {"jobId": "j1", "priority": 9, "submitTime": "2026-02-01T00:00:00Z"}, + {"jobId": "j2", "priority": 9, "submitTime": "2026-01-01T00:00:00Z"}, + {"jobId": "j3", "priority": 9, "submitTime": "2026-03-01T00:00:00Z"}, + ] + victims = select_victims(jobs, k=3) + assert [v["jobId"] for v in victims] == ["j2", "j1", "j3"] + + def test_respects_k(self): + jobs = [{"jobId": f"j{i}", "priority": i, "submitTime": ""} for i in range(10)] + victims = select_victims(jobs, k=3) + assert len(victims) == 3 + assert [v["jobId"] for v in victims] == ["j9", "j8", "j7"] + + def test_returns_fewer_when_list_shorter_than_k(self): + jobs = [{"jobId": "j1", "priority": 1, "submitTime": ""}] + victims = select_victims(jobs, k=5) + assert len(victims) == 1 + + def test_empty_list_returns_empty(self): + assert select_victims([], k=5) == [] + + def test_zero_k_returns_empty(self): + jobs = [{"jobId": "j1", "priority": 1, "submitTime": ""}] + assert select_victims(jobs, k=0) == [] + + def test_handles_missing_priority(self): + jobs = [ + {"jobId": "j1", "submitTime": "2026-01-01T00:00:00Z"}, + {"jobId": "j2", "priority": 5, "submitTime": "2026-01-01T00:00:00Z"}, + ] + victims = select_victims(jobs, k=2) + assert victims[0]["jobId"] == "j2" + + def test_handles_non_numeric_priority(self): + jobs = [ + {"jobId": "j1", "priority": "not-a-number", "submitTime": ""}, + {"jobId": "j2", "priority": 3, "submitTime": ""}, + ] + victims = select_victims(jobs, k=2) + assert victims[0]["jobId"] == "j2" + + +class TestLockFor: + def setup_method(self): + _EVICTION_LOCKS.clear() + + def test_same_account_returns_same_lock(self): + assert _lock_for("a") is _lock_for("a") + + def test_different_accounts_get_different_locks(self): + assert _lock_for("a") is not _lock_for("b") + + def test_locks_dict_thread_safe_under_contention(self): + results: list[threading.Lock] = [] + barrier = threading.Barrier(20) + + def grab() -> None: + barrier.wait() + results.append(_lock_for("shared")) + + threads = [threading.Thread(target=grab) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join() + assert len({id(r) for r in results}) == 1 + + +class TestRetryWithQuotaEviction: + def setup_method(self): + _EVICTION_LOCKS.clear() + + def _policy(self, **overrides: Any) -> QuotaEvictionPolicy: + defaults = { + "account": "test-acct", + "enabled": True, + "max_attempts": 3, + "cancel_num": 2, + "wait_seconds": 30.0, + "jitter_seconds": 5.0, + } + defaults.update(overrides) + return QuotaEvictionPolicy(**defaults) + + def test_happy_path_succeeds_without_eviction(self): + ctx = _FakeCtx(jobs=[]) + op = MagicMock(return_value={"ok": True}) + result = retry_with_quota_eviction( + op, eviction_ctx=ctx, policy=self._policy(), label="test" + ) + assert result == {"ok": True} + assert ctx.cancel_calls == [] + assert ctx.list_calls == [] + assert op.call_count == 1 + + def test_disabled_policy_passes_through(self): + ctx = _FakeCtx(jobs=[]) + err = RuntimeError("Cannot exceed 1000 queued SCOPE jobs") + op = MagicMock(side_effect=err) + with pytest.raises(RuntimeError, match="Cannot exceed"): + retry_with_quota_eviction( + op, + eviction_ctx=ctx, + policy=QuotaEvictionPolicy.disabled(), + label="test", + ) + assert ctx.cancel_calls == [] + assert op.call_count == 1 + + def test_non_quota_error_reraises_immediately(self): + ctx = _FakeCtx(jobs=[{"jobId": "j1", "priority": 9, "submitTime": ""}]) + op = MagicMock(side_effect=RuntimeError("unrelated 500 error")) + with pytest.raises(RuntimeError, match="unrelated"): + retry_with_quota_eviction(op, eviction_ctx=ctx, policy=self._policy(), label="test") + assert ctx.cancel_calls == [] + assert op.call_count == 1 + + def test_one_eviction_recovery(self): + jobs = [{"jobId": f"j{i}", "priority": i, "submitTime": ""} for i in range(1, 11)] + ctx = _FakeCtx(jobs=jobs) + err = RuntimeError("Cannot exceed 1000 queued SCOPE jobs") + op = MagicMock(side_effect=[err, {"ok": True}]) + sleeps: list[float] = [] + + result = retry_with_quota_eviction( + op, + eviction_ctx=ctx, + policy=self._policy(cancel_num=3), + label="test", + sleep=lambda s: sleeps.append(s), + random_uniform=lambda a, b: 0.0, + ) + assert result == {"ok": True} + assert op.call_count == 2 + assert ctx.list_calls == [("state ne 'Ended'", 1000)] + assert ctx.cancel_calls == ["j10", "j9", "j8"] + assert sleeps == [30.0] + + def test_exhausts_max_attempts(self): + jobs = [{"jobId": "j1", "priority": 9, "submitTime": ""}] + ctx = _FakeCtx(jobs=jobs) + err = RuntimeError("Cannot exceed 999 queued SCOPE jobs forever") + op = MagicMock(side_effect=err) + + with pytest.raises(RuntimeError, match="Cannot exceed"): + retry_with_quota_eviction( + op, + eviction_ctx=ctx, + policy=self._policy(max_attempts=2), + label="test", + sleep=lambda s: None, + random_uniform=lambda a, b: 0.0, + ) + assert op.call_count == 2 + assert len(ctx.cancel_calls) == 2 + + def test_empty_job_list_reraises(self): + ctx = _FakeCtx(jobs=[]) + err = RuntimeError("Cannot exceed 1000 queued SCOPE jobs") + op = MagicMock(side_effect=err) + with pytest.raises(RuntimeError, match="Cannot exceed"): + retry_with_quota_eviction( + op, + eviction_ctx=ctx, + policy=self._policy(), + label="test", + sleep=lambda s: None, + random_uniform=lambda a, b: 0.0, + ) + assert op.call_count == 1 + + def test_individual_cancel_failure_swallowed(self): + jobs = [{"jobId": f"j{i}", "priority": i, "submitTime": ""} for i in range(1, 4)] + ctx = _FakeCtx(jobs=jobs) + ctx.cancel_failures["j3"] = RuntimeError("transient cancel 500") + + err = RuntimeError("Cannot exceed 1000 queued SCOPE jobs") + op = MagicMock(side_effect=[err, {"ok": True}]) + + result = retry_with_quota_eviction( + op, + eviction_ctx=ctx, + policy=self._policy(cancel_num=3), + label="test", + sleep=lambda s: None, + random_uniform=lambda a, b: 0.0, + ) + assert result == {"ok": True} + assert set(ctx.cancel_calls) == {"j1", "j2", "j3"} + + def test_list_jobs_failure_aborts(self): + ctx = _FakeCtx(jobs=[]) + ctx.list_jobs = MagicMock(side_effect=RuntimeError("list failed")) # type: ignore[method-assign] + err = RuntimeError("Cannot exceed 1000 queued SCOPE jobs") + op = MagicMock(side_effect=err) + + with pytest.raises(RuntimeError, match="Cannot exceed"): + retry_with_quota_eviction( + op, + eviction_ctx=ctx, + policy=self._policy(), + label="test", + sleep=lambda s: None, + random_uniform=lambda a, b: 0.0, + ) + assert op.call_count == 1 + + def test_jitter_applied_to_sleep(self): + jobs = [{"jobId": "j1", "priority": 9, "submitTime": ""}] + ctx = _FakeCtx(jobs=jobs) + err = RuntimeError("Cannot exceed 1000 queued SCOPE jobs") + op = MagicMock(side_effect=[err, {"ok": True}]) + sleeps: list[float] = [] + + retry_with_quota_eviction( + op, + eviction_ctx=ctx, + policy=self._policy(wait_seconds=30, jitter_seconds=5), + label="test", + sleep=lambda s: sleeps.append(s), + random_uniform=lambda a, b: 3.5, + ) + assert sleeps == [33.5] + + def test_negative_jitter_clamped_to_zero(self): + jobs = [{"jobId": "j1", "priority": 9, "submitTime": ""}] + ctx = _FakeCtx(jobs=jobs) + err = RuntimeError("Cannot exceed 1000 queued SCOPE jobs") + op = MagicMock(side_effect=[err, {"ok": True}]) + sleeps: list[float] = [] + + retry_with_quota_eviction( + op, + eviction_ctx=ctx, + policy=self._policy(wait_seconds=2, jitter_seconds=10), + label="test", + sleep=lambda s: sleeps.append(s), + random_uniform=lambda a, b: -100, + ) + assert sleeps == [0.0] diff --git a/tests/unit/test_script_builder.py b/tests/unit/test_script_builder.py index 5982207..3dde717 100644 --- a/tests/unit/test_script_builder.py +++ b/tests/unit/test_script_builder.py @@ -271,6 +271,82 @@ def test_safety_buffer_default(self): cfg = ScriptConfig() assert cfg.safety_buffer_seconds == 30 + def test_max_file_count_per_output_file_set_default(self): + cfg = ScriptConfig() + assert cfg.max_file_count_per_output_file_set == 5000 + + +class TestMaxFileCountPerOutputFileSet: + """Tests for the @@MaxFileCountPerOutputFileSet SET emission and validation.""" + + def test_default_emitted_in_full_refresh(self, sample_config): + script = ScriptBuilder.build_full_refresh(sample_config, "SELECT * FROM @data") + assert "SET @@MaxFileCountPerOutputFileSet = 5000;" in script + + def test_default_emitted_in_incremental(self, sample_config): + script = ScriptBuilder.build_incremental(sample_config, "SELECT * FROM @data") + assert "SET @@MaxFileCountPerOutputFileSet = 5000;" in script + + def test_override_emitted_in_full_refresh(self, sample_config): + sample_config.max_file_count_per_output_file_set = 100000 + script = ScriptBuilder.build_full_refresh(sample_config, "SELECT * FROM @data") + assert "SET @@MaxFileCountPerOutputFileSet = 100000;" in script + assert "SET @@MaxFileCountPerOutputFileSet = 5000;" not in script + + def test_override_emitted_in_incremental(self, sample_config): + sample_config.max_file_count_per_output_file_set = 250000 + script = ScriptBuilder.build_incremental(sample_config, "SELECT * FROM @data") + assert "SET @@MaxFileCountPerOutputFileSet = 250000;" in script + + def test_min_boundary_accepted(self, sample_config): + sample_config.max_file_count_per_output_file_set = 1 + script = ScriptBuilder.build_incremental(sample_config, "SELECT * FROM @data") + assert "SET @@MaxFileCountPerOutputFileSet = 1;" in script + + def test_max_boundary_accepted(self, sample_config): + sample_config.max_file_count_per_output_file_set = 1_000_000 + script = ScriptBuilder.build_full_refresh(sample_config, "SELECT * FROM @data") + assert "SET @@MaxFileCountPerOutputFileSet = 1000000;" in script + + def test_zero_raises(self, sample_config): + sample_config.max_file_count_per_output_file_set = 0 + with pytest.raises(DbtRuntimeError, match="max_file_count_per_output_file_set"): + ScriptBuilder.build_incremental(sample_config, "SELECT * FROM @data") + + def test_negative_raises(self, sample_config): + sample_config.max_file_count_per_output_file_set = -1 + with pytest.raises(DbtRuntimeError, match="max_file_count_per_output_file_set"): + ScriptBuilder.build_full_refresh(sample_config, "SELECT * FROM @data") + + def test_above_max_raises(self, sample_config): + sample_config.max_file_count_per_output_file_set = 2_000_000 + with pytest.raises(DbtRuntimeError, match=r"\[1, 1000000\]"): + ScriptBuilder.build_incremental(sample_config, "SELECT * FROM @data") + + def test_non_int_raises(self, sample_config): + sample_config.max_file_count_per_output_file_set = "5000" + with pytest.raises(DbtRuntimeError, match="must be an int"): + ScriptBuilder.build_incremental(sample_config, "SELECT * FROM @data") + + def test_bool_rejected_despite_subclassing_int(self, sample_config): + sample_config.max_file_count_per_output_file_set = True + with pytest.raises(DbtRuntimeError, match="must be an int"): + ScriptBuilder.build_incremental(sample_config, "SELECT * FROM @data") + + def test_emitted_before_declare_path_full_refresh(self, sample_config): + script = ScriptBuilder.build_full_refresh(sample_config, "SELECT * FROM @data") + commit_pos = script.index("SET @@DeltaLakeCommitCondition") + max_files_pos = script.index("SET @@MaxFileCountPerOutputFileSet") + declare_pos = script.index("#DECLARE @deltaPath") + assert commit_pos < max_files_pos < declare_pos + + def test_emitted_before_declare_path_incremental(self, sample_config): + script = ScriptBuilder.build_incremental(sample_config, "SELECT * FROM @data") + commit_pos = script.index("SET @@DeltaLakeCommitCondition") + max_files_pos = script.index("SET @@MaxFileCountPerOutputFileSet") + declare_pos = script.index("#DECLARE @deltaPath") + assert commit_pos < max_files_pos < declare_pos + class TestScriptBuilderModelSQL: """Tests for model SQL handling in the new file-based approach.""" diff --git a/tests/unit/test_shutdown_cancellation.py b/tests/unit/test_shutdown_cancellation.py new file mode 100644 index 0000000..c680a29 --- /dev/null +++ b/tests/unit/test_shutdown_cancellation.py @@ -0,0 +1,531 @@ +"""Tests for graceful shutdown: cancel in-flight ADLA jobs on SIGINT/SIGTERM.""" + +from __future__ import annotations + +import signal +import threading +import time +from unittest.mock import MagicMock, patch + +import pytest +import requests.exceptions +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.scope import connections as conn_module +from dbt.adapters.scope import impl as impl_module +from dbt.adapters.scope.connections import ( + ScopeConnectionHandle, + ScopeConnectionManager, + _active_jobs, + _ActiveJobEntry, + _cancelled_job_ids, + _deregister_active_job, + _register_active_job, + _shutdown_event, + cancel_all_active_jobs, +) +from dbt.adapters.scope.credentials import ScopeCredentials + +# --------------------------------------------------------------------------- +# Test fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_shutdown_state(): + """Snapshot/restore module-level state mutated by these tests.""" + _shutdown_event.clear() + saved_jobs = dict(_active_jobs) + saved_cancelled = set(_cancelled_job_ids) + saved_observed = list(impl_module._observed_credentials) + _active_jobs.clear() + _cancelled_job_ids.clear() + impl_module._observed_credentials.clear() + yield + _shutdown_event.clear() + _active_jobs.clear() + _active_jobs.update(saved_jobs) + _cancelled_job_ids.clear() + _cancelled_job_ids.update(saved_cancelled) + impl_module._observed_credentials.clear() + impl_module._observed_credentials.extend(saved_observed) + + +def _make_handle(account: str = "test-adla") -> ScopeConnectionHandle: + creds = MagicMock() + creds.adla_account = account + creds.http_timeout_seconds = 30 + creds.http_retries = 3 + with patch.object(ScopeConnectionHandle, "_build_session", return_value=MagicMock()): + return ScopeConnectionHandle(creds) + + +def _make_entry(job_id: str, handle: ScopeConnectionHandle | None = None) -> _ActiveJobEntry: + return _ActiveJobEntry( + job_id=job_id, + name=f"job-{job_id}", + handle=handle or _make_handle(), + submitted_at=time.monotonic(), + model_name=None, + ) + + +# --------------------------------------------------------------------------- +# Registry primitives +# --------------------------------------------------------------------------- + + +class TestActiveJobRegistry: + def test_register_and_deregister(self): + entry = _make_entry("job-1") + _register_active_job(entry) + assert "job-1" in _active_jobs + _deregister_active_job("job-1") + assert "job-1" not in _active_jobs + + def test_deregister_unknown_is_noop(self): + _deregister_active_job("does-not-exist") # must not raise + + +# --------------------------------------------------------------------------- +# submit_and_wait: register/deregister + shutdown abort +# --------------------------------------------------------------------------- + + +class TestSubmitAndWaitRegistry: + def _handle_with_request(self, fake_request): + handle = _make_handle() + handle._get_token = MagicMock(return_value="fake-token") + handle._request = MagicMock(side_effect=fake_request) + return handle + + def test_registers_and_deregisters_on_success(self): + seen_during_poll: list[str] = [] + + def fake_request(method, url, **kwargs): + if method == "PUT": + return {"state": "Running"} + seen_during_poll.extend(_active_jobs.keys()) + return {"state": "Ended", "result": "Succeeded"} + + handle = self._handle_with_request(fake_request) + job = handle.submit_and_wait(name="t", script="// s", au=10, priority=1, poll_interval=0) + assert job.succeeded + assert seen_during_poll == [job.job_id] + assert job.job_id not in _active_jobs + + def test_deregisters_on_failure(self): + def fake_request(method, url, **kwargs): + if method == "PUT": + return {"state": "Running"} + return {"state": "Ended", "result": "Failed", "errorMessage": "boom"} + + handle = self._handle_with_request(fake_request) + from dbt_common.exceptions import DbtDatabaseError + + with pytest.raises(DbtDatabaseError): + handle.submit_and_wait(name="t", script="// s", au=10, priority=1, poll_interval=0) + assert _active_jobs == {} + + def test_aborts_on_shutdown_event_and_self_cancels(self): + cancel_calls: list[tuple[str, int]] = [] + + def fake_cancel_job(job_id, poll_interval=2, max_wait=120): + cancel_calls.append((job_id, max_wait)) + + def fake_request(method, url, **kwargs): + if method == "PUT": + return {"state": "Running"} + return {"state": "Running"} + + handle = self._handle_with_request(fake_request) + handle.cancel_job = MagicMock(side_effect=fake_cancel_job) + _shutdown_event.set() + + with pytest.raises(DbtRuntimeError, match="shutdown signal"): + handle.submit_and_wait( + name="t", + script="// s", + au=10, + priority=1, + poll_interval=0, + wait_on_cancel_seconds=17, + ) + assert len(cancel_calls) == 1 + assert cancel_calls[0][1] == 17 + assert _active_jobs == {} + assert cancel_calls[0][0] in _cancelled_job_ids + + def test_does_not_double_cancel_if_already_in_cancelled_set(self): + cancel_mock = MagicMock() + + def fake_request(method, url, **kwargs): + if method == "PUT": + return {"jobId": "preset-id", "state": "Running"} + return {"state": "Running"} + + handle = self._handle_with_request(fake_request) + handle.cancel_job = cancel_mock + + with patch("uuid.uuid4", return_value=MagicMock(__str__=lambda self: "preset-id")): + _cancelled_job_ids.add("preset-id") + _shutdown_event.set() + with pytest.raises(DbtRuntimeError): + handle.submit_and_wait(name="t", script="// s", au=10, priority=1, poll_interval=0) + cancel_mock.assert_not_called() + + +# --------------------------------------------------------------------------- +# cancel_all_active_jobs +# --------------------------------------------------------------------------- + + +class TestCancelAllActiveJobs: + def test_empty_registry_returns_zero(self): + assert cancel_all_active_jobs("test", wait_seconds=1) == (0, 0) + + def test_calls_cancel_job_per_entry_with_wait(self): + handle_a = _make_handle("a") + handle_b = _make_handle("b") + handle_a.cancel_job = MagicMock() + handle_b.cancel_job = MagicMock() + _register_active_job(_make_entry("job-a", handle_a)) + _register_active_job(_make_entry("job-b", handle_b)) + + attempted, confirmed = cancel_all_active_jobs("test", wait_seconds=11) + + assert attempted == 2 + assert confirmed == 2 + handle_a.cancel_job.assert_called_once_with("job-a", poll_interval=2, max_wait=11) + handle_b.cancel_job.assert_called_once_with("job-b", poll_interval=2, max_wait=11) + + def test_returns_attempted_and_confirmed_with_mixed_results(self): + good = _make_handle("good") + bad = _make_handle("bad") + good.cancel_job = MagicMock() + bad.cancel_job = MagicMock(side_effect=RuntimeError("network down")) + _register_active_job(_make_entry("ok", good)) + _register_active_job(_make_entry("err", bad)) + + attempted, confirmed = cancel_all_active_jobs("test", wait_seconds=5) + + assert attempted == 2 + assert confirmed == 1 + + def test_continues_on_per_job_failure(self): + h1 = _make_handle("h1") + h2 = _make_handle("h2") + h3 = _make_handle("h3") + h1.cancel_job = MagicMock() + h2.cancel_job = MagicMock(side_effect=ValueError("boom")) + h3.cancel_job = MagicMock() + _register_active_job(_make_entry("j1", h1)) + _register_active_job(_make_entry("j2", h2)) + _register_active_job(_make_entry("j3", h3)) + + cancel_all_active_jobs("test", wait_seconds=5) + + h1.cancel_job.assert_called_once() + h2.cancel_job.assert_called_once() + h3.cancel_job.assert_called_once() + + def test_respects_wait_ceiling(self): + slow = _make_handle("slow") + + def slow_cancel(job_id, poll_interval=2, max_wait=120): + time.sleep(max_wait + 10) + + slow.cancel_job = MagicMock(side_effect=slow_cancel) + _register_active_job(_make_entry("slow-job", slow)) + + start = time.monotonic() + cancel_all_active_jobs("test", wait_seconds=1) + elapsed = time.monotonic() - start + + # wait_seconds=1 + 5s grace = 6s ceiling + assert elapsed < 8, f"Cancel-all blocked {elapsed:.1f}s, expected < 8s" + + def test_skips_jobs_already_in_cancelled_set(self): + handle = _make_handle() + handle.cancel_job = MagicMock() + _register_active_job(_make_entry("already-cancelled", handle)) + _cancelled_job_ids.add("already-cancelled") + + attempted, confirmed = cancel_all_active_jobs("test", wait_seconds=5) + + assert attempted == 1 + assert confirmed == 1 + handle.cancel_job.assert_not_called() + + +# --------------------------------------------------------------------------- +# Observed credentials gates in impl.py +# --------------------------------------------------------------------------- + + +class TestObservedCredentialsGates: + def _make_creds(self, *, cancel=True, wait=30): + creds = ScopeCredentials( + database="db", + schema="sch", + adla_account="acct", + cancel_jobs_on_shutdown=cancel, + wait_on_cancel_seconds=wait, + ) + return creds + + def test_no_observed_defaults_to_enabled(self): + assert impl_module._any_observed_cancel_on_shutdown_enabled() is True + assert impl_module._observed_max_wait_on_cancel_seconds() == 30 + + def test_all_opt_out(self): + impl_module._observe_credentials(self._make_creds(cancel=False)) + impl_module._observe_credentials(self._make_creds(cancel=False)) + assert impl_module._any_observed_cancel_on_shutdown_enabled() is False + + def test_any_enabled_triggers_gate(self): + impl_module._observe_credentials(self._make_creds(cancel=False)) + impl_module._observe_credentials(self._make_creds(cancel=True, wait=42)) + assert impl_module._any_observed_cancel_on_shutdown_enabled() is True + + def test_observed_max_wait_returns_max_across_credentials(self): + impl_module._observe_credentials(self._make_creds(wait=30)) + impl_module._observe_credentials(self._make_creds(wait=60)) + impl_module._observe_credentials(self._make_creds(wait=10)) + assert impl_module._observed_max_wait_on_cancel_seconds() == 60 + + def test_observed_max_wait_ignores_opt_out_entries(self): + impl_module._observe_credentials(self._make_creds(cancel=False, wait=999)) + impl_module._observe_credentials(self._make_creds(cancel=True, wait=20)) + assert impl_module._observed_max_wait_on_cancel_seconds() == 20 + + def test_observe_credentials_is_idempotent(self): + c = self._make_creds() + impl_module._observe_credentials(c) + impl_module._observe_credentials(c) + assert len(impl_module._observed_credentials) == 1 + + +# --------------------------------------------------------------------------- +# Signal handler + atexit +# --------------------------------------------------------------------------- + + +class TestSignalHandlerCancelAll: + def test_signal_handler_invokes_cancel_all(self): + impl_module._observe_credentials( + ScopeCredentials( + database="db", + schema="sch", + adla_account="acct", + wait_on_cancel_seconds=42, + ) + ) + with ( + patch.object(impl_module, "cancel_all_active_jobs") as mock_cancel, + patch.object(impl_module, "_signal_handlers_installed", False), + ): + old_sigint = signal.getsignal(signal.SIGINT) + old_sigterm = signal.getsignal(signal.SIGTERM) + # Set SIGINT to SIG_IGN so the chained previous handler doesn't + # raise KeyboardInterrupt inside the test. + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_IGN) + try: + impl_module._install_signal_handlers() + handler = signal.getsignal(signal.SIGINT) + assert callable(handler) + handler(signal.SIGINT, None) + finally: + signal.signal(signal.SIGINT, old_sigint) + signal.signal(signal.SIGTERM, old_sigterm) + impl_module._signal_handlers_installed = False + mock_cancel.assert_called_once() + args, kwargs = mock_cancel.call_args + assert args[0] == "signal:SIGINT" + assert kwargs.get("wait_seconds") == 42 + assert _shutdown_event.is_set() + + def test_signal_handler_skipped_when_all_opted_out(self): + impl_module._observe_credentials( + ScopeCredentials( + database="db", + schema="sch", + adla_account="acct", + cancel_jobs_on_shutdown=False, + ) + ) + with ( + patch.object(impl_module, "cancel_all_active_jobs") as mock_cancel, + patch.object(impl_module, "_signal_handlers_installed", False), + ): + old_sigint = signal.getsignal(signal.SIGINT) + old_sigterm = signal.getsignal(signal.SIGTERM) + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_IGN) + try: + impl_module._install_signal_handlers() + handler = signal.getsignal(signal.SIGINT) + handler(signal.SIGINT, None) + finally: + signal.signal(signal.SIGINT, old_sigint) + signal.signal(signal.SIGTERM, old_sigterm) + impl_module._signal_handlers_installed = False + mock_cancel.assert_not_called() + # _shutdown_event is still set so in-flight loops abort + assert _shutdown_event.is_set() + + +class TestAtexitCancelAll: + def test_atexit_invokes_cancel_all_when_enabled(self): + impl_module._observe_credentials( + ScopeCredentials( + database="db", + schema="sch", + adla_account="acct", + wait_on_cancel_seconds=15, + ) + ) + with patch.object(impl_module, "cancel_all_active_jobs") as mock_cancel: + impl_module._atexit_cancel_all() + mock_cancel.assert_called_once() + args, kwargs = mock_cancel.call_args + assert args[0] == "atexit" + assert kwargs.get("wait_seconds") == 15 + + def test_atexit_skipped_when_all_opted_out(self): + impl_module._observe_credentials( + ScopeCredentials( + database="db", + schema="sch", + adla_account="acct", + cancel_jobs_on_shutdown=False, + ) + ) + with patch.object(impl_module, "cancel_all_active_jobs") as mock_cancel: + impl_module._atexit_cancel_all() + mock_cancel.assert_not_called() + + +# --------------------------------------------------------------------------- +# dbt-native cancel hooks +# --------------------------------------------------------------------------- + + +class TestIsCancelable: + def test_returns_true(self): + assert impl_module.ScopeAdapter.is_cancelable() is True + + +class TestManagerCancelDelegation: + def test_cancel_delegates_to_cancel_all(self): + connection = MagicMock() + connection.credentials = ScopeCredentials( + database="db", + schema="sch", + adla_account="acct", + wait_on_cancel_seconds=21, + ) + mgr = ScopeConnectionManager.__new__(ScopeConnectionManager) + with patch.object(conn_module, "cancel_all_active_jobs") as mock_cancel: + mgr.cancel(connection) + mock_cancel.assert_called_once_with("dbt-native:cancel", wait_seconds=21) + + def test_cancel_respects_opt_out_credential(self): + connection = MagicMock() + connection.credentials = ScopeCredentials( + database="db", + schema="sch", + adla_account="acct", + cancel_jobs_on_shutdown=False, + ) + mgr = ScopeConnectionManager.__new__(ScopeConnectionManager) + with patch.object(conn_module, "cancel_all_active_jobs") as mock_cancel: + mgr.cancel(connection) + mock_cancel.assert_not_called() + + def test_cancel_open_delegates_to_cancel_all(self): + with patch.object(conn_module, "cancel_all_active_jobs") as mock_cancel: + ScopeConnectionManager.cancel_open() + mock_cancel.assert_called_once_with("dbt-native:cancel_open", wait_seconds=30) + + +# --------------------------------------------------------------------------- +# Connection open hook wires everything +# --------------------------------------------------------------------------- + + +class TestOpenHookInstallation: + def test_open_hook_is_wired(self): + assert ScopeConnectionManager._on_open is not None + # impl.py sets it to _scope_open_hook + assert ScopeConnectionManager._on_open is impl_module._scope_open_hook + + def test_open_hook_observes_credentials_and_installs_handlers(self): + creds = ScopeCredentials( + database="db", + schema="sch", + adla_account="acct", + wait_on_cancel_seconds=55, + ) + with ( + patch.object(impl_module, "_install_signal_handlers") as install_mock, + patch.object(impl_module, "_register_atexit") as atexit_mock, + ): + impl_module._scope_open_hook(creds) + assert creds in impl_module._observed_credentials + install_mock.assert_called_once() + atexit_mock.assert_called_once() + + +# --------------------------------------------------------------------------- +# Sanity: poll loop tolerates exceptions raised by cancel_job during shutdown +# --------------------------------------------------------------------------- + + +class TestSelfCancelFailureDuringShutdown: + def test_self_cancel_exception_still_raises_runtime_error(self): + handle = _make_handle() + handle._get_token = MagicMock(return_value="fake-token") + + def fake_request(method, url, **kwargs): + if method == "PUT": + return {"state": "Running"} + return {"state": "Running"} + + handle._request = MagicMock(side_effect=fake_request) + handle.cancel_job = MagicMock(side_effect=requests.exceptions.ConnectionError("nope")) + _shutdown_event.set() + + with pytest.raises(DbtRuntimeError, match="shutdown signal"): + handle.submit_and_wait( + name="t", + script="// s", + au=10, + priority=1, + poll_interval=0, + wait_on_cancel_seconds=3, + ) + # Still deregistered + assert _active_jobs == {} + + +# --------------------------------------------------------------------------- +# Thread-safety smoke +# --------------------------------------------------------------------------- + + +class TestRegistryThreadSafety: + def test_concurrent_register_deregister(self): + def worker(i: int): + for j in range(50): + e = _make_entry(f"j-{i}-{j}") + _register_active_job(e) + _deregister_active_job(e.job_id) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + assert _active_jobs == {}