diff --git a/compressai/latent_codecs/__init__.py b/compressai/latent_codecs/__init__.py index ceada0b1..82a41947 100644 --- a/compressai/latent_codecs/__init__.py +++ b/compressai/latent_codecs/__init__.py @@ -29,6 +29,7 @@ from .base import LatentCodec from .channel_groups import ChannelGroupsLatentCodec +from .channel_slice import ChannelSliceLatentCodec from .checkerboard import CheckerboardLatentCodec from .entropy_bottleneck import EntropyBottleneckLatentCodec from .gain import GainHyperLatentCodec, GainHyperpriorLatentCodec @@ -40,6 +41,7 @@ __all__ = [ "LatentCodec", "ChannelGroupsLatentCodec", + "ChannelSliceLatentCodec", "CheckerboardLatentCodec", "EntropyBottleneckLatentCodec", "GainHyperLatentCodec", diff --git a/compressai/latent_codecs/channel_slice.py b/compressai/latent_codecs/channel_slice.py new file mode 100644 index 00000000..73e32d6e --- /dev/null +++ b/compressai/latent_codecs/channel_slice.py @@ -0,0 +1,269 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.ans import BufferedRansEncoder, RansDecoder +from compressai.entropy_models import GaussianConditional +from compressai.ops import quantize_ste +from compressai.registry import register_module + +from .base import LatentCodec + +__all__ = [ + "ChannelSliceLatentCodec", +] + + +@register_module("ChannelSliceLatentCodec") +class ChannelSliceLatentCodec(LatentCodec): + """Channel-conditional entropy model with separate scale/mean heads and LRP. + + Splits ``y`` into equal-sized slices along the channel axis. For each + slice ``k`` the previously decoded slices (truncated to + ``max_support_slices``) are concatenated with ``latent_means`` / + ``latent_scales`` and pushed through ``cc_mean_transforms[k]`` and + ``cc_scale_transforms[k]`` to obtain ``mu`` / ``scale``. After the + Gaussian conditional step, an optional latent residual prediction + (LRP) head refines ``y_hat``. + + This is the channel-autoregressive entropy model from [Minnen2020] + with the LRP refinement variant used in [Zhu2022] (STF / WACNN), + [He2022] (ELIC) and many follow-up papers (MLIC++, TCM, ...). + + [Minnen2020]: `"Channel-wise Autoregressive Entropy Models for + Learned Image Compression" `_, by + David Minnen and Saurabh Singh, ICIP 2020. + + [Zhu2022]: `"Transformer-based Transform Coding" + `_, by Yinhao Zhu, + Yang Yang and Taco Cohen, ICLR 2022. + """ + + cc_mean_transforms: nn.ModuleList + cc_scale_transforms: nn.ModuleList + lrp_transforms: nn.ModuleList + gaussian_conditional: GaussianConditional + + def __init__( + self, + cc_mean_transforms: nn.ModuleList, + cc_scale_transforms: nn.ModuleList, + lrp_transforms: Optional[nn.ModuleList] = None, + gaussian_conditional: Optional[GaussianConditional] = None, + mean_support_transforms: Optional[nn.ModuleList] = None, + scale_support_transforms: Optional[nn.ModuleList] = None, + *, + num_slices: Optional[int] = None, + max_support_slices: int = -1, + quantizer: str = "ste", + lrp_scale: float = 0.5, + **kwargs: Any, + ) -> None: + super().__init__() + self._kwargs = kwargs + + inferred_num_slices = len(cc_mean_transforms) + if num_slices is None: + num_slices = inferred_num_slices + if inferred_num_slices != num_slices: + raise ValueError( + "cc_mean_transforms must have num_slices entries " + f"(got {inferred_num_slices}, expected {num_slices})" + ) + if len(cc_scale_transforms) != num_slices: + raise ValueError("cc_scale_transforms must have num_slices entries") + if lrp_transforms is not None and len(lrp_transforms) != num_slices: + raise ValueError("lrp_transforms must have num_slices entries") + if ( + mean_support_transforms is not None + and len(mean_support_transforms) != num_slices + ): + raise ValueError("mean_support_transforms must have num_slices entries") + if ( + scale_support_transforms is not None + and len(scale_support_transforms) != num_slices + ): + raise ValueError("scale_support_transforms must have num_slices entries") + if quantizer not in ("ste", "noise"): + raise ValueError(f"unknown quantizer {quantizer!r}") + + self.num_slices = int(num_slices) + self.max_support_slices = int(max_support_slices) + self.quantizer = quantizer + self.lrp_scale = float(lrp_scale) + self.cc_mean_transforms = cc_mean_transforms + self.cc_scale_transforms = cc_scale_transforms + self.mean_support_transforms = mean_support_transforms or nn.ModuleList( + nn.Identity() for _ in range(num_slices) + ) + self.scale_support_transforms = scale_support_transforms or nn.ModuleList( + nn.Identity() for _ in range(num_slices) + ) + self.lrp_transforms = lrp_transforms or nn.ModuleList( + nn.Identity() for _ in range(num_slices) + ) + self.gaussian_conditional = gaussian_conditional or GaussianConditional(None) + + def _support_slices(self, y_hat_slices: Sequence[Tensor]) -> List[Tensor]: + if self.max_support_slices < 0: + return list(y_hat_slices) + return list(y_hat_slices[: self.max_support_slices]) + + def _slice_params( + self, + slice_index: int, + latent_means: Tensor, + latent_scales: Tensor, + y_hat_slices: Sequence[Tensor], + spatial_shape: Tuple[int, int], + ) -> Tuple[Tensor, Tensor, Tensor]: + support = self._support_slices(y_hat_slices) + mean_support = torch.cat([latent_means, *support], dim=1) + mean_support = self.mean_support_transforms[slice_index](mean_support) + mu = self.cc_mean_transforms[slice_index](mean_support) + mu = mu[:, :, : spatial_shape[0], : spatial_shape[1]] + scale_support = torch.cat([latent_scales, *support], dim=1) + scale_support = self.scale_support_transforms[slice_index](scale_support) + scale = self.cc_scale_transforms[slice_index](scale_support) + scale = scale[:, :, : spatial_shape[0], : spatial_shape[1]] + return mu, scale, mean_support + + def _apply_lrp( + self, slice_index: int, mean_support: Tensor, y_hat_slice: Tensor + ) -> Tensor: + lrp = self.lrp_transforms[slice_index]( + torch.cat([mean_support, y_hat_slice], dim=1) + ) + return y_hat_slice + self.lrp_scale * torch.tanh(lrp) + + def forward( + self, + y: Tensor, + latent_means: Tensor, + latent_scales: Tensor, + ) -> Dict[str, Any]: + spatial_shape = (y.shape[2], y.shape[3]) + y_hat_slices: List[Tensor] = [] + y_likelihoods_slices: List[Tensor] = [] + + for slice_index, y_slice in enumerate(y.chunk(self.num_slices, dim=1)): + mu, scale, mean_support = self._slice_params( + slice_index, latent_means, latent_scales, y_hat_slices, spatial_shape + ) + _, y_slice_likelihoods = self.gaussian_conditional(y_slice, scale, means=mu) + if self.quantizer == "ste": + y_hat_slice = quantize_ste(y_slice - mu) + mu + else: + y_hat_slice = self.gaussian_conditional.quantize( + y_slice, "noise" if self.training else "dequantize", mu + ) + y_hat_slice = self._apply_lrp(slice_index, mean_support, y_hat_slice) + y_hat_slices.append(y_hat_slice) + y_likelihoods_slices.append(y_slice_likelihoods) + + return { + "y_hat": torch.cat(y_hat_slices, dim=1), + "likelihoods": {"y": torch.cat(y_likelihoods_slices, dim=1)}, + } + + def compress( + self, + y: Tensor, + latent_means: Tensor, + latent_scales: Tensor, + ) -> Dict[str, Any]: + spatial_shape = (y.shape[2], y.shape[3]) + cdf = self.gaussian_conditional.quantized_cdf.tolist() + cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() + offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() + encoder = BufferedRansEncoder() + symbols_list: List[int] = [] + indexes_list: List[int] = [] + y_hat_slices: List[Tensor] = [] + + for slice_index, y_slice in enumerate(y.chunk(self.num_slices, dim=1)): + mu, scale, mean_support = self._slice_params( + slice_index, latent_means, latent_scales, y_hat_slices, spatial_shape + ) + indexes = self.gaussian_conditional.build_indexes(scale) + y_q_slice = self.gaussian_conditional.quantize(y_slice, "symbols", mu) + y_hat_slice = y_q_slice + mu + symbols_list.extend(y_q_slice.reshape(-1).tolist()) + indexes_list.extend(indexes.reshape(-1).tolist()) + y_hat_slice = self._apply_lrp(slice_index, mean_support, y_hat_slice) + y_hat_slices.append(y_hat_slice) + + encoder.encode_with_indexes( + symbols_list, indexes_list, cdf, cdf_lengths, offsets + ) + return { + "strings": [encoder.flush()], + "shape": spatial_shape, + "y_hat": torch.cat(y_hat_slices, dim=1), + } + + def decompress( + self, + strings: Sequence[bytes], + shape: Tuple[int, int], + latent_means: Tensor, + latent_scales: Tensor, + **kwargs: Any, + ) -> Dict[str, Any]: + cdf = self.gaussian_conditional.quantized_cdf.tolist() + cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() + offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() + decoder = RansDecoder() + decoder.set_stream(strings[0]) + y_hat_slices: List[Tensor] = [] + + for slice_index in range(self.num_slices): + mu, scale, mean_support = self._slice_params( + slice_index, latent_means, latent_scales, y_hat_slices, shape + ) + indexes = self.gaussian_conditional.build_indexes(scale) + values = decoder.decode_stream( + indexes.reshape(-1).tolist(), cdf, cdf_lengths, offsets + ) + y_q_slice = torch.tensor(values, device=mu.device, dtype=mu.dtype).reshape( + mu.shape + ) + y_hat_slice = self.gaussian_conditional.dequantize(y_q_slice, mu) + y_hat_slice = self._apply_lrp(slice_index, mean_support, y_hat_slice) + y_hat_slices.append(y_hat_slice) + + return {"y_hat": torch.cat(y_hat_slices, dim=1)} diff --git a/compressai/layers/__init__.py b/compressai/layers/__init__.py index 0362981c..58ec2402 100644 --- a/compressai/layers/__init__.py +++ b/compressai/layers/__init__.py @@ -30,3 +30,7 @@ from .basic import * from .gdn import * from .layers import * + +# Window-based attention layers in `.attn` depend on `timm` (optional via the +# `[attn]` extras). Not re-exported here so that `import compressai` works +# without `timm` — import them via `from compressai.layers.attn import ...`. diff --git a/compressai/layers/attn/__init__.py b/compressai/layers/attn/__init__.py new file mode 100644 index 00000000..331df45b --- /dev/null +++ b/compressai/layers/attn/__init__.py @@ -0,0 +1,39 @@ +from .swin import ( + WMSA, + ConvTransBlock, + PatchMerging, + PatchSplit, + SWAtten, + SwinBlock, + WindowAttention, + WinNoShiftAttention, + WinResidualUnit, + build_window_attention_mask, + pad_to_window_multiple, + window_partition, + window_reverse, +) + +__all__ = [ + "ConvTransBlock", + "PatchMerging", + "PatchSplit", + "SWAtten", + "SwinBlock", + "WMSA", + "WinNoShiftAttention", + "WinResidualUnit", + "WindowAttention", + "build_window_attention_mask", + "pad_to_window_multiple", + "window_partition", + "window_reverse", +] + + +def __getattr__(name): + if name == "Win_noShift_Attention": + from .swin import Win_noShift_Attention as _alias + + return _alias + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/compressai/layers/attn/swin.py b/compressai/layers/attn/swin.py new file mode 100644 index 00000000..0dc3b06c --- /dev/null +++ b/compressai/layers/attn/swin.py @@ -0,0 +1,570 @@ +from __future__ import annotations + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.layers import DropPath, Mlp +from timm.models.swin_transformer import ( + WindowAttention as _TimmWindowAttention, +) +from timm.models.swin_transformer import ( + window_partition as _timm_window_partition, +) +from timm.models.swin_transformer import ( + window_reverse as _timm_window_reverse, +) +from torch import Tensor + +from ..layers import AttentionBlock, ResidualBlock, conv1x1, conv3x3 + +__all__ = [ + "ConvTransBlock", + "PatchMerging", + "PatchSplit", + "SWAtten", + "SwinBlock", + "WMSA", + "WinNoShiftAttention", + "WinResidualUnit", + "WindowAttention", + "build_window_attention_mask", + "pad_to_window_multiple", + "window_partition", + "window_reverse", +] + + +def window_partition(input_tensor: Tensor, window_size: int) -> Tensor: + """Square-window adapter around timm's ``window_partition``. + + timm uses ``Tuple[int, int]`` for the window size; the STF / WACNN models + in compressai always use square windows, so this thin wrapper keeps the + ``window_size: int`` call-site convention while delegating to timm. + """ + return _timm_window_partition(input_tensor, (window_size, window_size)) + + +def window_reverse( + windows: Tensor, + window_size: int, + height: int, + width: int, +) -> Tensor: + """Square-window adapter around timm's ``window_reverse`` (see + :func:`window_partition` for the rationale).""" + return _timm_window_reverse(windows, (window_size, window_size), height, width) + + +def build_window_attention_mask( + height: int, + width: int, + window_size: int, + shift_size: int, + device: torch.device, +) -> Optional[Tensor]: + if shift_size == 0: + return None + + img_mask = torch.zeros((1, height, width, 1), device=device) + h_slices = ( + slice(0, -window_size), + slice(-window_size, -shift_size), + slice(-shift_size, None), + ) + w_slices = ( + slice(0, -window_size), + slice(-window_size, -shift_size), + slice(-shift_size, None), + ) + + count = 0 + for h_index in h_slices: + for w_index in w_slices: + img_mask[:, h_index, w_index, :] = count + count += 1 + + mask_windows = window_partition(img_mask, window_size) + mask_windows = mask_windows.view(-1, window_size * window_size) + attention_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attention_mask = attention_mask.masked_fill(attention_mask != 0, float(-100.0)) + return attention_mask.masked_fill(attention_mask == 0, float(0.0)) + + +def pad_to_window_multiple( + input_tensor: Tensor, + window_size: Union[int, Tuple[int, int]], + *, + layout: str = "BCHW", +) -> Tuple[Tensor, int, int]: + """Right/bottom-pad a 4D tensor so its spatial dims are multiples of + ``window_size``. + + Args: + input_tensor: 4D tensor in either ``BCHW`` or ``BHWC`` layout. + window_size: ``int`` (square window) or ``(window_h, window_w)``. + layout: ``"BCHW"`` (default, PyTorch convention) or ``"BHWC"`` + (Swin / FTIC token-major layout). + + Returns: + ``(padded_tensor, pad_h, pad_w)``, where ``pad_h`` / ``pad_w`` are + the bottom / right padding widths added to the height / width + dimension respectively. + """ + if isinstance(window_size, int): + win_h = win_w = int(window_size) + else: + win_h, win_w = (int(s) for s in window_size) + + if layout == "BCHW": + height, width = input_tensor.shape[-2], input_tensor.shape[-1] + elif layout == "BHWC": + height, width = input_tensor.shape[1], input_tensor.shape[2] + else: + raise ValueError(f"layout must be 'BCHW' or 'BHWC', got {layout!r}") + + pad_h = (win_h - height % win_h) % win_h + pad_w = (win_w - width % win_w) % win_w + if pad_h == 0 and pad_w == 0: + return input_tensor, 0, 0 + + if layout == "BCHW": + # F.pad on BCHW: (W_left, W_right, H_left, H_right) + return F.pad(input_tensor, (0, pad_w, 0, pad_h)), pad_h, pad_w + # F.pad on BHWC: (C_left, C_right, W_left, W_right, H_left, H_right) + return F.pad(input_tensor, (0, 0, 0, pad_w, 0, pad_h)), pad_h, pad_w + + +class WindowAttention(_TimmWindowAttention): + """timm ``WindowAttention`` with two minor tweaks for compressai: + + 1. ``relative_position_index`` is re-registered as a *persistent* buffer + so released compressai checkpoints (which include this tensor) load + under ``strict=True``. timm registers it as ``persistent=False``. + 2. The constructor accepts an optional ``qk_scale`` to keep STF's + (and CompressAI's) call-site convention; timm always derives the + scale from ``head_dim``. + + Forward / state-dict layout otherwise match timm exactly, including + the optional fused-attention path. + """ + + def __init__( + self, + dim: int, + window_size: int, + num_heads: int, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__( + dim=dim, + num_heads=num_heads, + window_size=window_size, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + if qk_scale is not None: + self.scale = qk_scale + # Promote the index buffer to persistent so checkpoint round-trip + # works without filtering keys at load time. + index = self.relative_position_index + del self._buffers["relative_position_index"] + self.register_buffer("relative_position_index", index, persistent=True) + + +class WMSA(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: Optional[int], + head_dim: int, + window_size: int, + type: str = "W", + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + output_proj: bool = True, + ) -> None: + super().__init__() + if type not in {"W", "SW"}: + raise ValueError(f"Unsupported attention type: {type}") + if input_dim % head_dim != 0: + raise ValueError("`input_dim` must be divisible by `head_dim`.") + + self.window_size = window_size + self.shift_size = 0 if type == "W" else window_size // 2 + self.attn = WindowAttention( + dim=input_dim, + window_size=window_size, + num_heads=input_dim // head_dim, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + # ``output_proj=False`` mirrors the STF / WACNN topology, which feeds + # the WindowAttention output straight back into the downstream block + # without an extra Linear projection. Set ``True`` (default) for the + # SwinBlock / SWAtten variant used by the rest of CompressAI. + self.output_proj = ( + nn.Linear(input_dim, output_dim or input_dim) + if output_proj + else nn.Identity() + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + _, height, width, _ = input_tensor.shape + output, pad_height, pad_width = pad_to_window_multiple( + input_tensor, + self.window_size, + layout="BHWC", + ) + padded_height, padded_width = output.shape[1], output.shape[2] + + if self.shift_size > 0: + mask = build_window_attention_mask( + padded_height, + padded_width, + self.window_size, + self.shift_size, + output.device, + ) + output = torch.roll( + output, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2), + ) + else: + mask = None + + windows = window_partition(output, self.window_size) + windows = windows.view( + -1, + self.window_size * self.window_size, + windows.shape[-1], + ) + windows = self.attn(windows, mask=mask) + windows = windows.view( + -1, + self.window_size, + self.window_size, + windows.shape[-1], + ) + output = window_reverse(windows, self.window_size, padded_height, padded_width) + + if self.shift_size > 0: + output = torch.roll( + output, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2), + ) + if pad_height > 0 or pad_width > 0: + output = output[:, :height, :width, :].contiguous() + return self.output_proj(output) + + +class Block(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: Optional[int], + head_dim: int, + window_size: int, + drop_path: float, + type: str = "W", + mlp_ratio: float = 4.0, + ) -> None: + super().__init__() + output_dim = output_dim or input_dim + self.norm1 = nn.LayerNorm(input_dim) + self.msa = WMSA(input_dim, input_dim, head_dim, window_size, type=type) + self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() + self.norm2 = nn.LayerNorm(input_dim) + self.mlp = Mlp( + in_features=input_dim, + hidden_features=int(input_dim * mlp_ratio), + out_features=output_dim, + ) + self.residual_proj = ( + nn.Linear(input_dim, output_dim) + if input_dim != output_dim + else nn.Identity() + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = input_tensor + self.drop_path(self.msa(self.norm1(input_tensor))) + residual = self.residual_proj(output) + return residual + self.drop_path(self.mlp(self.norm2(output))) + + +class SwinBlock(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: Optional[int], + head_dim: int, + window_size: int, + drop_path: float, + mlp_ratio: float = 4.0, + ) -> None: + super().__init__() + output_dim = output_dim or input_dim + self.block_1 = Block( + input_dim, + input_dim, + head_dim, + window_size, + drop_path, + type="W", + mlp_ratio=mlp_ratio, + ) + self.block_2 = Block( + input_dim, + output_dim, + head_dim, + window_size, + drop_path, + type="SW", + mlp_ratio=mlp_ratio, + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = input_tensor.permute(0, 2, 3, 1).contiguous() + output = self.block_1(output) + output = self.block_2(output) + return output.permute(0, 3, 1, 2).contiguous() + + +class ConvTransBlock(nn.Module): + def __init__( + self, + conv_dim: int, + trans_dim: int, + head_dim: int, + window_size: int, + drop_path: float, + type: str = "W", + mlp_ratio: float = 4.0, + ) -> None: + super().__init__() + if type not in {"W", "SW"}: + raise ValueError(f"Unsupported attention type: {type}") + + self.conv_dim = conv_dim + self.trans_dim = trans_dim + self.conv1_1 = nn.Conv2d(conv_dim + trans_dim, conv_dim + trans_dim, 1) + self.conv1_2 = nn.Conv2d(conv_dim + trans_dim, conv_dim + trans_dim, 1) + self.conv_block = ResidualBlock(conv_dim, conv_dim) + self.trans_block = Block( + trans_dim, + trans_dim, + head_dim, + window_size, + drop_path, + type=type, + mlp_ratio=mlp_ratio, + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + mixed = self.conv1_1(input_tensor) + conv_tensor, trans_tensor = torch.split( + mixed, + (self.conv_dim, self.trans_dim), + dim=1, + ) + conv_tensor = self.conv_block(conv_tensor) + conv_tensor + trans_tensor = trans_tensor.permute(0, 2, 3, 1).contiguous() + trans_tensor = self.trans_block(trans_tensor) + trans_tensor = trans_tensor.permute(0, 3, 1, 2).contiguous() + output = torch.cat((conv_tensor, trans_tensor), dim=1) + return input_tensor + self.conv1_2(output) + + +class SWAtten(AttentionBlock): + def __init__( + self, + input_dim: int, + output_dim: int, + head_dim: int, + window_size: int, + drop_path: float, + inter_dim: Optional[int] = 192, + ) -> None: + hidden_dim = inter_dim or input_dim + super().__init__(N=hidden_dim) + self.in_conv = ( + conv1x1(input_dim, hidden_dim) if inter_dim is not None else nn.Identity() + ) + self.out_conv = ( + conv1x1(hidden_dim, output_dim) if inter_dim is not None else nn.Identity() + ) + self.non_local_block = SwinBlock( + hidden_dim, + hidden_dim, + head_dim, + window_size, + drop_path, + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.in_conv(input_tensor) + identity = output + non_local = self.non_local_block(output) + output = self.conv_a(output) * torch.sigmoid(self.conv_b(non_local)) + output = output + identity + return self.out_conv(output) + + +class WinResidualUnit(nn.Module): + """1x1 -> 3x3 -> 1x1 GELU residual unit; bottleneck width is half the + input channels. Used inside :class:`WinNoShiftAttention`.""" + + def __init__(self, channels: int) -> None: + super().__init__() + self.conv = nn.Sequential( + conv1x1(channels, channels // 2), + nn.GELU(), + conv3x3(channels // 2, channels // 2), + nn.GELU(), + conv1x1(channels // 2, channels), + ) + self.act = nn.GELU() + + def forward(self, input_tensor: Tensor) -> Tensor: + return self.act(self.conv(input_tensor) + input_tensor) + + +class _WinBasedAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + window_size: int, + shift_size: int, + drop_path: float, + output_proj: bool = True, + ) -> None: + super().__init__() + attention_type = "SW" if shift_size > 0 else "W" + self.attn = WMSA( + input_dim=dim, + output_dim=dim, + head_dim=dim // num_heads, + window_size=window_size, + type=attention_type, + output_proj=output_proj, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() + + def forward(self, input_tensor: Tensor) -> Tensor: + output = input_tensor.permute(0, 2, 3, 1).contiguous() + output = self.attn(output) + output = output.permute(0, 3, 1, 2).contiguous() + return input_tensor + self.drop_path(output) + + +class WinNoShiftAttention(nn.Module): + """Sigmoid-gated dual-branch window attention block, used by STF / WACNN + and (with ``output_proj=True``) by other window-attention CompressAI + models. ``output_proj=False`` reproduces the STF / WACNN topology in which + the WindowAttention output feeds straight back into the block without + an additional Linear projection.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + window_size: int = 8, + shift_size: int = 0, + drop_path: float = 0.0, + output_proj: bool = True, + ) -> None: + super().__init__() + self.conv_a = nn.Sequential( + WinResidualUnit(dim), + WinResidualUnit(dim), + WinResidualUnit(dim), + ) + self.conv_b = nn.Sequential( + _WinBasedAttention( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=shift_size, + drop_path=drop_path, + output_proj=output_proj, + ), + WinResidualUnit(dim), + WinResidualUnit(dim), + WinResidualUnit(dim), + conv1x1(dim, dim), + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + return input_tensor + self.conv_a(input_tensor) * torch.sigmoid( + self.conv_b(input_tensor) + ) + + +class PatchMerging(nn.Module): + def __init__(self, dim: int, norm_layer: type[nn.Module] = nn.LayerNorm) -> None: + super().__init__() + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, input_tensor: Tensor, height: int, width: int) -> Tensor: + batch_size, length, channels = input_tensor.shape + if length != height * width: + raise ValueError("Input feature has wrong size.") + + output = input_tensor.view(batch_size, height, width, channels) + if height % 2 == 1 or width % 2 == 1: + output = F.pad(output, (0, 0, 0, width % 2, 0, height % 2)) + + x0 = output[:, 0::2, 0::2, :] + x1 = output[:, 1::2, 0::2, :] + x2 = output[:, 0::2, 1::2, :] + x3 = output[:, 1::2, 1::2, :] + output = torch.cat([x0, x1, x2, x3], dim=-1) + output = output.view(batch_size, -1, 4 * channels) + return self.reduction(self.norm(output)) + + +class PatchSplit(nn.Module): + def __init__(self, dim: int, norm_layer: type[nn.Module] = nn.LayerNorm) -> None: + super().__init__() + self.reduction = nn.Linear(dim, dim * 2, bias=False) + self.norm = norm_layer(dim) + self.shuffle = nn.PixelShuffle(2) + + def forward(self, input_tensor: Tensor, height: int, width: int) -> Tensor: + batch_size, length, channels = input_tensor.shape + if length != height * width: + raise ValueError("Input feature has wrong size.") + + output = self.reduction(self.norm(input_tensor)) + output = output.permute(0, 2, 1).contiguous() + output = output.view(batch_size, 2 * channels, height, width) + output = self.shuffle(output) + output = output.permute(0, 2, 3, 1).contiguous() + return output.view(batch_size, 4 * length, -1) + + +def __getattr__(name): + if name == "Win_noShift_Attention": + import warnings + + warnings.warn( + "Win_noShift_Attention is deprecated; use WinNoShiftAttention instead.", + DeprecationWarning, + stacklevel=2, + ) + return WinNoShiftAttention + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/compressai/models/__init__.py b/compressai/models/__init__.py index 79112b89..097354d1 100644 --- a/compressai/models/__init__.py +++ b/compressai/models/__init__.py @@ -33,3 +33,8 @@ from .sensetime import * from .vbr import * from .waseda import * + +# Models in `.stf` (WACNN, SymmetricalTransFormer) depend on `timm`, which is +# an optional extras dependency (`pip install compressai[attn]`). They are not +# re-exported here so that `import compressai` works without `timm` installed — +# import them directly via `from compressai.models.stf import ...`. diff --git a/compressai/models/_bases/__init__.py b/compressai/models/_bases/__init__.py new file mode 100644 index 00000000..065119c2 --- /dev/null +++ b/compressai/models/_bases/__init__.py @@ -0,0 +1,24 @@ +"""Abstract base classes shared by multiple slice-based LIC models. + +These were historically hidden behind ``stf_support`` / ``dcae_support`` file +names which obscured the fact that they're real abstract :class:`CompressionModel` +subclasses inherited by 3-4 models each. +""" + +from .slice_entropy import ( + SliceEntropyCompressionModel, + infer_max_support_slices, + infer_num_slices, + lrp_support_channels, + make_entropy_transform, + slice_support_channels, +) + +__all__ = [ + "SliceEntropyCompressionModel", + "infer_max_support_slices", + "infer_num_slices", + "lrp_support_channels", + "make_entropy_transform", + "slice_support_channels", +] diff --git a/compressai/models/_bases/slice_entropy.py b/compressai/models/_bases/slice_entropy.py new file mode 100644 index 00000000..7159468f --- /dev/null +++ b/compressai/models/_bases/slice_entropy.py @@ -0,0 +1,260 @@ +"""Slice-conditional entropy backbone shared by WACNN / SymmetricalTransFormer / MambaVC. + +Promoted out of the historical ``models/stf_support.py`` so the abstract base +class is discoverable by name. Channel-counting helpers and a parameterised +entropy-transform factory live here too — they used to be duplicated across +``stf_support`` / ``ssm_support`` / ``weconvene_support``. +""" + +from __future__ import annotations + +from typing import Dict, Optional, Sequence, Tuple + +import torch.nn as nn + +from torch import Tensor + +from compressai.entropy_models import EntropyBottleneck +from compressai.latent_codecs import ChannelSliceLatentCodec +from compressai.models.utils import conv + +from ..base import CompressionModel + +__all__ = [ + "SliceEntropyCompressionModel", + "infer_max_support_slices", + "infer_num_slices", + "lrp_support_channels", + "make_entropy_transform", + "slice_support_channels", +] + + +_DEFAULT_NUM_SLICES_PREFIX = "latent_codec.cc_mean_transforms." +_KEY_SUFFIX = ".0.weight" + + +def slice_support_channels( + latent_channels: int, + slice_channels: int, + index: int, + max_support_slices: int, +) -> int: + if max_support_slices < 0: + return latent_channels + slice_channels * index + return latent_channels + slice_channels * min(index, max_support_slices) + + +def lrp_support_channels( + latent_channels: int, + slice_channels: int, + index: int, + max_support_slices: int, +) -> int: + if max_support_slices < 0: + return latent_channels + slice_channels * (index + 1) + return latent_channels + slice_channels * min(index + 1, max_support_slices + 1) + + +def make_entropy_transform( + in_channels: int, + out_channels: int, + *, + widths: Sequence[int] = (224, 128), +) -> nn.Sequential: + """Stack of stride-1 3x3 convs with GELU between, used by every slice + entropy model. ``widths`` specifies hidden conv widths; defaults to the + Mamba/WeConvene 3-conv stack. Pass ``widths=(224, 176, 128, 64)`` for the + STF/WACNN 5-conv stack.""" + layers: list[nn.Module] = [] + prev = in_channels + for width in widths: + layers.append(conv(prev, width, stride=1, kernel_size=3)) + layers.append(nn.GELU()) + prev = width + layers.append(conv(prev, out_channels, stride=1, kernel_size=3)) + return nn.Sequential(*layers) + + +def infer_num_slices( + state_dict: Dict[str, Tensor], + *, + prefix: str = _DEFAULT_NUM_SLICES_PREFIX, + suffix: str = _KEY_SUFFIX, +) -> int: + slice_indices = { + int(key[len(prefix) :].split(".", 1)[0]) + for key in state_dict + if key.startswith(prefix) and key.endswith(suffix) + } + return len(slice_indices) + + +def infer_max_support_slices( + state_dict: Dict[str, Tensor], + latent_channels: int, + num_slices: int, + *, + prefix: str = _DEFAULT_NUM_SLICES_PREFIX, + suffix: str = _KEY_SUFFIX, + extra_factor: int = 1, +) -> int: + """Infer ``max_support_slices`` from the input width of the first + cc_mean transform conv. ``extra_factor`` accounts for models like DCAE/SAAF + that prepend additional copies of the latent (``M*3 + slice_channels*N``); + pass ``extra_factor=3`` there. Slice-only models (STF/Mamba*) keep the + default ``extra_factor=1``.""" + slice_channels = latent_channels // num_slices + matching = [ + tensor.size(1) + for key, tensor in state_dict.items() + if key.startswith(prefix) and key.endswith(suffix) + ] + if not matching: + return 0 + max_input_channels = max(matching) + return max( + 0, (max_input_channels - extra_factor * latent_channels) // slice_channels + ) + + +class SliceEntropyCompressionModel(CompressionModel): + """Channel-conditional entropy backbone shared by WACNN, SymmetricalTransFormer, MambaVC. + + Subclasses must populate ``g_a``, ``g_s``, ``h_a``, ``h_mean_s`` and + ``h_scale_s``, then call :meth:`_init_slice_entropy` to wire up the + entropy bottleneck for ``z`` and the :class:`ChannelSliceLatentCodec` + for ``y``. + """ + + h_a: nn.Module + h_mean_s: nn.Module + h_scale_s: nn.Module + entropy_bottleneck: EntropyBottleneck + latent_codec: ChannelSliceLatentCodec + + def _init_slice_entropy( + self, + latent_channels: int, + entropy_bottleneck_channels: int, + num_slices: int, + max_support_slices: int, + mean_support_transforms: Optional[nn.ModuleList] = None, + scale_support_transforms: Optional[nn.ModuleList] = None, + ) -> None: + if latent_channels % num_slices != 0: + raise ValueError("latent_channels must be divisible by num_slices") + if ( + mean_support_transforms is not None + and len(mean_support_transforms) != num_slices + ): + raise ValueError("mean_support_transforms must have num_slices entries") + if ( + scale_support_transforms is not None + and len(scale_support_transforms) != num_slices + ): + raise ValueError("scale_support_transforms must have num_slices entries") + + slice_channels = latent_channels // num_slices + widths = (224, 176, 128, 64) + cc_mean_transforms = nn.ModuleList( + make_entropy_transform( + slice_support_channels( + latent_channels, slice_channels, index, max_support_slices + ), + slice_channels, + widths=widths, + ) + for index in range(num_slices) + ) + cc_scale_transforms = nn.ModuleList( + make_entropy_transform( + slice_support_channels( + latent_channels, slice_channels, index, max_support_slices + ), + slice_channels, + widths=widths, + ) + for index in range(num_slices) + ) + lrp_transforms = nn.ModuleList( + make_entropy_transform( + lrp_support_channels( + latent_channels, slice_channels, index, max_support_slices + ), + slice_channels, + widths=widths, + ) + for index in range(num_slices) + ) + + self.entropy_bottleneck = EntropyBottleneck(entropy_bottleneck_channels) + self.latent_codec = ChannelSliceLatentCodec( + cc_mean_transforms=cc_mean_transforms, + cc_scale_transforms=cc_scale_transforms, + lrp_transforms=lrp_transforms, + mean_support_transforms=mean_support_transforms, + scale_support_transforms=scale_support_transforms, + num_slices=num_slices, + max_support_slices=max_support_slices, + quantizer="ste", + ) + + @property + def num_slices(self) -> int: + return self.latent_codec.num_slices + + @property + def max_support_slices(self) -> int: + return self.latent_codec.max_support_slices + + def _hyper_priors(self, y: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + z = self.h_a(y) + z_hat, z_likelihoods = self.entropy_bottleneck(z) + latent_means = self.h_mean_s(z_hat) + latent_scales = self.h_scale_s(z_hat) + return z, z_likelihoods, latent_means, latent_scales + + def _forward_latent_output( + self, y: Tensor + ) -> Dict[str, Dict[str, Tensor] | Tensor]: + _, z_likelihoods, latent_means, latent_scales = self._hyper_priors(y) + y_out = self.latent_codec(y, latent_means, latent_scales) + output: Dict[str, Dict[str, Tensor] | Tensor] = { + "y_hat": y_out["y_hat"], + "likelihoods": {"y": y_out["likelihoods"]["y"], "z": z_likelihoods}, + } + return output + + def _forward_latent(self, y: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + output = self._forward_latent_output(y) + return output["y_hat"], output["likelihoods"]["y"], output["likelihoods"]["z"] + + def _compress_latent(self, y: Tensor) -> Dict[str, object]: + z = self.h_a(y) + z_strings = self.entropy_bottleneck.compress(z) + z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) + latent_means = self.h_mean_s(z_hat) + latent_scales = self.h_scale_s(z_hat) + y_out = self.latent_codec.compress(y, latent_means, latent_scales) + return { + "strings": [[y_out["strings"][0]], z_strings], + "shape": z.size()[-2:], + } + + def _decompress_latent( + self, + strings: Sequence[Sequence[bytes]], + shape: Tuple[int, int], + ) -> Tensor: + if len(strings) != 2: + raise ValueError("strings must contain [y_strings, z_strings]") + + z_hat = self.entropy_bottleneck.decompress(strings[1], shape) + latent_means = self.h_mean_s(z_hat) + latent_scales = self.h_scale_s(z_hat) + y_shape = (z_hat.shape[2] * 4, z_hat.shape[3] * 4) + y_out = self.latent_codec.decompress( + strings[0], y_shape, latent_means, latent_scales + ) + return y_out["y_hat"] diff --git a/compressai/models/stf.py b/compressai/models/stf.py new file mode 100644 index 00000000..5074d5b6 --- /dev/null +++ b/compressai/models/stf.py @@ -0,0 +1,645 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts code from https://github.com/Googolxx/STF +# (originally distributed under the Apache License 2.0). The upstream copyright +# notice is preserved in that repository; modifications by InterDigital +# Communications, Inc. are released under the BSD 3-Clause Clear License +# terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import math + +from typing import Dict, Optional, Sequence, Tuple, Type + +import torch +import torch.nn as nn + +from timm.layers import DropPath, Mlp +from timm.models.swin_transformer import SwinTransformerBlock as _TimmSwinBlock +from torch import Tensor + +from compressai.layers import GDN, conv1x1, conv3x3, subpel_conv3x3 +from compressai.layers.attn import ( + PatchMerging, + PatchSplit, + WinNoShiftAttention, +) +from compressai.models._bases import ( + SliceEntropyCompressionModel, + infer_max_support_slices, + infer_num_slices, +) +from compressai.models.utils import conv, deconv +from compressai.registry import register_model + +__all__ = [ + "SymmetricalTransFormer", + "WACNN", + "convert_upstream_stf_state_dict", +] + + +# ---------------------------------------------------------------------------- +# STF building blocks +# (formerly compressai/layers/lic/stf.py; private to the WACNN / SymmetricalTransFormer models) +# ---------------------------------------------------------------------------- + + +class _STFBasicLayer(nn.Module): + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + window_size: int = 7, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float | Sequence[float] = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + downsample: Optional[Type[nn.Module]] = None, + ) -> None: + del qk_scale # timm SwinTransformerBlock derives scale from head_dim + super().__init__() + drop_path_values = ( + list(drop_path) + if isinstance(drop_path, Sequence) + and not isinstance(drop_path, (str, bytes)) + else [float(drop_path)] * depth + ) + self.window_size = window_size + self.shift_size = window_size // 2 + self.blocks = nn.ModuleList( + [ + _TimmSwinBlock( + dim=dim, + input_resolution=(0, 0), # ignored when always_partition=True + num_heads=num_heads, + window_size=window_size, + shift_size=0 if index % 2 == 0 else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=drop, + attn_drop=attn_drop, + drop_path=drop_path_values[index], + norm_layer=norm_layer, + always_partition=True, # keep configured window/shift even if input is small + dynamic_mask=True, + ) + for index in range(depth) + ] + ) + self.downsample = ( + downsample(dim=dim, norm_layer=norm_layer) if downsample else None + ) + + # Released STF checkpoints carry `attn.relative_position_index` per block + # (the upstream WindowAttention registers it as a persistent buffer). + # timm's WindowAttention uses persistent=False, so promote it here so + # strict-mode state_dict loading round-trips without filtering keys. + for block in self.blocks: + index = block.attn.relative_position_index + del block.attn._buffers["relative_position_index"] + block.attn.register_buffer( + "relative_position_index", index, persistent=True + ) + + def forward( + self, input_tensor: Tensor, height: int, width: int + ) -> tuple[Tensor, int, int]: + batch_size, length, channels = input_tensor.shape + if length != height * width: + raise ValueError("input feature has wrong size") + x = input_tensor.view(batch_size, height, width, channels) + for block in self.blocks: + x = block(x) + x = x.reshape(batch_size, height * width, channels) + + if self.downsample is None: + return x, height, width + + x = self.downsample(x, height, width) + if isinstance(self.downsample, PatchMerging): + return x, (height + 1) // 2, (width + 1) // 2 + return x, height * 2, width * 2 + + +class _PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 4, + in_chans: int = 3, + embed_dim: int = 96, + norm_layer: Optional[Type[nn.Module]] = None, + ) -> None: + super().__init__() + self.patch_size = (patch_size, patch_size) + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + self.norm = norm_layer(embed_dim) if norm_layer is not None else None + + def forward(self, input_tensor: Tensor) -> Tensor: + _, _, height, width = input_tensor.size() + if width % self.patch_size[1] != 0: + input_tensor = nn.functional.pad( + input_tensor, + (0, self.patch_size[1] - width % self.patch_size[1]), + ) + if height % self.patch_size[0] != 0: + input_tensor = nn.functional.pad( + input_tensor, + (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]), + ) + + output = self.proj(input_tensor) + if self.norm is None: + return output + + out_height, out_width = output.size(2), output.size(3) + output = output.flatten(2).transpose(1, 2) + output = self.norm(output) + return output.transpose(1, 2).view(-1, self.embed_dim, out_height, out_width) + + +# ---------------------------------------------------------------------------- +# STF / WACNN models +# ---------------------------------------------------------------------------- + + +_UPSTREAM_LATENT_CODEC_PREFIXES = ( + "cc_mean_transforms", + "cc_scale_transforms", + "lrp_transforms", + "gaussian_conditional", +) + + +def convert_upstream_stf_state_dict(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: + """Translate a candidate ``STF`` / ``WACNN`` state dict into compressai layout. + + Upstream checkpoints (``stf__best.pth.tar`` / ``cnn__best.pth.tar`` + from `Zou et al. 2022 `_) are saved from a + ``DataParallel``-wrapped module and place the channel-conditional entropy + transforms at the model root. compressai houses those transforms (plus the + Gaussian conditional) under ``latent_codec.*``. This helper: + + - strips the leading ``module.`` prefix added by ``DataParallel``; + - re-roots ``cc_mean_transforms`` / ``cc_scale_transforms`` / + ``lrp_transforms`` / ``gaussian_conditional`` under ``latent_codec.``; + - leaves ``g_a`` / ``g_s`` / ``patch_embed`` / ``layers`` / ``syn_layers`` + / ``end_conv`` / ``h_a`` / ``h_mean_s`` / ``h_scale_s`` / + ``entropy_bottleneck`` keys unchanged. + + The returned dict can be loaded by :meth:`WACNN.from_state_dict` or + :meth:`SymmetricalTransFormer.from_state_dict`. Both ``from_state_dict`` + entry points auto-detect the upstream layout and call this helper, so + direct invocation is only needed when persisting the converted dict. + """ + converted: Dict[str, Tensor] = {} + for key, value in state_dict.items(): + new_key = key[len("module.") :] if key.startswith("module.") else key + head = new_key.split(".", 1)[0] + if head in _UPSTREAM_LATENT_CODEC_PREFIXES: + new_key = "latent_codec." + new_key + converted[new_key] = value + return converted + + +def _is_upstream_stf_state_dict(state_dict: Dict[str, Tensor]) -> bool: + """Heuristic: upstream checkpoints either carry a ``module.`` prefix or + place ``cc_mean_transforms`` at the root instead of under ``latent_codec``. + """ + for key in state_dict: + if key.startswith("module."): + return True + if key.startswith("cc_mean_transforms.") or key.startswith( + "gaussian_conditional." + ): + return True + return False + + +@register_model("stf-wacnn") +class WACNN(SliceEntropyCompressionModel): + r"""WACNN model from R. Zou, C. Song, Z. Zhang: `"The Devil Is in the + Details: Window-based Attention for Image Compression" + `_, IEEE/CVF Conf. on Computer Vision + and Pattern Recognition (CVPR), 2022. + + CNN-based variant that inserts window-based attention modules + (:class:`compressai.layers.attn.WinNoShiftAttention` with + ``output_proj=False``) inside the analysis/synthesis transforms, paired + with a Minnen2020-style channel-wise autoregressive entropy model. + + Args: + N (int): Number of channels in the hyperprior backbone. + M (int): Number of channels in the latent representation. + num_slices (int): Number of channel slices for the entropy model. + """ + + def __init__( + self, + N: int = 192, + M: int = 320, + num_slices: int = 10, + max_support_slices: int = 5, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.g_a = nn.Sequential( + conv(3, N, kernel_size=5, stride=2), + GDN(N), + conv(N, N, kernel_size=5, stride=2), + GDN(N), + WinNoShiftAttention( + dim=N, num_heads=8, window_size=8, shift_size=4, output_proj=False + ), + conv(N, N, kernel_size=5, stride=2), + GDN(N), + conv(N, M, kernel_size=5, stride=2), + WinNoShiftAttention( + dim=M, num_heads=8, window_size=4, shift_size=2, output_proj=False + ), + ) + self.g_s = nn.Sequential( + WinNoShiftAttention( + dim=M, num_heads=8, window_size=4, shift_size=2, output_proj=False + ), + deconv(M, N, kernel_size=5, stride=2), + GDN(N, inverse=True), + deconv(N, N, kernel_size=5, stride=2), + GDN(N, inverse=True), + WinNoShiftAttention( + dim=N, num_heads=8, window_size=8, shift_size=4, output_proj=False + ), + deconv(N, N, kernel_size=5, stride=2), + GDN(N, inverse=True), + deconv(N, 3, kernel_size=5, stride=2), + ) + self.h_a = nn.Sequential( + conv3x3(M, M), + nn.GELU(), + conv3x3(M, 288), + nn.GELU(), + conv3x3(288, 256, stride=2), + nn.GELU(), + conv3x3(256, 224), + nn.GELU(), + conv3x3(224, N, stride=2), + ) + self.h_mean_s = nn.Sequential( + conv3x3(N, N), + nn.GELU(), + subpel_conv3x3(N, 224, 2), + nn.GELU(), + conv3x3(224, 256), + nn.GELU(), + subpel_conv3x3(256, 288, 2), + nn.GELU(), + conv3x3(288, M), + ) + self.h_scale_s = nn.Sequential( + conv3x3(N, N), + nn.GELU(), + subpel_conv3x3(N, 224, 2), + nn.GELU(), + conv3x3(224, 256), + nn.GELU(), + subpel_conv3x3(256, 288, 2), + nn.GELU(), + conv3x3(288, M), + ) + self._init_slice_entropy( + M, + N, + num_slices, + max_support_slices, + ) + + def forward(self, x: Tensor) -> Dict[str, Dict[str, Tensor] | Tensor]: + y = self.g_a(x) + latent_output = self._forward_latent_output(y) + return { + "x_hat": self.g_s(latent_output["y_hat"]), + "likelihoods": latent_output["likelihoods"], + } + + def compress(self, x: Tensor) -> Dict[str, object]: + return self._compress_latent(self.g_a(x)) + + def decompress( + self, strings: Sequence[Sequence[bytes]], shape: Tuple[int, int] + ) -> Dict[str, Tensor]: + return {"x_hat": self.g_s(self._decompress_latent(strings, shape)).clamp_(0, 1)} + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "WACNN": + if _is_upstream_stf_state_dict(state_dict): + state_dict = convert_upstream_stf_state_dict(state_dict) + N = state_dict["g_a.0.weight"].size(0) + M = state_dict["g_a.7.weight"].size(0) + num_slices = infer_num_slices(state_dict) or 10 + max_support_slices = infer_max_support_slices(state_dict, M, num_slices) + net = cls( + N=N, + M=M, + num_slices=num_slices, + max_support_slices=max_support_slices, + ) + net.load_state_dict(state_dict) + return net + + +@register_model("stf") +class SymmetricalTransFormer(SliceEntropyCompressionModel): + r"""Symmetrical Transformer model (STF) from R. Zou, C. Song, Z. Zhang: + `"The Devil Is in the Details: Window-based Attention for Image + Compression" `_, IEEE/CVF Conf. on + Computer Vision and Pattern Recognition (CVPR), 2022. + + Transformer-based companion of :class:`WACNN` that builds the + analysis/synthesis transforms with stacked Swin-style basic layers and a + channel-wise autoregressive entropy model. + + Args: + embed_dim (int): Patch-embedding dimension. + num_slices (int): Number of channel slices for the entropy model. + """ + + def __init__( + self, + pretrain_img_size: int = 256, + patch_size: int = 2, + in_chans: int = 3, + embed_dim: int = 48, + depths: Optional[Sequence[int]] = None, + num_heads: Optional[Sequence[int]] = None, + window_size: int = 4, + num_slices: int = 12, + max_support_slices: Optional[int] = None, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.2, + norm_layer: type[nn.Module] = nn.LayerNorm, + patch_norm: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + depths = list(depths or [2, 2, 6, 2]) + num_heads = list(num_heads or [3, 6, 12, 24]) + if len(depths) != len(num_heads): + raise ValueError("depths and num_heads must have the same length") + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.patch_embed = _PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if patch_norm else None, + ) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [value.item() for value in torch.linspace(0, drop_path_rate, sum(depths))] + self.layers = nn.ModuleList() + for layer_index in range(self.num_layers): + self.layers.append( + _STFBasicLayer( + dim=int(embed_dim * 2**layer_index), + depth=depths[layer_index], + num_heads=num_heads[layer_index], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[ + sum(depths[:layer_index]) : sum(depths[: layer_index + 1]) + ], + norm_layer=norm_layer, + downsample=None + if layer_index == self.num_layers - 1 + else PatchMerging, + ) + ) + + reversed_depths = list(reversed(depths)) + reversed_heads = list(reversed(num_heads)) + self.syn_layers = nn.ModuleList() + for layer_index in range(self.num_layers): + self.syn_layers.append( + _STFBasicLayer( + dim=int(embed_dim * 2 ** (self.num_layers - 1 - layer_index)), + depth=reversed_depths[layer_index], + num_heads=reversed_heads[layer_index], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[ + sum(reversed_depths[:layer_index]) : sum( + reversed_depths[: layer_index + 1] + ) + ], + norm_layer=norm_layer, + downsample=None + if layer_index == self.num_layers - 1 + else PatchSplit, + ) + ) + + self.end_conv = nn.Sequential( + nn.Conv2d( + embed_dim, embed_dim * patch_size**2, kernel_size=5, stride=1, padding=2 + ), + nn.PixelShuffle(patch_size), + nn.Conv2d(embed_dim, 3, kernel_size=3, stride=1, padding=1), + ) + + latent_channels = int(embed_dim * 2 ** (self.num_layers - 1)) + bottleneck_channels = latent_channels // 2 + self.h_a = nn.Sequential( + conv3x3(latent_channels, latent_channels), + nn.GELU(), + conv3x3(latent_channels, latent_channels - embed_dim), + nn.GELU(), + conv3x3( + latent_channels - embed_dim, latent_channels - 2 * embed_dim, stride=2 + ), + nn.GELU(), + conv3x3(latent_channels - 2 * embed_dim, latent_channels - 3 * embed_dim), + nn.GELU(), + conv3x3(latent_channels - 3 * embed_dim, bottleneck_channels, stride=2), + ) + self.h_mean_s = nn.Sequential( + conv3x3(bottleneck_channels, latent_channels - 3 * embed_dim), + nn.GELU(), + subpel_conv3x3( + latent_channels - 3 * embed_dim, latent_channels - 2 * embed_dim, 2 + ), + nn.GELU(), + conv3x3(latent_channels - 2 * embed_dim, latent_channels - embed_dim), + nn.GELU(), + subpel_conv3x3(latent_channels - embed_dim, latent_channels, 2), + nn.GELU(), + conv3x3(latent_channels, latent_channels), + ) + self.h_scale_s = nn.Sequential( + conv3x3(bottleneck_channels, latent_channels - 3 * embed_dim), + nn.GELU(), + subpel_conv3x3( + latent_channels - 3 * embed_dim, latent_channels - 2 * embed_dim, 2 + ), + nn.GELU(), + conv3x3(latent_channels - 2 * embed_dim, latent_channels - embed_dim), + nn.GELU(), + subpel_conv3x3(latent_channels - embed_dim, latent_channels, 2), + nn.GELU(), + conv3x3(latent_channels, latent_channels), + ) + self._init_slice_entropy( + latent_channels, + bottleneck_channels, + num_slices, + num_slices // 2 if max_support_slices is None else max_support_slices, + ) + + def _analysis_transform(self, x: Tensor) -> Tuple[Tensor, int, int]: + output = self.patch_embed(x) + height, width = output.size(2), output.size(3) + output = self.pos_drop(output.flatten(2).transpose(1, 2)) + for layer in self.layers: + output, height, width = layer(output, height, width) + channels = self.embed_dim * 2 ** (self.num_layers - 1) + output = ( + output.view(-1, height, width, channels).permute(0, 3, 1, 2).contiguous() + ) + return output, height, width + + def _synthesis_transform(self, y_hat: Tensor, height: int, width: int) -> Tensor: + channels = self.embed_dim * 2 ** (self.num_layers - 1) + output = ( + y_hat.permute(0, 2, 3, 1).contiguous().view(-1, height * width, channels) + ) + for layer in self.syn_layers: + output, height, width = layer(output, height, width) + output = ( + output.view(-1, height, width, self.embed_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + return self.end_conv(output) + + def forward(self, x: Tensor) -> Dict[str, Dict[str, Tensor] | Tensor]: + y, height, width = self._analysis_transform(x) + latent_output = self._forward_latent_output(y) + return { + "x_hat": self._synthesis_transform(latent_output["y_hat"], height, width), + "likelihoods": latent_output["likelihoods"], + } + + def compress(self, x: Tensor) -> Dict[str, object]: + y, _, _ = self._analysis_transform(x) + return self._compress_latent(y) + + def decompress( + self, strings: Sequence[Sequence[bytes]], shape: Tuple[int, int] + ) -> Dict[str, Tensor]: + y_hat = self._decompress_latent(strings, shape) + height, width = y_hat.shape[2:] + return {"x_hat": self._synthesis_transform(y_hat, height, width).clamp_(0, 1)} + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "SymmetricalTransFormer": + if _is_upstream_stf_state_dict(state_dict): + state_dict = convert_upstream_stf_state_dict(state_dict) + patch_size = state_dict["patch_embed.proj.weight"].size(2) + embed_dim = state_dict["patch_embed.proj.weight"].size(0) + layer_indices = sorted( + { + int(key.split(".")[1]) + for key in state_dict + if key.startswith("layers.") and ".blocks." in key + } + ) + depths = [ + len( + { + int(key.split(".")[3]) + for key in state_dict + if key.startswith(f"layers.{layer_index}.blocks.") + } + ) + for layer_index in layer_indices + ] + num_heads = [ + state_dict[ + f"layers.{layer_index}.blocks.0.attn.relative_position_bias_table" + ].size(1) + for layer_index in layer_indices + ] + table_size = state_dict[ + "layers.0.blocks.0.attn.relative_position_bias_table" + ].size(0) + window_size = (math.isqrt(table_size) + 1) // 2 + num_slices = infer_num_slices(state_dict) or 12 + latent_channels = embed_dim * 2 ** (len(depths) - 1) + max_support_slices = infer_max_support_slices( + state_dict, latent_channels, num_slices + ) + + net = cls( + patch_size=patch_size, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + num_slices=num_slices, + max_support_slices=max_support_slices, + ) + net.load_state_dict(state_dict) + return net diff --git a/compressai/zoo/__init__.py b/compressai/zoo/__init__.py index 5c56bee7..acebc705 100644 --- a/compressai/zoo/__init__.py +++ b/compressai/zoo/__init__.py @@ -35,6 +35,8 @@ cheng2020_attn, mbt2018, mbt2018_mean, + stf, + stf_wacnn, ) from .image_vbr import bmshj2018_hyperprior_vbr, mbt2018_mean_vbr, mbt2018_vbr from .pretrained import load_pretrained as load_state_dict @@ -48,6 +50,8 @@ "mbt2018": mbt2018, "cheng2020-anchor": cheng2020_anchor, "cheng2020-attn": cheng2020_attn, + "stf": stf, + "stf-wacnn": stf_wacnn, "bmshj2018-hyperprior-vbr": bmshj2018_hyperprior_vbr, "mbt2018-mean-vbr": mbt2018_mean_vbr, "mbt2018-vbr": mbt2018_vbr, diff --git a/compressai/zoo/image.py b/compressai/zoo/image.py index e0c34492..f506d6bf 100644 --- a/compressai/zoo/image.py +++ b/compressai/zoo/image.py @@ -41,6 +41,39 @@ from .pretrained import load_pretrained + +class _LazyImport: + """Defer ``from import `` until the attribute is actually used. + + Lets ``model_architectures`` reference classes that live in modules with + optional dependencies (e.g. ``compressai.models.stf`` needs ``timm``) + without importing them at zoo-import time. Forwards ``__call__`` and + attribute access to the resolved class so existing call sites + (``model_architectures[arch](...)`` and ``....from_state_dict(...)``) + keep working. + """ + + __slots__ = ("_module", "_name", "_resolved") + + def __init__(self, module: str, name: str): + self._module = module + self._name = name + self._resolved = None + + def _load(self): + if self._resolved is None: + import importlib + + self._resolved = getattr(importlib.import_module(self._module), self._name) + return self._resolved + + def __call__(self, *args, **kwargs): + return self._load()(*args, **kwargs) + + def __getattr__(self, item): + return getattr(self._load(), item) + + __all__ = [ "bmshj2018_factorized", "bmshj2018_factorized_relu", @@ -49,6 +82,8 @@ "mbt2018_mean", "cheng2020_anchor", "cheng2020_attn", + "stf", + "stf_wacnn", ] model_architectures = { @@ -59,6 +94,9 @@ "mbt2018": JointAutoregressiveHierarchicalPriors, "cheng2020-anchor": Cheng2020Anchor, "cheng2020-attn": Cheng2020Attention, + # Resolved lazily so `compressai.zoo` is importable without `timm`. + "stf": _LazyImport("compressai.models.stf", "SymmetricalTransFormer"), + "stf-wacnn": _LazyImport("compressai.models.stf", "WACNN"), } root_url = "https://compressai.s3.amazonaws.com/models/v1" @@ -447,3 +485,43 @@ def cheng2020_attn(quality, metric="mse", pretrained=False, progress=True, **kwa return _load_model( "cheng2020-attn", metric, quality, pretrained, progress, **kwargs ) + + +def stf(pretrained: bool = False, progress: bool = True, **kwargs): + r"""Symmetrical TransFormer (STF) model from R. Zou, C. Song, Z. Zhang: + `"The Devil Is in the Details: Window-based Attention for Image + Compression" `_, IEEE/CVF Conf. on + Computer Vision and Pattern Recognition (CVPR), 2022. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained STF weights are not yet hosted on S3.") + from compressai.models.stf import SymmetricalTransFormer + + return SymmetricalTransFormer(**kwargs) + + +def stf_wacnn(pretrained: bool = False, progress: bool = True, **kwargs): + r"""WACNN model from R. Zou, C. Song, Z. Zhang: + `"The Devil Is in the Details: Window-based Attention for Image + Compression" `_, IEEE/CVF Conf. on + Computer Vision and Pattern Recognition (CVPR), 2022. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained WACNN weights are not yet hosted on S3.") + from compressai.models.stf import WACNN + + return WACNN(**kwargs) diff --git a/examples/convert_stf_checkpoint.py b/examples/convert_stf_checkpoint.py new file mode 100644 index 00000000..914857be --- /dev/null +++ b/examples/convert_stf_checkpoint.py @@ -0,0 +1,137 @@ +"""Convert an upstream STF / WACNN checkpoint to compressai layout. + +Loads the published candidate weight file (e.g. ``stf_0018_best.pth.tar`` or +``cnn_0018_best.pth.tar`` from the STF repo), translates it to compressai's +module layout, and writes a state dict that +``compressai.models.SymmetricalTransFormer.from_state_dict`` / +``compressai.models.WACNN.from_state_dict`` can load directly. Optionally +reports forward-pass sanity numbers (PSNR / bpp) on a synthetic input. + +Example:: + + python examples/convert_stf_checkpoint.py \\ + --src candidate/STF/stf_0018_best.pth.tar \\ + --arch stf \\ + --dst /tmp/stf_compressai.pth \\ + --smoke + + python examples/convert_stf_checkpoint.py \\ + --src candidate/STF/cnn_0018_best.pth.tar \\ + --arch wacnn \\ + --smoke +""" + +from __future__ import annotations + +import argparse + +from pathlib import Path + +import torch + +from compressai.models.stf import ( + WACNN, + SymmetricalTransFormer, + convert_upstream_stf_state_dict, +) + +_ARCHES = {"stf": SymmetricalTransFormer, "wacnn": WACNN} + + +def _detect_arch(state_dict: dict) -> str: + keys = state_dict.keys() + if any("patch_embed" in k for k in keys): + return "stf" + if any(k.endswith("g_a.0.weight") for k in keys): + return "wacnn" + raise SystemExit("could not auto-detect arch; pass --arch {stf,wacnn} explicitly") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--src", + type=Path, + required=True, + help="Path to the upstream checkpoint (e.g. stf_0018_best.pth.tar).", + ) + parser.add_argument( + "--arch", + choices=sorted(_ARCHES), + default=None, + help="Architecture to instantiate. Auto-detected from key names if omitted.", + ) + parser.add_argument( + "--dst", + type=Path, + default=None, + help=( + "Optional output path for the converted state dict. If omitted, " + "the script only verifies that the checkpoint loads cleanly." + ), + ) + parser.add_argument( + "--smoke", + action="store_true", + help="Run a forward smoke test on a synthetic 256x256 image.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if not args.src.exists(): + raise SystemExit(f"checkpoint not found: {args.src}") + + upstream = torch.load(args.src, map_location="cpu", weights_only=False) + upstream = ( + upstream.get("state_dict", upstream) if isinstance(upstream, dict) else upstream + ) + converted = convert_upstream_stf_state_dict(upstream) + print(f"loaded {len(upstream)} upstream keys → {len(converted)} compressai keys") + + arch = args.arch or _detect_arch(upstream) + cls = _ARCHES[arch] + net = cls.from_state_dict(upstream) + net.eval() + print(f"loaded {arch.upper()}: {sum(p.numel() for p in net.parameters()):,} params") + + if args.dst is not None: + args.dst.parent.mkdir(parents=True, exist_ok=True) + torch.save(net.state_dict(), args.dst) + print(f"wrote converted state dict → {args.dst}") + + if args.smoke: + height = width = 256 + ys, xs = torch.meshgrid( + torch.linspace(0, 1, height), + torch.linspace(0, 1, width), + indexing="ij", + ) + img = ( + torch.stack( + [ + 0.5 + 0.3 * torch.sin(8 * xs), + 0.5 + 0.3 * torch.sin(8 * ys), + 0.5 + 0.3 * torch.cos(8 * (xs + ys)), + ], + dim=0, + ) + .unsqueeze(0) + .clamp(0, 1) + ) + + with torch.no_grad(): + out = net(img) + n_pix = height * width + psnr = -10 * torch.log10(((out["x_hat"].clamp(0, 1) - img) ** 2).mean()).item() + y_bpp = -torch.log2(out["likelihoods"]["y"]).sum().item() / n_pix + z_bpp = -torch.log2(out["likelihoods"]["z"]).sum().item() / n_pix + print( + f"smoke: PSNR={psnr:.2f}dB y_bpp={y_bpp:.4f} z_bpp={z_bpp:.4f} " + f"total_bpp={y_bpp + z_bpp:.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index dc6d2cc9..da77a0b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,9 @@ pointcloud = [ # "pointops-yoda", # Please install via uv pip install pointops-yoda --no-build-isolation "pyntcloud-yoda", ] +attn = [ + "timm", +] # NOTE: Temporarily duplicated from [project.optional-dependencies] until # pip supports installing [dependency-groups]. diff --git a/tests/test_models.py b/tests/test_models.py index c69ae7d5..c23b1865 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -278,6 +278,63 @@ def test_scale_space_flow(self): assert z_likelihoods_shape[3] == x[1].shape[3] / 2**7 +class TestStf: + def test_wacnn_forward_and_state_dict_round_trip(self): + from compressai.models.stf import WACNN + + model = WACNN(N=64, M=128, num_slices=4, max_support_slices=2).eval() + x = torch.rand(1, 3, 64, 64) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + + loaded = WACNN.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + + def test_symmetrical_transformer_forward_and_state_dict_round_trip(self): + from compressai.models.stf import SymmetricalTransFormer + + model = SymmetricalTransFormer( + embed_dim=24, + depths=(1, 1, 1, 1), + num_heads=(2, 2, 2, 2), + num_slices=4, + max_support_slices=2, + ).eval() + x = torch.rand(1, 3, 128, 128) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + + loaded = SymmetricalTransFormer.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + + def test_stf_upstream_state_dict_conversion(self): + from compressai.models.stf import ( + convert_upstream_stf_state_dict, + ) + + upstream = { + "module.g_a.0.weight": torch.zeros(2), + "module.cc_mean_transforms.0.0.weight": torch.zeros(2), + "module.gaussian_conditional.scale_table": torch.zeros(2), + "module.h_a.0.weight": torch.zeros(2), + } + converted = convert_upstream_stf_state_dict(upstream) + assert "g_a.0.weight" in converted + assert "latent_codec.cc_mean_transforms.0.0.weight" in converted + assert "latent_codec.gaussian_conditional.scale_table" in converted + assert "h_a.0.weight" in converted + + def test_scale_table_default(): table = get_scale_table() assert SCALES_MIN == 0.11 diff --git a/uv.lock b/uv.lock index ae4ce819..7f3aa8ea 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.8, <14" resolution-markers = [ "python_full_version >= '3.12'", @@ -179,6 +179,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/64/88/c7083fc61120ab661c5d0b82cb77079fc1429d3f913a456c1c82cf4658f7/alabaster-0.7.13-py3-none-any.whl", hash = "sha256:1ee19aca801bbabb5ba3f5f258e4422dfa86f82f3e9cefb0859b283cdd7f62a3", size = 13857, upload-time = "2023-01-13T06:42:52.336Z" }, ] +[[package]] +name = "annotated-doc" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" }, +] + [[package]] name = "anyio" version = "4.5.2" @@ -332,7 +341,8 @@ name = "black" version = "24.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "click" }, + { name = "click", version = "8.1.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "click", version = "8.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "mypy-extensions" }, { name = "packaging" }, { name = "pathspec" }, @@ -560,14 +570,35 @@ wheels = [ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.9.*'", + "python_full_version < '3.9'", +] dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "python_full_version < '3.10' and sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593, upload-time = "2024-12-21T18:38:44.339Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188, upload-time = "2024-12-21T18:38:41.666Z" }, ] +[[package]] +name = "click" +version = "8.3.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version == '3.10.*'", +] +dependencies = [ + { name = "colorama", marker = "python_full_version >= '3.10' and sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bb/63/f9e1ea081ce35720d8b92acde70daaedace594dc93b693c869e0d5910718/click-8.3.3.tar.gz", hash = "sha256:398329ad4837b2ff7cbe1dd166a4c0f8900c3ca3a218de04466f38f6497f18a2", size = 328061, upload-time = "2026-04-22T15:11:27.506Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl", hash = "sha256:a2bf429bb3033c89fa4936ffb35d5cb471e3719e1f3c8a7c3fff0b8314305613", size = 110502, upload-time = "2026-04-22T15:11:25.044Z" }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -619,6 +650,9 @@ dependencies = [ ] [package.optional-dependencies] +attn = [ + { name = "timm" }, +] dev = [ { name = "black" }, { name = "flake8" }, @@ -708,6 +742,7 @@ requires-dist = [ { name = "setuptools", specifier = ">=68" }, { name = "sphinx", marker = "extra == 'doc'", specifier = "==4.3.0" }, { name = "sphinx-book-theme", marker = "extra == 'doc'", specifier = "==1.0.1" }, + { name = "timm", marker = "extra == 'attn'" }, { name = "tomli", specifier = ">=2.2.1" }, { name = "torch", specifier = ">=1.13.1" }, { name = "torch", marker = "python_full_version >= '3.12'", specifier = ">=2.6.0" }, @@ -718,7 +753,7 @@ requires-dist = [ { name = "typing-extensions", specifier = ">=4.0.0" }, { name = "wheel", specifier = ">=0.32.0" }, ] -provides-extras = ["test", "dev", "doc", "tutorials", "pointcloud"] +provides-extras = ["test", "dev", "doc", "tutorials", "pointcloud", "attn"] [package.metadata.requires-dev] dev = [ @@ -1380,6 +1415,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259, upload-time = "2022-09-25T15:39:59.68Z" }, ] +[[package]] +name = "hf-xet" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/74/d8/5c06fc76461418326a7decf8367480c35be11a41fd938633929c60a9ec6b/hf_xet-1.5.0.tar.gz", hash = "sha256:e0fb0a34d9f406eed88233e829a67ec016bec5af19e480eac65a233ea289a948", size = 837196, upload-time = "2026-05-06T06:18:15.583Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/9b/6912c99070915a4f28119e3c5b52a9abd1eec0ad5cb293b8c967a0c6f5a2/hf_xet-1.5.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:7d70fe2ce97b9db73b9c9b9c81fe3693640aec83416a966c446afea54acfae3c", size = 4023383, upload-time = "2026-05-06T06:17:53.947Z" }, + { url = "https://files.pythonhosted.org/packages/0f/6d/9563cfde59b5d8128a9c7ec972a087f4c782e4f7bac5a85234edfd5d5e49/hf_xet-1.5.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:73a0dae8c71de3b0633a45c73f4a4a5ed09e94b43441d82981a781d4f12baa42", size = 3792751, upload-time = "2026-05-06T06:17:51.791Z" }, + { url = "https://files.pythonhosted.org/packages/07/a5/ed5a0cf35b49a0571af5a8f53416dad1877a718c021c9937c3a53cb45781/hf_xet-1.5.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a60290ec57e9b71767fba7c3645ddafdd0759974b540441510c629c6db6db24a", size = 4456058, upload-time = "2026-05-06T06:17:40.735Z" }, + { url = "https://files.pythonhosted.org/packages/60/fb/3ae8bf2a7a37a4197d0195d7247fd25b3952e15cb8a599e285dfaa6f52b3/hf_xet-1.5.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:e5de0f6deada0dada870bb376a11bcd1f08abf3a968a6d118f33e72d1b1eb480", size = 4250783, upload-time = "2026-05-06T06:17:38.412Z" }, + { url = "https://files.pythonhosted.org/packages/a2/9b/8bae40d4d91525085137196e84eb0ed49cf65b5e96e5c3ecdadd8bd0fac2/hf_xet-1.5.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c799d49f1a5544a0ef7591c0ee75e0d6b93d6f56dc7a4979f59f7518d2872216", size = 4445594, upload-time = "2026-05-06T06:18:04.219Z" }, + { url = "https://files.pythonhosted.org/packages/13/59/c74efbbd4e8728172b2cc72a2bc014d2947a4b7bdced932fbd3f5da1a4e5/hf_xet-1.5.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2baea1b0b989e5c152fe81425f7745ddc8901280ba3d97c98d8cdece7b706c60", size = 4663995, upload-time = "2026-05-06T06:18:06.1Z" }, + { url = "https://files.pythonhosted.org/packages/73/32/8e1e0410af64cda9b139d1dcebdc993a8ff9c8c7c0e2696ae356d75ccc0d/hf_xet-1.5.0-cp313-cp313t-win_amd64.whl", hash = "sha256:526345b3ed45f374f6317349df489167606736c876241ba984105afe7fd4839d", size = 3966608, upload-time = "2026-05-06T06:18:19.74Z" }, + { url = "https://files.pythonhosted.org/packages/fc/34/a8febc8f4edbea8b3e21b02ebc8b628679b84ba7e45cde624a7736b51500/hf_xet-1.5.0-cp313-cp313t-win_arm64.whl", hash = "sha256:786d28e2eb8315d5035544b9d137b4a842d600c434bb91bf7d0d953cce906ad4", size = 3796946, upload-time = "2026-05-06T06:18:17.568Z" }, + { url = "https://files.pythonhosted.org/packages/2a/20/8fc8996afe5815fa1a6be8e9e5c02f24500f409d599e905800d498a4e14d/hf_xet-1.5.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:872d5601e6deea30d15865ede55d29eac6daf5a534ab417b99b6ef6b076dd96c", size = 4023495, upload-time = "2026-05-06T06:18:01.94Z" }, + { url = "https://files.pythonhosted.org/packages/32/6a/93d84463c00cecb561a7508aa6303e35ee2894294eac14245526924415fe/hf_xet-1.5.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:9929561f5abf4581c8ea79587881dfef6b8abb2a0d8a51915936fc2a614f4e73", size = 3792731, upload-time = "2026-05-06T06:18:00.021Z" }, + { url = "https://files.pythonhosted.org/packages/9d/5a/8ec8e0c863b382d00b3c2e2af6ded6b06371be617144a625903a6d562f4b/hf_xet-1.5.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f7b7bbae318e583a86fb21e5a4a175d6721d628a2874f4bd022d0e660c32a682", size = 4456738, upload-time = "2026-05-06T06:17:49.574Z" }, + { url = "https://files.pythonhosted.org/packages/c5/ca/f7effa1a67717da2bcc6b6c28f71c6ca648c77acaec4e2c32f40cbe16d85/hf_xet-1.5.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:cf7b2dc6f31a4ea754bb50f74cde482dcf5d366d184076d8530b9872787f3761", size = 4251622, upload-time = "2026-05-06T06:17:47.096Z" }, + { url = "https://files.pythonhosted.org/packages/65/f2/19247dba3e231cf77dec59ddfb878f00057635ff773d099c9b59d37812c3/hf_xet-1.5.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:8dbcbab554c9ef158ef2c991545c3e970ddd8cc7acdcd0a78c5a41095dab4ded", size = 4445667, upload-time = "2026-05-06T06:18:11.983Z" }, + { url = "https://files.pythonhosted.org/packages/7f/64/6f116801a3bcfb6f59f5c251f48cadc47ea54026441c4a385079286a94fa/hf_xet-1.5.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5906bf7718d3636dc13402914736abe723492cb730f744834f5f5b67d3a12702", size = 4664619, upload-time = "2026-05-06T06:18:13.771Z" }, + { url = "https://files.pythonhosted.org/packages/5c/e8/069542d37946ed08669b127e1496fa99e78196d71de8d41eda5e9f1b7a58/hf_xet-1.5.0-cp314-cp314t-win_amd64.whl", hash = "sha256:5f3dc2248fc01cc0a00cd392ab497f1ca373fcbc7e3f2da1f452480b384e839e", size = 3966802, upload-time = "2026-05-06T06:18:28.162Z" }, + { url = "https://files.pythonhosted.org/packages/f9/91/fc6fdec27b14d04e88c386ac0a0129732b53fa23f7c4a78f4b83a039c567/hf_xet-1.5.0-cp314-cp314t-win_arm64.whl", hash = "sha256:b285cea1b5bab46b758772716ba8d6854a1a0310fed1c249d678a8b38601e5a0", size = 3797168, upload-time = "2026-05-06T06:18:26.287Z" }, + { url = "https://files.pythonhosted.org/packages/3d/fb/69ff198a82cae7eb1a69fb84d93b3a3e4816564d76817fe541ddc96874eb/hf_xet-1.5.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:dad0dc84e941b8ba3c860659fe1fdc35c049d47cce293f003287757e971a8f56", size = 4030814, upload-time = "2026-05-06T06:17:57.933Z" }, + { url = "https://files.pythonhosted.org/packages/9b/ff/edcc2b40162bef3ff78e14ab637e5f3b89243d6aee72f5949d3bb6a5af83/hf_xet-1.5.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:fd6e5a9b0fdac4ed03ed45ef79254a655b1aaab514a02202617fbf643f5fdf7a", size = 3798444, upload-time = "2026-05-06T06:17:55.79Z" }, + { url = "https://files.pythonhosted.org/packages/49/4d/103f76b04310e5e57656696cc184690d20c466af0bca3ca88f8c8ea5d4f3/hf_xet-1.5.0-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3531b1823a0e6d77d80f9ed15ca0e00f0d115094f8ac033d5cae88f4564cc949", size = 4465986, upload-time = "2026-05-06T06:17:44.886Z" }, + { url = "https://files.pythonhosted.org/packages/c4/a2/546f47f464737b3edbab6f8ddb57f2599b93d2cbb66f06abb475ccb48651/hf_xet-1.5.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:9a0ee58cd18d5ea799f7ed11290bbccbe56bdd8b1d97ca74b9cc49a3945d7a3b", size = 4259865, upload-time = "2026-05-06T06:17:42.639Z" }, + { url = "https://files.pythonhosted.org/packages/95/7f/1be593c1f28613be2e196473481cd81bfc5910795e30a34e8f744f6cac4f/hf_xet-1.5.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:1e60df5a42e9bed8628b6416af2cba4cba57ae9f02de226a06b020d98e1aab18", size = 4459835, upload-time = "2026-05-06T06:18:08.026Z" }, + { url = "https://files.pythonhosted.org/packages/aa/b2/703569fc881f3284487e68cda7b42179978480da3c438042a6bbbb4a671c/hf_xet-1.5.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4b35549ce62601b84da4ff9b24d970032ace3d4430f52d91bcbb26c901d6c690", size = 4672414, upload-time = "2026-05-06T06:18:09.864Z" }, + { url = "https://files.pythonhosted.org/packages/af/37/1b6def445c567286b50aa3b33828158e135b1be44938dde59f11382a500c/hf_xet-1.5.0-cp37-abi3-win_amd64.whl", hash = "sha256:2806c7c17b4d23f8d88f7c4814f838c3b6150773fe339c20af23e1cfaf2797e4", size = 3977238, upload-time = "2026-05-06T06:18:23.621Z" }, + { url = "https://files.pythonhosted.org/packages/62/94/3b66b148778ee100dcfd69c2ca22b57b41b44d3063ceec934f209e9184ce/hf_xet-1.5.0-cp37-abi3-win_arm64.whl", hash = "sha256:b6c9df403040248c76d808d3e047d64db2d923bae593eb244c41e425cf6cd7be", size = 3806916, upload-time = "2026-05-06T06:18:21.7Z" }, +] + [[package]] name = "httpcore" version = "1.0.7" @@ -1408,6 +1475,76 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "huggingface-hub" +version = "0.36.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.9'", +] +dependencies = [ + { name = "filelock", marker = "python_full_version < '3.9'" }, + { name = "fsspec", marker = "python_full_version < '3.9'" }, + { name = "hf-xet", marker = "(python_full_version < '3.9' and platform_machine == 'aarch64') or (python_full_version < '3.9' and platform_machine == 'amd64') or (python_full_version < '3.9' and platform_machine == 'arm64') or (python_full_version < '3.9' and platform_machine == 'x86_64')" }, + { name = "packaging", marker = "python_full_version < '3.9'" }, + { name = "pyyaml", marker = "python_full_version < '3.9'" }, + { name = "requests", marker = "python_full_version < '3.9'" }, + { name = "tqdm", marker = "python_full_version < '3.9'" }, + { name = "typing-extensions", marker = "python_full_version < '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7c/b7/8cb61d2eece5fb05a83271da168186721c450eb74e3c31f7ef3169fa475b/huggingface_hub-0.36.2.tar.gz", hash = "sha256:1934304d2fb224f8afa3b87007d58501acfda9215b334eed53072dd5e815ff7a", size = 649782, upload-time = "2026-02-06T09:24:13.098Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/af/48ac8483240de756d2438c380746e7130d1c6f75802ef22f3c6d49982787/huggingface_hub-0.36.2-py3-none-any.whl", hash = "sha256:48f0c8eac16145dfce371e9d2d7772854a4f591bcb56c9cf548accf531d54270", size = 566395, upload-time = "2026-02-06T09:24:11.133Z" }, +] + +[[package]] +name = "huggingface-hub" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.9.*'", +] +dependencies = [ + { name = "filelock", marker = "python_full_version == '3.9.*'" }, + { name = "fsspec", marker = "python_full_version == '3.9.*'" }, + { name = "hf-xet", marker = "(python_full_version == '3.9.*' and platform_machine == 'AMD64') or (python_full_version == '3.9.*' and platform_machine == 'aarch64') or (python_full_version == '3.9.*' and platform_machine == 'amd64') or (python_full_version == '3.9.*' and platform_machine == 'arm64') or (python_full_version == '3.9.*' and platform_machine == 'x86_64')" }, + { name = "httpx", marker = "python_full_version == '3.9.*'" }, + { name = "packaging", marker = "python_full_version == '3.9.*'" }, + { name = "pyyaml", marker = "python_full_version == '3.9.*'" }, + { name = "tqdm", marker = "python_full_version == '3.9.*'" }, + { name = "typer", version = "0.23.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, + { name = "typing-extensions", marker = "python_full_version == '3.9.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8e/2a/a847fd02261cd051da218baf99f90ee7c7040c109a01833db4f838f25256/huggingface_hub-1.8.0.tar.gz", hash = "sha256:c5627b2fd521e00caf8eff4ac965ba988ea75167fad7ee72e17f9b7183ec63f3", size = 735839, upload-time = "2026-03-25T16:01:28.152Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/ae/8a3a16ea4d202cb641b51d2681bdd3d482c1c592d7570b3fa264730829ce/huggingface_hub-1.8.0-py3-none-any.whl", hash = "sha256:d3eb5047bd4e33c987429de6020d4810d38a5bef95b3b40df9b17346b7f353f2", size = 625208, upload-time = "2026-03-25T16:01:26.603Z" }, +] + +[[package]] +name = "huggingface-hub" +version = "1.14.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version == '3.10.*'", +] +dependencies = [ + { name = "filelock", marker = "python_full_version >= '3.10'" }, + { name = "fsspec", marker = "python_full_version >= '3.10'" }, + { name = "hf-xet", marker = "(python_full_version >= '3.10' and platform_machine == 'AMD64') or (python_full_version >= '3.10' and platform_machine == 'aarch64') or (python_full_version >= '3.10' and platform_machine == 'amd64') or (python_full_version >= '3.10' and platform_machine == 'arm64') or (python_full_version >= '3.10' and platform_machine == 'x86_64')" }, + { name = "httpx", marker = "python_full_version >= '3.10'" }, + { name = "packaging", marker = "python_full_version >= '3.10'" }, + { name = "pyyaml", marker = "python_full_version >= '3.10'" }, + { name = "tqdm", marker = "python_full_version >= '3.10'" }, + { name = "typer", version = "0.25.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/40/43109e943fd718b0ccd0cd61eb4f1c347df22bf81f5874c6f22adf44bcff/huggingface_hub-1.14.0.tar.gz", hash = "sha256:d6d2c9cd6be1d02ae9ec6672d5587d10a427f377db688e82528f426a041622c2", size = 782365, upload-time = "2026-05-06T14:14:34.278Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/a5/33b49ba7bea7c41bb37f74ec0f8beea0831e052330196633fe2c77516ea6/huggingface_hub-1.14.0-py3-none-any.whl", hash = "sha256:efe075535c62e130b30e836b138e13785f6f043d1f0539e0a39aa411a99e90b8", size = 661479, upload-time = "2026-05-06T14:14:32.029Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -1954,6 +2091,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8c/95/4a103776c265d13b3d2cd24fb0494d4e04ea435a8ef97e1b2c026d43250b/kiwisolver-1.4.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:0c6c43471bc764fad4bc99c5c2d6d16a676b1abf844ca7c8702bdae92df01ee0", size = 55811, upload-time = "2024-09-04T09:06:53.078Z" }, ] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.9.*'", +] +dependencies = [ + { name = "mdurl", marker = "python_full_version == '3.9.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596, upload-time = "2023-06-03T06:41:14.443Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528, upload-time = "2023-06-03T06:41:11.019Z" }, +] + +[[package]] +name = "markdown-it-py" +version = "4.2.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version == '3.10.*'", +] +dependencies = [ + { name = "mdurl", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/ff/7841249c247aa650a76b9ee4bbaeae59370dc8bfd2f6c01f3630c35eb134/markdown_it_py-4.2.0.tar.gz", hash = "sha256:04a21681d6fbb623de53f6f364d352309d4094dd4194040a10fd51833e418d49", size = 82454, upload-time = "2026-05-07T12:08:28.36Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/81/4da04ced5a082363ecfa159c010d200ecbd959ae410c10c0264a38cac0f5/markdown_it_py-4.2.0-py3-none-any.whl", hash = "sha256:9f7ebbcd14fe59494226453aed97c1070d83f8d24b6fc3a3bcf9a38092641c4a", size = 91687, upload-time = "2026-05-07T12:08:27.182Z" }, +] + [[package]] name = "markupsafe" version = "2.1.5" @@ -2170,6 +2339,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/27/1a/1f68f9ba0c207934b35b86a8ca3aad8395a3d6dd7921c0686e23853ff5a9/mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e", size = 7350, upload-time = "2022-01-24T01:14:49.62Z" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "mistune" version = "3.1.3" @@ -3769,6 +3947,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/51/17023c0f8f1869d8806b979a2bffa3f861f26a3f1a66b094288323fba52f/rfc3986_validator-0.1.1-py2.py3-none-any.whl", hash = "sha256:2f235c432ef459970b4306369336b9d5dbdda31b510ca1e327636e01f528bfa9", size = 4242, upload-time = "2019-10-28T16:00:13.976Z" }, ] +[[package]] +name = "rich" +version = "15.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py", version = "3.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, + { name = "markdown-it-py", version = "4.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "pygments", marker = "python_full_version >= '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/8f/0722ca900cc807c13a6a0c696dacf35430f72e0ec571c4275d2371fca3e9/rich-15.0.0.tar.gz", hash = "sha256:edd07a4824c6b40189fb7ac9bc4c52536e9780fbbfbddf6f1e2502c31b068c36", size = 230680, upload-time = "2026-04-12T08:24:00.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" }, +] + [[package]] name = "rpds-py" version = "0.20.1" @@ -3904,6 +4096,67 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/f8/3765e053acd07baa055c96b2065c7fab91f911b3c076dfea71006666f5b0/ruff-0.8.6-py3-none-win_arm64.whl", hash = "sha256:7d7fc2377a04b6e04ffe588caad613d0c460eb2ecba4c0ccbbfe2bc973cbc162", size = 9149556, upload-time = "2025-01-04T12:22:57.173Z" }, ] +[[package]] +name = "safetensors" +version = "0.5.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.9'", +] +sdist = { url = "https://files.pythonhosted.org/packages/71/7e/2d5d6ee7b40c0682315367ec7475693d110f512922d582fef1bd4a63adc3/safetensors-0.5.3.tar.gz", hash = "sha256:b6b0d6ecacec39a4fdd99cc19f4576f5219ce858e6fd8dbe7609df0b8dc56965", size = 67210, upload-time = "2025-02-26T09:15:13.155Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/ae/88f6c49dbd0cc4da0e08610019a3c78a7d390879a919411a410a1876d03a/safetensors-0.5.3-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bd20eb133db8ed15b40110b7c00c6df51655a2998132193de2f75f72d99c7073", size = 436917, upload-time = "2025-02-26T09:15:03.702Z" }, + { url = "https://files.pythonhosted.org/packages/b8/3b/11f1b4a2f5d2ab7da34ecc062b0bc301f2be024d110a6466726bec8c055c/safetensors-0.5.3-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:21d01c14ff6c415c485616b8b0bf961c46b3b343ca59110d38d744e577f9cce7", size = 418419, upload-time = "2025-02-26T09:15:01.765Z" }, + { url = "https://files.pythonhosted.org/packages/5d/9a/add3e6fef267658075c5a41573c26d42d80c935cdc992384dfae435feaef/safetensors-0.5.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11bce6164887cd491ca75c2326a113ba934be596e22b28b1742ce27b1d076467", size = 459493, upload-time = "2025-02-26T09:14:51.812Z" }, + { url = "https://files.pythonhosted.org/packages/df/5c/bf2cae92222513cc23b3ff85c4a1bb2811a2c3583ac0f8e8d502751de934/safetensors-0.5.3-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4a243be3590bc3301c821da7a18d87224ef35cbd3e5f5727e4e0728b8172411e", size = 472400, upload-time = "2025-02-26T09:14:53.549Z" }, + { url = "https://files.pythonhosted.org/packages/58/11/7456afb740bd45782d0f4c8e8e1bb9e572f1bf82899fb6ace58af47b4282/safetensors-0.5.3-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8bd84b12b1670a6f8e50f01e28156422a2bc07fb16fc4e98bded13039d688a0d", size = 522891, upload-time = "2025-02-26T09:14:55.717Z" }, + { url = "https://files.pythonhosted.org/packages/57/3d/fe73a9d2ace487e7285f6e157afee2383bd1ddb911b7cb44a55cf812eae3/safetensors-0.5.3-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:391ac8cab7c829452175f871fcaf414aa1e292b5448bd02620f675a7f3e7abb9", size = 537694, upload-time = "2025-02-26T09:14:57.036Z" }, + { url = "https://files.pythonhosted.org/packages/a6/f8/dae3421624fcc87a89d42e1898a798bc7ff72c61f38973a65d60df8f124c/safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cead1fa41fc54b1e61089fa57452e8834f798cb1dc7a09ba3524f1eb08e0317a", size = 471642, upload-time = "2025-02-26T09:15:00.544Z" }, + { url = "https://files.pythonhosted.org/packages/ce/20/1fbe16f9b815f6c5a672f5b760951e20e17e43f67f231428f871909a37f6/safetensors-0.5.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1077f3e94182d72618357b04b5ced540ceb71c8a813d3319f1aba448e68a770d", size = 502241, upload-time = "2025-02-26T09:14:58.303Z" }, + { url = "https://files.pythonhosted.org/packages/5f/18/8e108846b506487aa4629fe4116b27db65c3dde922de2c8e0cc1133f3f29/safetensors-0.5.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:799021e78287bac619c7b3f3606730a22da4cda27759ddf55d37c8db7511c74b", size = 638001, upload-time = "2025-02-26T09:15:05.79Z" }, + { url = "https://files.pythonhosted.org/packages/82/5a/c116111d8291af6c8c8a8b40628fe833b9db97d8141c2a82359d14d9e078/safetensors-0.5.3-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:df26da01aaac504334644e1b7642fa000bfec820e7cef83aeac4e355e03195ff", size = 734013, upload-time = "2025-02-26T09:15:07.892Z" }, + { url = "https://files.pythonhosted.org/packages/7d/ff/41fcc4d3b7de837963622e8610d998710705bbde9a8a17221d85e5d0baad/safetensors-0.5.3-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:32c3ef2d7af8b9f52ff685ed0bc43913cdcde135089ae322ee576de93eae5135", size = 670687, upload-time = "2025-02-26T09:15:09.979Z" }, + { url = "https://files.pythonhosted.org/packages/40/ad/2b113098e69c985a3d8fbda4b902778eae4a35b7d5188859b4a63d30c161/safetensors-0.5.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:37f1521be045e56fc2b54c606d4455573e717b2d887c579ee1dbba5f868ece04", size = 643147, upload-time = "2025-02-26T09:15:11.185Z" }, + { url = "https://files.pythonhosted.org/packages/0a/0c/95aeb51d4246bd9a3242d3d8349c1112b4ee7611a4b40f0c5c93b05f001d/safetensors-0.5.3-cp38-abi3-win32.whl", hash = "sha256:cfc0ec0846dcf6763b0ed3d1846ff36008c6e7290683b61616c4b040f6a54ace", size = 296677, upload-time = "2025-02-26T09:15:16.554Z" }, + { url = "https://files.pythonhosted.org/packages/69/e2/b011c38e5394c4c18fb5500778a55ec43ad6106126e74723ffaee246f56e/safetensors-0.5.3-cp38-abi3-win_amd64.whl", hash = "sha256:836cbbc320b47e80acd40e44c8682db0e8ad7123209f69b093def21ec7cafd11", size = 308878, upload-time = "2025-02-26T09:15:14.99Z" }, +] + +[[package]] +name = "safetensors" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version == '3.10.*'", + "python_full_version == '3.9.*'", +] +sdist = { url = "https://files.pythonhosted.org/packages/29/9c/6e74567782559a63bd040a236edca26fd71bc7ba88de2ef35d75df3bca5e/safetensors-0.7.0.tar.gz", hash = "sha256:07663963b67e8bd9f0b8ad15bb9163606cd27cc5a1b96235a50d8369803b96b0", size = 200878, upload-time = "2025-11-19T15:18:43.199Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/47/aef6c06649039accf914afef490268e1067ed82be62bcfa5b7e886ad15e8/safetensors-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c82f4d474cf725255d9e6acf17252991c3c8aac038d6ef363a4bf8be2f6db517", size = 467781, upload-time = "2025-11-19T15:18:35.84Z" }, + { url = "https://files.pythonhosted.org/packages/e8/00/374c0c068e30cd31f1e1b46b4b5738168ec79e7689ca82ee93ddfea05109/safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94fd4858284736bb67a897a41608b5b0c2496c9bdb3bf2af1fa3409127f20d57", size = 447058, upload-time = "2025-11-19T15:18:34.416Z" }, + { url = "https://files.pythonhosted.org/packages/f1/06/578ffed52c2296f93d7fd2d844cabfa92be51a587c38c8afbb8ae449ca89/safetensors-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e07d91d0c92a31200f25351f4acb2bc6aff7f48094e13ebb1d0fb995b54b6542", size = 491748, upload-time = "2025-11-19T15:18:09.79Z" }, + { url = "https://files.pythonhosted.org/packages/ae/33/1debbbb70e4791dde185edb9413d1fe01619255abb64b300157d7f15dddd/safetensors-0.7.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8469155f4cb518bafb4acf4865e8bb9d6804110d2d9bdcaa78564b9fd841e104", size = 503881, upload-time = "2025-11-19T15:18:16.145Z" }, + { url = "https://files.pythonhosted.org/packages/8e/1c/40c2ca924d60792c3be509833df711b553c60effbd91da6f5284a83f7122/safetensors-0.7.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54bef08bf00a2bff599982f6b08e8770e09cc012d7bba00783fc7ea38f1fb37d", size = 623463, upload-time = "2025-11-19T15:18:21.11Z" }, + { url = "https://files.pythonhosted.org/packages/9b/3a/13784a9364bd43b0d61eef4bea2845039bc2030458b16594a1bd787ae26e/safetensors-0.7.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42cb091236206bb2016d245c377ed383aa7f78691748f3bb6ee1bfa51ae2ce6a", size = 532855, upload-time = "2025-11-19T15:18:25.719Z" }, + { url = "https://files.pythonhosted.org/packages/a0/60/429e9b1cb3fc651937727befe258ea24122d9663e4d5709a48c9cbfceecb/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac7252938f0696ddea46f5e855dd3138444e82236e3be475f54929f0c510d48", size = 507152, upload-time = "2025-11-19T15:18:33.023Z" }, + { url = "https://files.pythonhosted.org/packages/3c/a8/4b45e4e059270d17af60359713ffd83f97900d45a6afa73aaa0d737d48b6/safetensors-0.7.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1d060c70284127fa805085d8f10fbd0962792aed71879d00864acda69dbab981", size = 541856, upload-time = "2025-11-19T15:18:31.075Z" }, + { url = "https://files.pythonhosted.org/packages/06/87/d26d8407c44175d8ae164a95b5a62707fcc445f3c0c56108e37d98070a3d/safetensors-0.7.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cdab83a366799fa730f90a4ebb563e494f28e9e92c4819e556152ad55e43591b", size = 674060, upload-time = "2025-11-19T15:18:37.211Z" }, + { url = "https://files.pythonhosted.org/packages/11/f5/57644a2ff08dc6325816ba7217e5095f17269dada2554b658442c66aed51/safetensors-0.7.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:672132907fcad9f2aedcb705b2d7b3b93354a2aec1b2f706c4db852abe338f85", size = 771715, upload-time = "2025-11-19T15:18:38.689Z" }, + { url = "https://files.pythonhosted.org/packages/86/31/17883e13a814bd278ae6e266b13282a01049b0c81341da7fd0e3e71a80a3/safetensors-0.7.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5d72abdb8a4d56d4020713724ba81dac065fedb7f3667151c4a637f1d3fb26c0", size = 714377, upload-time = "2025-11-19T15:18:40.162Z" }, + { url = "https://files.pythonhosted.org/packages/4a/d8/0c8a7dc9b41dcac53c4cbf9df2b9c83e0e0097203de8b37a712b345c0be5/safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4", size = 677368, upload-time = "2025-11-19T15:18:41.627Z" }, + { url = "https://files.pythonhosted.org/packages/05/e5/cb4b713c8a93469e3c5be7c3f8d77d307e65fe89673e731f5c2bfd0a9237/safetensors-0.7.0-cp38-abi3-win32.whl", hash = "sha256:c74af94bf3ac15ac4d0f2a7c7b4663a15f8c2ab15ed0fc7531ca61d0835eccba", size = 326423, upload-time = "2025-11-19T15:18:45.74Z" }, + { url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" }, + { url = "https://files.pythonhosted.org/packages/a7/6a/4d08d89a6fcbe905c5ae68b8b34f0791850882fc19782d0d02c65abbdf3b/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4729811a6640d019a4b7ba8638ee2fd21fa5ca8c7e7bdf0fed62068fcaac737", size = 492430, upload-time = "2025-11-19T15:18:11.884Z" }, + { url = "https://files.pythonhosted.org/packages/dd/29/59ed8152b30f72c42d00d241e58eaca558ae9dbfa5695206e2e0f54c7063/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:12f49080303fa6bb424b362149a12949dfbbf1e06811a88f2307276b0c131afd", size = 503977, upload-time = "2025-11-19T15:18:17.523Z" }, + { url = "https://files.pythonhosted.org/packages/d3/0b/4811bfec67fa260e791369b16dab105e4bae82686120554cc484064e22b4/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0071bffba4150c2f46cae1432d31995d77acfd9f8db598b5d1a2ce67e8440ad2", size = 623890, upload-time = "2025-11-19T15:18:22.666Z" }, + { url = "https://files.pythonhosted.org/packages/58/5b/632a58724221ef03d78ab65062e82a1010e1bef8e8e0b9d7c6d7b8044841/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:473b32699f4200e69801bf5abf93f1a4ecd432a70984df164fc22ccf39c4a6f3", size = 531885, upload-time = "2025-11-19T15:18:27.146Z" }, + { url = "https://files.pythonhosted.org/packages/94/60/13ccb63ea85bfe2e4fe6af602cf1272155f048906556d5ec8509da9dba42/safetensors-0.7.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b95a3fa7b3abb9b5b0e07668e808364d0d40f6bbbf9ae0faa8b5b210c97b140", size = 492627, upload-time = "2025-11-19T15:18:14.661Z" }, + { url = "https://files.pythonhosted.org/packages/2e/2b/e2fde0d6334439908b0b0c4cba18b8ad76ea6a03b569d4a3388f423b4046/safetensors-0.7.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cfdead2f57330d76aa7234051dadfa7d4eedc0e5a27fd08e6f96714a92b00f09", size = 503861, upload-time = "2025-11-19T15:18:19.418Z" }, + { url = "https://files.pythonhosted.org/packages/f0/71/566e3dd559a9cef1b4775c239daae09e6b6a32ca8b45eb1db9a4dfa1ba81/safetensors-0.7.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc92bc2db7b45bda4510e4f51c59b00fe80b2d6be88928346e4294ce1c2abe7c", size = 623577, upload-time = "2025-11-19T15:18:24.275Z" }, + { url = "https://files.pythonhosted.org/packages/82/fc/3035c5c30c8a5a82c31c6b2ad6f8bcd45ea2ddd9a8088840406bcf997413/safetensors-0.7.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6999421eb8ba9df4450a16d9184fcb7bef26240b9f98e95401f17af6c2210b71", size = 532524, upload-time = "2025-11-19T15:18:29.334Z" }, +] + [[package]] name = "scipy" version = "1.10.1" @@ -3998,6 +4251,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/42/be1c7bbdd83e1bfb160c94b9cafd8e25efc7400346cf7ccdbdb452c467fa/setuptools-68.0.0-py3-none-any.whl", hash = "sha256:11e52c67415a381d10d6b462ced9cfb97066179f0e871399e006c4ab101fc85f", size = 804037, upload-time = "2023-06-19T15:53:03.089Z" }, ] +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, +] + [[package]] name = "six" version = "1.17.0" @@ -4169,6 +4431,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/9e/2064975477fdc887e47ad42157e214526dcad8f317a948dee17e1659a62f/terminado-0.18.1-py3-none-any.whl", hash = "sha256:a4468e1b37bb318f8a86514f65814e1afc977cf29b3992a4500d9dd305dcceb0", size = 14154, upload-time = "2024-03-12T14:34:36.569Z" }, ] +[[package]] +name = "timm" +version = "1.0.26" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub", version = "0.36.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "huggingface-hub", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, + { name = "huggingface-hub", version = "1.14.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "pyyaml" }, + { name = "safetensors", version = "0.5.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "safetensors", version = "0.7.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "torch", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "torchvision", version = "0.17.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "torchvision", version = "0.24.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7b/1e/e924b3b2326a856aaf68586f9c52a5fc81ef45715eca408393b68c597e0e/timm-1.0.26.tar.gz", hash = "sha256:f66f082f2f381cf68431c22714c8b70f723837fa2a185b155961eab90f2d5b10", size = 2419859, upload-time = "2026-03-23T18:12:10.272Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/e9/bebf3d50e3fc847378988235f87c37ad3ac26d386041ab915d15e92025cd/timm-1.0.26-py3-none-any.whl", hash = "sha256:985c330de5ccc3a2aa0224eb7272e6a336084702390bb7e3801f3c91603d3683", size = 2568766, upload-time = "2026-03-23T18:12:08.062Z" }, +] + [[package]] name = "tinycss2" version = "1.2.1" @@ -4525,6 +4808,44 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/b7/1dec8433ac604c061173d0589d99217fe7bf90a70bdc375e745d044b8aad/triton-3.5.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:317fe477ea8fd4524a6a8c499fb0a36984a56d0b75bf9c9cb6133a1c56d5a6e7", size = 170580176, upload-time = "2025-10-13T16:38:31.14Z" }, ] +[[package]] +name = "typer" +version = "0.23.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.9.*'", +] +dependencies = [ + { name = "annotated-doc", marker = "python_full_version == '3.9.*'" }, + { name = "click", version = "8.1.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, + { name = "rich", marker = "python_full_version == '3.9.*'" }, + { name = "shellingham", marker = "python_full_version == '3.9.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d3/ae/93d16574e66dfe4c2284ffdaca4b0320ade32858cb2cc586c8dd79f127c5/typer-0.23.2.tar.gz", hash = "sha256:a99706a08e54f1aef8bb6a8611503808188a4092808e86addff1828a208af0de", size = 120162, upload-time = "2026-02-16T18:52:40.354Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/2c/dee705c427875402200fe779eb8a3c00ccb349471172c41178336e9599cc/typer-0.23.2-py3-none-any.whl", hash = "sha256:e9c8dc380f82450b3c851a9b9d5a0edf95d1d6456ae70c517d8b06a50c7a9978", size = 56834, upload-time = "2026-02-16T18:52:39.308Z" }, +] + +[[package]] +name = "typer" +version = "0.25.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version == '3.10.*'", +] +dependencies = [ + { name = "annotated-doc", marker = "python_full_version >= '3.10'" }, + { name = "click", version = "8.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "rich", marker = "python_full_version >= '3.10'" }, + { name = "shellingham", marker = "python_full_version >= '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e4/51/9aed62104cea109b820bbd6c14245af756112017d309da813ef107d42e7e/typer-0.25.1.tar.gz", hash = "sha256:9616eb8853a09ffeabab1698952f33c6f29ffdbceb4eaeecf571880e8d7664cc", size = 122276, upload-time = "2026-04-30T19:32:16.964Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/f9/2b3ff4e56e5fa7debfaf9eb135d0da96f3e9a1d5b27222223c7296336e5f/typer-0.25.1-py3-none-any.whl", hash = "sha256:75caa44ed46a03fb2dab8808753ffacdbfea88495e74c85a28c5eefcf5f39c89", size = 58409, upload-time = "2026-04-30T19:32:18.271Z" }, +] + [[package]] name = "types-python-dateutil" version = "2.9.0.20241206"