Skip to content

Commit 288ff83

Browse files
committed
geotiff: address PR #1952 review nits
1 parent bbec25c commit 288ff83

2 files changed

Lines changed: 23 additions & 8 deletions

File tree

xrspatial/geotiff/_writers/gpu.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,17 @@ def _gpu_compress_to_part(gpu_arr, w, h, spp):
493493
# issue #1948.
494494
current = arr
495495
cumulative_factor = 1
496+
# ``make_overview_gpu`` preserves dtype, so the sentinel cast is
497+
# loop-invariant. Hoist it (and the float/finite gate) out of the
498+
# inner ``while`` to skip redundant per-level scalar work.
499+
rewrite_nodata = (
500+
nodata is not None
501+
and np_dtype.kind == 'f'
502+
and not np.isnan(float(nodata))
503+
)
504+
sentinel_scalar = (
505+
np_dtype.type(nodata) if rewrite_nodata else None
506+
)
496507
for target_factor in overview_levels:
497508
# Halve repeatedly until the cumulative decimation matches
498509
# the requested factor. Validation has already established
@@ -502,14 +513,10 @@ def _gpu_compress_to_part(gpu_arr, w, h, spp):
502513
current = make_overview_gpu(current, method=overview_resampling,
503514
nodata=nodata)
504515
cumulative_factor *= 2
505-
if (nodata is not None
506-
and np.dtype(str(current.dtype)).kind == 'f'
507-
and not np.isnan(float(nodata))):
516+
if rewrite_nodata:
508517
nan_mask = cupy.isnan(current)
509518
if bool(nan_mask.any().item()):
510-
cupy.putmask(
511-
current, nan_mask,
512-
np.dtype(str(current.dtype)).type(nodata))
519+
cupy.putmask(current, nan_mask, sentinel_scalar)
513520
oh, ow = current.shape[:2]
514521
parts.append(_gpu_compress_to_part(current, ow, oh, samples))
515522

xrspatial/geotiff/tests/test_gpu_writer_overview_inplace_1948.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,16 @@ def test_gpu_writer_overview_loop_uses_putmask_1948():
8383
)
8484
# The legacy two-line pattern would have ``current = current.copy()``
8585
# right before the indexed write. Ensure the overview branch no
86-
# longer contains that exact line.
87-
overview_branch = src[idx_overview:idx_overview + 1500]
86+
# longer contains that exact line. Anchor the slice on the next
87+
# statement after the inner ``while`` (``parts.append(...)``) so
88+
# the window tracks the real loop body instead of a fixed
89+
# character count that drifts as surrounding code changes.
90+
idx_parts_append = src.find("parts.append(", idx_overview)
91+
assert idx_parts_append != -1, (
92+
"could not locate the ``parts.append(`` sentinel that closes "
93+
"the overview-loop body"
94+
)
95+
overview_branch = src[idx_overview:idx_parts_append]
8896
assert "current = current.copy()" not in overview_branch, (
8997
"overview loop should no longer copy the cupy buffer before "
9098
"the in-place sentinel rewrite (issue #1948)."

0 commit comments

Comments
 (0)