Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Unreleased

#### Bug fixes and improvements
- Shut down the per-tile compression `ThreadPoolExecutor` on every exit path of the streaming tiled-write code in `to_geotiff`. The old code only called `shutdown(wait=True)` after the tile-row loop completed, so any mid-stream raise (compression failure, dask compute failure, file write failure) bypassed shutdown and leaked worker threads. The loop now runs inside `try/finally` and the finally calls `shutdown(wait=True, cancel_futures=True)` so queued tiles get dropped on the error path instead of blocking the unwind. The pool's workers carry an `xrspatial-geotiff-tile-compress` `thread_name_prefix` so leak-detection tests can tell them apart from dask's own offload/scheduler pools. (#2276)
- Remove read-side emission of the 13 deprecated GeoTIFF attrs (`crs_name`, `geog_citation`, `datum_code`, `angular_units`, `semi_major_axis`, `inv_flattening`, `linear_units`, `projection_code`, `vertical_crs`, `vertical_citation`, `vertical_units`, `colormap_rgba`, `cmap`) and bump `attrs['_xrspatial_geotiff_contract']` from 1 to 2. Downstream code that read these via `attrs[key]` now sees `KeyError`; migrate to `attrs.get(key)` or derive the value from `attrs['crs']` / `attrs['crs_wkt']` with pyproj. The `.xrs.plot()` accessor still surfaces palette colormaps by building a `ListedColormap` from the canonical `attrs['colormap']`. (#2016)
- Accept numpy integer scalars as the `crs=` argument to `to_geotiff` / `write_geotiff_gpu`. The validator already allowed `numbers.Integral`, but the writers gated EPSG assignment on `isinstance(crs, int)`, so `np.int32` / `np.int64` / `np.uint16` values passed validation then silently fell through with no EPSG written. (#2082)
- Tighten the writer's no-georef sentinel for integer x/y coords. The pre-fix check treated any integer dtype on either axis as the read-side no-georef placeholder and skipped transform inference, which also caught user-authored projected grids with integer-spaced coords (e.g. `x=[100,101,102], y=[200,199]`) and silently stripped their georef on write. The sentinel now matches only the exact reader pattern: `int64` ascending contiguous-step-1 arange on both axes. User-authored integer-coord grids that don't match (descending, non-unit step, non-uniform, or non-`int64`) now produce a real transform or raise `NonUniformCoordsError`. Coord values round-trip exactly through the new path; dtype flips int->float on subsequent reads. (#2087)
Expand Down
237 changes: 132 additions & 105 deletions xrspatial/geotiff/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@
# carrying computed offsets, dimensions, or layout. See issue #1769.
_OVERRIDABLE_AUTO_TAG_IDS = frozenset({TAG_PHOTOMETRIC, TAG_EXTRA_SAMPLES})

# Thread-name prefix for the per-tile compression ``ThreadPoolExecutor``
# in the streaming write path. Tagging the workers lets leak-detection
# tests (issue #2276) tell our pool's threads apart from dask's
# offload/scheduler pools, which also use ``ThreadPoolExecutor`` and
# are kept alive deliberately by dask as singletons.
_TILE_POOL_THREAD_PREFIX = 'xrspatial-geotiff-tile-compress'

# TIFF Photometric Interpretation values (``PHOTOMETRIC_MINISBLACK``,
# ``PHOTOMETRIC_RGB``) and the ``_PHOTOMETRIC_NAME_MAP`` friendly-name
# table live in ``_encode.py`` and are re-exported above for
Expand Down Expand Up @@ -1023,113 +1030,133 @@ def _write_streaming(dask_data, path: str, *,
_pool_workers = min(tiles_per_segment, os.cpu_count() or 4)
_use_pool = (comp_tag != COMPRESSION_NONE
and _pool_workers > 1)
tile_pool = (ThreadPoolExecutor(max_workers=_pool_workers)
if _use_pool else None)

for tr in range(tiles_down):
r0 = tr * th
r1 = min(r0 + th, height)
actual_h = r1 - r0

for seg_start in range(0, tiles_across, tiles_per_segment):
seg_end = min(seg_start + tiles_per_segment,
tiles_across)
seg_c0 = seg_start * tw
seg_c1 = min(seg_end * tw, width)

# Compute just this horizontal segment
if dask_data.ndim == 3:
seg_np = np.asarray(
dask_data[r0:r1, seg_c0:seg_c1, :].compute())
else:
seg_np = np.asarray(
dask_data[r0:r1, seg_c0:seg_c1].compute())
if hasattr(seg_np, 'get'):
seg_np = seg_np.get()

if seg_np.dtype != out_dtype:
seg_np = seg_np.astype(out_dtype)

# NaN -> nodata sentinel
if (nodata is not None and seg_np.dtype.kind == 'f'
and not np.isnan(nodata)
and restore_sentinel):
nan_mask = np.isnan(seg_np)
if nan_mask.any():
seg_np = seg_np.copy()
seg_np[nan_mask] = seg_np.dtype.type(nodata)

# Build tile arrays for this segment
seg_tile_arrs = []
for tc in range(seg_start, seg_end):
c0 = tc * tw
c1 = min(c0 + tw, width)
actual_w = c1 - c0

local_c0 = c0 - seg_c0
local_c1 = c1 - seg_c0
tile_slice = seg_np[:, local_c0:local_c1]

if actual_h < th or actual_w < tw:
if seg_np.ndim == 3:
padded = np.zeros((th, tw, samples),
dtype=out_dtype)
# ``thread_name_prefix`` tags the worker threads so leak
# detection in tests (issue #2276) can tell our pool's
# workers apart from dask's offload/scheduler pools.
tile_pool = (
ThreadPoolExecutor(
max_workers=_pool_workers,
thread_name_prefix=_TILE_POOL_THREAD_PREFIX)
if _use_pool else None)

# Wrap the tile loop in ``try/finally`` so the pool is
# always shut down before any exception (compression
# failure, dask compute failure, file write failure)
# propagates. The previous code only called
# ``shutdown`` after the loop completed and leaked
# worker threads on any mid-stream raise. See #2276.
try:
for tr in range(tiles_down):
r0 = tr * th
r1 = min(r0 + th, height)
actual_h = r1 - r0

for seg_start in range(0, tiles_across, tiles_per_segment):
seg_end = min(seg_start + tiles_per_segment,
tiles_across)
seg_c0 = seg_start * tw
seg_c1 = min(seg_end * tw, width)

# Compute just this horizontal segment
if dask_data.ndim == 3:
seg_np = np.asarray(
dask_data[r0:r1, seg_c0:seg_c1, :].compute())
else:
seg_np = np.asarray(
dask_data[r0:r1, seg_c0:seg_c1].compute())
if hasattr(seg_np, 'get'):
seg_np = seg_np.get()

if seg_np.dtype != out_dtype:
seg_np = seg_np.astype(out_dtype)

# NaN -> nodata sentinel
if (nodata is not None and seg_np.dtype.kind == 'f'
and not np.isnan(nodata)
and restore_sentinel):
nan_mask = np.isnan(seg_np)
if nan_mask.any():
seg_np = seg_np.copy()
seg_np[nan_mask] = seg_np.dtype.type(nodata)

# Build tile arrays for this segment
seg_tile_arrs = []
for tc in range(seg_start, seg_end):
c0 = tc * tw
c1 = min(c0 + tw, width)
actual_w = c1 - c0

local_c0 = c0 - seg_c0
local_c1 = c1 - seg_c0
tile_slice = seg_np[:, local_c0:local_c1]

if actual_h < th or actual_w < tw:
if seg_np.ndim == 3:
padded = np.zeros((th, tw, samples),
dtype=out_dtype)
else:
padded = np.zeros((th, tw), dtype=out_dtype)
padded[:actual_h, :actual_w] = tile_slice
tile_arr = padded
else:
padded = np.zeros((th, tw), dtype=out_dtype)
padded[:actual_h, :actual_w] = tile_slice
tile_arr = padded
tile_arr = np.ascontiguousarray(tile_slice)

seg_tile_arrs.append(tile_arr)

# Parallel compress on the hoisted ``tile_pool``
# when it exists. zlib/zstd/LZW release the GIL,
# so threading actually parallelises the C-level
# work. Peak memory while the segment is in
# flight covers BOTH the uncompressed
# ``seg_tile_arrs`` (one full tile per column,
# released after the futures resolve) AND the
# compressed buffers ``seg_compressed`` (held
# until the sequential write loop drains them).
# Both lists are bounded by ``tiles_per_segment``
# which the streaming buffer cap sets; fall
# through to a serial path when the pool is None
# (no compression / single core) or when only
# one tile sits in this segment.
n_seg_tiles = len(seg_tile_arrs)
if tile_pool is None or n_seg_tiles <= 1:
seg_compressed = [
_compress_block(
ta, tw, th, samples, out_dtype,
bytes_per_sample, pred_int, comp_tag,
compression_level, max_z_error)
for ta in seg_tile_arrs
]
else:
tile_arr = np.ascontiguousarray(tile_slice)

seg_tile_arrs.append(tile_arr)

# Parallel compress on the hoisted ``tile_pool``
# when it exists. zlib/zstd/LZW release the GIL,
# so threading actually parallelises the C-level
# work. Peak memory while the segment is in
# flight covers BOTH the uncompressed
# ``seg_tile_arrs`` (one full tile per column,
# released after the futures resolve) AND the
# compressed buffers ``seg_compressed`` (held
# until the sequential write loop drains them).
# Both lists are bounded by ``tiles_per_segment``
# which the streaming buffer cap sets; fall
# through to a serial path when the pool is None
# (no compression / single core) or when only
# one tile sits in this segment.
n_seg_tiles = len(seg_tile_arrs)
if tile_pool is None or n_seg_tiles <= 1:
seg_compressed = [
_compress_block(
ta, tw, th, samples, out_dtype,
bytes_per_sample, pred_int, comp_tag,
compression_level, max_z_error)
for ta in seg_tile_arrs
]
else:
futures = [
tile_pool.submit(
_compress_block,
ta, tw, th, samples, out_dtype,
bytes_per_sample, pred_int, comp_tag,
compression_level, max_z_error,
True)
for ta in seg_tile_arrs
]
seg_compressed = [
fut.result() for fut in futures]

# Sequential file write to preserve on-disk tile order
for compressed in seg_compressed:
actual_offsets.append(current_offset)
actual_counts.append(len(compressed))
f.write(compressed)
current_offset += len(compressed)

del seg_np, seg_tile_arrs, seg_compressed

if tile_pool is not None:
tile_pool.shutdown(wait=True)
futures = [
tile_pool.submit(
_compress_block,
ta, tw, th, samples, out_dtype,
bytes_per_sample, pred_int, comp_tag,
compression_level, max_z_error,
True)
for ta in seg_tile_arrs
]
seg_compressed = [
fut.result() for fut in futures]

# Sequential file write to preserve on-disk tile order
for compressed in seg_compressed:
actual_offsets.append(current_offset)
actual_counts.append(len(compressed))
f.write(compressed)
current_offset += len(compressed)

del seg_np, seg_tile_arrs, seg_compressed
finally:
# ``cancel_futures=True`` (Python 3.9+) drops any
# queued-but-not-started compress jobs on the
# error path so ``wait=True`` only blocks on work
# already in flight. The previous shutdown call
# lived past the for-loop and never ran when an
# exception escaped, leaking worker threads. See
# issue #2276.
if tile_pool is not None:
tile_pool.shutdown(wait=True, cancel_futures=True)
else:
# Strip layout
for i in range(n_entries):
Expand Down
Loading
Loading