diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py index 910055ea..7752abc0 100644 --- a/xrspatial/geotiff/_writer.py +++ b/xrspatial/geotiff/_writer.py @@ -673,6 +673,37 @@ def _write(data: np.ndarray, path: str, *, # re-exported above for backwards compatibility. +def _max_streaming_row_span(row_chunks, tile_h, height): + """Worst-case source rows a tiled streaming compute materialises. + + The streaming writer computes one ``tile_h``-tall band per dask + ``.compute()``. For a plain windowed read that band materialises only + the source chunk-rows it overlaps. For a ``map_overlap`` source + (slope / aspect / curvature / hillshade) dask also pulls the + neighbouring chunk-row on each side to satisfy the halo, and a + windowed read materialises every touched chunk-row in full. ``depth`` + can never exceed one chunk, so a one-chunk halo each side is an upper + bound. Return the largest span across all bands so the column budget + is sized from the source geometry, not the output tile height (#3007). + """ + import bisect + offsets = [0] + for h in row_chunks: + offsets.append(offsets[-1] + int(h)) + n = len(row_chunks) + worst = tile_h + for r0 in range(0, height, tile_h): + r1 = min(r0 + tile_h, height) + first = bisect.bisect_right(offsets, r0) - 1 + last = bisect.bisect_right(offsets, r1 - 1) - 1 + lo = max(0, first - 1) + hi = min(n - 1, last + 1) + span = offsets[hi + 1] - offsets[lo] + if span > worst: + worst = span + return worst + + def _write_streaming(dask_data, path: str, *, geo_transform: 'GeoTransform | None' = None, crs_epsg: int | None = None, @@ -704,11 +735,17 @@ def _write_streaming(dask_data, path: str, *, rasters get bounded peak memory at the cost of more dask compute calls. - Peak materialised memory is approximately - ``min(streaming_buffer_bytes, tile_height * width * bytes_per_sample - * samples)`` for tiled output, or - ``rows_per_strip * width * bytes_per_sample * samples`` for stripped - output (no horizontal segmentation in strip mode). + For tiled output the horizontal-segment budget is sized from the + source chunk geometry rather than the output tile height: a + ``map_overlap`` source (slope / aspect / curvature) makes one + tile-row band pull every source chunk-row it touches plus a one-chunk + halo, so a source chunked taller than the tile would otherwise blow + past the cap (#3007). ``streaming_buffer_bytes`` stays a soft cap -- + the column halo of a 2D overlap adds a bounded couple of source + chunk-columns on top. Strip output (``tiled=False``) does no + horizontal segmentation; its peak is + ``rows_per_strip * width * bytes_per_sample * samples`` (plus any + overlap halo). After all pixel data is written the IFD offset and byte-count arrays are patched in place. @@ -1023,12 +1060,29 @@ def _write_streaming(dask_data, path: str, *, # Stream pixel data if tiled: # Decide how many tile-columns we can buffer at once. - # bytes_per_full_tile_row = tile_h * width * dtype * samples; - # if it fits the budget we buffer the whole row (matches - # original behaviour). Otherwise segment horizontally, - # always at tile boundaries to keep slicing aligned. + # Peak bytes per ``.compute()`` are set by the SOURCE chunk + # geometry, not the output tile height: a map_overlap source + # (slope / aspect / curvature) makes a single tile-row band + # pull every source chunk-row it touches plus a one-chunk + # halo, each materialised in full by the windowed read. Size + # the budget from that row span so a tall-chunk wide raster + # cannot blow past streaming_buffer_bytes (#3007). The column + # halo adds a bounded couple of source chunk-columns on top; + # streaming_buffer_bytes stays a soft cap. A non-overlap read + # carries no halo, so this is intentionally conservative for + # it -- it may segment one step early, which only costs an + # extra compute call near the cap. Falls back to the tile + # height for numpy-from-dask or unknown-chunk arrays. + materialized_h = th + row_chunks = getattr(dask_data, 'chunks', None) + if row_chunks: + try: + materialized_h = _max_streaming_row_span( + row_chunks[0], th, height) + except (TypeError, ValueError): + materialized_h = th bytes_per_tile_col = ( - th * tw * bytes_per_sample * samples) + materialized_h * tw * bytes_per_sample * samples) bytes_per_full_row = bytes_per_tile_col * tiles_across if bytes_per_full_row <= streaming_buffer_bytes: tiles_per_segment = tiles_across diff --git a/xrspatial/geotiff/_writers/eager.py b/xrspatial/geotiff/_writers/eager.py index bd78a91d..7b6ab612 100644 --- a/xrspatial/geotiff/_writers/eager.py +++ b/xrspatial/geotiff/_writers/eager.py @@ -99,10 +99,13 @@ def to_geotiff(data: xr.DataArray | np.ndarray, codec and options). Dask-backed DataArrays are written in streaming mode: one tile-row - at a time, without materialising the full array into RAM. Peak - memory is roughly ``tile_size * width * bytes_per_sample``. COG - output (``cog=True``) still materialises because overviews need the - full array. + at a time, without materialising the full array into RAM. The + per-compute budget is sized from the source chunk geometry, so a + ``map_overlap`` source (e.g. ``slope`` / ``aspect``) chunked taller + than the tile stays within ``streaming_buffer_bytes`` instead of + pulling several source chunk-rows at once (#3007). COG output + (``cog=True``) still materialises because overviews need the full + array. Automatically dispatches to GPU compression when: - ``gpu=True`` is passed, or diff --git a/xrspatial/geotiff/tests/write/test_streaming.py b/xrspatial/geotiff/tests/write/test_streaming.py index bdc9c416..85ca0917 100644 --- a/xrspatial/geotiff/tests/write/test_streaming.py +++ b/xrspatial/geotiff/tests/write/test_streaming.py @@ -563,6 +563,104 @@ def test_multiband_segmentation(self, tmp_path): result = open_geotiff(path).values np.testing.assert_array_almost_equal(result, arr, decimal=10) + def test_overlap_source_respects_buffer_3007(self, tmp_path, monkeypatch): + """A map_overlap source must not pull far more than the buffer. + + Slicing one tile-row out of a ``slope`` (map_overlap, depth 1) + result materialises every source chunk-row the band touches plus + a one-chunk halo. When the source chunks are taller than the tile + the old budget -- sized from the output tile height -- left the + whole tile-row in a single segment and the compute pulled several + times the buffer. Bound peak source bytes per compute (#3007). + """ + import dask.array as da + + from xrspatial import slope + + height, width, chunk = 1024, 8192, 512 # chunks taller than tile=256 + base = da.random.random( + (height, width), chunks=(chunk, chunk)).astype('float32') + + materialized = [] # bytes of each source chunk as dask pulls it + + def _record(block): + materialized.append(block.nbytes) + return block + + base = base.map_blocks(_record, dtype='float32') + y = np.linspace(40.0, 39.0, height) + x = np.linspace(-105.0, -104.0, width) + lazy = slope(xr.DataArray(base, dims=['y', 'x'], + coords={'y': y, 'x': x})) + + peak = {'bytes': 0} + orig_compute = da.Array.compute + + def spy_compute(self, *args, **kwargs): + before = sum(materialized) + result = orig_compute(self, *args, **kwargs) + peak['bytes'] = max(peak['bytes'], sum(materialized) - before) + return result + + monkeypatch.setattr(da.Array, 'compute', spy_compute) + + buf = 8 * 1024 * 1024 # 8 MB + path = str(tmp_path / 'overlap_budget_3007.tif') + to_geotiff(lazy, path, compression='zstd', + streaming_buffer_bytes=buf) + + # streaming_buffer_bytes is a soft cap; the column halo of the 2D + # overlap adds a bounded couple of source chunk-columns on top of + # the row budget. Allow 2x slack. The unfixed writer pulled ~4x + # (one full-width 2-chunk-row band == 32 MB) and trips this. + assert peak['bytes'] <= 2 * buf, ( + f"peak source bytes {peak['bytes']} exceeded 2x buffer {buf}") + + def test_overlap_source_roundtrip_small_buffer_3007(self, tmp_path): + """slope() of a tall-chunk wide raster round-trips under a tight cap. + + Exercises the source-aware segmentation end to end: the lazy + slope written with a small buffer must match the eager slope + (#3007). + """ + import dask.array as da + + from xrspatial import slope + + height, width, chunk = 768, 4096, 512 + rng = np.random.default_rng(3007) + npdata = rng.random((height, width)).astype('float32') + y = np.linspace(40.0, 39.0, height) + x = np.linspace(-105.0, -104.0, width) + + eager = slope(xr.DataArray(npdata, dims=['y', 'x'], + coords={'y': y, 'x': x})) + base = da.from_array(npdata, chunks=(chunk, chunk)) + lazy = slope(xr.DataArray(base, dims=['y', 'x'], + coords={'y': y, 'x': x})) + + path = str(tmp_path / 'overlap_roundtrip_3007.tif') + to_geotiff(lazy, path, compression='zstd', + streaming_buffer_bytes=4 * 1024 * 1024) + + back = open_geotiff(path).values + np.testing.assert_allclose(back, eager.values, + rtol=1e-5, atol=1e-5, equal_nan=True) + + +def test_max_streaming_row_span_3007(): + """The row-span helper accounts for the source chunk grid + halo.""" + from xrspatial.geotiff._writer import _max_streaming_row_span + + # A 256-tall band landing in a 512-tall chunk pulls the neighbour + # chunk-row too -> 1024 source rows. + assert _max_streaming_row_span((512, 512), 256, 1024) == 1024 + # Uniform 256 chunks: a middle band touches three chunk-rows (its own + # plus one halo each side). + assert _max_streaming_row_span((256, 256, 256, 256), 256, 1024) == 768 + # A single chunk-row spanning the whole height has no neighbour. + assert _max_streaming_row_span((256,), 256, 256) == 256 + # ------------------------------------------------------------------------- # Section: parallel per-tile streaming compress (P4)