diff --git a/pyproject.toml b/pyproject.toml index b100e00..6bd09f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "numcodecs-combinators[xarray]~=0.2.13", "numcodecs-observers~=0.1.2", "numcodecs-replace==0.1.0", - "numcodecs-safeguards==0.1.0b2", + "numcodecs-safeguards==0.1.0b5", "numcodecs-wasm==0.2.2", "numcodecs-wasm-bit-round==0.4.0", "numcodecs-wasm-ebcc==0.3.1a0", diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/__init__.py b/src/climatebenchpress/compressor/compressors/safeguarded/__init__.py index 2fb5669..9ca2e18 100644 --- a/src/climatebenchpress/compressor/compressors/safeguarded/__init__.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/__init__.py @@ -1,5 +1,6 @@ __all__ = [ "SafeguardedBitRoundPco", + "SafeguardedEbcc", "SafeguardedSperr", "SafeguardedSz3", "SafeguardedZero", @@ -8,6 +9,7 @@ ] from .bitround_pco import SafeguardedBitRoundPco +from .ebcc import SafeguardedEbcc from .sperr import SafeguardedSperr from .sz3 import SafeguardedSz3 from .zero import SafeguardedZero diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/ebcc.py b/src/climatebenchpress/compressor/compressors/safeguarded/ebcc.py new file mode 100644 index 0000000..98184a8 --- /dev/null +++ b/src/climatebenchpress/compressor/compressors/safeguarded/ebcc.py @@ -0,0 +1,73 @@ +__all__ = ["SafeguardedEbcc"] + +import numcodecs.astype +import numcodecs_replace +import numcodecs_safeguards +import numcodecs_wasm_ebcc +import numpy as np +from numcodecs_combinators.stack import CodecStack + +from ..abc import Compressor + + +class SafeguardedEbcc(Compressor): + """Safeguarded EBCC compressor.""" + + name = "safeguarded-ebcc" + description = "Safeguarded(EBCC)" + + @staticmethod + def abs_bound_codec(error_bound, dtype=None, **kwargs): + assert dtype is not None, "dtype must be provided" + + return numcodecs_safeguards.SafeguardedCodec( + codec=CodecStack( + # EBCC only supports float32 data + numcodecs.astype.AsType( + encode_dtype="float32", + decode_dtype=dtype.name, + ), + # inspired by H5Z-SPERR's treatment of NaN values: + # https://github.com/NCAR/H5Z-SPERR/blob/72ebcb00e382886c229c5ef5a7e237fe451d5fb8/src/h5z-sperr.c#L464-L473 + # https://github.com/NCAR/H5Z-SPERR/blob/72ebcb00e382886c229c5ef5a7e237fe451d5fb8/src/h5zsperr_helper.cpp#L179-L212 + numcodecs_replace.ReplaceFilterCodec(replacements={np.nan: "nan_mean"}), + numcodecs_wasm_ebcc.Ebcc( + # reasonable default recommended by Langwen Huang + base_cr=100, + residual="absolute", + error=error_bound, + chunk_shape="auto", + ), + ), + safeguards=[ + dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), + ], + ) + + @staticmethod + def rel_bound_codec(error_bound, dtype=None, **kwargs): + assert dtype is not None, "dtype must be provided" + + return numcodecs_safeguards.SafeguardedCodec( + codec=CodecStack( + # EBCC only supports float32 data + numcodecs.astype.AsType( + encode_dtype="float32", + decode_dtype=dtype.name, + ), + # inspired by H5Z-SPERR's treatment of NaN values: + # https://github.com/NCAR/H5Z-SPERR/blob/72ebcb00e382886c229c5ef5a7e237fe451d5fb8/src/h5z-sperr.c#L464-L473 + # https://github.com/NCAR/H5Z-SPERR/blob/72ebcb00e382886c229c5ef5a7e237fe451d5fb8/src/h5zsperr_helper.cpp#L179-L212 + numcodecs_replace.ReplaceFilterCodec(replacements={np.nan: "nan_mean"}), + numcodecs_wasm_ebcc.Ebcc( + # reasonable default recommended by Langwen Huang + base_cr=100, + residual="relative", + error=error_bound, + chunk_shape="auto", + ), + ), + safeguards=[ + dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), + ], + )