Add WACNN and SymmetricalTransFormer (STF, CVPR 2022)#354
Add WACNN and SymmetricalTransFormer (STF, CVPR 2022)#354fracape merged 6 commits intoInterDigitalInc:masterfrom
Conversation
New compressai.layers.attn subpackage with Swin primitives needed by transformer-based learned image compression models (STF / WACNN in this PR, follow-up InterDigitalInc#353). Module layout: - compressai/layers/attn/swin.py: WindowAttention, WMSA, SwinBlock, SWAtten, ConvTransBlock, WinNoShiftAttention, WinResidualUnit, PatchMerging, PatchSplit + window_partition / window_reverse / pad_to_window_multiple / build_window_attention_mask helpers. - compressai/layers/attn/inference.py: infer_swatten_* helpers for from_state_dict. - compressai/layers/attn/__init__.py: single re-export surface. Implementation reuses timm.models.swin_transformer where possible: - WindowAttention is a thin subclass that promotes timm's relative_position_index buffer from persistent=False to True so released checkpoints round-trip under strict mode, plus accepts the historical qk_scale kwarg. - window_partition / window_reverse are square-window adapters around the timm helpers. WMSA / _WinBasedAttention / WinNoShiftAttention all take an output_proj=True/False switch: True (default) keeps the Linear projection used by SwinBlock / SWAtten elsewhere in CompressAI; False drops it so the same WinNoShiftAttention class serves the STF / WACNN topology in this PR (which has no projection there) without a separate private copy. Root compressai/layers/__init__.py only appends from .attn import * so existing call sites keep working.
Channel-conditional slice-entropy machinery shared by STF and WACNN in this PR, factored out so other slice-conditional models added in follow-up PRs (CCA, TCM, ...) can reuse it. - compressai/latent_codecs/channel_slice.py: ChannelSliceLatentCodec implements equal-sized channel slicing (Minnen2020 / He2022) — cc_mean_transforms / cc_scale_transforms per slice + LRP head + optional mean / scale support transforms. Sibling of the existing ChannelGroupsLatentCodec. - compressai/models/_bases/slice_entropy.py: SliceEntropyCompressionModel collects the recurring "build entropy bottleneck for z plus a ChannelSliceLatentCodec for y" plumbing used by every channel-slice model, plus the from_state_dict helpers (infer_num_slices, infer_max_support_slices, slice/lrp support channel arithmetic, make_entropy_transform). Subclasses populate g_a / g_s / h_a / h_mean_s / h_scale_s, then call self._init_slice_entropy(...). - compressai/models/_bases/__init__.py: re-export surface. - compressai/latent_codecs/__init__.py: export ChannelSliceLatentCodec.
…al. 2022 Adds the WACNN (CNN backbone) and SymmetricalTransFormer (transformer backbone) from R. Zou, C. Song, Z. Zhang, "The Devil Is in the Details: Window-based Attention for Image Compression", CVPR 2022 (https://arxiv.org/abs/2203.08450). Adapted from the official implementation at https://github.com/Googolxx/STF (Apache-2.0). What is included: - compressai/models/stf.py: WACNN, SymmetricalTransFormer, and a convert_upstream_stf_state_dict helper that strips the DataParallel module. prefix and re-roots cc_mean_transforms / cc_scale_transforms / lrp_transforms / gaussian_conditional under latent_codec.* so released checkpoints from the upstream repo load via WACNN.from_state_dict / SymmetricalTransFormer.from_state_dict. Builds on the WinNoShiftAttention (output_proj=False) primitive from the previous "feat(layers)" commit and uses timm.models.swin_transformer.SwinTransformerBlock with always_partition=True / dynamic_mask=True for the transformer stages (with the per-block relative_position_index buffer promoted to persistent so the upstream key list survives strict load). - compressai/models/__init__.py: export the two new model classes. - compressai/zoo/{__init__,image}.py: register "stf" and "stf-wacnn" in image_models with thin pretrained=False factory functions; pretrained=True raises a clear RuntimeError until weights are hosted on S3 by the maintainers (per InterDigitalInc#353). - examples/convert_stf_checkpoint.py: CLI wrapper around the upstream-state-dict converter, with an optional smoke test on a synthetic image. - tests/test_models.py: TestStf class — forward + state_dict round-trip for both backbones, plus a unit test for the convert_upstream_stf_state_dict helper. - pyproject.toml: add timm to the runtime dependencies (used by the Swin building blocks committed earlier in this PR for DropPath / Mlp / SwinTransformerBlock / WindowAttention). Pretrained weights are intentionally not bundled. State-dict round-trip diff is 0.0 for both WACNN (405 keys) and STF (315 keys); pytest tests/test_models.py tests/test_layers.py tests/test_init.py = 32 passed (3 new TestStf + 29 existing).
|
I have reused the swin-transformer code from |
PR InterDigitalInc#354 CI failures: - static_analysis: 6 files needed ruff format and 5 files needed import sort. - tests: uv.lock was stale after adding `timm` to dependencies.
|
This PR correctly keeps modifications to existing code minimal, other than some additional exports. To avoid introducing a new required dependency (
I think I prefer the first. Additionally, the newly introduced functions in the zoo could import the model lazily: def stf(pretrained: bool = False, progress: bool = True, **kwargs):
...
from compressai.models.stf import SymmetricalTransFormer
...Then, move the After that, it looks good to merge. |
Addresses PR InterDigitalInc#354 review feedback from @YodaEmbedding: avoid making timm a required dependency just to ship STF/WACNN. - pyproject.toml: move `timm` from `[project.dependencies]` into a new `[project.optional-dependencies] attn` group (`pip install compressai[attn]`). Named `attn` rather than `stf` so follow-up window-attention models (CCA, TCM, MLIC++, …) can share it. - compressai/models/__init__.py: drop `from .stf import *`. - compressai/layers/__init__.py: drop `from .attn import *`. Together these break the eager import chain so plain `import compressai` no longer touches `compressai.models.stf` or `compressai.layers.attn`, and therefore does not require timm. - compressai/zoo/image.py: lazy-import the model classes inside `stf()` and `stf_wacnn()`, and resolve the two `model_architectures` entries via a small `_LazyImport` proxy so `import compressai.zoo` is also timm-free. The proxy forwards `__call__` and attribute access so existing call sites (`model_architectures[arch](*cfg, **kw)`, `....from_state_dict(...)`) keep working. - tests/test_models.py, examples/convert_stf_checkpoint.py: import the STF symbols from their canonical module path (`compressai.models.stf`) instead of the package root. Verified: `import compressai; import compressai.zoo` triggers zero timm imports and does not load `compressai.models.stf` or `compressai.layers.attn`. Existing 32 tests still pass.
I will submit the following code in a similar style. |
The three `infer_swatten_*` state-dict introspection helpers in `compressai/layers/attn/inference.py` were intended for MambaIC / MambaVC `from_state_dict`, but neither model is part of this PR. No code in this PR (stf.py, swin.py, tests) calls them. Removing per "don't add infrastructure for hypothetical future PRs" — the helpers can be reintroduced alongside the first model that actually uses them.
|
I'm trying to design a better abstraction layer for the latent codec of subsequent models, since the original authors all have different coding styles. My suggestion is to merge this PR first, and then I’ll include the refactored abstraction layer in the next PR. |
PR #354 CI failures: - static_analysis: 6 files needed ruff format and 5 files needed import sort. - tests: uv.lock was stale after adding `timm` to dependencies.
Addresses PR #354 review feedback from @YodaEmbedding: avoid making timm a required dependency just to ship STF/WACNN. - pyproject.toml: move `timm` from `[project.dependencies]` into a new `[project.optional-dependencies] attn` group (`pip install compressai[attn]`). Named `attn` rather than `stf` so follow-up window-attention models (CCA, TCM, MLIC++, …) can share it. - compressai/models/__init__.py: drop `from .stf import *`. - compressai/layers/__init__.py: drop `from .attn import *`. Together these break the eager import chain so plain `import compressai` no longer touches `compressai.models.stf` or `compressai.layers.attn`, and therefore does not require timm. - compressai/zoo/image.py: lazy-import the model classes inside `stf()` and `stf_wacnn()`, and resolve the two `model_architectures` entries via a small `_LazyImport` proxy so `import compressai.zoo` is also timm-free. The proxy forwards `__call__` and attribute access so existing call sites (`model_architectures[arch](*cfg, **kw)`, `....from_state_dict(...)`) keep working. - tests/test_models.py, examples/convert_stf_checkpoint.py: import the STF symbols from their canonical module path (`compressai.models.stf`) instead of the package root. Verified: `import compressai; import compressai.zoo` triggers zero timm imports and does not load `compressai.models.stf` or `compressai.layers.attn`. Existing 32 tests still pass.
…ressai.layers.attn Adds compressai/layers/attn/dictionary.py (~250 lines) with the building blocks DCAE (Lu et al., CVPR 2025) and SAAF (Ma et al., CVPR 2026) share for their channel-context heads: - MutiScaleDictionaryCrossAttentionGLU — main entropy head, cross-attends per-slice support against a shared dictionary tensor - MultiScaleAggregation, ConvolutionalGLU, DenseBlock, ConvWithDW, SpatialAttentionModule, DWConv, Scale — supporting modules Lifted verbatim from the upstream DCAE reference implementation (the SAAF paper reuses identical entropy-side blocks). No new dependencies — einops is already a hard dep, and timm.layers.DropPath is already pulled in via the [attn] extras from the WACNN/STF PR (InterDigitalInc#354).
…ressai.layers.attn Adds compressai/layers/attn/dictionary.py (~250 lines) with the building blocks DCAE (Lu et al., CVPR 2025) and SAAF (Ma et al., CVPR 2026) share for their channel-context heads: - MutiScaleDictionaryCrossAttentionGLU — main entropy head, cross-attends per-slice support against a shared dictionary tensor - MultiScaleAggregation, ConvolutionalGLU, DenseBlock, ConvWithDW, SpatialAttentionModule, DWConv, Scale — supporting modules Lifted verbatim from the upstream DCAE reference implementation (the SAAF paper reuses identical entropy-side blocks). No new dependencies — einops is already a hard dep, and timm.layers.DropPath is already pulled in via the [attn] extras from the WACNN/STF PR (InterDigitalInc#354).
Adds WACNN and SymmetricalTransFormer (STF) from R. Zou, C. Song, Z. Zhang, "The Devil Is in the Details: Window-based Attention for Image Compression", CVPR 2022 (arXiv:2203.08450).
Adapted from the official implementation at https://github.com/Googolxx/STF (Apache-2.0).
This is the first installment of the per-model PR series proposed in #353. Pretrained weights are intentionally not bundled — calling
pretrained=Trueraises a clearRuntimeErroruntil weights are hosted on S3 (per the discussion in #353).Summary
"stf"and"stf-wacnn"(compressai.models.SymmetricalTransFormerandcompressai.models.WACNN).compressai.layers.attnsubpackage with the Swin window-based attention building blocks the two models depend on. Reusestimm.models.swin_transformerwherever the implementation is generic to avoid vendoring a parallel Swin stack — see Reuse of timm below.ChannelSliceLatentCodec+SliceEntropyCompressionModelbase — designed to be reused by the channel-conditional models in follow-up PRs (CCA, TCM, …).examples/convert_stf_checkpoint.pythat loads the publishedstf_<bpp>_best.pth.tar/cnn_<bpp>_best.pth.tarfiles from the upstream repo and writes them in compressai layout.timmadded todependencies(the Swin building blocks reuseDropPath,Mlp,trunc_normal_,WindowAttention,SwinTransformerBlock,window_partition,window_reversefrom it).Reuse of timm
Rather than vendor a full Swin stack inside CompressAI, the implementation in this PR delegates to
timm.models.swin_transformereverywhere the upstream STF code matches the Swin reference. This kept the diff focused on the genuinely STF-specific pieces and shaved ~280 lines from an earlier vendored draft.WindowAttentiontimm.models.swin_transformer.WindowAttentionthat promotes therelative_position_indexbuffer frompersistent=FalsetoTrue(so released checkpoints load under strict mode) and accepts the historicalqk_scalekwarg. ~15 lines instead of a ~50-line reimplementation.SwinTransformerBlock(used inside_STFBasicLayer)timm.models.swin_transformer.SwinTransformerBlock(always_partition=True, dynamic_mask=True)directly. After construction we promote each block'sattn.relative_position_indexto persistent so per-block keys round-trip strict-mode. Avoids reimplementing the cyclic-shift / pad / window-attn / unpad / unshift forward path.window_partition/window_reverseTuple[int, int]whereas STF passesint.DropPath,Mlp,trunc_normal_timm.layersversions used directly.WMSA/WinNoShiftAttention(the STF-specific dual-branch sigmoid-gated attention block)output_proj=True/Falseso a single class serves both the STF / WACNN topology in this PR (no projection) and the projection-bearing variant used by other window-attention CompressAI models. No private_STF*duplicate is kept.SwinBlock,SWAtten,ConvTransBlock,_PatchEmbed,_WinBasedAttention,WinResidualUnit,pad_to_window_multiple,build_window_attention_mask)timmdoes not expose an equivalent), so vendoring keeps the API stable acrosstimmreleases.The dependency on
timm.models.swin_transformer.*is deliberate (the file lives undertimm.models.*rather thantimm.layers.*, so it is not part of timm's stability promise). If maintainers prefer to insulate CompressAI fromtimmmodel-internals, the subclass / wrapper pattern makes it a small, self-contained ~120-line revert. Happy to do that on request.Commits
Three commits, designed to be reviewed independently:
feat(layers): add Swin window-based attention building blockscompressai/layers/attn/{swin,inference,__init__}.py+ tiny re-export inlayers/__init__.pyfeat(latent_codecs): add ChannelSliceLatentCodec + slice-entropy basecompressai/latent_codecs/channel_slice.py+compressai/models/_bases/{slice_entropy,__init__}.py+ re-exportfeat(models): add WACNN and SymmetricalTransFormer (STF) from Zou et al. 2022compressai/models/stf.py+ zoo / converter / smoke tests +timminpyproject.tomlLicense & attribution
compressai/models/stf.pycarries a dual-license header noting the upstream source URL and Apache-2.0 license alongside the standard InterDigital BSD 3-Clause Clear License for the modifications. The Swin building blocks incompressai/layers/attn/swin.pyare a mix of timm subclasses / wrappers (covered by timm's Apache-2.0) and STF-derived classes (also Apache-2.0); happy to add per-file attribution headers there as well if maintainers prefer.Verified
pytest tests/test_models.py tests/test_layers.py tests/test_init.py→ 32 passed (3 newTestStf+ 29 existing).WACNN.from_state_dict(model.state_dict())round-trip →x_hatdiff = 0.0 (405 keys).SymmetricalTransFormer.from_state_dict(model.state_dict())round-trip →x_hatdiff = 0.0 (315 keys).convert_upstream_stf_state_dictcorrectly re-rootsmodule.cc_*/module.gaussian_conditionalkeys underlatent_codec.*so the publishedGoogolxx/STFcheckpoints load viafrom_state_dict.Test plan
TestStf).examples/convert_stf_checkpoint.pyagainst an upstreamcnn_<bpp>_best.pth.tarcheckpoint locally (x_hatdiff = 0 between original and converted state dict in eval mode).timmbeing moved into harddependenciesis acceptable (alternative: keep[stf]extras group).timm.models.swin_transformer.*(model-internal API) is acceptable, vs. vendoring a CompressAI copy. Reverting is a small isolated change if preferred.models/stf.py), I will add them.Notes for follow-up PRs (per #353)
The next PR will add CCA + TCM together — both reuse
ChannelSliceLatentCodecfrom this PR, and CCA contributes aCausalContextAdjustmentEntropyModelthat TCM can opt into. After that, the remaining license-clear models (InvCompress,MLIC++,HPCM,SAAF,DCAE,GLIC,TIC,TinyLIC,ShiftLIC) follow one or two at a time, each PR layering on top of what's already merged.