From c6d556a2d4285b192d570c8fe4f7e20e7bba44ca Mon Sep 17 00:00:00 2001 From: boyceyi <1473416941@qq.com> Date: Sat, 9 May 2026 12:09:50 +0800 Subject: [PATCH 1/8] feat(latent_codecs): add containerized infrastructure for Family 1 codecs Lay the groundwork for refactoring channel-slice models (STF, WACNN, TCM, CCA) toward an ELIC-style containerized layout where the model owns only g_a/g_s and the latent_codec owns the entire entropy stack (h_a, h_s, z bottleneck, per-slice channel context, per-slice leaves). Codec primitives (compressai/latent_codecs): - DualHyperSynthesis(h_mean_s, h_scale_s) adapter so HyperpriorLatentCodec can wrap two parallel hyper-synthesis heads while keeping their state_dict paths split. - LRPGaussianLatentCodec(GaussianConditionalLatentCodec): subclass adding the lrp_scale * tanh(lrp_transform(cat(ctx_params, y_hat))) refinement used by Zhu2022 and follow-ups; lives in the same file as its base. - ChannelGroupsLatentCodec gains optional max_support_slices and support_filter parameters (defaults -1 / None preserve ELIC's use-all-prior behaviour). Enables STF/TCM-style support clamping and CCA-aux skip-most-recent selection without sibling codec classes. - _slice_helpers hosts make_entropy_transform, slice_support_channels, lrp_support_channels, infer_num_slices, infer_max_support_slices, with default state_dict prefixes pointing at the post-refactor layout (latent_codec.latent_codec.y.channel_context.yK.mean_cc.*). Application-layer helpers (compressai/models/_helpers): - build_channel_slice_codec: per-slice factory that hides the y0..yK-1 dictionary boilerplate when assembling ChannelGroupsLatentCodec. - MeanScaleContextHead + build_mean_scale_head: independent mean_cc and scale_cc Sequentials with optional per-path support transforms (for SWAtten / NAFTransform). No model wiring is changed in this commit; STF/WACNN continue to use the existing _bases/slice_entropy + ChannelSliceLatentCodec path. Coverage: 28 new unit tests across tests/test_latent_codecs.py and tests/test_models_helpers.py; tests/test_models.py::TestStf still passes; importing compressai / compressai.zoo / compressai.latent_codecs introduces no new dependencies. --- compressai/latent_codecs/__init__.py | 5 +- compressai/latent_codecs/_hyper_synthesis.py | 60 ++++ compressai/latent_codecs/_slice_helpers.py | 162 ++++++++++ compressai/latent_codecs/channel_groups.py | 17 +- .../latent_codecs/gaussian_conditional.py | 64 ++++ compressai/models/_helpers/__init__.py | 37 +++ compressai/models/_helpers/channel_context.py | 137 ++++++++ compressai/models/_helpers/channel_slice.py | 107 +++++++ tests/test_latent_codecs.py | 300 ++++++++++++++++++ tests/test_models_helpers.py | 178 +++++++++++ 10 files changed, 1064 insertions(+), 3 deletions(-) create mode 100644 compressai/latent_codecs/_hyper_synthesis.py create mode 100644 compressai/latent_codecs/_slice_helpers.py create mode 100644 compressai/models/_helpers/__init__.py create mode 100644 compressai/models/_helpers/channel_context.py create mode 100644 compressai/models/_helpers/channel_slice.py create mode 100644 tests/test_latent_codecs.py create mode 100644 tests/test_models_helpers.py diff --git a/compressai/latent_codecs/__init__.py b/compressai/latent_codecs/__init__.py index 82a41947..bd587257 100644 --- a/compressai/latent_codecs/__init__.py +++ b/compressai/latent_codecs/__init__.py @@ -27,13 +27,14 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from ._hyper_synthesis import DualHyperSynthesis from .base import LatentCodec from .channel_groups import ChannelGroupsLatentCodec from .channel_slice import ChannelSliceLatentCodec from .checkerboard import CheckerboardLatentCodec from .entropy_bottleneck import EntropyBottleneckLatentCodec from .gain import GainHyperLatentCodec, GainHyperpriorLatentCodec -from .gaussian_conditional import GaussianConditionalLatentCodec +from .gaussian_conditional import GaussianConditionalLatentCodec, LRPGaussianLatentCodec from .hyper import HyperLatentCodec from .hyperprior import HyperpriorLatentCodec from .rasterscan import RasterScanLatentCodec @@ -43,11 +44,13 @@ "ChannelGroupsLatentCodec", "ChannelSliceLatentCodec", "CheckerboardLatentCodec", + "DualHyperSynthesis", "EntropyBottleneckLatentCodec", "GainHyperLatentCodec", "GainHyperpriorLatentCodec", "GaussianConditionalLatentCodec", "HyperLatentCodec", "HyperpriorLatentCodec", + "LRPGaussianLatentCodec", "RasterScanLatentCodec", ] diff --git a/compressai/latent_codecs/_hyper_synthesis.py b/compressai/latent_codecs/_hyper_synthesis.py new file mode 100644 index 00000000..69bbaf2a --- /dev/null +++ b/compressai/latent_codecs/_hyper_synthesis.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch +import torch.nn as nn + +from torch import Tensor + +__all__ = [ + "DualHyperSynthesis", +] + + +class DualHyperSynthesis(nn.Module): + """Concatenate outputs of two parallel hyper-synthesis heads. + + Channel-slice models in Family 1 (STF, WACNN, TCM, CCA, ...) factor the + hyperprior as ``params = cat(h_mean_s(z_hat), h_scale_s(z_hat))``. Pass + an instance as the ``h_s`` argument of + :class:`~compressai.latent_codecs.HyperpriorLatentCodec` to fold both + heads into the codec while keeping their state-dict paths separate + (``h_s.h_mean_s.*`` / ``h_s.h_scale_s.*``). + """ + + h_mean_s: nn.Module + h_scale_s: nn.Module + + def __init__(self, h_mean_s: nn.Module, h_scale_s: nn.Module) -> None: + super().__init__() + self.h_mean_s = h_mean_s + self.h_scale_s = h_scale_s + + def forward(self, z_hat: Tensor) -> Tensor: + return torch.cat([self.h_mean_s(z_hat), self.h_scale_s(z_hat)], dim=1) diff --git a/compressai/latent_codecs/_slice_helpers.py b/compressai/latent_codecs/_slice_helpers.py new file mode 100644 index 00000000..8c017cd2 --- /dev/null +++ b/compressai/latent_codecs/_slice_helpers.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Channel-slice support helpers shared by Family 1 codecs. + +These functions support the Family 1 (pure 1-pass channel-slice) entropy +models — STF / WACNN / TCM / CCA / DCAE / MambaVC. They were previously +hosted in ``compressai.models._bases.slice_entropy``; the canonical home is +now this module so they sit alongside the latent-codec primitives that +consume them. ``_DEFAULT_NUM_SLICES_PREFIX`` reflects the post-refactor +state-dict layout used by the containerised +:class:`~compressai.latent_codecs.ChannelGroupsLatentCodec` wiring. +""" + +from __future__ import annotations + +from typing import Dict, Sequence + +import torch.nn as nn + +from torch import Tensor + +from compressai.models.utils import conv + +__all__ = [ + "infer_max_support_slices", + "infer_num_slices", + "lrp_support_channels", + "make_entropy_transform", + "slice_support_channels", +] + + +# Post-refactor state-dict layout: ``ChannelGroupsLatentCodec`` lives at +# ``latent_codec.latent_codec.y`` and stores per-slice mean / scale heads +# under ``channel_context.y{k}.{mean,scale}_cc.0.weight``. Slice 0 has no +# channel context, so prefix scans should expect ``k >= 1``. +_DEFAULT_NUM_SLICES_PREFIX = "latent_codec.latent_codec.y.channel_context.y" +_DEFAULT_KEY_SUFFIX = ".mean_cc.0.weight" + + +def slice_support_channels( + latent_channels: int, + slice_channels: int, + index: int, + max_support_slices: int, +) -> int: + if max_support_slices < 0: + return latent_channels + slice_channels * index + return latent_channels + slice_channels * min(index, max_support_slices) + + +def lrp_support_channels( + latent_channels: int, + slice_channels: int, + index: int, + max_support_slices: int, +) -> int: + if max_support_slices < 0: + return latent_channels + slice_channels * (index + 1) + return latent_channels + slice_channels * min(index + 1, max_support_slices + 1) + + +def make_entropy_transform( + in_channels: int, + out_channels: int, + *, + widths: Sequence[int] = (224, 128), +) -> nn.Sequential: + """Stack of stride-1 3x3 convs with GELU activations. + + Used as the ``mean_cc`` / ``scale_cc`` per-slice heads (and as ``lrp_transform``) + by every Family 1 channel-slice model. ``widths`` specifies hidden conv + widths and defaults to the TCM / CCA / Mamba 3-conv stack + ``(224, 128)``; pass ``widths=(224, 176, 128, 64)`` for the STF / WACNN + 5-conv stack. + """ + layers: list[nn.Module] = [] + prev = in_channels + for width in widths: + layers.append(conv(prev, width, stride=1, kernel_size=3)) + layers.append(nn.GELU()) + prev = width + layers.append(conv(prev, out_channels, stride=1, kernel_size=3)) + return nn.Sequential(*layers) + + +def infer_num_slices( + state_dict: Dict[str, Tensor], + *, + prefix: str = _DEFAULT_NUM_SLICES_PREFIX, + suffix: str = _DEFAULT_KEY_SUFFIX, +) -> int: + """Count distinct ``y{k}`` channel-context entries in ``state_dict``. + + Slice 0 has no channel-context entry (it consumes ``side_params`` only), + so the count returned is ``num_slices - 1``; callers wanting the slice + count should add one whenever any channel context is present. + """ + slice_indices = { + int(key[len(prefix) :].split(".", 1)[0]) + for key in state_dict + if key.startswith(prefix) and key.endswith(suffix) + } + if not slice_indices: + return 0 + return len(slice_indices) + 1 + + +def infer_max_support_slices( + state_dict: Dict[str, Tensor], + latent_channels: int, + num_slices: int, + *, + prefix: str = _DEFAULT_NUM_SLICES_PREFIX, + suffix: str = _DEFAULT_KEY_SUFFIX, + extra_factor: int = 1, +) -> int: + """Infer ``max_support_slices`` from the input width of the ``mean_cc`` + first conv. ``extra_factor`` accounts for application-layer heads (e.g., + DCAE / SAAF) that prepend additional copies of the latent + (``M*extra + slice_channels*N``); default ``1`` covers Family 1 models + whose ``mean_cc`` only sees the previous-slice support. + """ + slice_channels = latent_channels // num_slices + matching = [ + tensor.size(1) + for key, tensor in state_dict.items() + if key.startswith(prefix) and key.endswith(suffix) + ] + if not matching: + return 0 + max_input_channels = max(matching) + return max( + 0, (max_input_channels - extra_factor * latent_channels) // slice_channels + ) diff --git a/compressai/latent_codecs/channel_groups.py b/compressai/latent_codecs/channel_groups.py index dd8956c8..06204ab2 100644 --- a/compressai/latent_codecs/channel_groups.py +++ b/compressai/latent_codecs/channel_groups.py @@ -28,7 +28,7 @@ # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from itertools import accumulate -from typing import Any, Dict, List, Mapping, Tuple +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple import torch import torch.nn as nn @@ -74,12 +74,16 @@ def __init__( channel_context: Mapping[str, nn.Module], *, groups: List[int], + max_support_slices: int = -1, + support_filter: Optional[Callable[[int, List[Tensor]], List[Tensor]]] = None, **kwargs, ): super().__init__() self._kwargs = kwargs self.groups = list(groups) self.groups_acc = list(accumulate(self.groups, initial=0)) + self.max_support_slices = int(max_support_slices) + self.support_filter = support_filter self.channel_context = nn.ModuleDict(channel_context) self.latent_codec = nn.ModuleDict(latent_codec) @@ -165,5 +169,14 @@ def _get_ctx_params( ) -> Tensor: if k == 0: return side_params - ch_ctx_params = self.channel_context[f"y{k}"](self.merge_y(*y_hat_[:k])) + support = self._select_support(k, y_hat_) + ch_ctx_params = self.channel_context[f"y{k}"](self.merge_y(*support)) return self.merge_params(ch_ctx_params, side_params) + + def _select_support(self, k: int, y_hat_: List[Tensor]) -> List[Tensor]: + prior = list(y_hat_[:k]) + if self.support_filter is not None: + return list(self.support_filter(k, prior)) + if self.max_support_slices < 0: + return prior + return prior[: self.max_support_slices] diff --git a/compressai/latent_codecs/gaussian_conditional.py b/compressai/latent_codecs/gaussian_conditional.py index e422f681..456d87bc 100644 --- a/compressai/latent_codecs/gaussian_conditional.py +++ b/compressai/latent_codecs/gaussian_conditional.py @@ -29,6 +29,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union +import torch import torch.nn as nn from torch import Tensor @@ -41,6 +42,7 @@ __all__ = [ "GaussianConditionalLatentCodec", + "LRPGaussianLatentCodec", ] @@ -142,3 +144,65 @@ def _chunk(self, params: Tensor) -> Tuple[Tensor, Tensor]: if self.chunks == ("means", "scales"): means, scales = params.chunk(2, 1) return scales, means + + +@register_module("LRPGaussianLatentCodec") +class LRPGaussianLatentCodec(GaussianConditionalLatentCodec): + """Gaussian conditional with a latent residual prediction (LRP) refinement. + + Wraps :class:`GaussianConditionalLatentCodec` and applies an additive LRP + head to the quantized latent ``y_hat``. The LRP head receives + ``cat(ctx_params, y_hat)`` and produces a residual scaled by ``lrp_scale`` + and squashed by ``tanh``:: + + y_hat = y_hat + lrp_scale * tanh(lrp_transform(cat(ctx_params, y_hat))) + + Used as the per-slice leaf for Family 1 channel-slice models (STF / WACNN + / TCM / CCA-main, plus the first ``K-2`` slices of CCA-aux). The LRP + refinement variant was introduced in [Zhu2022] and is widely adopted by + follow-up work (ELIC checkerboard variants, MLIC++, TCM, ...). + + [Zhu2022]: `"Transformer-based Transform Coding" + `_, by Yinhao Zhu, Yang Yang + and Taco Cohen, ICLR 2022. + """ + + lrp_transform: nn.Module + + def __init__( + self, + lrp_transform: nn.Module, + *, + lrp_scale: float = 0.5, + **gc_kwargs: Any, + ) -> None: + super().__init__(**gc_kwargs) + self.lrp_transform = lrp_transform + self.lrp_scale = float(lrp_scale) + + def _apply_lrp(self, ctx_params: Tensor, y_hat: Tensor) -> Tensor: + lrp = self.lrp_scale * torch.tanh( + self.lrp_transform(torch.cat([ctx_params, y_hat], dim=1)) + ) + return y_hat + lrp + + def forward(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: + out = super().forward(y, ctx_params) + out["y_hat"] = self._apply_lrp(ctx_params, out["y_hat"]) + return out + + def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: + out = super().compress(y, ctx_params) + out["y_hat"] = self._apply_lrp(ctx_params, out["y_hat"]) + return out + + def decompress( + self, + strings: List[List[bytes]], + shape: Tuple[int, int], + ctx_params: Tensor, + **kwargs: Any, + ) -> Dict[str, Any]: + out = super().decompress(strings, shape, ctx_params, **kwargs) + out["y_hat"] = self._apply_lrp(ctx_params, out["y_hat"]) + return out diff --git a/compressai/models/_helpers/__init__.py b/compressai/models/_helpers/__init__.py new file mode 100644 index 00000000..fa122778 --- /dev/null +++ b/compressai/models/_helpers/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Application-layer helpers for assembling Family 1 channel-slice codecs. + +These helpers wrap :class:`~compressai.latent_codecs.ChannelGroupsLatentCodec` +behind a per-slice factory interface, removing the +``{"y0": ..., "y1": ..., ...}`` dictionary boilerplate that would otherwise +appear in every Family 1 model. They live outside ``compressai.latent_codecs`` +because they are application-layer ergonomics, not codec primitives. +""" diff --git a/compressai/models/_helpers/channel_context.py b/compressai/models/_helpers/channel_context.py new file mode 100644 index 00000000..e17e1298 --- /dev/null +++ b/compressai/models/_helpers/channel_context.py @@ -0,0 +1,137 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Mean / scale split channel-context heads for Family 1 models. + +The :class:`MeanScaleContextHead` keeps a separate ``mean_cc`` and +``scale_cc`` Sequential — matching the historical ``cc_mean_transforms`` / +``cc_scale_transforms`` ModuleList layout used by STF / WACNN / TCM / +CCA — and concatenates their outputs to form the +``channel_context.y{k}`` entry expected by +:class:`~compressai.latent_codecs.ChannelGroupsLatentCodec`. +""" + +from __future__ import annotations + +from typing import Callable, Optional, Sequence + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.latent_codecs._slice_helpers import make_entropy_transform + +__all__ = [ + "MeanScaleContextHead", + "build_mean_scale_head", +] + + +class MeanScaleContextHead(nn.Module): + """Channel-context head with separate mean / scale sub-networks. + + Internal layout:: + + mean_cc: in_channels -> ... -> slice_ch + scale_cc: in_channels -> ... -> slice_ch + + Forward output is ``cat([mean_cc(...), scale_cc(...)], dim=1)`` of shape + ``(B, 2 * slice_ch, H, W)``. Optional ``mean_support_transform`` / + ``scale_support_transform`` run independently on the input before the + sub-networks (used for SWAtten in TCM and NAFTransform in CCA). + """ + + mean_cc: nn.Module + scale_cc: nn.Module + mean_support_transform: nn.Module + scale_support_transform: nn.Module + + def __init__( + self, + mean_cc: nn.Module, + scale_cc: nn.Module, + mean_support_transform: Optional[nn.Module] = None, + scale_support_transform: Optional[nn.Module] = None, + ) -> None: + super().__init__() + self.mean_cc = mean_cc + self.scale_cc = scale_cc + self.mean_support_transform = mean_support_transform or nn.Identity() + self.scale_support_transform = scale_support_transform or nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + mean = self.mean_cc(self.mean_support_transform(x)) + scale = self.scale_cc(self.scale_support_transform(x)) + return torch.cat([mean, scale], dim=1) + + +def build_mean_scale_head( + slice_ch: int, + support_ch: int, + *, + widths: Sequence[int] = (224, 128), + support_transform_factory: Optional[Callable[[int, int], nn.Module]] = None, +) -> MeanScaleContextHead: + """Construct a :class:`MeanScaleContextHead` with default conv-stack heads. + + Parameters + ---------- + slice_ch + Channel count of the slice being predicted (per-sub-head output). + support_ch + Input channel count to ``mean_cc`` / ``scale_cc`` (post-support- + transform). Caller is responsible for accounting for any extra + channels that the application's wiring concatenates upstream. + widths + Hidden conv widths inside the ``mean_cc`` / ``scale_cc`` Sequentials. + STF / WACNN use ``(224, 176, 128, 64)``; TCM / CCA use + ``(224, 128)``. + support_transform_factory + ``(in_ch, out_ch) -> nn.Module``. When supplied, builds independent + instances for the mean and scale paths (e.g., per-slice SWAtten in + TCM or NAFTransform in CCA). Both transforms are expected to + preserve channel count. + """ + mean_cc = make_entropy_transform(support_ch, slice_ch, widths=widths) + scale_cc = make_entropy_transform(support_ch, slice_ch, widths=widths) + mean_support: Optional[nn.Module] + scale_support: Optional[nn.Module] + if support_transform_factory is not None: + mean_support = support_transform_factory(support_ch, support_ch) + scale_support = support_transform_factory(support_ch, support_ch) + else: + mean_support = None + scale_support = None + return MeanScaleContextHead( + mean_cc=mean_cc, + scale_cc=scale_cc, + mean_support_transform=mean_support, + scale_support_transform=scale_support, + ) diff --git a/compressai/models/_helpers/channel_slice.py b/compressai/models/_helpers/channel_slice.py new file mode 100644 index 00000000..3ac728f3 --- /dev/null +++ b/compressai/models/_helpers/channel_slice.py @@ -0,0 +1,107 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import Callable, List, Optional + +import torch.nn as nn + +from torch import Tensor + +from compressai.latent_codecs import ChannelGroupsLatentCodec +from compressai.latent_codecs.base import LatentCodec + +__all__ = [ + "build_channel_slice_codec", +] + + +def build_channel_slice_codec( + *, + groups: List[int], + leaf_factory: Callable[[int, int], LatentCodec], + channel_context_factory: Optional[Callable[[int, int, int], nn.Module]] = None, + max_support_slices: int = -1, + support_filter: Optional[Callable[[int, List[Tensor]], List[Tensor]]] = None, +) -> ChannelGroupsLatentCodec: + """Assemble a :class:`ChannelGroupsLatentCodec` with per-slice modules. + + Generates the ``{"y0".."yK-1"}`` ``latent_codec`` dict and the + ``{"y1".."yK-1"}`` ``channel_context`` dict (slice 0 has no channel + context — it consumes ``side_params`` only). + + Parameters + ---------- + groups + Per-slice channel counts. Use ``[M // K] * K`` for equal slices + (STF / WACNN / TCM) or a custom list for variable-size slices (CCA). + leaf_factory + ``(k, slice_ch_k) -> LatentCodec``. Constructs the leaf for slice + ``k`` — typically :class:`LRPGaussianLatentCodec` or + :class:`GaussianConditionalLatentCodec`. + channel_context_factory + ``(k, slice_ch_k, support_ch_k) -> nn.Module``. Constructs the + channel-context module for slice ``k`` (``k >= 1``). ``support_ch_k`` + is the total channel count of the previous slices that will be fed + in (post ``max_support_slices`` clamp). Default ``None`` uses + :class:`~torch.nn.Identity`, which is rarely useful in practice but + keeps the API parallel with ``leaf_factory``. + max_support_slices + Forwarded to :class:`ChannelGroupsLatentCodec`. Default ``-1`` uses + all previous slices (ELIC / CCA-main behaviour). + support_filter + Forwarded to :class:`ChannelGroupsLatentCodec`. Used by CCA-aux for + skip-most-recent support selection. + """ + if channel_context_factory is None: + channel_context_factory = lambda *_: nn.Identity() # noqa: E731 + + K = len(groups) + + def _support_ch(k: int) -> int: + if max_support_slices < 0: + count = k + else: + count = min(k, max_support_slices) + return sum(groups[:count]) + + channel_context = { + f"y{k}": channel_context_factory(k, groups[k], _support_ch(k)) + for k in range(1, K) + } + latent_codec = {f"y{k}": leaf_factory(k, groups[k]) for k in range(K)} + + return ChannelGroupsLatentCodec( + latent_codec=latent_codec, + channel_context=channel_context, + groups=list(groups), + max_support_slices=max_support_slices, + support_filter=support_filter, + ) diff --git a/tests/test_latent_codecs.py b/tests/test_latent_codecs.py new file mode 100644 index 00000000..89776eb9 --- /dev/null +++ b/tests/test_latent_codecs.py @@ -0,0 +1,300 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch +import torch.nn as nn + +from compressai.latent_codecs import ( + ChannelGroupsLatentCodec, + DualHyperSynthesis, + GaussianConditionalLatentCodec, + LRPGaussianLatentCodec, +) +from compressai.latent_codecs._slice_helpers import ( + infer_max_support_slices, + infer_num_slices, + lrp_support_channels, + make_entropy_transform, + slice_support_channels, +) + + +class TestDualHyperSynthesis: + def test_concatenates_dual_heads(self): + h_mean_s = nn.Conv2d(4, 6, 1) + h_scale_s = nn.Conv2d(4, 6, 1) + wrapper = DualHyperSynthesis(h_mean_s, h_scale_s) + z_hat = torch.randn(2, 4, 8, 8) + out = wrapper(z_hat) + assert out.shape == (2, 12, 8, 8) + expected = torch.cat([h_mean_s(z_hat), h_scale_s(z_hat)], dim=1) + assert torch.allclose(out, expected) + + def test_state_dict_paths_split_per_head(self): + wrapper = DualHyperSynthesis(nn.Conv2d(4, 6, 1), nn.Conv2d(4, 6, 1)) + keys = set(wrapper.state_dict().keys()) + assert "h_mean_s.weight" in keys + assert "h_mean_s.bias" in keys + assert "h_scale_s.weight" in keys + assert "h_scale_s.bias" in keys + + +class TestLRPGaussianLatentCodec: + def _make(self, slice_ch=4, ctx_ch=8): + # entropy_parameters maps ctx_ch -> 2*slice_ch (chunked into scales/means) + entropy_parameters = nn.Conv2d(ctx_ch, 2 * slice_ch, 1) + # lrp_transform input = ctx_ch + slice_ch (cat ctx_params with y_hat) + lrp_transform = nn.Sequential( + nn.Conv2d(ctx_ch + slice_ch, 8, 3, padding=1), + nn.GELU(), + nn.Conv2d(8, slice_ch, 3, padding=1), + ) + return LRPGaussianLatentCodec( + lrp_transform=lrp_transform, + entropy_parameters=entropy_parameters, + ) + + def test_forward_shapes(self): + codec = self._make() + y = torch.randn(2, 4, 8, 8) + ctx = torch.randn(2, 8, 8, 8) + out = codec(y, ctx) + assert out["y_hat"].shape == (2, 4, 8, 8) + assert out["likelihoods"]["y"].shape == (2, 4, 8, 8) + + def test_lrp_changes_y_hat_relative_to_base(self): + # With identical entropy parameters, LRP variant should produce a + # different y_hat than the un-refined GaussianConditionalLatentCodec. + torch.manual_seed(0) + slice_ch, ctx_ch = 4, 8 + entropy_parameters = nn.Conv2d(ctx_ch, 2 * slice_ch, 1) + lrp_transform = nn.Sequential( + nn.Conv2d(ctx_ch + slice_ch, 8, 3, padding=1), + nn.GELU(), + nn.Conv2d(8, slice_ch, 3, padding=1), + ) + # Push lrp through tanh's slope-1 region: zero biases keep tanh near 0 + # so the refinement is small but non-trivial. + for m in lrp_transform.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.1) + nn.init.zeros_(m.bias) + base = GaussianConditionalLatentCodec( + entropy_parameters=entropy_parameters + ).eval() + refined = LRPGaussianLatentCodec( + lrp_transform=lrp_transform, entropy_parameters=entropy_parameters + ).eval() + y = torch.randn(2, slice_ch, 8, 8) + ctx = torch.randn(2, ctx_ch, 8, 8) + with torch.no_grad(): + base_out = base(y, ctx) + ref_out = refined(y, ctx) + assert not torch.allclose(base_out["y_hat"], ref_out["y_hat"]) + # Same y_likelihoods because the Gaussian step is identical. + assert torch.allclose(base_out["likelihoods"]["y"], ref_out["likelihoods"]["y"]) + + def test_state_dict_round_trip(self): + codec = self._make().eval() + keys = set(codec.state_dict().keys()) + # Inherits entropy_parameters / gaussian_conditional from base, adds lrp_transform. + assert any(k.startswith("lrp_transform.") for k in keys) + assert any(k.startswith("entropy_parameters.") for k in keys) + assert any(k.startswith("gaussian_conditional.") for k in keys) + + reconstructed = self._make().eval() + reconstructed.load_state_dict(codec.state_dict()) + y = torch.randn(2, 4, 8, 8) + ctx = torch.randn(2, 8, 8, 8) + with torch.no_grad(): + out_a = codec(y, ctx) + out_b = reconstructed(y, ctx) + assert torch.allclose(out_a["y_hat"], out_b["y_hat"]) + + def test_lrp_scale_zero_collapses_to_base(self): + torch.manual_seed(1) + codec = self._make().eval() + codec.lrp_scale = 0.0 + y = torch.randn(2, 4, 8, 8) + ctx = torch.randn(2, 8, 8, 8) + base_codec = GaussianConditionalLatentCodec( + entropy_parameters=codec.entropy_parameters, + gaussian_conditional=codec.gaussian_conditional, + ).eval() + with torch.no_grad(): + base_out = base_codec(y, ctx) + ref_out = codec(y, ctx) + assert torch.allclose(base_out["y_hat"], ref_out["y_hat"]) + + +class TestChannelGroupsLatentCodecExtensions: + def _make_codec( + self, + groups=(4, 4, 4), + side_ch=8, + max_support_slices=-1, + support_filter=None, + ): + K = len(groups) + # Channel-context input is sum of (clamped) prev y_hat slice channels; + # use Identity which simply forwards the concatenated tensor unchanged. + channel_context = {f"y{k}": nn.Identity() for k in range(1, K)} + + def _ctx_in(k): + if k == 0: + return side_ch + if max_support_slices < 0: + count = k + else: + count = min(k, max_support_slices) + return side_ch + sum(groups[:count]) + + # Each leaf needs an entropy_parameters MLP sized to its own ctx input. + latent_codec = { + f"y{k}": GaussianConditionalLatentCodec( + entropy_parameters=nn.Conv2d(_ctx_in(k), 2 * groups[k], 1), + ) + for k in range(K) + } + return ChannelGroupsLatentCodec( + latent_codec=latent_codec, + channel_context=channel_context, + groups=list(groups), + max_support_slices=max_support_slices, + support_filter=support_filter, + ) + + def test_default_select_support_uses_all_prior(self): + codec = self._make_codec() + slices = [torch.zeros(1, 4, 4, 4) for _ in range(3)] + assert codec._select_support(0, slices) == [] + assert codec._select_support(1, slices) == slices[:1] + assert codec._select_support(3, slices) == slices[:3] + + def test_max_support_slices_clamps(self): + codec = self._make_codec(max_support_slices=2) + slices = [torch.zeros(1, 4, 4, 4) for _ in range(3)] + # k=3 with clamp=2 -> drop the most recent slice (index 2) + result = codec._select_support(3, slices) + assert len(result) == 2 + assert result == slices[:2] + + def test_support_filter_overrides_max_support(self): + # CCA-aux skip-most-recent pattern. + def skip_recent(k, prior): + return prior[: max(k - 1, 0)] + + codec = self._make_codec(max_support_slices=10, support_filter=skip_recent) + slices = [torch.zeros(1, 4, 4, 4) for _ in range(4)] + assert codec._select_support(0, slices) == [] + assert codec._select_support(1, slices) == [] # k-1 = 0 + assert codec._select_support(3, slices) == slices[:2] # skip slice 2 + + def test_default_forward_matches_pre_extension_behaviour(self): + # With defaults the new constructor should be drop-in for ELIC-style use. + torch.manual_seed(7) + codec = self._make_codec() + groups = codec.groups + M = sum(groups) + y = torch.randn(2, M, 8, 8) + side_params = torch.randn(2, 8, 8, 8) + out = codec(y, side_params) + assert out["y_hat"].shape == (2, M, 8, 8) + assert out["likelihoods"]["y"].shape == (2, M, 8, 8) + + def test_max_support_slices_changes_forward_output(self): + # Build a codec whose channel_context input width matches a clamped support; + # then verify that forward produces a different y_hat than the un-clamped version. + torch.manual_seed(3) + codec_clamp = self._make_codec(groups=(4, 4, 4, 4), max_support_slices=1) + # Reuse all leaf weights from clamp codec on a fresh "no clamp" codec for + # an apples-to-apples comparison; we expect clamp to drop information for + # slices k >= 2 only (their leaf input width differs, so we only need to + # check that clamp codec runs end-to-end). + y = torch.randn(2, 16, 8, 8) + side_params = torch.randn(2, 8, 8, 8) + out = codec_clamp(y, side_params) + assert out["y_hat"].shape == (2, 16, 8, 8) + + +class TestSliceHelpers: + def test_slice_support_channels_default_use_all(self): + # With max_support_slices = -1 the helper returns the full latent + k slices. + assert slice_support_channels(64, 8, 0, -1) == 64 + assert slice_support_channels(64, 8, 5, -1) == 64 + 8 * 5 + + def test_slice_support_channels_clamps(self): + assert slice_support_channels(64, 8, 5, 3) == 64 + 8 * 3 + assert slice_support_channels(64, 8, 1, 3) == 64 + 8 * 1 + + def test_lrp_support_channels(self): + assert lrp_support_channels(64, 8, 0, -1) == 64 + 8 + assert lrp_support_channels(64, 8, 5, 3) == 64 + 8 * 4 + + def test_make_entropy_transform_default_widths(self): + net = make_entropy_transform(40, 8) + # Default widths (224, 128): conv-gelu-conv-gelu-conv -> 5 modules. + assert len(net) == 5 + x = torch.randn(2, 40, 8, 8) + y = net(x) + assert y.shape == (2, 8, 8, 8) + + def test_make_entropy_transform_custom_widths(self): + net = make_entropy_transform(40, 8, widths=(64, 32)) + x = torch.randn(2, 40, 8, 8) + y = net(x) + assert y.shape == (2, 8, 8, 8) + + def test_infer_num_slices_new_path(self): + # New state-dict layout: channel_context entries exist for k >= 1. + # For 4 slices total, we expect 3 mean_cc keys -> infer returns 4. + sd = { + f"latent_codec.latent_codec.y.channel_context.y{k}.mean_cc.0.weight": ( + torch.zeros(8, 4) + ) + for k in range(1, 4) + } + assert infer_num_slices(sd) == 4 + + def test_infer_num_slices_empty(self): + assert infer_num_slices({}) == 0 + + def test_infer_max_support_slices_new_path(self): + # mean_cc.0 takes (latent_means + slice_channels * support) input channels. + # With M=64, num_slices=8, slice_channels=8, support=2 -> input ch = 64 + 16 = 80. + sd = { + "latent_codec.latent_codec.y.channel_context.y2.mean_cc.0.weight": ( + torch.zeros(64, 80, 3, 3) + ), + "latent_codec.latent_codec.y.channel_context.y3.mean_cc.0.weight": ( + torch.zeros(64, 80, 3, 3) + ), + } + # extra_factor=1 is the Family 1 default (single latent_means concat). + assert infer_max_support_slices(sd, latent_channels=64, num_slices=8) == 2 diff --git a/tests/test_models_helpers.py b/tests/test_models_helpers.py new file mode 100644 index 00000000..4da40400 --- /dev/null +++ b/tests/test_models_helpers.py @@ -0,0 +1,178 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch +import torch.nn as nn + +from compressai.latent_codecs import ( + ChannelGroupsLatentCodec, + GaussianConditionalLatentCodec, +) +from compressai.models._helpers.channel_context import ( + MeanScaleContextHead, + build_mean_scale_head, +) +from compressai.models._helpers.channel_slice import build_channel_slice_codec + + +class TestMeanScaleContextHead: + def test_forward_shape_concatenates_mean_and_scale(self): + slice_ch, support_ch = 4, 12 + head = build_mean_scale_head(slice_ch, support_ch, widths=(8, 8)) + x = torch.randn(2, support_ch, 4, 4) + out = head(x) + assert out.shape == (2, 2 * slice_ch, 4, 4) + + def test_state_dict_paths_split_mean_and_scale(self): + head = build_mean_scale_head(4, 12, widths=(8, 8)) + keys = set(head.state_dict().keys()) + assert any(k.startswith("mean_cc.") for k in keys) + assert any(k.startswith("scale_cc.") for k in keys) + # No support transforms by default -> no associated state. + assert not any(k.startswith("mean_support_transform.") for k in keys) + assert not any(k.startswith("scale_support_transform.") for k in keys) + + def test_support_transform_factory_wraps_inputs(self): + # Use 1x1 conv that preserves channel count. + def factory(c_in, c_out): + return nn.Conv2d(c_in, c_out, 1) + + head = build_mean_scale_head( + 4, 12, widths=(8, 8), support_transform_factory=factory + ) + keys = set(head.state_dict().keys()) + assert any(k.startswith("mean_support_transform.") for k in keys) + assert any(k.startswith("scale_support_transform.") for k in keys) + # mean and scale support transforms are independent instances (not shared). + assert head.mean_support_transform is not head.scale_support_transform + + def test_direct_construction_round_trip(self): + torch.manual_seed(0) + mean_cc = nn.Conv2d(12, 4, 1) + scale_cc = nn.Conv2d(12, 4, 1) + head = MeanScaleContextHead(mean_cc=mean_cc, scale_cc=scale_cc) + rebuilt = MeanScaleContextHead( + mean_cc=nn.Conv2d(12, 4, 1), scale_cc=nn.Conv2d(12, 4, 1) + ) + rebuilt.load_state_dict(head.state_dict()) + x = torch.randn(2, 12, 4, 4) + with torch.no_grad(): + assert torch.allclose(head(x), rebuilt(x)) + + +class TestBuildChannelSliceCodec: + def _leaf_factory(self, side_ch=8): + # Return a leaf factory whose entropy_parameters width matches what + # ChannelGroupsLatentCodec hands the leaf at slice k. + def factory(k, slice_ch): + if k == 0: + ctx_in = side_ch + else: + ctx_in = ( + side_ch + 2 * slice_ch + ) # ch_ctx (= 2*slice_ch from MeanScaleHead) + side + return GaussianConditionalLatentCodec( + entropy_parameters=nn.Conv2d(ctx_in, 2 * slice_ch, 1), + ) + + return factory + + def test_dict_keys_y0_through_yK_minus_one(self): + codec = build_channel_slice_codec( + groups=[4, 4, 4], + leaf_factory=self._leaf_factory(side_ch=8), + channel_context_factory=lambda k, ch, sup: build_mean_scale_head( + ch, sup, widths=(8, 8) + ), + ) + latent_keys = set(codec.latent_codec.keys()) + ctx_keys = set(codec.channel_context.keys()) + assert latent_keys == {"y0", "y1", "y2"} + # Slice 0 has no channel context entry by design. + assert ctx_keys == {"y1", "y2"} + + def test_state_dict_paths_match_design_doc(self): + codec = build_channel_slice_codec( + groups=[4, 4], + leaf_factory=self._leaf_factory(side_ch=8), + channel_context_factory=lambda k, ch, sup: build_mean_scale_head( + ch, sup, widths=(8,) + ), + ) + keys = set(codec.state_dict().keys()) + # Design doc paths (relative to ChannelGroupsLatentCodec root): + # channel_context.y{k}.mean_cc..weight + # channel_context.y{k}.scale_cc..weight + # latent_codec.y{k}.gaussian_conditional. + assert any(k.startswith("channel_context.y1.mean_cc.") for k in keys) + assert any(k.startswith("channel_context.y1.scale_cc.") for k in keys) + assert any(k.startswith("latent_codec.y0.") for k in keys) + assert any(k.startswith("latent_codec.y1.") for k in keys) + + def test_returns_channel_groups_latent_codec(self): + codec = build_channel_slice_codec( + groups=[4, 4, 4], + leaf_factory=self._leaf_factory(side_ch=8), + channel_context_factory=lambda k, ch, sup: build_mean_scale_head( + ch, sup, widths=(8, 8) + ), + ) + assert isinstance(codec, ChannelGroupsLatentCodec) + assert codec.groups == [4, 4, 4] + assert codec.max_support_slices == -1 + assert codec.support_filter is None + + def test_max_support_slices_propagates(self): + codec = build_channel_slice_codec( + groups=[4, 4, 4, 4], + leaf_factory=self._leaf_factory(side_ch=8), + channel_context_factory=lambda k, ch, sup: build_mean_scale_head( + ch, sup, widths=(8,) + ), + max_support_slices=2, + ) + assert codec.max_support_slices == 2 + # support_ch passed to channel_context.y3 should be clamped to 2 slices. + # The MeanScaleContextHead's mean_cc input width is the first conv's + # in_channels. + head_y3 = codec.channel_context["y3"] + first_mean_conv = next(m for m in head_y3.mean_cc if isinstance(m, nn.Conv2d)) + assert first_mean_conv.in_channels == 2 * 4 # 2 slices * 4 ch each + + def test_support_filter_propagates(self): + def skip_recent(k, prior): + return prior[: max(k - 1, 0)] + + codec = build_channel_slice_codec( + groups=[4, 4, 4], + leaf_factory=self._leaf_factory(side_ch=8), + channel_context_factory=lambda k, ch, sup: nn.Identity(), + support_filter=skip_recent, + ) + assert codec.support_filter is skip_recent From 8b3ea4df062b9de4935df0f6b17910a7a7110b62 Mon Sep 17 00:00:00 2001 From: boyceyi <1473416941@qq.com> Date: Sat, 9 May 2026 13:10:34 +0800 Subject: [PATCH 2/8] refactor(models/stf): migrate WACNN + SymmetricalTransFormer to containerized codec Replace the SliceEntropyCompressionModel base + monolithic ChannelSliceLatentCodec wiring with an ELIC-style HyperpriorLatentCodec that owns h_a, h_s, the z entropy bottleneck, and the per-slice channel context. Both WACNN and SymmetricalTransFormer now inherit CompressionModel directly with 5-line forward / compress / decompress methods that delegate to self.latent_codec. Supporting infra extensions (built on the prior containerized scaffolding commit 51f536f): - ChannelGroupsLatentCodec gains side_in_context: bool. Off (default) is the existing ELIC behaviour. On, _get_ctx_params: (a) routes side_params through channel_context.y0 instead of returning it raw, (b) for k>=1 feeds cat(side_params, prev_y_hat) to channel_context.y_k, (c) skips the trailing cat with side_params at the leaf level. Family 1 models opt in. - MeanScaleContextHead gains side_split (split leading 2*side_split into latent_means / latent_scales and route to mean_cc / scale_cc separately) and emit_mean_support (append cat(latent_means, prev_y_hat) to the output so downstream LRP can recover the upstream input layout). - LRPGaussianLatentCodec gains mean_support_trail_channels: when > 0, the leaf splits ctx_params into [gaussian_params, mean_support] and feeds the trailing block to the LRP transform. This restores the upstream cat(latent_means, *prev_y_hat, y_hat) LRP input shape, so the per-slice lrp_transforms.{k} weights from upstream Zou et al. checkpoints transfer byte-for-byte. - build_channel_slice_codec accepts side_in_context + side_channels and builds the y0 channel_context entry on opt-in. - _slice_helpers.infer_num_slices auto-detects whether y0 is present in state_dict and adjusts the count accordingly. Upstream checkpoint converter (convert_upstream_stf_state_dict): - Strips DataParallel module. prefix. - Re-roots cc_mean_transforms / cc_scale_transforms / lrp_transforms / gaussian_conditional / entropy_bottleneck / h_a / h_mean_s / h_scale_s under their new latent_codec.* paths. - Replicates the single shared gaussian_conditional buffer set into per-slice leaves (driven by the discovered slice count). - Nests upstream conv_b..attn.{qkv, proj, relative_position_*} keys under the WMSA wrapper level (.attn.attn.) via _nest_winmsa_keys, so WindowAttention parameters land on the right submodule. Verified end-to-end with the upstream Zou et al. checkpoints candidate/cnn_0018_best.pth.tar (WACNN) and candidate/stf_0018_best.pth.tar (SymmetricalTransFormer): both WACNN.from_state_dict and SymmetricalTransFormer.from_state_dict succeed under strict loading and forward pass. State-dict round-trip is exercised by an updated tests/test_models.py::TestStf with self-checks on the new key paths. --- compressai/latent_codecs/_slice_helpers.py | 28 +- compressai/latent_codecs/channel_groups.py | 25 + .../latent_codecs/gaussian_conditional.py | 51 ++- compressai/models/_helpers/channel_context.py | 117 ++++- compressai/models/_helpers/channel_slice.py | 59 ++- compressai/models/stf.py | 432 ++++++++++++++---- tests/test_latent_codecs.py | 93 +++- tests/test_models.py | 64 ++- tests/test_models_helpers.py | 88 ++++ 9 files changed, 806 insertions(+), 151 deletions(-) diff --git a/compressai/latent_codecs/_slice_helpers.py b/compressai/latent_codecs/_slice_helpers.py index 8c017cd2..6332d1c0 100644 --- a/compressai/latent_codecs/_slice_helpers.py +++ b/compressai/latent_codecs/_slice_helpers.py @@ -57,11 +57,14 @@ ] -# Post-refactor state-dict layout: ``ChannelGroupsLatentCodec`` lives at -# ``latent_codec.latent_codec.y`` and stores per-slice mean / scale heads -# under ``channel_context.y{k}.{mean,scale}_cc.0.weight``. Slice 0 has no -# channel context, so prefix scans should expect ``k >= 1``. -_DEFAULT_NUM_SLICES_PREFIX = "latent_codec.latent_codec.y.channel_context.y" +# Post-refactor state-dict layout: ``HyperpriorLatentCodec`` exposes +# ``ChannelGroupsLatentCodec`` as ``self.y`` (the inner ``self.latent_codec`` +# dict is not a registered nn.Module), so the channel-context entries live +# under ``latent_codec.y.channel_context.y{k}``. Slice 0 has no channel +# context entry by default (``side_in_context=False`` ELIC mode); Family 1 +# ``side_in_context=True`` mode adds a ``y0`` entry whose presence triggers +# the auto-detection in :func:`infer_num_slices`. +_DEFAULT_NUM_SLICES_PREFIX = "latent_codec.y.channel_context.y" _DEFAULT_KEY_SUFFIX = ".mean_cc.0.weight" @@ -119,9 +122,16 @@ def infer_num_slices( ) -> int: """Count distinct ``y{k}`` channel-context entries in ``state_dict``. - Slice 0 has no channel-context entry (it consumes ``side_params`` only), - so the count returned is ``num_slices - 1``; callers wanting the slice - count should add one whenever any channel context is present. + Two layouts are supported: + + - ELIC default: channel_context starts at ``y1`` (slice 0 bypasses it), + so the count returned is ``num_slices - 1`` and we add ``1`` to recover + ``num_slices``. + - Family 1 ``side_in_context=True``: channel_context covers every + slice including ``y0``, so the count is already ``num_slices``. + + The two cases are auto-detected by whether ``y0`` appears in the matched + keys. """ slice_indices = { int(key[len(prefix) :].split(".", 1)[0]) @@ -130,6 +140,8 @@ def infer_num_slices( } if not slice_indices: return 0 + if 0 in slice_indices: + return len(slice_indices) return len(slice_indices) + 1 diff --git a/compressai/latent_codecs/channel_groups.py b/compressai/latent_codecs/channel_groups.py index 06204ab2..85159393 100644 --- a/compressai/latent_codecs/channel_groups.py +++ b/compressai/latent_codecs/channel_groups.py @@ -76,6 +76,7 @@ def __init__( groups: List[int], max_support_slices: int = -1, support_filter: Optional[Callable[[int, List[Tensor]], List[Tensor]]] = None, + side_in_context: bool = False, **kwargs, ): super().__init__() @@ -84,8 +85,15 @@ def __init__( self.groups_acc = list(accumulate(self.groups, initial=0)) self.max_support_slices = int(max_support_slices) self.support_filter = support_filter + self.side_in_context = bool(side_in_context) self.channel_context = nn.ModuleDict(channel_context) self.latent_codec = nn.ModuleDict(latent_codec) + if self.side_in_context and "y0" not in self.channel_context: + raise ValueError( + "side_in_context=True requires a channel_context entry for 'y0' " + "(slice 0's channel_context absorbs side_params instead of " + "ChannelGroupsLatentCodec returning side_params raw)" + ) def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: y_ = torch.split(y, self.groups, dim=1) @@ -167,12 +175,29 @@ def merge_params(self, *args): def _get_ctx_params( self, k: int, side_params: Tensor, y_hat_: List[Tensor] ) -> Tensor: + if self.side_in_context: + return self._get_ctx_params_side_in_context(k, side_params, y_hat_) if k == 0: return side_params support = self._select_support(k, y_hat_) ch_ctx_params = self.channel_context[f"y{k}"](self.merge_y(*support)) return self.merge_params(ch_ctx_params, side_params) + def _get_ctx_params_side_in_context( + self, k: int, side_params: Tensor, y_hat_: List[Tensor] + ) -> Tensor: + # Family 1 layout (STF / WACNN / TCM / CCA): ``channel_context.y{k}`` + # absorbs ``side_params`` directly so its mean_cc / scale_cc heads can + # see the hyperprior latent_means / latent_scales alongside the + # previously decoded slices. The head's output already encodes the + # final per-slice (mean, scale) prediction, so no further cat with + # ``side_params`` is needed before the leaf. + if k == 0: + return self.channel_context["y0"](side_params) + support = self._select_support(k, y_hat_) + ch_input = self.merge_params(side_params, self.merge_y(*support)) + return self.channel_context[f"y{k}"](ch_input) + def _select_support(self, k: int, y_hat_: List[Tensor]) -> List[Tensor]: prior = list(y_hat_[:k]) if self.support_filter is not None: diff --git a/compressai/latent_codecs/gaussian_conditional.py b/compressai/latent_codecs/gaussian_conditional.py index 456d87bc..b9fb34cb 100644 --- a/compressai/latent_codecs/gaussian_conditional.py +++ b/compressai/latent_codecs/gaussian_conditional.py @@ -152,16 +152,32 @@ class LRPGaussianLatentCodec(GaussianConditionalLatentCodec): Wraps :class:`GaussianConditionalLatentCodec` and applies an additive LRP head to the quantized latent ``y_hat``. The LRP head receives - ``cat(ctx_params, y_hat)`` and produces a residual scaled by ``lrp_scale`` - and squashed by ``tanh``:: + ``cat(mean_support, y_hat)`` and produces a residual scaled by + ``lrp_scale`` and squashed by ``tanh``:: - y_hat = y_hat + lrp_scale * tanh(lrp_transform(cat(ctx_params, y_hat))) + y_hat = y_hat + lrp_scale * tanh(lrp_transform(cat(mean_support, y_hat))) Used as the per-slice leaf for Family 1 channel-slice models (STF / WACNN / TCM / CCA-main, plus the first ``K-2`` slices of CCA-aux). The LRP refinement variant was introduced in [Zhu2022] and is widely adopted by follow-up work (ELIC checkerboard variants, MLIC++, TCM, ...). + The ``mean_support`` tensor that feeds the LRP head depends on + ``mean_support_trail_channels``: + + - ``0`` (default): ``mean_support = ctx_params`` — LRP sees the full + ctx_params concatenated with ``y_hat``. + - ``> 0``: ``ctx_params`` is expected to be laid out as + ``cat(gaussian_params, mean_support)`` where the trailing + ``mean_support_trail_channels`` block carries + ``cat(latent_means, *prev_y_hat)``. The leaf forwards only + ``gaussian_params = ctx_params[:, :-mean_support_trail_channels]`` to + the underlying :class:`GaussianConditionalLatentCodec` (so chunk + semantics are preserved) and uses the trailing block as + ``mean_support`` for LRP. This recovers the upstream STF / WACNN LRP + input layout (``cat(latent_means, *prev_y_hat, y_hat)``), enabling + byte-for-byte transfer of upstream LRP weights. + [Zhu2022]: `"Transformer-based Transform Coding" `_, by Yinhao Zhu, Yang Yang and Taco Cohen, ICLR 2022. @@ -174,26 +190,38 @@ def __init__( lrp_transform: nn.Module, *, lrp_scale: float = 0.5, + mean_support_trail_channels: int = 0, **gc_kwargs: Any, ) -> None: super().__init__(**gc_kwargs) self.lrp_transform = lrp_transform self.lrp_scale = float(lrp_scale) + self.mean_support_trail_channels = int(mean_support_trail_channels) + + def _split_ctx_params(self, ctx_params: Tensor) -> Tuple[Tensor, Tensor]: + if self.mean_support_trail_channels <= 0: + return ctx_params, ctx_params + trail = self.mean_support_trail_channels + gaussian_params = ctx_params[:, :-trail] + mean_support = ctx_params[:, -trail:] + return gaussian_params, mean_support - def _apply_lrp(self, ctx_params: Tensor, y_hat: Tensor) -> Tensor: + def _apply_lrp(self, mean_support: Tensor, y_hat: Tensor) -> Tensor: lrp = self.lrp_scale * torch.tanh( - self.lrp_transform(torch.cat([ctx_params, y_hat], dim=1)) + self.lrp_transform(torch.cat([mean_support, y_hat], dim=1)) ) return y_hat + lrp def forward(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: - out = super().forward(y, ctx_params) - out["y_hat"] = self._apply_lrp(ctx_params, out["y_hat"]) + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().forward(y, gaussian_params) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) return out def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: - out = super().compress(y, ctx_params) - out["y_hat"] = self._apply_lrp(ctx_params, out["y_hat"]) + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().compress(y, gaussian_params) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) return out def decompress( @@ -203,6 +231,7 @@ def decompress( ctx_params: Tensor, **kwargs: Any, ) -> Dict[str, Any]: - out = super().decompress(strings, shape, ctx_params, **kwargs) - out["y_hat"] = self._apply_lrp(ctx_params, out["y_hat"]) + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().decompress(strings, shape, gaussian_params, **kwargs) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) return out diff --git a/compressai/models/_helpers/channel_context.py b/compressai/models/_helpers/channel_context.py index e17e1298..1523f67c 100644 --- a/compressai/models/_helpers/channel_context.py +++ b/compressai/models/_helpers/channel_context.py @@ -62,10 +62,42 @@ class MeanScaleContextHead(nn.Module): mean_cc: in_channels -> ... -> slice_ch scale_cc: in_channels -> ... -> slice_ch - Forward output is ``cat([mean_cc(...), scale_cc(...)], dim=1)`` of shape - ``(B, 2 * slice_ch, H, W)``. Optional ``mean_support_transform`` / - ``scale_support_transform`` run independently on the input before the - sub-networks (used for SWAtten in TCM and NAFTransform in CCA). + Forward output is ``cat([scale_cc(...), mean_cc(...)], dim=1)`` of shape + ``(B, 2 * slice_ch, H, W)`` — order matches + :class:`GaussianConditionalLatentCodec` ``chunks=("scales", "means")``. + Optional ``mean_support_transform`` / ``scale_support_transform`` run + independently on the input before the sub-networks (used for SWAtten in + TCM and NAFTransform in CCA). + + When ``side_split > 0`` the head expects its input to be the + concatenation ``cat(latent_means(side_split), latent_scales(side_split), + *prev_y_hat)`` produced by + :class:`~compressai.latent_codecs.ChannelGroupsLatentCodec` running in + ``side_in_context=True`` mode. The head splits the leading + ``2 * side_split`` channels back into ``latent_means`` / + ``latent_scales`` and routes: + + - ``mean_cc(cat(latent_means, *prev_y_hat))`` + - ``scale_cc(cat(latent_scales, *prev_y_hat))`` + + so each sub-network sees the same input shape it would have under the + pre-refactor STF / WACNN / TCM / CCA wiring (``cc_mean_transforms[k]`` / + ``cc_scale_transforms[k]``). This keeps state-dict weights compatible + with the legacy layout when migrating via + ``convert_*_checkpoint.py``. + + When ``side_split == 0`` (default) the head is generic: ``mean_cc`` and + ``scale_cc`` both see the full input, no internal split. + + When ``emit_mean_support=True`` (only meaningful with ``side_split > 0``) + the head appends the ``mean_in = cat(latent_means, *prev_y_hat)`` tensor + to the output, producing + ``cat(scale, mean, mean_in)`` of shape + ``(B, 2*slice_ch + side_split + sum(prev_groups), H, W)``. This trailing + block is consumed by :class:`LRPGaussianLatentCodec` (with matching + ``mean_support_trail_channels``) to recover the upstream STF / WACNN + LRP input layout (``cat(latent_means, *prev_y_hat, y_hat)``), enabling + byte-for-byte transfer of upstream LRP weights. """ mean_cc: nn.Module @@ -79,17 +111,39 @@ def __init__( scale_cc: nn.Module, mean_support_transform: Optional[nn.Module] = None, scale_support_transform: Optional[nn.Module] = None, + *, + side_split: int = 0, + emit_mean_support: bool = False, ) -> None: super().__init__() self.mean_cc = mean_cc self.scale_cc = scale_cc self.mean_support_transform = mean_support_transform or nn.Identity() self.scale_support_transform = scale_support_transform or nn.Identity() + self.side_split = int(side_split) + self.emit_mean_support = bool(emit_mean_support) + if self.emit_mean_support and self.side_split <= 0: + raise ValueError( + "emit_mean_support=True requires side_split > 0 to recover " + "the legacy mean_support layout cat(latent_means, *prev_y_hat)." + ) def forward(self, x: Tensor) -> Tensor: - mean = self.mean_cc(self.mean_support_transform(x)) - scale = self.scale_cc(self.scale_support_transform(x)) - return torch.cat([mean, scale], dim=1) + if self.side_split > 0: + split = self.side_split + latent_means = x[:, :split] + latent_scales = x[:, split : 2 * split] + prev_y_hat = x[:, 2 * split :] + mean_in = torch.cat([latent_means, prev_y_hat], dim=1) + scale_in = torch.cat([latent_scales, prev_y_hat], dim=1) + else: + mean_in = scale_in = x + mean = self.mean_cc(self.mean_support_transform(mean_in)) + scale = self.scale_cc(self.scale_support_transform(scale_in)) + out = torch.cat([scale, mean], dim=1) + if self.emit_mean_support: + out = torch.cat([out, mean_in], dim=1) + return out def build_mean_scale_head( @@ -98,6 +152,8 @@ def build_mean_scale_head( *, widths: Sequence[int] = (224, 128), support_transform_factory: Optional[Callable[[int, int], nn.Module]] = None, + side_split: int = 0, + emit_mean_support: bool = False, ) -> MeanScaleContextHead: """Construct a :class:`MeanScaleContextHead` with default conv-stack heads. @@ -106,9 +162,14 @@ def build_mean_scale_head( slice_ch Channel count of the slice being predicted (per-sub-head output). support_ch - Input channel count to ``mean_cc`` / ``scale_cc`` (post-support- - transform). Caller is responsible for accounting for any extra - channels that the application's wiring concatenates upstream. + FULL input channel count to the head (i.e., what + :class:`ChannelGroupsLatentCodec` will hand it). When + ``side_split > 0`` this equals ``2 * side_split + slice_ch * + support_count``; the head will internally split off ``2 * side_split`` + channels and route ``side_split`` each to ``mean_cc`` / ``scale_cc``, + so each sub-network receives ``support_ch - side_split`` channels. + When ``side_split == 0`` ``mean_cc`` / ``scale_cc`` see the full + ``support_ch`` directly. widths Hidden conv widths inside the ``mean_cc`` / ``scale_cc`` Sequentials. STF / WACNN use ``(224, 176, 128, 64)``; TCM / CCA use @@ -117,15 +178,39 @@ def build_mean_scale_head( ``(in_ch, out_ch) -> nn.Module``. When supplied, builds independent instances for the mean and scale paths (e.g., per-slice SWAtten in TCM or NAFTransform in CCA). Both transforms are expected to - preserve channel count. + preserve channel count and are applied to the per-path input + (``support_ch - side_split`` channels). + side_split + Number of leading channels in the input that hold ``latent_means`` + (with ``latent_scales`` immediately after, also ``side_split`` wide). + Set to the hyper-synthesis output channel count ``M`` for the + Family 1 ``side_in_context=True`` wiring; leave ``0`` for generic + usage. + emit_mean_support + Forwarded to :class:`MeanScaleContextHead`. Why this flag exists: + the upstream STF / WACNN / TCM / CCA LRP transform consumes + ``cat(latent_means, *prev_y_hat, y_hat)`` (i.e. ``M + slice_ch * + (support_count + 1)`` channels — variable per slice). The Phase 3 + leaf only sees the channel-context ``ctx_params`` (= 2*slice_ch) and + ``y_hat``, which would force an architectural change to the LRP + transform input width and prevent byte-for-byte transfer of upstream + LRP weights. Setting ``emit_mean_support=True`` makes the head + append ``mean_in = cat(latent_means, *prev_y_hat)`` to its output; + :class:`LRPGaussianLatentCodec` (with matching + ``mean_support_trail_channels``) then strips that trailing block off + ``ctx_params``, feeds only the leading ``2*slice_ch`` to the + Gaussian conditional's ``chunks=("scales","means")`` step, and uses + the trailing block as the LRP input — recovering the upstream layout + exactly. """ - mean_cc = make_entropy_transform(support_ch, slice_ch, widths=widths) - scale_cc = make_entropy_transform(support_ch, slice_ch, widths=widths) + sub_in_ch = support_ch - side_split + mean_cc = make_entropy_transform(sub_in_ch, slice_ch, widths=widths) + scale_cc = make_entropy_transform(sub_in_ch, slice_ch, widths=widths) mean_support: Optional[nn.Module] scale_support: Optional[nn.Module] if support_transform_factory is not None: - mean_support = support_transform_factory(support_ch, support_ch) - scale_support = support_transform_factory(support_ch, support_ch) + mean_support = support_transform_factory(sub_in_ch, sub_in_ch) + scale_support = support_transform_factory(sub_in_ch, sub_in_ch) else: mean_support = None scale_support = None @@ -134,4 +219,6 @@ def build_mean_scale_head( scale_cc=scale_cc, mean_support_transform=mean_support, scale_support_transform=scale_support, + side_split=side_split, + emit_mean_support=emit_mean_support, ) diff --git a/compressai/models/_helpers/channel_slice.py b/compressai/models/_helpers/channel_slice.py index 3ac728f3..f6891e5d 100644 --- a/compressai/models/_helpers/channel_slice.py +++ b/compressai/models/_helpers/channel_slice.py @@ -50,12 +50,18 @@ def build_channel_slice_codec( channel_context_factory: Optional[Callable[[int, int, int], nn.Module]] = None, max_support_slices: int = -1, support_filter: Optional[Callable[[int, List[Tensor]], List[Tensor]]] = None, + side_in_context: bool = False, + side_channels: int = 0, ) -> ChannelGroupsLatentCodec: """Assemble a :class:`ChannelGroupsLatentCodec` with per-slice modules. Generates the ``{"y0".."yK-1"}`` ``latent_codec`` dict and the ``{"y1".."yK-1"}`` ``channel_context`` dict (slice 0 has no channel - context — it consumes ``side_params`` only). + context — it consumes ``side_params`` only). When + ``side_in_context=True`` the ``channel_context`` dict additionally + includes a ``"y0"`` entry whose input is just ``side_params``; the + leaf for slice 0 then receives the head's output (already shaped + ``2 * groups[0]``) instead of raw ``side_params``. Parameters ---------- @@ -68,33 +74,59 @@ def build_channel_slice_codec( :class:`GaussianConditionalLatentCodec`. channel_context_factory ``(k, slice_ch_k, support_ch_k) -> nn.Module``. Constructs the - channel-context module for slice ``k`` (``k >= 1``). ``support_ch_k`` - is the total channel count of the previous slices that will be fed - in (post ``max_support_slices`` clamp). Default ``None`` uses - :class:`~torch.nn.Identity`, which is rarely useful in practice but - keeps the API parallel with ``leaf_factory``. + channel-context module for slice ``k``. ``support_ch_k`` is the + TOTAL channel count of the head's input — i.e., what + :class:`ChannelGroupsLatentCodec._get_ctx_params` will hand it. + For ELIC default mode (``side_in_context=False``) ``support_ch_k = + sum(groups[:clamped_k])`` and only ``k >= 1`` entries are built. + For Family 1 mode (``side_in_context=True``) ``support_ch_k = + side_channels + sum(groups[:clamped_k])`` and a ``y0`` entry with + ``support_ch_0 = side_channels`` is built too. max_support_slices Forwarded to :class:`ChannelGroupsLatentCodec`. Default ``-1`` uses all previous slices (ELIC / CCA-main behaviour). support_filter Forwarded to :class:`ChannelGroupsLatentCodec`. Used by CCA-aux for skip-most-recent support selection. + side_in_context + Forwarded to :class:`ChannelGroupsLatentCodec`. When ``True`` the + ``channel_context`` for ``y0`` consumes ``side_params`` and + downstream ``y_k`` heads receive ``cat(side_params, prev_y_hat)``. + side_channels + Width of ``side_params`` (= hyper-synthesis output channel count). + Required when ``side_in_context=True`` so the factory can size + ``support_ch`` correctly. """ if channel_context_factory is None: channel_context_factory = lambda *_: nn.Identity() # noqa: E731 + if side_in_context and side_channels <= 0: + raise ValueError( + "side_in_context=True requires side_channels > 0 so the factory " + "can size the channel_context heads (== side_channels for k=0; " + "side_channels + sum(groups[:k]) clamped, for k>=1)." + ) K = len(groups) - def _support_ch(k: int) -> int: + def _support_count(k: int) -> int: if max_support_slices < 0: - count = k - else: - count = min(k, max_support_slices) - return sum(groups[:count]) + return k + return min(k, max_support_slices) + def _support_ch(k: int) -> int: + prior_ch = sum(groups[: _support_count(k)]) + if side_in_context: + return side_channels + prior_ch + return prior_ch + + if side_in_context: + # y0 entry: head sees only side_params (no prev_y_hat yet). + ctx_keys = range(0, K) + else: + # ELIC default: slice 0 bypasses channel_context entirely. + ctx_keys = range(1, K) channel_context = { - f"y{k}": channel_context_factory(k, groups[k], _support_ch(k)) - for k in range(1, K) + f"y{k}": channel_context_factory(k, groups[k], _support_ch(k)) for k in ctx_keys } latent_codec = {f"y{k}": leaf_factory(k, groups[k]) for k in range(K)} @@ -104,4 +136,5 @@ def _support_ch(k: int) -> int: groups=list(groups), max_support_slices=max_support_slices, support_filter=support_filter, + side_in_context=side_in_context, ) diff --git a/compressai/models/stf.py b/compressai/models/stf.py index 5074d5b6..14ae9a0a 100644 --- a/compressai/models/stf.py +++ b/compressai/models/stf.py @@ -36,6 +36,7 @@ from __future__ import annotations import math +import re from typing import Dict, Optional, Sequence, Tuple, Type @@ -46,17 +47,27 @@ from timm.models.swin_transformer import SwinTransformerBlock as _TimmSwinBlock from torch import Tensor -from compressai.layers import GDN, conv1x1, conv3x3, subpel_conv3x3 +from compressai.entropy_models import EntropyBottleneck +from compressai.latent_codecs import ( + DualHyperSynthesis, + EntropyBottleneckLatentCodec, + HyperpriorLatentCodec, + LRPGaussianLatentCodec, +) +from compressai.latent_codecs._slice_helpers import ( + infer_max_support_slices, + infer_num_slices, + make_entropy_transform, +) +from compressai.layers import GDN, conv3x3, subpel_conv3x3 from compressai.layers.attn import ( PatchMerging, PatchSplit, WinNoShiftAttention, ) -from compressai.models._bases import ( - SliceEntropyCompressionModel, - infer_max_support_slices, - infer_num_slices, -) +from compressai.models._helpers.channel_context import build_mean_scale_head +from compressai.models._helpers.channel_slice import build_channel_slice_codec +from compressai.models.base import CompressionModel from compressai.models.utils import conv, deconv from compressai.registry import register_model @@ -209,22 +220,67 @@ def forward(self, input_tensor: Tensor) -> Tensor: "gaussian_conditional", ) +# Top-level rename map applied AFTER per-slice cc_/lrp_/gaussian_conditional +# rerooting. Keys are matched as exact prefixes (with the trailing dot). +_UPSTREAM_TOP_LEVEL_RENAMES: Dict[str, str] = { + "h_a.": "latent_codec.h_a.", + "h_mean_s.": "latent_codec.h_s.h_mean_s.", + "h_scale_s.": "latent_codec.h_s.h_scale_s.", + "entropy_bottleneck.": "latent_codec.z.entropy_bottleneck.", +} + +# Upstream STF places the WindowAttention parameters directly under +# ``conv_b..attn.{qkv,proj,relative_position_*}``. CompressAI wraps the +# WindowAttention inside a :class:`compressai.layers.attn.swin.WMSA` shim, so +# the live model keeps ``WMSA.attn = WindowAttention(...)`` and the +# parameters land at ``conv_b..attn.attn.*``. This regex inserts the extra +# ``.attn`` so renamed upstream keys round-trip into the WMSA wrapper without +# changing the model topology. +_WMSA_NEST_PATTERN = re.compile( + r"(\.conv_b\.\d+\.attn)\.(qkv\.|proj\.|relative_position_)" +) + -def convert_upstream_stf_state_dict(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: +def _nest_winmsa_keys(key: str) -> str: + """Insert the WMSA wrapper level (``.attn``) into upstream + ``conv_b.*.attn.{qkv,proj,relative_position_*}`` keys.""" + return _WMSA_NEST_PATTERN.sub(r"\1.attn.\2", key) + + +def convert_upstream_stf_state_dict( + state_dict: Dict[str, Tensor], +) -> Dict[str, Tensor]: """Translate a candidate ``STF`` / ``WACNN`` state dict into compressai layout. Upstream checkpoints (``stf__best.pth.tar`` / ``cnn__best.pth.tar`` from `Zou et al. 2022 `_) are saved from a ``DataParallel``-wrapped module and place the channel-conditional entropy - transforms at the model root. compressai houses those transforms (plus the - Gaussian conditional) under ``latent_codec.*``. This helper: + transforms at the model root. After the H+G containerised refactor + compressai houses those transforms (plus the Gaussian conditional and + the hyperprior backbone) inside ``latent_codec.*``. This helper: - strips the leading ``module.`` prefix added by ``DataParallel``; - - re-roots ``cc_mean_transforms`` / ``cc_scale_transforms`` / - ``lrp_transforms`` / ``gaussian_conditional`` under ``latent_codec.``; - - leaves ``g_a`` / ``g_s`` / ``patch_embed`` / ``layers`` / ``syn_layers`` - / ``end_conv`` / ``h_a`` / ``h_mean_s`` / ``h_scale_s`` / - ``entropy_bottleneck`` keys unchanged. + - re-roots ``cc_mean_transforms.{k}`` / ``cc_scale_transforms.{k}`` / + ``lrp_transforms.{k}`` under + ``latent_codec.y.channel_context.y{k}.{mean_cc,scale_cc}.*`` / + ``latent_codec.y.latent_codec.y{k}.lrp_transform.*``; + - replicates the single shared ``gaussian_conditional.*`` buffer set + under each per-slice leaf (``latent_codec.y.latent_codec.y{k}.gaussian_conditional.*``); + - moves ``entropy_bottleneck.*`` / ``h_a.*`` / ``h_mean_s.*`` / + ``h_scale_s.*`` under ``latent_codec.*`` per the new layout; + - leaves ``g_a`` / ``g_s`` / ``patch_embed`` / ``layers`` / + ``syn_layers`` / ``end_conv`` keys unchanged. + + .. caveat:: + The Phase 3 wiring sets ``emit_mean_support=True`` on the + ``MeanScaleContextHead`` so the upstream LRP layout + (``cat(latent_means, *prev_y_hat, y_hat)``) is recoverable inside the + leaf — upstream ``lrp_transforms.{k}`` weights therefore transfer + byte-for-byte. The model's ``WinNoShiftAttention`` consumers wrap + their windowed-attention layers in a :class:`WMSA` shim, so the + conversion also nests upstream ``conv_b.{i}.attn.{qkv,proj, + relative_position_*}`` keys under the extra ``.attn`` level (see + :func:`_nest_winmsa_keys`). The returned dict can be loaded by :meth:`WACNN.from_state_dict` or :meth:`SymmetricalTransFormer.from_state_dict`. Both ``from_state_dict`` @@ -232,12 +288,73 @@ def convert_upstream_stf_state_dict(state_dict: Dict[str, Tensor]) -> Dict[str, direct invocation is only needed when persisting the converted dict. """ converted: Dict[str, Tensor] = {} + + _LEGACY_ROOT_HEADS = set(_UPSTREAM_LATENT_CODEC_PREFIXES) | { + "h_a", + "h_mean_s", + "h_scale_s", + "entropy_bottleneck", + } + + # Pass 1: strip module. prefix, fold the upstream single-``attn`` window + # attention path back into compressai's WMSA wrapper layout, and inventory + # which keys exist. + cleaned: Dict[str, Tensor] = {} + has_legacy_root_keys = False for key, value in state_dict.items(): new_key = key[len("module.") :] if key.startswith("module.") else key - head = new_key.split(".", 1)[0] - if head in _UPSTREAM_LATENT_CODEC_PREFIXES: - new_key = "latent_codec." + new_key - converted[new_key] = value + new_key = _nest_winmsa_keys(new_key) + cleaned[new_key] = value + if new_key.split(".", 1)[0] in _LEGACY_ROOT_HEADS: + has_legacy_root_keys = True + + if not has_legacy_root_keys: + # Already in (or near) the new layout — return cleaned dict as-is. + return cleaned + + # Pass 2: discover slice indices to drive gaussian_conditional replication + # and per-slice rerooting. + slice_indices = sorted( + { + int(key.split(".")[1]) + for key in cleaned + if key.startswith("cc_mean_transforms.") + } + ) + num_slices = len(slice_indices) + + for key, value in cleaned.items(): + head = key.split(".", 1)[0] + if head == "cc_mean_transforms": + _, k, *rest = key.split(".") + new_key = f"latent_codec.y.channel_context.y{k}.mean_cc." + ".".join(rest) + converted[new_key] = value + elif head == "cc_scale_transforms": + _, k, *rest = key.split(".") + new_key = f"latent_codec.y.channel_context.y{k}.scale_cc." + ".".join(rest) + converted[new_key] = value + elif head == "lrp_transforms": + _, k, *rest = key.split(".") + new_key = f"latent_codec.y.latent_codec.y{k}.lrp_transform." + ".".join( + rest + ) + converted[new_key] = value + elif head == "gaussian_conditional": + # Replicate the single shared instance to per-slice leaves. + tail = key[len("gaussian_conditional.") :] + for k in range(num_slices): + new_key = ( + f"latent_codec.y.latent_codec.y{k}" f".gaussian_conditional.{tail}" + ) + converted[new_key] = value + else: + renamed = key + for prefix, replacement in _UPSTREAM_TOP_LEVEL_RENAMES.items(): + if key.startswith(prefix): + renamed = replacement + key[len(prefix) :] + break + converted[renamed] = value + return converted @@ -248,15 +365,19 @@ def _is_upstream_stf_state_dict(state_dict: Dict[str, Tensor]) -> bool: for key in state_dict: if key.startswith("module."): return True - if key.startswith("cc_mean_transforms.") or key.startswith( - "gaussian_conditional." - ): + head = key.split(".", 1)[0] + if head in _UPSTREAM_LATENT_CODEC_PREFIXES or head in { + "h_a", + "h_mean_s", + "h_scale_s", + "entropy_bottleneck", + }: return True return False @register_model("stf-wacnn") -class WACNN(SliceEntropyCompressionModel): +class WACNN(CompressionModel): r"""WACNN model from R. Zou, C. Song, Z. Zhang: `"The Devil Is in the Details: Window-based Attention for Image Compression" `_, IEEE/CVF Conf. on Computer Vision @@ -267,6 +388,11 @@ class WACNN(SliceEntropyCompressionModel): ``output_proj=False``) inside the analysis/synthesis transforms, paired with a Minnen2020-style channel-wise autoregressive entropy model. + The entropy stack is a fully containerised + :class:`HyperpriorLatentCodec` that owns ``h_a``, ``h_s``, the ``z`` + bottleneck and the per-slice ``ChannelGroupsLatentCodec`` running in + Family 1 ``side_in_context=True`` mode. + Args: N (int): Number of channels in the hyperprior backbone. M (int): Number of channels in the latent representation. @@ -282,6 +408,10 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) + if M % num_slices != 0: + raise ValueError("M must be divisible by num_slices") + slice_ch = M // num_slices + self.g_a = nn.Sequential( conv(3, N, kernel_size=5, stride=2), GDN(N), @@ -312,7 +442,8 @@ def __init__( GDN(N, inverse=True), deconv(N, 3, kernel_size=5, stride=2), ) - self.h_a = nn.Sequential( + + h_a = nn.Sequential( conv3x3(M, M), nn.GELU(), conv3x3(M, 288), @@ -323,50 +454,41 @@ def __init__( nn.GELU(), conv3x3(224, N, stride=2), ) - self.h_mean_s = nn.Sequential( - conv3x3(N, N), - nn.GELU(), - subpel_conv3x3(N, 224, 2), - nn.GELU(), - conv3x3(224, 256), - nn.GELU(), - subpel_conv3x3(256, 288, 2), - nn.GELU(), - conv3x3(288, M), - ) - self.h_scale_s = nn.Sequential( - conv3x3(N, N), - nn.GELU(), - subpel_conv3x3(N, 224, 2), - nn.GELU(), - conv3x3(224, 256), - nn.GELU(), - subpel_conv3x3(256, 288, 2), - nn.GELU(), - conv3x3(288, M), - ) - self._init_slice_entropy( - M, - N, - num_slices, - max_support_slices, + h_mean_s = _build_stf_h_subpel(N, M) + h_scale_s = _build_stf_h_subpel(N, M) + + self.latent_codec = _build_family1_latent_codec( + N=N, + M=M, + slice_ch=slice_ch, + num_slices=num_slices, + max_support_slices=max_support_slices, + widths=(224, 176, 128, 64), + h_a=h_a, + h_mean_s=h_mean_s, + h_scale_s=h_scale_s, ) def forward(self, x: Tensor) -> Dict[str, Dict[str, Tensor] | Tensor]: y = self.g_a(x) - latent_output = self._forward_latent_output(y) + y_out = self.latent_codec(y) return { - "x_hat": self.g_s(latent_output["y_hat"]), - "likelihoods": latent_output["likelihoods"], + "x_hat": self.g_s(y_out["y_hat"]), + "likelihoods": y_out["likelihoods"], } def compress(self, x: Tensor) -> Dict[str, object]: - return self._compress_latent(self.g_a(x)) + y = self.g_a(x) + y_out = self.latent_codec.compress(y) + return {"strings": y_out["strings"], "shape": y_out["shape"]} def decompress( - self, strings: Sequence[Sequence[bytes]], shape: Tuple[int, int] + self, + strings: Sequence[Sequence[bytes]], + shape: Dict[str, Tuple[int, ...]] | Tuple[int, int], ) -> Dict[str, Tensor]: - return {"x_hat": self.g_s(self._decompress_latent(strings, shape)).clamp_(0, 1)} + y_out = self.latent_codec.decompress(strings, shape) + return {"x_hat": self.g_s(y_out["y_hat"]).clamp_(0, 1)} @classmethod def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "WACNN": @@ -386,8 +508,130 @@ def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "WACNN": return net +def _build_stf_h_subpel(N: int, M: int) -> nn.Sequential: + """Default ``h_mean_s`` / ``h_scale_s`` stack used by both WACNN and + SymmetricalTransFormer's WACNN-shaped variant: 5 conv / subpel blocks + going from ``N -> N -> 224 -> 256 -> 288 -> M`` with GELU activations. + """ + return nn.Sequential( + conv3x3(N, N), + nn.GELU(), + subpel_conv3x3(N, 224, 2), + nn.GELU(), + conv3x3(224, 256), + nn.GELU(), + subpel_conv3x3(256, 288, 2), + nn.GELU(), + conv3x3(288, M), + ) + + +def _build_stf_transformer_h_subpel( + bottleneck_channels: int, latent_channels: int, embed_dim: int +) -> nn.Sequential: + """Hyper-synthesis stack used by :class:`SymmetricalTransFormer`. + + Mirrors the original Zou et al. STF Transformer configuration: widths + derived from the per-stage channel counts (``latent_channels - k * + embed_dim``) instead of the WACNN-style fixed ladder. + """ + return nn.Sequential( + conv3x3(bottleneck_channels, latent_channels - 3 * embed_dim), + nn.GELU(), + subpel_conv3x3( + latent_channels - 3 * embed_dim, latent_channels - 2 * embed_dim, 2 + ), + nn.GELU(), + conv3x3(latent_channels - 2 * embed_dim, latent_channels - embed_dim), + nn.GELU(), + subpel_conv3x3(latent_channels - embed_dim, latent_channels, 2), + nn.GELU(), + conv3x3(latent_channels, latent_channels), + ) + + +def _build_family1_latent_codec( + *, + N: int, + M: int, + slice_ch: int, + num_slices: int, + max_support_slices: int, + widths: Sequence[int], + h_a: nn.Module, + h_mean_s: nn.Module, + h_scale_s: nn.Module, +) -> HyperpriorLatentCodec: + """Assemble the Phase 3 Family 1 entropy stack: ``HyperpriorLatentCodec`` + wrapping ``DualHyperSynthesis`` and a per-slice + ``ChannelGroupsLatentCodec`` (``side_in_context=True``) whose channel + contexts are :class:`MeanScaleContextHead` instances and leaves are + :class:`LRPGaussianLatentCodec` (STE-quantised). ``side_channels = 2 * + M`` because ``DualHyperSynthesis`` cats ``h_mean_s(z_hat)`` and + ``h_scale_s(z_hat)``. + + The channel-context heads run with ``emit_mean_support=True`` so each + head appends ``cat(latent_means, *prev_y_hat)`` to its output; the leaf + splits that trailing block off (``mean_support_trail_channels``) and + uses it as the LRP input. This reproduces the upstream STF / WACNN LRP + layout (``cat(latent_means, *prev_y_hat, y_hat)``), so the + ``lrp_transforms.{k}`` weights from upstream Zou et al. checkpoints + transfer byte-for-byte after the rename pass in + :func:`convert_upstream_stf_state_dict`. + """ + side_channels = 2 * M + + def _support_count(k: int) -> int: + if max_support_slices < 0: + return k + return min(k, max_support_slices) + + def _mean_support_ch(k: int) -> int: + # cat(latent_means(M), *prev_y_hat(slice_ch * support_count)). + return M + slice_ch * _support_count(k) + + def _leaf(k: int, _slice_ch: int) -> LRPGaussianLatentCodec: + ms_ch = _mean_support_ch(k) + return LRPGaussianLatentCodec( + lrp_transform=make_entropy_transform( + ms_ch + _slice_ch, # cat(mean_support, y_hat) + _slice_ch, + widths=widths, + ), + mean_support_trail_channels=ms_ch, + quantizer="ste", + ) + + def _channel_context(_k: int, _slice_ch: int, support_ch: int) -> nn.Module: + return build_mean_scale_head( + slice_ch=_slice_ch, + support_ch=support_ch, + widths=widths, + side_split=M, + emit_mean_support=True, + ) + + return HyperpriorLatentCodec( + h_a=h_a, + h_s=DualHyperSynthesis(h_mean_s, h_scale_s), + latent_codec={ + "z": EntropyBottleneckLatentCodec( + entropy_bottleneck=EntropyBottleneck(N), quantizer="noise" + ), + "y": build_channel_slice_codec( + groups=[slice_ch] * num_slices, + side_channels=side_channels, + side_in_context=True, + max_support_slices=max_support_slices, + leaf_factory=_leaf, + channel_context_factory=_channel_context, + ), + }, + ) + + @register_model("stf") -class SymmetricalTransFormer(SliceEntropyCompressionModel): +class SymmetricalTransFormer(CompressionModel): r"""Symmetrical Transformer model (STF) from R. Zou, C. Song, Z. Zhang: `"The Devil Is in the Details: Window-based Attention for Image Compression" `_, IEEE/CVF Conf. on @@ -395,7 +639,10 @@ class SymmetricalTransFormer(SliceEntropyCompressionModel): Transformer-based companion of :class:`WACNN` that builds the analysis/synthesis transforms with stacked Swin-style basic layers and a - channel-wise autoregressive entropy model. + channel-wise autoregressive entropy model. The entropy stack mirrors + :class:`WACNN`'s containerised :class:`HyperpriorLatentCodec` (Family 1 + ``side_in_context=True`` mode), with widths derived from the + transformer's stage channel counts. Args: embed_dim (int): Patch-embedding dimension. @@ -502,7 +749,14 @@ def __init__( latent_channels = int(embed_dim * 2 ** (self.num_layers - 1)) bottleneck_channels = latent_channels // 2 - self.h_a = nn.Sequential( + if latent_channels % num_slices != 0: + raise ValueError("latent_channels must be divisible by num_slices") + slice_ch = latent_channels // num_slices + resolved_max_support = ( + num_slices // 2 if max_support_slices is None else max_support_slices + ) + + h_a = nn.Sequential( conv3x3(latent_channels, latent_channels), nn.GELU(), conv3x3(latent_channels, latent_channels - embed_dim), @@ -515,37 +769,23 @@ def __init__( nn.GELU(), conv3x3(latent_channels - 3 * embed_dim, bottleneck_channels, stride=2), ) - self.h_mean_s = nn.Sequential( - conv3x3(bottleneck_channels, latent_channels - 3 * embed_dim), - nn.GELU(), - subpel_conv3x3( - latent_channels - 3 * embed_dim, latent_channels - 2 * embed_dim, 2 - ), - nn.GELU(), - conv3x3(latent_channels - 2 * embed_dim, latent_channels - embed_dim), - nn.GELU(), - subpel_conv3x3(latent_channels - embed_dim, latent_channels, 2), - nn.GELU(), - conv3x3(latent_channels, latent_channels), + h_mean_s = _build_stf_transformer_h_subpel( + bottleneck_channels, latent_channels, embed_dim ) - self.h_scale_s = nn.Sequential( - conv3x3(bottleneck_channels, latent_channels - 3 * embed_dim), - nn.GELU(), - subpel_conv3x3( - latent_channels - 3 * embed_dim, latent_channels - 2 * embed_dim, 2 - ), - nn.GELU(), - conv3x3(latent_channels - 2 * embed_dim, latent_channels - embed_dim), - nn.GELU(), - subpel_conv3x3(latent_channels - embed_dim, latent_channels, 2), - nn.GELU(), - conv3x3(latent_channels, latent_channels), + h_scale_s = _build_stf_transformer_h_subpel( + bottleneck_channels, latent_channels, embed_dim ) - self._init_slice_entropy( - latent_channels, - bottleneck_channels, - num_slices, - num_slices // 2 if max_support_slices is None else max_support_slices, + + self.latent_codec = _build_family1_latent_codec( + N=bottleneck_channels, + M=latent_channels, + slice_ch=slice_ch, + num_slices=num_slices, + max_support_slices=resolved_max_support, + widths=(224, 176, 128, 64), + h_a=h_a, + h_mean_s=h_mean_s, + h_scale_s=h_scale_s, ) def _analysis_transform(self, x: Tensor) -> Tuple[Tensor, int, int]: @@ -576,20 +816,24 @@ def _synthesis_transform(self, y_hat: Tensor, height: int, width: int) -> Tensor def forward(self, x: Tensor) -> Dict[str, Dict[str, Tensor] | Tensor]: y, height, width = self._analysis_transform(x) - latent_output = self._forward_latent_output(y) + y_out = self.latent_codec(y) return { - "x_hat": self._synthesis_transform(latent_output["y_hat"], height, width), - "likelihoods": latent_output["likelihoods"], + "x_hat": self._synthesis_transform(y_out["y_hat"], height, width), + "likelihoods": y_out["likelihoods"], } def compress(self, x: Tensor) -> Dict[str, object]: y, _, _ = self._analysis_transform(x) - return self._compress_latent(y) + y_out = self.latent_codec.compress(y) + return {"strings": y_out["strings"], "shape": y_out["shape"]} def decompress( - self, strings: Sequence[Sequence[bytes]], shape: Tuple[int, int] + self, + strings: Sequence[Sequence[bytes]], + shape: Dict[str, Tuple[int, ...]] | Tuple[int, int], ) -> Dict[str, Tensor]: - y_hat = self._decompress_latent(strings, shape) + y_out = self.latent_codec.decompress(strings, shape) + y_hat = y_out["y_hat"] height, width = y_hat.shape[2:] return {"x_hat": self._synthesis_transform(y_hat, height, width).clamp_(0, 1)} diff --git a/tests/test_latent_codecs.py b/tests/test_latent_codecs.py index 89776eb9..cbf508b1 100644 --- a/tests/test_latent_codecs.py +++ b/tests/test_latent_codecs.py @@ -27,6 +27,7 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest import torch import torch.nn as nn @@ -243,6 +244,75 @@ def test_max_support_slices_changes_forward_output(self): assert out["y_hat"].shape == (2, 16, 8, 8) +class TestChannelGroupsSideInContext: + """Phase 3 ``side_in_context`` mode used by Family 1 codecs.""" + + def _make_family1_codec( + self, + groups=(4, 4, 4), + side_ch=8, + max_support_slices=-1, + ): + K = len(groups) + + # channel_context for k=0: input is just side_params (= side_ch). + # channel_context for k>=1: input is cat(side_params, *prev_y_hat). + # Use 1x1 convs that map to 2 * groups[k] (the leaf chunks into scales/means). + def _ctx_in(k): + count = k if max_support_slices < 0 else min(k, max_support_slices) + return side_ch + sum(groups[:count]) + + channel_context = { + f"y{k}": nn.Conv2d(_ctx_in(k), 2 * groups[k], 1) for k in range(K) + } + # Leaves see only the channel_context output (no re-cat with side_params) + # in side_in_context mode -> entropy_parameters can be Identity since + # channel_context already shaped the tensor to 2 * slice_ch. + latent_codec = {f"y{k}": GaussianConditionalLatentCodec() for k in range(K)} + return ChannelGroupsLatentCodec( + latent_codec=latent_codec, + channel_context=channel_context, + groups=list(groups), + max_support_slices=max_support_slices, + side_in_context=True, + ) + + def test_constructor_requires_y0_entry(self): + # side_in_context=True but missing y0 channel_context -> ValueError. + with pytest.raises(ValueError, match="y0"): + ChannelGroupsLatentCodec( + latent_codec={"y0": GaussianConditionalLatentCodec()}, + channel_context={}, # missing y0 + groups=[4], + side_in_context=True, + ) + + def test_forward_routes_through_y0_channel_context(self): + torch.manual_seed(11) + codec = self._make_family1_codec(groups=(4, 4)) + y = torch.randn(2, 8, 8, 8) + side_params = torch.randn(2, 8, 8, 8) + out = codec(y, side_params) + assert out["y_hat"].shape == (2, 8, 8, 8) + assert out["likelihoods"]["y"].shape == (2, 8, 8, 8) + + def test_get_ctx_params_for_k_zero_calls_y0(self): + codec = self._make_family1_codec(groups=(4, 4)) + side_params = torch.zeros(1, 8, 4, 4) + ctx = codec._get_ctx_params(0, side_params, []) + # Output shape == channel_context.y0(side_params) -> (1, 2 * groups[0], 4, 4). + assert ctx.shape == (1, 8, 4, 4) + + def test_get_ctx_params_for_k_positive_concats_side(self): + codec = self._make_family1_codec(groups=(4, 4)) + side_params = torch.zeros(1, 8, 4, 4) + prev_y_hat = [torch.zeros(1, 4, 4, 4)] + ctx = codec._get_ctx_params(1, side_params, prev_y_hat) + # Channel_context.y1 input width = side_ch + groups[0] = 8 + 4 = 12; + # output = 2 * groups[1] = 8. + assert ctx.shape == (1, 8, 4, 4) + + class TestSliceHelpers: def test_slice_support_channels_default_use_all(self): # With max_support_slices = -1 the helper returns the full latent + k slices. @@ -272,16 +342,25 @@ def test_make_entropy_transform_custom_widths(self): assert y.shape == (2, 8, 8, 8) def test_infer_num_slices_new_path(self): - # New state-dict layout: channel_context entries exist for k >= 1. - # For 4 slices total, we expect 3 mean_cc keys -> infer returns 4. + # New state-dict layout (ELIC default): channel_context entries exist + # for k >= 1. For 4 slices total, we expect 3 mean_cc keys -> infer + # returns 4 (helper adds 1 because y0 is missing). sd = { - f"latent_codec.latent_codec.y.channel_context.y{k}.mean_cc.0.weight": ( - torch.zeros(8, 4) - ) + f"latent_codec.y.channel_context.y{k}.mean_cc.0.weight": (torch.zeros(8, 4)) for k in range(1, 4) } assert infer_num_slices(sd) == 4 + def test_infer_num_slices_side_in_context(self): + # Family 1 side_in_context=True layout: channel_context covers every + # slice (y0..yK-1). Helper auto-detects via the presence of y0 and + # does NOT add 1. + sd = { + f"latent_codec.y.channel_context.y{k}.mean_cc.0.weight": (torch.zeros(8, 4)) + for k in range(0, 4) + } + assert infer_num_slices(sd) == 4 + def test_infer_num_slices_empty(self): assert infer_num_slices({}) == 0 @@ -289,10 +368,10 @@ def test_infer_max_support_slices_new_path(self): # mean_cc.0 takes (latent_means + slice_channels * support) input channels. # With M=64, num_slices=8, slice_channels=8, support=2 -> input ch = 64 + 16 = 80. sd = { - "latent_codec.latent_codec.y.channel_context.y2.mean_cc.0.weight": ( + "latent_codec.y.channel_context.y2.mean_cc.0.weight": ( torch.zeros(64, 80, 3, 3) ), - "latent_codec.latent_codec.y.channel_context.y3.mean_cc.0.weight": ( + "latent_codec.y.channel_context.y3.mean_cc.0.weight": ( torch.zeros(64, 80, 3, 3) ), } diff --git a/tests/test_models.py b/tests/test_models.py index c23b1865..64136e61 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -290,6 +290,27 @@ def test_wacnn_forward_and_state_dict_round_trip(self): assert "y" in out["likelihoods"] assert "z" in out["likelihoods"] + # Phase 3 containerised state-dict layout self-check. + sd_keys = set(model.state_dict().keys()) + assert "latent_codec.h_a.0.weight" in sd_keys + assert "latent_codec.h_s.h_mean_s.0.weight" in sd_keys + assert "latent_codec.h_s.h_scale_s.0.weight" in sd_keys + assert "latent_codec.z.entropy_bottleneck.quantiles" in sd_keys + # side_in_context=True -> channel_context covers y0..y(K-1). + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y1.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in sd_keys + # Per-slice leaves (LRP + per-slice GaussianConditional copy). + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in sd_keys + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" in sd_keys + ) + # Old monolithic paths should be gone. + assert not any( + k.startswith("latent_codec.cc_mean_transforms.") for k in sd_keys + ) + assert "h_a.0.weight" not in sd_keys # moved under latent_codec. + loaded = WACNN.from_state_dict(model.state_dict()).eval() with torch.no_grad(): out_loaded = loaded(x) @@ -312,6 +333,13 @@ def test_symmetrical_transformer_forward_and_state_dict_round_trip(self): assert "y" in out["likelihoods"] assert "z" in out["likelihoods"] + sd_keys = set(model.state_dict().keys()) + assert "latent_codec.h_a.0.weight" in sd_keys + assert "latent_codec.h_s.h_mean_s.0.weight" in sd_keys + assert "latent_codec.z.entropy_bottleneck.quantiles" in sd_keys + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in sd_keys + loaded = SymmetricalTransFormer.from_state_dict(model.state_dict()).eval() with torch.no_grad(): out_loaded = loaded(x) @@ -325,14 +353,44 @@ def test_stf_upstream_state_dict_conversion(self): upstream = { "module.g_a.0.weight": torch.zeros(2), "module.cc_mean_transforms.0.0.weight": torch.zeros(2), + "module.cc_mean_transforms.1.0.weight": torch.zeros(2), + "module.cc_scale_transforms.0.0.weight": torch.zeros(2), + "module.lrp_transforms.0.0.weight": torch.zeros(2), "module.gaussian_conditional.scale_table": torch.zeros(2), "module.h_a.0.weight": torch.zeros(2), + "module.h_mean_s.0.weight": torch.zeros(2), + "module.h_scale_s.0.weight": torch.zeros(2), + "module.entropy_bottleneck.quantiles": torch.zeros(2), } converted = convert_upstream_stf_state_dict(upstream) + # g_a passes through unchanged. assert "g_a.0.weight" in converted - assert "latent_codec.cc_mean_transforms.0.0.weight" in converted - assert "latent_codec.gaussian_conditional.scale_table" in converted - assert "h_a.0.weight" in converted + # Hyperprior backbone moves under latent_codec. + assert "latent_codec.h_a.0.weight" in converted + assert "latent_codec.h_s.h_mean_s.0.weight" in converted + assert "latent_codec.h_s.h_scale_s.0.weight" in converted + assert "latent_codec.z.entropy_bottleneck.quantiles" in converted + # cc_mean / cc_scale re-rooted per slice. + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in converted + assert "latent_codec.y.channel_context.y1.mean_cc.0.weight" in converted + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in converted + # gaussian_conditional replicated to every slice (driven by mean_cc count). + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y1.gaussian_conditional.scale_table" + in converted + ) + # LRP weights are now retained: emit_mean_support=True on the head + # makes the leaf consume cat(latent_means, *prev_y_hat) as the LRP + # input, matching upstream's M + slice_ch*(support+1) input width. + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in converted + # Old root-level paths should be gone after conversion. + assert "h_a.0.weight" not in converted + assert "cc_mean_transforms.0.0.weight" not in converted + assert "lrp_transforms.0.0.weight" not in converted def test_scale_table_default(): diff --git a/tests/test_models_helpers.py b/tests/test_models_helpers.py index 4da40400..963c0994 100644 --- a/tests/test_models_helpers.py +++ b/tests/test_models_helpers.py @@ -27,6 +27,7 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest import torch import torch.nn as nn @@ -85,6 +86,38 @@ def test_direct_construction_round_trip(self): with torch.no_grad(): assert torch.allclose(head(x), rebuilt(x)) + def test_side_split_routes_means_to_mean_cc_and_scales_to_scale_cc(self): + # side_split=8 means input is cat(latent_means(8), latent_scales(8), prev_y_hat(4)); + # mean_cc should see cat(latent_means(8), prev_y_hat(4)) = 12 channels; + # scale_cc same width but reading latent_scales instead of latent_means. + torch.manual_seed(0) + head = build_mean_scale_head( + slice_ch=4, support_ch=20, widths=(8,), side_split=8 + ) + # Sub-network input width = support_ch - side_split = 12. + first_mean_conv = next(m for m in head.mean_cc if isinstance(m, nn.Conv2d)) + first_scale_conv = next(m for m in head.scale_cc if isinstance(m, nn.Conv2d)) + assert first_mean_conv.in_channels == 12 + assert first_scale_conv.in_channels == 12 + + latent_means = torch.randn(2, 8, 4, 4) + latent_scales = torch.randn(2, 8, 4, 4) + prev_y_hat = torch.randn(2, 4, 4, 4) + x = torch.cat([latent_means, latent_scales, prev_y_hat], dim=1) + with torch.no_grad(): + head_out = head(x) + assert head_out.shape == (2, 8, 4, 4) + # Verify routing: mean_cc(cat(latent_means, prev_y_hat)) appears as + # the second half of head_out (chunks=("scales","means")). + with torch.no_grad(): + expected_mean = head.mean_cc(torch.cat([latent_means, prev_y_hat], dim=1)) + expected_scale = head.scale_cc( + torch.cat([latent_scales, prev_y_hat], dim=1) + ) + scale_out, mean_out = head_out.chunk(2, dim=1) + assert torch.allclose(scale_out, expected_scale) + assert torch.allclose(mean_out, expected_mean) + class TestBuildChannelSliceCodec: def _leaf_factory(self, side_ch=8): @@ -176,3 +209,58 @@ def skip_recent(k, prior): support_filter=skip_recent, ) assert codec.support_filter is skip_recent + + def test_side_in_context_builds_y0_entry(self): + # In side_in_context mode the leaf gets only channel_context output + # (already shaped 2*slice_ch); each leaf can use Identity entropy_parameters. + codec = build_channel_slice_codec( + groups=[4, 4, 4], + side_channels=8, + side_in_context=True, + leaf_factory=lambda k, ch: GaussianConditionalLatentCodec(), + channel_context_factory=lambda k, ch, sup: build_mean_scale_head( + slice_ch=ch, support_ch=sup, side_split=4, widths=(8,) + ), + ) + ctx_keys = set(codec.channel_context.keys()) + assert ctx_keys == {"y0", "y1", "y2"} + assert codec.side_in_context is True + + # y0 head input width = side_channels = 8. + head_y0 = codec.channel_context["y0"] + first_conv_y0 = next(m for m in head_y0.mean_cc if isinstance(m, nn.Conv2d)) + # mean_cc input = support_ch - side_split = 8 - 4 = 4. + assert first_conv_y0.in_channels == 4 + + # y2 head input width = side_channels + groups[0] + groups[1] = 8 + 8 = 16; + # mean_cc input = 16 - 4 = 12. + head_y2 = codec.channel_context["y2"] + first_conv_y2 = next(m for m in head_y2.mean_cc if isinstance(m, nn.Conv2d)) + assert first_conv_y2.in_channels == 12 + + def test_side_in_context_requires_side_channels(self): + with pytest.raises(ValueError, match="side_channels"): + build_channel_slice_codec( + groups=[4, 4], + side_in_context=True, + leaf_factory=lambda k, ch: GaussianConditionalLatentCodec(), + ) + + def test_side_in_context_forward_runs_end_to_end(self): + torch.manual_seed(0) + codec = build_channel_slice_codec( + groups=[4, 4], + side_channels=8, + side_in_context=True, + leaf_factory=lambda k, ch: GaussianConditionalLatentCodec(), + channel_context_factory=lambda k, ch, sup: build_mean_scale_head( + slice_ch=ch, support_ch=sup, side_split=4, widths=(8,) + ), + ) + codec.eval() + y = torch.randn(2, 8, 8, 8) + side_params = torch.randn(2, 8, 8, 8) + with torch.no_grad(): + out = codec(y, side_params) + assert out["y_hat"].shape == (2, 8, 8, 8) + assert out["likelihoods"]["y"].shape == (2, 8, 8, 8) From c2d931a8aa2d5df5edf2c89acc9306a9550d8f09 Mon Sep 17 00:00:00 2001 From: boyceyi <1473416941@qq.com> Date: Sat, 9 May 2026 13:48:07 +0800 Subject: [PATCH 3/8] feat(models): add TCM with containerized codec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Migrate TCM (Liu et al., CVPR 2023) to the H+G containerized entropy stack used by STF / WACNN. The hyperprior backbone (h_a, h_mean_s, h_scale_s, EntropyBottleneck) plus the per-slice channel-conditional entropy heads now live under a single HyperpriorLatentCodec — matching ELIC-pattern wiring with three Family 1 specializations: - DualHyperSynthesis(h_mean_s, h_scale_s) cats the two parallel hyper-synthesis outputs into side_params of width 2*M. - ChannelGroupsLatentCodec(side_in_context=True) routes side_params into every channel_context head (incl. y0). - LRPGaussianLatentCodec(mean_support_trail_channels=...) plus the MeanScaleContextHead(emit_mean_support=True) recover the upstream cat(latent_means, *prev_y_hat, y_hat) LRP layout for byte-for-byte weight transfer from upstream LIC_TCM checkpoints. TCM-specific kwargs vs STF/WACNN: 3-conv widths=(224, 128) plus support_transform_factory=SWAtten (independent windowed-attention per mean / scale path), mirroring upstream atten_mean[k] / atten_scale[k]. use_cca / use_auxt are intentionally not implemented in this PR: - use_cca will be re-added once Phase 5 lands the containerized _CCAAuxEntropyModel. - use_auxt (AuxT, ICLR 2025) depends on layers (WLS / iWLS / OLP) not in this branch and is out of scope. Verified on the LIC_TCM 0.05.pth.tar (N=64) and mse_lambda_0.05.pth.tar (N=128) candidate checkpoints: strict load succeeds, sinusoidal smoke test reaches PSNR 39.15 dB / 39.41 dB respectively (vs 5.41 dB for a fresh-init model), confirming all weights — including LRP — transfer byte-for-byte through convert_upstream_tcm_state_dict. Also: append a Family 1 wiring comment block to compressai/latent_codecs/__init__.py so reviewers can see how the ELIC-style upstream codecs compose into the STF / WACNN / TCM / CCA / DCAE / MambaVC pattern without reading model source. Tests: - tests/test_models.py::TestTcm::test_tcm_forward_and_state_dict_round_trip — forward + new state_dict path self-check + round-trip allclose. - tests/test_models.py::TestTcm::test_tcm_upstream_state_dict_conversion — synthetic upstream LIC_TCM-style state_dict, asserts MSA buffer reshape + per-slice / SWAtten-wrapper / hyperprior re-rooting. make static-analysis clean. pytest tests/test_models.py tests/test_latent_codecs.py tests/test_models_helpers.py tests/test_layers.py tests/test_init.py: 71/71. import compressai / compressai.zoo / compressai.latent_codecs still trigger zero timm imports (TCM follows the STF lazy-load convention — not re-exported from compressai.models.__init__). --- compressai/latent_codecs/__init__.py | 66 +++ compressai/models/tcm.py | 761 +++++++++++++++++++++++++++ examples/convert_tcm_checkpoint.py | 122 +++++ tests/test_models.py | 152 ++++++ 4 files changed, 1101 insertions(+) create mode 100644 compressai/models/tcm.py create mode 100644 examples/convert_tcm_checkpoint.py diff --git a/compressai/latent_codecs/__init__.py b/compressai/latent_codecs/__init__.py index bd587257..1d5ea9c6 100644 --- a/compressai/latent_codecs/__init__.py +++ b/compressai/latent_codecs/__init__.py @@ -54,3 +54,69 @@ "LRPGaussianLatentCodec", "RasterScanLatentCodec", ] + + +# ---------------------------------------------------------------------------- +# Family 1 wiring (STF / WACNN / TCM / CCA / DCAE / MambaVC) +# ---------------------------------------------------------------------------- +# +# "Family 1" is the set of channel-slice models that share the same outer +# entropy-stack shape: +# +# HyperpriorLatentCodec( +# h_a=h_a, +# h_s=DualHyperSynthesis(h_mean_s, h_scale_s), # cat(mean_s, scale_s) +# latent_codec={ +# "z": EntropyBottleneckLatentCodec(EntropyBottleneck(N), ...), +# "y": ChannelGroupsLatentCodec( # side_in_context=True mode +# latent_codec={"y0": LRPGaussianLatentCodec(...), ...}, +# channel_context={"y0": MeanScaleContextHead(...), ...}, +# groups=[M//K]*K, +# max_support_slices=MS, +# side_in_context=True, +# ), +# }, +# ) +# +# Compared to the ELIC-style channel-slice wiring it differs in three +# places, all reproducible through optional kwargs on the upstream codecs: +# +# 1. Two parallel ``h_s`` heads instead of one — DualHyperSynthesis cats +# them into a single ``side_params`` tensor of width ``2*M``. +# 2. ``ChannelGroupsLatentCodec(side_in_context=True)`` routes +# ``side_params`` into every channel_context head (including ``y0``) +# instead of only handing it to the leaves; the head is then +# responsible for re-splitting ``side_params`` into mean / scale. +# 3. The leaf is :class:`LRPGaussianLatentCodec` (mostly), which adds a +# learned residual prediction on top of ``y_hat``. With matching +# ``mean_support_trail_channels`` the leaf reads the LRP input from a +# trailing block of ``ctx_params`` produced by the head's +# ``emit_mean_support=True`` mode, recovering the upstream +# ``cat(latent_means, *prev_y_hat, y_hat)`` layout for byte-for-byte +# weight transfer. +# +# Application-layer helpers in +# :mod:`compressai.models._helpers.channel_slice` and +# :mod:`compressai.models._helpers.channel_context` +# (``build_channel_slice_codec``, ``MeanScaleContextHead``, +# ``build_mean_scale_head``) wire these pieces declaratively. Per-model +# variations stay in the kwargs: +# +# - **STF / WACNN**: 5-conv cc heads ``widths=(224, 176, 128, 64)``, no +# support transform. +# - **TCM**: 3-conv cc heads ``widths=(224, 128)``, +# ``support_transform_factory=SWAtten`` (independent windowed-attention +# transforms per mean / scale path). +# - **CCA-main**: variable-length slices (``groups=resolved_slice_sizes``), +# ``support_transform_factory=NAFTransform``, +# ``EntropyBottleneckLatentCodec(quantizer="ste")`` for the ``z`` leaf. +# - **CCA-aux**: lives outside the hyperprior container (separate +# ``ChannelGroupsLatentCodec``), uses ``support_filter`` for +# skip-most-recent prior selection, and mixes +# :class:`LRPGaussianLatentCodec` (early slices) with +# :class:`GaussianConditionalLatentCodec` (last two slices). +# - **DCAE / MambaVC**: future Family 1 follow-ups; same shape, different +# support transforms. +# +# See :mod:`compressai.models.stf` and :mod:`compressai.models.tcm` for +# end-to-end examples. diff --git a/compressai/models/tcm.py b/compressai/models/tcm.py new file mode 100644 index 00000000..a5d8e713 --- /dev/null +++ b/compressai/models/tcm.py @@ -0,0 +1,761 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts code from https://github.com/jmliu206/LIC_TCM +# (originally distributed under the MIT License). The upstream copyright +# notice is preserved in that repository; modifications by InterDigital +# Communications, Inc. are released under the BSD 3-Clause Clear License +# terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import re + +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.entropy_models import EntropyBottleneck +from compressai.latent_codecs import ( + DualHyperSynthesis, + EntropyBottleneckLatentCodec, + HyperpriorLatentCodec, + LRPGaussianLatentCodec, +) +from compressai.latent_codecs._slice_helpers import ( + infer_max_support_slices, + infer_num_slices, + make_entropy_transform, +) +from compressai.layers import ( + ResidualBlockUpsample, + ResidualBlockWithStride, + conv3x3, + subpel_conv3x3, +) +from compressai.layers.attn import ConvTransBlock, SWAtten +from compressai.models._helpers.channel_context import build_mean_scale_head +from compressai.models._helpers.channel_slice import build_channel_slice_codec +from compressai.models.base import CompressionModel +from compressai.models.utils import conv +from compressai.registry import register_model + +__all__ = [ + "TCM", + "convert_upstream_tcm_state_dict", +] + + +# ---------------------------------------------------------------------------- +# Upstream LIC_TCM checkpoint conversion +# ---------------------------------------------------------------------------- + + +# Heads from upstream LIC_TCM (Liu et al. 2023) checkpoints that move under +# ``latent_codec.*`` after the H+G containerised refactor. +_UPSTREAM_LATENT_CODEC_PREFIXES = ( + "cc_mean_transforms", + "cc_scale_transforms", + "lrp_transforms", + "atten_mean", + "atten_scale", + "mean_support_transforms", + "scale_support_transforms", + "gaussian_conditional", +) + +# Top-level rename map applied AFTER per-slice rerooting. Keys are matched as +# exact prefixes (with the trailing dot). +_UPSTREAM_TOP_LEVEL_RENAMES: Dict[str, str] = { + "h_a.": "latent_codec.h_a.", + "h_mean_s.": "latent_codec.h_s.h_mean_s.", + "h_scale_s.": "latent_codec.h_s.h_scale_s.", + "entropy_bottleneck.": "latent_codec.z.entropy_bottleneck.", +} + +# Upstream LIC_TCM wraps each ``SWAtten`` in an ``nn.Sequential`` and stores +# parameters at ``atten_mean.{k}.0.<...>``. Compressai's :class:`SWAtten` +# lives directly at ``mean_support_transform.<...>`` after rerooting, so the +# leading ``.0`` wrapper level is stripped. +_UPSTREAM_SWATTEN_WRAPPER = re.compile( + r"^(atten_mean|atten_scale|mean_support_transforms|scale_support_transforms)\.(\d+)\.0\." +) + + +def _rename_msa_keys(key: str, value: Tensor) -> Tuple[str, Tensor]: + """Translate upstream LIC_TCM ConvTransBlock-internal MSA layout to + compressai's :class:`WMSA` wrapper layout. + + Three kinds of upstream keys appear inside ``g_a`` / ``g_s`` / ``h_a`` / + ``h_mean_s`` / ``h_scale_s`` blocks: + + - ``.msa.relative_position_params`` is a ``(2*win-1, 2*win-1, num_heads)`` + buffer; compressai's ``WindowAttention`` registers it as a flat + ``(N, num_heads)`` ``relative_position_bias_table``. The value is + permuted and reshaped accordingly. + - ``.msa.embedding_layer`` is upstream's name for the fused ``qkv`` + linear; compressai exposes it as ``.msa.attn.qkv.<...>``. + - ``.msa.linear`` is upstream's optional output projection; compressai + drops it and instead uses the WindowAttention's identity ``.proj`` — + see :func:`_ensure_identity_attention_projection` for the identity + injection that keeps strict ``load_state_dict`` round-trips clean. + """ + if ".msa.relative_position_params" in key: + new_key = key.replace( + ".msa.relative_position_params", + ".msa.attn.relative_position_bias_table", + ) + new_value = value.permute(1, 2, 0).reshape(-1, value.size(0)).contiguous() + return new_key, new_value + if ".msa.embedding_layer." in key: + return key.replace(".msa.embedding_layer.", ".msa.attn.qkv."), value + if ".msa.linear." in key: + return key.replace(".msa.linear.", ".msa.output_proj."), value + return key, value + + +def _ensure_identity_attention_projection( + state_dict: Dict[str, Tensor], + output_proj_key: str, + output_proj_value: Tensor, +) -> None: + """Inject an identity ``WindowAttention.proj`` for upstream blocks whose + output projection sits outside the attention module (``.msa.linear`` → + ``.msa.output_proj``). The model has both ``.msa.attn.proj`` (inside + WindowAttention, identity-initialised here) and ``.msa.output_proj`` + (the actual learned projection) so strict ``load_state_dict`` succeeds. + """ + prefix, suffix = output_proj_key.rsplit(".msa.output_proj.", 1) + attn_proj_key = f"{prefix}.msa.attn.proj.{suffix}" + if attn_proj_key in state_dict: + return + if suffix == "weight": + dimension = output_proj_value.size(0) + state_dict[attn_proj_key] = torch.eye( + dimension, + dtype=output_proj_value.dtype, + device=output_proj_value.device, + ) + return + if suffix == "bias": + state_dict[attn_proj_key] = torch.zeros_like(output_proj_value) + + +def _is_upstream_tcm_state_dict(state_dict: Dict[str, Tensor]) -> bool: + """Heuristic: upstream LIC_TCM checkpoints either carry the ``module.`` + prefix from ``DataParallel`` saving, the ``.msa.relative_position_params`` + buffer, or the per-slice entropy heads (``cc_mean_transforms`` / + ``atten_mean`` / ``lrp_transforms`` / ``gaussian_conditional`` / ``h_a`` + / ``h_mean_s`` / ``h_scale_s`` / ``entropy_bottleneck``) at the model + root rather than under ``latent_codec.*``. + """ + legacy_roots = set(_UPSTREAM_LATENT_CODEC_PREFIXES) | { + "h_a", + "h_mean_s", + "h_scale_s", + "entropy_bottleneck", + } + for key in state_dict: + if key.startswith("module."): + return True + if ( + ".msa.relative_position_params" in key + or ".msa.embedding_layer." in key + or ".msa.linear." in key + ): + return True + if key.split(".", 1)[0] in legacy_roots: + return True + return False + + +def convert_upstream_tcm_state_dict( + state_dict: Dict[str, Tensor], +) -> Dict[str, Tensor]: + """Translate an upstream LIC_TCM state dict into compressai layout. + + Upstream checkpoints (e.g. ``0.013.pth..tar`` from + `Liu et al. 2023 `_, + https://github.com/jmliu206/LIC_TCM) place the channel-conditional entropy + transforms and the hyperprior backbone at the model root. After the H+G + containerised refactor compressai houses those transforms (plus the + Gaussian conditional and the ``z`` bottleneck) inside ``latent_codec.*``. + This helper: + + - strips the leading ``module.`` prefix added by ``DataParallel``; + - rewrites ConvTransBlock attention buffers via :func:`_rename_msa_keys` + (``.msa.relative_position_params`` / ``.msa.embedding_layer`` / + ``.msa.linear``) and standard layer-name renames (``ln1`` → ``norm1``, + ``mlp.0`` / ``mlp.2`` → ``mlp.fc1`` / ``mlp.fc2``); + - unwraps the upstream ``nn.Sequential`` wrapper around each ``SWAtten`` + (``atten_mean.{k}.0.<...>`` → ``atten_mean.{k}.<...>``); + - re-roots ``cc_mean_transforms.{k}`` / ``cc_scale_transforms.{k}`` / + ``lrp_transforms.{k}`` under + ``latent_codec.y.channel_context.y{k}.{mean_cc,scale_cc}.*`` / + ``latent_codec.y.latent_codec.y{k}.lrp_transform.*``; + - re-roots ``atten_mean.{k}`` / ``atten_scale.{k}`` (or their + ``mean_support_transforms`` / ``scale_support_transforms`` aliases) + under ``latent_codec.y.channel_context.y{k}.{mean,scale}_support_transform.*``; + - replicates the single shared ``gaussian_conditional.*`` buffer set + under each per-slice leaf + (``latent_codec.y.latent_codec.y{k}.gaussian_conditional.*``); + - moves ``entropy_bottleneck.*`` / ``h_a.*`` / ``h_mean_s.*`` / + ``h_scale_s.*`` under ``latent_codec.*`` per the new layout; + - leaves ``g_a`` / ``g_s`` keys (other than the MSA renames inside their + ConvTransBlocks) untouched. + + The Phase 3 wiring sets ``emit_mean_support=True`` on the + :class:`MeanScaleContextHead`, so the upstream LRP layout + (``cat(latent_means, *prev_y_hat, y_hat)``) is recoverable inside the + leaf — upstream ``lrp_transforms.{k}`` weights therefore transfer + byte-for-byte. + + The returned dict can be loaded by :meth:`TCM.from_state_dict`, which + auto-detects the upstream layout and calls this helper, so direct + invocation is only needed when persisting the converted dict. + """ + # Pass 1: strip ``module.`` prefix; rewrite ConvTransBlock attention + # buffers and layer names; unwrap the SWAtten ``nn.Sequential`` wrapper; + # alias ``atten_mean`` / ``atten_scale`` to the canonical + # ``mean_support_transforms`` / ``scale_support_transforms`` names so the + # per-slice rerooting in Pass 2 only has to handle one form. + cleaned: Dict[str, Tensor] = {} + for key, value in state_dict.items(): + new_key = key[len("module.") :] if key.startswith("module.") else key + new_key, value = _rename_msa_keys(new_key, value) + wrapper = _UPSTREAM_SWATTEN_WRAPPER.match(new_key) + if wrapper: + new_key = ( + f"{wrapper.group(1)}.{wrapper.group(2)}." + new_key[wrapper.end() :] + ) + if new_key.startswith("atten_mean."): + new_key = "mean_support_transforms." + new_key[len("atten_mean.") :] + elif new_key.startswith("atten_scale."): + new_key = "scale_support_transforms." + new_key[len("atten_scale.") :] + new_key = new_key.replace(".ln1.", ".norm1.") + new_key = new_key.replace(".ln2.", ".norm2.") + new_key = new_key.replace(".mlp.0.", ".mlp.fc1.") + new_key = new_key.replace(".mlp.2.", ".mlp.fc2.") + if ".msa.output_proj." in new_key: + _ensure_identity_attention_projection(cleaned, new_key, value) + cleaned[new_key] = value + + # Pass 2: discover slice indices to drive ``gaussian_conditional`` + # replication, then reroot per-slice and top-level keys. + converted: Dict[str, Tensor] = {} + slice_indices = sorted( + { + int(key.split(".")[1]) + for key in cleaned + if key.startswith("cc_mean_transforms.") + } + ) + num_slices = len(slice_indices) + + for key, value in cleaned.items(): + head = key.split(".", 1)[0] + if head == "cc_mean_transforms": + _, k, *rest = key.split(".") + new_key = f"latent_codec.y.channel_context.y{k}.mean_cc." + ".".join(rest) + converted[new_key] = value + elif head == "cc_scale_transforms": + _, k, *rest = key.split(".") + new_key = f"latent_codec.y.channel_context.y{k}.scale_cc." + ".".join(rest) + converted[new_key] = value + elif head == "mean_support_transforms": + _, k, *rest = key.split(".") + new_key = ( + f"latent_codec.y.channel_context.y{k}.mean_support_transform." + + ".".join(rest) + ) + converted[new_key] = value + elif head == "scale_support_transforms": + _, k, *rest = key.split(".") + new_key = ( + f"latent_codec.y.channel_context.y{k}.scale_support_transform." + + ".".join(rest) + ) + converted[new_key] = value + elif head == "lrp_transforms": + _, k, *rest = key.split(".") + new_key = f"latent_codec.y.latent_codec.y{k}.lrp_transform." + ".".join( + rest + ) + converted[new_key] = value + elif head == "gaussian_conditional": + tail = key[len("gaussian_conditional.") :] + for k in range(num_slices): + new_key = ( + f"latent_codec.y.latent_codec.y{k}.gaussian_conditional.{tail}" + ) + converted[new_key] = value + else: + renamed = key + for prefix, replacement in _UPSTREAM_TOP_LEVEL_RENAMES.items(): + if key.startswith(prefix): + renamed = replacement + key[len(prefix) :] + break + converted[renamed] = value + + return converted + + +# ---------------------------------------------------------------------------- +# Architecture inference helpers (state_dict -> hyperparameters) +# ---------------------------------------------------------------------------- + + +def _group_consecutive(indices: Iterable[int]) -> List[List[int]]: + grouped: List[List[int]] = [] + for index in sorted(indices): + if not grouped or index != grouped[-1][-1] + 1: + grouped.append([index]) + continue + grouped[-1].append(index) + return grouped + + +def _infer_stage_groups(state_dict: Dict[str, Tensor], prefix: str) -> List[List[int]]: + indices = { + int(key.split(".")[1]) + for key in state_dict + if key.startswith(f"{prefix}.") and ".conv1_1.weight" in key + } + return _group_consecutive(indices) + + +def _infer_stage_depths(state_dict: Dict[str, Tensor]) -> Optional[List[int]]: + g_a_groups = _infer_stage_groups(state_dict, "g_a") + g_s_groups = _infer_stage_groups(state_dict, "g_s") + if len(g_a_groups) != 3 or len(g_s_groups) != 3: + return None + return [len(group) for group in g_a_groups + g_s_groups] + + +def _infer_head_dims(state_dict: Dict[str, Tensor], N: int) -> Optional[List[int]]: + head_dims: List[int] = [] + for prefix in ("g_a", "g_s"): + for group in _infer_stage_groups(state_dict, prefix): + if not group: + continue + table_key = ( + f"{prefix}.{group[0]}.trans_block.msa.attn.relative_position_bias_table" + ) + if table_key not in state_dict: + return None + num_heads = state_dict[table_key].size(1) + head_dims.append(N // num_heads) + return head_dims if len(head_dims) == 6 else None + + +def _infer_hyper_head_dim(state_dict: Dict[str, Tensor], N: int, default: int) -> int: + for key in ( + "h_a.1.trans_block.msa.attn.relative_position_bias_table", + "h_mean_s.1.trans_block.msa.attn.relative_position_bias_table", + ): + if key in state_dict: + return N // state_dict[key].size(1) + return default + + +# ---------------------------------------------------------------------------- +# Architecture building blocks +# ---------------------------------------------------------------------------- + + +def _make_mixed_stage( + depth: int, + branch_channels: int, + head_dim: int, + window_size: int, + drop_paths: Sequence[float], + tail: nn.Module, +) -> List[nn.Module]: + if len(drop_paths) != depth: + raise ValueError("drop_paths must match stage depth") + blocks = [ + ConvTransBlock( + branch_channels, + branch_channels, + head_dim, + window_size, + drop_paths[index], + type="W" if index % 2 == 0 else "SW", + ) + for index in range(depth) + ] + return [*blocks, tail] + + +# ---------------------------------------------------------------------------- +# TCM model +# ---------------------------------------------------------------------------- + + +@register_model("lic-tcm") +@register_model("tcm") +class TCM(CompressionModel): + r"""TCM model from J. Liu, H. Sun, J. Katto: `"Learned Image Compression + with Mixed Transformer-CNN Architectures" + `_, IEEE/CVF Conf. on Computer Vision + and Pattern Recognition (CVPR), 2023 (Highlight). + + Stacks parallel Transformer-CNN Mixture (TCM) blocks for the + analysis/synthesis transforms and uses a channel-wise autoregressive + entropy model with parameter-efficient swin-transformer attention + (``SWAtten``). + + The entropy stack is a fully containerised + :class:`HyperpriorLatentCodec` that owns ``h_a``, ``h_s``, the ``z`` + bottleneck and the per-slice ``ChannelGroupsLatentCodec`` running in + Family 1 ``side_in_context=True`` mode. The channel-context heads run + with ``support_transform_factory=SWAtten`` so per-slice ``mean_in`` / + ``scale_in`` are routed through independent SWAtten instances before the + 3-conv ``mean_cc`` / ``scale_cc`` stacks (TCM's distinctive widths + ``(224, 128)``). + + Args: + N (int): Channel width of the analysis/synthesis transform branches. + M (int): Channels in the latent representation ``y``. + hyper_channels (int): Channels in the hyperprior backbone ``z``. + num_slices (int): Number of channel slices for the entropy model. + max_support_slices (int): Per-slice context cap. + """ + + def __init__( + self, + config: Optional[Sequence[int]] = None, + head_dim: Optional[Sequence[int]] = None, + drop_path_rate: float = 0.0, + N: int = 128, + M: int = 320, + hyper_channels: int = 192, + num_slices: int = 5, + max_support_slices: int = 5, + window_size: int = 8, + hyper_window_size: int = 4, + hyper_head_dim: int = 32, + **kwargs, + ) -> None: + super().__init__(**kwargs) + config = tuple(int(value) for value in (config or (2, 2, 2, 2, 2, 2))) + head_dim = tuple(int(value) for value in (head_dim or (8, 16, 32, 32, 16, 8))) + if len(config) != 6: + raise ValueError("config must provide six stage depths") + if len(head_dim) != 6: + raise ValueError("head_dim must provide six stage head dimensions") + if any(value < 0 for value in config): + raise ValueError("config values must be non-negative") + if M % num_slices != 0: + raise ValueError("M must be divisible by num_slices") + if any(N % value != 0 for value in head_dim): + raise ValueError("Each head_dim must divide N") + if N % hyper_head_dim != 0: + raise ValueError("hyper_head_dim must divide N") + + self.config = config + self.head_dim = head_dim + self.window_size = int(window_size) + self.hyper_window_size = int(hyper_window_size) + self.hyper_head_dim = int(hyper_head_dim) + self.N = int(N) + self.M = int(M) + self.hyper_channels = int(hyper_channels) + self.num_slices = int(num_slices) + self.max_support_slices = int(max_support_slices) + + drop_paths = torch.linspace(0, drop_path_rate, sum(config)).tolist() + offset = 0 + + def stage_drop_paths(depth: int) -> List[float]: + nonlocal offset + values = [float(value) for value in drop_paths[offset : offset + depth]] + offset += depth + return values + + self.g_a = nn.Sequential( + ResidualBlockWithStride(3, 2 * N, stride=2), + *_make_mixed_stage( + config[0], + N, + head_dim[0], + self.window_size, + stage_drop_paths(config[0]), + ResidualBlockWithStride(2 * N, 2 * N, stride=2), + ), + *_make_mixed_stage( + config[1], + N, + head_dim[1], + self.window_size, + stage_drop_paths(config[1]), + ResidualBlockWithStride(2 * N, 2 * N, stride=2), + ), + *_make_mixed_stage( + config[2], + N, + head_dim[2], + self.window_size, + stage_drop_paths(config[2]), + conv3x3(2 * N, M, stride=2), + ), + ) + self.g_s = nn.Sequential( + ResidualBlockUpsample(M, 2 * N, 2), + *_make_mixed_stage( + config[3], + N, + head_dim[3], + self.window_size, + stage_drop_paths(config[3]), + ResidualBlockUpsample(2 * N, 2 * N, 2), + ), + *_make_mixed_stage( + config[4], + N, + head_dim[4], + self.window_size, + stage_drop_paths(config[4]), + ResidualBlockUpsample(2 * N, 2 * N, 2), + ), + *_make_mixed_stage( + config[5], + N, + head_dim[5], + self.window_size, + stage_drop_paths(config[5]), + subpel_conv3x3(2 * N, 3, 2), + ), + ) + + h_a = nn.Sequential( + ResidualBlockWithStride(M, 2 * N, 2), + *_make_mixed_stage( + config[0], + N, + self.hyper_head_dim, + self.hyper_window_size, + [0.0] * config[0], + conv3x3(2 * N, hyper_channels, stride=2), + ), + ) + h_mean_s = nn.Sequential( + ResidualBlockUpsample(hyper_channels, 2 * N, 2), + *_make_mixed_stage( + config[3], + N, + self.hyper_head_dim, + self.hyper_window_size, + [0.0] * config[3], + subpel_conv3x3(2 * N, M, 2), + ), + ) + h_scale_s = nn.Sequential( + ResidualBlockUpsample(hyper_channels, 2 * N, 2), + *_make_mixed_stage( + config[3], + N, + self.hyper_head_dim, + self.hyper_window_size, + [0.0] * config[3], + subpel_conv3x3(2 * N, M, 2), + ), + ) + + slice_ch = M // num_slices + self.latent_codec = _build_tcm_latent_codec( + hyper_channels=hyper_channels, + M=M, + slice_ch=slice_ch, + num_slices=num_slices, + max_support_slices=max_support_slices, + window_size=self.window_size, + h_a=h_a, + h_mean_s=h_mean_s, + h_scale_s=h_scale_s, + ) + + def forward(self, x: Tensor) -> Dict[str, Dict[str, Tensor] | Tensor]: + y = self.g_a(x) + y_out = self.latent_codec(y) + return { + "x_hat": self.g_s(y_out["y_hat"]), + "likelihoods": y_out["likelihoods"], + } + + def compress(self, x: Tensor) -> Dict[str, object]: + y = self.g_a(x) + y_out = self.latent_codec.compress(y) + return {"strings": y_out["strings"], "shape": y_out["shape"]} + + def decompress( + self, + strings: Sequence[Sequence[bytes]], + shape: Dict[str, Tuple[int, ...]] | Tuple[int, int], + ) -> Dict[str, Tensor]: + y_out = self.latent_codec.decompress(strings, shape) + return {"x_hat": self.g_s(y_out["y_hat"]).clamp_(0, 1)} + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "TCM": + if _is_upstream_tcm_state_dict(state_dict): + state_dict = convert_upstream_tcm_state_dict(state_dict) + N = state_dict["g_a.0.conv1.weight"].size(0) // 2 + M = state_dict["latent_codec.h_a.0.conv1.weight"].size(1) + config = _infer_stage_depths(state_dict) or [2, 2, 2, 2, 2, 2] + head_dim = _infer_head_dims(state_dict, N) or [8, 16, 32, 32, 16, 8] + hyper_channels = state_dict["latent_codec.z.entropy_bottleneck.quantiles"].size( + 0 + ) + num_slices = infer_num_slices(state_dict) or 5 + max_support_slices = infer_max_support_slices(state_dict, M, num_slices) + net = cls( + config=config, + head_dim=head_dim, + N=N, + M=M, + hyper_channels=hyper_channels, + num_slices=num_slices, + max_support_slices=max_support_slices, + hyper_head_dim=_infer_hyper_head_dim(state_dict, N, 32), + ) + # ConvTransBlock's WindowAttention registers + # ``relative_position_index`` as a non-persistent buffer, so it is + # absent from saved state dicts. Tolerate the missing keys. + incompatible_keys = net.load_state_dict(state_dict, strict=False) + allowed_missing = { + key for key in net.state_dict() if key.endswith("relative_position_index") + } + missing_keys = set(incompatible_keys.missing_keys) - allowed_missing + if missing_keys or incompatible_keys.unexpected_keys: + raise RuntimeError( + "Unexpected incompatibility while loading TCM state_dict: " + f"missing={sorted(missing_keys)}, " + f"unexpected={sorted(incompatible_keys.unexpected_keys)}" + ) + return net + + +def _build_tcm_latent_codec( + *, + hyper_channels: int, + M: int, + slice_ch: int, + num_slices: int, + max_support_slices: int, + window_size: int, + h_a: nn.Module, + h_mean_s: nn.Module, + h_scale_s: nn.Module, +) -> HyperpriorLatentCodec: + """Assemble TCM's Family 1 entropy stack: ``HyperpriorLatentCodec`` + wrapping ``DualHyperSynthesis`` and a per-slice + ``ChannelGroupsLatentCodec`` (``side_in_context=True``). + + Differences from the STF / WACNN wiring (see + :func:`compressai.models.stf._build_family1_latent_codec`): + + - ``widths=(224, 128)`` — TCM's 3-conv ``mean_cc`` / ``scale_cc`` stack + in place of STF's 5-conv ``(224, 176, 128, 64)`` ladder. + - ``support_transform_factory=SWAtten`` — independent windowed-attention + transforms wrap each slice's ``mean_in`` / ``scale_in`` before the + conv stack. The two SWAtten instances per slice mirror upstream + ``atten_mean[k]`` / ``atten_scale[k]``. + + Like STF, the channel-context heads run with ``emit_mean_support=True`` + and the leaves with matching ``mean_support_trail_channels`` so the + upstream LRP layout (``cat(latent_means, *prev_y_hat, y_hat)``) is + preserved — upstream ``lrp_transforms.{k}`` weights transfer + byte-for-byte after :func:`convert_upstream_tcm_state_dict`. + """ + widths = (224, 128) + side_channels = 2 * M + + def _support_count(k: int) -> int: + if max_support_slices < 0: + return k + return min(k, max_support_slices) + + def _mean_support_ch(k: int) -> int: + # cat(latent_means(M), *prev_y_hat(slice_ch * support_count)). + return M + slice_ch * _support_count(k) + + def _leaf(k: int, _slice_ch: int) -> LRPGaussianLatentCodec: + ms_ch = _mean_support_ch(k) + return LRPGaussianLatentCodec( + lrp_transform=make_entropy_transform( + ms_ch + _slice_ch, # cat(mean_support, y_hat) + _slice_ch, + widths=widths, + ), + mean_support_trail_channels=ms_ch, + quantizer="ste", + ) + + def _swatten_factory(c_in: int, c_out: int) -> nn.Module: + return SWAtten( + input_dim=c_in, + output_dim=c_out, + head_dim=16, + window_size=window_size, + drop_path=0.0, + inter_dim=128, + ) + + def _channel_context(_k: int, _slice_ch: int, support_ch: int) -> nn.Module: + return build_mean_scale_head( + slice_ch=_slice_ch, + support_ch=support_ch, + widths=widths, + side_split=M, + emit_mean_support=True, + support_transform_factory=_swatten_factory, + ) + + return HyperpriorLatentCodec( + h_a=h_a, + h_s=DualHyperSynthesis(h_mean_s, h_scale_s), + latent_codec={ + "z": EntropyBottleneckLatentCodec( + entropy_bottleneck=EntropyBottleneck(hyper_channels), + quantizer="noise", + ), + "y": build_channel_slice_codec( + groups=[slice_ch] * num_slices, + side_channels=side_channels, + side_in_context=True, + max_support_slices=max_support_slices, + leaf_factory=_leaf, + channel_context_factory=_channel_context, + ), + }, + ) diff --git a/examples/convert_tcm_checkpoint.py b/examples/convert_tcm_checkpoint.py new file mode 100644 index 00000000..77ab2700 --- /dev/null +++ b/examples/convert_tcm_checkpoint.py @@ -0,0 +1,122 @@ +"""Convert an upstream LIC-TCM checkpoint to compressai layout. + +Loads the published candidate weight file (e.g. ``0.05.pth.tar`` or +``mse_lambda_0.05.pth.tar`` from the LIC_TCM repo, +https://github.com/jmliu206/LIC_TCM), translates it to compressai's module +layout, and writes a state dict that ``compressai.models.tcm.TCM.from_state_dict`` +can load directly. Optionally reports forward-pass sanity numbers +(PSNR / bpp) on a synthetic input. + +The upstream-vs-compressai key differences (``module.`` ``DataParallel`` +prefix, the ``nn.Sequential`` wrapper around each ``SWAtten``, +``atten_mean`` -> ``latent_codec.y.channel_context.y{k}.mean_support_transform``, +ConvTransBlock MSA buffer layouts, layer-norm names, the H+G containerised +re-rooting under ``latent_codec.*``, etc.) are all handled inside +``convert_upstream_tcm_state_dict``; this script is a thin CLI around it. + +Example:: + + python examples/convert_tcm_checkpoint.py \\ + --src candidate/TCM/0.05.pth.tar \\ + --dst /tmp/tcm_compressai.pth \\ + --smoke +""" + +from __future__ import annotations + +import argparse + +from pathlib import Path + +import torch + +from compressai.models.tcm import TCM, convert_upstream_tcm_state_dict + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--src", + type=Path, + required=True, + help="Path to the upstream LIC-TCM checkpoint (e.g. 0.05.pth.tar).", + ) + parser.add_argument( + "--dst", + type=Path, + default=None, + help=( + "Optional output path for the converted state dict. If omitted, " + "the script only verifies that the checkpoint loads cleanly." + ), + ) + parser.add_argument( + "--smoke", + action="store_true", + help="Run a forward smoke test on a synthetic 256x256 image.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if not args.src.exists(): + raise SystemExit(f"checkpoint not found: {args.src}") + + upstream = torch.load(args.src, map_location="cpu", weights_only=False) + upstream = ( + upstream.get("state_dict", upstream) if isinstance(upstream, dict) else upstream + ) + converted = convert_upstream_tcm_state_dict(upstream) + print(f"loaded {len(upstream)} upstream keys → {len(converted)} compressai keys") + + net = TCM.from_state_dict(upstream) + net.eval() + print( + "variant: " + f"N={net.N}, M={net.M}, num_slices={net.num_slices}, " + f"config={tuple(net.config)}, head_dim={tuple(net.head_dim)}, " + f"hyper_channels={net.hyper_channels}, " + f"max_support_slices={net.max_support_slices}" + ) + print(f"parameters: {sum(p.numel() for p in net.parameters()):,}") + + if args.dst is not None: + args.dst.parent.mkdir(parents=True, exist_ok=True) + torch.save(net.state_dict(), args.dst) + print(f"wrote converted state dict → {args.dst}") + + if args.smoke: + height = width = 256 + ys, xs = torch.meshgrid( + torch.linspace(0, 1, height), + torch.linspace(0, 1, width), + indexing="ij", + ) + img = ( + torch.stack( + [ + 0.5 + 0.3 * torch.sin(8 * xs), + 0.5 + 0.3 * torch.sin(8 * ys), + 0.5 + 0.3 * torch.cos(8 * (xs + ys)), + ], + dim=0, + ) + .unsqueeze(0) + .clamp(0, 1) + ) + + with torch.no_grad(): + out = net(img) + n_pix = height * width + psnr = -10 * torch.log10(((out["x_hat"].clamp(0, 1) - img) ** 2).mean()).item() + y_bpp = -torch.log2(out["likelihoods"]["y"]).sum().item() / n_pix + z_bpp = -torch.log2(out["likelihoods"]["z"]).sum().item() / n_pix + print( + f"smoke: PSNR={psnr:.2f}dB y_bpp={y_bpp:.4f} z_bpp={z_bpp:.4f} " + f"total_bpp={y_bpp + z_bpp:.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_models.py b/tests/test_models.py index 64136e61..99532a42 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -393,6 +393,158 @@ def test_stf_upstream_state_dict_conversion(self): assert "lrp_transforms.0.0.weight" not in converted +class TestTcm: + def test_tcm_forward_and_state_dict_round_trip(self): + from compressai.models.tcm import TCM + + model = TCM( + N=32, + M=64, + hyper_channels=48, + num_slices=4, + max_support_slices=2, + ).eval() + x = torch.rand(1, 3, 64, 64) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + + # Phase 4 containerised state-dict layout self-check. + sd_keys = set(model.state_dict().keys()) + # Hyperprior backbone moved under latent_codec.* (TCM's h_a / h_*_s + # use ResidualBlockWithStride / ResidualBlockUpsample, so the first + # learnable weight is conv1 / conv). + assert "latent_codec.h_a.0.conv1.weight" in sd_keys + assert "latent_codec.h_s.h_mean_s.0.conv.weight" in sd_keys + assert "latent_codec.h_s.h_scale_s.0.conv.weight" in sd_keys + assert "latent_codec.z.entropy_bottleneck.quantiles" in sd_keys + # side_in_context=True -> channel_context covers y0..y(K-1). + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y1.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in sd_keys + # SWAtten support transforms (TCM-specific; absent on STF/WACNN). + assert ( + "latent_codec.y.channel_context.y0.mean_support_transform.in_conv.weight" + in sd_keys + ) + assert ( + "latent_codec.y.channel_context.y0.scale_support_transform.in_conv.weight" + in sd_keys + ) + # Per-slice leaves (LRP + per-slice GaussianConditional copy). + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in sd_keys + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" in sd_keys + ) + # Old monolithic / pr-stf-wacnn paths should be gone. + assert not any( + k.startswith("latent_codec.cc_mean_transforms.") for k in sd_keys + ) + assert not any(k.startswith("latent_codec.atten_mean.") for k in sd_keys) + assert "h_a.0.conv1.weight" not in sd_keys # moved under latent_codec. + + loaded = TCM.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + assert loaded.N == 32 + assert loaded.M == 64 + assert loaded.hyper_channels == 48 + assert loaded.num_slices == 4 + assert loaded.max_support_slices == 2 + + def test_tcm_upstream_state_dict_conversion(self): + from compressai.models.tcm import convert_upstream_tcm_state_dict + + # Synthetic upstream LIC_TCM-style state_dict: DataParallel ``module.`` + # prefix, raw entropy heads at the root, the SWAtten ``nn.Sequential`` + # wrapper level (``atten_mean.{k}.0.``), and a ConvTransBlock attention + # buffer in upstream layout (``.msa.relative_position_params``). + upstream = { + "module.g_a.0.conv1.weight": torch.zeros(2), + "module.g_a.1.trans_block.msa.relative_position_params": torch.zeros( + 4, 15, 15 + ), + "module.g_a.1.trans_block.msa.embedding_layer.weight": torch.zeros(2), + "module.g_a.1.trans_block.ln1.weight": torch.zeros(2), + "module.g_a.1.trans_block.mlp.0.weight": torch.zeros(2), + "module.g_a.1.trans_block.mlp.2.weight": torch.zeros(2), + "module.cc_mean_transforms.0.0.weight": torch.zeros(2), + "module.cc_mean_transforms.1.0.weight": torch.zeros(2), + "module.cc_scale_transforms.0.0.weight": torch.zeros(2), + "module.atten_mean.0.0.in_conv.weight": torch.zeros(2), + "module.atten_scale.0.0.in_conv.weight": torch.zeros(2), + "module.lrp_transforms.0.0.weight": torch.zeros(2), + "module.gaussian_conditional.scale_table": torch.zeros(2), + "module.h_a.0.conv1.weight": torch.zeros(2), + "module.h_mean_s.0.conv.weight": torch.zeros(2), + "module.h_scale_s.0.conv.weight": torch.zeros(2), + "module.entropy_bottleneck.quantiles": torch.zeros(2), + } + converted = convert_upstream_tcm_state_dict(upstream) + + # ``module.`` prefix gone; g_a / ConvTransBlock pass through with the + # MSA / layer-name renames applied. + assert "g_a.0.conv1.weight" in converted + # ``relative_position_params`` -> ``relative_position_bias_table`` + # with shape permuted from (2*win-1, 2*win-1, num_heads) = + # (15, 15, 4) into the flat (225, 4) layout. + assert "g_a.1.trans_block.msa.attn.relative_position_bias_table" in converted + assert converted[ + "g_a.1.trans_block.msa.attn.relative_position_bias_table" + ].shape == (15 * 15, 4) + # ``embedding_layer`` -> ``attn.qkv``. + assert "g_a.1.trans_block.msa.attn.qkv.weight" in converted + # ``ln1`` -> ``norm1``; ``mlp.0`` / ``mlp.2`` -> ``mlp.fc1`` / ``fc2``. + assert "g_a.1.trans_block.norm1.weight" in converted + assert "g_a.1.trans_block.mlp.fc1.weight" in converted + assert "g_a.1.trans_block.mlp.fc2.weight" in converted + + # Hyperprior backbone moves under latent_codec. + assert "latent_codec.h_a.0.conv1.weight" in converted + assert "latent_codec.h_s.h_mean_s.0.conv.weight" in converted + assert "latent_codec.h_s.h_scale_s.0.conv.weight" in converted + assert "latent_codec.z.entropy_bottleneck.quantiles" in converted + + # cc_mean / cc_scale re-rooted per slice. + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in converted + assert "latent_codec.y.channel_context.y1.mean_cc.0.weight" in converted + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in converted + + # SWAtten wrapper unwrapped: ``atten_mean.0.0.<...>`` -> + # ``...mean_support_transform.<...>`` (no extra ``.0`` level). + assert ( + "latent_codec.y.channel_context.y0.mean_support_transform.in_conv.weight" + in converted + ) + assert ( + "latent_codec.y.channel_context.y0.scale_support_transform.in_conv.weight" + in converted + ) + + # gaussian_conditional replicated per slice (driven by mean_cc count). + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y1.gaussian_conditional.scale_table" + in converted + ) + + # LRP weights retained byte-for-byte (emit_mean_support=True path). + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in converted + + # Old root-level paths should be gone after conversion. + assert "h_a.0.conv1.weight" not in converted + assert "cc_mean_transforms.0.0.weight" not in converted + assert "atten_mean.0.0.in_conv.weight" not in converted + assert "lrp_transforms.0.0.weight" not in converted + assert "module.g_a.0.conv1.weight" not in converted + + def test_scale_table_default(): table = get_scale_table() assert SCALES_MIN == 0.11 From 1e636f9b893c9892b7df1522b6cb5618d57b3b28 Mon Sep 17 00:00:00 2001 From: boyceyi <1473416941@qq.com> Date: Sat, 9 May 2026 14:20:05 +0800 Subject: [PATCH 4/8] feat(models): add CCA model and loss with containerized codec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the Causal Context Adjustment (CCA) standalone autoencoder from Han et al., NeurIPS 2024 (https://arxiv.org/abs/2410.04847) using the H+G containerized entropy stack already adopted by STF / WACNN / TCM. The hyperprior backbone (h_a, h_mean_s, h_scale_s, EntropyBottleneck) plus the per-slice channel-conditional heads live under a single HyperpriorLatentCodec — same Family 1 pattern, with two CCA-specific specializations: - Variable-length channel slices (slice_proportions defaults to the upstream M=320 layout (8, 28, 56, 92, 136)) instead of the equal slices used by STF / WACNN / TCM. - Per-slice NAFTransform mean / scale support transforms (analogous to TCM's SWAtten), wired via build_mean_scale_head with support_transform_factory. Auxiliary CCA branch (cca_training=True) adds a _CCAAuxEntropyModel field that re-encodes y with skip-most-recent support selection (support_filter=lambda k, prior: prior[: max(k - 1, 0)]) and produces y_aux / y_cca likelihoods consumed by CCARateDistortionLoss. All slices use LRPGaussianLatentCodec — the upstream published checkpoints carry LRP weights for every slice; the unused last-two-slice LRPs are benign because support_filter excludes those slices' y_hat from any later slice's prior. Three small infra additions to support the wiring above: - MeanScaleContextHead.emit_mean_support extended to bool|"pre"|"post". CCA's upstream LRP heads consume the *post*-NAFTransform mean_support, while STF / TCM use the raw pre layout (Identity transform makes pre/post equivalent for them, so True still maps to "pre" for back-compat). - build_channel_slice_codec.support_count_fn lets the caller declare the prior-slice count seen by each channel_context head when a custom support_filter selects a non-default count (CCA-aux's skip-most-recent: lambda k: max(k - 1, 0)). - ChannelGroupsLatentCodec._get_ctx_params_side_in_context falls back to side_params alone when support_filter returns an empty list (CCA-aux at k=1), avoiding torch.cat() on an empty list. CCAModel.forward in cca_training=True replays the hyperprior path to recover latent_means / latent_scales for the aux branch instead of piercing the HyperpriorLatentCodec abstraction — small cost, zero interface churn for the main codec. Verified on candidate/CCA/checkpoint_lambda_0.3.pth.tar (M=320, slice_sizes=[8, 28, 56, 92, 136], em_hidden=224, em_layers=4, cca_training=True; 97M params): strict-loads after convert_upstream_cca_state_dict, sinusoidal smoke yields PSNR 50.07 dB / total bpp 0.072 (vs ~5 dB for a fresh-init model), confirming all weights — including LRP and aux — transfer byte-for-byte. Upstream → compressai key count delta of +56 is explained by replicating the single shared gaussian_conditional buffer set across each per-slice leaf (7 buffers × 4 extra copies per branch × 2 branches = 56), matching the channel_context.y{k} layout. Tests: - tests/test_models.py::TestCca::test_cca_forward_and_state_dict_round_trip - tests/test_models.py::TestCca::test_cca_training_branch_forward_and_round_trip - tests/test_models.py::TestCca::test_cca_upstream_state_dict_conversion Also drop a few stray internal phase-tracking labels from existing docstrings (stf.py / tcm.py / channel_context.py / test_latent_codecs.py / test_models.py) so the comments describe the wiring directly rather than referencing project-internal sequencing. --- compressai/latent_codecs/channel_groups.py | 6 + compressai/losses/__init__.py | 2 + compressai/losses/cca.py | 134 ++ compressai/models/_helpers/channel_context.py | 79 +- compressai/models/_helpers/channel_slice.py | 14 +- compressai/models/cca.py | 1078 +++++++++++++++++ compressai/models/stf.py | 4 +- compressai/models/tcm.py | 2 +- examples/convert_cca_checkpoint.py | 124 ++ tests/test_latent_codecs.py | 2 +- tests/test_models.py | 252 +++- 11 files changed, 1665 insertions(+), 32 deletions(-) create mode 100644 compressai/losses/cca.py create mode 100644 compressai/models/cca.py create mode 100644 examples/convert_cca_checkpoint.py diff --git a/compressai/latent_codecs/channel_groups.py b/compressai/latent_codecs/channel_groups.py index 85159393..69670925 100644 --- a/compressai/latent_codecs/channel_groups.py +++ b/compressai/latent_codecs/channel_groups.py @@ -195,6 +195,12 @@ def _get_ctx_params_side_in_context( if k == 0: return self.channel_context["y0"](side_params) support = self._select_support(k, y_hat_) + # ``support`` can be empty when ``support_filter`` skips the most + # recent slice for k=1 (e.g., CCA-aux's + # ``lambda k, prior: prior[: max(k - 1, 0)]``); in that case the + # head sees only ``side_params``. + if not support: + return self.channel_context[f"y{k}"](side_params) ch_input = self.merge_params(side_params, self.merge_y(*support)) return self.channel_context[f"y{k}"](ch_input) diff --git a/compressai/losses/__init__.py b/compressai/losses/__init__.py index b0863e85..9e8f8ea6 100644 --- a/compressai/losses/__init__.py +++ b/compressai/losses/__init__.py @@ -28,10 +28,12 @@ # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from . import pointcloud +from .cca import CCARateDistortionLoss from .pointcloud import * from .rate_distortion import RateDistortionLoss __all__ = [ *pointcloud.__all__, + "CCARateDistortionLoss", "RateDistortionLoss", ] diff --git a/compressai/losses/cca.py b/compressai/losses/cca.py new file mode 100644 index 00000000..3a4b7009 --- /dev/null +++ b/compressai/losses/cca.py @@ -0,0 +1,134 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Causal Context Adjustment rate-distortion loss (Han et al., NeurIPS 2024). + +Companion criterion for :class:`compressai.models.cca.CCAModel`. Requires +the model's ``forward`` to return ``aux_likelihoods = {"y_aux", "y_cca"}`` +(populated when ``cca_training=True``) so this loss can add the auxiliary +CCA terms on top of the standard rate-distortion objective. +""" + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn + +from pytorch_msssim import ms_ssim + +from compressai.registry import register_criterion + + +@register_criterion("CCARateDistortionLoss") +class CCARateDistortionLoss(nn.Module): + r"""Causal Context Adjustment rate-distortion loss from M. Han, S. Jiang, + S. Li, X. Deng, M. Xu, C. Zhu, S. Gu: `"Causal Context Adjustment Loss + for Learned Image Compression" `_, + Adv. in Neural Information Processing Systems 38 (NeurIPS), 2024. + + Combines the standard rate (``bpp``) and distortion (MSE / MS-SSIM) + terms with the CCA term that measures the gap between the main and + auxiliary causal-context likelihoods produced by + :class:`compressai.models.cca.CCAModel` (with ``cca_training=True``). + + Args: + lmbda: Distortion weight. + metric: Distortion metric, ``"mse"`` or ``"ms-ssim"``. + return_type: ``"all"`` returns the dict of components; otherwise + return the named scalar component (e.g. ``"loss"``). + alpha: Weight on the CCA loss term. + beta: Weight on the bit-rate term. + """ + + def __init__( + self, + lmbda: float = 0.01, + metric: str = "mse", + return_type: str = "all", + alpha: float = 1.0, + beta: float = 1.0, + ) -> None: + super().__init__() + if metric == "mse": + self.metric = nn.MSELoss() + elif metric == "ms-ssim": + self.metric = ms_ssim + else: + raise NotImplementedError(f"{metric} is not implemented!") + + self.lmbda = float(lmbda) + self.alpha = float(alpha) + self.beta = float(beta) + self.return_type = return_type + + def forward(self, output, target): + if "aux_likelihoods" not in output or output["aux_likelihoods"] is None: + raise KeyError( + "output must contain aux_likelihoods for CCARateDistortionLoss; " + "ensure CCAModel was constructed with cca_training=True" + ) + + aux_likelihoods = output["aux_likelihoods"] + if "y_aux" not in aux_likelihoods or "y_cca" not in aux_likelihoods: + raise KeyError("aux_likelihoods must contain y_aux and y_cca") + + batch_size, _, height, width = target.size() + num_pixels = batch_size * height * width + out = {} + + out["cca_loss"] = ( + torch.log(output["likelihoods"]["y"]).sum() / (-math.log(2)) + - torch.log(aux_likelihoods["y_cca"]).sum() / (-math.log(2)) + ) / num_pixels + out["aux2_loss"] = torch.sum( + aux_likelihoods["y_cca"] * torch.log(aux_likelihoods["y_aux"]) + ) / (-math.log(2) * num_pixels) + out["bpp_loss"] = sum( + (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) + for likelihoods in output["likelihoods"].values() + ) + + if self.metric == ms_ssim: + out["ms_ssim_loss"] = self.metric(output["x_hat"], target, data_range=1) + distortion = 1 - out["ms_ssim_loss"] + else: + out["mse_loss"] = self.metric(output["x_hat"], target) + distortion = 255**2 * out["mse_loss"] + + out["loss"] = ( + self.lmbda * distortion + + self.beta * out["bpp_loss"] + + self.alpha * out["cca_loss"] + + out["aux2_loss"] + ) + if self.return_type == "all": + return out + return out[self.return_type] diff --git a/compressai/models/_helpers/channel_context.py b/compressai/models/_helpers/channel_context.py index 1523f67c..d9b3c008 100644 --- a/compressai/models/_helpers/channel_context.py +++ b/compressai/models/_helpers/channel_context.py @@ -39,7 +39,7 @@ from __future__ import annotations -from typing import Callable, Optional, Sequence +from typing import Callable, Literal, Optional, Sequence, Union import torch import torch.nn as nn @@ -89,15 +89,27 @@ class MeanScaleContextHead(nn.Module): When ``side_split == 0`` (default) the head is generic: ``mean_cc`` and ``scale_cc`` both see the full input, no internal split. - When ``emit_mean_support=True`` (only meaningful with ``side_split > 0``) - the head appends the ``mean_in = cat(latent_means, *prev_y_hat)`` tensor - to the output, producing - ``cat(scale, mean, mean_in)`` of shape - ``(B, 2*slice_ch + side_split + sum(prev_groups), H, W)``. This trailing - block is consumed by :class:`LRPGaussianLatentCodec` (with matching - ``mean_support_trail_channels``) to recover the upstream STF / WACNN - LRP input layout (``cat(latent_means, *prev_y_hat, y_hat)``), enabling - byte-for-byte transfer of upstream LRP weights. + When ``emit_mean_support`` is truthy (only meaningful with + ``side_split > 0``) the head appends a copy of the mean-path tensor to + the output, producing + ``cat(scale, mean, mean_support)`` of shape + ``(B, 2*slice_ch + side_split + sum(prev_groups), H, W)``. Two flavours: + + - ``"pre"`` (legacy ``True``) — emit the raw ``mean_in = + cat(latent_means, *prev_y_hat)`` (i.e., before + ``mean_support_transform``). STF / WACNN / TCM use this because their + ``mean_support_transform`` is :class:`Identity` (or the upstream LRP + input is the un-transformed mean_in). + - ``"post"`` — emit ``mean_support_transform(mean_in)`` (the same tensor + that feeds ``mean_cc``). CCA-main / CCA-aux use this because their + upstream ``lrp_transforms`` consume the *post*-NAFTransform mean + support; emitting "pre" would produce wrong LRP outputs even though + the channel widths match. + + The trailing block is consumed by :class:`LRPGaussianLatentCodec` (with + matching ``mean_support_trail_channels``) to recover the upstream LRP + input layout (``cat(mean_support, y_hat)``), enabling byte-for-byte + transfer of upstream LRP weights. """ mean_cc: nn.Module @@ -113,7 +125,7 @@ def __init__( scale_support_transform: Optional[nn.Module] = None, *, side_split: int = 0, - emit_mean_support: bool = False, + emit_mean_support: Union[bool, Literal["pre", "post"]] = False, ) -> None: super().__init__() self.mean_cc = mean_cc @@ -121,11 +133,22 @@ def __init__( self.mean_support_transform = mean_support_transform or nn.Identity() self.scale_support_transform = scale_support_transform or nn.Identity() self.side_split = int(side_split) - self.emit_mean_support = bool(emit_mean_support) + self.emit_mean_support: Literal[False, "pre", "post"] + if emit_mean_support is True: + self.emit_mean_support = "pre" + elif emit_mean_support is False: + self.emit_mean_support = False + elif emit_mean_support in ("pre", "post"): + self.emit_mean_support = emit_mean_support + else: + raise ValueError( + f"emit_mean_support must be False, True, 'pre', or 'post'; " + f"got {emit_mean_support!r}" + ) if self.emit_mean_support and self.side_split <= 0: raise ValueError( - "emit_mean_support=True requires side_split > 0 to recover " - "the legacy mean_support layout cat(latent_means, *prev_y_hat)." + "emit_mean_support requires side_split > 0 to recover the " + "legacy mean_support layout cat(latent_means, *prev_y_hat)." ) def forward(self, x: Tensor) -> Tensor: @@ -138,11 +161,14 @@ def forward(self, x: Tensor) -> Tensor: scale_in = torch.cat([latent_scales, prev_y_hat], dim=1) else: mean_in = scale_in = x - mean = self.mean_cc(self.mean_support_transform(mean_in)) + mean_support = self.mean_support_transform(mean_in) + mean = self.mean_cc(mean_support) scale = self.scale_cc(self.scale_support_transform(scale_in)) out = torch.cat([scale, mean], dim=1) - if self.emit_mean_support: + if self.emit_mean_support == "pre": out = torch.cat([out, mean_in], dim=1) + elif self.emit_mean_support == "post": + out = torch.cat([out, mean_support], dim=1) return out @@ -153,7 +179,7 @@ def build_mean_scale_head( widths: Sequence[int] = (224, 128), support_transform_factory: Optional[Callable[[int, int], nn.Module]] = None, side_split: int = 0, - emit_mean_support: bool = False, + emit_mean_support: Union[bool, Literal["pre", "post"]] = False, ) -> MeanScaleContextHead: """Construct a :class:`MeanScaleContextHead` with default conv-stack heads. @@ -189,19 +215,22 @@ def build_mean_scale_head( emit_mean_support Forwarded to :class:`MeanScaleContextHead`. Why this flag exists: the upstream STF / WACNN / TCM / CCA LRP transform consumes - ``cat(latent_means, *prev_y_hat, y_hat)`` (i.e. ``M + slice_ch * - (support_count + 1)`` channels — variable per slice). The Phase 3 + ``cat(mean_support, y_hat)`` (i.e. ``M + slice_ch * + (support_count + 1)`` channels — variable per slice). A naive leaf only sees the channel-context ``ctx_params`` (= 2*slice_ch) and ``y_hat``, which would force an architectural change to the LRP transform input width and prevent byte-for-byte transfer of upstream - LRP weights. Setting ``emit_mean_support=True`` makes the head - append ``mean_in = cat(latent_means, *prev_y_hat)`` to its output; - :class:`LRPGaussianLatentCodec` (with matching - ``mean_support_trail_channels``) then strips that trailing block off + LRP weights. Setting ``emit_mean_support`` to ``"pre"`` (or legacy + ``True``) makes the head append ``mean_in = cat(latent_means, + *prev_y_hat)`` to its output; setting it to ``"post"`` appends + ``mean_support_transform(mean_in)`` instead (CCA-main / CCA-aux, + whose upstream LRP heads consume the *post*-NAFTransform mean + support). :class:`LRPGaussianLatentCodec` (with matching + ``mean_support_trail_channels``) strips that trailing block off ``ctx_params``, feeds only the leading ``2*slice_ch`` to the Gaussian conditional's ``chunks=("scales","means")`` step, and uses - the trailing block as the LRP input — recovering the upstream layout - exactly. + the trailing block as the LRP input — recovering the upstream + layout exactly. """ sub_in_ch = support_ch - side_split mean_cc = make_entropy_transform(sub_in_ch, slice_ch, widths=widths) diff --git a/compressai/models/_helpers/channel_slice.py b/compressai/models/_helpers/channel_slice.py index f6891e5d..28d236d2 100644 --- a/compressai/models/_helpers/channel_slice.py +++ b/compressai/models/_helpers/channel_slice.py @@ -50,6 +50,7 @@ def build_channel_slice_codec( channel_context_factory: Optional[Callable[[int, int, int], nn.Module]] = None, max_support_slices: int = -1, support_filter: Optional[Callable[[int, List[Tensor]], List[Tensor]]] = None, + support_count_fn: Optional[Callable[[int], int]] = None, side_in_context: bool = False, side_channels: int = 0, ) -> ChannelGroupsLatentCodec: @@ -88,6 +89,15 @@ def build_channel_slice_codec( support_filter Forwarded to :class:`ChannelGroupsLatentCodec`. Used by CCA-aux for skip-most-recent support selection. + support_count_fn + ``(k) -> int``. Override for the number of prior slices that + ``channel_context.y{k}`` will see at *runtime*. Required when + ``support_filter`` selects a non-default count (e.g., CCA-aux's + skip-most-recent ``lambda k: max(k - 1, 0)``); the factory uses + this count when sizing the channel_context heads. Defaults to + ``min(k, max_support_slices)`` (or ``k`` when + ``max_support_slices < 0``), matching ``ChannelGroupsLatentCodec``'s + own clamp logic when ``support_filter`` is unset. side_in_context Forwarded to :class:`ChannelGroupsLatentCodec`. When ``True`` the ``channel_context`` for ``y0`` consumes ``side_params`` and @@ -108,11 +118,13 @@ def build_channel_slice_codec( K = len(groups) - def _support_count(k: int) -> int: + def _default_support_count(k: int) -> int: if max_support_slices < 0: return k return min(k, max_support_slices) + _support_count = support_count_fn or _default_support_count + def _support_ch(k: int) -> int: prior_ch = sum(groups[: _support_count(k)]) if side_in_context: diff --git a/compressai/models/cca.py b/compressai/models/cca.py new file mode 100644 index 00000000..fe892bd2 --- /dev/null +++ b/compressai/models/cca.py @@ -0,0 +1,1078 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts code from https://github.com/CVL-UESTC/CCA +# (originally distributed under the MIT License). The upstream copyright +# notice is preserved in that repository; modifications by InterDigital +# Communications, Inc. are released under the BSD 3-Clause Clear License +# terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Causal Context Adjustment (CCA) standalone autoencoder. + +Mirror of the upstream ``LICAutoencoder`` from +M. Han, S. Jiang, S. Li, X. Deng, M. Xu, C. Zhu, S. Gu: +`"Causal Context Adjustment Loss for Learned Image Compression" +`_, NeurIPS 2024. + +Family 1 wiring (see :mod:`compressai.latent_codecs.__init__`): the main +entropy stack is a fully containerised :class:`HyperpriorLatentCodec` +running ``side_in_context=True`` with variable-length slice groups and +per-slice :class:`_NAFTransform` support transforms. The optional +auxiliary CCA branch (:class:`_CCAAuxEntropyModel`) is a separate +``nn.Module`` that re-encodes ``y`` with the skip-most-recent +``support_filter`` selection used by +:class:`compressai.losses.CCARateDistortionLoss` to align the causal +context with the rate-distortion objective. +""" + +from __future__ import annotations + +import math + +from typing import Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.entropy_models import EntropyBottleneck +from compressai.latent_codecs import ( + DualHyperSynthesis, + EntropyBottleneckLatentCodec, + GaussianConditionalLatentCodec, + HyperpriorLatentCodec, + LRPGaussianLatentCodec, +) +from compressai.latent_codecs._slice_helpers import make_entropy_transform +from compressai.layers.layers import conv1x1 +from compressai.models._helpers.channel_context import build_mean_scale_head +from compressai.models._helpers.channel_slice import build_channel_slice_codec +from compressai.models.base import CompressionModel, get_scale_table +from compressai.models.sensetime import ResidualBottleneckBlock +from compressai.models.utils import conv, deconv +from compressai.registry import register_model + +__all__ = [ + "CCAModel", + "convert_upstream_cca_state_dict", +] + + +# ---------------------------------------------------------------------------- +# Slice-size resolver. +# ---------------------------------------------------------------------------- + + +def _resolve_slice_sizes( + latent_channels: int, slice_proportions: Sequence[int] +) -> List[int]: + if len(slice_proportions) == 0: + raise ValueError("slice_proportions must contain at least one entry") + total = sum(slice_proportions) + if total <= 0: + raise ValueError("slice_proportions must sum to a positive integer") + sizes = [ + int(math.floor(latent_channels * proportion / total)) + for proportion in slice_proportions + ] + sizes[-1] += latent_channels - sum(sizes) + if any(size <= 0 for size in sizes): + raise ValueError("resolved slice sizes must all be positive") + return sizes + + +# ---------------------------------------------------------------------------- +# NAF (Non-linear Activation Free) building blocks +# ---------------------------------------------------------------------------- + + +class _SimpleGate(nn.Module): + def forward(self, input_tensor: Tensor) -> Tensor: + gate_tensor, value_tensor = input_tensor.chunk(2, dim=1) + return gate_tensor * value_tensor + + +class _NAFBlock(nn.Module): + """Non-linear Activation Free residual block. + + Used by both the CCA entropy-model auxiliary transforms and the CCA + image-compression model's analysis / synthesis stacks. State-dict keys + (``norm1`` / ``pointwise_depthwise`` / ``channel_attention`` / + ``project`` / ``feed_forward`` / ``beta`` / ``gamma``) match upstream + after :func:`convert_upstream_cca_state_dict` so released checkpoints + load 1:1. + """ + + def __init__(self, channels: int) -> None: + super().__init__() + from timm.layers import LayerNorm2d + + expanded_channels = channels * 2 + self.norm1 = LayerNorm2d(channels) + self.pointwise_depthwise = nn.Sequential( + conv1x1(channels, expanded_channels), + nn.Conv2d( + expanded_channels, + expanded_channels, + kernel_size=3, + padding=1, + groups=expanded_channels, + ), + ) + self.gate = _SimpleGate() + self.channel_attention = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + conv1x1(channels, channels), + ) + self.project = conv1x1(channels, channels) + self.norm2 = LayerNorm2d(channels) + self.feed_forward = nn.Sequential( + conv1x1(channels, expanded_channels), + _SimpleGate(), + conv1x1(channels, channels), + ) + self.beta = nn.Parameter(torch.zeros(1, channels, 1, 1)) + self.gamma = nn.Parameter(torch.zeros(1, channels, 1, 1)) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.norm1(input_tensor) + output = self.pointwise_depthwise(output) + output = self.gate(output) + output = output * self.channel_attention(output) + output = self.project(output) + output = input_tensor + self.beta * output + return output + self.gamma * self.feed_forward(self.norm2(output)) + + +class _NAFTransform(nn.Module): + """``Conv1x1 -> NAFBlock x N -> Conv1x1`` per-slice support transform.""" + + def __init__( + self, + input_channels: int, + output_channels: int, + hidden_channels: int, + num_layers: int, + ) -> None: + super().__init__() + if num_layers < 1: + raise ValueError("num_layers must be positive") + + self.input_projection = conv1x1(input_channels, hidden_channels) + self.blocks = nn.Sequential( + *(_NAFBlock(hidden_channels) for _ in range(num_layers)) + ) + self.output_projection = conv1x1(hidden_channels, output_channels) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.input_projection(input_tensor) + return self.output_projection(output + self.blocks(output)) + + +# ---------------------------------------------------------------------------- +# Analysis / synthesis transforms (NAFBlock + ResidualBottleneckBlock). +# ---------------------------------------------------------------------------- + + +class _CCAEncoder(nn.Module): + """NAFBlock + ResidualBottleneckBlock analysis transform (4 strides).""" + + def __init__( + self, + in_channels: int, + latent_channels: int, + stage_dims: Sequence[int], + stage_layers: Sequence[int], + ) -> None: + super().__init__() + if len(stage_dims) != len(stage_layers): + raise ValueError("stage_dims and stage_layers must have matching length") + self.depth = len(stage_dims) + all_dims = [in_channels, *stage_dims, latent_channels] + self.down = nn.ModuleList( + conv(all_dims[index], all_dims[index + 1], kernel_size=5, stride=2) + for index in range(self.depth + 1) + ) + self.blocks = nn.ModuleList( + nn.Sequential( + *( + ResidualBottleneckBlock(stage_dims[index], stage_dims[index]) + for _ in range(3) + ), + *(_NAFBlock(stage_dims[index]) for _ in range(stage_layers[index])), + ) + for index in range(self.depth) + ) + + def forward(self, x: Tensor) -> Tensor: + for index in range(self.depth): + x = self.down[index](x) + x = self.blocks[index](x) + return self.down[self.depth](x) + + +class _CCADecoder(nn.Module): + """NAFBlock + ResidualBottleneckBlock synthesis transform (4 strides).""" + + def __init__( + self, + out_channels: int, + latent_channels: int, + stage_dims: Sequence[int], + stage_layers: Sequence[int], + ) -> None: + super().__init__() + if len(stage_dims) != len(stage_layers): + raise ValueError("stage_dims and stage_layers must have matching length") + self.depth = len(stage_dims) + all_dims = [out_channels, *stage_dims, latent_channels] + self.up = nn.ModuleList( + deconv(all_dims[index + 1], all_dims[index], kernel_size=5, stride=2) + for index in reversed(range(self.depth + 1)) + ) + self.blocks = nn.ModuleList( + nn.Sequential( + *(_NAFBlock(stage_dims[index]) for _ in range(stage_layers[index])), + *( + ResidualBottleneckBlock(stage_dims[index], stage_dims[index]) + for _ in range(3) + ), + ) + for index in reversed(range(self.depth)) + ) + + def forward(self, x: Tensor) -> Tensor: + for index in range(self.depth): + x = self.up[index](x) + x = self.blocks[index](x) + return self.up[self.depth](x) + + +# ---------------------------------------------------------------------------- +# Family 1 entropy-stack builders (main + auxiliary). +# ---------------------------------------------------------------------------- + + +def _build_cca_main_latent_codec( + *, + M: int, + hyper_channels: int, + slice_sizes: Sequence[int], + em_hidden_channels: int, + em_num_layers: int, + h_a: nn.Module, + h_mean_s: nn.Module, + h_scale_s: nn.Module, +) -> HyperpriorLatentCodec: + """Main entropy stack: ``HyperpriorLatentCodec`` wrapping + ``DualHyperSynthesis`` and a per-slice ``ChannelGroupsLatentCodec``. + + Distinctive choices vs. STF/TCM (other Family 1 models): + + - ``groups`` is a variable-length list (resolved from + ``slice_proportions``); STF / WACNN / TCM use uniform ``[M//K]*K``. + - ``support_transform_factory`` builds a per-slice + :class:`_NAFTransform` for both mean and scale paths (vs. STF + identity / TCM SWAtten). + - The leaf is :class:`LRPGaussianLatentCodec` with + ``mean_support_trail_channels`` matching + ``M + sum(slice_sizes[:k])``, paired with + ``MeanScaleContextHead(emit_mean_support="post")`` so the LRP head + receives the *post*-NAFTransform mean support — replicating the + upstream LIC LRP layout for byte-for-byte weight transfer. + - The ``z`` leaf uses ``EntropyBottleneckLatentCodec(quantizer="ste")`` + to recover upstream's ``quantize_ste(z - z_offset) + z_offset`` + behaviour without a model-side hack. + """ + cumulative = list(_cumsum_with_zero(slice_sizes)) + side_channels = 2 * M + K = len(slice_sizes) + + def _support_count(k: int) -> int: + # use-all-prior; matches upstream LIC main path (no skip). + return k + + def _mean_support_ch(k: int) -> int: + # cat(latent_means(M), *prev_y_hat(sum(slice_sizes[:k]))). + return M + cumulative[k] + + def _leaf(k: int, slice_ch: int) -> LRPGaussianLatentCodec: + ms_ch = _mean_support_ch(k) + return LRPGaussianLatentCodec( + lrp_transform=_make_cca_head( + ms_ch + slice_ch, # cat(mean_support, y_hat) + em_hidden_channels, + slice_ch, + ), + mean_support_trail_channels=ms_ch, + quantizer="ste", + ) + + def _naf_factory(c_in: int, c_out: int) -> nn.Module: + return _NAFTransform(c_in, c_out, em_hidden_channels, em_num_layers) + + def _channel_context(_k: int, slice_ch: int, support_ch: int) -> nn.Module: + return build_mean_scale_head( + slice_ch=slice_ch, + support_ch=support_ch, + widths=(em_hidden_channels, 128), + side_split=M, + emit_mean_support="post", + support_transform_factory=_naf_factory, + ) + + if K == 0: + raise ValueError("slice_sizes must contain at least one entry") + + return HyperpriorLatentCodec( + h_a=h_a, + h_s=DualHyperSynthesis(h_mean_s, h_scale_s), + latent_codec={ + "z": EntropyBottleneckLatentCodec( + entropy_bottleneck=EntropyBottleneck(hyper_channels), + quantizer="ste", + ), + "y": build_channel_slice_codec( + groups=list(slice_sizes), + side_channels=side_channels, + side_in_context=True, + max_support_slices=-1, + support_count_fn=_support_count, + leaf_factory=_leaf, + channel_context_factory=_channel_context, + ), + }, + ) + + +def _make_cca_head( + in_channels: int, hidden_channels: int, out_channels: int +) -> nn.Sequential: + """Three-conv stack ``in -> hidden -> 128 -> out`` (kernel 3, stride 1). + + Matches upstream ``mean_cc_transforms[k]`` / ``lrp_transforms[k]`` + layout. Wraps :func:`make_entropy_transform` with the CCA-specific + ``widths=(hidden_channels, 128)``. + """ + return make_entropy_transform( + in_channels, out_channels, widths=(hidden_channels, 128) + ) + + +def _cumsum_with_zero(values: Sequence[int]) -> List[int]: + """Return ``[0, values[0], values[0]+values[1], ...]`` (length ``len+1``).""" + out = [0] + running = 0 + for value in values: + running += int(value) + out.append(running) + return out + + +class _CCAAuxEntropyModel(nn.Module): + """Auxiliary CCA entropy branch (skip-most-recent-slice support). + + Produces the ``y_aux`` (factorised) and ``y_cca`` (Gaussian-conditional) + likelihoods used by :class:`compressai.losses.CCARateDistortionLoss`. + + Mirrors the upstream ``AuxEntropyModel`` in + ``candidate/CCA/models/aux_em.py``: for slice ``i`` the support is + ``cat(latent_means, *y_hat_slices[: max(i - 1, 0)])`` (i.e., skip the + *most recent* decoded slice). This is wired declaratively through + :func:`build_channel_slice_codec` with + ``support_filter=lambda k, prior: prior[: max(k - 1, 0)]`` and a + matching ``support_count_fn`` to size the channel-context heads. + + Although upstream only *uses* the LRP path on the first ``num_slices - + 2`` slices, the published checkpoints carry LRP weights for *all* + slices. To strict-load those checkpoints every leaf is a + :class:`LRPGaussianLatentCodec`; the LRP applied to the trailing two + slices is benign (those slices' ``y_hat`` is excluded from every + later slice's support_filter, so it never feeds back into the + likelihoods). + """ + + def __init__( + self, + latent_channels: int, + slice_sizes: Sequence[int], + hidden_channels: int, + num_layers: int, + ) -> None: + super().__init__() + self.latent_channels = int(latent_channels) + self.slice_sizes: List[int] = list(map(int, slice_sizes)) + self.num_slices = len(self.slice_sizes) + self.hidden_channels = int(hidden_channels) + self.num_layers = int(num_layers) + + cumulative = _cumsum_with_zero(self.slice_sizes) + side_channels = 2 * self.latent_channels + + def _support_count(k: int) -> int: + return max(k - 1, 0) + + def _support_filter(k: int, prior: List[Tensor]) -> List[Tensor]: + return prior[: max(k - 1, 0)] + + def _mean_support_ch(k: int) -> int: + return self.latent_channels + cumulative[_support_count(k)] + + def _leaf(k: int, slice_ch: int) -> LRPGaussianLatentCodec: + ms_ch = _mean_support_ch(k) + return LRPGaussianLatentCodec( + lrp_transform=_make_cca_head( + ms_ch + slice_ch, + self.hidden_channels, + slice_ch, + ), + mean_support_trail_channels=ms_ch, + quantizer="ste", + ) + + def _naf_factory(c_in: int, c_out: int) -> nn.Module: + return _NAFTransform(c_in, c_out, self.hidden_channels, self.num_layers) + + def _channel_context(_k: int, slice_ch: int, support_ch: int) -> nn.Module: + return build_mean_scale_head( + slice_ch=slice_ch, + support_ch=support_ch, + widths=(self.hidden_channels, 128), + side_split=self.latent_channels, + emit_mean_support="post", + support_transform_factory=_naf_factory, + ) + + self.y_entropy_bottleneck = EntropyBottleneck(self.latent_channels) + self.inner_codec = build_channel_slice_codec( + groups=list(self.slice_sizes), + side_channels=side_channels, + side_in_context=True, + max_support_slices=-1, + support_filter=_support_filter, + support_count_fn=_support_count, + leaf_factory=_leaf, + channel_context_factory=_channel_context, + ) + + def forward( + self, + y: Tensor, + latent_means: Tensor, + latent_scales: Tensor, + ) -> Dict[str, Tensor]: + _, y_aux_likelihoods = self.y_entropy_bottleneck(y) + side_params = torch.cat([latent_means, latent_scales], dim=1) + inner_out = self.inner_codec(y, side_params) + return { + "y_aux": y_aux_likelihoods, + "y_cca": inner_out["likelihoods"]["y"], + } + + +# ---------------------------------------------------------------------------- +# Top-level CCAModel. +# ---------------------------------------------------------------------------- + + +@register_model("cca") +class CCAModel(CompressionModel): + r"""Causal Context Adjustment standalone autoencoder. + + Mirrors the upstream ``LICAutoencoder`` from M. Han et al., NeurIPS 2024 + (`Causal Context Adjustment Loss for Learned Image Compression + `_). + + The entropy stack is a Family 1 :class:`HyperpriorLatentCodec` (see + :mod:`compressai.latent_codecs.__init__` for the full pattern) with + variable-length channel slices (``slice_proportions``), per-slice + :class:`_NAFTransform` support transforms, and a STE-quantised ``z`` + leaf. When ``cca_training=True`` an auxiliary + :class:`_CCAAuxEntropyModel` branch is added that produces ``y_aux`` / + ``y_cca`` likelihoods consumed by + :class:`compressai.losses.CCARateDistortionLoss`. + + Args: + latent_channels: Number of channels in the latent (``M``). + hyper_channels: Number of channels in the hyper-latent (``N_z``). + slice_proportions: Per-slice channel proportions; the actual slice + channel widths are computed as + ``floor(latent_channels * p / sum(p))`` with the residual added + to the last slice. Pass ``[1] * num_slices`` for equal-sized + slices; pass ``[8, 28, 56, 92, 136]`` to reproduce the upstream + published M=320 layout. + encoder_dims: Per-stage feature widths for the analysis transform + (3 stages by default). + encoder_layers: Per-stage NAFBlock counts for the analysis transform. + em_hidden_channels: Hidden width inside the per-slice NAFTransforms + and channel-context heads. + em_num_layers: NAFBlock count inside each per-slice NAFTransform. + cca_training: When ``True``, allocate the auxiliary CCA entropy + branch so that ``forward`` populates ``aux_likelihoods``. + """ + + def __init__( + self, + latent_channels: int = 320, + hyper_channels: int = 192, + slice_proportions: Sequence[int] = (8, 28, 56, 92, 136), + encoder_dims: Sequence[int] = (192, 224, 256), + encoder_layers: Sequence[int] = (4, 4, 4), + em_hidden_channels: int = 224, + em_num_layers: int = 4, + cca_training: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + encoder_dims = tuple(encoder_dims) + encoder_layers = tuple(encoder_layers) + slice_proportions = tuple(int(value) for value in slice_proportions) + + self.M = int(latent_channels) + self.N = int(hyper_channels) + self.encoder_dims = encoder_dims + self.encoder_layers = encoder_layers + self.slice_proportions = slice_proportions + self.em_hidden_channels = int(em_hidden_channels) + self.em_num_layers = int(em_num_layers) + self.cca_training = bool(cca_training) + + self.slice_sizes: List[int] = _resolve_slice_sizes(self.M, slice_proportions) + self.num_slices = len(self.slice_sizes) + + self.g_a = _CCAEncoder(3, self.M, encoder_dims, encoder_layers) + self.g_s = _CCADecoder(3, self.M, encoder_dims, encoder_layers) + + last_encoder_dim = encoder_dims[-1] + h_a = nn.Sequential( + conv(self.M, last_encoder_dim, kernel_size=3, stride=1), + nn.GELU(), + conv(last_encoder_dim, last_encoder_dim, kernel_size=5, stride=2), + nn.GELU(), + conv(last_encoder_dim, self.N, kernel_size=5, stride=2), + ) + h_mean_s = nn.Sequential( + deconv(self.N, last_encoder_dim, kernel_size=5, stride=2), + nn.GELU(), + deconv(last_encoder_dim, last_encoder_dim, kernel_size=5, stride=2), + nn.GELU(), + deconv(last_encoder_dim, self.M, kernel_size=3, stride=1), + ) + h_scale_s = nn.Sequential( + deconv(self.N, last_encoder_dim, kernel_size=5, stride=2), + nn.GELU(), + deconv(last_encoder_dim, last_encoder_dim, kernel_size=5, stride=2), + nn.GELU(), + deconv(last_encoder_dim, self.M, kernel_size=3, stride=1), + ) + + self.latent_codec = _build_cca_main_latent_codec( + M=self.M, + hyper_channels=self.N, + slice_sizes=self.slice_sizes, + em_hidden_channels=self.em_hidden_channels, + em_num_layers=self.em_num_layers, + h_a=h_a, + h_mean_s=h_mean_s, + h_scale_s=h_scale_s, + ) + + if self.cca_training: + self.aux_entropy_model = _CCAAuxEntropyModel( + self.M, + self.slice_sizes, + self.em_hidden_channels, + self.em_num_layers, + ) + + def forward(self, x: Tensor) -> Dict[str, object]: + y = self.g_a(x) + y_out = self.latent_codec(y) + result: Dict[str, object] = { + "y": y, + "x_hat": self.g_s(y_out["y_hat"]), + "likelihoods": y_out["likelihoods"], + } + if self.cca_training: + # ``self.latent_codec.h_s`` is the ``DualHyperSynthesis``; its + # output is ``cat(latent_means, latent_scales)`` of width 2*M. + # Recover them from the inner ``z`` round-trip so the aux + # branch sees the same hyperprior context as the main path. + z_out = self.latent_codec.latent_codec["z"](self.latent_codec.h_a(y)) + side_params = self.latent_codec.h_s(z_out["y_hat"]) + latent_means, latent_scales = torch.split(side_params, self.M, dim=1) + result["aux_likelihoods"] = self.aux_entropy_model( + y, latent_means, latent_scales + ) + else: + result["aux_likelihoods"] = None + return result + + def compress(self, x: Tensor) -> Dict[str, object]: + y = self.g_a(x) + y_out = self.latent_codec.compress(y) + return {"strings": y_out["strings"], "shape": y_out["shape"]} + + def decompress( + self, + strings: Sequence[Sequence[bytes]], + shape: Tuple[int, int], + ) -> Dict[str, Tensor]: + y_out = self.latent_codec.decompress(strings, shape) + return {"x_hat": self.g_s(y_out["y_hat"]).clamp_(0, 1)} + + def update( + self, scale_table: Optional[Tensor] = None, force: bool = False, **kwargs + ) -> bool: + if scale_table is None: + scale_table = get_scale_table() + return super().update(scale_table=scale_table, force=force, **kwargs) + + def load_state_dict(self, state_dict: Dict[str, Tensor], strict: bool = True): + if _is_upstream_cca_state_dict(state_dict): + state_dict = convert_upstream_cca_state_dict(state_dict) + return super().load_state_dict(state_dict, strict=strict) + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "CCAModel": + if _is_upstream_cca_state_dict(state_dict): + state_dict = convert_upstream_cca_state_dict(state_dict) + cfg = _infer_config_from_state_dict(state_dict) + net = cls(**cfg) + net.load_state_dict(state_dict) + return net + + +# ---------------------------------------------------------------------------- +# Upstream → compressai state-dict conversion. +# ---------------------------------------------------------------------------- + + +# NAFBlock interior renames (upstream -> compressai). These are scoped to +# detected NAFBlock prefixes so they don't accidentally rewrite ``conv1`` in +# unrelated modules (e.g. ResidualBottleneckBlock has its own ``conv1``). +_NAF_BLOCK_RENAMES = { + "dwconv.": "pointwise_depthwise.", + "sca.": "channel_attention.", + "FFN.": "feed_forward.", + "conv1.": "project.", +} +# NAFTransform interior renames. +_NAF_TRANSFORM_RENAMES = { + "in_conv.": "input_projection.", + "out_conv.": "output_projection.", +} +# Top-level rename map applied AFTER NAFBlock / NAFTransform interior renames +# and BEFORE per-slice rerooting. Used for hyperprior backbone and aux module. +_TOPLEVEL_RENAMES: Dict[str, str] = { + "aux_entropymodel.": "aux_entropy_model.", + "h_a.": "latent_codec.h_a.", + "h_mean_s.": "latent_codec.h_s.h_mean_s.", + "h_scale_s.": "latent_codec.h_s.h_scale_s.", + "z_entropy_bottleneck.": "latent_codec.z.entropy_bottleneck.", +} +# Upstream uses ``mean_NAF_transforms`` / ``scale_NAF_transforms``; this PR +# stores them at ``{mean,scale}_support_transform`` inside the channel-context +# head (singular per slice). Aliasing here keeps the per-slice rerooting pass +# uniform across main and aux branches. +_NAMED_PART_RENAMES: Dict[str, str] = { + "mean_NAF_transforms.": "mean_support_transforms.", + "scale_NAF_transforms.": "scale_support_transforms.", +} + + +def _is_upstream_cca_state_dict(state_dict: Dict[str, Tensor]) -> bool: + """Heuristic detector for upstream ``LICAutoencoder`` checkpoints.""" + for key in state_dict: + if ( + key.startswith("mean_NAF_transforms.") + or key.startswith("scale_NAF_transforms.") + or key.startswith("aux_entropymodel.") + or key.startswith("z_entropy_bottleneck.") + or key.startswith("mean_cc_transforms.") + or key.startswith("scale_cc_transforms.") + or key.startswith("lrp_transforms.") + ): + return True + return False + + +def _find_naf_block_prefixes(state_dict: Dict[str, Tensor]) -> List[str]: + """Locate every NAFBlock instance by matching the ``.beta`` / ``.gamma`` + / ``.dwconv.0.weight`` / ``.FFN.0.weight`` 4-tuple at the same scope. + """ + suffix = ".beta" + out: List[str] = [] + for key in state_dict: + if not key.endswith(suffix): + continue + base = key[: -len(suffix)] + if ( + f"{base}.gamma" in state_dict + and f"{base}.dwconv.0.weight" in state_dict + and f"{base}.FFN.0.weight" in state_dict + ): + out.append(base) + return out + + +def _find_naf_transform_prefixes(state_dict: Dict[str, Tensor]) -> List[str]: + """Locate every NAFTransform instance by matching the ``.in_conv.weight`` + / ``.out_conv.weight`` / ``.blocks.0.beta`` triple at the same scope. + """ + suffix = ".in_conv.weight" + out: List[str] = [] + for key in state_dict: + if not key.endswith(suffix): + continue + base = key[: -len(suffix)] + if ( + f"{base}.out_conv.weight" in state_dict + and f"{base}.blocks.0.beta" in state_dict + ): + out.append(base) + return out + + +def _strip_prefix(key: str, prefix: str) -> Optional[str]: + return key[len(prefix) :] if key.startswith(prefix) else None + + +def _rename_with_table( + key: str, + base_prefixes: Sequence[str], + rename_map: Dict[str, str], +) -> str: + for base in base_prefixes: + head = base + "." + rest = _strip_prefix(key, head) + if rest is None: + continue + for old, new in rename_map.items(): + inner = _strip_prefix(rest, old) + if inner is not None: + return head + new + inner + return key + return key + + +def _reroot_per_slice_keys( + cleaned: Dict[str, Tensor], + converted: Dict[str, Tensor], + *, + legacy_prefix: str, + container_prefix: str, + sub_name: str, + num_slices: int, + consume: List[str], +) -> None: + """Move ``legacy_prefix.{k}.<...>`` keys to + ``container_prefix.y{k}.sub_name.<...>``. + + Keys that match are removed from ``cleaned`` (recorded in ``consume`` + for a later bulk drop) and inserted into ``converted`` under the new + path. + """ + for key in list(cleaned.keys()): + rest = _strip_prefix(key, legacy_prefix + ".") + if rest is None: + continue + idx_str, _, tail = rest.partition(".") + try: + idx = int(idx_str) + except ValueError: + continue + if idx >= num_slices: + continue + new_key = ( + f"{container_prefix}.y{idx}.{sub_name}.{tail}" + if tail + else f"{container_prefix}.y{idx}.{sub_name}" + ) + converted[new_key] = cleaned[key] + consume.append(key) + + +def _replicate_gaussian_conditional( + cleaned: Dict[str, Tensor], + converted: Dict[str, Tensor], + *, + legacy_prefix: str, + new_prefix: str, + num_slices: int, + consume: List[str], +) -> None: + """Copy a single shared ``gaussian_conditional.<...>`` buffer set under + every per-slice leaf so the per-slice + :class:`GaussianConditionalLatentCodec` copies all strict-load. + """ + for key in list(cleaned.keys()): + tail = _strip_prefix(key, legacy_prefix + ".") + if tail is None: + continue + for k in range(num_slices): + new_key = f"{new_prefix}.y{k}.gaussian_conditional.{tail}" + converted[new_key] = cleaned[key] + consume.append(key) + + +def convert_upstream_cca_state_dict( + state_dict: Dict[str, Tensor], +) -> Dict[str, Tensor]: + """Translate an upstream CCA ``LICAutoencoder`` state dict to the + compressai layout produced by :class:`CCAModel`. + + Conversion runs three logical passes: + + 1. Interior renames: ``NAFBlock`` (``dwconv`` → ``pointwise_depthwise``, + etc.) and ``NAFTransform`` (``in_conv`` → ``input_projection``, + etc.). Detection is by structural fingerprint + (:func:`_find_naf_block_prefixes`) so the renames apply uniformly to + NAFBlocks anywhere in the state dict (``g_a`` / ``g_s`` / per-slice + support transforms / aux module). + 2. Top-level renames: ``aux_entropymodel`` → ``aux_entropy_model``, + hyperprior backbone (``h_a`` / ``h_mean_s`` / ``h_scale_s``) and + ``z_entropy_bottleneck`` are moved under ``latent_codec.*``; + ``mean_NAF_transforms`` / ``scale_NAF_transforms`` are aliased to + the singular ``{mean,scale}_support_transforms`` form so the + per-slice rerooting in pass 3 only handles one name. + 3. Per-slice rerooting: ``mean_cc_transforms.{k}`` / + ``scale_cc_transforms.{k}`` move to + ``latent_codec.y.channel_context.y{k}.{mean,scale}_cc.*``; + ``mean_support_transforms.{k}`` / ``scale_support_transforms.{k}`` + move to + ``latent_codec.y.channel_context.y{k}.{mean,scale}_support_transform.*``; + ``lrp_transforms.{k}`` moves to + ``latent_codec.y.latent_codec.y{k}.lrp_transform.*``; the single + shared ``gaussian_conditional.*`` buffer set is replicated under + every per-slice leaf + (``latent_codec.y.latent_codec.y{k}.gaussian_conditional.*``). The + same rerooting is applied to ``aux_entropy_model.*`` (after the + top-level rename) under ``aux_entropy_model.inner_codec.*``. + + The returned dict can be loaded by :meth:`CCAModel.from_state_dict`, + which auto-detects the upstream layout via + :func:`_is_upstream_cca_state_dict`, so direct invocation is only + needed when persisting the converted dict. + """ + naf_blocks = _find_naf_block_prefixes(state_dict) + naf_transforms = _find_naf_transform_prefixes(state_dict) + + # Pass 1+2: interior + top-level renames. + cleaned: Dict[str, Tensor] = {} + for key, value in state_dict.items(): + new_key = _rename_with_table(key, naf_blocks, _NAF_BLOCK_RENAMES) + new_key = _rename_with_table(new_key, naf_transforms, _NAF_TRANSFORM_RENAMES) + for old, new in _NAMED_PART_RENAMES.items(): + new_key = new_key.replace(old, new) + for old, new in _TOPLEVEL_RENAMES.items(): + if new_key.startswith(old): + new_key = new + new_key[len(old) :] + break + cleaned[new_key] = value + + # Pass 3a: per-slice rerooting for the main entropy stack. Discover + # ``num_slices`` from ``mean_cc_transforms`` first, then drive the rest. + main_indices = sorted( + { + int(key[len("mean_cc_transforms.") :].split(".", 1)[0]) + for key in cleaned + if key.startswith("mean_cc_transforms.") + } + ) + num_slices_main = len(main_indices) + + converted: Dict[str, Tensor] = {} + consumed: List[str] = [] + + if num_slices_main: + for legacy, container, sub in ( + ("mean_cc_transforms", "latent_codec.y.channel_context", "mean_cc"), + ("scale_cc_transforms", "latent_codec.y.channel_context", "scale_cc"), + ( + "mean_support_transforms", + "latent_codec.y.channel_context", + "mean_support_transform", + ), + ( + "scale_support_transforms", + "latent_codec.y.channel_context", + "scale_support_transform", + ), + ("lrp_transforms", "latent_codec.y.latent_codec", "lrp_transform"), + ): + _reroot_per_slice_keys( + cleaned, + converted, + legacy_prefix=legacy, + container_prefix=container, + sub_name=sub, + num_slices=num_slices_main, + consume=consumed, + ) + _replicate_gaussian_conditional( + cleaned, + converted, + legacy_prefix="gaussian_conditional", + new_prefix="latent_codec.y.latent_codec", + num_slices=num_slices_main, + consume=consumed, + ) + + # Pass 3b: per-slice rerooting inside the aux entropy module. Discover + # ``num_slices_aux`` from ``aux_entropy_model.mean_cc_transforms``. + aux_indices = sorted( + { + int(key[len("aux_entropy_model.mean_cc_transforms.") :].split(".", 1)[0]) + for key in cleaned + if key.startswith("aux_entropy_model.mean_cc_transforms.") + } + ) + num_slices_aux = len(aux_indices) + if num_slices_aux: + for legacy, container, sub in ( + ( + "aux_entropy_model.mean_cc_transforms", + "aux_entropy_model.inner_codec.channel_context", + "mean_cc", + ), + ( + "aux_entropy_model.scale_cc_transforms", + "aux_entropy_model.inner_codec.channel_context", + "scale_cc", + ), + ( + "aux_entropy_model.mean_support_transforms", + "aux_entropy_model.inner_codec.channel_context", + "mean_support_transform", + ), + ( + "aux_entropy_model.scale_support_transforms", + "aux_entropy_model.inner_codec.channel_context", + "scale_support_transform", + ), + ( + "aux_entropy_model.lrp_transforms", + "aux_entropy_model.inner_codec.latent_codec", + "lrp_transform", + ), + ): + _reroot_per_slice_keys( + cleaned, + converted, + legacy_prefix=legacy, + container_prefix=container, + sub_name=sub, + num_slices=num_slices_aux, + consume=consumed, + ) + _replicate_gaussian_conditional( + cleaned, + converted, + legacy_prefix="aux_entropy_model.gaussian_conditional", + new_prefix="aux_entropy_model.inner_codec.latent_codec", + num_slices=num_slices_aux, + consume=consumed, + ) + + for key in consumed: + cleaned.pop(key, None) + # Remaining keys (g_a / g_s / latent_codec.* hyperprior backbone / + # aux_entropy_model.y_entropy_bottleneck / etc.) pass through unchanged. + converted.update(cleaned) + return converted + + +# ---------------------------------------------------------------------------- +# Architecture inference helpers (state_dict -> hyperparameters). +# ---------------------------------------------------------------------------- + + +def _infer_config_from_state_dict(state_dict: Dict[str, Tensor]) -> Dict[str, object]: + """Recover constructor kwargs from a compressai-layout CCA state dict.""" + encoder_dims = ( + state_dict["g_a.down.0.weight"].size(0), + state_dict["g_a.down.1.weight"].size(0), + state_dict["g_a.down.2.weight"].size(0), + ) + latent_channels = state_dict["g_a.down.3.weight"].size(0) + hyper_channels = state_dict["latent_codec.h_a.4.weight"].size(0) + + encoder_layers: List[int] = [] + for stage in range(3): + index = 0 + while f"g_a.blocks.{stage}.{index}.beta" in state_dict or _has_resblock( + state_dict, stage, index + ): + index += 1 + encoder_layers.append(index - 3) + + cc_keys = [ + key + for key in state_dict + if key.startswith("latent_codec.y.channel_context.y") + and key.endswith(".mean_cc.4.weight") + ] + cc_keys.sort(key=lambda key: int(key.split(".")[3][1:])) # ".y{k}." -> k + if not cc_keys: + raise RuntimeError("state dict does not contain channel-context mean_cc heads") + slice_sizes = [int(state_dict[key].size(0)) for key in cc_keys] + + em_hidden_channels = int( + state_dict[ + "latent_codec.y.channel_context.y0.mean_support_transform.input_projection.weight" + ].size(0) + ) + + em_num_layers = 0 + while ( + f"latent_codec.y.channel_context.y0.mean_support_transform.blocks.{em_num_layers}.beta" + in state_dict + ): + em_num_layers += 1 + + cca_training = any(key.startswith("aux_entropy_model.") for key in state_dict) + + return { + "latent_channels": int(latent_channels), + "hyper_channels": int(hyper_channels), + "slice_proportions": tuple(slice_sizes), + "encoder_dims": tuple(int(value) for value in encoder_dims), + "encoder_layers": tuple(int(value) for value in encoder_layers), + "em_hidden_channels": em_hidden_channels, + "em_num_layers": em_num_layers, + "cca_training": cca_training, + } + + +def _has_resblock(state_dict: Dict[str, Tensor], stage: int, sub_index: int) -> bool: + return f"g_a.blocks.{stage}.{sub_index}.conv2.weight" in state_dict and ( + f"g_a.blocks.{stage}.{sub_index}.beta" not in state_dict + ) diff --git a/compressai/models/stf.py b/compressai/models/stf.py index 14ae9a0a..d2ed1348 100644 --- a/compressai/models/stf.py +++ b/compressai/models/stf.py @@ -272,7 +272,7 @@ def convert_upstream_stf_state_dict( ``syn_layers`` / ``end_conv`` keys unchanged. .. caveat:: - The Phase 3 wiring sets ``emit_mean_support=True`` on the + The wiring sets ``emit_mean_support=True`` on the ``MeanScaleContextHead`` so the upstream LRP layout (``cat(latent_means, *prev_y_hat, y_hat)``) is recoverable inside the leaf — upstream ``lrp_transforms.{k}`` weights therefore transfer @@ -562,7 +562,7 @@ def _build_family1_latent_codec( h_mean_s: nn.Module, h_scale_s: nn.Module, ) -> HyperpriorLatentCodec: - """Assemble the Phase 3 Family 1 entropy stack: ``HyperpriorLatentCodec`` + """Assemble the Family 1 entropy stack: ``HyperpriorLatentCodec`` wrapping ``DualHyperSynthesis`` and a per-slice ``ChannelGroupsLatentCodec`` (``side_in_context=True``) whose channel contexts are :class:`MeanScaleContextHead` instances and leaves are diff --git a/compressai/models/tcm.py b/compressai/models/tcm.py index a5d8e713..89fe107c 100644 --- a/compressai/models/tcm.py +++ b/compressai/models/tcm.py @@ -233,7 +233,7 @@ def convert_upstream_tcm_state_dict( - leaves ``g_a`` / ``g_s`` keys (other than the MSA renames inside their ConvTransBlocks) untouched. - The Phase 3 wiring sets ``emit_mean_support=True`` on the + The wiring sets ``emit_mean_support=True`` on the :class:`MeanScaleContextHead`, so the upstream LRP layout (``cat(latent_means, *prev_y_hat, y_hat)``) is recoverable inside the leaf — upstream ``lrp_transforms.{k}`` weights therefore transfer diff --git a/examples/convert_cca_checkpoint.py b/examples/convert_cca_checkpoint.py new file mode 100644 index 00000000..8379eaa7 --- /dev/null +++ b/examples/convert_cca_checkpoint.py @@ -0,0 +1,124 @@ +"""Convert an upstream CCA checkpoint to compressai layout. + +Loads the published candidate weight file (e.g. +``checkpoint_lambda_0.3.pth.tar`` from M. Han et al., +https://github.com/CVL-UESTC/CCA, NeurIPS 2024), translates it to +compressai's module layout, and writes a state dict that +``compressai.models.cca.CCAModel.from_state_dict`` can load directly. +Optionally reports forward-pass sanity numbers (PSNR / bpp) on a +synthetic input. + +The upstream-vs-compressai key differences (NAFBlock interior renames, +``mean_NAF_transforms`` -> ``channel_context.y{k}.mean_support_transform``, +``mean_cc_transforms.{k}`` -> ``channel_context.y{k}.mean_cc``, +``lrp_transforms.{k}`` -> ``latent_codec.y{k}.lrp_transform``, +``aux_entropymodel.*`` -> ``aux_entropy_model.inner_codec.*``, the +gaussian_conditional replication across slices, the H+G containerised +re-rooting under ``latent_codec.*``, etc.) are all handled inside +``convert_upstream_cca_state_dict``; this script is a thin CLI around it. + +Example:: + + python examples/convert_cca_checkpoint.py \\ + --src candidate/CCA/checkpoint_lambda_0.3.pth.tar \\ + --dst /tmp/cca_compressai.pth \\ + --smoke +""" + +from __future__ import annotations + +import argparse + +from pathlib import Path + +import torch + +from compressai.models.cca import CCAModel, convert_upstream_cca_state_dict + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--src", + type=Path, + required=True, + help="Path to the upstream CCA checkpoint (e.g. checkpoint_lambda_0.3.pth.tar).", + ) + parser.add_argument( + "--dst", + type=Path, + default=None, + help=( + "Optional output path for the converted state dict. If omitted, " + "the script only verifies that the checkpoint loads cleanly." + ), + ) + parser.add_argument( + "--smoke", + action="store_true", + help="Run a forward smoke test on a synthetic 256x256 image.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if not args.src.exists(): + raise SystemExit(f"checkpoint not found: {args.src}") + + upstream = torch.load(args.src, map_location="cpu", weights_only=False) + upstream = ( + upstream.get("state_dict", upstream) if isinstance(upstream, dict) else upstream + ) + converted = convert_upstream_cca_state_dict(upstream) + print(f"loaded {len(upstream)} upstream keys → {len(converted)} compressai keys") + + net = CCAModel.from_state_dict(upstream) + net.eval() + print( + "variant: " + f"M={net.M}, N={net.N}, slice_sizes={net.slice_sizes}, " + f"em_hidden={net.em_hidden_channels}, em_layers={net.em_num_layers}, " + f"cca_training={net.cca_training}" + ) + print(f"parameters: {sum(p.numel() for p in net.parameters()):,}") + + if args.dst is not None: + args.dst.parent.mkdir(parents=True, exist_ok=True) + torch.save(net.state_dict(), args.dst) + print(f"wrote converted state dict → {args.dst}") + + if args.smoke: + height = width = 256 + ys, xs = torch.meshgrid( + torch.linspace(0, 1, height), + torch.linspace(0, 1, width), + indexing="ij", + ) + img = ( + torch.stack( + [ + 0.5 + 0.3 * torch.sin(8 * xs), + 0.5 + 0.3 * torch.sin(8 * ys), + 0.5 + 0.3 * torch.cos(8 * (xs + ys)), + ], + dim=0, + ) + .unsqueeze(0) + .clamp(0, 1) + ) + + with torch.no_grad(): + out = net(img) + n_pix = height * width + psnr = -10 * torch.log10(((out["x_hat"].clamp(0, 1) - img) ** 2).mean()).item() + y_bpp = -torch.log2(out["likelihoods"]["y"]).sum().item() / n_pix + z_bpp = -torch.log2(out["likelihoods"]["z"]).sum().item() / n_pix + print( + f"smoke: PSNR={psnr:.2f}dB y_bpp={y_bpp:.4f} z_bpp={z_bpp:.4f} " + f"total_bpp={y_bpp + z_bpp:.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_latent_codecs.py b/tests/test_latent_codecs.py index cbf508b1..cbd03ff8 100644 --- a/tests/test_latent_codecs.py +++ b/tests/test_latent_codecs.py @@ -245,7 +245,7 @@ def test_max_support_slices_changes_forward_output(self): class TestChannelGroupsSideInContext: - """Phase 3 ``side_in_context`` mode used by Family 1 codecs.""" + """``side_in_context`` mode used by Family 1 codecs.""" def _make_family1_codec( self, diff --git a/tests/test_models.py b/tests/test_models.py index 99532a42..7c1b968a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -290,7 +290,7 @@ def test_wacnn_forward_and_state_dict_round_trip(self): assert "y" in out["likelihoods"] assert "z" in out["likelihoods"] - # Phase 3 containerised state-dict layout self-check. + # Containerised state-dict layout self-check. sd_keys = set(model.state_dict().keys()) assert "latent_codec.h_a.0.weight" in sd_keys assert "latent_codec.h_s.h_mean_s.0.weight" in sd_keys @@ -411,7 +411,7 @@ def test_tcm_forward_and_state_dict_round_trip(self): assert "y" in out["likelihoods"] assert "z" in out["likelihoods"] - # Phase 4 containerised state-dict layout self-check. + # Containerised state-dict layout self-check. sd_keys = set(model.state_dict().keys()) # Hyperprior backbone moved under latent_codec.* (TCM's h_a / h_*_s # use ResidualBlockWithStride / ResidualBlockUpsample, so the first @@ -545,6 +545,254 @@ def test_tcm_upstream_state_dict_conversion(self): assert "module.g_a.0.conv1.weight" not in converted +class TestCca: + def test_cca_forward_and_state_dict_round_trip(self): + from compressai.models.cca import CCAModel + + # Tiny variant — variable-length slices, smaller dims keep the + # NAFTransform stack cheap. Slice proportions reproduce the + # upstream layout (8/28/56/92/136 over M=320) at scale. + model = CCAModel( + latent_channels=64, + hyper_channels=48, + slice_proportions=(2, 6, 12, 18, 26), + encoder_dims=(48, 56, 64), + encoder_layers=(1, 1, 1), + em_hidden_channels=56, + em_num_layers=1, + ).eval() + x = torch.rand(1, 3, 128, 128) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + # cca_training=False -> no aux likelihoods exposed. + assert out["aux_likelihoods"] is None + + # Containerised state-dict layout self-check. + sd_keys = set(model.state_dict().keys()) + # Hyperprior backbone moved under latent_codec.* (CCA's h_a / + # h_*_s use plain Sequentials of conv / GELU; first weight is at + # `.0.weight` rather than `.0.conv1.weight`). + assert "latent_codec.h_a.0.weight" in sd_keys + assert "latent_codec.h_s.h_mean_s.0.weight" in sd_keys + assert "latent_codec.h_s.h_scale_s.0.weight" in sd_keys + # CCA-main uses STE quantization on z (vs. STF/TCM noise mode). + assert "latent_codec.z.entropy_bottleneck.quantiles" in sd_keys + # side_in_context=True -> channel_context covers y0..y(K-1). + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y4.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in sd_keys + # NAFTransform support transforms (CCA-specific; absent on STF/WACNN). + assert ( + "latent_codec.y.channel_context.y0.mean_support_transform.input_projection.weight" + in sd_keys + ) + assert ( + "latent_codec.y.channel_context.y0.scale_support_transform.input_projection.weight" + in sd_keys + ) + # Per-slice leaves (LRP + per-slice GaussianConditional copy). + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in sd_keys + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" in sd_keys + ) + # Old monolithic / pre-refactor paths should be gone. + assert "h_a.0.weight" not in sd_keys + assert "z_entropy_bottleneck.quantiles" not in sd_keys + assert not any(k.startswith("mean_cc_transforms.") for k in sd_keys) + assert not any(k.startswith("aux_entropy_model.") for k in sd_keys) + + loaded = CCAModel.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + assert torch.allclose(out["likelihoods"]["y"], out_loaded["likelihoods"]["y"]) + assert torch.allclose(out["likelihoods"]["z"], out_loaded["likelihoods"]["z"]) + assert loaded.M == 64 + assert loaded.N == 48 + assert tuple(loaded.slice_sizes) == (2, 6, 12, 18, 26) + assert loaded.em_hidden_channels == 56 + assert loaded.em_num_layers == 1 + assert loaded.cca_training is False + + def test_cca_training_branch_forward_and_round_trip(self): + from compressai.models.cca import CCAModel + + model = CCAModel( + latent_channels=64, + hyper_channels=48, + slice_proportions=(2, 6, 12, 18, 26), + encoder_dims=(48, 56, 64), + encoder_layers=(1, 1, 1), + em_hidden_channels=56, + em_num_layers=1, + cca_training=True, + ).eval() + x = torch.rand(1, 3, 128, 128) + with torch.no_grad(): + out = model(x) + # Aux branch populates y_aux (factorised) and y_cca (Gaussian). + assert isinstance(out["aux_likelihoods"], dict) + assert set(out["aux_likelihoods"].keys()) == {"y_aux", "y_cca"} + assert out["aux_likelihoods"]["y_aux"].shape == out["likelihoods"]["y"].shape + assert out["aux_likelihoods"]["y_cca"].shape == out["likelihoods"]["y"].shape + + # Aux state-dict paths (skip-most-recent inner ChannelGroupsLatentCodec). + sd_keys = set(model.state_dict().keys()) + assert "aux_entropy_model.y_entropy_bottleneck.quantiles" in sd_keys + assert ( + "aux_entropy_model.inner_codec.channel_context.y0.mean_cc.0.weight" + in sd_keys + ) + assert ( + "aux_entropy_model.inner_codec.channel_context.y0.mean_support_transform.input_projection.weight" + in sd_keys + ) + assert ( + "aux_entropy_model.inner_codec.latent_codec.y0.lrp_transform.0.weight" + in sd_keys + ) + + loaded = CCAModel.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert loaded.cca_training is True + assert torch.allclose( + out["aux_likelihoods"]["y_aux"], out_loaded["aux_likelihoods"]["y_aux"] + ) + assert torch.allclose( + out["aux_likelihoods"]["y_cca"], out_loaded["aux_likelihoods"]["y_cca"] + ) + + def test_cca_upstream_state_dict_conversion(self): + from compressai.models.cca import ( + _is_upstream_cca_state_dict, + convert_upstream_cca_state_dict, + ) + + # Synthetic upstream LICAutoencoder-style state_dict with one slice + # per branch covering the full path: NAFBlock interior renames, + # NAFTransform interior renames, named-part NAF -> support_transforms + # alias, top-level hyperprior + aux module rerooting, per-slice + # rerooting under channel_context / latent_codec, and the + # gaussian_conditional replication. + upstream = { + # ResidualBottleneckBlock inside g_a (conv1 should NOT be renamed + # since it's not inside a NAFBlock — checked via the NAFBlock + # detector which requires the .beta/.gamma/.dwconv.0 triple). + "g_a.blocks.0.0.conv1.weight": torch.zeros(2), + # NAFBlock inside g_a (full triple present -> dwconv/sca/FFN/conv1 + # interior renames apply to this scope only). + "g_a.blocks.0.3.beta": torch.zeros(2), + "g_a.blocks.0.3.gamma": torch.zeros(2), + "g_a.blocks.0.3.dwconv.0.weight": torch.zeros(2), + "g_a.blocks.0.3.sca.1.weight": torch.zeros(2), + "g_a.blocks.0.3.FFN.0.weight": torch.zeros(2), + "g_a.blocks.0.3.conv1.weight": torch.zeros(2), + # Per-slice main entropy heads (one slice for compactness). + "mean_cc_transforms.0.0.weight": torch.zeros(2), + "scale_cc_transforms.0.0.weight": torch.zeros(2), + "lrp_transforms.0.0.weight": torch.zeros(2), + # NAFTransform interior (in_conv/out_conv -> input_projection/...). + # Triple required for the detector: .in_conv.weight, + # .out_conv.weight, .blocks.0.beta. + "mean_NAF_transforms.0.in_conv.weight": torch.zeros(2), + "mean_NAF_transforms.0.out_conv.weight": torch.zeros(2), + "mean_NAF_transforms.0.blocks.0.beta": torch.zeros(2), + "scale_NAF_transforms.0.in_conv.weight": torch.zeros(2), + "scale_NAF_transforms.0.out_conv.weight": torch.zeros(2), + "scale_NAF_transforms.0.blocks.0.beta": torch.zeros(2), + "gaussian_conditional.scale_table": torch.zeros(2), + # Hyperprior backbone (root-level -> latent_codec.*). + "h_a.0.weight": torch.zeros(2), + "h_mean_s.0.weight": torch.zeros(2), + "h_scale_s.0.weight": torch.zeros(2), + "z_entropy_bottleneck.quantiles": torch.zeros(2), + # Aux entropy module (aux_entropymodel -> aux_entropy_model, then + # the same per-slice rerooting as the main path). + "aux_entropymodel.mean_cc_transforms.0.0.weight": torch.zeros(2), + "aux_entropymodel.scale_cc_transforms.0.0.weight": torch.zeros(2), + "aux_entropymodel.lrp_transforms.0.0.weight": torch.zeros(2), + "aux_entropymodel.mean_NAF_transforms.0.in_conv.weight": torch.zeros(2), + "aux_entropymodel.mean_NAF_transforms.0.out_conv.weight": torch.zeros(2), + "aux_entropymodel.mean_NAF_transforms.0.blocks.0.beta": torch.zeros(2), + "aux_entropymodel.scale_NAF_transforms.0.in_conv.weight": torch.zeros(2), + "aux_entropymodel.scale_NAF_transforms.0.out_conv.weight": torch.zeros(2), + "aux_entropymodel.scale_NAF_transforms.0.blocks.0.beta": torch.zeros(2), + "aux_entropymodel.gaussian_conditional.scale_table": torch.zeros(2), + "aux_entropymodel.y_entropy_bottleneck.quantiles": torch.zeros(2), + } + assert _is_upstream_cca_state_dict(upstream) + + converted = convert_upstream_cca_state_dict(upstream) + + # ResidualBottleneckBlock conv1 NOT renamed (not inside NAFBlock). + assert "g_a.blocks.0.0.conv1.weight" in converted + # NAFBlock interior renames applied at the NAFBlock scope only. + assert "g_a.blocks.0.3.beta" in converted + assert "g_a.blocks.0.3.pointwise_depthwise.0.weight" in converted + assert "g_a.blocks.0.3.channel_attention.1.weight" in converted + assert "g_a.blocks.0.3.feed_forward.0.weight" in converted + assert "g_a.blocks.0.3.project.weight" in converted + + # Hyperprior backbone moves under latent_codec. + assert "latent_codec.h_a.0.weight" in converted + assert "latent_codec.h_s.h_mean_s.0.weight" in converted + assert "latent_codec.h_s.h_scale_s.0.weight" in converted + assert "latent_codec.z.entropy_bottleneck.quantiles" in converted + + # Per-slice main rerooting. + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in converted + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in converted + # NAFTransform: in_conv -> input_projection; mean_NAF_transforms -> + # channel_context.y{k}.mean_support_transform. + assert ( + "latent_codec.y.channel_context.y0.mean_support_transform.input_projection.weight" + in converted + ) + assert ( + "latent_codec.y.channel_context.y0.scale_support_transform.input_projection.weight" + in converted + ) + # gaussian_conditional replicated under each per-slice leaf. + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" + in converted + ) + # LRP weights byte-for-byte under per-slice leaf. + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in converted + + # Aux entropy module rerooting (aux_entropymodel -> aux_entropy_model; + # per-slice contents land under inner_codec.*). + assert ( + "aux_entropy_model.inner_codec.channel_context.y0.mean_cc.0.weight" + in converted + ) + assert ( + "aux_entropy_model.inner_codec.channel_context.y0.mean_support_transform.input_projection.weight" + in converted + ) + assert ( + "aux_entropy_model.inner_codec.latent_codec.y0.lrp_transform.0.weight" + in converted + ) + assert ( + "aux_entropy_model.inner_codec.latent_codec.y0.gaussian_conditional.scale_table" + in converted + ) + assert "aux_entropy_model.y_entropy_bottleneck.quantiles" in converted + + # Old paths should be gone after conversion. + assert "h_a.0.weight" not in converted + assert "z_entropy_bottleneck.quantiles" not in converted + assert "mean_cc_transforms.0.0.weight" not in converted + assert "mean_NAF_transforms.0.in_conv.weight" not in converted + assert "lrp_transforms.0.0.weight" not in converted + assert "aux_entropymodel.mean_cc_transforms.0.0.weight" not in converted + + def test_scale_table_default(): table = get_scale_table() assert SCALES_MIN == 0.11 From 0c44c75ab003c3bddfbcac454f1555e74bfa9144 Mon Sep 17 00:00:00 2001 From: boyceyi <1473416941@qq.com> Date: Sat, 9 May 2026 14:26:40 +0800 Subject: [PATCH 5/8] chore(latent_codecs,models): drop ChannelSliceLatentCodec and SliceEntropyCompressionModel Family 1 models (STF, WACNN, TCM, CCA) all migrated to ChannelGroupsLatentCodec + _slice_helpers, leaving these two scaffolding modules with no callers in production code or tests. Delete: - compressai/latent_codecs/channel_slice.py - compressai/models/_bases/ (slice_entropy.py + __init__.py; whole dir is empty) - corresponding exports from latent_codecs/__init__.py --- compressai/latent_codecs/__init__.py | 2 - compressai/latent_codecs/_slice_helpers.py | 10 +- compressai/latent_codecs/channel_slice.py | 269 --------------------- compressai/models/_bases/__init__.py | 24 -- compressai/models/_bases/slice_entropy.py | 260 -------------------- 5 files changed, 4 insertions(+), 561 deletions(-) delete mode 100644 compressai/latent_codecs/channel_slice.py delete mode 100644 compressai/models/_bases/__init__.py delete mode 100644 compressai/models/_bases/slice_entropy.py diff --git a/compressai/latent_codecs/__init__.py b/compressai/latent_codecs/__init__.py index 1d5ea9c6..e1e8f324 100644 --- a/compressai/latent_codecs/__init__.py +++ b/compressai/latent_codecs/__init__.py @@ -30,7 +30,6 @@ from ._hyper_synthesis import DualHyperSynthesis from .base import LatentCodec from .channel_groups import ChannelGroupsLatentCodec -from .channel_slice import ChannelSliceLatentCodec from .checkerboard import CheckerboardLatentCodec from .entropy_bottleneck import EntropyBottleneckLatentCodec from .gain import GainHyperLatentCodec, GainHyperpriorLatentCodec @@ -42,7 +41,6 @@ __all__ = [ "LatentCodec", "ChannelGroupsLatentCodec", - "ChannelSliceLatentCodec", "CheckerboardLatentCodec", "DualHyperSynthesis", "EntropyBottleneckLatentCodec", diff --git a/compressai/latent_codecs/_slice_helpers.py b/compressai/latent_codecs/_slice_helpers.py index 6332d1c0..d2286c00 100644 --- a/compressai/latent_codecs/_slice_helpers.py +++ b/compressai/latent_codecs/_slice_helpers.py @@ -30,12 +30,10 @@ """Channel-slice support helpers shared by Family 1 codecs. These functions support the Family 1 (pure 1-pass channel-slice) entropy -models — STF / WACNN / TCM / CCA / DCAE / MambaVC. They were previously -hosted in ``compressai.models._bases.slice_entropy``; the canonical home is -now this module so they sit alongside the latent-codec primitives that -consume them. ``_DEFAULT_NUM_SLICES_PREFIX`` reflects the post-refactor -state-dict layout used by the containerised -:class:`~compressai.latent_codecs.ChannelGroupsLatentCodec` wiring. +models — STF / WACNN / TCM / CCA / DCAE / MambaVC. They sit alongside the +latent-codec primitives that consume them. +``_DEFAULT_NUM_SLICES_PREFIX`` reflects the containerised state-dict layout +used by :class:`~compressai.latent_codecs.ChannelGroupsLatentCodec`. """ from __future__ import annotations diff --git a/compressai/latent_codecs/channel_slice.py b/compressai/latent_codecs/channel_slice.py deleted file mode 100644 index 73e32d6e..00000000 --- a/compressai/latent_codecs/channel_slice.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright (c) 2021-2025, InterDigital Communications, Inc -# All rights reserved. - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted (subject to the limitations in the disclaimer -# below) provided that the following conditions are met: - -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# * Neither the name of InterDigital Communications, Inc nor the names of its -# contributors may be used to endorse or promote products derived from this -# software without specific prior written permission. - -# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY -# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND -# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT -# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR -# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF -# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import annotations - -from typing import Any, Dict, List, Optional, Sequence, Tuple - -import torch -import torch.nn as nn - -from torch import Tensor - -from compressai.ans import BufferedRansEncoder, RansDecoder -from compressai.entropy_models import GaussianConditional -from compressai.ops import quantize_ste -from compressai.registry import register_module - -from .base import LatentCodec - -__all__ = [ - "ChannelSliceLatentCodec", -] - - -@register_module("ChannelSliceLatentCodec") -class ChannelSliceLatentCodec(LatentCodec): - """Channel-conditional entropy model with separate scale/mean heads and LRP. - - Splits ``y`` into equal-sized slices along the channel axis. For each - slice ``k`` the previously decoded slices (truncated to - ``max_support_slices``) are concatenated with ``latent_means`` / - ``latent_scales`` and pushed through ``cc_mean_transforms[k]`` and - ``cc_scale_transforms[k]`` to obtain ``mu`` / ``scale``. After the - Gaussian conditional step, an optional latent residual prediction - (LRP) head refines ``y_hat``. - - This is the channel-autoregressive entropy model from [Minnen2020] - with the LRP refinement variant used in [Zhu2022] (STF / WACNN), - [He2022] (ELIC) and many follow-up papers (MLIC++, TCM, ...). - - [Minnen2020]: `"Channel-wise Autoregressive Entropy Models for - Learned Image Compression" `_, by - David Minnen and Saurabh Singh, ICIP 2020. - - [Zhu2022]: `"Transformer-based Transform Coding" - `_, by Yinhao Zhu, - Yang Yang and Taco Cohen, ICLR 2022. - """ - - cc_mean_transforms: nn.ModuleList - cc_scale_transforms: nn.ModuleList - lrp_transforms: nn.ModuleList - gaussian_conditional: GaussianConditional - - def __init__( - self, - cc_mean_transforms: nn.ModuleList, - cc_scale_transforms: nn.ModuleList, - lrp_transforms: Optional[nn.ModuleList] = None, - gaussian_conditional: Optional[GaussianConditional] = None, - mean_support_transforms: Optional[nn.ModuleList] = None, - scale_support_transforms: Optional[nn.ModuleList] = None, - *, - num_slices: Optional[int] = None, - max_support_slices: int = -1, - quantizer: str = "ste", - lrp_scale: float = 0.5, - **kwargs: Any, - ) -> None: - super().__init__() - self._kwargs = kwargs - - inferred_num_slices = len(cc_mean_transforms) - if num_slices is None: - num_slices = inferred_num_slices - if inferred_num_slices != num_slices: - raise ValueError( - "cc_mean_transforms must have num_slices entries " - f"(got {inferred_num_slices}, expected {num_slices})" - ) - if len(cc_scale_transforms) != num_slices: - raise ValueError("cc_scale_transforms must have num_slices entries") - if lrp_transforms is not None and len(lrp_transforms) != num_slices: - raise ValueError("lrp_transforms must have num_slices entries") - if ( - mean_support_transforms is not None - and len(mean_support_transforms) != num_slices - ): - raise ValueError("mean_support_transforms must have num_slices entries") - if ( - scale_support_transforms is not None - and len(scale_support_transforms) != num_slices - ): - raise ValueError("scale_support_transforms must have num_slices entries") - if quantizer not in ("ste", "noise"): - raise ValueError(f"unknown quantizer {quantizer!r}") - - self.num_slices = int(num_slices) - self.max_support_slices = int(max_support_slices) - self.quantizer = quantizer - self.lrp_scale = float(lrp_scale) - self.cc_mean_transforms = cc_mean_transforms - self.cc_scale_transforms = cc_scale_transforms - self.mean_support_transforms = mean_support_transforms or nn.ModuleList( - nn.Identity() for _ in range(num_slices) - ) - self.scale_support_transforms = scale_support_transforms or nn.ModuleList( - nn.Identity() for _ in range(num_slices) - ) - self.lrp_transforms = lrp_transforms or nn.ModuleList( - nn.Identity() for _ in range(num_slices) - ) - self.gaussian_conditional = gaussian_conditional or GaussianConditional(None) - - def _support_slices(self, y_hat_slices: Sequence[Tensor]) -> List[Tensor]: - if self.max_support_slices < 0: - return list(y_hat_slices) - return list(y_hat_slices[: self.max_support_slices]) - - def _slice_params( - self, - slice_index: int, - latent_means: Tensor, - latent_scales: Tensor, - y_hat_slices: Sequence[Tensor], - spatial_shape: Tuple[int, int], - ) -> Tuple[Tensor, Tensor, Tensor]: - support = self._support_slices(y_hat_slices) - mean_support = torch.cat([latent_means, *support], dim=1) - mean_support = self.mean_support_transforms[slice_index](mean_support) - mu = self.cc_mean_transforms[slice_index](mean_support) - mu = mu[:, :, : spatial_shape[0], : spatial_shape[1]] - scale_support = torch.cat([latent_scales, *support], dim=1) - scale_support = self.scale_support_transforms[slice_index](scale_support) - scale = self.cc_scale_transforms[slice_index](scale_support) - scale = scale[:, :, : spatial_shape[0], : spatial_shape[1]] - return mu, scale, mean_support - - def _apply_lrp( - self, slice_index: int, mean_support: Tensor, y_hat_slice: Tensor - ) -> Tensor: - lrp = self.lrp_transforms[slice_index]( - torch.cat([mean_support, y_hat_slice], dim=1) - ) - return y_hat_slice + self.lrp_scale * torch.tanh(lrp) - - def forward( - self, - y: Tensor, - latent_means: Tensor, - latent_scales: Tensor, - ) -> Dict[str, Any]: - spatial_shape = (y.shape[2], y.shape[3]) - y_hat_slices: List[Tensor] = [] - y_likelihoods_slices: List[Tensor] = [] - - for slice_index, y_slice in enumerate(y.chunk(self.num_slices, dim=1)): - mu, scale, mean_support = self._slice_params( - slice_index, latent_means, latent_scales, y_hat_slices, spatial_shape - ) - _, y_slice_likelihoods = self.gaussian_conditional(y_slice, scale, means=mu) - if self.quantizer == "ste": - y_hat_slice = quantize_ste(y_slice - mu) + mu - else: - y_hat_slice = self.gaussian_conditional.quantize( - y_slice, "noise" if self.training else "dequantize", mu - ) - y_hat_slice = self._apply_lrp(slice_index, mean_support, y_hat_slice) - y_hat_slices.append(y_hat_slice) - y_likelihoods_slices.append(y_slice_likelihoods) - - return { - "y_hat": torch.cat(y_hat_slices, dim=1), - "likelihoods": {"y": torch.cat(y_likelihoods_slices, dim=1)}, - } - - def compress( - self, - y: Tensor, - latent_means: Tensor, - latent_scales: Tensor, - ) -> Dict[str, Any]: - spatial_shape = (y.shape[2], y.shape[3]) - cdf = self.gaussian_conditional.quantized_cdf.tolist() - cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() - offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() - encoder = BufferedRansEncoder() - symbols_list: List[int] = [] - indexes_list: List[int] = [] - y_hat_slices: List[Tensor] = [] - - for slice_index, y_slice in enumerate(y.chunk(self.num_slices, dim=1)): - mu, scale, mean_support = self._slice_params( - slice_index, latent_means, latent_scales, y_hat_slices, spatial_shape - ) - indexes = self.gaussian_conditional.build_indexes(scale) - y_q_slice = self.gaussian_conditional.quantize(y_slice, "symbols", mu) - y_hat_slice = y_q_slice + mu - symbols_list.extend(y_q_slice.reshape(-1).tolist()) - indexes_list.extend(indexes.reshape(-1).tolist()) - y_hat_slice = self._apply_lrp(slice_index, mean_support, y_hat_slice) - y_hat_slices.append(y_hat_slice) - - encoder.encode_with_indexes( - symbols_list, indexes_list, cdf, cdf_lengths, offsets - ) - return { - "strings": [encoder.flush()], - "shape": spatial_shape, - "y_hat": torch.cat(y_hat_slices, dim=1), - } - - def decompress( - self, - strings: Sequence[bytes], - shape: Tuple[int, int], - latent_means: Tensor, - latent_scales: Tensor, - **kwargs: Any, - ) -> Dict[str, Any]: - cdf = self.gaussian_conditional.quantized_cdf.tolist() - cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() - offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() - decoder = RansDecoder() - decoder.set_stream(strings[0]) - y_hat_slices: List[Tensor] = [] - - for slice_index in range(self.num_slices): - mu, scale, mean_support = self._slice_params( - slice_index, latent_means, latent_scales, y_hat_slices, shape - ) - indexes = self.gaussian_conditional.build_indexes(scale) - values = decoder.decode_stream( - indexes.reshape(-1).tolist(), cdf, cdf_lengths, offsets - ) - y_q_slice = torch.tensor(values, device=mu.device, dtype=mu.dtype).reshape( - mu.shape - ) - y_hat_slice = self.gaussian_conditional.dequantize(y_q_slice, mu) - y_hat_slice = self._apply_lrp(slice_index, mean_support, y_hat_slice) - y_hat_slices.append(y_hat_slice) - - return {"y_hat": torch.cat(y_hat_slices, dim=1)} diff --git a/compressai/models/_bases/__init__.py b/compressai/models/_bases/__init__.py deleted file mode 100644 index 065119c2..00000000 --- a/compressai/models/_bases/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Abstract base classes shared by multiple slice-based LIC models. - -These were historically hidden behind ``stf_support`` / ``dcae_support`` file -names which obscured the fact that they're real abstract :class:`CompressionModel` -subclasses inherited by 3-4 models each. -""" - -from .slice_entropy import ( - SliceEntropyCompressionModel, - infer_max_support_slices, - infer_num_slices, - lrp_support_channels, - make_entropy_transform, - slice_support_channels, -) - -__all__ = [ - "SliceEntropyCompressionModel", - "infer_max_support_slices", - "infer_num_slices", - "lrp_support_channels", - "make_entropy_transform", - "slice_support_channels", -] diff --git a/compressai/models/_bases/slice_entropy.py b/compressai/models/_bases/slice_entropy.py deleted file mode 100644 index 7159468f..00000000 --- a/compressai/models/_bases/slice_entropy.py +++ /dev/null @@ -1,260 +0,0 @@ -"""Slice-conditional entropy backbone shared by WACNN / SymmetricalTransFormer / MambaVC. - -Promoted out of the historical ``models/stf_support.py`` so the abstract base -class is discoverable by name. Channel-counting helpers and a parameterised -entropy-transform factory live here too — they used to be duplicated across -``stf_support`` / ``ssm_support`` / ``weconvene_support``. -""" - -from __future__ import annotations - -from typing import Dict, Optional, Sequence, Tuple - -import torch.nn as nn - -from torch import Tensor - -from compressai.entropy_models import EntropyBottleneck -from compressai.latent_codecs import ChannelSliceLatentCodec -from compressai.models.utils import conv - -from ..base import CompressionModel - -__all__ = [ - "SliceEntropyCompressionModel", - "infer_max_support_slices", - "infer_num_slices", - "lrp_support_channels", - "make_entropy_transform", - "slice_support_channels", -] - - -_DEFAULT_NUM_SLICES_PREFIX = "latent_codec.cc_mean_transforms." -_KEY_SUFFIX = ".0.weight" - - -def slice_support_channels( - latent_channels: int, - slice_channels: int, - index: int, - max_support_slices: int, -) -> int: - if max_support_slices < 0: - return latent_channels + slice_channels * index - return latent_channels + slice_channels * min(index, max_support_slices) - - -def lrp_support_channels( - latent_channels: int, - slice_channels: int, - index: int, - max_support_slices: int, -) -> int: - if max_support_slices < 0: - return latent_channels + slice_channels * (index + 1) - return latent_channels + slice_channels * min(index + 1, max_support_slices + 1) - - -def make_entropy_transform( - in_channels: int, - out_channels: int, - *, - widths: Sequence[int] = (224, 128), -) -> nn.Sequential: - """Stack of stride-1 3x3 convs with GELU between, used by every slice - entropy model. ``widths`` specifies hidden conv widths; defaults to the - Mamba/WeConvene 3-conv stack. Pass ``widths=(224, 176, 128, 64)`` for the - STF/WACNN 5-conv stack.""" - layers: list[nn.Module] = [] - prev = in_channels - for width in widths: - layers.append(conv(prev, width, stride=1, kernel_size=3)) - layers.append(nn.GELU()) - prev = width - layers.append(conv(prev, out_channels, stride=1, kernel_size=3)) - return nn.Sequential(*layers) - - -def infer_num_slices( - state_dict: Dict[str, Tensor], - *, - prefix: str = _DEFAULT_NUM_SLICES_PREFIX, - suffix: str = _KEY_SUFFIX, -) -> int: - slice_indices = { - int(key[len(prefix) :].split(".", 1)[0]) - for key in state_dict - if key.startswith(prefix) and key.endswith(suffix) - } - return len(slice_indices) - - -def infer_max_support_slices( - state_dict: Dict[str, Tensor], - latent_channels: int, - num_slices: int, - *, - prefix: str = _DEFAULT_NUM_SLICES_PREFIX, - suffix: str = _KEY_SUFFIX, - extra_factor: int = 1, -) -> int: - """Infer ``max_support_slices`` from the input width of the first - cc_mean transform conv. ``extra_factor`` accounts for models like DCAE/SAAF - that prepend additional copies of the latent (``M*3 + slice_channels*N``); - pass ``extra_factor=3`` there. Slice-only models (STF/Mamba*) keep the - default ``extra_factor=1``.""" - slice_channels = latent_channels // num_slices - matching = [ - tensor.size(1) - for key, tensor in state_dict.items() - if key.startswith(prefix) and key.endswith(suffix) - ] - if not matching: - return 0 - max_input_channels = max(matching) - return max( - 0, (max_input_channels - extra_factor * latent_channels) // slice_channels - ) - - -class SliceEntropyCompressionModel(CompressionModel): - """Channel-conditional entropy backbone shared by WACNN, SymmetricalTransFormer, MambaVC. - - Subclasses must populate ``g_a``, ``g_s``, ``h_a``, ``h_mean_s`` and - ``h_scale_s``, then call :meth:`_init_slice_entropy` to wire up the - entropy bottleneck for ``z`` and the :class:`ChannelSliceLatentCodec` - for ``y``. - """ - - h_a: nn.Module - h_mean_s: nn.Module - h_scale_s: nn.Module - entropy_bottleneck: EntropyBottleneck - latent_codec: ChannelSliceLatentCodec - - def _init_slice_entropy( - self, - latent_channels: int, - entropy_bottleneck_channels: int, - num_slices: int, - max_support_slices: int, - mean_support_transforms: Optional[nn.ModuleList] = None, - scale_support_transforms: Optional[nn.ModuleList] = None, - ) -> None: - if latent_channels % num_slices != 0: - raise ValueError("latent_channels must be divisible by num_slices") - if ( - mean_support_transforms is not None - and len(mean_support_transforms) != num_slices - ): - raise ValueError("mean_support_transforms must have num_slices entries") - if ( - scale_support_transforms is not None - and len(scale_support_transforms) != num_slices - ): - raise ValueError("scale_support_transforms must have num_slices entries") - - slice_channels = latent_channels // num_slices - widths = (224, 176, 128, 64) - cc_mean_transforms = nn.ModuleList( - make_entropy_transform( - slice_support_channels( - latent_channels, slice_channels, index, max_support_slices - ), - slice_channels, - widths=widths, - ) - for index in range(num_slices) - ) - cc_scale_transforms = nn.ModuleList( - make_entropy_transform( - slice_support_channels( - latent_channels, slice_channels, index, max_support_slices - ), - slice_channels, - widths=widths, - ) - for index in range(num_slices) - ) - lrp_transforms = nn.ModuleList( - make_entropy_transform( - lrp_support_channels( - latent_channels, slice_channels, index, max_support_slices - ), - slice_channels, - widths=widths, - ) - for index in range(num_slices) - ) - - self.entropy_bottleneck = EntropyBottleneck(entropy_bottleneck_channels) - self.latent_codec = ChannelSliceLatentCodec( - cc_mean_transforms=cc_mean_transforms, - cc_scale_transforms=cc_scale_transforms, - lrp_transforms=lrp_transforms, - mean_support_transforms=mean_support_transforms, - scale_support_transforms=scale_support_transforms, - num_slices=num_slices, - max_support_slices=max_support_slices, - quantizer="ste", - ) - - @property - def num_slices(self) -> int: - return self.latent_codec.num_slices - - @property - def max_support_slices(self) -> int: - return self.latent_codec.max_support_slices - - def _hyper_priors(self, y: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - z = self.h_a(y) - z_hat, z_likelihoods = self.entropy_bottleneck(z) - latent_means = self.h_mean_s(z_hat) - latent_scales = self.h_scale_s(z_hat) - return z, z_likelihoods, latent_means, latent_scales - - def _forward_latent_output( - self, y: Tensor - ) -> Dict[str, Dict[str, Tensor] | Tensor]: - _, z_likelihoods, latent_means, latent_scales = self._hyper_priors(y) - y_out = self.latent_codec(y, latent_means, latent_scales) - output: Dict[str, Dict[str, Tensor] | Tensor] = { - "y_hat": y_out["y_hat"], - "likelihoods": {"y": y_out["likelihoods"]["y"], "z": z_likelihoods}, - } - return output - - def _forward_latent(self, y: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - output = self._forward_latent_output(y) - return output["y_hat"], output["likelihoods"]["y"], output["likelihoods"]["z"] - - def _compress_latent(self, y: Tensor) -> Dict[str, object]: - z = self.h_a(y) - z_strings = self.entropy_bottleneck.compress(z) - z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) - latent_means = self.h_mean_s(z_hat) - latent_scales = self.h_scale_s(z_hat) - y_out = self.latent_codec.compress(y, latent_means, latent_scales) - return { - "strings": [[y_out["strings"][0]], z_strings], - "shape": z.size()[-2:], - } - - def _decompress_latent( - self, - strings: Sequence[Sequence[bytes]], - shape: Tuple[int, int], - ) -> Tensor: - if len(strings) != 2: - raise ValueError("strings must contain [y_strings, z_strings]") - - z_hat = self.entropy_bottleneck.decompress(strings[1], shape) - latent_means = self.h_mean_s(z_hat) - latent_scales = self.h_scale_s(z_hat) - y_shape = (z_hat.shape[2] * 4, z_hat.shape[3] * 4) - y_out = self.latent_codec.decompress( - strings[0], y_shape, latent_means, latent_scales - ) - return y_out["y_hat"] From f87c8c833c07f05419f6c306678d72eab65c6b4b Mon Sep 17 00:00:00 2001 From: boyceyi <1473416941@qq.com> Date: Sat, 9 May 2026 14:28:11 +0800 Subject: [PATCH 6/8] chore(zoo): wire cca/tcm zoo entries with lazy import Register tcm/cca in image_models and model_architectures via _LazyImport so import compressai.zoo stays timm-free; add tcm()/cca() factory functions mirroring stf()/stf_wacnn() (pretrained=True raises until weights are hosted). --- compressai/zoo/__init__.py | 4 ++++ compressai/zoo/image.py | 44 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/compressai/zoo/__init__.py b/compressai/zoo/__init__.py index acebc705..e3e75863 100644 --- a/compressai/zoo/__init__.py +++ b/compressai/zoo/__init__.py @@ -31,12 +31,14 @@ bmshj2018_factorized, bmshj2018_factorized_relu, bmshj2018_hyperprior, + cca, cheng2020_anchor, cheng2020_attn, mbt2018, mbt2018_mean, stf, stf_wacnn, + tcm, ) from .image_vbr import bmshj2018_hyperprior_vbr, mbt2018_mean_vbr, mbt2018_vbr from .pretrained import load_pretrained as load_state_dict @@ -52,6 +54,8 @@ "cheng2020-attn": cheng2020_attn, "stf": stf, "stf-wacnn": stf_wacnn, + "tcm": tcm, + "cca": cca, "bmshj2018-hyperprior-vbr": bmshj2018_hyperprior_vbr, "mbt2018-mean-vbr": mbt2018_mean_vbr, "mbt2018-vbr": mbt2018_vbr, diff --git a/compressai/zoo/image.py b/compressai/zoo/image.py index f506d6bf..69085976 100644 --- a/compressai/zoo/image.py +++ b/compressai/zoo/image.py @@ -84,6 +84,8 @@ def __getattr__(self, item): "cheng2020_attn", "stf", "stf_wacnn", + "tcm", + "cca", ] model_architectures = { @@ -97,6 +99,8 @@ def __getattr__(self, item): # Resolved lazily so `compressai.zoo` is importable without `timm`. "stf": _LazyImport("compressai.models.stf", "SymmetricalTransFormer"), "stf-wacnn": _LazyImport("compressai.models.stf", "WACNN"), + "tcm": _LazyImport("compressai.models.tcm", "TCM"), + "cca": _LazyImport("compressai.models.cca", "CCAModel"), } root_url = "https://compressai.s3.amazonaws.com/models/v1" @@ -525,3 +529,43 @@ def stf_wacnn(pretrained: bool = False, progress: bool = True, **kwargs): from compressai.models.stf import WACNN return WACNN(**kwargs) + + +def tcm(pretrained: bool = False, progress: bool = True, **kwargs): + r"""TCM (Transformer-CNN Mixture) model from J. Liu, H. Sun, J. Katto: + `"Learned Image Compression with Mixed Transformer-CNN Architectures" + `_, IEEE/CVF Conf. on Computer Vision + and Pattern Recognition (CVPR), 2023. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained TCM weights are not yet hosted on S3.") + from compressai.models.tcm import TCM + + return TCM(**kwargs) + + +def cca(pretrained: bool = False, progress: bool = True, **kwargs): + r"""CCA (Causal Context Adjustment) model from M. Han, S. Jiang, S. Li, + X. Deng, M. Xu, C. Zhu, S. Liu: `"Causal Context Adjustment Loss for + Learned Image Compression" `_, NeurIPS + 2024. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained CCA weights are not yet hosted on S3.") + from compressai.models.cca import CCAModel + + return CCAModel(**kwargs) From 7ef153453abf86d3c9610170573fed04ad4cd64a Mon Sep 17 00:00:00 2001 From: boyceyi <1473416941@qq.com> Date: Sun, 10 May 2026 13:51:52 +0800 Subject: [PATCH 7/8] fix(latent_codecs): channel_groups decompress allocates correct 4D buffer for 2D-shape leaves ChannelGroupsLatentCodec.decompress reconstructed the destination buffer shape with (sum(s[0] for s in shape), *shape[0][1:]), which assumes each leaf reports a 3D (C, H, W) shape (the CheckerboardLatentCodec convention). Family 1 leaves (LRPGaussianLatentCodec via GaussianConditionalLatentCodec) report a 2D (H, W) shape, which collapsed the buffer to 3D (N, sum_H, W) and triggered a broadcast RuntimeError when assigning the leaf's 4D y_hat back into the split slice (manifesting on STF/WACNN compress->decompress round-trip). Use self.groups for the channel total and take the trailing two dims of any per-group shape as spatial -- works for both leaf shape conventions. Adds regression coverage for both 2D and 3D leaf shapes. --- compressai/latent_codecs/channel_groups.py | 8 +- tests/test_latent_codecs.py | 89 ++++++++++++++++++++++ 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/compressai/latent_codecs/channel_groups.py b/compressai/latent_codecs/channel_groups.py index 69670925..43de6716 100644 --- a/compressai/latent_codecs/channel_groups.py +++ b/compressai/latent_codecs/channel_groups.py @@ -149,8 +149,12 @@ def decompress( strings_per_group = len(strings) // len(self.groups) y_out_ = [{}] * len(self.groups) - y_shape = (sum(s[0] for s in shape), *shape[0][1:]) - y_hat = torch.zeros((n, *y_shape), device=side_params.device) + # Spatial dims are the trailing two entries of any per-group shape; + # the channel total is determined by ``self.groups`` (so this works + # for both leaves that report ``(C, H, W)`` -- e.g. CheckerboardLatentCodec -- + # and leaves that report ``(H, W)`` -- e.g. GaussianConditionalLatentCodec). + spatial = tuple(shape[0])[-2:] + y_hat = torch.zeros((n, sum(self.groups), *spatial), device=side_params.device) y_hat_ = y_hat.split(self.groups, dim=1) for k in range(len(self.groups)): diff --git a/tests/test_latent_codecs.py b/tests/test_latent_codecs.py index cbd03ff8..8b9297bd 100644 --- a/tests/test_latent_codecs.py +++ b/tests/test_latent_codecs.py @@ -313,6 +313,95 @@ def test_get_ctx_params_for_k_positive_concats_side(self): assert ctx.shape == (1, 8, 4, 4) +class TestChannelGroupsDecompressShape: + """Regression coverage for ``ChannelGroupsLatentCodec.decompress`` shape + reconstruction. + + Family 1 (STF/WACNN/TCM/CCA) leaves are :class:`LRPGaussianLatentCodec` + which inherits :class:`GaussianConditionalLatentCodec.compress` → returns + ``shape = y.shape[2:4]`` (2D ``(H, W)``). ELIC-style leaves use + :class:`CheckerboardLatentCodec` → returns ``shape = y_hat.shape[1:]`` + (3D ``(C, H, W)``). ``decompress`` must allocate the correct 4D + ``(N, sum_C, H, W)`` buffer in either case. + """ + + class _LeafMock2D(nn.Module): + """Mimics GaussianConditionalLatentCodec: shape=(H, W) from compress, + no real entropy coding (zeros for y_hat).""" + + def __init__(self, slice_ch): + super().__init__() + self.slice_ch = slice_ch + + def compress(self, y, ctx_params): + n = y.shape[0] + return { + "strings": [[b"" for _ in range(n)]], + "shape": tuple(y.shape[2:4]), + "y_hat": torch.zeros_like(y), + } + + def decompress(self, strings, shape, ctx_params, **kwargs): + n = len(strings[0]) + h, w = shape + return {"y_hat": torch.zeros((n, self.slice_ch, h, w))} + + class _LeafMock3D(nn.Module): + """Mimics CheckerboardLatentCodec: shape=(C, H, W) from compress.""" + + def __init__(self, slice_ch): + super().__init__() + self.slice_ch = slice_ch + + def compress(self, y, ctx_params): + n = y.shape[0] + return { + "strings": [[b"" for _ in range(n)]], + "shape": tuple(y.shape[1:]), + "y_hat": torch.zeros_like(y), + } + + def decompress(self, strings, shape, ctx_params, **kwargs): + n = len(strings[0]) + c, h, w = shape + return {"y_hat": torch.zeros((n, c, h, w))} + + def _make_codec(self, leaf_cls, groups=(4, 4, 4), side_ch=8): + K = len(groups) + return ChannelGroupsLatentCodec( + latent_codec={f"y{k}": leaf_cls(groups[k]) for k in range(K)}, + channel_context={f"y{k}": nn.Identity() for k in range(1, K)}, + groups=list(groups), + ) + + def test_decompress_with_2d_leaf_shape(self): + # Pre-fix: y_shape = (sum(s[0] for s in shape), *shape[0][1:]) + # collapsed to (sum_H, W) = (3*6, 5) -> y_hat 3D -> RuntimeError when + # assigning the 4D leaf y_hat into a 3D split slice. + groups = [4, 4, 4] + codec = self._make_codec(self._LeafMock2D, groups=groups) + # Deliberately pick H != W and H != sum(groups) so a regression in + # axis-confusion (e.g. sum_H instead of sum_C) surfaces as a shape + # error, not a silent wrong-shape pass. + h, w = 6, 5 + y = torch.randn(1, sum(groups), h, w) + side_params = torch.zeros(1, 8, h, w) + out_enc = codec.compress(y, side_params) + out_dec = codec.decompress(out_enc["strings"], out_enc["shape"], side_params) + assert out_dec["y_hat"].shape == (1, sum(groups), h, w) + + def test_decompress_with_3d_leaf_shape_still_works(self): + # ELIC-style path must keep working. + groups = [4, 4, 4] + codec = self._make_codec(self._LeafMock3D, groups=groups) + h, w = 6, 5 + y = torch.randn(1, sum(groups), h, w) + side_params = torch.zeros(1, 8, h, w) + out_enc = codec.compress(y, side_params) + out_dec = codec.decompress(out_enc["strings"], out_enc["shape"], side_params) + assert out_dec["y_hat"].shape == (1, sum(groups), h, w) + + class TestSliceHelpers: def test_slice_support_channels_default_use_all(self): # With max_support_slices = -1 the helper returns the full latent + k slices. From 565c6bfb7f0b997ef5fd94bddaf90bdf8a018b9d Mon Sep 17 00:00:00 2001 From: boyceyi <1473416941@qq.com> Date: Mon, 11 May 2026 10:18:52 +0800 Subject: [PATCH 8/8] fix(models): align z-leaf quantizer with upstream STE on STF/WACNN/TCM The original LIC TCM (`tcm.py:434`), WACNN (`cnn.py:152`), and SymmetricalTransFormer (`stf.py:602`) implementations all run `quantize_ste(z - z_offset) + z_offset` after the entropy bottleneck so that downstream `h_s` consumes a STE-rounded `z_hat` (likelihoods are still computed on noisy z to train the parametric prior). The Family 1 containerization in `_build_family1_latent_codec` and `_build_tcm_latent_codec` was passing `quantizer="noise"` to the `z` leaf, which silently propagated noisy `z_hat` to `h_s` during training -- a real RD-relevant deviation from the published models. Switch both build sites to `quantizer="ste"`, matching CCA-main (`cca.py:360`) which already used STE. Eval-mode forward and state-dict layout are unchanged (same module tree, same parameters); only training-time `z_hat` propagated to `h_s` becomes deterministic. Also drop the misleading STF/TCM-noise vs CCA-STE distinction from `compressai/latent_codecs/__init__.py` and `tests/test_models.py`, replacing it with a Family 1 invariant note: all four models use STE on z. --- compressai/latent_codecs/__init__.py | 8 ++++++-- compressai/models/stf.py | 2 +- compressai/models/tcm.py | 2 +- tests/test_models.py | 3 ++- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/compressai/latent_codecs/__init__.py b/compressai/latent_codecs/__init__.py index e1e8f324..0aa4a7ef 100644 --- a/compressai/latent_codecs/__init__.py +++ b/compressai/latent_codecs/__init__.py @@ -93,6 +93,11 @@ # ``cat(latent_means, *prev_y_hat, y_hat)`` layout for byte-for-byte # weight transfer. # +# All Family 1 models use ``EntropyBottleneckLatentCodec(quantizer="ste")`` +# for the ``z`` leaf, mirroring the upstream +# ``quantize_ste(z - z_offset) + z_offset`` pattern: noise-based likelihoods +# during training but a STE-rounded ``z_hat`` propagated to ``h_s``. +# # Application-layer helpers in # :mod:`compressai.models._helpers.channel_slice` and # :mod:`compressai.models._helpers.channel_context` @@ -106,8 +111,7 @@ # ``support_transform_factory=SWAtten`` (independent windowed-attention # transforms per mean / scale path). # - **CCA-main**: variable-length slices (``groups=resolved_slice_sizes``), -# ``support_transform_factory=NAFTransform``, -# ``EntropyBottleneckLatentCodec(quantizer="ste")`` for the ``z`` leaf. +# ``support_transform_factory=NAFTransform``. # - **CCA-aux**: lives outside the hyperprior container (separate # ``ChannelGroupsLatentCodec``), uses ``support_filter`` for # skip-most-recent prior selection, and mixes diff --git a/compressai/models/stf.py b/compressai/models/stf.py index d2ed1348..a361765b 100644 --- a/compressai/models/stf.py +++ b/compressai/models/stf.py @@ -616,7 +616,7 @@ def _channel_context(_k: int, _slice_ch: int, support_ch: int) -> nn.Module: h_s=DualHyperSynthesis(h_mean_s, h_scale_s), latent_codec={ "z": EntropyBottleneckLatentCodec( - entropy_bottleneck=EntropyBottleneck(N), quantizer="noise" + entropy_bottleneck=EntropyBottleneck(N), quantizer="ste" ), "y": build_channel_slice_codec( groups=[slice_ch] * num_slices, diff --git a/compressai/models/tcm.py b/compressai/models/tcm.py index 89fe107c..7eba1e67 100644 --- a/compressai/models/tcm.py +++ b/compressai/models/tcm.py @@ -747,7 +747,7 @@ def _channel_context(_k: int, _slice_ch: int, support_ch: int) -> nn.Module: latent_codec={ "z": EntropyBottleneckLatentCodec( entropy_bottleneck=EntropyBottleneck(hyper_channels), - quantizer="noise", + quantizer="ste", ), "y": build_channel_slice_codec( groups=[slice_ch] * num_slices, diff --git a/tests/test_models.py b/tests/test_models.py index 7c1b968a..fab3d059 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -578,7 +578,8 @@ def test_cca_forward_and_state_dict_round_trip(self): assert "latent_codec.h_a.0.weight" in sd_keys assert "latent_codec.h_s.h_mean_s.0.weight" in sd_keys assert "latent_codec.h_s.h_scale_s.0.weight" in sd_keys - # CCA-main uses STE quantization on z (vs. STF/TCM noise mode). + # All Family 1 models (STF/WACNN/TCM/CCA) use STE on z; the + # entropy_bottleneck still owns the parametric prior. assert "latent_codec.z.entropy_bottleneck.quantiles" in sd_keys # side_in_context=True -> channel_context covers y0..y(K-1). assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in sd_keys