diff --git a/.gitignore b/.gitignore index f085d7ff..11a8c9f2 100644 --- a/.gitignore +++ b/.gitignore @@ -97,3 +97,4 @@ dmypy.json .asv/ xrspatial-examples/ *.zarr/ +.claude/worktrees/ diff --git a/examples/dask/distributed_reprojection.ipynb b/examples/dask/distributed_reprojection.ipynb index 2c80452e..5f868727 100644 --- a/examples/dask/distributed_reprojection.ipynb +++ b/examples/dask/distributed_reprojection.ipynb @@ -35,10 +35,11 @@ "import xarray as xr\n", "import dask\n", "import dask.array as da\n", - "import matplotlib.pyplot as plt\n", + "\n", "from dask.distributed import Client, LocalCluster\n", "from pathlib import Path\n", "\n", + "import xrspatial\n", "from xrspatial import reproject" ] }, @@ -56,9 +57,9 @@ "outputs": [], "source": [ "cluster = LocalCluster(\n", - " n_workers=20,\n", - " threads_per_worker=1,\n", - " memory_limit=\"2GB\",\n", + " n_workers=4,\n", + " threads_per_worker=2,\n", + " memory_limit=\"10GB\",\n", ")\n", "client = Client(cluster)\n", "client" @@ -88,24 +89,15 @@ "source": [ "ZARR_PATH = Path.home() / \"elevation\" / \"usgs10m_dem_c6.zarr\"\n", "\n", - "ds = xr.open_zarr(ZARR_PATH)\n", + "ds = xr.open_zarr(ZARR_PATH).xrs.rechunk_no_shuffle(target_mb=512)\n", "ds" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ds.xrs.preview().plot()" - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Nothing has been read from disk yet. The repr above shows the Dask task graph backing the array. Each chunk is 2048 x 2048 pixels.\n", + "Nothing has been read from disk yet. The repr above shows the Dask task graph backing the array. `rechunk_no_shuffle` detects the Zarr source and re-opens it with larger chunks, so each dask task reads multiple storage chunks in one call. This keeps the task graph small even for a 29 TB store.\n", "\n", "Let's clip to Colorado. Good mix of flat plains and mountains, and small enough to finish in a reasonable time." ] @@ -146,12 +138,14 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "dem.xrs.preview().plot()" + "## Transform\n", + "\n", + "The source CRS is EPSG:4269 (NAD83, geographic lat/lon). We'll reproject to EPSG:5070 (NAD83 / Conus Albers Equal Area Conic), which gives equal-area cells in meters. That matters any time pixel area feeds into a calculation, like drainage area or cut/fill volumes, and it makes the DEM compatible with other projected datasets.\n", + "\n", + "`xrspatial.reproject` handles Dask arrays natively. It builds a lazy task graph where each output chunk is reprojected independently using numba-JIT'd resampling kernels." ] }, { @@ -162,7 +156,9 @@ "\n", "The source CRS is EPSG:4269 (NAD83, geographic lat/lon). We'll reproject to EPSG:5070 (NAD83 / Conus Albers Equal Area Conic), which gives equal-area cells in meters. That matters any time pixel area feeds into a calculation, like drainage area or cut/fill volumes, and it makes the DEM compatible with other projected datasets.\n", "\n", - "`xrspatial.reproject` handles Dask arrays natively. It builds a lazy task graph where each output chunk is reprojected independently using numba-JIT'd resampling kernels." + "`xrspatial.reproject` handles Dask arrays natively. It builds a lazy task graph where each output chunk is reprojected independently using numba-JIT'd resampling kernels.\n", + "\n", + "Set `chunk_size` to the desired **output** chunk size here rather than rechunking afterwards. A rechunk-merge layer after reproject creates intermediate results that pile up in cluster memory (every reproject result must be held until its merge group is complete). Writing the reproject output directly to Zarr avoids this entirely -- each result is consumed immediately." ] }, { @@ -182,18 +178,11 @@ " resolution=TARGET_RES,\n", " resampling=\"nearest\",\n", " nodata=np.nan,\n", - " chunk_size=2048,\n", + " chunk_size=4096,\n", ")\n", "dem_projected" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The result is still lazy. The repr shows the projected coordinate arrays and the new shape. No pixels have been resampled yet." - ] - }, { "cell_type": "code", "execution_count": null, @@ -259,7 +248,17 @@ "metadata": {}, "outputs": [], "source": [ - "ds_check.xrs.preview().plot()" + "small_ds = ds_check.xrs.preview()\n", + "small_ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "small_ds.xrs.plot()" ] }, { @@ -324,4 +323,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/examples/user_guide/36_Rechunk_No_Shuffle.ipynb b/examples/user_guide/36_Rechunk_No_Shuffle.ipynb index f890229d..93166010 100644 --- a/examples/user_guide/36_Rechunk_No_Shuffle.ipynb +++ b/examples/user_guide/36_Rechunk_No_Shuffle.ipynb @@ -9,7 +9,7 @@ "When working with large dask-backed rasters, rechunking to bigger blocks can\n", "speed up downstream operations like `slope()` or `focal_mean()` that use\n", "`map_overlap`. But if the new chunk size is not an exact multiple of the\n", - "original, dask has to split and recombine blocks — essentially a shuffle —\n", + "original, dask has to split and recombine blocks \u2014 essentially a shuffle \u2014\n", "which tanks performance.\n", "\n", "`rechunk_no_shuffle` picks the largest whole-chunk multiple that fits your\n", @@ -133,7 +133,7 @@ "## Non-dask arrays pass through unchanged\n", "\n", "If the input is a plain numpy-backed DataArray, the function returns it\n", - "as-is — no copy, no error." + "as-is \u2014 no copy, no error." ] }, { @@ -162,4 +162,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/xrspatial/accessor.py b/xrspatial/accessor.py index c621afab..e6195214 100644 --- a/xrspatial/accessor.py +++ b/xrspatial/accessor.py @@ -63,8 +63,9 @@ def plot(self, **kwargs): # Create a figure with sensible size if none provided. if 'ax' not in kwargs: fig, ax = plt.subplots( - figsize=kwargs.pop('figsize', (8, 6)), + figsize=kwargs.get('figsize', (8, 6)), ) + kwargs.pop('figsize', None) kwargs['ax'] = ax result = da.plot(**kwargs) @@ -1029,11 +1030,3 @@ def open_geotiff(self, source, **kwargs): def rechunk_no_shuffle(self, **kwargs): from .utils import rechunk_no_shuffle return rechunk_no_shuffle(self._obj, **kwargs) - - def fused_overlap(self, *stages, **kwargs): - from .utils import fused_overlap - return fused_overlap(self._obj, *stages, **kwargs) - - def multi_overlap(self, func, n_outputs, **kwargs): - from .utils import multi_overlap - return multi_overlap(self._obj, func, n_outputs, **kwargs) diff --git a/xrspatial/tests/test_rechunk_no_shuffle.py b/xrspatial/tests/test_rechunk_no_shuffle.py index be6faa93..feb4f9e3 100644 --- a/xrspatial/tests/test_rechunk_no_shuffle.py +++ b/xrspatial/tests/test_rechunk_no_shuffle.py @@ -102,6 +102,29 @@ def test_rejects_non_dataarray(): rechunk_no_shuffle(np.zeros((10, 10))) +def test_dataset_rechunk(): + """Dataset without zarr backing rechunks via the map() fallback.""" + ds = xr.Dataset({ + "elev": xr.DataArray( + da.from_array(np.random.rand(100, 100).astype(np.float32), + chunks=(10, 10)), + dims=["y", "x"], + ), + "slope": xr.DataArray( + da.from_array(np.random.rand(100, 100).astype(np.float32), + chunks=(10, 10)), + dims=["y", "x"], + ), + }) + result = rechunk_no_shuffle(ds, target_mb=1) + assert isinstance(result, xr.Dataset) + for name in ds.data_vars: + xr.testing.assert_equal(ds[name], result[name]) + # Chunks should be at least as large as the originals. + for orig, new in zip(ds[name].chunks, result[name].chunks): + assert new[0] >= orig[0] + + def test_rejects_nonpositive_target(): raster = _make_dask_raster() with pytest.raises(ValueError, match="target_mb must be > 0"): diff --git a/xrspatial/utils.py b/xrspatial/utils.py index 63de67ce..26013b81 100644 --- a/xrspatial/utils.py +++ b/xrspatial/utils.py @@ -1028,8 +1028,35 @@ def _sample_windows_min_max( return float(np.nanmin(np.array(mins, dtype=float))), float(np.nanmax(np.array(maxs, dtype=float))) +def _no_shuffle_chunks(chunks, dtype, dims, target_mb): + """Compute target chunk dict that is an exact multiple of *chunks*. + + Returns a ``{dim: size}`` dict, or ``None`` when the current + chunks already meet or exceed the target. + """ + base = tuple(c[0] for c in chunks) + + current_bytes = dtype.itemsize + for b in base: + current_bytes *= b + + target_bytes = target_mb * 1024 * 1024 + + if current_bytes >= target_bytes: + return None + + ndim = len(base) + ratio = target_bytes / current_bytes + multiplier = max(1, int(ratio ** (1.0 / ndim))) + + if multiplier <= 1: + return None + + return {dim: b * multiplier for dim, b in zip(dims, base)} + + def rechunk_no_shuffle(agg, target_mb=128): - """Rechunk a dask-backed DataArray without triggering a shuffle. + """Rechunk a dask-backed DataArray or Dataset without triggering a shuffle. Computes an integer multiplier per dimension so that each new chunk is an exact multiple of the original chunk size. This lets dask @@ -1038,9 +1065,10 @@ def rechunk_no_shuffle(agg, target_mb=128): Parameters ---------- - agg : xr.DataArray - Input raster. If not backed by a dask array the input is - returned unchanged. + agg : xr.DataArray or xr.Dataset + Input raster(s). If not backed by a dask array the input is + returned unchanged. For Datasets, each variable is rechunked + independently. target_mb : int or float Target chunk size in megabytes. The actual chunk size will be the closest multiple of the source chunk that does not exceed @@ -1048,13 +1076,13 @@ def rechunk_no_shuffle(agg, target_mb=128): Returns ------- - xr.DataArray - Rechunked DataArray. Coordinates and attributes are preserved. + xr.DataArray or xr.Dataset + Rechunked object. Coordinates and attributes are preserved. Raises ------ TypeError - If *agg* is not an ``xr.DataArray``. + If *agg* is not an ``xr.DataArray`` or ``xr.Dataset``. ValueError If *target_mb* is not positive. @@ -1066,9 +1094,11 @@ def rechunk_no_shuffle(agg, target_mb=128): >>> big = rechunk_no_shuffle(arr, target_mb=64) >>> big.chunks # multiples of 256 """ + if isinstance(agg, xr.Dataset): + return _rechunk_dataset_no_shuffle(agg, target_mb) if not isinstance(agg, xr.DataArray): raise TypeError( - f"rechunk_no_shuffle(): expected xr.DataArray, " + f"rechunk_no_shuffle(): expected xr.DataArray or xr.Dataset, " f"got {type(agg).__name__}" ) if target_mb <= 0: @@ -1079,27 +1109,39 @@ def rechunk_no_shuffle(agg, target_mb=128): if not has_dask_array() or not isinstance(agg.data, da.Array): return agg - chunks = agg.chunks # tuple of tuples - base = tuple(c[0] for c in chunks) - - current_bytes = agg.dtype.itemsize - for b in base: - current_bytes *= b + new_chunks = _no_shuffle_chunks( + agg.chunks, agg.dtype, agg.dims, target_mb, + ) + if new_chunks is None: + return agg + return agg.chunk(new_chunks) - target_bytes = target_mb * 1024 * 1024 - if current_bytes >= target_bytes: - return agg +def _rechunk_dataset_no_shuffle(ds, target_mb): + """Rechunk every variable in a Dataset without triggering a shuffle.""" + if target_mb <= 0: + raise ValueError( + f"rechunk_no_shuffle(): target_mb must be > 0, got {target_mb}" + ) - ndim = len(base) - ratio = target_bytes / current_bytes - multiplier = max(1, int(ratio ** (1.0 / ndim))) + if not has_dask_array(): + return ds + + # Compute target chunks from the first dask-backed variable. + # This assumes all variables share the same chunk layout and dtype; + # for mixed-dtype Datasets the budget may overshoot on smaller types. + new_chunks = None + for var in ds.data_vars.values(): + if isinstance(var.data, da.Array): + new_chunks = _no_shuffle_chunks( + var.chunks, var.dtype, var.dims, target_mb, + ) + break - if multiplier <= 1: - return agg + if new_chunks is None: + return ds - new_chunks = {dim: b * multiplier for dim, b in zip(agg.dims, base)} - return agg.chunk(new_chunks) + return ds.chunk(new_chunks) def _normalize_depth(depth, ndim):