From 22ab07f4f09c76eb9f465edbd055eee65e56ae57 Mon Sep 17 00:00:00 2001 From: Marie Coolsaet Date: Tue, 24 Feb 2026 15:54:33 -0500 Subject: [PATCH 1/5] Add DPF notebook example. --- .../distributed_partition_function/README.md | 42 ++ .../dpf_example.ipynb | 677 ++++++++++++++++++ 2 files changed, 719 insertions(+) create mode 100644 samples/ml/distributed_partition_function/README.md create mode 100644 samples/ml/distributed_partition_function/dpf_example.ipynb diff --git a/samples/ml/distributed_partition_function/README.md b/samples/ml/distributed_partition_function/README.md new file mode 100644 index 00000000..ff890d90 --- /dev/null +++ b/samples/ml/distributed_partition_function/README.md @@ -0,0 +1,42 @@ +# Distributed Partition Function (DPF) — Example Walkthrough + +## Introduction + +The **Distributed Partition Function (DPF)** lets you process data in parallel across one or more nodes in a compute pool. DPF partitions your data by a specified column (or by staged files) and executes your Python function on each partition concurrently. It handles distributed orchestration, errors, observability, and artifact persistence automatically. + +This example uses a **supply chain allocation** scenario: given factories with limited capacity and warehouses with specific demand, find the optimal shipping plan per region using `scipy.optimize.linprog`. Each region is solved as an independent DPF partition. + +## Execution Modes + +DPF supports two execution modes, both demonstrated in this notebook: + +| Mode | Method | Description | +|------|--------|-------------| +| **DataFrame mode** | `run()` | Partition a Snowpark DataFrame by column values and execute your function on each partition concurrently. | +| **Stage mode** | `run_from_stage()` | Process files from a Snowflake stage where each file becomes a partition. Ideal for large-scale file processing. | + +## What This Notebook Covers + +1. **Setup** — Session, stage, scale compute, and synthetic data generation +2. **DataFrame mode** — Define a processing function, run DPF, monitor progress, retrieve results, inspect logs, restore completed runs +3. **Stage mode** — Copy data to parquet files on stage, run DPF from stage +4. **ML Jobs deployment** — Deploy DPF workloads via the `@remote` decorator + +## Prerequisites + +- A [compute pool](https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool) with at least 3 max nodes (e.g., `CPU_X64_S`), or use the system-provided `SYSTEM_COMPUTE_POOL_CPU` +- A Snowflake Notebook running on the compute pool (Container Runtime) +- Stage access permissions for storing results and artifacts + +## Getting Started + +This notebook is intended to be run in a **Snowflake Notebook** environment on Snowpark Container Services. If running locally, use the **ML Jobs deployment** section at the bottom of the notebook to submit DPF workloads via the `@remote` decorator. + +Open the [DPF Example Notebook](./dpf_example.ipynb) for a full end-to-end walkthrough. + +## References + +- [DPF Documentation](https://docs.snowflake.com/en/developer-guide/snowflake-ml/process-data-across-partitions) +- [DPF API Reference](https://docs.snowflake.com/en/developer-guide/snowpark-ml/reference/latest/container-runtime/distributors.distributed_partition_function) +- [ML Jobs Documentation](https://docs.snowflake.com/developer-guide/snowflake-ml/ml-jobs/overview) +- [Many Model Training (MMT) Example](../many_model_training/mmt_example.ipynb) diff --git a/samples/ml/distributed_partition_function/dpf_example.ipynb b/samples/ml/distributed_partition_function/dpf_example.ipynb new file mode 100644 index 00000000..28040727 --- /dev/null +++ b/samples/ml/distributed_partition_function/dpf_example.ipynb @@ -0,0 +1,677 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Distributed Partition Function (DPF) — Example Walkthrough\n", + "\n", + "This notebook demonstrates how to use the **Distributed Partition Function (DPF)** to process data in parallel across multiple nodes in a compute pool. DPF partitions your data and executes your Python function on each partition concurrently, handling distributed orchestration, errors, observability, and artifact persistence automatically.\n", + "\n", + "We'll use a **supply chain allocation** scenario as the example: given factories with limited capacity and warehouses with specific demand, find the optimal shipping plan per region using `scipy.optimize.linprog`.\n", + "\n", + "DPF supports two execution modes:\n", + "\n", + "- **DataFrame mode** (`run()`): Partition a Snowpark DataFrame by column values and execute your function on each partition concurrently.\n", + "- **Stage mode** (`run_from_stage()`): Process files from a Snowflake stage where each file becomes a partition. Ideal for large-scale file processing with predictable memory usage.\n", + "\n", + "**Environment:** This notebook is designed to run in a Snowflake Notebook on Container Runtime. If running locally, see the **ML Jobs deployment** section at the bottom.\n", + "\n", + "**Prerequisites:**\n", + "- A compute pool with max nodes >= 3 (e.g., `CPU_X64_S`), or the system-provided `SYSTEM_COMPUTE_POOL_CPU`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "import json\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "from snowflake.snowpark import Session\n", + "\n", + "\n", + "session = Session.builder.getOrCreate()\n", + "\n", + "# Configuration\n", + "database = session.get_current_database() or \"MY_DATABASE\" # Change to your database\n", + "schema = session.get_current_schema() or \"MY_SCHEMA\" # Change to your schema\n", + "\n", + "input_stage = \"DPF_INPUT_STAGE\"\n", + "dpf_stage = \"DPF_RESULTS_STAGE\"\n", + "input_table = \"SUPPLY_CHAIN_DATA\"\n", + "output_table = \"OPTIMIZED_SHIPPING_MANIFEST\"\n", + "\n", + "# Create stages\n", + "session.use_schema(f\"{database}.{schema}\")\n", + "session.sql(f\"CREATE STAGE IF NOT EXISTS {dpf_stage}\").collect()\n", + "session.sql(f\"CREATE STAGE IF NOT EXISTS {input_stage}\").collect()\n", + "\n", + "print(f\"Session: {session}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import DPF modules and Scale Compute Nodes\n", + "Snowflake Notebook on Container Runtime only — skip this cell if running locally." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import DPF\n", + "from snowflake.ml.modeling.distributors.distributed_partition_function.dpf_run import (\n", + " DPFRun,\n", + ")\n", + "from snowflake.ml.modeling.distributors.distributed_partition_function.entities import (\n", + " RunStatus,\n", + " ExecutionOptions,\n", + ")\n", + "from snowflake.ml.runtime_cluster import scale_cluster\n", + "\n", + "# Scale to 3 nodes for parallel processing\n", + "scale_cluster(expected_cluster_size=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create Synthetic Supply Chain Data\n", + "\n", + "Generate a dataset with 5 regions, each containing 3 factories (supply) and 10 warehouses (demand)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_supply_chain_data(session, table_name):\n", + " \"\"\"Generate synthetic supply chain data with factories and warehouses across regions.\"\"\"\n", + " regions = [\"NA_WEST\", \"NA_EAST\", \"EMEA_CENTRAL\", \"APAC_SOUTH\", \"LATAM\"]\n", + " np.random.seed(42)\n", + " data = []\n", + "\n", + " for reg in regions:\n", + " # 3 Factories per region (supply)\n", + " for i in range(3):\n", + " data.append(\n", + " {\n", + " \"REGION\": reg,\n", + " \"LOCATION_ID\": f\"FACT_{reg}_{i}\",\n", + " \"TYPE\": \"FACTORY\",\n", + " \"LAT\": np.random.uniform(25, 55),\n", + " \"LON\": np.random.uniform(-130, -60),\n", + " \"CAPACITY\": 1000,\n", + " \"DEMAND\": 0,\n", + " }\n", + " )\n", + " # 10 Warehouses per region (demand)\n", + " for j in range(10):\n", + " data.append(\n", + " {\n", + " \"REGION\": reg,\n", + " \"LOCATION_ID\": f\"WH_{reg}_{j}\",\n", + " \"TYPE\": \"WAREHOUSE\",\n", + " \"LAT\": np.random.uniform(25, 55),\n", + " \"LON\": np.random.uniform(-130, -60),\n", + " \"CAPACITY\": 0,\n", + " \"DEMAND\": 250,\n", + " }\n", + " )\n", + "\n", + " df = pd.DataFrame(data)\n", + " sdf = session.create_dataframe(df)\n", + " sdf.write.mode(\"overwrite\").save_as_table(table_name)\n", + " print(f\"Created '{table_name}' with {len(df)} rows across {len(regions)} regions\")\n", + " return session.table(table_name)\n", + "\n", + "\n", + "supply_chain_sdf = create_supply_chain_data(session, input_table)\n", + "supply_chain_sdf.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## DataFrame Mode: Process Data by Column Partitions\n", + "\n", + "Partition the `SUPPLY_CHAIN_DATA` table by `REGION` and solve each region's allocation in parallel.\n", + "\n", + "1. **Define the processing function** — optimization logic that runs on each partition.\n", + "2. **Initialize and run DPF** — launch parallel execution across all partitions.\n", + "3. **Monitor progress** — track status and wait for completion.\n", + "4. **Retrieve results** — collect artifacts and output data from each partition.\n", + "5. **Restore a completed run** — access results from a previous run without re-executing.\n", + "\n", + "### Step 1: Define the Processing Function\n", + "\n", + "This function runs on each partition (region). It receives the partition's data via `data_connector` and\n", + "uses `scipy.optimize.linprog` to solve the transportation problem: minimize shipping cost while\n", + "satisfying warehouse demand without exceeding factory capacity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def solve_allocation(data_connector, context):\n", + " \"\"\"\n", + " Solve the supply chain allocation problem for a single region.\n", + "\n", + " Uses linear programming to find the optimal shipment plan that minimizes\n", + " total transportation cost (based on distance) subject to:\n", + " - Factory capacity constraints (supply)\n", + " - Warehouse demand constraints (demand)\n", + "\n", + " Args:\n", + " data_connector: Provides access to the partition's data.\n", + " context: PartitionContext with partition_id and artifact methods.\n", + " \"\"\"\n", + " from scipy.optimize import linprog\n", + " from scipy.spatial.distance import cdist\n", + " import pandas as pd\n", + " import numpy as np\n", + " import json\n", + "\n", + " df = data_connector.to_pandas()\n", + " region = context.partition_id\n", + " print(f\"[{region}] Processing {len(df)} locations\")\n", + "\n", + " factories = df[df[\"TYPE\"] == \"FACTORY\"].reset_index(drop=True)\n", + " warehouses = df[df[\"TYPE\"] == \"WAREHOUSE\"].reset_index(drop=True)\n", + " n_fact = len(factories)\n", + " n_wh = len(warehouses)\n", + "\n", + " # Build cost matrix (Euclidean distance as proxy for shipping cost)\n", + " cost_matrix = cdist(\n", + " factories[[\"LAT\", \"LON\"]], warehouses[[\"LAT\", \"LON\"]], metric=\"euclidean\"\n", + " )\n", + " c = cost_matrix.flatten()\n", + "\n", + " # Inequality constraint: total outbound from Factory_i <= Capacity_i\n", + " A_ub = np.zeros((n_fact, n_fact * n_wh))\n", + " for i in range(n_fact):\n", + " A_ub[i, i * n_wh : (i + 1) * n_wh] = 1\n", + " b_ub = factories[\"CAPACITY\"].values.astype(float)\n", + "\n", + " # Equality constraint: total inbound to Warehouse_j == Demand_j\n", + " A_eq = np.zeros((n_wh, n_fact * n_wh))\n", + " for j in range(n_wh):\n", + " A_eq[j, j::n_wh] = 1\n", + " b_eq = warehouses[\"DEMAND\"].values.astype(float)\n", + "\n", + " # Solve\n", + " res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, method=\"highs\")\n", + "\n", + " if res.success:\n", + " allocation = res.x.reshape((n_fact, n_wh))\n", + " manifest = []\n", + " for i in range(n_fact):\n", + " for j in range(n_wh):\n", + " qty = allocation[i, j]\n", + " if qty > 0.1:\n", + " manifest.append(\n", + " {\n", + " \"REGION\": region,\n", + " \"ORIGIN\": factories.loc[i, \"LOCATION_ID\"],\n", + " \"DESTINATION\": warehouses.loc[j, \"LOCATION_ID\"],\n", + " \"SHIPMENT_QTY\": round(float(qty), 2),\n", + " \"UNIT_DISTANCE\": round(float(cost_matrix[i, j]), 4),\n", + " }\n", + " )\n", + "\n", + " manifest_df = pd.DataFrame(manifest)\n", + "\n", + " summary = {\n", + " \"region\": region,\n", + " \"status\": \"OPTIMAL\",\n", + " \"total_cost\": round(float(res.fun), 2),\n", + " \"shipment_count\": len(manifest),\n", + " \"total_units_shipped\": round(sum(m[\"SHIPMENT_QTY\"] for m in manifest), 2),\n", + " }\n", + " print(\n", + " f\"[{region}] Optimal cost: {summary['total_cost']}, shipments: {len(manifest)}\"\n", + " )\n", + "\n", + " # Upload summary as a stage artifact\n", + " context.upload_to_stage(\n", + " summary,\n", + " \"summary.json\",\n", + " write_function=lambda obj, path: json.dump(obj, open(path, \"w\")),\n", + " )\n", + "\n", + " # Write results to a Snowflake table using the bounded session pool\n", + " context.with_session(\n", + " lambda session: session.create_dataframe(manifest_df)\n", + " .write.mode(\"append\")\n", + " .save_as_table(output_table)\n", + " )\n", + " else:\n", + " print(f\"[{region}] Optimization failed: {res.message}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 2: Initialize and Run DPF" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n", + "\n", + "session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", + "\n", + "run = dpf.run(\n", + " partition_by=\"REGION\",\n", + " snowpark_dataframe=session.table(input_table),\n", + " run_id=f\"supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n", + " execution_options=ExecutionOptions(use_head_node=True, num_cpus_per_worker=1),\n", + ")\n", + "print(f\"Launched: {run.run_id}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 3: Monitor Progress and Wait for Completion" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "final_status = run.wait() # Shows progress\n", + "print(f\"Job completed with status: {final_status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Quick summary of all partition statuses\n", + "progress = run.get_progress()\n", + "for status, partitions in progress.items():\n", + " print(f\"{status}: {len(partitions)} partitions\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 4: Retrieve Results from Each Partition" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def print_results(summaries):\n", + " \"\"\"Format and display the supply chain optimization results.\"\"\"\n", + " for s in summaries:\n", + " print(f\" {s['region']}: cost={s['total_cost']}, shipments={s['shipment_count']}\")\n", + "\n", + " total_cost = sum(s[\"total_cost\"] for s in summaries)\n", + " total_shipments = sum(s[\"shipment_count\"] for s in summaries)\n", + " print(f\"\\n TOTAL: cost={total_cost:.2f}, shipments={total_shipments}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if final_status == RunStatus.SUCCESS:\n", + " summaries = []\n", + " for partition_id, details in run.partition_details.items():\n", + " files = details.stage_artifacts_manager.list()\n", + " print(f\"Partition '{partition_id}' artifacts: {files}\")\n", + "\n", + " summary = details.stage_artifacts_manager.get(\n", + " \"summary.json\",\n", + " read_function=lambda path: json.load(open(path, \"r\")),\n", + " )\n", + " summaries.append(summary)\n", + "\n", + " print_results(summaries)\n", + "else:\n", + " print(f\"Run did not succeed: {final_status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# View the results written to the Snowflake table\n", + "session.table(output_table).show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inspect Partition Logs\n", + "\n", + "View stdout/stderr from individual partitions to verify processing or debug failures." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# View logs from each partition\n", + "for partition_id, details in run.partition_details.items():\n", + " print(f\"--- {partition_id} ---\")\n", + " print(details.logs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Debug failed partitions (if any)\n", + "# progress = run.get_progress()\n", + "# for partition in progress.get(\"FAILED\", []):\n", + "# print(f\"--- Failed: {partition.partition_id} ---\")\n", + "# print(partition.logs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 5: Restore Results from a Completed Run\n", + "\n", + "Access results from a previous run without re-executing. Useful after restarting a notebook session." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "restored_run = DPFRun.restore_from(\n", + " run_id=run.run_id,\n", + " stage_name=dpf_stage,\n", + ")\n", + "\n", + "print(f\"Restored run status: {restored_run.status}\")\n", + "for partition_id, details in restored_run.partition_details.items():\n", + " print(f\" {partition_id}: {details.status}\")\n", + "\n", + "# Note: Restored runs are read-only. You cannot call wait() or cancel() on them." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Stage Mode: Process Files from a Stage\n", + "\n", + "Process pre-staged parquet files where each file becomes a partition.\n", + "First, copy the data from the table to stage as parquet files, one per region." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare parquet files on stage — one file per region\n", + "session.sql(f\"REMOVE @{input_stage}/supply_chain/\").collect()\n", + "\n", + "session.sql(\n", + " f\"\"\"\n", + " COPY INTO @{input_stage}/supply_chain/\n", + " FROM {input_table}\n", + " PARTITION BY REGION\n", + " FILE_FORMAT = (TYPE = PARQUET COMPRESSION = SNAPPY)\n", + " MAX_FILE_SIZE = 15728640\n", + " HEADER = TRUE\n", + "\"\"\"\n", + ").collect()\n", + "\n", + "# Verify staged files\n", + "session.sql(f\"LIST @{input_stage}/supply_chain/\").show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run DPF from Stage\n", + "\n", + "The processing function signature is the same as DataFrame mode. The `data_connector` provides access\n", + "to each file's data, and `context.partition_id` is the relative file path." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dpf_from_stage = DPF(func=solve_allocation, stage_name=dpf_stage)\n", + "\n", + "session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", + "\n", + "stage_run = dpf_from_stage.run_from_stage(\n", + " stage_location=f\"@{input_stage}/supply_chain/\",\n", + " run_id=f\"supply_chain_stage_{datetime.now():%Y%m%d_%H%M%S}\",\n", + " file_pattern=\"*.parquet\",\n", + " execution_options=ExecutionOptions(\n", + " use_head_node=True,\n", + " num_cpus_per_worker=1,\n", + " ),\n", + ")\n", + "print(f\"Launched: {stage_run.run_id}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "stage_status = stage_run.wait()\n", + "print(f\"Stage mode completed with status: {stage_status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# View the results written to the Snowflake table\n", + "session.table(output_table).show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Deploy with ML Jobs via `@remote`\n", + "\n", + "Run DPF in an ML Job from any IDE. ML Jobs execute inside Snowpark Container Services\n", + "and can scale across multiple nodes. Logs are available in Snowsight under Monitoring > Services & Jobs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "job_stage = \"DPF_JOB_STAGE\"\n", + "compute_pool = \"SYSTEM_COMPUTE_POOL_CPU\" # Update with your compute pool name\n", + "\n", + "session.sql(f\"CREATE STAGE IF NOT EXISTS {job_stage}\").collect()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from snowflake.ml.jobs import remote\n", + "\n", + "@remote(\n", + " compute_pool=compute_pool,\n", + " stage_name=job_stage,\n", + " target_instances=3,\n", + ")\n", + "def launch_supply_chain_job():\n", + " \"\"\"\n", + " Launch a DPF supply chain optimization run as an ML Job.\n", + " \"\"\"\n", + " from datetime import datetime\n", + " from snowflake.snowpark import Session\n", + " from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import (\n", + " DPF,\n", + " )\n", + " from snowflake.ml.modeling.distributors.distributed_partition_function.entities import (\n", + " ExecutionOptions,\n", + " )\n", + "\n", + " session = Session.builder.getOrCreate()\n", + " dpf_input = session.table(input_table)\n", + "\n", + " dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n", + " run = dpf.run(\n", + " partition_by=\"REGION\",\n", + " snowpark_dataframe=dpf_input,\n", + " run_id=f\"job_supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n", + " execution_options=ExecutionOptions(use_head_node=False),\n", + " )\n", + " run.wait()\n", + "\n", + " print(f\"DPF run complete: {run.run_id}\")\n", + " return run.run_id\n", + "\n", + "session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", + "\n", + "job = launch_supply_chain_job()\n", + "print(f\"Job ID: {job.id}\")\n", + "print(f\"Status: {job.status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Check the status and logs of the ML Job\n", + "print(f\"Status: {job.status}\")\n", + "print(job.get_logs())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# View the results written to the Snowflake table\n", + "session.table(output_table).show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Cleanup\n", + "\n", + "Scale the cluster back down to a single node when you're done." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "scale_cluster(expected_cluster_size=1)\n", + "\n", + "# Uncomment to drop objects created by this notebook\n", + "# session.sql(f\"DROP TABLE IF EXISTS {input_table}\").collect()\n", + "# session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", + "# session.sql(f\"DROP STAGE IF EXISTS {dpf_stage}\").collect()\n", + "# session.sql(f\"DROP STAGE IF EXISTS {input_stage}\").collect()\n", + "# session.sql(f\"DROP STAGE IF EXISTS {job_stage}\").collect()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dpf-test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 12c7f6920306e7d90de07d43aaf1f614d55e0b66 Mon Sep 17 00:00:00 2001 From: Marie Coolsaet Date: Tue, 24 Feb 2026 16:50:31 -0500 Subject: [PATCH 2/5] Fix notebook format and metadata. --- .../distributed_partition_function/README.md | 10 +- .../dpf_example.ipynb | 251 +++++++++++++----- 2 files changed, 196 insertions(+), 65 deletions(-) diff --git a/samples/ml/distributed_partition_function/README.md b/samples/ml/distributed_partition_function/README.md index ff890d90..7ccda16e 100644 --- a/samples/ml/distributed_partition_function/README.md +++ b/samples/ml/distributed_partition_function/README.md @@ -1,4 +1,4 @@ -# Distributed Partition Function (DPF) — Example Walkthrough +# Distributed Partition Function (DPF) - Example Walkthrough ## Introduction @@ -17,10 +17,10 @@ DPF supports two execution modes, both demonstrated in this notebook: ## What This Notebook Covers -1. **Setup** — Session, stage, scale compute, and synthetic data generation -2. **DataFrame mode** — Define a processing function, run DPF, monitor progress, retrieve results, inspect logs, restore completed runs -3. **Stage mode** — Copy data to parquet files on stage, run DPF from stage -4. **ML Jobs deployment** — Deploy DPF workloads via the `@remote` decorator +1. **Setup** - Session, stage, scale compute, and synthetic data generation +2. **DataFrame mode** - Define a processing function, run DPF, monitor progress, retrieve results, inspect logs, restore completed runs +3. **Stage mode** - Copy data to parquet files on stage, run DPF from stage +4. **ML Jobs deployment** - Deploy DPF workloads via the `@remote` decorator ## Prerequisites diff --git a/samples/ml/distributed_partition_function/dpf_example.ipynb b/samples/ml/distributed_partition_function/dpf_example.ipynb index 28040727..bf5825db 100644 --- a/samples/ml/distributed_partition_function/dpf_example.ipynb +++ b/samples/ml/distributed_partition_function/dpf_example.ipynb @@ -2,9 +2,13 @@ "cells": [ { "cell_type": "markdown", - "metadata": {}, + "id": "e16f2bf5-88f3-4dfa-8d7a-6be220007ba3", + "metadata": { + "collapsed": false, + "name": "cell0" + }, "source": [ - "# Distributed Partition Function (DPF) — Example Walkthrough\n", + "# Distributed Partition Function (DPF) - Example Walkthrough\n", "\n", "This notebook demonstrates how to use the **Distributed Partition Function (DPF)** to process data in parallel across multiple nodes in a compute pool. DPF partitions your data and executes your Python function on each partition concurrently, handling distributed orchestration, errors, observability, and artifact persistence automatically.\n", "\n", @@ -23,7 +27,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "bebe7269-edd8-4117-957b-d19a5be03ff2", + "metadata": { + "collapsed": false, + "name": "cell1" + }, "source": [ "---\n", "## Setup" @@ -32,7 +40,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "85a70bc5-f1d0-45e7-99ee-7d9c811df886", + "metadata": { + "language": "python", + "name": "cell2" + }, "outputs": [], "source": [ "from datetime import datetime\n", @@ -65,16 +77,24 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "b8a5fdce-1b4d-4041-a1b9-862cefd1eade", + "metadata": { + "collapsed": false, + "name": "cell3" + }, "source": [ "### Import DPF modules and Scale Compute Nodes\n", - "Snowflake Notebook on Container Runtime only — skip this cell if running locally." + "Snowflake Notebook on Container Runtime only - skip this cell if running locally." ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "bbaef190-6856-4616-8a74-30a9b452fbf8", + "metadata": { + "language": "python", + "name": "cell4" + }, "outputs": [], "source": [ "from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import DPF\n", @@ -93,7 +113,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "830b0d3a-0f5c-4100-afac-6a5be0e36a17", + "metadata": { + "collapsed": false, + "name": "cell5" + }, "source": [ "### Create Synthetic Supply Chain Data\n", "\n", @@ -103,7 +127,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "42b3e109-80a8-4fcd-a1b0-a41842e6cbd5", + "metadata": { + "language": "python", + "name": "cell6" + }, "outputs": [], "source": [ "def create_supply_chain_data(session, table_name):\n", @@ -153,18 +181,22 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "a0cafdbc-ee9b-4fb4-b228-820dc3dcf5c1", + "metadata": { + "collapsed": false, + "name": "cell7" + }, "source": [ "---\n", "## DataFrame Mode: Process Data by Column Partitions\n", "\n", "Partition the `SUPPLY_CHAIN_DATA` table by `REGION` and solve each region's allocation in parallel.\n", "\n", - "1. **Define the processing function** — optimization logic that runs on each partition.\n", - "2. **Initialize and run DPF** — launch parallel execution across all partitions.\n", - "3. **Monitor progress** — track status and wait for completion.\n", - "4. **Retrieve results** — collect artifacts and output data from each partition.\n", - "5. **Restore a completed run** — access results from a previous run without re-executing.\n", + "1. **Define the processing function** - optimization logic that runs on each partition.\n", + "2. **Initialize and run DPF** - launch parallel execution across all partitions.\n", + "3. **Monitor progress** - track status and wait for completion.\n", + "4. **Retrieve results** - collect artifacts and output data from each partition.\n", + "5. **Restore a completed run** - access results from a previous run without re-executing.\n", "\n", "### Step 1: Define the Processing Function\n", "\n", @@ -176,7 +208,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "785989a9-c244-4761-9462-8cb5a62decd1", + "metadata": { + "language": "python", + "name": "cell8" + }, "outputs": [], "source": [ "def solve_allocation(data_connector, context):\n", @@ -277,7 +313,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "7d6f4691-a3cf-4c46-8785-1aa016476b5d", + "metadata": { + "collapsed": false, + "name": "cell9" + }, "source": [ "### Step 2: Initialize and Run DPF" ] @@ -285,7 +325,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "84e07682-7bca-4ccb-b0cf-51cbf4bc478c", + "metadata": { + "language": "python", + "name": "cell10" + }, "outputs": [], "source": [ "dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n", @@ -303,7 +347,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "f4b8b230-b983-48ca-a9e7-18ca97279764", + "metadata": { + "collapsed": false, + "name": "cell11" + }, "source": [ "### Step 3: Monitor Progress and Wait for Completion" ] @@ -311,7 +359,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "f5c3a1c7-d809-4a86-b395-16e56da332ec", + "metadata": { + "language": "python", + "name": "cell12" + }, "outputs": [], "source": [ "final_status = run.wait() # Shows progress\n", @@ -321,7 +373,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "2b824586-57ba-4d1d-bd96-f38f76ba2e73", + "metadata": { + "language": "python", + "name": "cell13" + }, "outputs": [], "source": [ "# Quick summary of all partition statuses\n", @@ -332,7 +388,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "2f92856d-4e0f-45cb-ba82-1b2ca6047ba6", + "metadata": { + "collapsed": false, + "name": "cell14" + }, "source": [ "### Step 4: Retrieve Results from Each Partition" ] @@ -340,7 +400,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "fedc35ff-c572-49e3-b10e-abdb31db8053", + "metadata": { + "language": "python", + "name": "cell15" + }, "outputs": [], "source": [ "def print_results(summaries):\n", @@ -356,7 +420,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "3ab86af1-e6f6-40ea-95f5-a1f5e5fae72a", + "metadata": { + "language": "python", + "name": "cell16" + }, "outputs": [], "source": [ "if final_status == RunStatus.SUCCESS:\n", @@ -379,7 +447,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "ae7b476e-340e-4103-904e-85298e9a3cdd", + "metadata": { + "language": "python", + "name": "cell17" + }, "outputs": [], "source": [ "# View the results written to the Snowflake table\n", @@ -388,7 +460,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "7fcab272-38a5-47b4-8a10-74f5f87eb299", + "metadata": { + "collapsed": false, + "name": "cell18" + }, "source": [ "### Inspect Partition Logs\n", "\n", @@ -398,7 +474,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "0c4a28a4-469b-46c9-b0cf-79e142bef101", + "metadata": { + "language": "python", + "name": "cell19" + }, "outputs": [], "source": [ "# View logs from each partition\n", @@ -410,7 +490,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "f8d1a479-8928-4f78-a61a-e67b82df1a7f", + "metadata": { + "language": "python", + "name": "cell20" + }, "outputs": [], "source": [ "# Debug failed partitions (if any)\n", @@ -422,7 +506,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "a9ad12c6-faf5-4fb1-b5de-c2289b9f4670", + "metadata": { + "collapsed": false, + "name": "cell21" + }, "source": [ "### Step 5: Restore Results from a Completed Run\n", "\n", @@ -432,7 +520,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "5d71d432-6c2c-4a49-96a8-8ba85397844f", + "metadata": { + "language": "python", + "name": "cell22" + }, "outputs": [], "source": [ "restored_run = DPFRun.restore_from(\n", @@ -449,7 +541,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "9e8d3b12-9019-4153-b808-195d16356657", + "metadata": { + "collapsed": false, + "name": "cell23" + }, "source": [ "---\n", "## Stage Mode: Process Files from a Stage\n", @@ -461,10 +557,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "efc121f3-43d2-4f7f-88f6-0f7a8fce63d4", + "metadata": { + "language": "python", + "name": "cell24" + }, "outputs": [], "source": [ - "# Prepare parquet files on stage — one file per region\n", + "# Prepare parquet files on stage - one file per region\n", "session.sql(f\"REMOVE @{input_stage}/supply_chain/\").collect()\n", "\n", "session.sql(\n", @@ -484,7 +584,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "c5e47877-9163-49a6-a6ce-ac8bf82e23b1", + "metadata": { + "collapsed": false, + "name": "cell25" + }, "source": [ "### Run DPF from Stage\n", "\n", @@ -495,7 +599,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "1d320ec5-e978-4392-a1fd-d35d3b772dfc", + "metadata": { + "language": "python", + "name": "cell26" + }, "outputs": [], "source": [ "dpf_from_stage = DPF(func=solve_allocation, stage_name=dpf_stage)\n", @@ -517,7 +625,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "c4bb2ffc-6906-4341-a758-511545a2209f", + "metadata": { + "language": "python", + "name": "cell27" + }, "outputs": [], "source": [ "stage_status = stage_run.wait()\n", @@ -527,7 +639,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "1ab59f01-345e-4bda-8248-cb74f9941287", + "metadata": { + "language": "python", + "name": "cell28" + }, "outputs": [], "source": [ "# View the results written to the Snowflake table\n", @@ -536,7 +652,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "f1cad585-2349-4d16-ad60-2de39db7a30b", + "metadata": { + "collapsed": false, + "name": "cell29" + }, "source": [ "---\n", "## Deploy with ML Jobs via `@remote`\n", @@ -548,7 +668,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "9f5b59b6-e8d5-47c3-91ed-d3366ea16885", + "metadata": { + "language": "python", + "name": "cell30" + }, "outputs": [], "source": [ "job_stage = \"DPF_JOB_STAGE\"\n", @@ -560,7 +684,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "6c00d321-7185-4d77-9111-98d4b19b5a83", + "metadata": { + "language": "python", + "name": "cell31" + }, "outputs": [], "source": [ "from snowflake.ml.jobs import remote\n", @@ -608,7 +736,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "fadd70e4-048c-45cf-b883-dcb71d525985", + "metadata": { + "language": "python", + "name": "cell32" + }, "outputs": [], "source": [ "# Check the status and logs of the ML Job\n", @@ -619,7 +751,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "9f8aadcc-9e10-4ab3-aa86-96a7483c308f", + "metadata": { + "language": "python", + "name": "cell33" + }, "outputs": [], "source": [ "# View the results written to the Snowflake table\n", @@ -628,7 +764,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "id": "d3697df9-557d-4603-be01-d4dda3ac9b3b", + "metadata": { + "collapsed": false, + "name": "cell34" + }, "source": [ "---\n", "## Cleanup\n", @@ -639,7 +779,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "2b2d20d9-6934-42e7-9072-66723ce5884d", + "metadata": { + "language": "python", + "name": "cell35" + }, "outputs": [], "source": [ "scale_cluster(expected_cluster_size=1)\n", @@ -655,23 +799,10 @@ ], "metadata": { "kernelspec": { - "display_name": "dpf-test", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.18" + "display_name": "Streamlit Notebook", + "name": "streamlit" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 5 } From 7c3662640e3e59eca6b10fc59c05eb256815a476 Mon Sep 17 00:00:00 2001 From: Marie Coolsaet Date: Fri, 27 Feb 2026 14:20:01 -0500 Subject: [PATCH 3/5] Removed max file size and changed to use head node in ml job. --- .../distributed_partition_function/dpf_example.ipynb | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/samples/ml/distributed_partition_function/dpf_example.ipynb b/samples/ml/distributed_partition_function/dpf_example.ipynb index bf5825db..e15d71e4 100644 --- a/samples/ml/distributed_partition_function/dpf_example.ipynb +++ b/samples/ml/distributed_partition_function/dpf_example.ipynb @@ -573,7 +573,6 @@ " FROM {input_table}\n", " PARTITION BY REGION\n", " FILE_FORMAT = (TYPE = PARQUET COMPRESSION = SNAPPY)\n", - " MAX_FILE_SIZE = 15728640\n", " HEADER = TRUE\n", "\"\"\"\n", ").collect()\n", @@ -719,7 +718,7 @@ " partition_by=\"REGION\",\n", " snowpark_dataframe=dpf_input,\n", " run_id=f\"job_supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n", - " execution_options=ExecutionOptions(use_head_node=False),\n", + " execution_options=ExecutionOptions(use_head_node=True),\n", " )\n", " run.wait()\n", "\n", @@ -799,8 +798,13 @@ ], "metadata": { "kernelspec": { - "display_name": "Streamlit Notebook", - "name": "streamlit" + "display_name": "dpf-test", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.18" } }, "nbformat": 4, From b507173ea81023b480a941b2d1ae3501300253d9 Mon Sep 17 00:00:00 2001 From: Marie Coolsaet Date: Thu, 30 Apr 2026 15:00:24 -0400 Subject: [PATCH 4/5] Use fully qualified names for stages and tables --- .../dpf_example.ipynb | 1209 ++++++----------- 1 file changed, 398 insertions(+), 811 deletions(-) diff --git a/samples/ml/distributed_partition_function/dpf_example.ipynb b/samples/ml/distributed_partition_function/dpf_example.ipynb index e15d71e4..8dd06b19 100644 --- a/samples/ml/distributed_partition_function/dpf_example.ipynb +++ b/samples/ml/distributed_partition_function/dpf_example.ipynb @@ -1,812 +1,399 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "e16f2bf5-88f3-4dfa-8d7a-6be220007ba3", - "metadata": { - "collapsed": false, - "name": "cell0" - }, - "source": [ - "# Distributed Partition Function (DPF) - Example Walkthrough\n", - "\n", - "This notebook demonstrates how to use the **Distributed Partition Function (DPF)** to process data in parallel across multiple nodes in a compute pool. DPF partitions your data and executes your Python function on each partition concurrently, handling distributed orchestration, errors, observability, and artifact persistence automatically.\n", - "\n", - "We'll use a **supply chain allocation** scenario as the example: given factories with limited capacity and warehouses with specific demand, find the optimal shipping plan per region using `scipy.optimize.linprog`.\n", - "\n", - "DPF supports two execution modes:\n", - "\n", - "- **DataFrame mode** (`run()`): Partition a Snowpark DataFrame by column values and execute your function on each partition concurrently.\n", - "- **Stage mode** (`run_from_stage()`): Process files from a Snowflake stage where each file becomes a partition. Ideal for large-scale file processing with predictable memory usage.\n", - "\n", - "**Environment:** This notebook is designed to run in a Snowflake Notebook on Container Runtime. If running locally, see the **ML Jobs deployment** section at the bottom.\n", - "\n", - "**Prerequisites:**\n", - "- A compute pool with max nodes >= 3 (e.g., `CPU_X64_S`), or the system-provided `SYSTEM_COMPUTE_POOL_CPU`" - ] - }, - { - "cell_type": "markdown", - "id": "bebe7269-edd8-4117-957b-d19a5be03ff2", - "metadata": { - "collapsed": false, - "name": "cell1" - }, - "source": [ - "---\n", - "## Setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "85a70bc5-f1d0-45e7-99ee-7d9c811df886", - "metadata": { - "language": "python", - "name": "cell2" - }, - "outputs": [], - "source": [ - "from datetime import datetime\n", - "import json\n", - "\n", - "import pandas as pd\n", - "import numpy as np\n", - "\n", - "from snowflake.snowpark import Session\n", - "\n", - "\n", - "session = Session.builder.getOrCreate()\n", - "\n", - "# Configuration\n", - "database = session.get_current_database() or \"MY_DATABASE\" # Change to your database\n", - "schema = session.get_current_schema() or \"MY_SCHEMA\" # Change to your schema\n", - "\n", - "input_stage = \"DPF_INPUT_STAGE\"\n", - "dpf_stage = \"DPF_RESULTS_STAGE\"\n", - "input_table = \"SUPPLY_CHAIN_DATA\"\n", - "output_table = \"OPTIMIZED_SHIPPING_MANIFEST\"\n", - "\n", - "# Create stages\n", - "session.use_schema(f\"{database}.{schema}\")\n", - "session.sql(f\"CREATE STAGE IF NOT EXISTS {dpf_stage}\").collect()\n", - "session.sql(f\"CREATE STAGE IF NOT EXISTS {input_stage}\").collect()\n", - "\n", - "print(f\"Session: {session}\")" - ] - }, - { - "cell_type": "markdown", - "id": "b8a5fdce-1b4d-4041-a1b9-862cefd1eade", - "metadata": { - "collapsed": false, - "name": "cell3" - }, - "source": [ - "### Import DPF modules and Scale Compute Nodes\n", - "Snowflake Notebook on Container Runtime only - skip this cell if running locally." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bbaef190-6856-4616-8a74-30a9b452fbf8", - "metadata": { - "language": "python", - "name": "cell4" - }, - "outputs": [], - "source": [ - "from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import DPF\n", - "from snowflake.ml.modeling.distributors.distributed_partition_function.dpf_run import (\n", - " DPFRun,\n", - ")\n", - "from snowflake.ml.modeling.distributors.distributed_partition_function.entities import (\n", - " RunStatus,\n", - " ExecutionOptions,\n", - ")\n", - "from snowflake.ml.runtime_cluster import scale_cluster\n", - "\n", - "# Scale to 3 nodes for parallel processing\n", - "scale_cluster(expected_cluster_size=3)" - ] - }, - { - "cell_type": "markdown", - "id": "830b0d3a-0f5c-4100-afac-6a5be0e36a17", - "metadata": { - "collapsed": false, - "name": "cell5" - }, - "source": [ - "### Create Synthetic Supply Chain Data\n", - "\n", - "Generate a dataset with 5 regions, each containing 3 factories (supply) and 10 warehouses (demand)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "42b3e109-80a8-4fcd-a1b0-a41842e6cbd5", - "metadata": { - "language": "python", - "name": "cell6" - }, - "outputs": [], - "source": [ - "def create_supply_chain_data(session, table_name):\n", - " \"\"\"Generate synthetic supply chain data with factories and warehouses across regions.\"\"\"\n", - " regions = [\"NA_WEST\", \"NA_EAST\", \"EMEA_CENTRAL\", \"APAC_SOUTH\", \"LATAM\"]\n", - " np.random.seed(42)\n", - " data = []\n", - "\n", - " for reg in regions:\n", - " # 3 Factories per region (supply)\n", - " for i in range(3):\n", - " data.append(\n", - " {\n", - " \"REGION\": reg,\n", - " \"LOCATION_ID\": f\"FACT_{reg}_{i}\",\n", - " \"TYPE\": \"FACTORY\",\n", - " \"LAT\": np.random.uniform(25, 55),\n", - " \"LON\": np.random.uniform(-130, -60),\n", - " \"CAPACITY\": 1000,\n", - " \"DEMAND\": 0,\n", - " }\n", - " )\n", - " # 10 Warehouses per region (demand)\n", - " for j in range(10):\n", - " data.append(\n", - " {\n", - " \"REGION\": reg,\n", - " \"LOCATION_ID\": f\"WH_{reg}_{j}\",\n", - " \"TYPE\": \"WAREHOUSE\",\n", - " \"LAT\": np.random.uniform(25, 55),\n", - " \"LON\": np.random.uniform(-130, -60),\n", - " \"CAPACITY\": 0,\n", - " \"DEMAND\": 250,\n", - " }\n", - " )\n", - "\n", - " df = pd.DataFrame(data)\n", - " sdf = session.create_dataframe(df)\n", - " sdf.write.mode(\"overwrite\").save_as_table(table_name)\n", - " print(f\"Created '{table_name}' with {len(df)} rows across {len(regions)} regions\")\n", - " return session.table(table_name)\n", - "\n", - "\n", - "supply_chain_sdf = create_supply_chain_data(session, input_table)\n", - "supply_chain_sdf.show()" - ] - }, - { - "cell_type": "markdown", - "id": "a0cafdbc-ee9b-4fb4-b228-820dc3dcf5c1", - "metadata": { - "collapsed": false, - "name": "cell7" - }, - "source": [ - "---\n", - "## DataFrame Mode: Process Data by Column Partitions\n", - "\n", - "Partition the `SUPPLY_CHAIN_DATA` table by `REGION` and solve each region's allocation in parallel.\n", - "\n", - "1. **Define the processing function** - optimization logic that runs on each partition.\n", - "2. **Initialize and run DPF** - launch parallel execution across all partitions.\n", - "3. **Monitor progress** - track status and wait for completion.\n", - "4. **Retrieve results** - collect artifacts and output data from each partition.\n", - "5. **Restore a completed run** - access results from a previous run without re-executing.\n", - "\n", - "### Step 1: Define the Processing Function\n", - "\n", - "This function runs on each partition (region). It receives the partition's data via `data_connector` and\n", - "uses `scipy.optimize.linprog` to solve the transportation problem: minimize shipping cost while\n", - "satisfying warehouse demand without exceeding factory capacity." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "785989a9-c244-4761-9462-8cb5a62decd1", - "metadata": { - "language": "python", - "name": "cell8" - }, - "outputs": [], - "source": [ - "def solve_allocation(data_connector, context):\n", - " \"\"\"\n", - " Solve the supply chain allocation problem for a single region.\n", - "\n", - " Uses linear programming to find the optimal shipment plan that minimizes\n", - " total transportation cost (based on distance) subject to:\n", - " - Factory capacity constraints (supply)\n", - " - Warehouse demand constraints (demand)\n", - "\n", - " Args:\n", - " data_connector: Provides access to the partition's data.\n", - " context: PartitionContext with partition_id and artifact methods.\n", - " \"\"\"\n", - " from scipy.optimize import linprog\n", - " from scipy.spatial.distance import cdist\n", - " import pandas as pd\n", - " import numpy as np\n", - " import json\n", - "\n", - " df = data_connector.to_pandas()\n", - " region = context.partition_id\n", - " print(f\"[{region}] Processing {len(df)} locations\")\n", - "\n", - " factories = df[df[\"TYPE\"] == \"FACTORY\"].reset_index(drop=True)\n", - " warehouses = df[df[\"TYPE\"] == \"WAREHOUSE\"].reset_index(drop=True)\n", - " n_fact = len(factories)\n", - " n_wh = len(warehouses)\n", - "\n", - " # Build cost matrix (Euclidean distance as proxy for shipping cost)\n", - " cost_matrix = cdist(\n", - " factories[[\"LAT\", \"LON\"]], warehouses[[\"LAT\", \"LON\"]], metric=\"euclidean\"\n", - " )\n", - " c = cost_matrix.flatten()\n", - "\n", - " # Inequality constraint: total outbound from Factory_i <= Capacity_i\n", - " A_ub = np.zeros((n_fact, n_fact * n_wh))\n", - " for i in range(n_fact):\n", - " A_ub[i, i * n_wh : (i + 1) * n_wh] = 1\n", - " b_ub = factories[\"CAPACITY\"].values.astype(float)\n", - "\n", - " # Equality constraint: total inbound to Warehouse_j == Demand_j\n", - " A_eq = np.zeros((n_wh, n_fact * n_wh))\n", - " for j in range(n_wh):\n", - " A_eq[j, j::n_wh] = 1\n", - " b_eq = warehouses[\"DEMAND\"].values.astype(float)\n", - "\n", - " # Solve\n", - " res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, method=\"highs\")\n", - "\n", - " if res.success:\n", - " allocation = res.x.reshape((n_fact, n_wh))\n", - " manifest = []\n", - " for i in range(n_fact):\n", - " for j in range(n_wh):\n", - " qty = allocation[i, j]\n", - " if qty > 0.1:\n", - " manifest.append(\n", - " {\n", - " \"REGION\": region,\n", - " \"ORIGIN\": factories.loc[i, \"LOCATION_ID\"],\n", - " \"DESTINATION\": warehouses.loc[j, \"LOCATION_ID\"],\n", - " \"SHIPMENT_QTY\": round(float(qty), 2),\n", - " \"UNIT_DISTANCE\": round(float(cost_matrix[i, j]), 4),\n", - " }\n", - " )\n", - "\n", - " manifest_df = pd.DataFrame(manifest)\n", - "\n", - " summary = {\n", - " \"region\": region,\n", - " \"status\": \"OPTIMAL\",\n", - " \"total_cost\": round(float(res.fun), 2),\n", - " \"shipment_count\": len(manifest),\n", - " \"total_units_shipped\": round(sum(m[\"SHIPMENT_QTY\"] for m in manifest), 2),\n", - " }\n", - " print(\n", - " f\"[{region}] Optimal cost: {summary['total_cost']}, shipments: {len(manifest)}\"\n", - " )\n", - "\n", - " # Upload summary as a stage artifact\n", - " context.upload_to_stage(\n", - " summary,\n", - " \"summary.json\",\n", - " write_function=lambda obj, path: json.dump(obj, open(path, \"w\")),\n", - " )\n", - "\n", - " # Write results to a Snowflake table using the bounded session pool\n", - " context.with_session(\n", - " lambda session: session.create_dataframe(manifest_df)\n", - " .write.mode(\"append\")\n", - " .save_as_table(output_table)\n", - " )\n", - " else:\n", - " print(f\"[{region}] Optimization failed: {res.message}\")" - ] - }, - { - "cell_type": "markdown", - "id": "7d6f4691-a3cf-4c46-8785-1aa016476b5d", - "metadata": { - "collapsed": false, - "name": "cell9" - }, - "source": [ - "### Step 2: Initialize and Run DPF" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "84e07682-7bca-4ccb-b0cf-51cbf4bc478c", - "metadata": { - "language": "python", - "name": "cell10" - }, - "outputs": [], - "source": [ - "dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n", - "\n", - "session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", - "\n", - "run = dpf.run(\n", - " partition_by=\"REGION\",\n", - " snowpark_dataframe=session.table(input_table),\n", - " run_id=f\"supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n", - " execution_options=ExecutionOptions(use_head_node=True, num_cpus_per_worker=1),\n", - ")\n", - "print(f\"Launched: {run.run_id}\")" - ] - }, - { - "cell_type": "markdown", - "id": "f4b8b230-b983-48ca-a9e7-18ca97279764", - "metadata": { - "collapsed": false, - "name": "cell11" - }, - "source": [ - "### Step 3: Monitor Progress and Wait for Completion" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f5c3a1c7-d809-4a86-b395-16e56da332ec", - "metadata": { - "language": "python", - "name": "cell12" - }, - "outputs": [], - "source": [ - "final_status = run.wait() # Shows progress\n", - "print(f\"Job completed with status: {final_status}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2b824586-57ba-4d1d-bd96-f38f76ba2e73", - "metadata": { - "language": "python", - "name": "cell13" - }, - "outputs": [], - "source": [ - "# Quick summary of all partition statuses\n", - "progress = run.get_progress()\n", - "for status, partitions in progress.items():\n", - " print(f\"{status}: {len(partitions)} partitions\")" - ] - }, - { - "cell_type": "markdown", - "id": "2f92856d-4e0f-45cb-ba82-1b2ca6047ba6", - "metadata": { - "collapsed": false, - "name": "cell14" - }, - "source": [ - "### Step 4: Retrieve Results from Each Partition" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fedc35ff-c572-49e3-b10e-abdb31db8053", - "metadata": { - "language": "python", - "name": "cell15" - }, - "outputs": [], - "source": [ - "def print_results(summaries):\n", - " \"\"\"Format and display the supply chain optimization results.\"\"\"\n", - " for s in summaries:\n", - " print(f\" {s['region']}: cost={s['total_cost']}, shipments={s['shipment_count']}\")\n", - "\n", - " total_cost = sum(s[\"total_cost\"] for s in summaries)\n", - " total_shipments = sum(s[\"shipment_count\"] for s in summaries)\n", - " print(f\"\\n TOTAL: cost={total_cost:.2f}, shipments={total_shipments}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3ab86af1-e6f6-40ea-95f5-a1f5e5fae72a", - "metadata": { - "language": "python", - "name": "cell16" - }, - "outputs": [], - "source": [ - "if final_status == RunStatus.SUCCESS:\n", - " summaries = []\n", - " for partition_id, details in run.partition_details.items():\n", - " files = details.stage_artifacts_manager.list()\n", - " print(f\"Partition '{partition_id}' artifacts: {files}\")\n", - "\n", - " summary = details.stage_artifacts_manager.get(\n", - " \"summary.json\",\n", - " read_function=lambda path: json.load(open(path, \"r\")),\n", - " )\n", - " summaries.append(summary)\n", - "\n", - " print_results(summaries)\n", - "else:\n", - " print(f\"Run did not succeed: {final_status}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ae7b476e-340e-4103-904e-85298e9a3cdd", - "metadata": { - "language": "python", - "name": "cell17" - }, - "outputs": [], - "source": [ - "# View the results written to the Snowflake table\n", - "session.table(output_table).show()" - ] - }, - { - "cell_type": "markdown", - "id": "7fcab272-38a5-47b4-8a10-74f5f87eb299", - "metadata": { - "collapsed": false, - "name": "cell18" - }, - "source": [ - "### Inspect Partition Logs\n", - "\n", - "View stdout/stderr from individual partitions to verify processing or debug failures." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0c4a28a4-469b-46c9-b0cf-79e142bef101", - "metadata": { - "language": "python", - "name": "cell19" - }, - "outputs": [], - "source": [ - "# View logs from each partition\n", - "for partition_id, details in run.partition_details.items():\n", - " print(f\"--- {partition_id} ---\")\n", - " print(details.logs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f8d1a479-8928-4f78-a61a-e67b82df1a7f", - "metadata": { - "language": "python", - "name": "cell20" - }, - "outputs": [], - "source": [ - "# Debug failed partitions (if any)\n", - "# progress = run.get_progress()\n", - "# for partition in progress.get(\"FAILED\", []):\n", - "# print(f\"--- Failed: {partition.partition_id} ---\")\n", - "# print(partition.logs)" - ] - }, - { - "cell_type": "markdown", - "id": "a9ad12c6-faf5-4fb1-b5de-c2289b9f4670", - "metadata": { - "collapsed": false, - "name": "cell21" - }, - "source": [ - "### Step 5: Restore Results from a Completed Run\n", - "\n", - "Access results from a previous run without re-executing. Useful after restarting a notebook session." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5d71d432-6c2c-4a49-96a8-8ba85397844f", - "metadata": { - "language": "python", - "name": "cell22" - }, - "outputs": [], - "source": [ - "restored_run = DPFRun.restore_from(\n", - " run_id=run.run_id,\n", - " stage_name=dpf_stage,\n", - ")\n", - "\n", - "print(f\"Restored run status: {restored_run.status}\")\n", - "for partition_id, details in restored_run.partition_details.items():\n", - " print(f\" {partition_id}: {details.status}\")\n", - "\n", - "# Note: Restored runs are read-only. You cannot call wait() or cancel() on them." - ] - }, - { - "cell_type": "markdown", - "id": "9e8d3b12-9019-4153-b808-195d16356657", - "metadata": { - "collapsed": false, - "name": "cell23" - }, - "source": [ - "---\n", - "## Stage Mode: Process Files from a Stage\n", - "\n", - "Process pre-staged parquet files where each file becomes a partition.\n", - "First, copy the data from the table to stage as parquet files, one per region." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "efc121f3-43d2-4f7f-88f6-0f7a8fce63d4", - "metadata": { - "language": "python", - "name": "cell24" - }, - "outputs": [], - "source": [ - "# Prepare parquet files on stage - one file per region\n", - "session.sql(f\"REMOVE @{input_stage}/supply_chain/\").collect()\n", - "\n", - "session.sql(\n", - " f\"\"\"\n", - " COPY INTO @{input_stage}/supply_chain/\n", - " FROM {input_table}\n", - " PARTITION BY REGION\n", - " FILE_FORMAT = (TYPE = PARQUET COMPRESSION = SNAPPY)\n", - " HEADER = TRUE\n", - "\"\"\"\n", - ").collect()\n", - "\n", - "# Verify staged files\n", - "session.sql(f\"LIST @{input_stage}/supply_chain/\").show()" - ] - }, - { - "cell_type": "markdown", - "id": "c5e47877-9163-49a6-a6ce-ac8bf82e23b1", - "metadata": { - "collapsed": false, - "name": "cell25" - }, - "source": [ - "### Run DPF from Stage\n", - "\n", - "The processing function signature is the same as DataFrame mode. The `data_connector` provides access\n", - "to each file's data, and `context.partition_id` is the relative file path." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1d320ec5-e978-4392-a1fd-d35d3b772dfc", - "metadata": { - "language": "python", - "name": "cell26" - }, - "outputs": [], - "source": [ - "dpf_from_stage = DPF(func=solve_allocation, stage_name=dpf_stage)\n", - "\n", - "session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", - "\n", - "stage_run = dpf_from_stage.run_from_stage(\n", - " stage_location=f\"@{input_stage}/supply_chain/\",\n", - " run_id=f\"supply_chain_stage_{datetime.now():%Y%m%d_%H%M%S}\",\n", - " file_pattern=\"*.parquet\",\n", - " execution_options=ExecutionOptions(\n", - " use_head_node=True,\n", - " num_cpus_per_worker=1,\n", - " ),\n", - ")\n", - "print(f\"Launched: {stage_run.run_id}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c4bb2ffc-6906-4341-a758-511545a2209f", - "metadata": { - "language": "python", - "name": "cell27" - }, - "outputs": [], - "source": [ - "stage_status = stage_run.wait()\n", - "print(f\"Stage mode completed with status: {stage_status}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1ab59f01-345e-4bda-8248-cb74f9941287", - "metadata": { - "language": "python", - "name": "cell28" - }, - "outputs": [], - "source": [ - "# View the results written to the Snowflake table\n", - "session.table(output_table).show()" - ] - }, - { - "cell_type": "markdown", - "id": "f1cad585-2349-4d16-ad60-2de39db7a30b", - "metadata": { - "collapsed": false, - "name": "cell29" - }, - "source": [ - "---\n", - "## Deploy with ML Jobs via `@remote`\n", - "\n", - "Run DPF in an ML Job from any IDE. ML Jobs execute inside Snowpark Container Services\n", - "and can scale across multiple nodes. Logs are available in Snowsight under Monitoring > Services & Jobs." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9f5b59b6-e8d5-47c3-91ed-d3366ea16885", - "metadata": { - "language": "python", - "name": "cell30" - }, - "outputs": [], - "source": [ - "job_stage = \"DPF_JOB_STAGE\"\n", - "compute_pool = \"SYSTEM_COMPUTE_POOL_CPU\" # Update with your compute pool name\n", - "\n", - "session.sql(f\"CREATE STAGE IF NOT EXISTS {job_stage}\").collect()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6c00d321-7185-4d77-9111-98d4b19b5a83", - "metadata": { - "language": "python", - "name": "cell31" - }, - "outputs": [], - "source": [ - "from snowflake.ml.jobs import remote\n", - "\n", - "@remote(\n", - " compute_pool=compute_pool,\n", - " stage_name=job_stage,\n", - " target_instances=3,\n", - ")\n", - "def launch_supply_chain_job():\n", - " \"\"\"\n", - " Launch a DPF supply chain optimization run as an ML Job.\n", - " \"\"\"\n", - " from datetime import datetime\n", - " from snowflake.snowpark import Session\n", - " from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import (\n", - " DPF,\n", - " )\n", - " from snowflake.ml.modeling.distributors.distributed_partition_function.entities import (\n", - " ExecutionOptions,\n", - " )\n", - "\n", - " session = Session.builder.getOrCreate()\n", - " dpf_input = session.table(input_table)\n", - "\n", - " dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n", - " run = dpf.run(\n", - " partition_by=\"REGION\",\n", - " snowpark_dataframe=dpf_input,\n", - " run_id=f\"job_supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n", - " execution_options=ExecutionOptions(use_head_node=True),\n", - " )\n", - " run.wait()\n", - "\n", - " print(f\"DPF run complete: {run.run_id}\")\n", - " return run.run_id\n", - "\n", - "session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", - "\n", - "job = launch_supply_chain_job()\n", - "print(f\"Job ID: {job.id}\")\n", - "print(f\"Status: {job.status}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fadd70e4-048c-45cf-b883-dcb71d525985", - "metadata": { - "language": "python", - "name": "cell32" - }, - "outputs": [], - "source": [ - "# Check the status and logs of the ML Job\n", - "print(f\"Status: {job.status}\")\n", - "print(job.get_logs())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9f8aadcc-9e10-4ab3-aa86-96a7483c308f", - "metadata": { - "language": "python", - "name": "cell33" - }, - "outputs": [], - "source": [ - "# View the results written to the Snowflake table\n", - "session.table(output_table).show()" - ] - }, - { - "cell_type": "markdown", - "id": "d3697df9-557d-4603-be01-d4dda3ac9b3b", - "metadata": { - "collapsed": false, - "name": "cell34" - }, - "source": [ - "---\n", - "## Cleanup\n", - "\n", - "Scale the cluster back down to a single node when you're done." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2b2d20d9-6934-42e7-9072-66723ce5884d", - "metadata": { - "language": "python", - "name": "cell35" - }, - "outputs": [], - "source": [ - "scale_cluster(expected_cluster_size=1)\n", - "\n", - "# Uncomment to drop objects created by this notebook\n", - "# session.sql(f\"DROP TABLE IF EXISTS {input_table}\").collect()\n", - "# session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", - "# session.sql(f\"DROP STAGE IF EXISTS {dpf_stage}\").collect()\n", - "# session.sql(f\"DROP STAGE IF EXISTS {input_stage}\").collect()\n", - "# session.sql(f\"DROP STAGE IF EXISTS {job_stage}\").collect()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "dpf-test", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.18" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "cells": [ + { + "cell_type": "markdown", + "id": "e16f2bf5-88f3-4dfa-8d7a-6be220007ba3", + "metadata": { + "collapsed": false, + "name": "cell0", + "codeCollapsed": true + }, + "source": "# Distributed Partition Function (DPF) - Example Walkthrough\n\nThis notebook demonstrates how to use the **Distributed Partition Function (DPF)** to process data in parallel across multiple nodes in a compute pool. DPF partitions your data and executes your Python function on each partition concurrently, handling distributed orchestration, errors, observability, and artifact persistence automatically.\n\nWe'll use a **supply chain allocation** scenario as the example: given factories with limited capacity and warehouses with specific demand, find the optimal shipping plan per region using `scipy.optimize.linprog`.\n\nDPF supports two execution modes:\n\n- **DataFrame mode** (`run()`): Partition a Snowpark DataFrame by column values and execute your function on each partition concurrently.\n- **Stage mode** (`run_from_stage()`): Process files from a Snowflake stage where each file becomes a partition. Ideal for large-scale file processing with predictable memory usage.\n\n**Environment:** This notebook is designed to run in a Snowflake Notebook on Container Runtime. If running locally, see the **ML Jobs deployment** section at the bottom.\n\n**Prerequisites:**\n- A compute pool with max nodes >= 3 (e.g., `CPU_X64_S`), or the system-provided `SYSTEM_COMPUTE_POOL_CPU`" + }, + { + "cell_type": "markdown", + "id": "bebe7269-edd8-4117-957b-d19a5be03ff2", + "metadata": { + "collapsed": false, + "name": "cell1", + "codeCollapsed": true + }, + "source": "---\n## Setup" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85a70bc5-f1d0-45e7-99ee-7d9c811df886", + "metadata": { + "language": "python", + "name": "cell2" + }, + "outputs": [], + "source": "from datetime import datetime\nimport json\n\nimport pandas as pd\nimport numpy as np\n\nfrom snowflake.snowpark import Session\n\n\nsession = Session.builder.getOrCreate()\n\n# Configuration — set target database/schema\ndatabase = \"MY_DATABASE\"\nschema = \"MY_SCHEMA\"\n\ninput_stage = f\"{database}.{schema}.DPF_INPUT_STAGE\"\ndpf_stage = f\"{database}.{schema}.DPF_RESULTS_STAGE\"\ninput_table = f\"{database}.{schema}.SUPPLY_CHAIN_DATA\"\noutput_table = f\"{database}.{schema}.OPTIMIZED_SHIPPING_MANIFEST\"\n\nsession.use_schema(f\"{database}.{schema}\")\nsession.sql(f\"CREATE STAGE IF NOT EXISTS {dpf_stage}\").collect()\nsession.sql(f\"CREATE STAGE IF NOT EXISTS {input_stage}\").collect()\n\nprint(f\"Session: {session}\")" + }, + { + "cell_type": "markdown", + "id": "b8a5fdce-1b4d-4041-a1b9-862cefd1eade", + "metadata": { + "collapsed": false, + "name": "cell3", + "codeCollapsed": true + }, + "source": "### Import DPF modules and Scale Compute Nodes\nSnowflake Notebook on Container Runtime only - skip this cell if running locally." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbaef190-6856-4616-8a74-30a9b452fbf8", + "metadata": { + "language": "python", + "name": "cell4" + }, + "outputs": [], + "source": "from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import DPF\nfrom snowflake.ml.modeling.distributors.distributed_partition_function.dpf_run import (\n DPFRun,\n)\nfrom snowflake.ml.modeling.distributors.distributed_partition_function.entities import (\n RunStatus,\n ExecutionOptions,\n)\nfrom snowflake.ml.runtime_cluster import scale_cluster\n\n# Scale to 3 nodes for parallel processing\nscale_cluster(expected_cluster_size=3)" + }, + { + "cell_type": "markdown", + "id": "830b0d3a-0f5c-4100-afac-6a5be0e36a17", + "metadata": { + "collapsed": false, + "name": "cell5", + "codeCollapsed": true + }, + "source": "### Create Synthetic Supply Chain Data\n\nGenerate a dataset with 5 regions, each containing 3 factories (supply) and 10 warehouses (demand)." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42b3e109-80a8-4fcd-a1b0-a41842e6cbd5", + "metadata": { + "language": "python", + "name": "cell6" + }, + "outputs": [], + "source": "def create_supply_chain_data(session, table_name):\n \"\"\"Generate synthetic supply chain data with factories and warehouses across regions.\"\"\"\n regions = [\"NA_WEST\", \"NA_EAST\", \"EMEA_CENTRAL\", \"APAC_SOUTH\", \"LATAM\"]\n np.random.seed(42)\n data = []\n\n for reg in regions:\n # 3 Factories per region (supply)\n for i in range(3):\n data.append(\n {\n \"REGION\": reg,\n \"LOCATION_ID\": f\"FACT_{reg}_{i}\",\n \"TYPE\": \"FACTORY\",\n \"LAT\": np.random.uniform(25, 55),\n \"LON\": np.random.uniform(-130, -60),\n \"CAPACITY\": 1000,\n \"DEMAND\": 0,\n }\n )\n # 10 Warehouses per region (demand)\n for j in range(10):\n data.append(\n {\n \"REGION\": reg,\n \"LOCATION_ID\": f\"WH_{reg}_{j}\",\n \"TYPE\": \"WAREHOUSE\",\n \"LAT\": np.random.uniform(25, 55),\n \"LON\": np.random.uniform(-130, -60),\n \"CAPACITY\": 0,\n \"DEMAND\": 250,\n }\n )\n\n df = pd.DataFrame(data)\n sdf = session.create_dataframe(df)\n sdf.write.mode(\"overwrite\").save_as_table(table_name)\n print(f\"Created '{table_name}' with {len(df)} rows across {len(regions)} regions\")\n return session.table(table_name)\n\n\nsupply_chain_sdf = create_supply_chain_data(session, input_table)\nsupply_chain_sdf.show()" + }, + { + "cell_type": "markdown", + "id": "a0cafdbc-ee9b-4fb4-b228-820dc3dcf5c1", + "metadata": { + "collapsed": false, + "name": "cell7", + "codeCollapsed": true + }, + "source": "---\n## DataFrame Mode: Process Data by Column Partitions\n\nPartition the `SUPPLY_CHAIN_DATA` table by `REGION` and solve each region's allocation in parallel.\n\n1. **Define the processing function** - optimization logic that runs on each partition.\n2. **Initialize and run DPF** - launch parallel execution across all partitions.\n3. **Monitor progress** - track status and wait for completion.\n4. **Retrieve results** - collect artifacts and output data from each partition.\n5. **Restore a completed run** - access results from a previous run without re-executing.\n\n### Step 1: Define the Processing Function\n\nThis function runs on each partition (region). It receives the partition's data via `data_connector` and\nuses `scipy.optimize.linprog` to solve the transportation problem: minimize shipping cost while\nsatisfying warehouse demand without exceeding factory capacity." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "785989a9-c244-4761-9462-8cb5a62decd1", + "metadata": { + "language": "python", + "name": "cell8" + }, + "outputs": [], + "source": "def solve_allocation(data_connector, context):\n \"\"\"\n Solve the supply chain allocation problem for a single region.\n\n Uses linear programming to find the optimal shipment plan that minimizes\n total transportation cost (based on distance) subject to:\n - Factory capacity constraints (supply)\n - Warehouse demand constraints (demand)\n\n Args:\n data_connector: Provides access to the partition's data.\n context: PartitionContext with partition_id and artifact methods.\n \"\"\"\n from scipy.optimize import linprog\n from scipy.spatial.distance import cdist\n import pandas as pd\n import numpy as np\n import json\n\n df = data_connector.to_pandas()\n region = context.partition_id\n print(f\"[{region}] Processing {len(df)} locations\")\n\n factories = df[df[\"TYPE\"] == \"FACTORY\"].reset_index(drop=True)\n warehouses = df[df[\"TYPE\"] == \"WAREHOUSE\"].reset_index(drop=True)\n n_fact = len(factories)\n n_wh = len(warehouses)\n\n # Build cost matrix (Euclidean distance as proxy for shipping cost)\n cost_matrix = cdist(\n factories[[\"LAT\", \"LON\"]], warehouses[[\"LAT\", \"LON\"]], metric=\"euclidean\"\n )\n c = cost_matrix.flatten()\n\n # Inequality constraint: total outbound from Factory_i <= Capacity_i\n A_ub = np.zeros((n_fact, n_fact * n_wh))\n for i in range(n_fact):\n A_ub[i, i * n_wh : (i + 1) * n_wh] = 1\n b_ub = factories[\"CAPACITY\"].values.astype(float)\n\n # Equality constraint: total inbound to Warehouse_j == Demand_j\n A_eq = np.zeros((n_wh, n_fact * n_wh))\n for j in range(n_wh):\n A_eq[j, j::n_wh] = 1\n b_eq = warehouses[\"DEMAND\"].values.astype(float)\n\n # Solve\n res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, method=\"highs\")\n\n if res.success:\n allocation = res.x.reshape((n_fact, n_wh))\n manifest = []\n for i in range(n_fact):\n for j in range(n_wh):\n qty = allocation[i, j]\n if qty > 0.1:\n manifest.append(\n {\n \"REGION\": region,\n \"ORIGIN\": factories.loc[i, \"LOCATION_ID\"],\n \"DESTINATION\": warehouses.loc[j, \"LOCATION_ID\"],\n \"SHIPMENT_QTY\": round(float(qty), 2),\n \"UNIT_DISTANCE\": round(float(cost_matrix[i, j]), 4),\n }\n )\n\n manifest_df = pd.DataFrame(manifest)\n\n summary = {\n \"region\": region,\n \"status\": \"OPTIMAL\",\n \"total_cost\": round(float(res.fun), 2),\n \"shipment_count\": len(manifest),\n \"total_units_shipped\": round(sum(m[\"SHIPMENT_QTY\"] for m in manifest), 2),\n }\n print(\n f\"[{region}] Optimal cost: {summary['total_cost']}, shipments: {len(manifest)}\"\n )\n\n # Upload summary as a stage artifact\n context.upload_to_stage(\n summary,\n \"summary.json\",\n write_function=lambda obj, path: json.dump(obj, open(path, \"w\")),\n )\n\n # Write results to a Snowflake table using the bounded session pool\n context.with_session(\n lambda session: session.create_dataframe(manifest_df)\n .write.mode(\"append\")\n .save_as_table(output_table)\n )\n else:\n print(f\"[{region}] Optimization failed: {res.message}\")" + }, + { + "cell_type": "markdown", + "id": "7d6f4691-a3cf-4c46-8785-1aa016476b5d", + "metadata": { + "collapsed": false, + "name": "cell9", + "codeCollapsed": true + }, + "source": "### Step 2: Initialize and Run DPF" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84e07682-7bca-4ccb-b0cf-51cbf4bc478c", + "metadata": { + "language": "python", + "name": "cell10" + }, + "outputs": [], + "source": "dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n\nsession.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n\nrun = dpf.run(\n partition_by=\"REGION\",\n snowpark_dataframe=session.table(input_table),\n run_id=f\"supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n execution_options=ExecutionOptions(use_head_node=True, num_cpus_per_worker=1),\n)\nprint(f\"Launched: {run.run_id}\")" + }, + { + "cell_type": "markdown", + "id": "f4b8b230-b983-48ca-a9e7-18ca97279764", + "metadata": { + "collapsed": false, + "name": "cell11", + "codeCollapsed": true + }, + "source": "### Step 3: Monitor Progress and Wait for Completion" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5c3a1c7-d809-4a86-b395-16e56da332ec", + "metadata": { + "language": "python", + "name": "cell12" + }, + "outputs": [], + "source": "final_status = run.wait() # Shows progress\nprint(f\"Job completed with status: {final_status}\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b824586-57ba-4d1d-bd96-f38f76ba2e73", + "metadata": { + "language": "python", + "name": "cell13" + }, + "outputs": [], + "source": "# Quick summary of all partition statuses\nprogress = run.get_progress()\nfor status, partitions in progress.items():\n print(f\"{status}: {len(partitions)} partitions\")" + }, + { + "cell_type": "markdown", + "id": "2f92856d-4e0f-45cb-ba82-1b2ca6047ba6", + "metadata": { + "collapsed": false, + "name": "cell14", + "codeCollapsed": true + }, + "source": "### Step 4: Retrieve Results from Each Partition" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fedc35ff-c572-49e3-b10e-abdb31db8053", + "metadata": { + "language": "python", + "name": "cell15" + }, + "outputs": [], + "source": "def print_results(summaries):\n \"\"\"Format and display the supply chain optimization results.\"\"\"\n for s in summaries:\n print(f\" {s['region']}: cost={s['total_cost']}, shipments={s['shipment_count']}\")\n\n total_cost = sum(s[\"total_cost\"] for s in summaries)\n total_shipments = sum(s[\"shipment_count\"] for s in summaries)\n print(f\"\\n TOTAL: cost={total_cost:.2f}, shipments={total_shipments}\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ab86af1-e6f6-40ea-95f5-a1f5e5fae72a", + "metadata": { + "language": "python", + "name": "cell16" + }, + "outputs": [], + "source": "if final_status == RunStatus.SUCCESS:\n summaries = []\n for partition_id, details in run.partition_details.items():\n files = details.stage_artifacts_manager.list()\n print(f\"Partition '{partition_id}' artifacts: {files}\")\n\n summary = details.stage_artifacts_manager.get(\n \"summary.json\",\n read_function=lambda path: json.load(open(path, \"r\")),\n )\n summaries.append(summary)\n\n print_results(summaries)\nelse:\n print(f\"Run did not succeed: {final_status}\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae7b476e-340e-4103-904e-85298e9a3cdd", + "metadata": { + "language": "python", + "name": "cell17" + }, + "outputs": [], + "source": "# View the results written to the Snowflake table\nsession.table(output_table).show()" + }, + { + "cell_type": "markdown", + "id": "7fcab272-38a5-47b4-8a10-74f5f87eb299", + "metadata": { + "collapsed": false, + "name": "cell18", + "codeCollapsed": true + }, + "source": "### Inspect Partition Logs\n\nView stdout/stderr from individual partitions to verify processing or debug failures." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c4a28a4-469b-46c9-b0cf-79e142bef101", + "metadata": { + "language": "python", + "name": "cell19" + }, + "outputs": [], + "source": "# View logs from each partition\nfor partition_id, details in run.partition_details.items():\n print(f\"--- {partition_id} ---\")\n print(details.logs)" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8d1a479-8928-4f78-a61a-e67b82df1a7f", + "metadata": { + "language": "python", + "name": "cell20" + }, + "outputs": [], + "source": "# Debug failed partitions (if any)\n# progress = run.get_progress()\n# for partition in progress.get(\"FAILED\", []):\n# print(f\"--- Failed: {partition.partition_id} ---\")\n# print(partition.logs)" + }, + { + "cell_type": "markdown", + "id": "a9ad12c6-faf5-4fb1-b5de-c2289b9f4670", + "metadata": { + "collapsed": false, + "name": "cell21", + "codeCollapsed": true + }, + "source": "### Step 5: Restore Results from a Completed Run\n\nAccess results from a previous run without re-executing. Useful after restarting a notebook session." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d71d432-6c2c-4a49-96a8-8ba85397844f", + "metadata": { + "language": "python", + "name": "cell22" + }, + "outputs": [], + "source": "restored_run = DPFRun.restore_from(\n run_id=run.run_id,\n stage_name=dpf_stage,\n)\n\nprint(f\"Restored run status: {restored_run.status}\")\nfor partition_id, details in restored_run.partition_details.items():\n print(f\" {partition_id}: {details.status}\")\n\n# Note: Restored runs are read-only. You cannot call wait() or cancel() on them." + }, + { + "cell_type": "markdown", + "id": "9e8d3b12-9019-4153-b808-195d16356657", + "metadata": { + "collapsed": false, + "name": "cell23", + "codeCollapsed": true + }, + "source": "---\n## Stage Mode: Process Files from a Stage\n\nProcess pre-staged parquet files where each file becomes a partition.\nFirst, copy the data from the table to stage as parquet files, one per region." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efc121f3-43d2-4f7f-88f6-0f7a8fce63d4", + "metadata": { + "language": "python", + "name": "cell24" + }, + "outputs": [], + "source": "# Prepare parquet files on stage - one file per region\nsession.sql(f\"REMOVE @{input_stage}/supply_chain/\").collect()\n\nsession.sql(\n f\"\"\"\n COPY INTO @{input_stage}/supply_chain/\n FROM {input_table}\n PARTITION BY REGION\n FILE_FORMAT = (TYPE = PARQUET COMPRESSION = SNAPPY)\n HEADER = TRUE\n\"\"\"\n).collect()\n\n# Verify staged files\nsession.sql(f\"LIST @{input_stage}/supply_chain/\").show()" + }, + { + "cell_type": "markdown", + "id": "c5e47877-9163-49a6-a6ce-ac8bf82e23b1", + "metadata": { + "collapsed": false, + "name": "cell25", + "codeCollapsed": true + }, + "source": "### Run DPF from Stage\n\nThe processing function signature is the same as DataFrame mode. The `data_connector` provides access\nto each file's data, and `context.partition_id` is the relative file path." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d320ec5-e978-4392-a1fd-d35d3b772dfc", + "metadata": { + "language": "python", + "name": "cell26" + }, + "outputs": [], + "source": "dpf_from_stage = DPF(func=solve_allocation, stage_name=dpf_stage)\n\nsession.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n\nstage_run = dpf_from_stage.run_from_stage(\n stage_location=f\"@{input_stage}/supply_chain/\",\n run_id=f\"supply_chain_stage_{datetime.now():%Y%m%d_%H%M%S}\",\n file_pattern=\"*.parquet\",\n execution_options=ExecutionOptions(\n use_head_node=True,\n num_cpus_per_worker=1,\n ),\n)\nprint(f\"Launched: {stage_run.run_id}\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4bb2ffc-6906-4341-a758-511545a2209f", + "metadata": { + "language": "python", + "name": "cell27" + }, + "outputs": [], + "source": "stage_status = stage_run.wait()\nprint(f\"Stage mode completed with status: {stage_status}\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ab59f01-345e-4bda-8248-cb74f9941287", + "metadata": { + "language": "python", + "name": "cell28" + }, + "outputs": [], + "source": "# View the results written to the Snowflake table\nsession.table(output_table).show()" + }, + { + "cell_type": "markdown", + "id": "f1cad585-2349-4d16-ad60-2de39db7a30b", + "metadata": { + "collapsed": false, + "name": "cell29", + "codeCollapsed": true + }, + "source": "---\n## Deploy with ML Jobs via `@remote`\n\nRun DPF in an ML Job from any IDE. ML Jobs execute inside Snowpark Container Services\nand can scale across multiple nodes. Logs are available in Snowsight under Monitoring > Services & Jobs." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f5b59b6-e8d5-47c3-91ed-d3366ea16885", + "metadata": { + "language": "python", + "name": "cell30" + }, + "outputs": [], + "source": "job_stage = f\"{database}.{schema}.DPF_JOB_STAGE\"\ncompute_pool = \"SYSTEM_COMPUTE_POOL_CPU\" # Update with your compute pool name\n\nsession.sql(f\"CREATE STAGE IF NOT EXISTS {job_stage}\").collect()" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c00d321-7185-4d77-9111-98d4b19b5a83", + "metadata": { + "language": "python", + "name": "cell31" + }, + "outputs": [], + "source": "from snowflake.ml.jobs import remote\n\n@remote(\n compute_pool=compute_pool,\n stage_name=job_stage,\n target_instances=3,\n)\ndef launch_supply_chain_job():\n \"\"\"\n Launch a DPF supply chain optimization run as an ML Job.\n \"\"\"\n from datetime import datetime\n from snowflake.snowpark import Session\n from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import (\n DPF,\n )\n from snowflake.ml.modeling.distributors.distributed_partition_function.entities import (\n ExecutionOptions,\n )\n\n session = Session.builder.getOrCreate()\n dpf_input = session.table(input_table)\n\n dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n run = dpf.run(\n partition_by=\"REGION\",\n snowpark_dataframe=dpf_input,\n run_id=f\"job_supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n execution_options=ExecutionOptions(use_head_node=True),\n )\n run.wait()\n\n print(f\"DPF run complete: {run.run_id}\")\n return run.run_id\n\nsession.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n\njob = launch_supply_chain_job()\nprint(f\"Job ID: {job.id}\")\nprint(f\"Status: {job.status}\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fadd70e4-048c-45cf-b883-dcb71d525985", + "metadata": { + "language": "python", + "name": "cell32" + }, + "outputs": [], + "source": "# Check the status and logs of the ML Job\nprint(f\"Status: {job.status}\")\nprint(job.get_logs())" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f8aadcc-9e10-4ab3-aa86-96a7483c308f", + "metadata": { + "language": "python", + "name": "cell33" + }, + "outputs": [], + "source": "# View the results written to the Snowflake table\nsession.table(output_table).show()" + }, + { + "cell_type": "markdown", + "id": "d3697df9-557d-4603-be01-d4dda3ac9b3b", + "metadata": { + "collapsed": false, + "name": "cell34", + "codeCollapsed": true + }, + "source": "---\n## Cleanup\n\nScale the cluster back down to a single node when you're done." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b2d20d9-6934-42e7-9072-66723ce5884d", + "metadata": { + "language": "python", + "name": "cell35" + }, + "outputs": [], + "source": "scale_cluster(expected_cluster_size=1)\n\n# Uncomment to drop objects created by this notebook\n# session.sql(f\"DROP TABLE IF EXISTS {input_table}\").collect()\n# session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n# session.sql(f\"DROP STAGE IF EXISTS {dpf_stage}\").collect()\n# session.sql(f\"DROP STAGE IF EXISTS {input_stage}\").collect()\n# session.sql(f\"DROP STAGE IF EXISTS {job_stage}\").collect()" + } + ], + "metadata": { + "kernelspec": { + "display_name": "dpf-test", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file From 2ef784ea93a83f167fc59f52675223f593d88054 Mon Sep 17 00:00:00 2001 From: Marie Coolsaet Date: Wed, 6 May 2026 14:48:01 -0400 Subject: [PATCH 5/5] Update snowflake notebook language. --- .../distributed_partition_function/README.md | 2 +- .../dpf_example.ipynb | 560 +++++++++++++++--- 2 files changed, 494 insertions(+), 68 deletions(-) diff --git a/samples/ml/distributed_partition_function/README.md b/samples/ml/distributed_partition_function/README.md index 7ccda16e..63a3d7b6 100644 --- a/samples/ml/distributed_partition_function/README.md +++ b/samples/ml/distributed_partition_function/README.md @@ -30,7 +30,7 @@ DPF supports two execution modes, both demonstrated in this notebook: ## Getting Started -This notebook is intended to be run in a **Snowflake Notebook** environment on Snowpark Container Services. If running locally, use the **ML Jobs deployment** section at the bottom of the notebook to submit DPF workloads via the `@remote` decorator. +This notebook is intended to be run in a **Snowflake Notebook**. If running locally, use the **ML Jobs deployment** section at the bottom of the notebook to submit DPF workloads via the `@remote` decorator. Open the [DPF Example Notebook](./dpf_example.ipynb) for a full end-to-end walkthrough. diff --git a/samples/ml/distributed_partition_function/dpf_example.ipynb b/samples/ml/distributed_partition_function/dpf_example.ipynb index 8dd06b19..b7fe3faf 100644 --- a/samples/ml/distributed_partition_function/dpf_example.ipynb +++ b/samples/ml/distributed_partition_function/dpf_example.ipynb @@ -4,21 +4,40 @@ "cell_type": "markdown", "id": "e16f2bf5-88f3-4dfa-8d7a-6be220007ba3", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell0", - "codeCollapsed": true - }, - "source": "# Distributed Partition Function (DPF) - Example Walkthrough\n\nThis notebook demonstrates how to use the **Distributed Partition Function (DPF)** to process data in parallel across multiple nodes in a compute pool. DPF partitions your data and executes your Python function on each partition concurrently, handling distributed orchestration, errors, observability, and artifact persistence automatically.\n\nWe'll use a **supply chain allocation** scenario as the example: given factories with limited capacity and warehouses with specific demand, find the optimal shipping plan per region using `scipy.optimize.linprog`.\n\nDPF supports two execution modes:\n\n- **DataFrame mode** (`run()`): Partition a Snowpark DataFrame by column values and execute your function on each partition concurrently.\n- **Stage mode** (`run_from_stage()`): Process files from a Snowflake stage where each file becomes a partition. Ideal for large-scale file processing with predictable memory usage.\n\n**Environment:** This notebook is designed to run in a Snowflake Notebook on Container Runtime. If running locally, see the **ML Jobs deployment** section at the bottom.\n\n**Prerequisites:**\n- A compute pool with max nodes >= 3 (e.g., `CPU_X64_S`), or the system-provided `SYSTEM_COMPUTE_POOL_CPU`" + "name": "cell0" + }, + "source": [ + "# Distributed Partition Function (DPF) - Example Walkthrough\n", + "\n", + "This notebook demonstrates how to use the **Distributed Partition Function (DPF)** to process data in parallel across multiple nodes in a compute pool. DPF partitions your data and executes your Python function on each partition concurrently, handling distributed orchestration, errors, observability, and artifact persistence automatically.\n", + "\n", + "We'll use a **supply chain allocation** scenario as the example: given factories with limited capacity and warehouses with specific demand, find the optimal shipping plan per region using `scipy.optimize.linprog`.\n", + "\n", + "DPF supports two execution modes:\n", + "\n", + "- **DataFrame mode** (`run()`): Partition a Snowpark DataFrame by column values and execute your function on each partition concurrently.\n", + "- **Stage mode** (`run_from_stage()`): Process files from a Snowflake stage where each file becomes a partition. Ideal for large-scale file processing with predictable memory usage.\n", + "\n", + "**Environment:** This notebook is designed to run in a Snowflake Notebook. If running locally, see the **ML Jobs deployment** section at the bottom.\n", + "\n", + "**Prerequisites:**\n", + "- A compute pool with max nodes >= 3 (e.g., `CPU_X64_S`), or the system-provided `SYSTEM_COMPUTE_POOL_CPU`" + ] }, { "cell_type": "markdown", "id": "bebe7269-edd8-4117-957b-d19a5be03ff2", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell1", - "codeCollapsed": true + "name": "cell1" }, - "source": "---\n## Setup" + "source": [ + "---\n", + "## Setup" + ] }, { "cell_type": "code", @@ -29,17 +48,46 @@ "name": "cell2" }, "outputs": [], - "source": "from datetime import datetime\nimport json\n\nimport pandas as pd\nimport numpy as np\n\nfrom snowflake.snowpark import Session\n\n\nsession = Session.builder.getOrCreate()\n\n# Configuration — set target database/schema\ndatabase = \"MY_DATABASE\"\nschema = \"MY_SCHEMA\"\n\ninput_stage = f\"{database}.{schema}.DPF_INPUT_STAGE\"\ndpf_stage = f\"{database}.{schema}.DPF_RESULTS_STAGE\"\ninput_table = f\"{database}.{schema}.SUPPLY_CHAIN_DATA\"\noutput_table = f\"{database}.{schema}.OPTIMIZED_SHIPPING_MANIFEST\"\n\nsession.use_schema(f\"{database}.{schema}\")\nsession.sql(f\"CREATE STAGE IF NOT EXISTS {dpf_stage}\").collect()\nsession.sql(f\"CREATE STAGE IF NOT EXISTS {input_stage}\").collect()\n\nprint(f\"Session: {session}\")" + "source": [ + "from datetime import datetime\n", + "import json\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "from snowflake.snowpark import Session\n", + "\n", + "\n", + "session = Session.builder.getOrCreate()\n", + "\n", + "# Configuration — set target database/schema\n", + "database = \"MY_DATABASE\"\n", + "schema = \"MY_SCHEMA\"\n", + "\n", + "input_stage = f\"{database}.{schema}.DPF_INPUT_STAGE\"\n", + "dpf_stage = f\"{database}.{schema}.DPF_RESULTS_STAGE\"\n", + "input_table = f\"{database}.{schema}.SUPPLY_CHAIN_DATA\"\n", + "output_table = f\"{database}.{schema}.OPTIMIZED_SHIPPING_MANIFEST\"\n", + "\n", + "session.use_schema(f\"{database}.{schema}\")\n", + "session.sql(f\"CREATE STAGE IF NOT EXISTS {dpf_stage}\").collect()\n", + "session.sql(f\"CREATE STAGE IF NOT EXISTS {input_stage}\").collect()\n", + "\n", + "print(f\"Session: {session}\")" + ] }, { "cell_type": "markdown", "id": "b8a5fdce-1b4d-4041-a1b9-862cefd1eade", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell3", - "codeCollapsed": true + "name": "cell3" }, - "source": "### Import DPF modules and Scale Compute Nodes\nSnowflake Notebook on Container Runtime only - skip this cell if running locally." + "source": [ + "### Import DPF modules and Scale Compute Nodes\n", + "Snowflake Notebook on Container Runtime only - skip this cell if running locally." + ] }, { "cell_type": "code", @@ -50,17 +98,34 @@ "name": "cell4" }, "outputs": [], - "source": "from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import DPF\nfrom snowflake.ml.modeling.distributors.distributed_partition_function.dpf_run import (\n DPFRun,\n)\nfrom snowflake.ml.modeling.distributors.distributed_partition_function.entities import (\n RunStatus,\n ExecutionOptions,\n)\nfrom snowflake.ml.runtime_cluster import scale_cluster\n\n# Scale to 3 nodes for parallel processing\nscale_cluster(expected_cluster_size=3)" + "source": [ + "from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import DPF\n", + "from snowflake.ml.modeling.distributors.distributed_partition_function.dpf_run import (\n", + " DPFRun,\n", + ")\n", + "from snowflake.ml.modeling.distributors.distributed_partition_function.entities import (\n", + " RunStatus,\n", + " ExecutionOptions,\n", + ")\n", + "from snowflake.ml.runtime_cluster import scale_cluster\n", + "\n", + "# Scale to 3 nodes for parallel processing\n", + "scale_cluster(expected_cluster_size=3)" + ] }, { "cell_type": "markdown", "id": "830b0d3a-0f5c-4100-afac-6a5be0e36a17", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell5", - "codeCollapsed": true + "name": "cell5" }, - "source": "### Create Synthetic Supply Chain Data\n\nGenerate a dataset with 5 regions, each containing 3 factories (supply) and 10 warehouses (demand)." + "source": [ + "### Create Synthetic Supply Chain Data\n", + "\n", + "Generate a dataset with 5 regions, each containing 3 factories (supply) and 10 warehouses (demand)." + ] }, { "cell_type": "code", @@ -71,17 +136,78 @@ "name": "cell6" }, "outputs": [], - "source": "def create_supply_chain_data(session, table_name):\n \"\"\"Generate synthetic supply chain data with factories and warehouses across regions.\"\"\"\n regions = [\"NA_WEST\", \"NA_EAST\", \"EMEA_CENTRAL\", \"APAC_SOUTH\", \"LATAM\"]\n np.random.seed(42)\n data = []\n\n for reg in regions:\n # 3 Factories per region (supply)\n for i in range(3):\n data.append(\n {\n \"REGION\": reg,\n \"LOCATION_ID\": f\"FACT_{reg}_{i}\",\n \"TYPE\": \"FACTORY\",\n \"LAT\": np.random.uniform(25, 55),\n \"LON\": np.random.uniform(-130, -60),\n \"CAPACITY\": 1000,\n \"DEMAND\": 0,\n }\n )\n # 10 Warehouses per region (demand)\n for j in range(10):\n data.append(\n {\n \"REGION\": reg,\n \"LOCATION_ID\": f\"WH_{reg}_{j}\",\n \"TYPE\": \"WAREHOUSE\",\n \"LAT\": np.random.uniform(25, 55),\n \"LON\": np.random.uniform(-130, -60),\n \"CAPACITY\": 0,\n \"DEMAND\": 250,\n }\n )\n\n df = pd.DataFrame(data)\n sdf = session.create_dataframe(df)\n sdf.write.mode(\"overwrite\").save_as_table(table_name)\n print(f\"Created '{table_name}' with {len(df)} rows across {len(regions)} regions\")\n return session.table(table_name)\n\n\nsupply_chain_sdf = create_supply_chain_data(session, input_table)\nsupply_chain_sdf.show()" + "source": [ + "def create_supply_chain_data(session, table_name):\n", + " \"\"\"Generate synthetic supply chain data with factories and warehouses across regions.\"\"\"\n", + " regions = [\"NA_WEST\", \"NA_EAST\", \"EMEA_CENTRAL\", \"APAC_SOUTH\", \"LATAM\"]\n", + " np.random.seed(42)\n", + " data = []\n", + "\n", + " for reg in regions:\n", + " # 3 Factories per region (supply)\n", + " for i in range(3):\n", + " data.append(\n", + " {\n", + " \"REGION\": reg,\n", + " \"LOCATION_ID\": f\"FACT_{reg}_{i}\",\n", + " \"TYPE\": \"FACTORY\",\n", + " \"LAT\": np.random.uniform(25, 55),\n", + " \"LON\": np.random.uniform(-130, -60),\n", + " \"CAPACITY\": 1000,\n", + " \"DEMAND\": 0,\n", + " }\n", + " )\n", + " # 10 Warehouses per region (demand)\n", + " for j in range(10):\n", + " data.append(\n", + " {\n", + " \"REGION\": reg,\n", + " \"LOCATION_ID\": f\"WH_{reg}_{j}\",\n", + " \"TYPE\": \"WAREHOUSE\",\n", + " \"LAT\": np.random.uniform(25, 55),\n", + " \"LON\": np.random.uniform(-130, -60),\n", + " \"CAPACITY\": 0,\n", + " \"DEMAND\": 250,\n", + " }\n", + " )\n", + "\n", + " df = pd.DataFrame(data)\n", + " sdf = session.create_dataframe(df)\n", + " sdf.write.mode(\"overwrite\").save_as_table(table_name)\n", + " print(f\"Created '{table_name}' with {len(df)} rows across {len(regions)} regions\")\n", + " return session.table(table_name)\n", + "\n", + "\n", + "supply_chain_sdf = create_supply_chain_data(session, input_table)\n", + "supply_chain_sdf.show()" + ] }, { "cell_type": "markdown", "id": "a0cafdbc-ee9b-4fb4-b228-820dc3dcf5c1", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell7", - "codeCollapsed": true - }, - "source": "---\n## DataFrame Mode: Process Data by Column Partitions\n\nPartition the `SUPPLY_CHAIN_DATA` table by `REGION` and solve each region's allocation in parallel.\n\n1. **Define the processing function** - optimization logic that runs on each partition.\n2. **Initialize and run DPF** - launch parallel execution across all partitions.\n3. **Monitor progress** - track status and wait for completion.\n4. **Retrieve results** - collect artifacts and output data from each partition.\n5. **Restore a completed run** - access results from a previous run without re-executing.\n\n### Step 1: Define the Processing Function\n\nThis function runs on each partition (region). It receives the partition's data via `data_connector` and\nuses `scipy.optimize.linprog` to solve the transportation problem: minimize shipping cost while\nsatisfying warehouse demand without exceeding factory capacity." + "name": "cell7" + }, + "source": [ + "---\n", + "## DataFrame Mode: Process Data by Column Partitions\n", + "\n", + "Partition the `SUPPLY_CHAIN_DATA` table by `REGION` and solve each region's allocation in parallel.\n", + "\n", + "1. **Define the processing function** - optimization logic that runs on each partition.\n", + "2. **Initialize and run DPF** - launch parallel execution across all partitions.\n", + "3. **Monitor progress** - track status and wait for completion.\n", + "4. **Retrieve results** - collect artifacts and output data from each partition.\n", + "5. **Restore a completed run** - access results from a previous run without re-executing.\n", + "\n", + "### Step 1: Define the Processing Function\n", + "\n", + "This function runs on each partition (region). It receives the partition's data via `data_connector` and\n", + "uses `scipy.optimize.linprog` to solve the transportation problem: minimize shipping cost while\n", + "satisfying warehouse demand without exceeding factory capacity." + ] }, { "cell_type": "code", @@ -92,17 +218,114 @@ "name": "cell8" }, "outputs": [], - "source": "def solve_allocation(data_connector, context):\n \"\"\"\n Solve the supply chain allocation problem for a single region.\n\n Uses linear programming to find the optimal shipment plan that minimizes\n total transportation cost (based on distance) subject to:\n - Factory capacity constraints (supply)\n - Warehouse demand constraints (demand)\n\n Args:\n data_connector: Provides access to the partition's data.\n context: PartitionContext with partition_id and artifact methods.\n \"\"\"\n from scipy.optimize import linprog\n from scipy.spatial.distance import cdist\n import pandas as pd\n import numpy as np\n import json\n\n df = data_connector.to_pandas()\n region = context.partition_id\n print(f\"[{region}] Processing {len(df)} locations\")\n\n factories = df[df[\"TYPE\"] == \"FACTORY\"].reset_index(drop=True)\n warehouses = df[df[\"TYPE\"] == \"WAREHOUSE\"].reset_index(drop=True)\n n_fact = len(factories)\n n_wh = len(warehouses)\n\n # Build cost matrix (Euclidean distance as proxy for shipping cost)\n cost_matrix = cdist(\n factories[[\"LAT\", \"LON\"]], warehouses[[\"LAT\", \"LON\"]], metric=\"euclidean\"\n )\n c = cost_matrix.flatten()\n\n # Inequality constraint: total outbound from Factory_i <= Capacity_i\n A_ub = np.zeros((n_fact, n_fact * n_wh))\n for i in range(n_fact):\n A_ub[i, i * n_wh : (i + 1) * n_wh] = 1\n b_ub = factories[\"CAPACITY\"].values.astype(float)\n\n # Equality constraint: total inbound to Warehouse_j == Demand_j\n A_eq = np.zeros((n_wh, n_fact * n_wh))\n for j in range(n_wh):\n A_eq[j, j::n_wh] = 1\n b_eq = warehouses[\"DEMAND\"].values.astype(float)\n\n # Solve\n res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, method=\"highs\")\n\n if res.success:\n allocation = res.x.reshape((n_fact, n_wh))\n manifest = []\n for i in range(n_fact):\n for j in range(n_wh):\n qty = allocation[i, j]\n if qty > 0.1:\n manifest.append(\n {\n \"REGION\": region,\n \"ORIGIN\": factories.loc[i, \"LOCATION_ID\"],\n \"DESTINATION\": warehouses.loc[j, \"LOCATION_ID\"],\n \"SHIPMENT_QTY\": round(float(qty), 2),\n \"UNIT_DISTANCE\": round(float(cost_matrix[i, j]), 4),\n }\n )\n\n manifest_df = pd.DataFrame(manifest)\n\n summary = {\n \"region\": region,\n \"status\": \"OPTIMAL\",\n \"total_cost\": round(float(res.fun), 2),\n \"shipment_count\": len(manifest),\n \"total_units_shipped\": round(sum(m[\"SHIPMENT_QTY\"] for m in manifest), 2),\n }\n print(\n f\"[{region}] Optimal cost: {summary['total_cost']}, shipments: {len(manifest)}\"\n )\n\n # Upload summary as a stage artifact\n context.upload_to_stage(\n summary,\n \"summary.json\",\n write_function=lambda obj, path: json.dump(obj, open(path, \"w\")),\n )\n\n # Write results to a Snowflake table using the bounded session pool\n context.with_session(\n lambda session: session.create_dataframe(manifest_df)\n .write.mode(\"append\")\n .save_as_table(output_table)\n )\n else:\n print(f\"[{region}] Optimization failed: {res.message}\")" + "source": [ + "def solve_allocation(data_connector, context):\n", + " \"\"\"\n", + " Solve the supply chain allocation problem for a single region.\n", + "\n", + " Uses linear programming to find the optimal shipment plan that minimizes\n", + " total transportation cost (based on distance) subject to:\n", + " - Factory capacity constraints (supply)\n", + " - Warehouse demand constraints (demand)\n", + "\n", + " Args:\n", + " data_connector: Provides access to the partition's data.\n", + " context: PartitionContext with partition_id and artifact methods.\n", + " \"\"\"\n", + " from scipy.optimize import linprog\n", + " from scipy.spatial.distance import cdist\n", + " import pandas as pd\n", + " import numpy as np\n", + " import json\n", + "\n", + " df = data_connector.to_pandas()\n", + " region = context.partition_id\n", + " print(f\"[{region}] Processing {len(df)} locations\")\n", + "\n", + " factories = df[df[\"TYPE\"] == \"FACTORY\"].reset_index(drop=True)\n", + " warehouses = df[df[\"TYPE\"] == \"WAREHOUSE\"].reset_index(drop=True)\n", + " n_fact = len(factories)\n", + " n_wh = len(warehouses)\n", + "\n", + " # Build cost matrix (Euclidean distance as proxy for shipping cost)\n", + " cost_matrix = cdist(\n", + " factories[[\"LAT\", \"LON\"]], warehouses[[\"LAT\", \"LON\"]], metric=\"euclidean\"\n", + " )\n", + " c = cost_matrix.flatten()\n", + "\n", + " # Inequality constraint: total outbound from Factory_i <= Capacity_i\n", + " A_ub = np.zeros((n_fact, n_fact * n_wh))\n", + " for i in range(n_fact):\n", + " A_ub[i, i * n_wh : (i + 1) * n_wh] = 1\n", + " b_ub = factories[\"CAPACITY\"].values.astype(float)\n", + "\n", + " # Equality constraint: total inbound to Warehouse_j == Demand_j\n", + " A_eq = np.zeros((n_wh, n_fact * n_wh))\n", + " for j in range(n_wh):\n", + " A_eq[j, j::n_wh] = 1\n", + " b_eq = warehouses[\"DEMAND\"].values.astype(float)\n", + "\n", + " # Solve\n", + " res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, method=\"highs\")\n", + "\n", + " if res.success:\n", + " allocation = res.x.reshape((n_fact, n_wh))\n", + " manifest = []\n", + " for i in range(n_fact):\n", + " for j in range(n_wh):\n", + " qty = allocation[i, j]\n", + " if qty > 0.1:\n", + " manifest.append(\n", + " {\n", + " \"REGION\": region,\n", + " \"ORIGIN\": factories.loc[i, \"LOCATION_ID\"],\n", + " \"DESTINATION\": warehouses.loc[j, \"LOCATION_ID\"],\n", + " \"SHIPMENT_QTY\": round(float(qty), 2),\n", + " \"UNIT_DISTANCE\": round(float(cost_matrix[i, j]), 4),\n", + " }\n", + " )\n", + "\n", + " manifest_df = pd.DataFrame(manifest)\n", + "\n", + " summary = {\n", + " \"region\": region,\n", + " \"status\": \"OPTIMAL\",\n", + " \"total_cost\": round(float(res.fun), 2),\n", + " \"shipment_count\": len(manifest),\n", + " \"total_units_shipped\": round(sum(m[\"SHIPMENT_QTY\"] for m in manifest), 2),\n", + " }\n", + " print(\n", + " f\"[{region}] Optimal cost: {summary['total_cost']}, shipments: {len(manifest)}\"\n", + " )\n", + "\n", + " # Upload summary as a stage artifact\n", + " context.upload_to_stage(\n", + " summary,\n", + " \"summary.json\",\n", + " write_function=lambda obj, path: json.dump(obj, open(path, \"w\")),\n", + " )\n", + "\n", + " # Write results to a Snowflake table using the bounded session pool\n", + " context.with_session(\n", + " lambda session: session.create_dataframe(manifest_df)\n", + " .write.mode(\"append\")\n", + " .save_as_table(output_table)\n", + " )\n", + " else:\n", + " print(f\"[{region}] Optimization failed: {res.message}\")" + ] }, { "cell_type": "markdown", "id": "7d6f4691-a3cf-4c46-8785-1aa016476b5d", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell9", - "codeCollapsed": true + "name": "cell9" }, - "source": "### Step 2: Initialize and Run DPF" + "source": [ + "### Step 2: Initialize and Run DPF" + ] }, { "cell_type": "code", @@ -113,17 +336,31 @@ "name": "cell10" }, "outputs": [], - "source": "dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n\nsession.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n\nrun = dpf.run(\n partition_by=\"REGION\",\n snowpark_dataframe=session.table(input_table),\n run_id=f\"supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n execution_options=ExecutionOptions(use_head_node=True, num_cpus_per_worker=1),\n)\nprint(f\"Launched: {run.run_id}\")" + "source": [ + "dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n", + "\n", + "session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", + "\n", + "run = dpf.run(\n", + " partition_by=\"REGION\",\n", + " snowpark_dataframe=session.table(input_table),\n", + " run_id=f\"supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n", + " execution_options=ExecutionOptions(use_head_node=True, num_cpus_per_worker=1),\n", + ")\n", + "print(f\"Launched: {run.run_id}\")" + ] }, { "cell_type": "markdown", "id": "f4b8b230-b983-48ca-a9e7-18ca97279764", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell11", - "codeCollapsed": true + "name": "cell11" }, - "source": "### Step 3: Monitor Progress and Wait for Completion" + "source": [ + "### Step 3: Monitor Progress and Wait for Completion" + ] }, { "cell_type": "code", @@ -134,7 +371,10 @@ "name": "cell12" }, "outputs": [], - "source": "final_status = run.wait() # Shows progress\nprint(f\"Job completed with status: {final_status}\")" + "source": [ + "final_status = run.wait() # Shows progress\n", + "print(f\"Job completed with status: {final_status}\")" + ] }, { "cell_type": "code", @@ -145,17 +385,24 @@ "name": "cell13" }, "outputs": [], - "source": "# Quick summary of all partition statuses\nprogress = run.get_progress()\nfor status, partitions in progress.items():\n print(f\"{status}: {len(partitions)} partitions\")" + "source": [ + "# Quick summary of all partition statuses\n", + "progress = run.get_progress()\n", + "for status, partitions in progress.items():\n", + " print(f\"{status}: {len(partitions)} partitions\")" + ] }, { "cell_type": "markdown", "id": "2f92856d-4e0f-45cb-ba82-1b2ca6047ba6", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell14", - "codeCollapsed": true + "name": "cell14" }, - "source": "### Step 4: Retrieve Results from Each Partition" + "source": [ + "### Step 4: Retrieve Results from Each Partition" + ] }, { "cell_type": "code", @@ -166,7 +413,16 @@ "name": "cell15" }, "outputs": [], - "source": "def print_results(summaries):\n \"\"\"Format and display the supply chain optimization results.\"\"\"\n for s in summaries:\n print(f\" {s['region']}: cost={s['total_cost']}, shipments={s['shipment_count']}\")\n\n total_cost = sum(s[\"total_cost\"] for s in summaries)\n total_shipments = sum(s[\"shipment_count\"] for s in summaries)\n print(f\"\\n TOTAL: cost={total_cost:.2f}, shipments={total_shipments}\")" + "source": [ + "def print_results(summaries):\n", + " \"\"\"Format and display the supply chain optimization results.\"\"\"\n", + " for s in summaries:\n", + " print(f\" {s['region']}: cost={s['total_cost']}, shipments={s['shipment_count']}\")\n", + "\n", + " total_cost = sum(s[\"total_cost\"] for s in summaries)\n", + " total_shipments = sum(s[\"shipment_count\"] for s in summaries)\n", + " print(f\"\\n TOTAL: cost={total_cost:.2f}, shipments={total_shipments}\")" + ] }, { "cell_type": "code", @@ -177,7 +433,23 @@ "name": "cell16" }, "outputs": [], - "source": "if final_status == RunStatus.SUCCESS:\n summaries = []\n for partition_id, details in run.partition_details.items():\n files = details.stage_artifacts_manager.list()\n print(f\"Partition '{partition_id}' artifacts: {files}\")\n\n summary = details.stage_artifacts_manager.get(\n \"summary.json\",\n read_function=lambda path: json.load(open(path, \"r\")),\n )\n summaries.append(summary)\n\n print_results(summaries)\nelse:\n print(f\"Run did not succeed: {final_status}\")" + "source": [ + "if final_status == RunStatus.SUCCESS:\n", + " summaries = []\n", + " for partition_id, details in run.partition_details.items():\n", + " files = details.stage_artifacts_manager.list()\n", + " print(f\"Partition '{partition_id}' artifacts: {files}\")\n", + "\n", + " summary = details.stage_artifacts_manager.get(\n", + " \"summary.json\",\n", + " read_function=lambda path: json.load(open(path, \"r\")),\n", + " )\n", + " summaries.append(summary)\n", + "\n", + " print_results(summaries)\n", + "else:\n", + " print(f\"Run did not succeed: {final_status}\")" + ] }, { "cell_type": "code", @@ -188,17 +460,24 @@ "name": "cell17" }, "outputs": [], - "source": "# View the results written to the Snowflake table\nsession.table(output_table).show()" + "source": [ + "# View the results written to the Snowflake table\n", + "session.table(output_table).show()" + ] }, { "cell_type": "markdown", "id": "7fcab272-38a5-47b4-8a10-74f5f87eb299", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell18", - "codeCollapsed": true + "name": "cell18" }, - "source": "### Inspect Partition Logs\n\nView stdout/stderr from individual partitions to verify processing or debug failures." + "source": [ + "### Inspect Partition Logs\n", + "\n", + "View stdout/stderr from individual partitions to verify processing or debug failures." + ] }, { "cell_type": "code", @@ -209,7 +488,12 @@ "name": "cell19" }, "outputs": [], - "source": "# View logs from each partition\nfor partition_id, details in run.partition_details.items():\n print(f\"--- {partition_id} ---\")\n print(details.logs)" + "source": [ + "# View logs from each partition\n", + "for partition_id, details in run.partition_details.items():\n", + " print(f\"--- {partition_id} ---\")\n", + " print(details.logs)" + ] }, { "cell_type": "code", @@ -220,17 +504,27 @@ "name": "cell20" }, "outputs": [], - "source": "# Debug failed partitions (if any)\n# progress = run.get_progress()\n# for partition in progress.get(\"FAILED\", []):\n# print(f\"--- Failed: {partition.partition_id} ---\")\n# print(partition.logs)" + "source": [ + "# Debug failed partitions (if any)\n", + "# progress = run.get_progress()\n", + "# for partition in progress.get(\"FAILED\", []):\n", + "# print(f\"--- Failed: {partition.partition_id} ---\")\n", + "# print(partition.logs)" + ] }, { "cell_type": "markdown", "id": "a9ad12c6-faf5-4fb1-b5de-c2289b9f4670", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell21", - "codeCollapsed": true + "name": "cell21" }, - "source": "### Step 5: Restore Results from a Completed Run\n\nAccess results from a previous run without re-executing. Useful after restarting a notebook session." + "source": [ + "### Step 5: Restore Results from a Completed Run\n", + "\n", + "Access results from a previous run without re-executing. Useful after restarting a notebook session." + ] }, { "cell_type": "code", @@ -241,17 +535,34 @@ "name": "cell22" }, "outputs": [], - "source": "restored_run = DPFRun.restore_from(\n run_id=run.run_id,\n stage_name=dpf_stage,\n)\n\nprint(f\"Restored run status: {restored_run.status}\")\nfor partition_id, details in restored_run.partition_details.items():\n print(f\" {partition_id}: {details.status}\")\n\n# Note: Restored runs are read-only. You cannot call wait() or cancel() on them." + "source": [ + "restored_run = DPFRun.restore_from(\n", + " run_id=run.run_id,\n", + " stage_name=dpf_stage,\n", + ")\n", + "\n", + "print(f\"Restored run status: {restored_run.status}\")\n", + "for partition_id, details in restored_run.partition_details.items():\n", + " print(f\" {partition_id}: {details.status}\")\n", + "\n", + "# Note: Restored runs are read-only. You cannot call wait() or cancel() on them." + ] }, { "cell_type": "markdown", "id": "9e8d3b12-9019-4153-b808-195d16356657", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell23", - "codeCollapsed": true + "name": "cell23" }, - "source": "---\n## Stage Mode: Process Files from a Stage\n\nProcess pre-staged parquet files where each file becomes a partition.\nFirst, copy the data from the table to stage as parquet files, one per region." + "source": [ + "---\n", + "## Stage Mode: Process Files from a Stage\n", + "\n", + "Process pre-staged parquet files where each file becomes a partition.\n", + "First, copy the data from the table to stage as parquet files, one per region." + ] }, { "cell_type": "code", @@ -262,17 +573,38 @@ "name": "cell24" }, "outputs": [], - "source": "# Prepare parquet files on stage - one file per region\nsession.sql(f\"REMOVE @{input_stage}/supply_chain/\").collect()\n\nsession.sql(\n f\"\"\"\n COPY INTO @{input_stage}/supply_chain/\n FROM {input_table}\n PARTITION BY REGION\n FILE_FORMAT = (TYPE = PARQUET COMPRESSION = SNAPPY)\n HEADER = TRUE\n\"\"\"\n).collect()\n\n# Verify staged files\nsession.sql(f\"LIST @{input_stage}/supply_chain/\").show()" + "source": [ + "# Prepare parquet files on stage - one file per region\n", + "session.sql(f\"REMOVE @{input_stage}/supply_chain/\").collect()\n", + "\n", + "session.sql(\n", + " f\"\"\"\n", + " COPY INTO @{input_stage}/supply_chain/\n", + " FROM {input_table}\n", + " PARTITION BY REGION\n", + " FILE_FORMAT = (TYPE = PARQUET COMPRESSION = SNAPPY)\n", + " HEADER = TRUE\n", + "\"\"\"\n", + ").collect()\n", + "\n", + "# Verify staged files\n", + "session.sql(f\"LIST @{input_stage}/supply_chain/\").show()" + ] }, { "cell_type": "markdown", "id": "c5e47877-9163-49a6-a6ce-ac8bf82e23b1", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell25", - "codeCollapsed": true + "name": "cell25" }, - "source": "### Run DPF from Stage\n\nThe processing function signature is the same as DataFrame mode. The `data_connector` provides access\nto each file's data, and `context.partition_id` is the relative file path." + "source": [ + "### Run DPF from Stage\n", + "\n", + "The processing function signature is the same as DataFrame mode. The `data_connector` provides access\n", + "to each file's data, and `context.partition_id` is the relative file path." + ] }, { "cell_type": "code", @@ -283,7 +615,22 @@ "name": "cell26" }, "outputs": [], - "source": "dpf_from_stage = DPF(func=solve_allocation, stage_name=dpf_stage)\n\nsession.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n\nstage_run = dpf_from_stage.run_from_stage(\n stage_location=f\"@{input_stage}/supply_chain/\",\n run_id=f\"supply_chain_stage_{datetime.now():%Y%m%d_%H%M%S}\",\n file_pattern=\"*.parquet\",\n execution_options=ExecutionOptions(\n use_head_node=True,\n num_cpus_per_worker=1,\n ),\n)\nprint(f\"Launched: {stage_run.run_id}\")" + "source": [ + "dpf_from_stage = DPF(func=solve_allocation, stage_name=dpf_stage)\n", + "\n", + "session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", + "\n", + "stage_run = dpf_from_stage.run_from_stage(\n", + " stage_location=f\"@{input_stage}/supply_chain/\",\n", + " run_id=f\"supply_chain_stage_{datetime.now():%Y%m%d_%H%M%S}\",\n", + " file_pattern=\"*.parquet\",\n", + " execution_options=ExecutionOptions(\n", + " use_head_node=True,\n", + " num_cpus_per_worker=1,\n", + " ),\n", + ")\n", + "print(f\"Launched: {stage_run.run_id}\")" + ] }, { "cell_type": "code", @@ -294,7 +641,10 @@ "name": "cell27" }, "outputs": [], - "source": "stage_status = stage_run.wait()\nprint(f\"Stage mode completed with status: {stage_status}\")" + "source": [ + "stage_status = stage_run.wait()\n", + "print(f\"Stage mode completed with status: {stage_status}\")" + ] }, { "cell_type": "code", @@ -305,17 +655,26 @@ "name": "cell28" }, "outputs": [], - "source": "# View the results written to the Snowflake table\nsession.table(output_table).show()" + "source": [ + "# View the results written to the Snowflake table\n", + "session.table(output_table).show()" + ] }, { "cell_type": "markdown", "id": "f1cad585-2349-4d16-ad60-2de39db7a30b", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell29", - "codeCollapsed": true + "name": "cell29" }, - "source": "---\n## Deploy with ML Jobs via `@remote`\n\nRun DPF in an ML Job from any IDE. ML Jobs execute inside Snowpark Container Services\nand can scale across multiple nodes. Logs are available in Snowsight under Monitoring > Services & Jobs." + "source": [ + "---\n", + "## Deploy with ML Jobs via `@remote`\n", + "\n", + "Run DPF in an ML Job from any IDE. ML Jobs execute inside Snowpark Container Services\n", + "and can scale across multiple nodes. Logs are available in Snowsight under Monitoring > Services & Jobs." + ] }, { "cell_type": "code", @@ -326,7 +685,12 @@ "name": "cell30" }, "outputs": [], - "source": "job_stage = f\"{database}.{schema}.DPF_JOB_STAGE\"\ncompute_pool = \"SYSTEM_COMPUTE_POOL_CPU\" # Update with your compute pool name\n\nsession.sql(f\"CREATE STAGE IF NOT EXISTS {job_stage}\").collect()" + "source": [ + "job_stage = f\"{database}.{schema}.DPF_JOB_STAGE\"\n", + "compute_pool = \"SYSTEM_COMPUTE_POOL_CPU\" # Update with your compute pool name\n", + "\n", + "session.sql(f\"CREATE STAGE IF NOT EXISTS {job_stage}\").collect()" + ] }, { "cell_type": "code", @@ -337,7 +701,48 @@ "name": "cell31" }, "outputs": [], - "source": "from snowflake.ml.jobs import remote\n\n@remote(\n compute_pool=compute_pool,\n stage_name=job_stage,\n target_instances=3,\n)\ndef launch_supply_chain_job():\n \"\"\"\n Launch a DPF supply chain optimization run as an ML Job.\n \"\"\"\n from datetime import datetime\n from snowflake.snowpark import Session\n from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import (\n DPF,\n )\n from snowflake.ml.modeling.distributors.distributed_partition_function.entities import (\n ExecutionOptions,\n )\n\n session = Session.builder.getOrCreate()\n dpf_input = session.table(input_table)\n\n dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n run = dpf.run(\n partition_by=\"REGION\",\n snowpark_dataframe=dpf_input,\n run_id=f\"job_supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n execution_options=ExecutionOptions(use_head_node=True),\n )\n run.wait()\n\n print(f\"DPF run complete: {run.run_id}\")\n return run.run_id\n\nsession.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n\njob = launch_supply_chain_job()\nprint(f\"Job ID: {job.id}\")\nprint(f\"Status: {job.status}\")" + "source": [ + "from snowflake.ml.jobs import remote\n", + "\n", + "@remote(\n", + " compute_pool=compute_pool,\n", + " stage_name=job_stage,\n", + " target_instances=3,\n", + ")\n", + "def launch_supply_chain_job():\n", + " \"\"\"\n", + " Launch a DPF supply chain optimization run as an ML Job.\n", + " \"\"\"\n", + " from datetime import datetime\n", + " from snowflake.snowpark import Session\n", + " from snowflake.ml.modeling.distributors.distributed_partition_function.dpf import (\n", + " DPF,\n", + " )\n", + " from snowflake.ml.modeling.distributors.distributed_partition_function.entities import (\n", + " ExecutionOptions,\n", + " )\n", + "\n", + " session = Session.builder.getOrCreate()\n", + " dpf_input = session.table(input_table)\n", + "\n", + " dpf = DPF(func=solve_allocation, stage_name=dpf_stage)\n", + " run = dpf.run(\n", + " partition_by=\"REGION\",\n", + " snowpark_dataframe=dpf_input,\n", + " run_id=f\"job_supply_chain_{datetime.now():%Y%m%d_%H%M%S}\",\n", + " execution_options=ExecutionOptions(use_head_node=True),\n", + " )\n", + " run.wait()\n", + "\n", + " print(f\"DPF run complete: {run.run_id}\")\n", + " return run.run_id\n", + "\n", + "session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", + "\n", + "job = launch_supply_chain_job()\n", + "print(f\"Job ID: {job.id}\")\n", + "print(f\"Status: {job.status}\")" + ] }, { "cell_type": "code", @@ -348,7 +753,11 @@ "name": "cell32" }, "outputs": [], - "source": "# Check the status and logs of the ML Job\nprint(f\"Status: {job.status}\")\nprint(job.get_logs())" + "source": [ + "# Check the status and logs of the ML Job\n", + "print(f\"Status: {job.status}\")\n", + "print(job.get_logs())" + ] }, { "cell_type": "code", @@ -359,17 +768,25 @@ "name": "cell33" }, "outputs": [], - "source": "# View the results written to the Snowflake table\nsession.table(output_table).show()" + "source": [ + "# View the results written to the Snowflake table\n", + "session.table(output_table).show()" + ] }, { "cell_type": "markdown", "id": "d3697df9-557d-4603-be01-d4dda3ac9b3b", "metadata": { + "codeCollapsed": true, "collapsed": false, - "name": "cell34", - "codeCollapsed": true + "name": "cell34" }, - "source": "---\n## Cleanup\n\nScale the cluster back down to a single node when you're done." + "source": [ + "---\n", + "## Cleanup\n", + "\n", + "Scale the cluster back down to a single node when you're done." + ] }, { "cell_type": "code", @@ -380,7 +797,16 @@ "name": "cell35" }, "outputs": [], - "source": "scale_cluster(expected_cluster_size=1)\n\n# Uncomment to drop objects created by this notebook\n# session.sql(f\"DROP TABLE IF EXISTS {input_table}\").collect()\n# session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n# session.sql(f\"DROP STAGE IF EXISTS {dpf_stage}\").collect()\n# session.sql(f\"DROP STAGE IF EXISTS {input_stage}\").collect()\n# session.sql(f\"DROP STAGE IF EXISTS {job_stage}\").collect()" + "source": [ + "scale_cluster(expected_cluster_size=1)\n", + "\n", + "# Uncomment to drop objects created by this notebook\n", + "# session.sql(f\"DROP TABLE IF EXISTS {input_table}\").collect()\n", + "# session.sql(f\"DROP TABLE IF EXISTS {output_table}\").collect()\n", + "# session.sql(f\"DROP STAGE IF EXISTS {dpf_stage}\").collect()\n", + "# session.sql(f\"DROP STAGE IF EXISTS {input_stage}\").collect()\n", + "# session.sql(f\"DROP STAGE IF EXISTS {job_stage}\").collect()" + ] } ], "metadata": { @@ -396,4 +822,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +}