Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 49 additions & 13 deletions chuck_data/clients/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import requests
import time
import urllib.parse
from datetime import datetime
from datetime import datetime, timezone
from chuck_data.config import get_warehouse_id
from chuck_data.clients.amperity import get_amperity_url
from chuck_data.databricks.url_utils import (
Expand Down Expand Up @@ -576,15 +576,38 @@ def submit_sql_statement(
# Jobs methods
#

def _build_libraries(self, data_provider=None):
@staticmethod
def _generate_jar_volume_path(init_script_path):
"""Return a timestamped JAR volume path co-located with the init script.

Mirrors the chuck-api side (CATALYST-253): the cluster init script
copies /opt/amperity/job.jar to JOB_JAR_VOL_PATH during cluster
startup, and the Run_Stitch task's library entry points at this
same path so Databricks blocks task start until the JAR is staged.
The filename is timestamped so concurrent runs do not clobber each
other's JARs.
"""
parent_dir = init_script_path.rsplit("/", 1)[0]
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
return f"{parent_dir}/job-{timestamp}.jar"

def _build_libraries(self, data_provider=None, main_jar_path=None):
"""Return the Maven/JAR library list for a Stitch job run.

The main JAR is always included. The connector Maven dependency is chosen
based on the data provider:
- "snowflake" → Snowflake Spark connector
- "aws_redshift" → Redshift community connector + Avro support

Args:
data_provider: Optional connector selector (see above).
main_jar_path: Optional override for the main JAR path. Defaults to
the local `file:///opt/amperity/job.jar` path. Callers that
stage the JAR into a Unity Catalog volume should pass that
volume path so Databricks blocks the task until the init
script has copied the JAR into the volume.
"""
libraries: list = [{"jar": "file:///opt/amperity/job.jar"}]
libraries: list = [{"jar": main_jar_path or "file:///opt/amperity/job.jar"}]
if data_provider == "snowflake":
libraries.append(
{"maven": {"coordinates": "net.snowflake:spark-snowflake_2.12:3.1.3"}}
Expand Down Expand Up @@ -634,9 +657,11 @@ def submit_job_run(
# Define the task and cluster for the one-time run
# Create base cluster configuration
# Detect init script location (S3 vs Volumes) and configure accordingly
jar_vol_path = None
if init_script_path.startswith("s3://"):
# S3 init script (for Redshift data source)
# Get region from config
# S3 init script (for Redshift data source). Volumes are a
# Unity Catalog concept and don't apply here, so we leave the
# library jar pointing at the local /opt/amperity/job.jar.
from chuck_data.config import get_aws_region

region = get_aws_region() or "us-west-2"
Expand All @@ -650,14 +675,27 @@ def submit_job_run(
}
]
else:
# Volumes init script (for Unity Catalog data source)
# Volumes init script (Unity Catalog). Stage the JAR into the
# same volume so the Run_Stitch task's library entry resolves
# against the volume path -- the chuck-api init script reads
# JOB_JAR_VOL_PATH from spark_env_vars and copies the JAR there.
init_scripts_config = [
{
"volumes": {
"destination": init_script_path,
}
}
]
jar_vol_path = self._generate_jar_volume_path(init_script_path)

spark_env_vars = {
"JNAME": "zulu17-ca-amd64",
"CHUCK_API_URL": f"https://{get_amperity_url()}",
"DEBUG_INIT_SRIPT_URL": init_script_path,
"DEBUG_CONFIG_PATH": config_path,
}
if jar_vol_path:
spark_env_vars["JOB_JAR_VOL_PATH"] = jar_vol_path

cluster_config = {
"cluster_name": "",
Expand All @@ -669,12 +707,7 @@ def submit_job_run(
"sys": "chuck",
"tenant": "amperity",
},
"spark_env_vars": {
"JNAME": "zulu17-ca-amd64",
"CHUCK_API_URL": f"https://{get_amperity_url()}",
"DEBUG_INIT_SRIPT_URL": init_script_path,
"DEBUG_CONFIG_PATH": config_path,
},
"spark_env_vars": spark_env_vars,
"enable_elastic_disk": False,
"data_security_mode": "SINGLE_USER",
"runtime_engine": "STANDARD",
Expand Down Expand Up @@ -704,7 +737,10 @@ def submit_job_run(
],
"run_as_repl": True,
},
"libraries": self._build_libraries(data_provider),
"libraries": self._build_libraries(
data_provider=data_provider,
main_jar_path=jar_vol_path,
),
"timeout_seconds": 0,
"email_notifications": {},
"webhook_notifications": {},
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/test_workspace_and_init_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,59 @@ def test_submit_job_run_with_volumes_init_script(self, client):

assert result == {"run_id": 123}

def test_submit_job_run_volumes_stages_jar_into_volume(self, client):
"""Volumes init script gets a timestamped JOB_JAR_VOL_PATH env var
and the Run_Stitch library jar matches that path so Databricks
blocks the task until the init script stages the JAR there."""
import re

with patch.object(client, "post") as mock_post:
mock_post.return_value = {"run_id": 321}

client.submit_job_run(
config_path="/Volumes/cat/schema/vol/config.json",
init_script_path="/Volumes/cat/schema/vol/init.sh",
)

payload = mock_post.call_args[0][1]
task = payload["tasks"][0]
spark_env_vars = task["new_cluster"]["spark_env_vars"]

# JOB_JAR_VOL_PATH is a timestamped job-*.jar co-located with
# the init script.
jar_vol_path = spark_env_vars["JOB_JAR_VOL_PATH"]
assert re.fullmatch(
r"/Volumes/cat/schema/vol/job-\d{8}-\d{6}\.jar", jar_vol_path
)

# Library jar entry matches JOB_JAR_VOL_PATH exactly so the
# cluster's library load waits for the init script to stage
# the JAR.
jar_entries = [lib["jar"] for lib in task["libraries"] if "jar" in lib]
assert jar_entries == [jar_vol_path]

def test_submit_job_run_s3_does_not_stage_jar_into_volume(self, client):
"""S3 init scripts (Redshift) keep the local file:// jar -- there
is no Unity Catalog volume to stage to and JOB_JAR_VOL_PATH must
not be set."""
with patch("chuck_data.config.get_aws_region", return_value="us-east-1"):
with patch.object(client, "post") as mock_post:
mock_post.return_value = {"run_id": 654}

client.submit_job_run(
config_path="s3://bucket/config.json",
init_script_path="s3://bucket/init.sh",
)

payload = mock_post.call_args[0][1]
task = payload["tasks"][0]
spark_env_vars = task["new_cluster"]["spark_env_vars"]

assert "JOB_JAR_VOL_PATH" not in spark_env_vars

jar_entries = [lib["jar"] for lib in task["libraries"] if "jar" in lib]
assert jar_entries == ["file:///opt/amperity/job.jar"]

@patch("chuck_data.config.get_aws_region")
def test_submit_job_run_with_s3_init_script(self, mock_get_region, client):
"""Test submit_job_run uses s3 format for S3 paths."""
Expand Down
Loading