Skip to content

Add TCM (CVPR 2023) + CCA (NeurIPS 2024) with Family-1 codec refactor#355

Open
Yiozolm wants to merge 8 commits intoInterDigitalInc:masterfrom
Yiozolm:pr-tcm-cca
Open

Add TCM (CVPR 2023) + CCA (NeurIPS 2024) with Family-1 codec refactor#355
Yiozolm wants to merge 8 commits intoInterDigitalInc:masterfrom
Yiozolm:pr-tcm-cca

Conversation

@Yiozolm
Copy link
Copy Markdown
Contributor

@Yiozolm Yiozolm commented May 9, 2026

Adds two new models from the per-model PR series in #353:

Per the discussion in #354, this PR also delivers the refactored latent-codec abstraction I committed to ship next ("I'll include the refactored abstraction layer in the next PR"). The new shared infrastructure unifies the channel-slice topology used by STF / WACNN (added in #354), TCM (this PR), CCA (this PR), and the upcoming DCAE / MambaVC follow-ups onto upstream ChannelGroupsLatentCodec rather than the temporary ChannelSliceLatentCodec introduced in #354.

Pretrained weights are intentionally not bundled — calling pretrained=True raises a clear RuntimeError until weights are hosted on S3 (per the discussion in #353).

Summary

  • New zoo entries "tcm" and "cca" (compressai.models.tcm.TCM, compressai.models.cca.CCAModel), wired via lazy-import _LazyImport proxy in model_architectures so import compressai.zoo stays timm-free.
  • New compressai.losses.cca.CCARateDistortionLoss — extends RateDistortionLoss with the auxiliary "causal context adjustment" term (NeurIPS 2024 §3.2) wired to the optional _CCAAuxEntropyModel head.
  • Family-1 latent-codec infrastructure in compressai/latent_codecs/ (see Refactor below).
  • Application-layer helpers in compressai/models/_helpers/{channel_slice,channel_context}.py — declarative factories that wire Family-1 models in ~3 calls.
  • Migration of STF / WACNN to the new infrastructure — drops the temporary ChannelSliceLatentCodec + SliceEntropyCompressionModel scaffolding from Add WACNN and SymmetricalTransFormer (STF, CVPR 2022) #354 with no checkpoint-format break (LRP weights are byte-for-byte transferable).
  • Checkpoint converters in examples/convert_tcm_checkpoint.py and examples/convert_cca_checkpoint.py for the published upstream weights.
  • No new hard dependencies. TCM / CCA both reuse timm.layers.LayerNorm2d (already pulled in by STF), so they live under the existing [attn] extras group set up in Add WACNN and SymmetricalTransFormer (STF, CVPR 2022) #354.

Refactor: Family-1 latent-codec abstraction

The four models targeted by this PR series so far (STF / WACNN / TCM / CCA) all follow the same outer entropy-stack shape but differ in the four shaded boxes below. #354 absorbed the variation by introducing a dedicated ChannelSliceLatentCodec; this PR shows that all four variants fit cleanly inside upstream ChannelGroupsLatentCodec once it gains four optional kwargs, eliminating the duplicate codec class and giving Family-1 models the same wiring story as ELIC.

HyperpriorLatentCodec(
    h_a=h_a,
    h_s=DualHyperSynthesis(h_mean_s, h_scale_s),       # (1) parallel mean/scale heads
    latent_codec={
        "z": EntropyBottleneckLatentCodec(EntropyBottleneck(N), quantizer=...),
        "y": ChannelGroupsLatentCodec(                 # (2) extended with side_in_context, etc.
            latent_codec={"y0": LRPGaussianLatentCodec(...), ...},   # (3) LRP-aware leaf
            channel_context={"y0": MeanScaleContextHead(...), ...},  # (4) split mean/scale heads
            groups=[M//K]*K,
            max_support_slices=MS,
            side_in_context=True,
            support_filter=...,                        # CCA-aux skip-most-recent
            support_count_fn=...,                      # CCA-aux head-width matching
        ),
    },
)

Concretely the PR adds:

Piece Where What it does
DualHyperSynthesis latent_codecs/_hyper_synthesis.py 25-line adapter that runs h_mean_s(z) and h_scale_s(z) in parallel and concatenates the result, so HyperpriorLatentCodec sees a single h_s.
LRPGaussianLatentCodec latent_codecs/gaussian_conditional.py (~30 lines appended) Subclass of upstream GaussianConditionalLatentCodec that adds the LRP residual prediction (y_hat += lrp_scale * tanh(lrp_transform(cat(mean_support, y_hat)))). With mean_support_trail_channels set, the leaf reads its LRP input from a trailing block of ctx_params produced by the head's emit_mean_support mode — giving byte-for-byte weight transfer from the upstream cat(latent_means, *prev_y_hat, y_hat) layout.
ChannelGroupsLatentCodec extensions latent_codecs/channel_groups.py (~50-line diff) Four optional kwargs, all defaulting to upstream behaviour: max_support_slices (clamp the number of preceding slices used as prior), support_filter (callable to pick a custom subset of priors), support_count_fn (declare how many priors support_filter yields, so head input widths can be sized correctly), and side_in_context (route side_params from h_s through every channel_context head instead of only handing it to the leaves). ELIC and other existing users default-through to the original behaviour.
MeanScaleContextHead + build_mean_scale_head models/_helpers/channel_context.py Application-layer helper: parallel mean/scale cc stacks with optional independent support-transforms per branch, optional `emit_mean_support="pre"
build_channel_slice_codec models/_helpers/channel_slice.py Application-layer factory that wires ChannelGroupsLatentCodec from groups + leaf_factory + channel_context_factory in one call.
_slice_helpers latent_codecs/_slice_helpers.py Free helpers (slice_support_channels, lrp_support_channels, make_entropy_transform, infer_num_slices, infer_max_support_slices) shared by all four models' from_state_dict machinery.

Per-model variation now lives entirely in the kwargs:

Model groups support_transform LRP leaf Notes
STF / WACNN [M//10]*10 none (Identity) yes 5-conv cc heads widths=(224, 176, 128, 64).
TCM [M//K]*K SWAtten (independent per mean/scale) yes 3-conv cc heads widths=(224, 128).
CCA-main slice_proportions=(8,28,56,92,136) (variable-length) NAFTransform (independent per mean/scale) yes Uses EntropyBottleneckLatentCodec(quantizer="ste") for z.
CCA-aux same as main NAFTransform yes Lives outside the HyperpriorLatentCodec tree; uses support_filter=skip_most_recent + matching support_count_fn.

The __init__.py of compressai/latent_codecs/ documents this wiring story in a top-level comment block so reviewers don't need to read each model file to understand the pattern.

State-dict layout

Containerization shifts the saved keys to a single-layer latent_codec.* prefix (the HyperpriorLatentCodec's self.y / self.z are real nn.Module registrations, not nested dicts). The published upstream checkpoints round-trip via the converters below — LRP weights transfer byte-for-byte thanks to mean_support_trail_channels, and TCM's per-slice gaussian_conditional buffer is materialized by copying the single shared upstream copy K times.

latent_codec.h_a.0.weight                                # STF/WACNN/CCA: plain Conv2d   TCM: ResidualBottleneckBlock → .0.conv1.weight
latent_codec.h_s.h_mean_s.0.weight                       # one head per parallel arm of DualHyperSynthesis
latent_codec.h_s.h_scale_s.0.weight
latent_codec.z.entropy_bottleneck.quantiles
latent_codec.y.channel_context.y{k}.mean_cc.0.weight     # MeanScaleContextHead per slice
latent_codec.y.channel_context.y{k}.scale_cc.0.weight
latent_codec.y.channel_context.y{k}.mean_support_transform.<...>     # only if support_transform_factory given
latent_codec.y.latent_codec.y{k}.lrp_transform.0.weight  # LRPGaussianLatentCodec leaf
latent_codec.y.latent_codec.y{k}.gaussian_conditional.scale_table
aux_entropy_model.inner_codec.<same shape>               # CCA only

Commits

Six commits, designed to be reviewed independently:

Commit Scope LOC
feat(latent_codecs): add containerized infrastructure for Family 1 codecs latent_codecs/{_hyper_synthesis, _slice_helpers, gaussian_conditional, channel_groups, __init__}.py + models/_helpers/{channel_slice, channel_context, __init__}.py + tests +900
refactor(models/stf): migrate WACNN + SymmetricalTransFormer to containerized codec models/stf.py + examples/convert_stf_checkpoint.py updates + tests/test_models.py::TestStf +400
feat(models): add TCM with containerized codec models/tcm.py + examples/convert_tcm_checkpoint.py + tests/test_models.py::TestTcm +900
feat(models): add CCA model and loss with containerized codec models/cca.py + losses/cca.py + examples/convert_cca_checkpoint.py + tests/test_models.py::TestCca +1200
chore(latent_codecs,models): drop ChannelSliceLatentCodec and SliceEntropyCompressionModel Delete latent_codecs/channel_slice.py + entire models/_bases/ directory + remove exports −561
chore(zoo): wire cca/tcm zoo entries with lazy import zoo/{__init__,image}.py factory functions + _LazyImport proxies +48
Total 23 files, +4596 / −655

The cleanup commit lands after all four models are migrated, so the branch never goes through a state where STF/WACNN are broken. The refactor and migrations preserve the existing public model classes — only the internal codec-tree shape and the corresponding state-dict paths change.

License & attribution

  • compressai/models/tcm.py carries a dual-license header pointing at the upstream jmliu206/LIC_TCM (Apache-2.0) alongside the standard InterDigital BSD 3-Clause Clear license for modifications.
  • compressai/models/cca.py carries a dual-license header pointing at the upstream LabShuHangGU/CCA (MIT) alongside the standard InterDigital BSD 3-Clause Clear license for modifications. The internal _NAFBlock / _NAFTransform are derived from NAFNet (Chen et al. 2022, MIT) — happy to add per-class attribution headers if maintainers prefer.
  • compressai/losses/cca.py similarly attributes the CCA paper for the auxiliary-loss formulation.

Verified

  • pytest tests/ -q (excluding pretrained-dependent suites — the local S3 ckpt cache is corrupted with unexpected EOF, unrelated to this PR) → 213 passed, 4 skipped, 32 deselected.
  • pytest tests/test_models.py tests/test_latent_codecs.py tests/test_models_helpers.py tests/test_layers.py tests/test_init.py -q74 passed (3 new TestStf + 2 new TestTcm + 3 new TestCca + existing).
  • Round-trip on published upstream checkpoints (from_state_dict(strict=True) then forward + sinusoidal-image smoke):
    • WACNN cnn_0018_best.pth.tar (585 keys) — strict load OK.
    • STF stf_0018_best.pth.tar (779 keys) — strict load OK.
    • TCM 0.05.pth.tar (N=64, M=320, 1397 keys after per-slice GC copy) — strict load OK, sinusoidal PSNR 39.15 dB / total bpp 0.317.
    • TCM mse_lambda_0.05.pth.tar (N=128, M=320, 1397 keys) — strict load OK, sinusoidal PSNR 39.41 dB / total bpp 0.236.
    • CCA checkpoint_lambda_0.3.pth.tar (M=320, slice_sizes=[8,28,56,92,136], 97M params, 2384 keys with main + aux) — strict load OK, sinusoidal PSNR 50.07 dB / total bpp 0.072. Fresh-init baseline at the same config gives ~5 dB, confirming weights are participating.
  • import compressai + import compressai.zoo + import compressai.latent_codecs triggers 0 timm modules (verified via sys.modules snapshot diff).
  • make static-analysis (ruff format / imports / lint, fail-fast) → all 3 steps clean.
  • uv lock --check → consistent (no pyproject.toml changes in this PR).

Test plan

  • Forward + state-dict round-trip for WACNN / STF / TCM / CCA at small configs (TestStf, TestTcm, TestCca).
  • Synthetic upstream-state-dict-conversion tests for all four models, asserting the new latent_codec.* paths exist and the old top-level paths are gone (test_*_upstream_state_dict_conversion).
  • Sinusoidal-image smoke against published upstream checkpoints for WACNN / STF / TCM (two configs) / CCA — PSNR jumps from ~5 dB (fresh init) to 39–50 dB (loaded), confirming byte-for-byte weight transfer.
  • Containerized ChannelGroupsLatentCodec extensions are backward-compatible with ELIC's existing usage (tests/test_models.py::TestElic still green with default kwargs).

Notes for follow-up PRs (per #353)

  • DCAE (Lu et al. CVPR 2025) and SAAF (Ma et al. CVPR 2026) are next — both are Family-1 channel-slice models that should drop straight onto the infrastructure added here. They are already partially implemented on a private branch using the earlier monolithic pattern; converting them to the containerized form is the bulk of the work.
  • Generalize CCA's auxiliary entropy model into a reusable plugin for other channel-slice models. The current _CCAAuxEntropyModel is a private nn.Module inside CCAModel, but its forward signature (y, latent_means, latent_scales) only depends on latent_channels + slice_proportions — not on the host backbone — so it should plug cleanly into WACNN / STF / TCM / MLIC++ / DCAE / SAAF / Mamba-family models via a use_cca=True opt-in. The plan is to extract it into a public compressai.entropy_models.CausalContextAdjustmentEntropyModel (or upgrade to a LatentCodec variant), pair it with the existing CCARateDistortionLoss, and let host models add it in ~30 lines without touching their main entropy path. Whether this transfers the RD gains the CCA paper reports on LICAutoencoder to other backbones is an empirical question for the follow-up PR; this PR only commits to keeping the API minimal so the migration is straightforward.

Yiozolm added 6 commits May 9, 2026 14:54
…decs

Lay the groundwork for refactoring channel-slice models (STF, WACNN, TCM,
CCA) toward an ELIC-style containerized layout where the model owns only
g_a/g_s and the latent_codec owns the entire entropy stack (h_a, h_s, z
bottleneck, per-slice channel context, per-slice leaves).

Codec primitives (compressai/latent_codecs):
  - DualHyperSynthesis(h_mean_s, h_scale_s) adapter so HyperpriorLatentCodec
    can wrap two parallel hyper-synthesis heads while keeping their
    state_dict paths split.
  - LRPGaussianLatentCodec(GaussianConditionalLatentCodec): subclass adding
    the lrp_scale * tanh(lrp_transform(cat(ctx_params, y_hat))) refinement
    used by Zhu2022 and follow-ups; lives in the same file as its base.
  - ChannelGroupsLatentCodec gains optional max_support_slices and
    support_filter parameters (defaults -1 / None preserve ELIC's
    use-all-prior behaviour). Enables STF/TCM-style support clamping and
    CCA-aux skip-most-recent selection without sibling codec classes.
  - _slice_helpers hosts make_entropy_transform, slice_support_channels,
    lrp_support_channels, infer_num_slices, infer_max_support_slices, with
    default state_dict prefixes pointing at the post-refactor layout
    (latent_codec.latent_codec.y.channel_context.yK.mean_cc.*).

Application-layer helpers (compressai/models/_helpers):
  - build_channel_slice_codec: per-slice factory that hides the
    y0..yK-1 dictionary boilerplate when assembling
    ChannelGroupsLatentCodec.
  - MeanScaleContextHead + build_mean_scale_head: independent mean_cc and
    scale_cc Sequentials with optional per-path support transforms (for
    SWAtten / NAFTransform).

No model wiring is changed in this commit; STF/WACNN continue to use the
existing _bases/slice_entropy + ChannelSliceLatentCodec path. Coverage:
28 new unit tests across tests/test_latent_codecs.py and
tests/test_models_helpers.py; tests/test_models.py::TestStf still passes;
importing compressai / compressai.zoo / compressai.latent_codecs introduces
no new dependencies.
…inerized codec

Replace the SliceEntropyCompressionModel base + monolithic
ChannelSliceLatentCodec wiring with an ELIC-style HyperpriorLatentCodec
that owns h_a, h_s, the z entropy bottleneck, and the per-slice channel
context. Both WACNN and SymmetricalTransFormer now inherit
CompressionModel directly with 5-line forward / compress / decompress
methods that delegate to self.latent_codec.

Supporting infra extensions (built on the prior containerized scaffolding
commit 51f536f):

- ChannelGroupsLatentCodec gains side_in_context: bool. Off (default) is
  the existing ELIC behaviour. On, _get_ctx_params: (a) routes side_params
  through channel_context.y0 instead of returning it raw, (b) for k>=1
  feeds cat(side_params, prev_y_hat) to channel_context.y_k, (c) skips the
  trailing cat with side_params at the leaf level. Family 1 models opt in.
- MeanScaleContextHead gains side_split (split leading 2*side_split into
  latent_means / latent_scales and route to mean_cc / scale_cc separately)
  and emit_mean_support (append cat(latent_means, prev_y_hat) to the
  output so downstream LRP can recover the upstream input layout).
- LRPGaussianLatentCodec gains mean_support_trail_channels: when > 0, the
  leaf splits ctx_params into [gaussian_params, mean_support] and feeds
  the trailing block to the LRP transform. This restores the upstream
  cat(latent_means, *prev_y_hat, y_hat) LRP input shape, so the
  per-slice lrp_transforms.{k} weights from upstream Zou et al.
  checkpoints transfer byte-for-byte.
- build_channel_slice_codec accepts side_in_context + side_channels and
  builds the y0 channel_context entry on opt-in.
- _slice_helpers.infer_num_slices auto-detects whether y0 is present in
  state_dict and adjusts the count accordingly.

Upstream checkpoint converter (convert_upstream_stf_state_dict):

- Strips DataParallel module. prefix.
- Re-roots cc_mean_transforms / cc_scale_transforms / lrp_transforms /
  gaussian_conditional / entropy_bottleneck / h_a / h_mean_s / h_scale_s
  under their new latent_codec.* paths.
- Replicates the single shared gaussian_conditional buffer set into
  per-slice leaves (driven by the discovered slice count).
- Nests upstream conv_b.<i>.attn.{qkv, proj, relative_position_*} keys
  under the WMSA wrapper level (.attn.attn.<x>) via _nest_winmsa_keys, so
  WindowAttention parameters land on the right submodule.

Verified end-to-end with the upstream Zou et al. checkpoints
candidate/cnn_0018_best.pth.tar (WACNN) and candidate/stf_0018_best.pth.tar
(SymmetricalTransFormer): both WACNN.from_state_dict and
SymmetricalTransFormer.from_state_dict succeed under strict loading and
forward pass. State-dict round-trip is exercised by an updated
tests/test_models.py::TestStf with self-checks on the new key paths.
Migrate TCM (Liu et al., CVPR 2023) to the H+G containerized entropy
stack used by STF / WACNN. The hyperprior backbone (h_a, h_mean_s,
h_scale_s, EntropyBottleneck) plus the per-slice channel-conditional
entropy heads now live under a single HyperpriorLatentCodec — matching
ELIC-pattern wiring with three Family 1 specializations:

- DualHyperSynthesis(h_mean_s, h_scale_s) cats the two parallel
  hyper-synthesis outputs into side_params of width 2*M.
- ChannelGroupsLatentCodec(side_in_context=True) routes side_params
  into every channel_context head (incl. y0).
- LRPGaussianLatentCodec(mean_support_trail_channels=...) plus the
  MeanScaleContextHead(emit_mean_support=True) recover the upstream
  cat(latent_means, *prev_y_hat, y_hat) LRP layout for byte-for-byte
  weight transfer from upstream LIC_TCM checkpoints.

TCM-specific kwargs vs STF/WACNN: 3-conv widths=(224, 128) plus
support_transform_factory=SWAtten (independent windowed-attention per
mean / scale path), mirroring upstream atten_mean[k] / atten_scale[k].

use_cca / use_auxt are intentionally not implemented in this PR:
- use_cca will be re-added once Phase 5 lands the containerized
  _CCAAuxEntropyModel.
- use_auxt (AuxT, ICLR 2025) depends on layers (WLS / iWLS / OLP) not
  in this branch and is out of scope.

Verified on the LIC_TCM 0.05.pth.tar (N=64) and mse_lambda_0.05.pth.tar
(N=128) candidate checkpoints: strict load succeeds, sinusoidal smoke
test reaches PSNR 39.15 dB / 39.41 dB respectively (vs 5.41 dB for a
fresh-init model), confirming all weights — including LRP — transfer
byte-for-byte through convert_upstream_tcm_state_dict.

Also: append a Family 1 wiring comment block to
compressai/latent_codecs/__init__.py so reviewers can see how the
ELIC-style upstream codecs compose into the STF / WACNN / TCM / CCA /
DCAE / MambaVC pattern without reading model source.

Tests:
- tests/test_models.py::TestTcm::test_tcm_forward_and_state_dict_round_trip
  — forward + new state_dict path self-check + round-trip allclose.
- tests/test_models.py::TestTcm::test_tcm_upstream_state_dict_conversion
  — synthetic upstream LIC_TCM-style state_dict, asserts MSA buffer
  reshape + per-slice / SWAtten-wrapper / hyperprior re-rooting.

make static-analysis clean. pytest tests/test_models.py
tests/test_latent_codecs.py tests/test_models_helpers.py
tests/test_layers.py tests/test_init.py: 71/71. import compressai
/ compressai.zoo / compressai.latent_codecs still trigger zero
timm imports (TCM follows the STF lazy-load convention — not
re-exported from compressai.models.__init__).
Add the Causal Context Adjustment (CCA) standalone autoencoder from
Han et al., NeurIPS 2024 (https://arxiv.org/abs/2410.04847) using the
H+G containerized entropy stack already adopted by STF / WACNN / TCM.
The hyperprior backbone (h_a, h_mean_s, h_scale_s, EntropyBottleneck)
plus the per-slice channel-conditional heads live under a single
HyperpriorLatentCodec — same Family 1 pattern, with two CCA-specific
specializations:

- Variable-length channel slices (slice_proportions defaults to the
  upstream M=320 layout (8, 28, 56, 92, 136)) instead of the equal
  slices used by STF / WACNN / TCM.
- Per-slice NAFTransform mean / scale support transforms (analogous to
  TCM's SWAtten), wired via build_mean_scale_head with
  support_transform_factory.

Auxiliary CCA branch (cca_training=True) adds a _CCAAuxEntropyModel
field that re-encodes y with skip-most-recent support selection
(support_filter=lambda k, prior: prior[: max(k - 1, 0)]) and produces
y_aux / y_cca likelihoods consumed by CCARateDistortionLoss. All
slices use LRPGaussianLatentCodec — the upstream published checkpoints
carry LRP weights for every slice; the unused last-two-slice LRPs are
benign because support_filter excludes those slices' y_hat from any
later slice's prior.

Three small infra additions to support the wiring above:

- MeanScaleContextHead.emit_mean_support extended to
  bool|"pre"|"post". CCA's upstream LRP heads consume the
  *post*-NAFTransform mean_support, while STF / TCM use the raw pre
  layout (Identity transform makes pre/post equivalent for them, so
  True still maps to "pre" for back-compat).
- build_channel_slice_codec.support_count_fn lets the caller declare
  the prior-slice count seen by each channel_context head when a
  custom support_filter selects a non-default count (CCA-aux's
  skip-most-recent: lambda k: max(k - 1, 0)).
- ChannelGroupsLatentCodec._get_ctx_params_side_in_context falls back
  to side_params alone when support_filter returns an empty list
  (CCA-aux at k=1), avoiding torch.cat() on an empty list.

CCAModel.forward in cca_training=True replays the hyperprior path to
recover latent_means / latent_scales for the aux branch instead of
piercing the HyperpriorLatentCodec abstraction — small cost, zero
interface churn for the main codec.

Verified on candidate/CCA/checkpoint_lambda_0.3.pth.tar (M=320,
slice_sizes=[8, 28, 56, 92, 136], em_hidden=224, em_layers=4,
cca_training=True; 97M params): strict-loads after
convert_upstream_cca_state_dict, sinusoidal smoke yields
PSNR 50.07 dB / total bpp 0.072 (vs ~5 dB for a fresh-init model),
confirming all weights — including LRP and aux — transfer
byte-for-byte. Upstream → compressai key count delta of +56 is
explained by replicating the single shared gaussian_conditional buffer
set across each per-slice leaf (7 buffers × 4 extra copies per branch
× 2 branches = 56), matching the channel_context.y{k} layout.

Tests:
- tests/test_models.py::TestCca::test_cca_forward_and_state_dict_round_trip
- tests/test_models.py::TestCca::test_cca_training_branch_forward_and_round_trip
- tests/test_models.py::TestCca::test_cca_upstream_state_dict_conversion

Also drop a few stray internal phase-tracking labels from existing
docstrings (stf.py / tcm.py / channel_context.py / test_latent_codecs.py
/ test_models.py) so the comments describe the wiring directly rather
than referencing project-internal sequencing.
…tropyCompressionModel

Family 1 models (STF, WACNN, TCM, CCA) all migrated to ChannelGroupsLatentCodec
+ _slice_helpers, leaving these two scaffolding modules with no callers in
production code or tests. Delete:

- compressai/latent_codecs/channel_slice.py
- compressai/models/_bases/ (slice_entropy.py + __init__.py; whole dir is empty)
- corresponding exports from latent_codecs/__init__.py
Register tcm/cca in image_models and model_architectures via _LazyImport so
import compressai.zoo stays timm-free; add tcm()/cca() factory functions
mirroring stf()/stf_wacnn() (pretrained=True raises until weights are hosted).
@Yiozolm
Copy link
Copy Markdown
Contributor Author

Yiozolm commented May 9, 2026

@fracape Thanks for merging #354! Posting this immediately since most of it was prepared.
No rush on review at all; please take whatever pace works for you.

@Yiozolm Yiozolm marked this pull request as draft May 10, 2026 05:39
…ffer for 2D-shape leaves

ChannelGroupsLatentCodec.decompress reconstructed the destination buffer
shape with (sum(s[0] for s in shape), *shape[0][1:]), which assumes each
leaf reports a 3D (C, H, W) shape (the CheckerboardLatentCodec
convention). Family 1 leaves (LRPGaussianLatentCodec via
GaussianConditionalLatentCodec) report a 2D (H, W) shape, which collapsed
the buffer to 3D (N, sum_H, W) and triggered a broadcast RuntimeError
when assigning the leaf's 4D y_hat back into the split slice
(manifesting on STF/WACNN compress->decompress round-trip).

Use self.groups for the channel total and take the trailing two dims of
any per-group shape as spatial -- works for both leaf shape conventions.
Adds regression coverage for both 2D and 3D leaf shapes.
@Yiozolm Yiozolm marked this pull request as ready for review May 10, 2026 06:12
The original LIC TCM (`tcm.py:434`), WACNN (`cnn.py:152`), and
SymmetricalTransFormer (`stf.py:602`) implementations all run
`quantize_ste(z - z_offset) + z_offset` after the entropy bottleneck so
that downstream `h_s` consumes a STE-rounded `z_hat` (likelihoods are
still computed on noisy z to train the parametric prior). The Family 1
containerization in `_build_family1_latent_codec` and
`_build_tcm_latent_codec` was passing `quantizer="noise"` to the `z`
leaf, which silently propagated noisy `z_hat` to `h_s` during training
-- a real RD-relevant deviation from the published models.

Switch both build sites to `quantizer="ste"`, matching CCA-main
(`cca.py:360`) which already used STE. Eval-mode forward and
state-dict layout are unchanged (same module tree, same parameters);
only training-time `z_hat` propagated to `h_s` becomes deterministic.

Also drop the misleading STF/TCM-noise vs CCA-STE distinction from
`compressai/latent_codecs/__init__.py` and `tests/test_models.py`,
replacing it with a Family 1 invariant note: all four models use STE
on z.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant