Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 64 additions & 10 deletions xrspatial/geotiff/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,37 @@ def _write(data: np.ndarray, path: str, *,
# re-exported above for backwards compatibility.


def _max_streaming_row_span(row_chunks, tile_h, height):
"""Worst-case source rows a tiled streaming compute materialises.

The streaming writer computes one ``tile_h``-tall band per dask
``.compute()``. For a plain windowed read that band materialises only
the source chunk-rows it overlaps. For a ``map_overlap`` source
(slope / aspect / curvature / hillshade) dask also pulls the
neighbouring chunk-row on each side to satisfy the halo, and a
windowed read materialises every touched chunk-row in full. ``depth``
can never exceed one chunk, so a one-chunk halo each side is an upper
bound. Return the largest span across all bands so the column budget
is sized from the source geometry, not the output tile height (#3007).
"""
import bisect
offsets = [0]
for h in row_chunks:
offsets.append(offsets[-1] + int(h))
n = len(row_chunks)
worst = tile_h
for r0 in range(0, height, tile_h):
r1 = min(r0 + tile_h, height)
first = bisect.bisect_right(offsets, r0) - 1
last = bisect.bisect_right(offsets, r1 - 1) - 1
lo = max(0, first - 1)
hi = min(n - 1, last + 1)
span = offsets[hi + 1] - offsets[lo]
if span > worst:
worst = span
return worst


def _write_streaming(dask_data, path: str, *,
geo_transform: 'GeoTransform | None' = None,
crs_epsg: int | None = None,
Expand Down Expand Up @@ -704,11 +735,17 @@ def _write_streaming(dask_data, path: str, *,
rasters get bounded peak memory at the cost of more dask compute
calls.

Peak materialised memory is approximately
``min(streaming_buffer_bytes, tile_height * width * bytes_per_sample
* samples)`` for tiled output, or
``rows_per_strip * width * bytes_per_sample * samples`` for stripped
output (no horizontal segmentation in strip mode).
For tiled output the horizontal-segment budget is sized from the
source chunk geometry rather than the output tile height: a
``map_overlap`` source (slope / aspect / curvature) makes one
tile-row band pull every source chunk-row it touches plus a one-chunk
halo, so a source chunked taller than the tile would otherwise blow
past the cap (#3007). ``streaming_buffer_bytes`` stays a soft cap --
the column halo of a 2D overlap adds a bounded couple of source
chunk-columns on top. Strip output (``tiled=False``) does no
horizontal segmentation; its peak is
``rows_per_strip * width * bytes_per_sample * samples`` (plus any
overlap halo).

After all pixel data is written the IFD offset and byte-count arrays
are patched in place.
Expand Down Expand Up @@ -1023,12 +1060,29 @@ def _write_streaming(dask_data, path: str, *,
# Stream pixel data
if tiled:
# Decide how many tile-columns we can buffer at once.
# bytes_per_full_tile_row = tile_h * width * dtype * samples;
# if it fits the budget we buffer the whole row (matches
# original behaviour). Otherwise segment horizontally,
# always at tile boundaries to keep slicing aligned.
# Peak bytes per ``.compute()`` are set by the SOURCE chunk
# geometry, not the output tile height: a map_overlap source
# (slope / aspect / curvature) makes a single tile-row band
# pull every source chunk-row it touches plus a one-chunk
# halo, each materialised in full by the windowed read. Size
# the budget from that row span so a tall-chunk wide raster
# cannot blow past streaming_buffer_bytes (#3007). The column
# halo adds a bounded couple of source chunk-columns on top;
# streaming_buffer_bytes stays a soft cap. A non-overlap read
# carries no halo, so this is intentionally conservative for
# it -- it may segment one step early, which only costs an
# extra compute call near the cap. Falls back to the tile
# height for numpy-from-dask or unknown-chunk arrays.
materialized_h = th
row_chunks = getattr(dask_data, 'chunks', None)
if row_chunks:
try:
materialized_h = _max_streaming_row_span(
row_chunks[0], th, height)
except (TypeError, ValueError):
materialized_h = th
bytes_per_tile_col = (
th * tw * bytes_per_sample * samples)
materialized_h * tw * bytes_per_sample * samples)
bytes_per_full_row = bytes_per_tile_col * tiles_across
if bytes_per_full_row <= streaming_buffer_bytes:
tiles_per_segment = tiles_across
Expand Down
11 changes: 7 additions & 4 deletions xrspatial/geotiff/_writers/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,13 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
codec and options).

Dask-backed DataArrays are written in streaming mode: one tile-row
at a time, without materialising the full array into RAM. Peak
memory is roughly ``tile_size * width * bytes_per_sample``. COG
output (``cog=True``) still materialises because overviews need the
full array.
at a time, without materialising the full array into RAM. The
per-compute budget is sized from the source chunk geometry, so a
``map_overlap`` source (e.g. ``slope`` / ``aspect``) chunked taller
than the tile stays within ``streaming_buffer_bytes`` instead of
pulling several source chunk-rows at once (#3007). COG output
(``cog=True``) still materialises because overviews need the full
array.

Automatically dispatches to GPU compression when:
- ``gpu=True`` is passed, or
Expand Down
98 changes: 98 additions & 0 deletions xrspatial/geotiff/tests/write/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,104 @@ def test_multiband_segmentation(self, tmp_path):
result = open_geotiff(path).values
np.testing.assert_array_almost_equal(result, arr, decimal=10)

def test_overlap_source_respects_buffer_3007(self, tmp_path, monkeypatch):
"""A map_overlap source must not pull far more than the buffer.

Slicing one tile-row out of a ``slope`` (map_overlap, depth 1)
result materialises every source chunk-row the band touches plus
a one-chunk halo. When the source chunks are taller than the tile
the old budget -- sized from the output tile height -- left the
whole tile-row in a single segment and the compute pulled several
times the buffer. Bound peak source bytes per compute (#3007).
"""
import dask.array as da

from xrspatial import slope

height, width, chunk = 1024, 8192, 512 # chunks taller than tile=256
base = da.random.random(
(height, width), chunks=(chunk, chunk)).astype('float32')

materialized = [] # bytes of each source chunk as dask pulls it

def _record(block):
materialized.append(block.nbytes)
return block

base = base.map_blocks(_record, dtype='float32')
y = np.linspace(40.0, 39.0, height)
x = np.linspace(-105.0, -104.0, width)
lazy = slope(xr.DataArray(base, dims=['y', 'x'],
coords={'y': y, 'x': x}))

peak = {'bytes': 0}
orig_compute = da.Array.compute

def spy_compute(self, *args, **kwargs):
before = sum(materialized)
result = orig_compute(self, *args, **kwargs)
peak['bytes'] = max(peak['bytes'], sum(materialized) - before)
return result

monkeypatch.setattr(da.Array, 'compute', spy_compute)

buf = 8 * 1024 * 1024 # 8 MB
path = str(tmp_path / 'overlap_budget_3007.tif')
to_geotiff(lazy, path, compression='zstd',
streaming_buffer_bytes=buf)

# streaming_buffer_bytes is a soft cap; the column halo of the 2D
# overlap adds a bounded couple of source chunk-columns on top of
# the row budget. Allow 2x slack. The unfixed writer pulled ~4x
# (one full-width 2-chunk-row band == 32 MB) and trips this.
assert peak['bytes'] <= 2 * buf, (
f"peak source bytes {peak['bytes']} exceeded 2x buffer {buf}")

def test_overlap_source_roundtrip_small_buffer_3007(self, tmp_path):
"""slope() of a tall-chunk wide raster round-trips under a tight cap.

Exercises the source-aware segmentation end to end: the lazy
slope written with a small buffer must match the eager slope
(#3007).
"""
import dask.array as da

from xrspatial import slope

height, width, chunk = 768, 4096, 512
rng = np.random.default_rng(3007)
npdata = rng.random((height, width)).astype('float32')
y = np.linspace(40.0, 39.0, height)
x = np.linspace(-105.0, -104.0, width)

eager = slope(xr.DataArray(npdata, dims=['y', 'x'],
coords={'y': y, 'x': x}))
base = da.from_array(npdata, chunks=(chunk, chunk))
lazy = slope(xr.DataArray(base, dims=['y', 'x'],
coords={'y': y, 'x': x}))

path = str(tmp_path / 'overlap_roundtrip_3007.tif')
to_geotiff(lazy, path, compression='zstd',
streaming_buffer_bytes=4 * 1024 * 1024)

back = open_geotiff(path).values
np.testing.assert_allclose(back, eager.values,
rtol=1e-5, atol=1e-5, equal_nan=True)


def test_max_streaming_row_span_3007():
"""The row-span helper accounts for the source chunk grid + halo."""
from xrspatial.geotiff._writer import _max_streaming_row_span

# A 256-tall band landing in a 512-tall chunk pulls the neighbour
# chunk-row too -> 1024 source rows.
assert _max_streaming_row_span((512, 512), 256, 1024) == 1024
# Uniform 256 chunks: a middle band touches three chunk-rows (its own
# plus one halo each side).
assert _max_streaming_row_span((256, 256, 256, 256), 256, 1024) == 768
# A single chunk-row spanning the whole height has no neighbour.
assert _max_streaming_row_span((256,), 256, 256) == 256


# -------------------------------------------------------------------------
# Section: parallel per-tile streaming compress (P4)
Expand Down
Loading