Skip to content

Commit 56cc261

Browse files
authored
geotiff: parallelize strip writer, adaptive tile threshold, optional libdeflate (closes #1800) (#1801)
* geotiff: parallelize strip writer and add optional libdeflate backend (#1800) The deflate strip-write path was 3.7x slower than rioxarray/GDAL because `_write_stripped` ran zlib.compress serially while the tile writer already parallelized via a thread pool. Three changes: 1. Mirror `_write_tiled`'s ThreadPoolExecutor pattern in `_write_stripped`. Strip preparation is hoisted into a new `_prepare_strip` helper so the same code drives both the serial and parallel paths. A 2048x2048 deflate strip write drops from 405 ms to 70 ms (5.8x speedup, beats rioxarray's 102 ms). 2. Replace the tile writer's `n_tiles <= 4` sequential cutoff with a bytes-based threshold (`_PARALLEL_MIN_BYTES = 4 MiB`). Pre-fix, `tile_size=1024` on a 2048x2048 image produced n_tiles=4 and forced the slow path; now those writes parallelize too. 3. Route `deflate_compress` through the optional `libdeflate` package when installed (1.5-2x faster than stdlib zlib at the same level; GDAL >= 3.7 already uses it). Output is wire-compatible -- decoded streams round-trip through `zlib.decompress` unchanged. Compressors are cached per thread via `threading.local`. * geotiff: harden thread-local cache test against pool scheduling (#1800) PR #1801's review flagged that `test_libdeflate_compressor_cache_is_thread_local` could pass with a single observed cache id: `ThreadPoolExecutor(max_workers=2).map(...)` is free to run both submissions on the same worker if the first returns quickly. Force both tasks to occupy a worker at the same time with a `threading.Barrier`, record `threading.get_ident()` so the assertion fails loudly if only one thread actually ran, and use the executor as a context manager so the pool is shut down on assertion failure.
1 parent a90f77f commit 56cc261

3 files changed

Lines changed: 422 additions & 45 deletions

File tree

xrspatial/geotiff/_compression.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,48 @@
11
"""Compression codecs: deflate (zlib) and LZW (Numba), plus horizontal predictor."""
22
from __future__ import annotations
33

4+
import threading
45
import zlib
56

67
import numpy as np
78

89
from xrspatial.utils import ngjit
910

11+
# -- Optional libdeflate backend --------------------------------------------
12+
#
13+
# When the ``libdeflate`` package is installed, ``deflate_compress`` routes
14+
# through it: libdeflate is typically 1.5-2x faster than ``zlib`` at the
15+
# same compression level, and GDAL >= 3.7 already uses it when available
16+
# so our installs match throughput. Output is wire-compatible (zlib
17+
# format), so encoded streams round-trip through stdlib ``zlib.decompress``
18+
# unchanged.
19+
#
20+
# libdeflate's ``Compressor`` objects are not thread-safe, so we keep one
21+
# per (thread, level) pair via ``threading.local``. The writer drives
22+
# compression from a ``ThreadPoolExecutor``; per-thread caching avoids
23+
# allocating a fresh compressor per strip/tile.
24+
try: # pragma: no cover - exercised only when libdeflate is installed
25+
import libdeflate as _libdeflate
26+
_HAVE_LIBDEFLATE = True
27+
except ImportError:
28+
_libdeflate = None
29+
_HAVE_LIBDEFLATE = False
30+
31+
_libdeflate_thread_local = threading.local()
32+
33+
34+
def _libdeflate_compressor(level: int):
35+
"""Return a thread-local libdeflate Compressor for *level*."""
36+
cache = getattr(_libdeflate_thread_local, 'cache', None)
37+
if cache is None:
38+
cache = {}
39+
_libdeflate_thread_local.cache = cache
40+
comp = cache.get(level)
41+
if comp is None:
42+
comp = _libdeflate.Compressor(level)
43+
cache[level] = comp
44+
return comp
45+
1046
# -- Decompression-bomb defenses ---------------------------------------------
1147
#
1248
# A malicious TIFF can declare a small strip/tile compressed payload that
@@ -98,7 +134,14 @@ def deflate_decompress(data: bytes, expected_size: int = 0) -> bytes:
98134

99135

100136
def deflate_compress(data: bytes, level: int = 6) -> bytes:
101-
"""Compress data with deflate/zlib."""
137+
"""Compress data with deflate/zlib.
138+
139+
Uses ``libdeflate`` when installed (1.5-2x faster than ``zlib``) and
140+
falls back to ``zlib.compress`` otherwise. Output is wire-compatible
141+
either way: the stdlib ``zlib.decompress`` accepts both.
142+
"""
143+
if _HAVE_LIBDEFLATE:
144+
return _libdeflate_compressor(level).compress(data, _libdeflate.Format.ZLIB)
102145
return zlib.compress(data, level)
103146

104147

xrspatial/geotiff/_writer.py

Lines changed: 102 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,14 @@ def _compression_tag(compression_name: str) -> int:
225225
#: override.
226226
_MAX_OVERVIEW_LEVELS = 8
227227

228+
#: Total uncompressed payload (bytes) below which the strip and tile
229+
#: writers stay sequential. The thread-pool startup cost dominates on
230+
#: small rasters; above this size the per-block compression cost more
231+
#: than pays for it. 4 MiB was chosen empirically on a 20-core box:
232+
#: parallel becomes a net win around ~2 MiB, and the 4 MiB margin keeps
233+
#: a few-tile / two-strip layout from incurring a slowdown.
234+
_PARALLEL_MIN_BYTES = 4 * 1024 * 1024
235+
228236

229237
def _validate_overview_levels(overview_levels, height=None, width=None):
230238
"""Validate and normalise an explicit ``overview_levels`` list.
@@ -651,12 +659,50 @@ def _build_ifd(tags: list[tuple], overflow_base: int,
651659
# Strip writer
652660
# ---------------------------------------------------------------------------
653661

662+
def _prepare_strip(data, i, rows_per_strip, height, width, samples, dtype,
663+
bytes_per_sample, predictor: int, compression,
664+
compression_level=None, max_z_error: float = 0.0):
665+
"""Extract and compress a single strip. Thread-safe."""
666+
r0 = i * rows_per_strip
667+
r1 = min(r0 + rows_per_strip, height)
668+
strip_rows = r1 - r0
669+
670+
if compression == COMPRESSION_JPEG:
671+
strip_data = np.ascontiguousarray(data[r0:r1]).tobytes()
672+
return jpeg_compress(strip_data, width, strip_rows, samples)
673+
if predictor != 1 and compression != COMPRESSION_NONE:
674+
strip_arr = np.ascontiguousarray(data[r0:r1])
675+
buf = strip_arr.view(np.uint8).ravel().copy()
676+
buf = _apply_predictor_encode(
677+
buf, predictor, width, strip_rows, bytes_per_sample, samples)
678+
strip_data = buf.tobytes()
679+
else:
680+
strip_data = np.ascontiguousarray(data[r0:r1]).tobytes()
681+
682+
if compression == COMPRESSION_JPEG2000:
683+
from ._compression import jpeg2000_compress
684+
return jpeg2000_compress(
685+
strip_data, width, strip_rows, samples=samples, dtype=dtype)
686+
if compression == COMPRESSION_LERC:
687+
from ._compression import lerc_compress
688+
return lerc_compress(
689+
strip_data, width, strip_rows, samples=samples, dtype=dtype,
690+
max_z_error=max_z_error)
691+
if compression_level is None:
692+
return compress(strip_data, compression)
693+
return compress(strip_data, compression, level=compression_level)
694+
695+
654696
def _write_stripped(data: np.ndarray, compression: int, predictor: int,
655697
rows_per_strip: int = 256,
656698
compression_level: int | None = None,
657699
max_z_error: float = 0.0) -> tuple[list, list, list]:
658700
"""Compress data as strips.
659701
702+
For compressed formats (deflate, lzw, zstd, lz4, ...) strips are
703+
compressed in parallel using a thread pool: zlib, zstandard, lz4,
704+
and the Numba LZW kernel all release the GIL during compression.
705+
660706
Returns
661707
-------
662708
(offsets_placeholder, byte_counts, compressed_chunks)
@@ -668,53 +714,60 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: int,
668714
dtype = data.dtype
669715
bytes_per_sample = dtype.itemsize
670716

671-
strips = []
717+
num_strips = math.ceil(height / rows_per_strip)
718+
719+
total_bytes = int(data.nbytes)
720+
721+
# Sequential path: uncompressed, few strips, or small payload. The
722+
# threshold mirrors the tile writer so we don't pay thread-pool
723+
# overhead on tiny rasters.
724+
use_parallel = (
725+
compression != COMPRESSION_NONE
726+
and num_strips > 2
727+
and total_bytes > _PARALLEL_MIN_BYTES
728+
)
729+
730+
if not use_parallel:
731+
strips = []
732+
rel_offsets = []
733+
byte_counts = []
734+
current_offset = 0
735+
for i in range(num_strips):
736+
compressed = _prepare_strip(
737+
data, i, rows_per_strip, height, width, samples, dtype,
738+
bytes_per_sample, predictor, compression,
739+
compression_level, max_z_error,
740+
)
741+
rel_offsets.append(current_offset)
742+
byte_counts.append(len(compressed))
743+
strips.append(compressed)
744+
current_offset += len(compressed)
745+
return rel_offsets, byte_counts, strips
746+
747+
# Parallel strip compression -- zlib/zstd/lz4/LZW all release the GIL.
748+
from concurrent.futures import ThreadPoolExecutor
749+
import os
750+
751+
n_workers = min(num_strips, os.cpu_count() or 4)
752+
with ThreadPoolExecutor(max_workers=n_workers) as pool:
753+
compressed_strips = list(pool.map(
754+
lambda i: _prepare_strip(
755+
data, i, rows_per_strip, height, width, samples, dtype,
756+
bytes_per_sample, predictor, compression,
757+
compression_level, max_z_error,
758+
),
759+
range(num_strips),
760+
))
761+
672762
rel_offsets = []
673763
byte_counts = []
674764
current_offset = 0
675-
676-
num_strips = math.ceil(height / rows_per_strip)
677-
for i in range(num_strips):
678-
r0 = i * rows_per_strip
679-
r1 = min(r0 + rows_per_strip, height)
680-
strip_rows = r1 - r0
681-
682-
if compression == COMPRESSION_JPEG:
683-
strip_data = np.ascontiguousarray(data[r0:r1]).tobytes()
684-
compressed = jpeg_compress(strip_data, width, strip_rows, samples)
685-
elif predictor != 1 and compression != COMPRESSION_NONE:
686-
strip_arr = np.ascontiguousarray(data[r0:r1])
687-
buf = strip_arr.view(np.uint8).ravel().copy()
688-
buf = _apply_predictor_encode(
689-
buf, predictor, width, strip_rows, bytes_per_sample, samples)
690-
strip_data = buf.tobytes()
691-
if compression_level is None:
692-
compressed = compress(strip_data, compression)
693-
else:
694-
compressed = compress(strip_data, compression, level=compression_level)
695-
else:
696-
strip_data = np.ascontiguousarray(data[r0:r1]).tobytes()
697-
698-
if compression == COMPRESSION_JPEG2000:
699-
from ._compression import jpeg2000_compress
700-
compressed = jpeg2000_compress(
701-
strip_data, width, strip_rows, samples=samples, dtype=dtype)
702-
elif compression == COMPRESSION_LERC:
703-
from ._compression import lerc_compress
704-
compressed = lerc_compress(
705-
strip_data, width, strip_rows, samples=samples, dtype=dtype,
706-
max_z_error=max_z_error)
707-
elif compression_level is None:
708-
compressed = compress(strip_data, compression)
709-
else:
710-
compressed = compress(strip_data, compression, level=compression_level)
711-
765+
for cs in compressed_strips:
712766
rel_offsets.append(current_offset)
713-
byte_counts.append(len(compressed))
714-
strips.append(compressed)
715-
current_offset += len(compressed)
767+
byte_counts.append(len(cs))
768+
current_offset += len(cs)
716769

717-
return rel_offsets, byte_counts, strips
770+
return rel_offsets, byte_counts, compressed_strips
718771

719772

720773
# ---------------------------------------------------------------------------
@@ -841,8 +894,13 @@ def _write_tiled(data: np.ndarray, compression: int, predictor: int,
841894

842895
return rel_offsets, byte_counts, tiles
843896

844-
if n_tiles <= 4:
845-
# Very few tiles: sequential (thread pool overhead not worth it)
897+
# Sequential path: very few tiles, or small total payload. A previous
898+
# ``n_tiles <= 4`` cutoff sent ``tile_size=1024`` writes on a 2048x2048
899+
# image down the serial path (n_tiles=4) and made them ~8x slower than
900+
# the parallel path. Switching to a bytes-based threshold lets
901+
# large-tile writes parallelize while still skipping the pool on
902+
# small rasters where its setup cost dominates.
903+
if n_tiles <= 2 or int(data.nbytes) <= _PARALLEL_MIN_BYTES:
846904
tiles = []
847905
rel_offsets = []
848906
byte_counts = []

0 commit comments

Comments
 (0)