diff --git a/compressai/layers/attn/__init__.py b/compressai/layers/attn/__init__.py index 331df45b..8319c710 100644 --- a/compressai/layers/attn/__init__.py +++ b/compressai/layers/attn/__init__.py @@ -1,3 +1,13 @@ +from .dictionary import ( + ConvolutionalGLU, + ConvWithDW, + DenseBlock, + DWConv, + MultiScaleAggregation, + MultiScaleDictionaryCrossAttentionGLU, + Scale, + SpatialAttentionModule, +) from .swin import ( WMSA, ConvTransBlock, @@ -16,9 +26,17 @@ __all__ = [ "ConvTransBlock", + "ConvWithDW", + "ConvolutionalGLU", + "DWConv", + "DenseBlock", + "MultiScaleAggregation", + "MultiScaleDictionaryCrossAttentionGLU", "PatchMerging", "PatchSplit", "SWAtten", + "Scale", + "SpatialAttentionModule", "SwinBlock", "WMSA", "WinNoShiftAttention", diff --git a/compressai/layers/attn/dictionary.py b/compressai/layers/attn/dictionary.py new file mode 100644 index 00000000..3a19940a --- /dev/null +++ b/compressai/layers/attn/dictionary.py @@ -0,0 +1,246 @@ +"""Dictionary-based multi-scale cross-attention building blocks. + +These layers implement the entropy-side cross-attention used by the DCAE / +SAAF families, which factor a learned per-image dictionary +(``dt: nn.Parameter`` of shape ``(dict_num, dictionary_dim)``) shared across +slices and cross-attended by every channel-context head. They were lifted +from the upstream DCAE reference implementation (Lu et al., CVPR 2025); the +SAAF entropy stack (Ma et al., CVPR 2026) reuses the exact same blocks. + +Adapted from the dictionary-entropy implementation released alongside the +DCAE / SAAF papers; transformer/attention plumbing follows their public +PyTorch sources. +""" + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn + +from einops import rearrange +from torch import Tensor + +__all__ = [ + "ConvWithDW", + "ConvolutionalGLU", + "DWConv", + "DenseBlock", + "MultiScaleAggregation", + "MultiScaleDictionaryCrossAttentionGLU", + "Scale", + "SpatialAttentionModule", +] + + +class Scale(nn.Module): + """Per-channel learnable scale (used as residual gating).""" + + def __init__( + self, dim: int, init_value: float = 1.0, trainable: bool = True + ) -> None: + super().__init__() + self.scale = nn.Parameter( + init_value * torch.ones(dim), + requires_grad=trainable, + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + return input_tensor * self.scale + + +class DWConv(nn.Module): + """Depthwise 3x3 convolution operating on channel-last activations.""" + + def __init__(self, dim: int) -> None: + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=dim, + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = rearrange(input_tensor, "b h w c -> b c h w") + output = self.dwconv(output) + return rearrange(output, "b c h w -> b h w c") + + +class ConvolutionalGLU(nn.Module): + """Convolutional Gated Linear Unit MLP block (channel-last).""" + + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = (hidden_features or in_features) // 2 + self.fc1 = nn.Linear(in_features, hidden_features * 2) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, input_tensor: Tensor) -> Tensor: + output, gate = self.fc1(input_tensor).chunk(2, dim=-1) + output = self.act(self.dwconv(output)) * gate + return self.fc2(output) + + +class ConvWithDW(nn.Module): + """1x1 -> depthwise 3x3 -> 1x1 conv block with GELU activations (channel-first).""" + + def __init__(self, input_dim: int = 320, output_dim: int = 320) -> None: + super().__init__() + self.in_trans = nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=True) + self.act1 = nn.GELU() + self.dw_conv = nn.Conv2d( + output_dim, + output_dim, + kernel_size=3, + padding=1, + groups=output_dim, + bias=True, + ) + self.act2 = nn.GELU() + self.out_trans = nn.Conv2d(output_dim, output_dim, kernel_size=1, bias=True) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.in_trans(input_tensor) + output = self.act1(output) + output = self.dw_conv(output) + output = self.act2(output) + return self.out_trans(output) + + +class DenseBlock(nn.Module): + """Dense block: ``layer_num`` ConvWithDW stages cat'd then projected back.""" + + def __init__(self, dim: int = 320, layer_num: int = 3) -> None: + super().__init__() + self.layer_num = layer_num + self.conv_layers = nn.ModuleList( + nn.Sequential(nn.GELU(), ConvWithDW(dim, dim)) for _ in range(layer_num) + ) + self.proj = nn.Conv2d(dim * (layer_num + 1), dim, kernel_size=1, bias=True) + + def forward(self, input_tensor: Tensor) -> Tensor: + outputs = [input_tensor] + for layer in self.conv_layers: + outputs.append(layer(outputs[-1])) + return self.proj(torch.cat(outputs, dim=1)) + + +class SpatialAttentionModule(nn.Module): + """CBAM-style spatial attention map (avg + max pooled along channel axis).""" + + def __init__(self, kernel_size: int = 7) -> None: + super().__init__() + self.conv1 = nn.Conv2d( + 2, + 1, + kernel_size, + padding=kernel_size // 2, + bias=False, + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, input_tensor: Tensor) -> Tensor: + average = input_tensor.mean(dim=1, keepdim=True) + maximum, _ = input_tensor.max(dim=1, keepdim=True) + output = torch.cat([average, maximum], dim=1) + return self.sigmoid(self.conv1(output)) + + +class MultiScaleAggregation(nn.Module): + """Combine 1x1 conv + DenseBlock with a CBAM spatial-attention gate.""" + + def __init__(self, dim: int) -> None: + super().__init__() + self.s = nn.Conv2d(dim, dim, kernel_size=1, bias=True) + self.spatial_atte = SpatialAttentionModule() + self.dense = DenseBlock(dim) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = rearrange(input_tensor, "b h w c -> b c h w") + output = self.s(output) + output = self.dense(output) + output = output * self.spatial_atte(output) + return rearrange(output, "b c h w -> b h w c") + + +class MultiScaleDictionaryCrossAttentionGLU(nn.Module): + """Cross-attend a per-slice support tensor against a shared dictionary. + + Used as the channel-context head body in DCAE / SAAF: ``input_tensor`` is + the ``(B, input_dim, H, W)`` slice support and ``dictionary`` is the + shared learnable ``(B, dict_num, dictionary_dim)`` dictionary tensor + (typically materialised once per forward via + ``dt.unsqueeze(0).expand(B, -1, -1)``). Returns ``(B, output_dim, H, W)``. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + mlp_rate: int = 4, + head_num: int = 20, + qkv_bias: bool = True, + dictionary_dim: Optional[int] = None, + ) -> None: + super().__init__() + dict_dim = dictionary_dim or 32 * head_num + if dict_dim % head_num != 0: + raise ValueError("dictionary_dim must be divisible by head_num") + + self.head_num = head_num + self.scale = nn.Parameter(torch.ones(head_num, 1, 1)) + self.x_trans = nn.Linear(input_dim, dict_dim, bias=qkv_bias) + self.ln_scale = nn.LayerNorm(dict_dim) + self.msa = MultiScaleAggregation(dict_dim) + self.lnx = nn.LayerNorm(dict_dim) + self.q_trans = nn.Linear(dict_dim, dict_dim, bias=qkv_bias) + self.dict_ln = nn.LayerNorm(dict_dim) + self.k = nn.Linear(dict_dim, dict_dim, bias=qkv_bias) + self.linear = nn.Linear(dict_dim, dict_dim, bias=qkv_bias) + self.ln_mlp = nn.LayerNorm(dict_dim) + self.mlp = ConvolutionalGLU(dict_dim, mlp_rate * dict_dim) + self.output_trans = nn.Sequential(nn.Linear(dict_dim, output_dim)) + self.softmax = nn.Softmax(dim=-1) + self.res_scale_1 = Scale(dict_dim, init_value=1.0) + self.res_scale_2 = Scale(dict_dim, init_value=1.0) + self.res_scale_3 = Scale(dict_dim, init_value=1.0) + + def forward(self, input_tensor: Tensor, dictionary: Tensor) -> Tensor: + batch_size, _, height, width = input_tensor.size() + output = rearrange(input_tensor, "b c h w -> b h w c") + output = self.x_trans(output) + output = self.msa(self.ln_scale(output)) + self.res_scale_1(output) + + shortcut = output + output = rearrange(self.q_trans(self.lnx(output)), "b h w c -> b c h w") + query = rearrange(output, "b (e c) h w -> b e (h w) c", e=self.head_num) + + dictionary = self.dict_ln(dictionary) + key = rearrange(self.k(dictionary), "b n (e c) -> b e n c", e=self.head_num) + dictionary_value = rearrange( + dictionary, "b n (e c) -> b e n c", e=self.head_num + ) + + scale = self.scale.to(device=query.device, dtype=query.dtype) + similarity = torch.einsum("benc,bedc->bend", query, key) * scale + probabilities = self.softmax(similarity) + output = torch.einsum("bend,bedc->benc", probabilities, dictionary_value) + output = rearrange(output, "b e (h w) c -> b h w (e c)", h=height, w=width) + output = self.linear(output) + self.res_scale_2(shortcut) + output = self.mlp(self.ln_mlp(output)) + self.res_scale_3(output) + output = self.output_trans(output) + return rearrange(output, "b h w c -> b c h w", b=batch_size) diff --git a/compressai/layers/wave/__init__.py b/compressai/layers/wave/__init__.py new file mode 100644 index 00000000..78cd6545 --- /dev/null +++ b/compressai/layers/wave/__init__.py @@ -0,0 +1,33 @@ +"""Generic ``pytorch_wavelets``-backed 2D DWT / IDWT primitives. + +Wraps the optional ``pytorch_wavelets`` dependency into a thin +:class:`DWT2D` / :class:`IDWT2D` channel-concatenated interface that +fits naturally into stride-2 conv chains. The dependency is loaded +lazily (``import compressai`` / ``compressai.zoo`` / +``compressai.layers`` stay free of the wavelet stack); construct +:class:`DWT2D` / :class:`IDWT2D` to trigger it. + +The AuxT-specific :class:`compressai.models._helpers.auxt.WLS` / +:class:`~compressai.models._helpers.auxt.iWLS` blocks (Li et al., ICLR +2025) are built on top of these wrappers but live alongside their +model-integration helpers in :mod:`compressai.models._helpers.auxt`. +Install the optional extras with ``pip install compressai[wavelet]``. +""" + +from __future__ import annotations + +from .wavelet import ( + DWT2D, + DWT_2D, + IDWT2D, + IDWT_2D, + is_pytorch_wavelets_available, +) + +__all__ = [ + "DWT2D", + "DWT_2D", + "IDWT2D", + "IDWT_2D", + "is_pytorch_wavelets_available", +] diff --git a/compressai/layers/wave/wavelet.py b/compressai/layers/wave/wavelet.py new file mode 100644 index 00000000..d5293a23 --- /dev/null +++ b/compressai/layers/wave/wavelet.py @@ -0,0 +1,101 @@ +"""``pytorch_wavelets``-backed 2D DWT / IDWT wrappers. + +Two thin wrappers (:class:`DWT2D` / :class:`IDWT2D`) smooth over +``pytorch_wavelets``' four-tuple subband layout into a single +channel-concatenated tensor that fits naturally into stride-2 conv +chains. + +``pytorch_wavelets`` is an optional dependency installed via the +``compressai[wavelet]`` extras. Module import is non-fatal: top-level +import succeeds without it, but constructing :class:`DWT2D` / +:class:`IDWT2D` raises a friendly :class:`ModuleNotFoundError` when the +extras are missing. + +The AuxT-specific :class:`compressai.models._helpers.auxt.WLS` / +:class:`~compressai.models._helpers.auxt.iWLS` blocks are built on top +of these wrappers but live alongside their model-integration helpers in +:mod:`compressai.models._helpers.auxt` rather than here, so this module +stays a generic wavelet primitive that future non-AuxT models (e.g. +WeConvene) can also reuse. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn + +from torch import Tensor + +try: + from pytorch_wavelets import DWTForward, DWTInverse +except ModuleNotFoundError as error: + DWTForward = None # type: ignore[assignment] + DWTInverse = None # type: ignore[assignment] + _PYTORCH_WAVELETS_IMPORT_ERROR = error +else: + _PYTORCH_WAVELETS_IMPORT_ERROR = None + +__all__ = [ + "DWT2D", + "IDWT2D", + "DWT_2D", + "IDWT_2D", + "is_pytorch_wavelets_available", +] + + +def is_pytorch_wavelets_available() -> bool: + """Return ``True`` when the optional ``pytorch_wavelets`` package is importable.""" + return DWTForward is not None and DWTInverse is not None + + +def _require_pytorch_wavelets() -> None: + if is_pytorch_wavelets_available(): + return + raise ModuleNotFoundError( + "Wavelet layers require the optional dependency `pytorch_wavelets`. " + "Install it via `pip install compressai[wavelet]`." + ) from _PYTORCH_WAVELETS_IMPORT_ERROR + + +class DWT2D(nn.Module): + """Single-level DWT wrapper that channel-concatenates the four subbands. + + Output channels = ``4 * input_channels`` (low + 3 high-pass), spatial + size halved. Use ``wave="haar"`` for the AuxT defaults. + """ + + def __init__(self, wave: str = "haar", mode: str = "zero") -> None: + super().__init__() + _require_pytorch_wavelets() + self.transform = DWTForward(J=1, wave=wave, mode=mode) + + def forward(self, input_tensor: Tensor) -> Tensor: + lowpass, highpass_pyramid = self.transform(input_tensor) + [highpass] = highpass_pyramid + subbands = ( + lowpass, + highpass[:, :, 0, ...], + highpass[:, :, 1, ...], + highpass[:, :, 2, ...], + ) + return torch.cat(subbands, dim=1) + + +class IDWT2D(nn.Module): + """Inverse counterpart of :class:`DWT2D` matching its channel layout.""" + + def __init__(self, wave: str = "haar", mode: str = "zero") -> None: + super().__init__() + _require_pytorch_wavelets() + self.inverse = DWTInverse(wave=wave, mode=mode) + + def forward(self, input_tensor: Tensor) -> Tensor: + lowpass, band_lh, band_hl, band_hh = input_tensor.chunk(4, dim=1) + highpass = torch.stack((band_lh, band_hl, band_hh), dim=2) + return self.inverse((lowpass, [highpass])) + + +# Aliases kept for parity with the upstream AuxT release. +DWT_2D = DWT2D +IDWT_2D = IDWT2D diff --git a/compressai/models/_helpers/auxt.py b/compressai/models/_helpers/auxt.py new file mode 100644 index 00000000..629478ae --- /dev/null +++ b/compressai/models/_helpers/auxt.py @@ -0,0 +1,420 @@ +"""AuxT (Auxiliary Transform) primitives + model-integration helpers. + +Z. Li et al., "On Disentangled Training for Nonlinear Transform in Learned +Image Compression", ICLR 2025 (Spotlight, +https://arxiv.org/abs/2501.13751). + +This module consolidates everything model-side that is AuxT-specific: + +- :class:`OLP` — Orthogonal Linear Projection primitive (no extra + dependency). Used both as a standalone channel mixer (SAAF + ``_AdaptiveFrequencyBlock``) and inside :class:`WLS` / :class:`iWLS`. +- :class:`WLS` / :class:`iWLS` — wavelet-based analysis / synthesis + blocks pairing a :class:`~compressai.layers.wave.DWT2D` / + :class:`~compressai.layers.wave.IDWT2D` with learnable per-subband + scaling and an :class:`OLP` channel mixer. Lazily imports + :mod:`compressai.layers.wave` so :func:`aux_loss` and the side-branch + helpers below stay importable without ``pytorch_wavelets``. +- Side-branch builders + walker (:func:`build_wls_branch`, + :func:`build_iwls_branch`, :func:`forward_with_auxt`, + :class:`AuxTTransform`, :func:`compute_analysis_aux_positions`, + :func:`compute_synthesis_aux_positions`) for hosts that integrate AuxT + as a parallel chain wrapped around ``g_a`` / ``g_s`` (TCM ``use_auxt`` and + any future model with the same six-stage config). +- :func:`aux_loss` — generic OLP regulariser aggregator used by both + TCM-style (side-branch) and SAAF-style (integral) hosts. +- State-dict utilities (:func:`has_auxt_state`, + :func:`is_auxt_wavelet_buffer_key`, + :func:`is_auxt_upstream_wavelet_buffer_key`, + :func:`normalize_upstream_auxt_key`) that any host's + ``from_state_dict`` / ``convert_upstream_*_state_dict`` can reuse. + +The wavelet-only :class:`~compressai.layers.wave.DWT2D` / +:class:`~compressai.layers.wave.IDWT2D` wrappers are kept under +:mod:`compressai.layers.wave` because they are generic +``pytorch_wavelets`` adapters that future non-AuxT models (e.g. +WeConvene) may want to reuse. +""" + +from __future__ import annotations + +from typing import Dict, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch import Tensor + +__all__ = [ + "AuxTTransform", + "OLP", + "WLS", + "aux_loss", + "build_iwls_branch", + "build_wls_branch", + "compute_analysis_aux_positions", + "compute_synthesis_aux_positions", + "forward_with_auxt", + "has_auxt_state", + "is_auxt_upstream_wavelet_buffer_key", + "is_auxt_wavelet_buffer_key", + "iWLS", + "normalize_upstream_auxt_key", +] + + +# --------------------------------------------------------------------------- +# AuxT primitives — OLP, WLS, iWLS +# --------------------------------------------------------------------------- + + +class OLP(nn.Module): + """Orthogonal linear projection with an auxiliary orthogonality regulariser. + + Forward is a plain :class:`nn.Linear` from ``in_features`` to ``out_dim``; + :meth:`loss` returns ``MSE(W @ Wᵀ, I)`` (or ``Wᵀ @ W`` if the projection + is over-complete) which the host model adds to its training objective + via :func:`aux_loss`. + """ + + def __init__(self, in_features: int, out_dim: int, bias: bool = True) -> None: + super().__init__() + self.linear = nn.Linear(in_features, out_dim, bias=bias) + self.in_dim = in_features + self.out_dim = out_dim + identity_size = min(in_features, out_dim) + self.register_buffer( + "identity_matrix", torch.eye(identity_size), persistent=False + ) + + def loss(self) -> Tensor: + weight = self.linear.weight + gram = ( + weight @ weight.t() if self.in_dim > self.out_dim else weight.t() @ weight + ) + target = self.identity_matrix.to(device=gram.device, dtype=gram.dtype) + return F.mse_loss(gram, target) + + def forward(self, input_tensor: Tensor) -> Tensor: + return self.linear(input_tensor) + + +def _make_scaling_factors(channels: int) -> Tensor: + return torch.cat( + ( + torch.full((1, 1, channels), 0.5), + torch.full((1, 1, channels), 0.5), + torch.full((1, 1, channels), 0.5), + torch.zeros((1, 1, channels)), + ), + dim=2, + ) + + +class WLS(nn.Module): + r"""Wavelet Linear Scaling (analysis) block from Li et al., ICLR 2025. + + Auxiliary downsampling block: applies a 2D discrete wavelet transform, + learnable per-subband scaling factors, and an :class:`OLP` channel + mixer. Used as a building block inside the AuxT_enc side branch + (:func:`build_wls_branch`) but also valid standalone. + + :class:`compressai.layers.wave.DWT2D` is imported lazily so this module + stays importable without the ``pytorch_wavelets`` extra; constructing + a :class:`WLS` instance triggers the dependency check. + """ + + def __init__(self, in_dim: int, out_dim: int, wave: str = "haar") -> None: + super().__init__() + from compressai.layers.wave import DWT2D + + self.dwt = DWT2D(wave=wave) + self.olp = OLP(in_dim * 4, out_dim) + self.scaling_factors = nn.Parameter(_make_scaling_factors(in_dim)) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.dwt(input_tensor) + batch_size, _, height, width = output.shape + output = output.view(batch_size, -1, height * width).permute(0, 2, 1) + output = output * torch.exp(self.scaling_factors) + output = self.olp(output) + output = output.view(batch_size, height, width, -1) + return output.permute(0, 3, 1, 2).contiguous() + + +class iWLS(nn.Module): + r"""Inverse Wavelet Linear Scaling (synthesis) block from Li et al., + ICLR 2025. + + Mirror of :class:`WLS`: applies an :class:`OLP` channel mixer, undoes + the learnable per-subband scaling, and reconstructs the spatial signal + with the inverse 2D DWT. + """ + + def __init__(self, in_dim: int, out_dim: int, wave: str = "haar") -> None: + super().__init__() + from compressai.layers.wave import IDWT2D + + self.idwt = IDWT2D(wave=wave) + self.olp = OLP(in_dim, out_dim * 4) + self.scaling_factors = nn.Parameter(_make_scaling_factors(out_dim)) + + def forward(self, input_tensor: Tensor) -> Tensor: + batch_size, _, height, width = input_tensor.shape + output = input_tensor.view(batch_size, -1, height * width).permute(0, 2, 1) + output = self.olp(output) + output = output / torch.exp(self.scaling_factors) + output = output.view(batch_size, height, width, -1) + output = output.permute(0, 3, 1, 2).contiguous() + return self.idwt(output) + + +# --------------------------------------------------------------------------- +# Side-branch builders + walker (TCM-style integration) +# --------------------------------------------------------------------------- + + +def build_wls_branch(N: int, M: int) -> nn.ModuleList: + """Standard 4-layer ``AuxT_enc`` analysis branch. + + Channel layout (matches Li et al., ICLR 2025 Sec. 3.2 reference impl): + + - ``WLS(3, 2N)`` — RGB image -> AuxT working width + - ``WLS(2N, 2N)`` x 2 — interior stages + - ``WLS(2N, M)`` — final stage matches the host's latent channels so + the output can be summed into the last ``g_a`` layer. + """ + return nn.ModuleList( + [ + WLS(3, 2 * N), + WLS(2 * N, 2 * N), + WLS(2 * N, 2 * N), + WLS(2 * N, M), + ] + ) + + +def build_iwls_branch(N: int, M: int) -> nn.ModuleList: + """Standard 4-layer ``AuxT_dec`` synthesis branch. + + Mirror of :func:`build_wls_branch`; final ``iWLS(2N, 3)`` reconstructs + an RGB image. + """ + return nn.ModuleList( + [ + iWLS(M, 2 * N), + iWLS(2 * N, 2 * N), + iWLS(2 * N, 2 * N), + iWLS(2 * N, 3), + ] + ) + + +def forward_with_auxt( + transform: nn.Sequential, + auxiliary_layers: Optional[nn.ModuleList], + merge_positions: Sequence[int], + input_tensor: Tensor, +) -> Tensor: + """Walk ``transform`` layer-by-layer, summing each AuxT[i] output at the + matching ``merge_positions``. + + When ``auxiliary_layers is None`` (i.e. AuxT was not constructed) this + collapses to ``transform(input_tensor)``, so hosts can call this + unconditionally regardless of ``use_auxt``. + + Raises :class:`RuntimeError` when ``len(auxiliary_layers)`` does not + match the number of merge positions actually consumed during the walk + — usually a sign that ``merge_positions`` was computed against the + wrong stage config. + """ + if auxiliary_layers is None: + return transform(input_tensor) + + if len(merge_positions) < len(auxiliary_layers): + raise RuntimeError( + "AuxT merge positions do not match auxiliary depth " + f"(merge_positions has {len(merge_positions)} entries; " + f"auxiliary_layers has {len(auxiliary_layers)})." + ) + + output = input_tensor + auxiliary = input_tensor + aux_index = 0 + for layer_index, layer in enumerate(transform): + output = layer(output) + if ( + aux_index < len(auxiliary_layers) + and layer_index == merge_positions[aux_index] + ): + auxiliary = auxiliary_layers[aux_index](auxiliary) + output = output + auxiliary + aux_index += 1 + if aux_index != len(auxiliary_layers): + raise RuntimeError( + "AuxT merge positions do not match auxiliary depth " + f"(merged {aux_index} of {len(auxiliary_layers)})." + ) + return output + + +class AuxTTransform(nn.Module): + """Wrap a transform with an AuxT side branch. + + The wrapper makes an AuxT-augmented transform look like the actual + ``g_a`` / ``g_s`` module, keeping model forward/compress/decompress + paths identical to non-AuxT hosts while storing AuxT parameters under + ``{g_a,g_s}.auxiliary_layers.*``. + """ + + transform: nn.Module + auxiliary_layers: nn.ModuleList + + def __init__( + self, + transform: nn.Module, + auxiliary_layers: nn.ModuleList, + merge_positions: Sequence[int], + ) -> None: + super().__init__() + self.transform = transform + self.auxiliary_layers = auxiliary_layers + self.merge_positions = tuple(merge_positions) + + def forward(self, input_tensor: Tensor) -> Tensor: + return forward_with_auxt( + self.transform, + self.auxiliary_layers, + self.merge_positions, + input_tensor, + ) + + +def compute_analysis_aux_positions( + config: Sequence[int], +) -> Tuple[int, int, int, int]: + """Layer indices in ``g_a`` where ``AuxT_enc[i]`` outputs are summed in, + for hosts using TCM's six-stage ``config`` convention. + + Derives the four boundaries by accumulating the depth of each + :func:`compressai.models.tcm._make_mixed_stage` plus the inserted stride + convolution. With the default TCM ``config = (2, 2, 2, 2, 2, 2)`` the + positions land at ``(0, 3, 6, 9)`` of the 10-element ``g_a`` + Sequential. + """ + return ( + 0, + config[0] + 1, + config[0] + config[1] + 2, + config[0] + config[1] + config[2] + 3, + ) + + +def compute_synthesis_aux_positions( + config: Sequence[int], +) -> Tuple[int, int, int, int]: + """Mirror of :func:`compute_analysis_aux_positions` for ``g_s``. + + Uses ``config[3:]`` (the synthesis stages) so the positions land at + ``(c3, c3+c4+1, c3+c4+c5+2, c3+c4+c5+3)``; with the default config + this is ``(2, 5, 8, 9)``. + """ + return ( + config[3], + config[3] + config[4] + 1, + config[3] + config[4] + config[5] + 2, + config[3] + config[4] + config[5] + 3, + ) + + +# --------------------------------------------------------------------------- +# OLP regulariser aggregation (works for TCM-style and SAAF-style hosts) +# --------------------------------------------------------------------------- + + +def aux_loss(model: nn.Module) -> Tensor: + """Sum :meth:`OLP.loss` over every :class:`OLP` in ``model``'s submodule + tree, returning a 0-d :class:`Tensor`. + + Returns zero on the same device / dtype as the first model parameter + when no :class:`OLP` modules are present, so callers can + unconditionally add the result to their training objective regardless + of whether AuxT is enabled. + """ + losses = [module.loss() for module in model.modules() if isinstance(module, OLP)] + if losses: + return torch.stack(losses).sum() + parameter = next(model.parameters()) + return torch.zeros((), device=parameter.device, dtype=parameter.dtype) + + +# --------------------------------------------------------------------------- +# State-dict helpers — checkpoint detection and upstream key normalization +# --------------------------------------------------------------------------- + + +def has_auxt_state(state_dict: Dict[str, Tensor]) -> bool: + """``True`` when a native state-dict carries AuxT wrapper keys. + + TCM stores AuxT branches under ``g_a.auxiliary_layers.*`` and + ``g_s.auxiliary_layers.*``; converters are responsible for translating + upstream ``AuxT_enc.*`` / ``AuxT_dec.*`` keys to that final layout. + """ + return any( + key.startswith("g_a.auxiliary_layers.") + or key.startswith("g_s.auxiliary_layers.") + for key in state_dict + ) + + +def is_auxt_wavelet_buffer_key(key: str) -> bool: + """Match ``pytorch_wavelets`` DWT/IDWT kernel buffer paths in the + native AuxT wrapper layout. + + Hosts that allow strict ``load_state_dict`` should add these to the + "allowed missing" set: ``pytorch_wavelets`` re-registers the kernels + at module construction time, so they are present in the model's own + state-dict but may be absent from a checkpoint saved by a version + that did not persist them. + """ + if key.startswith("g_a.auxiliary_layers."): + return ".dwt.transform." in key + if key.startswith("g_s.auxiliary_layers."): + return ".idwt.inverse." in key + return False + + +def is_auxt_upstream_wavelet_buffer_key(key: str) -> bool: + """Match the wavelet kernel buffer names used by the upstream LIC_TCM + AuxT release (``w_ll`` / ``w_lh`` / ``w_hl`` / ``w_hh`` for DWT and + ``filters`` for IDWT). + + Convert scripts should drop these — the + :mod:`pytorch_wavelets`-backed :class:`compressai.layers.wave.DWT2D` + / :class:`IDWT2D` regenerate equivalent kernels at construction. + """ + if key.startswith("AuxT_enc.") and ".dwt." in key: + return key.rsplit(".", 1)[-1] in {"w_ll", "w_lh", "w_hl", "w_hh"} + if key.startswith("AuxT_dec.") and ".idwt." in key: + return key.rsplit(".", 1)[-1] == "filters" + return False + + +def normalize_upstream_auxt_key(key: str) -> Optional[str]: + """Translate upstream AuxT keys to the native wrapper layout. + + Upstream LIC_TCM stores ``AuxT_enc.*`` / ``AuxT_dec.*`` at the model + root and names the projection module ``.OLP.``; CompressAI stores them + under ``g_a.auxiliary_layers.*`` / ``g_s.auxiliary_layers.*`` with + lowercase ``.olp.``. Returns ``None`` for non-AuxT keys. + """ + if key.startswith("AuxT_enc."): + return "g_a.auxiliary_layers." + key[len("AuxT_enc.") :].replace( + ".OLP.", ".olp." + ) + if key.startswith("AuxT_dec."): + return "g_s.auxiliary_layers." + key[len("AuxT_dec.") :].replace( + ".OLP.", ".olp." + ) + return None diff --git a/compressai/models/_helpers/dictionary_context.py b/compressai/models/_helpers/dictionary_context.py new file mode 100644 index 00000000..da8f66ab --- /dev/null +++ b/compressai/models/_helpers/dictionary_context.py @@ -0,0 +1,209 @@ +"""Dictionary-based channel-context heads for DCAE / SAAF. + +DCAE (Lu et al., CVPR 2025) and SAAF (Ma et al., CVPR 2026) share an +entropy stack that augments the per-slice channel support with a learned +**shared dictionary** (a single ``dt: nn.Parameter`` of shape +``(dict_num, dictionary_dim)`` that all K slices cross-attend against). +This module provides: + +- :class:`SharedDictionary` — owns the ``dt`` Parameter at the model level + (path: ``shared_dictionary.dt``). Heads access it via a closure stored as + a plain Python attribute so the parameter is not duplicated under K + per-slice paths in the state-dict. +- :class:`DictionaryMeanScaleContextHead` — per-slice channel-context head + that runs :class:`MultiScaleDictionaryCrossAttentionGLU` on its input, + concatenates the cross-attention output with the input, and feeds the + combined ``support`` tensor into separate ``mean_cc`` / ``scale_cc`` + Sequentials. Drops into + model-local side-parameter channel-groups paths, just like + :class:`~compressai.models._helpers.channel_context.MeanScaleContextHead`. +- :func:`build_dictionary_mean_scale_head` — convenience factory for the + DCAE / SAAF dictionary context heads. + +Why a closure for the shared dictionary, not a submodule? Storing +``SharedDictionary`` as a child module of every head would either: +(a) duplicate the ``dt`` Parameter under K paths in :meth:`state_dict` +(verified experimentally — :meth:`nn.Module.state_dict` traverses each +referencing submodule independently), or (b) require an invasive change to +upstream :class:`ChannelGroupsLatentCodec` to add a ``shared_modules`` slot. +Storing as a plain Python attribute (Callable) sidesteps both. +""" + +from __future__ import annotations + +from typing import Callable, Optional, Sequence + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.layers.attn.dictionary import MultiScaleDictionaryCrossAttentionGLU +from compressai.models._helpers.slice_helpers import make_entropy_transform + +__all__ = [ + "DictionaryMeanScaleContextHead", + "SharedDictionary", + "build_dictionary_mean_scale_head", +] + + +class SharedDictionary(nn.Module): + """Holds the learned dictionary tensor cross-attended by DCAE / SAAF heads. + + Owned by the model (e.g. ``self.shared_dictionary = SharedDictionary(...)``); + a closure over this instance is threaded into every per-slice + :class:`DictionaryMeanScaleContextHead` so all heads share the same + underlying ``dt`` Parameter without duplicating it in the state-dict. + """ + + dt: nn.Parameter + + def __init__(self, dict_num: int, dictionary_dim: int) -> None: + super().__init__() + self.dt = nn.Parameter(torch.randn(dict_num, dictionary_dim)) + + def expand_for(self, batch_size: int) -> Tensor: + """Broadcast ``dt`` to ``(batch_size, dict_num, dictionary_dim)``. + + :class:`MultiScaleDictionaryCrossAttentionGLU` expects a per-batch + dictionary tensor; we materialise a view here without copying. + """ + return self.dt.unsqueeze(0).expand(batch_size, -1, -1) + + +class DictionaryMeanScaleContextHead(nn.Module): + """Channel-context head with shared dictionary cross-attention. + + Forward flow (DCAE / SAAF, with + model-local side-parameter channel-groups path):: + + x = cat([latent_means(M), latent_scales(M), *prev_y_hat], dim=1) + dict_info = cross_attention(x, shared_dt) # (B, dict_output_ch, H, W) + support = cat([x, dict_info], dim=1) # (B, support_ch + dict_output_ch, H, W) + mean = mean_cc(support) # (B, slice_ch, H, W) + scale = scale_cc(support) + out = cat([scale, mean], dim=1) # chunks=("scales","means") + if emit_mean_support: out = cat([out, support], dim=1) + + The trailing ``support`` block (when ``emit_mean_support=True``) is + consumed by the model-local LRP Gaussian leaf with + ``mean_support_trail_channels = support_ch + dict_output_ch`` to recover + the upstream ``cat(support, y_hat)`` LRP input layout, enabling + byte-for-byte transfer of upstream LRP weights. + + Note on input ordering: the upstream DCAE source assembles its query as + ``cat([latent_scales, latent_means, *support_slices])`` (scales before + means), whereas the containerized wiring used here produces + ``cat([latent_means, latent_scales, *prev_y_hat])`` (means before + scales). The ``cc_mean`` / ``cc_scale`` / + ``cross_attention`` / ``lrp_transform`` first-conv weights from upstream + DCAE / SAAF checkpoints therefore need their leading 2M input channels + swapped (channels ``[0:M]`` ↔ ``[M:2M]``) at conversion time — + ``examples/convert_{dcae,saaf}_checkpoint.py`` handles this rename. + """ + + cross_attention: MultiScaleDictionaryCrossAttentionGLU + mean_cc: nn.Module + scale_cc: nn.Module + + def __init__( + self, + cross_attention: MultiScaleDictionaryCrossAttentionGLU, + mean_cc: nn.Module, + scale_cc: nn.Module, + *, + dictionary_provider: Callable[[int], Tensor], + emit_mean_support: bool = False, + ) -> None: + super().__init__() + self.cross_attention = cross_attention + self.mean_cc = mean_cc + self.scale_cc = scale_cc + # Plain Python attribute (Callable) — not registered as a submodule. + # See module docstring for the rationale. + self._dictionary_provider = dictionary_provider + self.emit_mean_support = bool(emit_mean_support) + + def forward(self, x: Tensor) -> Tensor: + batch_size = x.size(0) + dictionary = self._dictionary_provider(batch_size) + dict_info = self.cross_attention(x, dictionary) + support = torch.cat([x, dict_info], dim=1) + mean = self.mean_cc(support) + scale = self.scale_cc(support) + out = torch.cat([scale, mean], dim=1) + if self.emit_mean_support: + out = torch.cat([out, support], dim=1) + return out + + +def build_dictionary_mean_scale_head( + slice_ch: int, + support_ch: int, + *, + shared_dictionary: SharedDictionary, + dict_output_ch: int, + cross_attention_kwargs: Optional[dict] = None, + widths: Sequence[int] = (224, 128), + emit_mean_support: bool = False, +) -> DictionaryMeanScaleContextHead: + """Construct a :class:`DictionaryMeanScaleContextHead`. + + Parameters + ---------- + slice_ch + Output channel count of each ``mean_cc`` / ``scale_cc`` head + (= channel width of the slice being predicted). + support_ch + Input channel count handed to the head by + model-local side-parameter channel-groups path. This equals + ``2 * M + slice_ch * support_count`` (the head does no internal + split: the cross-attention treats it as a flat support tensor). + shared_dictionary + :class:`SharedDictionary` instance owned by the model. The head + captures it via a closure to avoid duplicating ``dt`` in the + state-dict (one path: ``shared_dictionary.dt``, regardless of K). + dict_output_ch + Output channel count of + :class:`MultiScaleDictionaryCrossAttentionGLU`. DCAE / SAAF use + ``M`` so the cross-attention contributes another M channels to the + ``support`` tensor that ``mean_cc`` / ``scale_cc`` consume. + cross_attention_kwargs + Extra kwargs forwarded to + :class:`MultiScaleDictionaryCrossAttentionGLU` (``head_num``, + ``mlp_rate``, ``qkv_bias``). ``dictionary_dim`` is filled in + automatically from ``shared_dictionary.dt.size(1)``. + widths + Hidden conv widths inside the ``mean_cc`` / ``scale_cc`` + Sequentials. Defaults to ``(224, 128)`` (the DCAE / SAAF / + TCM / CCA convention). + emit_mean_support + When ``True``, append the ``support`` tensor to the head output. + Pair with + the model-local LRP Gaussian leaf and + ``mean_support_trail_channels = support_ch + dict_output_ch`` to + reproduce the upstream DCAE / SAAF LRP input layout + ``cat(support, y_hat)``. + """ + cross_attention_kwargs = dict(cross_attention_kwargs or {}) + cross_attention_kwargs.setdefault("dictionary_dim", shared_dictionary.dt.size(1)) + cross_attention = MultiScaleDictionaryCrossAttentionGLU( + input_dim=support_ch, + output_dim=dict_output_ch, + **cross_attention_kwargs, + ) + cc_in_ch = support_ch + dict_output_ch + mean_cc = make_entropy_transform(cc_in_ch, slice_ch, widths=widths) + scale_cc = make_entropy_transform(cc_in_ch, slice_ch, widths=widths) + + def dictionary_provider(batch_size: int) -> Tensor: + return shared_dictionary.expand_for(batch_size) + + return DictionaryMeanScaleContextHead( + cross_attention=cross_attention, + mean_cc=mean_cc, + scale_cc=scale_cc, + dictionary_provider=dictionary_provider, + emit_mean_support=emit_mean_support, + ) diff --git a/compressai/models/dcae.py b/compressai/models/dcae.py new file mode 100644 index 00000000..31f13056 --- /dev/null +++ b/compressai/models/dcae.py @@ -0,0 +1,805 @@ +"""DCAE (Dictionary-based Channel-wise Auto-regressive Entropy) model. + +J. Lu, L. Zhang, X. Zhou, M. Li, W. Li, S. Gu, "Learned Image Compression +with Dictionary-based Entropy Model", IEEE/CVF Conf. on Computer Vision +and Pattern Recognition (CVPR), 2025 +(https://arxiv.org/abs/2504.00496). + +Adapted from the upstream reference implementation; the entropy stack uses +the containerized +:class:`~compressai.latent_codecs.HyperpriorLatentCodec` / +:class:`~compressai.latent_codecs.ChannelGroupsLatentCodec` wiring shared +with STF / WACNN / TCM / CCA, plus the dictionary cross-attention head +:class:`~compressai.models._helpers.dictionary_context.DictionaryMeanScaleContextHead` +introduced for DCAE / SAAF. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange +from timm.layers import DropPath +from torch import Tensor + +from compressai.entropy_models import EntropyBottleneck +from compressai.latent_codecs import ( + ChannelGroupsLatentCodec, + EntropyBottleneckLatentCodec, + GaussianConditionalLatentCodec, + HyperpriorLatentCodec, +) +from compressai.layers.attn.dictionary import ConvolutionalGLU, Scale +from compressai.layers.attn.swin import pad_to_window_multiple +from compressai.models._helpers.dictionary_context import ( + SharedDictionary, + build_dictionary_mean_scale_head, +) +from compressai.models._helpers.slice_helpers import ( + infer_num_slices, + lrp_support_channels, + make_entropy_transform, +) +from compressai.models.base import CompressionModel +from compressai.models.sensetime import ResidualBottleneckBlock +from compressai.models.utils import conv, deconv +from compressai.registry import register_model + +__all__ = ["DCAE"] + + +class _DualHyperSynthesis(nn.Module): + h_mean_s: nn.Module + h_scale_s: nn.Module + + def __init__(self, h_mean_s: nn.Module, h_scale_s: nn.Module) -> None: + super().__init__() + self.h_mean_s = h_mean_s + self.h_scale_s = h_scale_s + + def forward(self, z_hat: Tensor) -> Tensor: + return torch.cat([self.h_mean_s(z_hat), self.h_scale_s(z_hat)], dim=1) + + +class _LRPGaussianLatentCodec(GaussianConditionalLatentCodec): + lrp_transform: nn.Module + + def __init__( + self, + lrp_transform: nn.Module, + *, + lrp_scale: float = 0.5, + mean_support_trail_channels: int = 0, + **gc_kwargs: Any, + ) -> None: + super().__init__(**gc_kwargs) + self.lrp_transform = lrp_transform + self.lrp_scale = float(lrp_scale) + self.mean_support_trail_channels = int(mean_support_trail_channels) + + def _split_ctx_params(self, ctx_params: Tensor) -> Tuple[Tensor, Tensor]: + if self.mean_support_trail_channels <= 0: + return ctx_params, ctx_params + trail = self.mean_support_trail_channels + gaussian_params = ctx_params[:, :-trail] + mean_support = ctx_params[:, -trail:] + return gaussian_params, mean_support + + def _apply_lrp(self, mean_support: Tensor, y_hat: Tensor) -> Tensor: + lrp = self.lrp_scale * torch.tanh( + self.lrp_transform(torch.cat([mean_support, y_hat], dim=1)) + ) + return y_hat + lrp + + def forward(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().forward(y, gaussian_params) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().compress(y, gaussian_params) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + def decompress( + self, + strings: List[List[bytes]], + shape: Tuple[int, ...], + ctx_params: Tensor, + **kwargs: Any, + ) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().decompress(strings, shape, gaussian_params, **kwargs) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + +class _SideContextChannelGroupsLatentCodec(ChannelGroupsLatentCodec): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + if "y0" not in self.channel_context: + raise ValueError("side-parameter channel groups require channel_context.y0") + + def _get_ctx_params( + self, k: int, side_params: Tensor, y_hat_: List[Tensor] + ) -> Tensor: + if k == 0: + return self.channel_context["y0"](side_params) + support = [y_hat_[i] for i in self.support_slices[k]] + if not support: + return self.channel_context[f"y{k}"](side_params) + return self.channel_context[f"y{k}"]( + self.merge_params(side_params, self.merge_y(*support)) + ) + + +# --------------------------------------------------------------------------- +# DCAE-private g_a / g_s building blocks +# (Inlined from the upstream DCAE source rather than lifted to compressai/layers/ +# because they are not reused by other models in the PR series.) +# --------------------------------------------------------------------------- + + +class _ResidualBottleneckBlockWithStride(nn.Module): + """DCAE stride-2 residual-bottleneck downsampling block.""" + + def __init__(self, in_ch: int, out_ch: int) -> None: + super().__init__() + self.conv = conv(in_ch, out_ch, kernel_size=5, stride=2) + self.res1 = ResidualBottleneckBlock(out_ch, out_ch) + self.res2 = ResidualBottleneckBlock(out_ch, out_ch) + self.res3 = ResidualBottleneckBlock(out_ch, out_ch) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.conv(input_tensor) + output = self.res1(output) + output = self.res2(output) + return self.res3(output) + + +class _ResidualBottleneckBlockWithUpsample(nn.Module): + """DCAE residual-bottleneck upsampling block.""" + + def __init__(self, in_ch: int, out_ch: int) -> None: + super().__init__() + self.res1 = ResidualBottleneckBlock(in_ch, in_ch) + self.res2 = ResidualBottleneckBlock(in_ch, in_ch) + self.res3 = ResidualBottleneckBlock(in_ch, in_ch) + self.conv = deconv(in_ch, out_ch, kernel_size=5, stride=2) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.res1(input_tensor) + output = self.res2(output) + output = self.res3(output) + return self.conv(output) + + +class _WMSA(nn.Module): + """Windowed multi-head self-attention with optional cyclic shift. + + Lifted verbatim from the upstream DCAE source. ``type`` is ``"W"`` for a + plain window-attention pass or ``"SW"`` for a shifted-window pass. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + head_dim: int, + window_size: int, + type: str, + ) -> None: + super().__init__() + if type not in {"W", "SW"}: + raise ValueError(f"Unsupported attention type: {type}") + if input_dim % head_dim != 0: + raise ValueError("input_dim must be divisible by head_dim") + + self.input_dim = input_dim + self.output_dim = output_dim + self.head_dim = head_dim + self.scale = head_dim**-0.5 + self.n_heads = input_dim // head_dim + self.window_size = window_size + self.type = type + self.embedding_layer = nn.Linear(input_dim, 3 * input_dim, bias=True) + relative_position = torch.zeros( + self.n_heads, 2 * window_size - 1, 2 * window_size - 1 + ) + nn.init.trunc_normal_(relative_position, std=0.02) + self.relative_position_params = nn.Parameter(relative_position) + self.linear = nn.Linear(input_dim, output_dim) + + def generate_mask( + self, + height_windows: int, + width_windows: int, + window_size: int, + shift: int, + ) -> Tensor: + attention_mask = torch.zeros( + height_windows, + width_windows, + window_size, + window_size, + window_size, + window_size, + dtype=torch.bool, + device=self.relative_position_params.device, + ) + if self.type == "W": + return attention_mask + + split = window_size - shift + attention_mask[-1, :, :split, :, split:, :] = True + attention_mask[-1, :, split:, :, :split, :] = True + attention_mask[:, -1, :, :split, :, split:] = True + attention_mask[:, -1, :, split:, :, :split] = True + return rearrange(attention_mask, "h w p1 p2 p3 p4 -> 1 1 (h w) (p1 p2) (p3 p4)") + + def relative_embedding(self) -> Tensor: + coords = torch.stack( + torch.meshgrid( + torch.arange( + self.window_size, device=self.relative_position_params.device + ), + torch.arange( + self.window_size, device=self.relative_position_params.device + ), + indexing="ij", + ), + dim=-1, + ).view(-1, 2) + relation = coords[:, None, :] - coords[None, :, :] + self.window_size - 1 + return self.relative_position_params[ + :, relation[:, :, 0].long(), relation[:, :, 1].long() + ] + + def forward(self, input_tensor: Tensor) -> Tensor: + if self.type != "W": + input_tensor = torch.roll( + input_tensor, + shifts=(-(self.window_size // 2), -(self.window_size // 2)), + dims=(1, 2), + ) + + output = rearrange( + input_tensor, + "b (h p1) (w p2) c -> b h w p1 p2 c", + p1=self.window_size, + p2=self.window_size, + ) + height_windows = output.size(1) + width_windows = output.size(2) + output = rearrange( + output, + "b h w p1 p2 c -> b (h w) (p1 p2) c", + p1=self.window_size, + p2=self.window_size, + ) + + qkv = self.embedding_layer(output) + qkv = rearrange( + qkv, + "b nw np (three heads dim) -> three b heads nw np dim", + three=3, + heads=self.n_heads, + dim=self.head_dim, + ) + query, key, value = qkv[0], qkv[1], qkv[2] + + similarity = torch.einsum("bhwnc,bhwmc->bhwnm", query, key) * self.scale + similarity = similarity + rearrange( + self.relative_embedding(), "h p q -> 1 h 1 p q" + ) + if self.type != "W": + attention_mask = self.generate_mask( + height_windows, + width_windows, + self.window_size, + shift=self.window_size // 2, + ) + similarity = similarity.masked_fill(attention_mask, float("-inf")) + + probabilities = similarity.softmax(dim=-1) + output = torch.einsum("bhwij,bhwjc->bhwic", probabilities, value) + output = rearrange(output, "b h w p c -> b w p (h c)") + output = self.linear(output) + output = rearrange( + output, + "b (h w) (p1 p2) c -> b (h p1) (w p2) c", + h=height_windows, + p1=self.window_size, + p2=self.window_size, + ) + + if self.type != "W": + output = torch.roll( + output, + shifts=(self.window_size // 2, self.window_size // 2), + dims=(1, 2), + ) + return output + + +class _ResScaleConvolutionGateBlock(nn.Module): + """Residual-scaled WMSA + ConvolutionalGLU MLP block (channel-last).""" + + def __init__( + self, + input_dim: int, + output_dim: int, + head_dim: int, + window_size: int, + drop_path: float, + type: str = "W", + input_resolution: Optional[Tuple[int, int]] = None, + ) -> None: + del output_dim, input_resolution + super().__init__() + self.ln1 = nn.LayerNorm(input_dim) + self.msa = _WMSA(input_dim, input_dim, head_dim, window_size, type) + self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() + self.ln2 = nn.LayerNorm(input_dim) + self.mlp = ConvolutionalGLU(input_dim, input_dim * 4) + self.res_scale_1 = Scale(input_dim, init_value=1.0) + self.res_scale_2 = Scale(input_dim, init_value=1.0) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.res_scale_1(input_tensor) + self.drop_path( + self.msa(self.ln1(input_tensor)) + ) + return self.res_scale_2(output) + self.drop_path(self.mlp(self.ln2(output))) + + +class _SwinBlockWithConvMulti(nn.Module): + """Stack of ``block_num`` WMSA layers (W / SW alternating) followed by a 3x3 conv.""" + + def __init__( + self, + input_dim: int, + output_dim: int, + head_dim: int, + window_size: int, + drop_path: float, + block: type[nn.Module] = _ResScaleConvolutionGateBlock, + block_num: int = 2, + **kwargs, + ) -> None: + del kwargs + super().__init__() + self.layers = nn.ModuleList( + block( + input_dim, + input_dim, + head_dim, + window_size, + drop_path, + type="W" if index % 2 == 0 else "SW", + ) + for index in range(block_num) + ) + self.block_num = block_num + self.conv = conv(input_dim, output_dim, 3, 1) + self.window_size = window_size + + def forward(self, input_tensor: Tensor) -> Tensor: + output, pad_height, pad_width = pad_to_window_multiple( + input_tensor, self.window_size + ) + output = rearrange(output, "b c h w -> b h w c") + for layer in self.layers: + output = layer(output) + output = rearrange(output, "b h w c -> b c h w") + output = self.conv(output) + F.pad(input_tensor, (0, pad_width, 0, pad_height)) + if pad_height > 0 or pad_width > 0: + output = output[:, :, : input_tensor.size(2), : input_tensor.size(3)] + return output.contiguous() + + +# --------------------------------------------------------------------------- +# DCAE model +# --------------------------------------------------------------------------- + + +@register_model("dcae") +class DCAE(CompressionModel): + """DCAE model (Lu et al., CVPR 2025). + + Containerized entropy stack: + + .. code-block:: text + + latent_codec = HyperpriorLatentCodec( + h_a=h_a, + h_s=_DualHyperSynthesis(h_mean_s, h_scale_s), # original h_z_s2/h_z_s1 swapped to means/scales + latent_codec={ + "z": EntropyBottleneckLatentCodec(EntropyBottleneck(N), quantizer="noise"), + "y": _SideContextChannelGroupsLatentCodec(...) wired + inline with per-slice DictionaryMeanScaleContextHead + + _LRPGaussianLatentCodec(mean_support_trail_channels=...). + }, + ) + + The shared dictionary tensor ``dt`` lives at ``self.shared_dictionary.dt`` + (a single state-dict path); each per-slice + :class:`DictionaryMeanScaleContextHead` accesses it via a closure to + avoid duplicating the parameter under K paths. + """ + + def __init__( + self, + head_dim: Optional[Sequence[int]] = None, + N: int = 192, + M: int = 320, + hyper_channels: int = 192, + num_slices: int = 5, + max_support_slices: int = 5, + feature_dims: Optional[Sequence[int]] = None, + block_num: Optional[Sequence[int]] = None, + dict_num: int = 128, + dict_head_num: int = 20, + dictionary_dim: Optional[int] = None, + window_size: int = 8, + hyper_window_size: int = 4, + hyper_head_dim: int = 32, + **kwargs, + ) -> None: + super().__init__(**kwargs) + head_dim = tuple(head_dim or (8, 16, 32, 32, 16, 8)) + feature_dims = tuple(feature_dims or (96, 144, 256)) + block_num = tuple(block_num or (1, 2, 12)) + dictionary_dim = dictionary_dim or 32 * dict_head_num + if len(head_dim) != 6: + raise ValueError("head_dim must have six entries") + if len(feature_dims) != 3: + raise ValueError("feature_dims must have three entries") + if len(block_num) != 3: + raise ValueError("block_num must have three entries") + if M % num_slices != 0: + raise ValueError("M must be divisible by num_slices") + + self.N = int(N) + self.M = int(M) + self.hyper_channels = int(hyper_channels) + self.num_slices = int(num_slices) + self.max_support_slices = int(max_support_slices) + self.head_dim = head_dim + self.feature_dims = feature_dims + self.block_num = block_num + self.dict_num = int(dict_num) + self.dict_head_num = int(dict_head_num) + self.dictionary_dim = int(dictionary_dim) + self.window_size = int(window_size) + self.hyper_window_size = int(hyper_window_size) + self.hyper_head_dim = int(hyper_head_dim) + + slice_channels = M // num_slices + input_image_channel = 3 + output_image_channel = 3 + + # ----- g_a / g_s ----- + self.g_a = nn.Sequential( + _ResidualBottleneckBlockWithStride(input_image_channel, feature_dims[0]), + _SwinBlockWithConvMulti( + feature_dims[0], + feature_dims[0], + head_dim[0], + self.window_size, + 0.0, + _ResScaleConvolutionGateBlock, + block_num=block_num[0], + ), + _ResidualBottleneckBlockWithStride(feature_dims[0], feature_dims[1]), + _SwinBlockWithConvMulti( + feature_dims[1], + feature_dims[1], + head_dim[1], + self.window_size, + 0.0, + _ResScaleConvolutionGateBlock, + block_num=block_num[1], + ), + _ResidualBottleneckBlockWithStride(feature_dims[1], feature_dims[2]), + _SwinBlockWithConvMulti( + feature_dims[2], + feature_dims[2], + head_dim[2], + self.window_size, + 0.0, + _ResScaleConvolutionGateBlock, + block_num=block_num[2], + ), + conv(feature_dims[2], M, kernel_size=5, stride=2), + ) + self.g_s = nn.Sequential( + deconv(M, feature_dims[2], kernel_size=5, stride=2), + _SwinBlockWithConvMulti( + feature_dims[2], + feature_dims[2], + head_dim[3], + self.window_size, + 0.0, + _ResScaleConvolutionGateBlock, + block_num=block_num[2], + ), + _ResidualBottleneckBlockWithUpsample(feature_dims[2], feature_dims[1]), + _SwinBlockWithConvMulti( + feature_dims[1], + feature_dims[1], + head_dim[4], + self.window_size, + 0.0, + _ResScaleConvolutionGateBlock, + block_num=block_num[1], + ), + _ResidualBottleneckBlockWithUpsample(feature_dims[1], feature_dims[0]), + _SwinBlockWithConvMulti( + feature_dims[0], + feature_dims[0], + head_dim[5], + self.window_size, + 0.0, + _ResScaleConvolutionGateBlock, + block_num=block_num[0], + ), + _ResidualBottleneckBlockWithUpsample(feature_dims[0], output_image_channel), + ) + + # ----- h_a / h_mean_s / h_scale_s ----- + h_a = nn.Sequential( + _ResidualBottleneckBlockWithStride(M, N), + _SwinBlockWithConvMulti( + N, + N, + hyper_head_dim, + hyper_window_size, + 0.0, + _ResScaleConvolutionGateBlock, + block_num=1, + ), + conv(N, hyper_channels, kernel_size=3, stride=2), + ) + + # NOTE: upstream DCAE used h_z_s1 for *scales* and h_z_s2 for *means*; + # _DualHyperSynthesis(h_mean_s, h_scale_s) flips this to the + # (means, scales) ordering shared with STF / TCM / CCA. The convert + # script renames h_z_s2 -> h_s.h_mean_s and h_z_s1 -> h_s.h_scale_s. + h_mean_s = nn.Sequential( + deconv(hyper_channels, N, kernel_size=3, stride=2), + _SwinBlockWithConvMulti( + N, + N, + hyper_head_dim, + hyper_window_size, + 0.0, + _ResScaleConvolutionGateBlock, + block_num=1, + ), + _ResidualBottleneckBlockWithUpsample(N, M), + ) + h_scale_s = nn.Sequential( + deconv(hyper_channels, N, kernel_size=3, stride=2), + _SwinBlockWithConvMulti( + N, + N, + hyper_head_dim, + hyper_window_size, + 0.0, + _ResScaleConvolutionGateBlock, + block_num=1, + ), + _ResidualBottleneckBlockWithUpsample(N, M), + ) + + # ----- Shared dictionary ----- + self.shared_dictionary = SharedDictionary( + dict_num=self.dict_num, dictionary_dim=self.dictionary_dim + ) + + # ----- Latent codec ----- + cross_attention_kwargs = { + "head_num": self.dict_head_num, + "mlp_rate": 4, + "qkv_bias": True, + "dictionary_dim": self.dictionary_dim, + } + + widths = (224, 128) + groups = [slice_channels] * num_slices + + def support_count(k: int) -> int: + return k if max_support_slices < 0 else min(k, max_support_slices) + + # mean_support_trail = support tensor produced by the head: + # cat(input(2M + slice_ch*support_count), dict_info(M)). + def mean_support_ch(k: int) -> int: + return 2 * M + slice_channels * support_count(k) + M + + support_slices = [list(range(support_count(k))) for k in range(num_slices)] + + # Side-parameter channel-groups wiring, inlined ELIC-style (mirrors + # STF / TCM). channel_context covers y0..y(K-1); each head sees + # cat(side_params(2M), *prev_y_hat) and emits cat(scale, mean, + # mean_support) for the LRP-aware leaf. The dictionary cross-attention + # head (DCAE's distinctive piece) replaces STF's plain mean/scale head. + channel_context = { + f"y{k}": build_dictionary_mean_scale_head( + slice_ch=slice_channels, + support_ch=2 * M + slice_channels * support_count(k), + shared_dictionary=self.shared_dictionary, + dict_output_ch=M, + cross_attention_kwargs=cross_attention_kwargs, + widths=widths, + emit_mean_support=True, + ) + for k in range(num_slices) + } + y_latent_codec = { + f"y{k}": _LRPGaussianLatentCodec( + lrp_transform=make_entropy_transform( + lrp_support_channels(2 * M, slice_channels, k, max_support_slices) + + M, + slice_channels, + widths=widths, + ), + lrp_scale=0.5, + mean_support_trail_channels=mean_support_ch(k), + chunks=("scales", "means"), + quantizer="ste", + ) + for k in range(num_slices) + } + + self.latent_codec = HyperpriorLatentCodec( + h_a=h_a, + h_s=_DualHyperSynthesis(h_mean_s, h_scale_s), + latent_codec={ + "z": EntropyBottleneckLatentCodec( + entropy_bottleneck=EntropyBottleneck(hyper_channels), + quantizer="noise", + ), + "y": _SideContextChannelGroupsLatentCodec( + groups=groups, + channel_context=channel_context, + latent_codec=y_latent_codec, + support_slices=support_slices, + ), + }, + ) + + def forward(self, x: Tensor) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]: + y = self.g_a(x) + out = self.latent_codec(y) + return { + "x_hat": self.g_s(out["y_hat"]), + "likelihoods": out["likelihoods"], + } + + def compress(self, x: Tensor) -> Dict[str, object]: + y = self.g_a(x) + return self.latent_codec.compress(y) + + def decompress( + self, strings: Sequence[Sequence[bytes]], shape: Sequence[int] + ) -> Dict[str, Tensor]: + out = self.latent_codec.decompress(strings, shape) + return {"x_hat": self.g_s(out["y_hat"]).clamp_(0, 1)} + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "DCAE": + kwargs = _infer_config_from_state_dict(state_dict) + net = cls(**kwargs) + net.load_state_dict(state_dict) + return net + + +# --------------------------------------------------------------------------- +# Architecture inference helpers (state_dict -> hyperparameters) +# --------------------------------------------------------------------------- + + +def _infer_stage_block_num(state_dict: Dict[str, Tensor], prefix: str) -> int: + """Count ``{prefix}{N}.ln1.weight`` entries to recover SwinBlockWithConvMulti depth.""" + matches = [ + k for k in state_dict if k.startswith(prefix) and k.endswith(".ln1.weight") + ] + return len(matches) + + +def _infer_attention_head_dim( + state_dict: Dict[str, Tensor], prefix: str, channel_count: int +) -> int: + """Recover head_dim from ``{prefix}.layers.0.msa.relative_position_params``.""" + key = f"{prefix}.layers.0.msa.relative_position_params" + if key not in state_dict: + raise KeyError(f"missing {key} for head-dim inference") + n_heads = state_dict[key].size(0) + if channel_count % n_heads != 0: + raise ValueError( + f"channel_count {channel_count} not divisible by n_heads {n_heads} at {prefix}" + ) + return channel_count // n_heads + + +def _infer_window_size(state_dict: Dict[str, Tensor], prefix: str) -> int: + """Recover window_size from ``{prefix}.layers.0.msa.relative_position_params`` shape.""" + key = f"{prefix}.layers.0.msa.relative_position_params" + if key not in state_dict: + raise KeyError(f"missing {key} for window-size inference") + relative_dim = state_dict[key].size(1) + if relative_dim % 2 == 0: + raise ValueError( + f"relative_position_params has even spatial dim {relative_dim} at {prefix}" + ) + return (relative_dim + 1) // 2 + + +def _infer_config_from_state_dict(state_dict: Dict[str, Tensor]) -> Dict[str, object]: + """Recover DCAE constructor kwargs from a containerized state_dict.""" + feature_dims = ( + state_dict["g_a.0.conv.weight"].size(0), + state_dict["g_a.2.conv.weight"].size(0), + state_dict["g_a.4.conv.weight"].size(0), + ) + N = state_dict["latent_codec.h_a.0.conv.weight"].size(0) + M = state_dict["latent_codec.h_a.0.conv.weight"].size(1) + hyper_channels = state_dict["latent_codec.z.entropy_bottleneck.quantiles"].size(0) + num_slices = infer_num_slices(state_dict) + slice_ch = M // num_slices + # Recover max_support_slices from the widest cc input width. + # support_ch (head input) = 2 * M + slice_ch * support_count. + # cc_in_ch (mean_cc.0.weight) = support_ch + M = 3 * M + slice_ch * support_count. + if num_slices > 1: + widest = max( + state_dict[f"latent_codec.y.channel_context.y{k}.mean_cc.0.weight"].size(1) + for k in range(num_slices) + ) + support_count = (widest - 3 * M) // slice_ch + max_support_slices = max(support_count, 1) + else: + max_support_slices = 1 + + block_num = ( + _infer_stage_block_num(state_dict, "g_a.1.layers."), + _infer_stage_block_num(state_dict, "g_a.3.layers."), + _infer_stage_block_num(state_dict, "g_a.5.layers."), + ) + head_dim = ( + _infer_attention_head_dim(state_dict, "g_a.1", feature_dims[0]), + _infer_attention_head_dim(state_dict, "g_a.3", feature_dims[1]), + _infer_attention_head_dim(state_dict, "g_a.5", feature_dims[2]), + _infer_attention_head_dim(state_dict, "g_s.1", feature_dims[2]), + _infer_attention_head_dim(state_dict, "g_s.3", feature_dims[1]), + _infer_attention_head_dim(state_dict, "g_s.5", feature_dims[0]), + ) + + dt = state_dict["shared_dictionary.dt"] + dict_num = dt.size(0) + dictionary_dim = dt.size(1) + dict_head_num = state_dict[ + "latent_codec.y.channel_context.y0.cross_attention.scale" + ].size(0) + + return dict( + head_dim=head_dim, + N=N, + M=M, + hyper_channels=hyper_channels, + num_slices=num_slices, + max_support_slices=max_support_slices, + feature_dims=feature_dims, + block_num=block_num, + dict_num=dict_num, + dict_head_num=dict_head_num, + dictionary_dim=dictionary_dim, + window_size=_infer_window_size(state_dict, "g_a.1"), + hyper_window_size=_infer_window_size(state_dict, "latent_codec.h_a.1"), + hyper_head_dim=_infer_attention_head_dim(state_dict, "latent_codec.h_a.1", N), + ) diff --git a/compressai/models/saaf.py b/compressai/models/saaf.py new file mode 100644 index 00000000..88387ed5 --- /dev/null +++ b/compressai/models/saaf.py @@ -0,0 +1,1007 @@ +"""SAAF (Sparse Attention with Adaptive Frequency) model. + +H. Ma, X. Shi, H. Sun, X. Yue, X. Liu, G. Wang, W. Cai, "Learned Image +Compression via Sparse Attention and Adaptive Frequency", IEEE/CVF Conf. +on Computer Vision and Pattern Recognition (CVPR), 2026. + +Adapted from the upstream reference implementation at +https://github.com/huidong-ma/SAAF +based on CompressAI, DCAE, and AuxT. + +SAAF combines adaptive-frequency auxiliary transform branches +(``aux_enc`` / ``aux_dec``) with a denoising regularizer that produces +``diffusion_loss`` during training. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange +from timm.layers import DropPath +from torch import Tensor + +from compressai.entropy_models import EntropyBottleneck +from compressai.latent_codecs import ( + ChannelGroupsLatentCodec, + EntropyBottleneckLatentCodec, + GaussianConditionalLatentCodec, + HyperpriorLatentCodec, +) +from compressai.layers.attn.dictionary import ConvolutionalGLU, Scale +from compressai.layers.attn.swin import pad_to_window_multiple +from compressai.models._helpers.auxt import OLP +from compressai.models._helpers.auxt import aux_loss as _aggregate_aux_loss +from compressai.models._helpers.dictionary_context import ( + SharedDictionary, + build_dictionary_mean_scale_head, +) +from compressai.models._helpers.slice_helpers import ( + infer_num_slices, + lrp_support_channels, + make_entropy_transform, +) +from compressai.models.base import CompressionModel +from compressai.models.sensetime import ResidualBottleneckBlock +from compressai.models.utils import conv, deconv +from compressai.registry import register_model + +__all__ = ["SAAF"] + + +class _DualHyperSynthesis(nn.Module): + h_mean_s: nn.Module + h_scale_s: nn.Module + + def __init__(self, h_mean_s: nn.Module, h_scale_s: nn.Module) -> None: + super().__init__() + self.h_mean_s = h_mean_s + self.h_scale_s = h_scale_s + + def forward(self, z_hat: Tensor) -> Tensor: + return torch.cat([self.h_mean_s(z_hat), self.h_scale_s(z_hat)], dim=1) + + +class _LRPGaussianLatentCodec(GaussianConditionalLatentCodec): + lrp_transform: nn.Module + + def __init__( + self, + lrp_transform: nn.Module, + *, + lrp_scale: float = 0.5, + mean_support_trail_channels: int = 0, + **gc_kwargs: Any, + ) -> None: + super().__init__(**gc_kwargs) + self.lrp_transform = lrp_transform + self.lrp_scale = float(lrp_scale) + self.mean_support_trail_channels = int(mean_support_trail_channels) + + def _split_ctx_params(self, ctx_params: Tensor) -> Tuple[Tensor, Tensor]: + if self.mean_support_trail_channels <= 0: + return ctx_params, ctx_params + trail = self.mean_support_trail_channels + gaussian_params = ctx_params[:, :-trail] + mean_support = ctx_params[:, -trail:] + return gaussian_params, mean_support + + def _apply_lrp(self, mean_support: Tensor, y_hat: Tensor) -> Tensor: + lrp = self.lrp_scale * torch.tanh( + self.lrp_transform(torch.cat([mean_support, y_hat], dim=1)) + ) + return y_hat + lrp + + def forward(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().forward(y, gaussian_params) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().compress(y, gaussian_params) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + def decompress( + self, + strings: List[List[bytes]], + shape: Tuple[int, ...], + ctx_params: Tensor, + **kwargs: Any, + ) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().decompress(strings, shape, gaussian_params, **kwargs) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + +class _SideContextChannelGroupsLatentCodec(ChannelGroupsLatentCodec): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + if "y0" not in self.channel_context: + raise ValueError("side-parameter channel groups require channel_context.y0") + + def _get_ctx_params( + self, k: int, side_params: Tensor, y_hat_: List[Tensor] + ) -> Tensor: + if k == 0: + return self.channel_context["y0"](side_params) + support = [y_hat_[i] for i in self.support_slices[k]] + if not support: + return self.channel_context[f"y{k}"](side_params) + return self.channel_context[f"y{k}"]( + self.merge_params(side_params, self.merge_y(*support)) + ) + + +# --------------------------------------------------------------------------- +# SAAF-private g_a / g_s building blocks +# --------------------------------------------------------------------------- + + +def _group_count(channels: int, max_groups: int = 8) -> int: + """Largest divisor of ``channels`` not exceeding ``max_groups``. + + Used to size GroupNorm groups inside :class:`_DenoisingAsRegularizer`. + """ + for groups in range(min(max_groups, channels), 0, -1): + if channels % groups == 0: + return groups + return 1 + + +class _ResidualBottleneckBlockWithStride(nn.Module): + """SAAF stride-2 residual-bottleneck downsampling block.""" + + def __init__(self, in_ch: int, out_ch: int) -> None: + super().__init__() + self.conv = conv(in_ch, out_ch, kernel_size=5, stride=2) + self.res1 = ResidualBottleneckBlock(out_ch, out_ch) + self.res2 = ResidualBottleneckBlock(out_ch, out_ch) + self.res3 = ResidualBottleneckBlock(out_ch, out_ch) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.conv(input_tensor) + output = self.res1(output) + output = self.res2(output) + return self.res3(output) + + +class _ResidualBottleneckBlockWithUpsample(nn.Module): + """SAAF residual-bottleneck upsampling block.""" + + def __init__(self, in_ch: int, out_ch: int) -> None: + super().__init__() + self.res1 = ResidualBottleneckBlock(in_ch, in_ch) + self.res2 = ResidualBottleneckBlock(in_ch, in_ch) + self.res3 = ResidualBottleneckBlock(in_ch, in_ch) + self.conv = deconv(in_ch, out_ch, kernel_size=5, stride=2) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.res1(input_tensor) + output = self.res2(output) + output = self.res3(output) + return self.conv(output) + + +class _AdaptiveFrequencyBlock(nn.Module): + """SAAF analysis-side AuxT block: frequency-attention mixer + OLP. + + Used inside ``aux_enc`` to produce an auxiliary feature stream that is + summed into ``g_a`` at every stage boundary. Holds an + :class:`OLP` whose orthogonality regulariser is collected by + :meth:`SAAF.aux_loss`. + """ + + def __init__(self, in_dim: int, out_dim: int) -> None: + super().__init__() + self.olp = OLP(in_dim, out_dim) + mid_dim = max(in_dim // 4, 4) + self.freq_attn = nn.Sequential( + nn.Conv2d(in_dim, mid_dim, 1), + nn.GELU(), + nn.Conv2d(mid_dim, 4, 1), + nn.Softmax(dim=1), + ) + self.freq_weights = nn.Parameter(torch.tensor([1.0, 0.8, 0.8, 0.6])) + + def forward(self, input_tensor: Tensor) -> Tensor: + batch_size, _, height, width = input_tensor.shape + frequency_attention = self.freq_attn(input_tensor) + frequency_weights = torch.exp(self.freq_weights).view(1, 4, 1, 1) + output = input_tensor.unsqueeze(1) * frequency_attention.unsqueeze(2) + output = output * frequency_weights.unsqueeze(2) + output = output.sum(dim=1) + output = output.flatten(2).permute(0, 2, 1) + output = self.olp(output) + return output.permute(0, 2, 1).view(batch_size, -1, height, width) + + +class _InverseAdaptiveFrequencyBlock(nn.Module): + """Synthesis-side counterpart of :class:`_AdaptiveFrequencyBlock`. + + Used inside ``aux_dec``. Adds a small frequency-attention residual on + top of the OLP output (parameterised by ``0.1``) so the synthesis + branch can reweight subbands without dominating the main ``g_s`` path. + """ + + def __init__(self, in_dim: int, out_dim: int) -> None: + super().__init__() + self.olp = OLP(in_dim, out_dim) + mid_dim = max(in_dim // 4, 4) + self.freq_attn = nn.Sequential( + nn.Conv2d(in_dim, mid_dim, 1), + nn.GELU(), + nn.Conv2d(mid_dim, 4, 1), + nn.Softmax(dim=1), + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + batch_size, _, height, width = input_tensor.shape + frequency_weights = self.freq_attn(input_tensor) + output = input_tensor.flatten(2).permute(0, 2, 1) + output = self.olp(output) + output = output.permute(0, 2, 1).view(batch_size, -1, height, width) + enhanced = output * frequency_weights.mean(dim=1, keepdim=True) + return output + 0.1 * enhanced + + +class _DenoisingAsRegularizer(nn.Module): + """Noise-prediction head producing SAAF's training-only ``diffusion_loss``. + + Conditions on ``z_hat`` (hyperprior latent), perturbs the encoder + latent ``y`` with random Gaussian noise scaled by a per-batch random + timestep, and asks a small UNet-style predictor to recover the noise. + The MSE of the prediction is returned as a scalar regulariser. Lifted + verbatim from the upstream SAAF reference implementation. + """ + + def __init__(self, latent_dim: int = 320, hyper_channels: int = 192) -> None: + super().__init__() + self.time_embed = nn.Sequential( + nn.Linear(1, latent_dim), + nn.SiLU(), + nn.Linear(latent_dim, latent_dim), + ) + self.noise_predictor = nn.Sequential( + nn.Conv2d(latent_dim, latent_dim, 3, padding=1), + nn.GroupNorm(_group_count(latent_dim), latent_dim), + nn.SiLU(), + ResidualBottleneckBlock(latent_dim, latent_dim), + ResidualBottleneckBlock(latent_dim, latent_dim), + nn.Conv2d(latent_dim, latent_dim, 3, padding=1), + nn.GroupNorm(_group_count(latent_dim), latent_dim), + nn.SiLU(), + nn.Conv2d(latent_dim, latent_dim, 1), + ) + condition_channels = max(latent_dim * 4 // 5, 4) + self.condition_encoder = nn.Sequential( + nn.Conv2d(hyper_channels, condition_channels, 1), + nn.GroupNorm(_group_count(condition_channels), condition_channels), + nn.GELU(), + nn.Conv2d(condition_channels, latent_dim, 3, padding=1), + nn.Dropout(0.1), + nn.GELU(), + ) + + def forward(self, latent: Tensor, hyper_latent: Tensor) -> Tensor: + batch_size, channels, height, width = latent.size() + condition = self.condition_encoder(hyper_latent) + condition = F.interpolate( + condition, size=(height, width), mode="bilinear", align_corners=False + ) + time = torch.rand(batch_size, 1, device=latent.device, dtype=latent.dtype) + noise = torch.randn_like(latent) + noisy_latent = latent + noise * time.view(batch_size, 1, 1, 1) + time_embedding = self.time_embed(time).view(batch_size, channels, 1, 1) + prediction = self.noise_predictor(noisy_latent + time_embedding + condition) + return F.mse_loss(prediction, noise) + + +class _CrossSparseWindowAttention(nn.Module): + """SAAF-specific windowed attention with shared global tokens. + + Differs from :class:`compressai.layers.attn.swin.WMSA` in two ways: + (1) each window mixes its local self-attention output with a global + attention pass against a small set of learned tokens (parameterised + by ``num_global_tokens`` and a learnable ``global_alpha``); (2) uses + a flat ``relative_position_bias_table`` indexed by a precomputed + ``relative_position_index``, matching the upstream layout (so SAAF + checkpoints round-trip without further key renames). + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + head_dim: int, + window_size: int, + num_global_tokens: int = 2, + ) -> None: + super().__init__() + if input_dim % head_dim != 0: + raise ValueError("input_dim must be divisible by head_dim") + + self.input_dim = input_dim + self.output_dim = output_dim + self.head_dim = head_dim + self.scale = head_dim**-0.5 + self.n_heads = input_dim // head_dim + self.window_size = window_size + self.embedding_layer = nn.Linear(input_dim, 3 * input_dim, bias=True) + self.num_global_tokens = num_global_tokens + self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, input_dim)) + nn.init.trunc_normal_(self.global_tokens, std=0.02) + self.global_kv = nn.Linear(input_dim, input_dim * 2, bias=False) + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads) + ) + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + coords = torch.stack( + torch.meshgrid( + torch.arange(window_size), + torch.arange(window_size), + indexing="ij", + ) + ) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += window_size - 1 + relative_coords[:, :, 1] += window_size - 1 + relative_coords[:, :, 0] *= 2 * window_size - 1 + self.register_buffer("relative_position_index", relative_coords.sum(-1)) + self.linear = nn.Linear(input_dim, output_dim) + self.register_buffer("global_alpha", torch.tensor(0.25)) + + def forward(self, input_tensor: Tensor) -> Tensor: + batch_size, height, width, channels = input_tensor.shape + window_size = self.window_size + height_windows = height // window_size + width_windows = width // window_size + output = input_tensor.view( + batch_size, + height_windows, + window_size, + width_windows, + window_size, + channels, + ) + output = output.permute(0, 1, 3, 2, 4, 5).contiguous() + output = output.view( + batch_size * height_windows * width_windows, + window_size * window_size, + channels, + ) + num_windows = height_windows * width_windows + + qkv = self.embedding_layer(output).reshape( + batch_size * num_windows, + window_size * window_size, + 3, + self.n_heads, + self.head_dim, + ) + qkv = qkv.permute(2, 0, 3, 1, 4).contiguous() + query, key, value = qkv[0], qkv[1], qkv[2] + + similarity = torch.einsum("bhpc,bhqc->bhpq", query * self.scale, key) + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view(window_size * window_size, window_size * window_size, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + similarity = similarity + relative_position_bias.unsqueeze(0) + probabilities = similarity.softmax(dim=-1) + output_local = torch.einsum("bhij,bhjc->bhic", probabilities, value) + + global_tokens = self.global_tokens.expand(batch_size * num_windows, -1, -1) + global_tokens = global_tokens + output.mean(dim=1, keepdim=True) + global_kv = self.global_kv(global_tokens).reshape( + batch_size * num_windows, + self.num_global_tokens, + 2, + self.n_heads, + self.head_dim, + ) + global_kv = global_kv.permute(2, 0, 3, 1, 4).contiguous() + key_global, value_global = global_kv[0], global_kv[1] + similarity_global = torch.einsum( + "bhpc,bhgc->bhpg", query * self.scale, key_global + ) + probabilities_global = similarity_global.softmax(dim=-1) + output_global = torch.einsum( + "bhpg,bhgc->bhpc", probabilities_global, value_global + ) + + output = ( + 1 - self.global_alpha + ) * output_local + self.global_alpha * output_global + output = output.transpose(1, 2).reshape( + batch_size * num_windows, window_size * window_size, channels + ) + output = self.linear(output) + output = output.view( + batch_size, + height_windows, + width_windows, + window_size, + window_size, + channels, + ) + output = output.permute(0, 1, 3, 2, 4, 5).contiguous() + return output.view(batch_size, height, width, channels) + + +class _SpatialAttentionLayer(nn.Module): + """SAAF-specific transformer layer: ``_CrossSparseWindowAttention`` + + :class:`compressai.layers.attn.dictionary.ConvolutionalGLU` MLP block. + Counterpart of DCAE's ``_ResScaleConvolutionGateBlock``. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + head_dim: int, + window_size: int, + drop_path: float, + input_resolution: Optional[Tuple[int, int]] = None, + ) -> None: + del output_dim, input_resolution + super().__init__() + self.ln1 = nn.LayerNorm(input_dim) + self.msa = _CrossSparseWindowAttention( + input_dim, input_dim, head_dim, window_size + ) + self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() + self.ln2 = nn.LayerNorm(input_dim) + self.mlp = ConvolutionalGLU(input_dim, input_dim * 4) + self.res_scale_1 = Scale(input_dim, init_value=1.0) + self.res_scale_2 = Scale(input_dim, init_value=1.0) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.res_scale_1(input_tensor) + self.drop_path( + self.msa(self.ln1(input_tensor)) + ) + return self.res_scale_2(output) + self.drop_path(self.mlp(self.ln2(output))) + + +class _SpatialAttentionBlock(nn.Module): + """Stack of ``block_num`` :class:`_SpatialAttentionLayer` instances followed + by a 3x3 conv. Counterpart of DCAE's ``_SwinBlockWithConvMulti``.""" + + def __init__( + self, + input_dim: int, + output_dim: int, + head_dim: int, + window_size: int, + drop_path: float, + block: type[nn.Module] = _SpatialAttentionLayer, + block_num: int = 2, + **kwargs, + ) -> None: + del kwargs + super().__init__() + self.layers = nn.ModuleList( + block(input_dim, input_dim, head_dim, window_size, drop_path) + for _ in range(block_num) + ) + self.block_num = block_num + self.conv = conv(input_dim, output_dim, 3, 1) + self.window_size = window_size + + def forward(self, input_tensor: Tensor) -> Tensor: + output, pad_height, pad_width = pad_to_window_multiple( + input_tensor, self.window_size + ) + output = rearrange(output, "b c h w -> b h w c") + for layer in self.layers: + output = layer(output) + output = rearrange(output, "b h w c -> b c h w") + output = self.conv(output) + F.pad(input_tensor, (0, pad_width, 0, pad_height)) + if pad_height > 0 or pad_width > 0: + output = output[:, :, : input_tensor.size(2), : input_tensor.size(3)] + return output.contiguous() + + +# --------------------------------------------------------------------------- +# SAAF model +# --------------------------------------------------------------------------- + + +@register_model("saaf") +class SAAF(CompressionModel): + """SAAF model (Ma et al., CVPR 2026). + + Containerized entropy stack identical to DCAE (see + :class:`compressai.models.dcae.DCAE`); SAAF differs only in the + ``g_a`` / ``g_s`` building blocks, the parallel ``aux_enc`` / + ``aux_dec`` AuxT chain (each block carrying an :class:`OLP`), and the + training-only :class:`_DenoisingAsRegularizer` ``diffusion_prior`` head. + """ + + def __init__( + self, + head_dim: Optional[Sequence[int]] = None, + N: int = 192, + M: int = 320, + hyper_channels: int = 192, + num_slices: int = 5, + max_support_slices: int = 5, + feature_dims: Optional[Sequence[int]] = None, + block_num: Optional[Sequence[int]] = None, + dict_num: int = 128, + dict_head_num: int = 20, + dictionary_dim: Optional[int] = None, + window_size: int = 8, + hyper_window_size: int = 4, + hyper_head_dim: int = 32, + **kwargs, + ) -> None: + super().__init__(**kwargs) + head_dim = tuple(head_dim or (8, 16, 32, 32, 16, 8)) + feature_dims = tuple(feature_dims or (96, 144, 256)) + block_num = tuple(block_num or (1, 2, 12)) + dictionary_dim = dictionary_dim or 32 * dict_head_num + if len(head_dim) != 6: + raise ValueError("head_dim must have six entries") + if len(feature_dims) != 3: + raise ValueError("feature_dims must have three entries") + if len(block_num) != 3: + raise ValueError("block_num must have three entries") + if M % num_slices != 0: + raise ValueError("M must be divisible by num_slices") + + self.N = int(N) + self.M = int(M) + self.hyper_channels = int(hyper_channels) + self.num_slices = int(num_slices) + self.max_support_slices = int(max_support_slices) + self.head_dim = head_dim + self.feature_dims = feature_dims + self.block_num = block_num + self.dict_num = int(dict_num) + self.dict_head_num = int(dict_head_num) + self.dictionary_dim = int(dictionary_dim) + self.window_size = int(window_size) + self.hyper_window_size = int(hyper_window_size) + self.hyper_head_dim = int(hyper_head_dim) + + slice_channels = M // num_slices + input_image_channel = 3 + output_image_channel = 3 + + # ----- g_a / g_s (SAAF-specific spatial-attention stages) ----- + self.m_down1 = [ + _SpatialAttentionBlock( + feature_dims[0], + feature_dims[0], + head_dim[0], + self.window_size, + 0.0, + block=_SpatialAttentionLayer, + block_num=block_num[0], + ), + _ResidualBottleneckBlockWithStride(feature_dims[0], feature_dims[1]), + ] + self.m_down2 = [ + _SpatialAttentionBlock( + feature_dims[1], + feature_dims[1], + head_dim[1], + self.window_size, + 0.0, + block=_SpatialAttentionLayer, + block_num=block_num[1], + ), + _ResidualBottleneckBlockWithStride(feature_dims[1], feature_dims[2]), + ] + self.m_down3 = [ + _SpatialAttentionBlock( + feature_dims[2], + feature_dims[2], + head_dim[2], + self.window_size, + 0.0, + block=_SpatialAttentionLayer, + block_num=block_num[2], + ), + conv(feature_dims[2], M, kernel_size=5, stride=2), + ] + self.g_a = nn.Sequential( + _ResidualBottleneckBlockWithStride(input_image_channel, feature_dims[0]), + *self.m_down1, + *self.m_down2, + *self.m_down3, + ) + self.aux_enc = nn.ModuleList( + [ + _AdaptiveFrequencyBlock(input_image_channel, feature_dims[0]), + _AdaptiveFrequencyBlock(feature_dims[0], feature_dims[1]), + _AdaptiveFrequencyBlock(feature_dims[1], feature_dims[2]), + _AdaptiveFrequencyBlock(feature_dims[2], M), + ] + ) + + self.m_up1 = [ + _SpatialAttentionBlock( + feature_dims[2], + feature_dims[2], + head_dim[3], + self.window_size, + 0.0, + block=_SpatialAttentionLayer, + block_num=block_num[2], + ), + _ResidualBottleneckBlockWithUpsample(feature_dims[2], feature_dims[1]), + ] + self.m_up2 = [ + _SpatialAttentionBlock( + feature_dims[1], + feature_dims[1], + head_dim[4], + self.window_size, + 0.0, + block=_SpatialAttentionLayer, + block_num=block_num[1], + ), + _ResidualBottleneckBlockWithUpsample(feature_dims[1], feature_dims[0]), + ] + self.m_up3 = [ + _SpatialAttentionBlock( + feature_dims[0], + feature_dims[0], + head_dim[5], + self.window_size, + 0.0, + block=_SpatialAttentionLayer, + block_num=block_num[0], + ), + _ResidualBottleneckBlockWithUpsample(feature_dims[0], output_image_channel), + ] + self.g_s = nn.Sequential( + deconv(M, feature_dims[2], kernel_size=5, stride=2), + *self.m_up1, + *self.m_up2, + *self.m_up3, + ) + self.aux_dec = nn.ModuleList( + [ + _InverseAdaptiveFrequencyBlock(M, feature_dims[2]), + _InverseAdaptiveFrequencyBlock(feature_dims[2], feature_dims[1]), + _InverseAdaptiveFrequencyBlock(feature_dims[1], feature_dims[0]), + _InverseAdaptiveFrequencyBlock(feature_dims[0], output_image_channel), + ] + ) + + # ----- h_a / h_mean_s / h_scale_s (same SAAF blocks, hyper config) ----- + h_a = nn.Sequential( + _ResidualBottleneckBlockWithStride(M, N), + _SpatialAttentionBlock( + N, + N, + hyper_head_dim, + hyper_window_size, + 0.0, + block=_SpatialAttentionLayer, + block_num=1, + ), + conv(N, hyper_channels, kernel_size=3, stride=2), + ) + + # NOTE: upstream SAAF (like DCAE) uses h_z_s1 for *scales* and + # h_z_s2 for *means*; _DualHyperSynthesis(h_mean_s, h_scale_s) + # flips this. The convert script renames h_z_s2 -> h_s.h_mean_s + # and h_z_s1 -> h_s.h_scale_s. + h_mean_s = nn.Sequential( + deconv(hyper_channels, N, kernel_size=3, stride=2), + _SpatialAttentionBlock( + N, + N, + hyper_head_dim, + hyper_window_size, + 0.0, + block=_SpatialAttentionLayer, + block_num=1, + ), + _ResidualBottleneckBlockWithUpsample(N, M), + ) + h_scale_s = nn.Sequential( + deconv(hyper_channels, N, kernel_size=3, stride=2), + _SpatialAttentionBlock( + N, + N, + hyper_head_dim, + hyper_window_size, + 0.0, + block=_SpatialAttentionLayer, + block_num=1, + ), + _ResidualBottleneckBlockWithUpsample(N, M), + ) + + # ----- Shared dictionary + diffusion prior ----- + self.shared_dictionary = SharedDictionary( + dict_num=self.dict_num, dictionary_dim=self.dictionary_dim + ) + self.diffusion_prior = _DenoisingAsRegularizer( + latent_dim=M, hyper_channels=hyper_channels + ) + + # ----- Latent codec (identical to DCAE wiring) ----- + cross_attention_kwargs = { + "head_num": self.dict_head_num, + "mlp_rate": 4, + "qkv_bias": True, + "dictionary_dim": self.dictionary_dim, + } + + widths = (224, 128) + groups = [slice_channels] * num_slices + + def support_count(k: int) -> int: + return k if max_support_slices < 0 else min(k, max_support_slices) + + # mean_support_trail = support tensor produced by the head: + # cat(input(2M + slice_ch*support_count), dict_info(M)). + def mean_support_ch(k: int) -> int: + return 2 * M + slice_channels * support_count(k) + M + + support_slices = [list(range(support_count(k))) for k in range(num_slices)] + + # Side-parameter channel-groups wiring, inlined ELIC-style (mirrors + # STF / TCM / DCAE). channel_context covers y0..y(K-1); each head sees + # cat(side_params(2M), *prev_y_hat) and emits cat(scale, mean, + # mean_support) for the LRP-aware leaf. The dictionary cross-attention + # head is the DCAE/SAAF-distinctive piece. + channel_context = { + f"y{k}": build_dictionary_mean_scale_head( + slice_ch=slice_channels, + support_ch=2 * M + slice_channels * support_count(k), + shared_dictionary=self.shared_dictionary, + dict_output_ch=M, + cross_attention_kwargs=cross_attention_kwargs, + widths=widths, + emit_mean_support=True, + ) + for k in range(num_slices) + } + y_latent_codec = { + f"y{k}": _LRPGaussianLatentCodec( + lrp_transform=make_entropy_transform( + lrp_support_channels(2 * M, slice_channels, k, max_support_slices) + + M, + slice_channels, + widths=widths, + ), + lrp_scale=0.5, + mean_support_trail_channels=mean_support_ch(k), + chunks=("scales", "means"), + quantizer="ste", + ) + for k in range(num_slices) + } + + self.latent_codec = HyperpriorLatentCodec( + h_a=h_a, + h_s=_DualHyperSynthesis(h_mean_s, h_scale_s), + latent_codec={ + "z": EntropyBottleneckLatentCodec( + entropy_bottleneck=EntropyBottleneck(hyper_channels), + quantizer="noise", + ), + "y": _SideContextChannelGroupsLatentCodec( + groups=groups, + channel_context=channel_context, + latent_codec=y_latent_codec, + support_slices=support_slices, + ), + }, + ) + + @staticmethod + def _merge_features(main: Tensor, auxiliary: Tensor) -> Tensor: + """Sum ``auxiliary`` into ``main``, bilinear-interpolating to match + spatial size when the AuxT branch hasn't downsampled yet.""" + if auxiliary.shape[2:] != main.shape[2:]: + auxiliary = F.interpolate( + auxiliary, size=main.shape[2:], mode="bilinear", align_corners=False + ) + return main + auxiliary + + def _encode(self, x: Tensor) -> Tensor: + """Main + AuxT analysis: walk ``g_a`` stage-by-stage, summing + ``aux_enc[i]`` after every stage boundary.""" + y_main = self.g_a[0](x) + y_aux = self.aux_enc[0](x) + y_main = self._merge_features(y_main, y_aux) + + for index, stage in enumerate( + (self.m_down1, self.m_down2, self.m_down3), start=1 + ): + for layer in stage: + y_main = layer(y_main) + y_aux = self.aux_enc[index](y_aux) + y_main = self._merge_features(y_main, y_aux) + return y_main + + def _decode(self, y_hat: Tensor) -> Tensor: + """Main + AuxT synthesis: mirror of :meth:`_encode`.""" + x_main = self.g_s[0](y_hat) + x_aux = self.aux_dec[0](y_hat) + x_main = self._merge_features(x_main, x_aux) + + for index, stage in enumerate((self.m_up1, self.m_up2, self.m_up3), start=1): + for layer in stage: + x_main = layer(x_main) + x_aux = self.aux_dec[index](x_aux) + x_main = self._merge_features(x_main, x_aux) + return x_main + + def aux_loss(self) -> Tensor: + """Auxiliary entropy-bottleneck loss plus OLP regulariser.""" + return super().aux_loss() + _aggregate_aux_loss(self) + + def forward(self, x: Tensor) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]: + y = self._encode(x) + out = self.latent_codec(y) + diffusion_loss = torch.zeros((), device=x.device, dtype=x.dtype) + if self.training: + # Reproduce upstream's z_hat-from-rounded-medians path so the + # diffusion prior conditions on the same hyper latent the + # entropy stack uses. Pulling z out of the latent codec keeps + # the regulariser independent of the codec's noise/STE choice. + z_hat = self.latent_codec.h_a(y) + z_eb = self.latent_codec.latent_codec["z"].entropy_bottleneck + z_medians = z_eb._get_medians() + z_hat = torch.round(z_hat - z_medians) + z_medians + diffusion_loss = self.diffusion_prior(y, z_hat) + return { + "x_hat": self._decode(out["y_hat"]), + "likelihoods": out["likelihoods"], + "diffusion_loss": diffusion_loss, + } + + def compress(self, x: Tensor) -> Dict[str, object]: + y = self._encode(x) + return self.latent_codec.compress(y) + + def decompress( + self, strings: Sequence[Sequence[bytes]], shape: Sequence[int] + ) -> Dict[str, Tensor]: + out = self.latent_codec.decompress(strings, shape) + return {"x_hat": self._decode(out["y_hat"]).clamp_(0, 1)} + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "SAAF": + kwargs = _infer_config_from_state_dict(state_dict) + net = cls(**kwargs) + # ``_CrossSparseWindowAttention`` registers ``relative_position_index`` + # and ``global_alpha`` as non-persistent buffers, so they may be + # absent from saved checkpoints. Tolerate the missing keys. + incompatible_keys = net.load_state_dict(state_dict, strict=False) + allowed_missing = { + key + for key in net.state_dict() + if key.endswith("relative_position_index") or key.endswith("global_alpha") + } + missing_keys = set(incompatible_keys.missing_keys) - allowed_missing + if missing_keys or incompatible_keys.unexpected_keys: + raise RuntimeError( + "Unexpected incompatibility while loading SAAF state_dict: " + f"missing={sorted(missing_keys)}, " + f"unexpected={sorted(incompatible_keys.unexpected_keys)}" + ) + return net + + +# --------------------------------------------------------------------------- +# Architecture inference helpers (state_dict -> hyperparameters) +# --------------------------------------------------------------------------- + + +def _infer_stage_block_num(state_dict: Dict[str, Tensor], prefix: str) -> int: + """Count ``{prefix}{N}.ln1.weight`` entries to recover SpatialAttentionBlock depth.""" + matches = [ + k for k in state_dict if k.startswith(prefix) and k.endswith(".ln1.weight") + ] + return len(matches) + + +def _infer_attention_head_dim( + state_dict: Dict[str, Tensor], prefix: str, channel_count: int +) -> int: + """Recover head_dim from ``{prefix}.layers.0.msa.relative_position_bias_table``.""" + key = f"{prefix}.layers.0.msa.relative_position_bias_table" + if key not in state_dict: + raise KeyError(f"missing {key} for head-dim inference") + n_heads = state_dict[key].size(1) + if channel_count % n_heads != 0: + raise ValueError( + f"channel_count {channel_count} not divisible by n_heads {n_heads} at {prefix}" + ) + return channel_count // n_heads + + +def _infer_window_size(state_dict: Dict[str, Tensor], prefix: str) -> int: + """Recover window_size from the + ``relative_position_bias_table`` flat shape ``(2*win-1)^2``.""" + key = f"{prefix}.layers.0.msa.relative_position_bias_table" + if key not in state_dict: + raise KeyError(f"missing {key} for window-size inference") + flat_dim = state_dict[key].size(0) + side = int(round(flat_dim**0.5)) + if side * side != flat_dim or side % 2 == 0: + raise ValueError( + f"relative_position_bias_table has unexpected length {flat_dim} at {prefix}" + ) + return (side + 1) // 2 + + +def _infer_config_from_state_dict(state_dict: Dict[str, Tensor]) -> Dict[str, object]: + """Recover SAAF constructor kwargs from a containerized state_dict.""" + feature_dims = ( + state_dict["g_a.0.conv.weight"].size(0), + state_dict["g_a.2.conv.weight"].size(0), + state_dict["g_a.4.conv.weight"].size(0), + ) + N = state_dict["latent_codec.h_a.0.conv.weight"].size(0) + M = state_dict["latent_codec.h_a.0.conv.weight"].size(1) + hyper_channels = state_dict["latent_codec.z.entropy_bottleneck.quantiles"].size(0) + num_slices = infer_num_slices(state_dict) + slice_ch = M // num_slices + if num_slices > 1: + widest = max( + state_dict[f"latent_codec.y.channel_context.y{k}.mean_cc.0.weight"].size(1) + for k in range(num_slices) + ) + support_count = (widest - 3 * M) // slice_ch + max_support_slices = max(support_count, 1) + else: + max_support_slices = 1 + + block_num = ( + _infer_stage_block_num(state_dict, "g_a.1.layers."), + _infer_stage_block_num(state_dict, "g_a.3.layers."), + _infer_stage_block_num(state_dict, "g_a.5.layers."), + ) + head_dim = ( + _infer_attention_head_dim(state_dict, "g_a.1", feature_dims[0]), + _infer_attention_head_dim(state_dict, "g_a.3", feature_dims[1]), + _infer_attention_head_dim(state_dict, "g_a.5", feature_dims[2]), + _infer_attention_head_dim(state_dict, "g_s.1", feature_dims[2]), + _infer_attention_head_dim(state_dict, "g_s.3", feature_dims[1]), + _infer_attention_head_dim(state_dict, "g_s.5", feature_dims[0]), + ) + + dt = state_dict["shared_dictionary.dt"] + dict_num = dt.size(0) + dictionary_dim = dt.size(1) + dict_head_num = state_dict[ + "latent_codec.y.channel_context.y0.cross_attention.scale" + ].size(0) + + return dict( + head_dim=head_dim, + N=N, + M=M, + hyper_channels=hyper_channels, + num_slices=num_slices, + max_support_slices=max_support_slices, + feature_dims=feature_dims, + block_num=block_num, + dict_num=dict_num, + dict_head_num=dict_head_num, + dictionary_dim=dictionary_dim, + window_size=_infer_window_size(state_dict, "g_a.1"), + hyper_window_size=_infer_window_size(state_dict, "latent_codec.h_a.1"), + hyper_head_dim=_infer_attention_head_dim(state_dict, "latent_codec.h_a.1", N), + ) diff --git a/compressai/models/tcm.py b/compressai/models/tcm.py index 7a80609c..5cef509a 100644 --- a/compressai/models/tcm.py +++ b/compressai/models/tcm.py @@ -56,6 +56,18 @@ subpel_conv3x3, ) from compressai.layers.attn import ConvTransBlock, SWAtten +from compressai.models._helpers.auxt import ( + AuxTTransform, + build_iwls_branch, + build_wls_branch, + compute_analysis_aux_positions, + compute_synthesis_aux_positions, + has_auxt_state, + is_auxt_wavelet_buffer_key, +) +from compressai.models._helpers.auxt import ( + aux_loss as _aggregate_aux_loss, +) from compressai.models._helpers.channel_context import MeanScaleContextHead from compressai.models._helpers.slice_helpers import ( infer_max_support_slices, @@ -172,11 +184,20 @@ def _group_consecutive(indices: Iterable[int]) -> List[List[int]]: return grouped +def _transform_prefix(state_dict: Dict[str, Tensor], prefix: str) -> str: + wrapped_prefix = f"{prefix}.transform" + if any(key.startswith(f"{wrapped_prefix}.") for key in state_dict): + return wrapped_prefix + return prefix + + def _infer_stage_groups(state_dict: Dict[str, Tensor], prefix: str) -> List[List[int]]: + transform_prefix = _transform_prefix(state_dict, prefix) + index_offset = transform_prefix.count(".") + 1 indices = { - int(key.split(".")[1]) + int(key.split(".")[index_offset]) for key in state_dict - if key.startswith(f"{prefix}.") and ".conv1_1.weight" in key + if key.startswith(f"{transform_prefix}.") and ".conv1_1.weight" in key } return _group_consecutive(indices) @@ -192,11 +213,13 @@ def _infer_stage_depths(state_dict: Dict[str, Tensor]) -> Optional[List[int]]: def _infer_head_dims(state_dict: Dict[str, Tensor], N: int) -> Optional[List[int]]: head_dims: List[int] = [] for prefix in ("g_a", "g_s"): + transform_prefix = _transform_prefix(state_dict, prefix) for group in _infer_stage_groups(state_dict, prefix): if not group: continue table_key = ( - f"{prefix}.{group[0]}.trans_block.msa.attn.relative_position_bias_table" + f"{transform_prefix}.{group[0]}." + "trans_block.msa.attn.relative_position_bias_table" ) if table_key not in state_dict: return None @@ -207,6 +230,8 @@ def _infer_head_dims(state_dict: Dict[str, Tensor], N: int) -> Optional[List[int def _infer_hyper_head_dim(state_dict: Dict[str, Tensor], N: int, default: int) -> int: for key in ( + "latent_codec.h_a.1.trans_block.msa.attn.relative_position_bias_table", + "latent_codec.h_s.h_mean_s.1.trans_block.msa.attn.relative_position_bias_table", "h_a.1.trans_block.msa.attn.relative_position_bias_table", "h_mean_s.1.trans_block.msa.attn.relative_position_bias_table", ): @@ -290,6 +315,7 @@ def __init__( window_size: int = 8, hyper_window_size: int = 4, hyper_head_dim: int = 32, + use_auxt: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -491,10 +517,37 @@ def swatten_factory(c_in: int, c_out: int) -> nn.Module: ), }, ) + if use_auxt: + self.g_a = AuxTTransform( + self.g_a, + build_wls_branch(N, M), + compute_analysis_aux_positions(config), + ) + self.g_s = AuxTTransform( + self.g_s, + build_iwls_branch(N, M), + compute_synthesis_aux_positions(config), + ) + + @property + def use_auxt(self) -> bool: + """``True`` when the AuxT side branch was constructed.""" + return isinstance(self.g_a, AuxTTransform) and isinstance( + self.g_s, AuxTTransform + ) + + def aux_loss(self) -> Tensor: + """Auxiliary entropy-bottleneck loss plus AuxT OLP regulariser. + + The AuxT term is zero when ``use_auxt=False``, so callers can + unconditionally add the returned scalar to the training objective. + """ + return super().aux_loss() + _aggregate_aux_loss(self) @classmethod def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "TCM": - N = state_dict["g_a.0.conv1.weight"].size(0) // 2 + g_a_prefix = _transform_prefix(state_dict, "g_a") + N = state_dict[f"{g_a_prefix}.0.conv1.weight"].size(0) // 2 M = state_dict["latent_codec.h_a.0.conv1.weight"].size(1) config = _infer_stage_depths(state_dict) or [2, 2, 2, 2, 2, 2] head_dim = _infer_head_dims(state_dict, N) or [8, 16, 32, 32, 16, 8] @@ -512,13 +565,20 @@ def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "TCM": num_slices=num_slices, max_support_slices=max_support_slices, hyper_head_dim=_infer_hyper_head_dim(state_dict, N, 32), + use_auxt=has_auxt_state(state_dict), ) # ConvTransBlock's WindowAttention registers # ``relative_position_index`` as a non-persistent buffer, so it is - # absent from saved state dicts. Tolerate the missing keys. + # absent from saved state dicts. AuxT WLS / iWLS rely on + # ``pytorch_wavelets`` for the DWT / IDWT kernels which may or may + # not be persisted depending on the saving model's pytorch_wavelets + # version. Tolerate both as missing keys. incompatible_keys = net.load_state_dict(state_dict, strict=False) allowed_missing = { - key for key in net.state_dict() if key.endswith("relative_position_index") + key + for key in net.state_dict() + if key.endswith("relative_position_index") + or is_auxt_wavelet_buffer_key(key) } missing_keys = set(incompatible_keys.missing_keys) - allowed_missing if missing_keys or incompatible_keys.unexpected_keys: diff --git a/compressai/zoo/__init__.py b/compressai/zoo/__init__.py index e3e75863..17f53136 100644 --- a/compressai/zoo/__init__.py +++ b/compressai/zoo/__init__.py @@ -34,8 +34,10 @@ cca, cheng2020_anchor, cheng2020_attn, + dcae, mbt2018, mbt2018_mean, + saaf, stf, stf_wacnn, tcm, @@ -52,6 +54,8 @@ "mbt2018": mbt2018, "cheng2020-anchor": cheng2020_anchor, "cheng2020-attn": cheng2020_attn, + "dcae": dcae, + "saaf": saaf, "stf": stf, "stf-wacnn": stf_wacnn, "tcm": tcm, diff --git a/compressai/zoo/image.py b/compressai/zoo/image.py index 69085976..120c619f 100644 --- a/compressai/zoo/image.py +++ b/compressai/zoo/image.py @@ -82,6 +82,8 @@ def __getattr__(self, item): "mbt2018_mean", "cheng2020_anchor", "cheng2020_attn", + "dcae", + "saaf", "stf", "stf_wacnn", "tcm", @@ -96,6 +98,8 @@ def __getattr__(self, item): "mbt2018": JointAutoregressiveHierarchicalPriors, "cheng2020-anchor": Cheng2020Anchor, "cheng2020-attn": Cheng2020Attention, + "dcae": _LazyImport("compressai.models.dcae", "DCAE"), + "saaf": _LazyImport("compressai.models.saaf", "SAAF"), # Resolved lazily so `compressai.zoo` is importable without `timm`. "stf": _LazyImport("compressai.models.stf", "SymmetricalTransFormer"), "stf-wacnn": _LazyImport("compressai.models.stf", "WACNN"), @@ -491,6 +495,46 @@ def cheng2020_attn(quality, metric="mse", pretrained=False, progress=True, **kwa ) +def dcae(pretrained: bool = False, progress: bool = True, **kwargs): + r"""DCAE model from J. Lu, L. Zhang, X. Zhou, M. Li, W. Li, S. Gu: + `"Learned Image Compression with Dictionary-based Entropy Model" + `_, IEEE/CVF Conf. on Computer + Vision and Pattern Recognition (CVPR), 2025. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained DCAE weights are not yet hosted on S3.") + from compressai.models.dcae import DCAE + + return DCAE(**kwargs) + + +def saaf(pretrained: bool = False, progress: bool = True, **kwargs): + r"""SAAF model from H. Ma, X. Shi, H. Sun, X. Yue, X. Liu, G. Wang, + W. Cai: "Learned Image Compression via Sparse Attention and Adaptive + Frequency", IEEE/CVF Conf. on Computer Vision and Pattern Recognition + (CVPR), 2026. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained SAAF weights are not yet hosted on S3.") + from compressai.models.saaf import SAAF + + return SAAF(**kwargs) + + def stf(pretrained: bool = False, progress: bool = True, **kwargs): r"""Symmetrical TransFormer (STF) model from R. Zou, C. Song, Z. Zhang: `"The Devil Is in the Details: Window-based Attention for Image diff --git a/examples/convert_dcae_checkpoint.py b/examples/convert_dcae_checkpoint.py new file mode 100644 index 00000000..f1170259 --- /dev/null +++ b/examples/convert_dcae_checkpoint.py @@ -0,0 +1,345 @@ +"""Convert an upstream DCAE checkpoint to compressai layout. + +Loads the published DCAE weight file (e.g. ``0.0018checkpoint_best.pth.tar`` +from the upstream DCAE release accompanying Lu et al., CVPR 2025), translates +it to compressai's containerized module layout, and writes a state dict that +``compressai.models.dcae.DCAE.from_state_dict`` can load directly. Optionally +reports forward-pass sanity numbers (PSNR / bpp) on a synthetic input. + +The upstream-vs-compressai key differences (top-level ``dt`` -> shared +dictionary submodule, ``cc_mean_transforms.{k}`` / +``cc_scale_transforms.{k}`` / ``lrp_transforms.{k}`` / +``dt_cross_attention.{k}`` ModuleLists -> per-slice +``latent_codec.y.channel_context.y{k}`` / ``latent_codec.y.latent_codec.y{k}`` +entries, the means/scales swap on the leading 2*M input channels of the +first conv / linear weights, the ``h_z_s2``/``h_z_s1`` -> means/scales +swap, the H+G containerized re-rooting under ``latent_codec.*``, etc.) +are all handled inside ``convert_upstream_dcae_state_dict``; this script +is a thin CLI around it. + +Example:: + + python examples/convert_dcae_checkpoint.py \\ + --src candidate/DCAE/0.0018checkpoint_best.pth.tar \\ + --dst /tmp/dcae_compressai.pth \\ + --smoke +""" + +from __future__ import annotations + +import argparse + +from pathlib import Path +from typing import Dict + +import torch + +from torch import Tensor + +from compressai.models.dcae import DCAE + +# ---------------------------------------------------------------------------- +# Upstream → compressai state-dict conversion. +# +# Lives here (not in compressai/models/dcae.py) so the model module stays a +# clean compressai-native definition — ``DCAE.from_state_dict`` only loads +# already-converted state dicts. Run this script once to translate a +# published upstream checkpoint into compressai layout, then load the result +# via ``from_state_dict``. +# ---------------------------------------------------------------------------- + + +def _is_upstream_layout(state_dict: Dict[str, Tensor]) -> bool: + """Heuristic: upstream DCAE has top-level ``dt`` and ``dt_cross_attention.0.*`` + (with or without a ``module.`` ``DataParallel`` prefix).""" + return any(key in ("dt", "module.dt") for key in state_dict) and any( + k.startswith("dt_cross_attention.") + or k.startswith("module.dt_cross_attention.") + for k in state_dict + ) + + +def _swap_first_2m_in_channels(weight: Tensor, m: int) -> Tensor: + """Swap the first ``m`` and second ``m`` slices along ``dim=1``. + + Upstream DCAE assembles its query as + ``cat([latent_scales(m), latent_means(m), *prev_y_hat])`` (scales first), + whereas the containerized wiring uses ``cat([latent_means(m), + latent_scales(m), *prev_y_hat])`` (means first). Permuting the leading + 2*m input channels of the first conv / linear weight in cc_mean, + cc_scale, lrp_transforms, and cross_attention.x_trans rebases the + upstream weights to the new input order with no retraining. + """ + if weight.dim() < 2 or weight.size(1) < 2 * m: + return weight + permuted = weight.clone() + permuted[:, :m] = weight[:, m : 2 * m] + permuted[:, m : 2 * m] = weight[:, :m] + return permuted + + +def convert_upstream_dcae_state_dict( + state_dict: Dict[str, Tensor], +) -> Dict[str, Tensor]: + """Convert an upstream DCAE checkpoint to the containerized layout. + + The upstream DCAE source follows the master-era monolithic pattern + (model owns ``h_a`` / ``h_z_s1`` / ``h_z_s2`` / ``entropy_bottleneck`` / + ``gaussian_conditional`` / per-slice ModuleLists for cc heads / LRP / + cross-attention plus a top-level ``dt`` Parameter). This function + rewrites the keys to the containerized layout used by + :class:`~compressai.models.dcae.DCAE` post-refactor: + + - ``dt`` -> ``shared_dictionary.dt`` + - ``dt_cross_attention.{k}.*`` -> + ``latent_codec.y.channel_context.y{k}.cross_attention.*`` + - ``cc_mean_transforms.{k}.*`` -> + ``latent_codec.y.channel_context.y{k}.mean_cc.*`` + - ``cc_scale_transforms.{k}.*`` -> + ``latent_codec.y.channel_context.y{k}.scale_cc.*`` + - ``lrp_transforms.{k}.*`` -> + ``latent_codec.y.latent_codec.y{k}.lrp_transform.*`` + - ``gaussian_conditional.*`` -> fanned out to K copies under + ``latent_codec.y.latent_codec.y{k}.gaussian_conditional.*`` + - ``h_a.*`` -> ``latent_codec.h_a.*`` + - ``h_z_s2.*`` -> ``latent_codec.h_s.h_mean_s.*`` (means) + - ``h_z_s1.*`` -> ``latent_codec.h_s.h_scale_s.*`` (scales) + - ``entropy_bottleneck.*`` -> ``latent_codec.z.entropy_bottleneck.*`` + + Additionally permutes the leading 2*M input channels of the first + conv / linear weight in ``cc_mean`` / ``cc_scale`` / ``lrp_transform`` / + ``cross_attention.x_trans`` to swap upstream's + ``[scales, means, ...]`` ordering to the containerized + ``[means, scales, ...]`` ordering — see :func:`_swap_first_2m_in_channels`. + """ + if not _is_upstream_layout(state_dict): + return state_dict + + # Strip the ``module.`` ``DataParallel`` prefix if present. + state_dict = { + (key[len("module.") :] if key.startswith("module.") else key): value + for key, value in state_dict.items() + } + + # First pass: figure out M and num_slices from the input dict so we know + # how many cc / lrp / cross_attention slots to fan out and how to size + # the means/scales swap. + if "h_a.0.conv.weight" in state_dict: + m = state_dict["h_a.0.conv.weight"].size(1) + else: + # Fallback: gaussian_conditional doesn't expose M; rely on dt_cross_attention.0.linear.weight + m = state_dict["dt_cross_attention.0.linear.weight"].size(0) + cc_indices = sorted( + { + int(k.split(".")[1]) + for k in state_dict + if k.startswith("cc_mean_transforms.") + } + ) + if not cc_indices: + raise ValueError("cannot infer num_slices from upstream state_dict") + num_slices = max(cc_indices) + 1 + + out: Dict[str, Tensor] = {} + + # Helper to rewrite a per-slice ModuleList key like "cc_mean_transforms.0.0.weight" + # to a containerized destination, optionally swapping first-2M channels in the + # first sub-module (index "0"). + def _reroot_modlist( + src_prefix: str, dst_prefix: str, *, swap_first_conv: bool + ) -> None: + for key, value in state_dict.items(): + if not key.startswith(src_prefix): + continue + tail = key[ + len(src_prefix) : + ] # e.g., "0.0.weight" (slice 0, sub 0, "weight") + parts = tail.split(".", 2) + if len(parts) < 2: + continue + slice_idx, sub_idx = parts[0], parts[1] + sub_tail = parts[2] if len(parts) > 2 else "" + new_value = value + if swap_first_conv and sub_idx == "0" and sub_tail == "weight": + new_value = _swap_first_2m_in_channels(value, m) + dst_key = f"{dst_prefix.format(slice_idx=slice_idx)}.{sub_idx}" + if sub_tail: + dst_key = f"{dst_key}.{sub_tail}" + out[dst_key] = new_value + + _reroot_modlist( + "cc_mean_transforms.", + "latent_codec.y.channel_context.y{slice_idx}.mean_cc", + swap_first_conv=True, + ) + _reroot_modlist( + "cc_scale_transforms.", + "latent_codec.y.channel_context.y{slice_idx}.scale_cc", + swap_first_conv=True, + ) + _reroot_modlist( + "lrp_transforms.", + "latent_codec.y.latent_codec.y{slice_idx}.lrp_transform", + swap_first_conv=True, + ) + + # cross_attention is a ModuleList of MultiScaleDictionaryCrossAttentionGLU; its + # x_trans.weight is the only Linear that consumes the [scales, means, ...] input. + for key, value in state_dict.items(): + if not key.startswith("dt_cross_attention."): + continue + # Tail layout: "{slice_idx}.{remainder}". + tail = key[len("dt_cross_attention.") :] + slice_idx, remainder = tail.split(".", 1) + new_value = value + if remainder == "x_trans.weight": + new_value = _swap_first_2m_in_channels(value, m) + dst_key = ( + f"latent_codec.y.channel_context.y{slice_idx}.cross_attention.{remainder}" + ) + out[dst_key] = new_value + + # Fan out the single shared gaussian_conditional to K per-slice copies. + for key, value in state_dict.items(): + if not key.startswith("gaussian_conditional."): + continue + suffix = key[len("gaussian_conditional.") :] + for slice_idx in range(num_slices): + out[ + f"latent_codec.y.latent_codec.y{slice_idx}.gaussian_conditional.{suffix}" + ] = value + + # Top-level renames. + top_level_renames = { + "dt": "shared_dictionary.dt", + "h_a.": "latent_codec.h_a.", + "h_z_s1.": "latent_codec.h_s.h_scale_s.", + "h_z_s2.": "latent_codec.h_s.h_mean_s.", + "entropy_bottleneck.": "latent_codec.z.entropy_bottleneck.", + } + for key, value in state_dict.items(): + if key in ("dt",) or key.startswith( + ( + "h_a.", + "h_z_s1.", + "h_z_s2.", + "entropy_bottleneck.", + ) + ): + for src, dst in top_level_renames.items(): + if key == src: + out[dst] = value + break + if src.endswith(".") and key.startswith(src): + out[dst + key[len(src) :]] = value + break + elif key.startswith( + ( + "cc_mean_transforms.", + "cc_scale_transforms.", + "lrp_transforms.", + "dt_cross_attention.", + "gaussian_conditional.", + ) + ): + # Already handled above. + continue + else: + # Carry-over keys not part of the entropy / hyperprior sections + # (e.g., g_a / g_s parameters). + out[key] = value + + return out + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--src", + type=Path, + required=True, + help="Path to the upstream DCAE checkpoint (e.g. 0.0018checkpoint_best.pth.tar).", + ) + parser.add_argument( + "--dst", + type=Path, + default=None, + help=( + "Optional output path for the converted state dict. If omitted, " + "the script only verifies that the checkpoint loads cleanly." + ), + ) + parser.add_argument( + "--smoke", + action="store_true", + help="Run a forward smoke test on a synthetic 256x256 image.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if not args.src.exists(): + raise SystemExit(f"checkpoint not found: {args.src}") + + upstream = torch.load(args.src, map_location="cpu", weights_only=False) + upstream = ( + upstream.get("state_dict", upstream) if isinstance(upstream, dict) else upstream + ) + if _is_upstream_layout(upstream): + converted = convert_upstream_dcae_state_dict(upstream) + else: + converted = upstream + print(f"loaded {len(upstream)} upstream keys → {len(converted)} compressai keys") + + net = DCAE.from_state_dict(converted) + net.eval() + print( + "variant: " + f"N={net.N}, M={net.M}, hyper_channels={net.hyper_channels}, " + f"num_slices={net.num_slices}, max_support_slices={net.max_support_slices}, " + f"feature_dims={tuple(net.feature_dims)}, block_num={tuple(net.block_num)}, " + f"dict_num={net.dict_num}, dict_head_num={net.dict_head_num}, " + f"dictionary_dim={net.dictionary_dim}" + ) + print(f"parameters: {sum(p.numel() for p in net.parameters()):,}") + + if args.dst is not None: + args.dst.parent.mkdir(parents=True, exist_ok=True) + torch.save(net.state_dict(), args.dst) + print(f"wrote converted state dict → {args.dst}") + + if args.smoke: + height = width = 256 + ys, xs = torch.meshgrid( + torch.linspace(0, 1, height), + torch.linspace(0, 1, width), + indexing="ij", + ) + img = ( + torch.stack( + [ + 0.5 + 0.3 * torch.sin(8 * xs), + 0.5 + 0.3 * torch.sin(8 * ys), + 0.5 + 0.3 * torch.cos(8 * (xs + ys)), + ], + dim=0, + ) + .unsqueeze(0) + .clamp(0, 1) + ) + + with torch.no_grad(): + out = net(img) + n_pix = height * width + psnr = -10 * torch.log10(((out["x_hat"].clamp(0, 1) - img) ** 2).mean()).item() + y_bpp = -torch.log2(out["likelihoods"]["y"]).sum().item() / n_pix + z_bpp = -torch.log2(out["likelihoods"]["z"]).sum().item() / n_pix + print( + f"smoke: PSNR={psnr:.2f}dB y_bpp={y_bpp:.4f} z_bpp={z_bpp:.4f} " + f"total_bpp={y_bpp + z_bpp:.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/convert_saaf_checkpoint.py b/examples/convert_saaf_checkpoint.py new file mode 100644 index 00000000..0e930742 --- /dev/null +++ b/examples/convert_saaf_checkpoint.py @@ -0,0 +1,317 @@ +"""Convert an upstream SAAF checkpoint to compressai layout. + +Loads the published SAAF weight file (e.g. ``mse_0.0018.pth`` from the +upstream SAAF release accompanying Ma et al., CVPR 2026), translates it +to compressai's containerized module layout, and writes a state dict +that ``compressai.models.saaf.SAAF.from_state_dict`` can load directly. +Optionally reports forward-pass sanity numbers (PSNR / bpp) on a +synthetic input. + +The upstream-vs-compressai key differences mirror DCAE's converter +exactly (top-level ``dt`` -> shared dictionary submodule, +``cc_mean_transforms.{k}`` / ``cc_scale_transforms.{k}`` / +``lrp_transforms.{k}`` / ``dt_cross_attention.{k}`` ModuleLists -> +per-slice ``latent_codec.y.channel_context.y{k}`` / +``latent_codec.y.latent_codec.y{k}`` entries, the means/scales swap on +the first 2*M input channels of the first conv / linear weights, the +``h_z_s2`` / ``h_z_s1`` -> means/scales swap, the H+G containerized +re-rooting under ``latent_codec.*``, etc.) — all handled inside +``convert_upstream_saaf_state_dict``; this script is a thin CLI around +it. The SAAF-specific ``aux_enc`` / ``aux_dec`` / ``diffusion_prior`` +keys pass through unchanged. + +Example:: + + python examples/convert_saaf_checkpoint.py \\ + --src candidate/SAAF/mse_0.0018.pth \\ + --dst /tmp/saaf_compressai.pth \\ + --smoke +""" + +from __future__ import annotations + +import argparse + +from pathlib import Path +from typing import Dict + +import torch + +from torch import Tensor + +from compressai.models.saaf import SAAF + +# ---------------------------------------------------------------------------- +# Upstream → compressai state-dict conversion. +# +# Lives here (not in compressai/models/saaf.py) so the model module stays a +# clean compressai-native definition — ``SAAF.from_state_dict`` only loads +# already-converted state dicts. Run this script once to translate a +# published upstream checkpoint into compressai layout, then load the result +# via ``from_state_dict``. +# ---------------------------------------------------------------------------- + + +def _is_upstream_layout(state_dict: Dict[str, Tensor]) -> bool: + """Heuristic: upstream SAAF has top-level ``dt`` and ``dt_cross_attention.0.*`` + (with or without a ``module.`` ``DataParallel`` prefix).""" + return any(key in ("dt", "module.dt") for key in state_dict) and any( + k.startswith("dt_cross_attention.") + or k.startswith("module.dt_cross_attention.") + for k in state_dict + ) + + +def _swap_first_2m_in_channels(weight: Tensor, m: int) -> Tensor: + """Swap the leading ``m`` and second ``m`` slices along ``dim=1``. + + Same means/scales swap as DCAE: upstream uses scales-before-means, + the containerized wiring uses means-before-scales. Permuting the + leading 2*m input channels of the first conv / linear weight in + cc_mean, cc_scale, lrp_transforms, and cross_attention.x_trans rebases + the upstream weights to the new input order with no retraining. + """ + if weight.dim() < 2 or weight.size(1) < 2 * m: + return weight + permuted = weight.clone() + permuted[:, :m] = weight[:, m : 2 * m] + permuted[:, m : 2 * m] = weight[:, :m] + return permuted + + +def convert_upstream_saaf_state_dict( + state_dict: Dict[str, Tensor], +) -> Dict[str, Tensor]: + """Convert an upstream SAAF checkpoint to the containerized layout. + + Mirrors :func:`convert_upstream_dcae_state_dict` one-to-one — SAAF and + DCAE share the same entropy stack (top-level ``dt``, per-slice + ``dt_cross_attention`` / ``cc_mean_transforms`` / ``cc_scale_transforms`` + / ``lrp_transforms`` ModuleLists, single shared ``gaussian_conditional``, + and ``h_z_s2`` (means) / ``h_z_s1`` (scales) hyper-synthesis heads). The + SAAF-specific ``aux_enc`` / ``aux_dec`` / ``diffusion_prior`` keys pass + through unchanged. + + See :func:`_swap_first_2m_in_channels` for the means/scales swap on the + first 2M input channels of the first conv / linear weights inside + ``cross_attention``, ``cc_mean``, ``cc_scale``, and ``lrp_transform``. + """ + if not _is_upstream_layout(state_dict): + return state_dict + + # Strip the ``module.`` ``DataParallel`` prefix if present. + state_dict = { + (key[len("module.") :] if key.startswith("module.") else key): value + for key, value in state_dict.items() + } + # Upstream OLP persists its ``identity_matrix`` buffer; compressai's + # :class:`compressai.models._helpers.auxt.OLP` registers it with + # ``persistent=False`` and rebuilds it at construction. Drop the + # upstream copies so strict load_state_dict succeeds. + state_dict = { + key: value + for key, value in state_dict.items() + if not key.endswith(".olp.identity_matrix") + } + + if "h_a.0.conv.weight" in state_dict: + m = state_dict["h_a.0.conv.weight"].size(1) + else: + m = state_dict["dt_cross_attention.0.linear.weight"].size(0) + cc_indices = sorted( + { + int(k.split(".")[1]) + for k in state_dict + if k.startswith("cc_mean_transforms.") + } + ) + if not cc_indices: + raise ValueError("cannot infer num_slices from upstream state_dict") + num_slices = max(cc_indices) + 1 + + out: Dict[str, Tensor] = {} + + def _reroot_modlist( + src_prefix: str, dst_prefix: str, *, swap_first_conv: bool + ) -> None: + for key, value in state_dict.items(): + if not key.startswith(src_prefix): + continue + tail = key[len(src_prefix) :] + parts = tail.split(".", 2) + if len(parts) < 2: + continue + slice_idx, sub_idx = parts[0], parts[1] + sub_tail = parts[2] if len(parts) > 2 else "" + new_value = value + if swap_first_conv and sub_idx == "0" and sub_tail == "weight": + new_value = _swap_first_2m_in_channels(value, m) + dst_key = f"{dst_prefix.format(slice_idx=slice_idx)}.{sub_idx}" + if sub_tail: + dst_key = f"{dst_key}.{sub_tail}" + out[dst_key] = new_value + + _reroot_modlist( + "cc_mean_transforms.", + "latent_codec.y.channel_context.y{slice_idx}.mean_cc", + swap_first_conv=True, + ) + _reroot_modlist( + "cc_scale_transforms.", + "latent_codec.y.channel_context.y{slice_idx}.scale_cc", + swap_first_conv=True, + ) + _reroot_modlist( + "lrp_transforms.", + "latent_codec.y.latent_codec.y{slice_idx}.lrp_transform", + swap_first_conv=True, + ) + + for key, value in state_dict.items(): + if not key.startswith("dt_cross_attention."): + continue + tail = key[len("dt_cross_attention.") :] + slice_idx, remainder = tail.split(".", 1) + new_value = value + if remainder == "x_trans.weight": + new_value = _swap_first_2m_in_channels(value, m) + dst_key = ( + f"latent_codec.y.channel_context.y{slice_idx}.cross_attention.{remainder}" + ) + out[dst_key] = new_value + + for key, value in state_dict.items(): + if not key.startswith("gaussian_conditional."): + continue + suffix = key[len("gaussian_conditional.") :] + for slice_idx in range(num_slices): + out[ + f"latent_codec.y.latent_codec.y{slice_idx}.gaussian_conditional.{suffix}" + ] = value + + top_level_renames = { + "dt": "shared_dictionary.dt", + "h_a.": "latent_codec.h_a.", + "h_z_s1.": "latent_codec.h_s.h_scale_s.", + "h_z_s2.": "latent_codec.h_s.h_mean_s.", + "entropy_bottleneck.": "latent_codec.z.entropy_bottleneck.", + } + handled_prefixes = ( + "cc_mean_transforms.", + "cc_scale_transforms.", + "lrp_transforms.", + "dt_cross_attention.", + "gaussian_conditional.", + ) + for key, value in state_dict.items(): + if key in ("dt",) or key.startswith( + ("h_a.", "h_z_s1.", "h_z_s2.", "entropy_bottleneck.") + ): + for src, dst in top_level_renames.items(): + if key == src: + out[dst] = value + break + if src.endswith(".") and key.startswith(src): + out[dst + key[len(src) :]] = value + break + elif key.startswith(handled_prefixes): + continue + else: + # SAAF-specific carry-through keys: g_a / g_s / aux_enc / + # aux_dec / diffusion_prior — no rename needed. + out[key] = value + + return out + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--src", + type=Path, + required=True, + help="Path to the upstream SAAF checkpoint (e.g. mse_0.0018.pth).", + ) + parser.add_argument( + "--dst", + type=Path, + default=None, + help=( + "Optional output path for the converted state dict. If omitted, " + "the script only verifies that the checkpoint loads cleanly." + ), + ) + parser.add_argument( + "--smoke", + action="store_true", + help="Run a forward smoke test on a synthetic 256x256 image.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if not args.src.exists(): + raise SystemExit(f"checkpoint not found: {args.src}") + + upstream = torch.load(args.src, map_location="cpu", weights_only=False) + upstream = ( + upstream.get("state_dict", upstream) if isinstance(upstream, dict) else upstream + ) + if _is_upstream_layout(upstream): + converted = convert_upstream_saaf_state_dict(upstream) + else: + converted = upstream + print(f"loaded {len(upstream)} upstream keys → {len(converted)} compressai keys") + + net = SAAF.from_state_dict(converted) + net.eval() + print( + "variant: " + f"N={net.N}, M={net.M}, hyper_channels={net.hyper_channels}, " + f"num_slices={net.num_slices}, max_support_slices={net.max_support_slices}, " + f"feature_dims={tuple(net.feature_dims)}, block_num={tuple(net.block_num)}, " + f"dict_num={net.dict_num}, dict_head_num={net.dict_head_num}, " + f"dictionary_dim={net.dictionary_dim}" + ) + print(f"parameters: {sum(p.numel() for p in net.parameters()):,}") + + if args.dst is not None: + args.dst.parent.mkdir(parents=True, exist_ok=True) + torch.save(net.state_dict(), args.dst) + print(f"wrote converted state dict → {args.dst}") + + if args.smoke: + height = width = 256 + ys, xs = torch.meshgrid( + torch.linspace(0, 1, height), + torch.linspace(0, 1, width), + indexing="ij", + ) + img = ( + torch.stack( + [ + 0.5 + 0.3 * torch.sin(8 * xs), + 0.5 + 0.3 * torch.sin(8 * ys), + 0.5 + 0.3 * torch.cos(8 * (xs + ys)), + ], + dim=0, + ) + .unsqueeze(0) + .clamp(0, 1) + ) + + with torch.no_grad(): + out = net(img) + n_pix = height * width + psnr = -10 * torch.log10(((out["x_hat"].clamp(0, 1) - img) ** 2).mean()).item() + y_bpp = -torch.log2(out["likelihoods"]["y"]).sum().item() / n_pix + z_bpp = -torch.log2(out["likelihoods"]["z"]).sum().item() / n_pix + print( + f"smoke: PSNR={psnr:.2f}dB y_bpp={y_bpp:.4f} z_bpp={z_bpp:.4f} " + f"total_bpp={y_bpp + z_bpp:.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/convert_tcm_checkpoint.py b/examples/convert_tcm_checkpoint.py index d8ad9771..63ec6b6c 100644 --- a/examples/convert_tcm_checkpoint.py +++ b/examples/convert_tcm_checkpoint.py @@ -34,6 +34,10 @@ from torch import Tensor +from compressai.models._helpers.auxt import ( + is_auxt_upstream_wavelet_buffer_key, + normalize_upstream_auxt_key, +) from compressai.models.tcm import TCM # ---------------------------------------------------------------------------- @@ -209,9 +213,26 @@ def convert_upstream_tcm_state_dict( # alias ``atten_mean`` / ``atten_scale`` to the canonical # ``mean_support_transforms`` / ``scale_support_transforms`` names so the # per-slice rerooting in Pass 2 only has to handle one form. + has_auxt = any( + (key[len("module.") :] if key.startswith("module.") else key).startswith( + ("AuxT_enc.", "AuxT_dec.") + ) + for key in state_dict + ) + cleaned: Dict[str, Tensor] = {} for key, value in state_dict.items(): new_key = key[len("module.") :] if key.startswith("module.") else key + # Drop the upstream LIC_TCM custom DWT/IDWT kernel buffers — the + # ``pytorch_wavelets``-backed :class:`compressai.layers.wave.DWT2D` + # / :class:`IDWT2D` regenerate their kernels at construction. + if is_auxt_upstream_wavelet_buffer_key(new_key): + continue + # Upstream stores the OLP submodule as ``.OLP.`` (PascalCase to match + # the class name); compressai uses ``.olp.`` (lower attribute name). + normalized = normalize_upstream_auxt_key(new_key) + if normalized is not None: + new_key = normalized new_key, value = _rename_msa_keys(new_key, value) wrapper = _UPSTREAM_SWATTEN_WRAPPER.match(new_key) if wrapper: @@ -226,6 +247,13 @@ def convert_upstream_tcm_state_dict( new_key = new_key.replace(".ln2.", ".norm2.") new_key = new_key.replace(".mlp.0.", ".mlp.fc1.") new_key = new_key.replace(".mlp.2.", ".mlp.fc2.") + if ( + has_auxt + and (new_key.startswith("g_a.") or new_key.startswith("g_s.")) + and ".auxiliary_layers." not in new_key + ): + root, tail = new_key.split(".", 1) + new_key = f"{root}.transform.{tail}" if ".msa.output_proj." in new_key: _ensure_identity_attention_projection(cleaned, new_key, value) cleaned[new_key] = value diff --git a/pyproject.toml b/pyproject.toml index da77a0b6..75b051bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,9 @@ pointcloud = [ attn = [ "timm", ] +wavelet = [ + "pytorch_wavelets", +] # NOTE: Temporarily duplicated from [project.optional-dependencies] until # pip supports installing [dependency-groups]. diff --git a/tests/test_layers.py b/tests/test_layers.py index 821de761..c763fcd6 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -204,6 +204,97 @@ def test_AttentionBlock(): layer(torch.rand(1, 8, 4, 4)) +class TestMultiScaleDictionaryCrossAttentionGLU: + @staticmethod + def test_forward_shape(): + from compressai.layers.attn.dictionary import ( + MultiScaleDictionaryCrossAttentionGLU, + ) + + mod = MultiScaleDictionaryCrossAttentionGLU( + input_dim=192, + output_dim=320, + head_num=4, + dictionary_dim=128, + ) + x = torch.randn(2, 192, 4, 4) + dictionary = torch.randn(2, 16, 128) + out = mod(x, dictionary) + assert out.shape == (2, 320, 4, 4) + + @staticmethod + def test_state_dict_round_trip(): + from compressai.layers.attn.dictionary import ( + MultiScaleDictionaryCrossAttentionGLU, + ) + + mod = MultiScaleDictionaryCrossAttentionGLU( + input_dim=192, + output_dim=320, + head_num=4, + dictionary_dim=128, + ) + mod2 = MultiScaleDictionaryCrossAttentionGLU( + input_dim=192, + output_dim=320, + head_num=4, + dictionary_dim=128, + ) + mod2.load_state_dict(mod.state_dict(), strict=True) + x = torch.randn(2, 192, 4, 4) + dictionary = torch.randn(2, 16, 128) + assert torch.allclose(mod(x, dictionary), mod2(x, dictionary)) + + @staticmethod + def test_dictionary_dim_default_matches_head_num(): + from compressai.layers.attn.dictionary import ( + MultiScaleDictionaryCrossAttentionGLU, + ) + + # Default dictionary_dim = 32 * head_num + mod = MultiScaleDictionaryCrossAttentionGLU( + input_dim=64, output_dim=128, head_num=4 + ) + x = torch.randn(1, 64, 2, 2) + dictionary = torch.randn(1, 8, 128) + out = mod(x, dictionary) + assert out.shape == (1, 128, 2, 2) + + @staticmethod + def test_dictionary_dim_must_divide_head_num(): + from compressai.layers.attn.dictionary import ( + MultiScaleDictionaryCrossAttentionGLU, + ) + + with pytest.raises(ValueError, match="divisible"): + MultiScaleDictionaryCrossAttentionGLU( + input_dim=32, output_dim=64, head_num=3, dictionary_dim=128 + ) + + +class TestWavelet: + @staticmethod + def test_dwt_idwt_round_trip(): + pytest.importorskip("pytorch_wavelets") + from compressai.layers.wave import DWT2D, IDWT2D + + dwt = DWT2D(wave="haar") + idwt = IDWT2D(wave="haar") + x = torch.randn(2, 3, 16, 16) + sub = dwt(x) + # 4 subbands -> output channels = 4 * input + assert sub.shape == (2, 12, 8, 8) + rec = idwt(sub) + assert rec.shape == x.shape + assert (rec - x).abs().max().item() < 1e-5 + + @staticmethod + def test_is_pytorch_wavelets_available_returns_bool(): + from compressai.layers.wave import is_pytorch_wavelets_available + + assert isinstance(is_pytorch_wavelets_available(), bool) + + class TestQReLU: @staticmethod def test_QReLU(): diff --git a/tests/test_models.py b/tests/test_models.py index ff9ef958..66c4013a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -568,6 +568,543 @@ def test_tcm_upstream_state_dict_conversion(self): assert "lrp_transforms.0.0.weight" not in converted assert "module.g_a.0.conv1.weight" not in converted + def test_tcm_use_auxt_default_false(self): + from compressai.models.tcm import TCM + + model = TCM( + N=32, + M=64, + hyper_channels=48, + num_slices=4, + max_support_slices=2, + ) + assert model.use_auxt is False + assert not hasattr(model.g_a, "auxiliary_layers") + assert not hasattr(model.g_s, "auxiliary_layers") + # aux_loss remains a scalar without AuxT, carrying the base + # EntropyBottleneck auxiliary loss. + loss = model.aux_loss() + assert loss.dim() == 0 + assert torch.isfinite(loss) + + def test_tcm_use_auxt_construction_and_forward(self): + pytest.importorskip("pytorch_wavelets") + from compressai.models.tcm import TCM + + model = TCM( + N=32, + M=64, + hyper_channels=48, + num_slices=4, + max_support_slices=2, + use_auxt=True, + ).eval() + assert model.use_auxt is True + assert len(model.g_a.auxiliary_layers) == 4 + assert len(model.g_s.auxiliary_layers) == 4 + # Default config (2,2,2,2,2,2) -> 10-layer g_a / g_s with merge + # positions (0, 3, 6, 9) and (2, 5, 8, 9) respectively. + assert model.g_a.merge_positions == (0, 3, 6, 9) + assert model.g_s.merge_positions == (2, 5, 8, 9) + + x = torch.rand(1, 3, 64, 64) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + + # Aggregated OLP regulariser is a finite scalar > 0 with random init. + loss = model.aux_loss() + assert loss.dim() == 0 + assert torch.isfinite(loss) and loss.item() > 0 + + def test_tcm_use_auxt_state_dict_round_trip(self): + pytest.importorskip("pytorch_wavelets") + from compressai.models.tcm import TCM + + model = TCM( + N=32, + M=64, + hyper_channels=48, + num_slices=4, + max_support_slices=2, + use_auxt=True, + ).eval() + sd = model.state_dict() + # AuxT submodule paths are present. + auxt_keys = { + k + for k in sd + if k.startswith(("g_a.auxiliary_layers.", "g_s.auxiliary_layers.")) + } + assert any(".olp.linear.weight" in k for k in auxt_keys) + assert any(".scaling_factors" in k for k in auxt_keys) + assert "g_a.transform.0.conv1.weight" in sd + + loaded = TCM.from_state_dict(sd).eval() + # use_auxt is auto-detected from the wrapper auxiliary keys. + assert loaded.use_auxt is True + x = torch.rand(1, 3, 64, 64) + with torch.no_grad(): + assert torch.allclose(model(x)["x_hat"], loaded(x)["x_hat"]) + + def test_tcm_convert_strips_upstream_wavelet_buffers_and_renames_olp(self): + convert_upstream_tcm_state_dict = _load_convert_fn( + "convert_tcm_checkpoint.py", "convert_upstream_tcm_state_dict" + ) + + # Synthetic upstream LIC_TCM-with-AuxT key set: minimal entropy + # backbone keys to drive num_slices inference, plus AuxT keys with + # the upstream-style ``.OLP.`` submodule and ``w_*`` / ``filters`` + # custom DWT/IDWT kernel buffers that should get dropped. + upstream = { + "module.g_a.0.conv1.weight": torch.zeros(2), + "module.cc_mean_transforms.0.0.weight": torch.zeros(2), + "module.cc_scale_transforms.0.0.weight": torch.zeros(2), + "module.lrp_transforms.0.0.weight": torch.zeros(2), + "module.gaussian_conditional.scale_table": torch.zeros(2), + "module.h_a.0.conv1.weight": torch.zeros(2), + "module.h_mean_s.0.conv.weight": torch.zeros(2), + "module.h_scale_s.0.conv.weight": torch.zeros(2), + "module.entropy_bottleneck.quantiles": torch.zeros(2), + # AuxT keys: .OLP. should be renamed to .olp., w_*/filters dropped. + "module.AuxT_enc.0.OLP.linear.weight": torch.zeros(8, 12), + "module.AuxT_enc.0.OLP.linear.bias": torch.zeros(8), + "module.AuxT_enc.0.scaling_factors": torch.zeros(1, 1, 12), + "module.AuxT_enc.0.dwt.w_ll": torch.zeros(2, 2), + "module.AuxT_enc.0.dwt.w_lh": torch.zeros(2, 2), + "module.AuxT_enc.0.dwt.w_hl": torch.zeros(2, 2), + "module.AuxT_enc.0.dwt.w_hh": torch.zeros(2, 2), + "module.AuxT_dec.0.OLP.linear.weight": torch.zeros(12, 8), + "module.AuxT_dec.0.idwt.filters": torch.zeros(4, 4), + } + converted = convert_upstream_tcm_state_dict(upstream) + + # ``.OLP.`` -> ``.olp.`` rename, ``module.`` prefix gone. + assert "g_a.auxiliary_layers.0.olp.linear.weight" in converted + assert "g_a.auxiliary_layers.0.olp.linear.bias" in converted + assert "g_s.auxiliary_layers.0.olp.linear.weight" in converted + # scaling_factors carries through. + assert "g_a.auxiliary_layers.0.scaling_factors" in converted + # With AuxT present, the main transform becomes the wrapped transform. + assert "g_a.transform.0.conv1.weight" in converted + assert "g_a.0.conv1.weight" not in converted + # Upstream-LIC_TCM-specific DWT/IDWT kernel buffers dropped. + for suffix in ("w_ll", "w_lh", "w_hl", "w_hh"): + assert not any( + k.endswith(suffix) for k in converted + ), f"upstream DWT buffer {suffix} should have been dropped" + assert not any(k.endswith(".idwt.filters") for k in converted) + # Upstream-style PascalCase OLP keys should be gone. + assert "g_a.auxiliary_layers.0.OLP.linear.weight" not in converted + + +class TestDcae: + def _tiny_kwargs(self): + return dict( + N=64, + M=80, + hyper_channels=64, + num_slices=4, + max_support_slices=2, + feature_dims=(48, 64, 80), + block_num=(1, 1, 2), + head_dim=(8, 8, 8, 8, 8, 8), + dict_num=8, + dict_head_num=4, + dictionary_dim=32, + window_size=4, + hyper_window_size=2, + hyper_head_dim=8, + ) + + def test_dcae_forward_and_state_dict_round_trip(self): + from compressai.models.dcae import DCAE + + model = DCAE(**self._tiny_kwargs()).eval() + x = torch.rand(1, 3, 64, 64) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + + sd_keys = set(model.state_dict().keys()) + # Shared dictionary lives at the model level (single state-dict path). + assert "shared_dictionary.dt" in sd_keys + assert sum(1 for k in sd_keys if k.endswith(".dt")) == 1 + # Hyperprior backbone moved under latent_codec.* (DCAE h_a wraps a + # ResidualBottleneckBlockWithStride: outermost weight is .conv). + assert "latent_codec.h_a.0.conv.weight" in sd_keys + assert "latent_codec.h_s.h_mean_s.0.weight" in sd_keys + assert "latent_codec.h_s.h_scale_s.0.weight" in sd_keys + assert "latent_codec.z.entropy_bottleneck.quantiles" in sd_keys + # Side-parameter channel context covers y0..y(K-1). + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y1.mean_cc.0.weight" in sd_keys + # DCAE-specific dictionary cross-attention head. + assert ( + "latent_codec.y.channel_context.y0.cross_attention.x_trans.weight" + in sd_keys + ) + # Per-slice leaves (LRP + per-slice GaussianConditional copy). + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in sd_keys + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" in sd_keys + ) + assert ( + "latent_codec.y.latent_codec.y3.gaussian_conditional.scale_table" in sd_keys + ) + # Old monolithic paths should be gone. + assert "dt" not in sd_keys + assert not any(k.startswith("dt_cross_attention.") for k in sd_keys) + assert not any(k.startswith("cc_mean_transforms.") for k in sd_keys) + assert not any(k.startswith("h_z_s1.") for k in sd_keys) + assert not any(k.startswith("h_z_s2.") for k in sd_keys) + + loaded = DCAE.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + assert loaded.N == 64 + assert loaded.M == 80 + assert loaded.hyper_channels == 64 + assert loaded.num_slices == 4 + assert loaded.max_support_slices == 2 + assert loaded.dict_num == 8 + assert loaded.dict_head_num == 4 + assert loaded.dictionary_dim == 32 + + def test_dcae_upstream_state_dict_conversion(self): + convert_upstream_dcae_state_dict = _load_convert_fn( + "convert_dcae_checkpoint.py", "convert_upstream_dcae_state_dict" + ) + + # Synthetic upstream DCAE-style state_dict: top-level dt + per-slice + # ModuleLists for cc_mean / cc_scale / lrp / dt_cross_attention, + # single shared gaussian_conditional, model-owned hyperprior with + # h_z_s1 (scales) / h_z_s2 (means). + m = 80 + # num_slices=4 -> slice_ch = m // 4 = 20 (used inline in shape calculations below) + # cc_mean.0 first conv input width = M*3 + slice_ch * 0 = 240 + # cc_mean.1 first conv input width = M*3 + slice_ch * 1 = 260 + # lrp.0 first conv input width = M*3 + 0 + slice_ch = 260 + # cross_attention.0 input width = M*2 + 0 = 160 + # cross_attention.1 input width = M*2 + slice_ch = 180 + upstream = { + # Top-level dictionary tensor. + "dt": torch.zeros(8, 32), + # Per-slice dt_cross_attention: x_trans is the only Linear that + # consumes the original [scales, means, ...] input order. + "dt_cross_attention.0.x_trans.weight": torch.arange(32 * 160) + .float() + .reshape(32, 160), + "dt_cross_attention.0.scale": torch.zeros(4, 1, 1), + "dt_cross_attention.1.x_trans.weight": torch.arange(32 * 180) + .float() + .reshape(32, 180), + "dt_cross_attention.1.scale": torch.zeros(4, 1, 1), + # Per-slice cc_mean / cc_scale: first conv has the means/scales swap. + "cc_mean_transforms.0.0.weight": torch.arange(64 * 240) + .float() + .reshape(64, 240, 1, 1), + "cc_mean_transforms.1.0.weight": torch.zeros(64, 260, 1, 1), + "cc_scale_transforms.0.0.weight": torch.arange(64 * 240) + .float() + .reshape(64, 240, 1, 1), + "cc_scale_transforms.1.0.weight": torch.zeros(64, 260, 1, 1), + # Per-slice LRP: first conv also has the means/scales swap. + "lrp_transforms.0.0.weight": torch.arange(64 * 260) + .float() + .reshape(64, 260, 1, 1), + "lrp_transforms.1.0.weight": torch.zeros(64, 280, 1, 1), + # Single shared gaussian_conditional (gets fanned out per slice). + "gaussian_conditional.scale_table": torch.zeros(2), + # Model-owned hyperprior backbone. + "h_a.0.conv.weight": torch.zeros(64, 80, 5, 5), + "h_z_s1.0.weight": torch.zeros(64, 64, 3, 3), # scales (originally h_z_s1) + "h_z_s2.0.weight": torch.zeros(64, 64, 3, 3), # means (originally h_z_s2) + "entropy_bottleneck.quantiles": torch.zeros(64, 1, 3), + # g_a / g_s carry through unchanged. + "g_a.0.conv.weight": torch.zeros(48, 3, 5, 5), + } + converted = convert_upstream_dcae_state_dict(upstream) + + # Top-level dt -> shared_dictionary.dt. + assert "shared_dictionary.dt" in converted + assert "dt" not in converted + + # Per-slice cross_attention re-rooted; x_trans.weight has its first 2*M + # input channels (dim=1) swapped (means/scales reorder). + assert ( + "latent_codec.y.channel_context.y0.cross_attention.x_trans.weight" + in converted + ) + original = upstream["dt_cross_attention.0.x_trans.weight"] + swapped = converted[ + "latent_codec.y.channel_context.y0.cross_attention.x_trans.weight" + ] + # Swap should leave the trailing channels (>=2*M) unchanged but flip + # the leading [0:M] and [M:2M] blocks. + assert torch.equal(swapped[:, :m], original[:, m : 2 * m]) + assert torch.equal(swapped[:, m : 2 * m], original[:, :m]) + # cross_attention scale (head_num,1,1) carries through unchanged. + assert "latent_codec.y.channel_context.y0.cross_attention.scale" in converted + + # Per-slice cc_mean / cc_scale re-rooted with means/scales swap on first conv. + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in converted + original_cc = upstream["cc_mean_transforms.0.0.weight"] + swapped_cc = converted["latent_codec.y.channel_context.y0.mean_cc.0.weight"] + assert torch.equal(swapped_cc[:, :m], original_cc[:, m : 2 * m]) + assert torch.equal(swapped_cc[:, m : 2 * m], original_cc[:, :m]) + assert "latent_codec.y.channel_context.y1.mean_cc.0.weight" in converted + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in converted + + # Per-slice LRP re-rooted with means/scales swap on first conv. + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in converted + original_lrp = upstream["lrp_transforms.0.0.weight"] + swapped_lrp = converted["latent_codec.y.latent_codec.y0.lrp_transform.0.weight"] + assert torch.equal(swapped_lrp[:, :m], original_lrp[:, m : 2 * m]) + assert torch.equal(swapped_lrp[:, m : 2 * m], original_lrp[:, :m]) + + # gaussian_conditional fanned out to all K slices (driven by num_slices = 2 here: + # only cc_mean has indices 0 and 1). + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y1.gaussian_conditional.scale_table" + in converted + ) + + # Hyperprior backbone moved under latent_codec.*; h_z_s2 -> h_mean_s, + # h_z_s1 -> h_scale_s (originally swapped on the upstream side). + assert "latent_codec.h_a.0.conv.weight" in converted + assert "latent_codec.h_s.h_mean_s.0.weight" in converted # was h_z_s2 + assert "latent_codec.h_s.h_scale_s.0.weight" in converted # was h_z_s1 + assert "latent_codec.z.entropy_bottleneck.quantiles" in converted + + # g_a / g_s carry through unchanged. + assert "g_a.0.conv.weight" in converted + + # Old root-level paths should be gone after conversion. + assert "h_a.0.conv.weight" not in converted + assert "h_z_s1.0.weight" not in converted + assert "h_z_s2.0.weight" not in converted + assert "entropy_bottleneck.quantiles" not in converted + assert "cc_mean_transforms.0.0.weight" not in converted + assert "lrp_transforms.0.0.weight" not in converted + assert "dt_cross_attention.0.x_trans.weight" not in converted + assert "gaussian_conditional.scale_table" not in converted + + +class TestSaaf: + def _tiny_kwargs(self): + # Same shape as TestDcae for direct cross-comparison. + return dict( + N=64, + M=80, + hyper_channels=64, + num_slices=4, + max_support_slices=2, + feature_dims=(48, 64, 80), + block_num=(1, 1, 2), + head_dim=(8, 8, 8, 8, 8, 8), + dict_num=8, + dict_head_num=4, + dictionary_dim=32, + window_size=4, + hyper_window_size=2, + hyper_head_dim=8, + ) + + def test_saaf_forward_and_state_dict_round_trip(self): + from compressai.models.saaf import SAAF + + model = SAAF(**self._tiny_kwargs()).eval() + x = torch.rand(1, 3, 64, 64) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + # diffusion_loss is always present in the output dict; zero in eval mode. + assert "diffusion_loss" in out + assert out["diffusion_loss"].dim() == 0 + assert out["diffusion_loss"].item() == 0.0 + + sd_keys = set(model.state_dict().keys()) + # Shared dictionary lives at the model level (single state-dict path). + assert "shared_dictionary.dt" in sd_keys + assert sum(1 for k in sd_keys if k.endswith(".dt")) == 1 + # Hyperprior backbone moved under latent_codec.* (SAAF h_a wraps a + # ResidualBottleneckBlockWithStride: outermost weight is .conv). + assert "latent_codec.h_a.0.conv.weight" in sd_keys + assert "latent_codec.h_s.h_mean_s.0.weight" in sd_keys + assert "latent_codec.h_s.h_scale_s.0.weight" in sd_keys + assert "latent_codec.z.entropy_bottleneck.quantiles" in sd_keys + # Side-parameter channel context covers y0..y(K-1). + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y1.mean_cc.0.weight" in sd_keys + # Dictionary cross-attention head (shared with DCAE). + assert ( + "latent_codec.y.channel_context.y0.cross_attention.x_trans.weight" + in sd_keys + ) + # Per-slice leaves (LRP + per-slice GaussianConditional copy). + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in sd_keys + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" in sd_keys + ) + # SAAF-specific: aux_enc / aux_dec each carry an OLP per stage, + # and diffusion_prior holds the noise predictor. + assert "aux_enc.0.olp.linear.weight" in sd_keys + assert "aux_dec.3.olp.linear.weight" in sd_keys + assert "diffusion_prior.noise_predictor.0.weight" in sd_keys + # Old monolithic paths should be gone. + assert "dt" not in sd_keys + assert not any(k.startswith("dt_cross_attention.") for k in sd_keys) + assert not any(k.startswith("cc_mean_transforms.") for k in sd_keys) + assert not any(k.startswith("h_z_s1.") for k in sd_keys) + assert not any(k.startswith("h_z_s2.") for k in sd_keys) + + loaded = SAAF.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + assert loaded.N == 64 + assert loaded.M == 80 + assert loaded.dict_num == 8 + + def test_saaf_aux_loss_is_nonzero_scalar(self): + from compressai.models.saaf import SAAF + + # SAAF integrates AuxT unconditionally (every _AdaptiveFrequencyBlock / + # _InverseAdaptiveFrequencyBlock carries an OLP), so aux_loss is + # always a non-trivial scalar — unlike TCM where use_auxt=False + # gives zero. + model = SAAF(**self._tiny_kwargs()).eval() + loss = model.aux_loss() + assert loss.dim() == 0 + assert torch.isfinite(loss) + assert loss.item() > 0 + + def test_saaf_diffusion_loss_active_in_training_mode(self): + from compressai.models.saaf import SAAF + + model = SAAF(**self._tiny_kwargs()).train() + x = torch.rand(1, 3, 64, 64) + with torch.no_grad(): + out = model(x) + assert out["diffusion_loss"].dim() == 0 + # Random init + random noise -> finite, non-zero scalar. + assert torch.isfinite(out["diffusion_loss"]) + assert out["diffusion_loss"].item() > 0 + + def test_saaf_upstream_state_dict_conversion(self): + convert_upstream_saaf_state_dict = _load_convert_fn( + "convert_saaf_checkpoint.py", "convert_upstream_saaf_state_dict" + ) + + # Synthetic upstream SAAF-style state_dict — same entropy backbone + # as DCAE plus SAAF-specific aux_enc / aux_dec / diffusion_prior + # keys that must pass through unchanged. + m = 80 + # num_slices=4 -> slice_ch = m // 4 = 20 (used inline below) + upstream = { + "dt": torch.zeros(8, 32), + "dt_cross_attention.0.x_trans.weight": torch.arange(32 * 160) + .float() + .reshape(32, 160), + "dt_cross_attention.0.scale": torch.zeros(4, 1, 1), + "cc_mean_transforms.0.0.weight": torch.arange(64 * 240) + .float() + .reshape(64, 240, 1, 1), + "cc_mean_transforms.1.0.weight": torch.zeros(64, 260, 1, 1), + "cc_scale_transforms.0.0.weight": torch.arange(64 * 240) + .float() + .reshape(64, 240, 1, 1), + "lrp_transforms.0.0.weight": torch.arange(64 * 260) + .float() + .reshape(64, 260, 1, 1), + "gaussian_conditional.scale_table": torch.zeros(2), + "h_a.0.conv.weight": torch.zeros(64, 80, 5, 5), + "h_z_s1.0.weight": torch.zeros(64, 64, 3, 3), # scales + "h_z_s2.0.weight": torch.zeros(64, 64, 3, 3), # means + "entropy_bottleneck.quantiles": torch.zeros(64, 1, 3), + "g_a.0.conv.weight": torch.zeros(48, 3, 5, 5), + # SAAF-specific keys that should pass through unchanged. + "aux_enc.0.olp.linear.weight": torch.zeros(48, 3), + "aux_enc.0.freq_weights": torch.zeros(4), + "aux_dec.3.olp.linear.weight": torch.zeros(3, 48), + "diffusion_prior.noise_predictor.0.weight": torch.zeros(80, 80, 3, 3), + } + converted = convert_upstream_saaf_state_dict(upstream) + + # Top-level dt -> shared_dictionary.dt. + assert "shared_dictionary.dt" in converted + assert "dt" not in converted + + # Means/scales swap on cross_attention.x_trans (first 2*M cols). + original = upstream["dt_cross_attention.0.x_trans.weight"] + swapped = converted[ + "latent_codec.y.channel_context.y0.cross_attention.x_trans.weight" + ] + assert torch.equal(swapped[:, :m], original[:, m : 2 * m]) + assert torch.equal(swapped[:, m : 2 * m], original[:, :m]) + + # Same swap on cc_mean and lrp_transform first conv weights. + for src_key, dst_key in ( + ( + "cc_mean_transforms.0.0.weight", + "latent_codec.y.channel_context.y0.mean_cc.0.weight", + ), + ( + "lrp_transforms.0.0.weight", + "latent_codec.y.latent_codec.y0.lrp_transform.0.weight", + ), + ): + original_w = upstream[src_key] + swapped_w = converted[dst_key] + assert torch.equal(swapped_w[:, :m], original_w[:, m : 2 * m]) + assert torch.equal(swapped_w[:, m : 2 * m], original_w[:, :m]) + + # gaussian_conditional fanned out per slice. + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y1.gaussian_conditional.scale_table" + in converted + ) + + # h_z_s2 -> h_mean_s, h_z_s1 -> h_scale_s renames (DCAE convention). + assert "latent_codec.h_a.0.conv.weight" in converted + assert "latent_codec.h_s.h_mean_s.0.weight" in converted + assert "latent_codec.h_s.h_scale_s.0.weight" in converted + assert "latent_codec.z.entropy_bottleneck.quantiles" in converted + + # SAAF-specific aux_enc / aux_dec / diffusion_prior pass through + # without any rename. + assert "aux_enc.0.olp.linear.weight" in converted + assert "aux_enc.0.freq_weights" in converted + assert "aux_dec.3.olp.linear.weight" in converted + assert "diffusion_prior.noise_predictor.0.weight" in converted + + # g_a / g_s pass through unchanged. + assert "g_a.0.conv.weight" in converted + + # Old root-level entropy paths gone. + assert "h_a.0.conv.weight" not in converted + assert "h_z_s1.0.weight" not in converted + assert "h_z_s2.0.weight" not in converted + assert "entropy_bottleneck.quantiles" not in converted + assert "cc_mean_transforms.0.0.weight" not in converted + assert "lrp_transforms.0.0.weight" not in converted + assert "dt_cross_attention.0.x_trans.weight" not in converted + assert "gaussian_conditional.scale_table" not in converted + class TestCca: def test_cca_forward_and_state_dict_round_trip(self): diff --git a/tests/test_models_helpers.py b/tests/test_models_helpers.py index 2b54f0bc..ba4ab71b 100644 --- a/tests/test_models_helpers.py +++ b/tests/test_models_helpers.py @@ -27,6 +27,7 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest import torch import torch.nn as nn @@ -189,3 +190,287 @@ def test_infer_max_support_slices_new_path(self): } # extra_factor=1 is the default single latent_means concat. assert infer_max_support_slices(sd, latent_channels=64, num_slices=8) == 2 + + +class TestSharedDictionary: + def test_dt_shape_and_state_dict_path(self): + from compressai.models._helpers.dictionary_context import SharedDictionary + + shared = SharedDictionary(dict_num=16, dictionary_dim=64) + assert shared.dt.shape == (16, 64) + assert list(shared.state_dict().keys()) == ["dt"] + + def test_expand_for_broadcasts_without_copy(self): + from compressai.models._helpers.dictionary_context import SharedDictionary + + shared = SharedDictionary(dict_num=8, dictionary_dim=32) + out = shared.expand_for(4) + assert out.shape == (4, 8, 32) + # All B copies share storage with the underlying dt + assert out.data_ptr() == shared.dt.data_ptr() + + +class TestBuildDictionaryMeanScaleHead: + def _build(self, *, emit_mean_support=False): + from compressai.models._helpers.dictionary_context import ( + SharedDictionary, + build_dictionary_mean_scale_head, + ) + + # Tiny config: M=32, slice_ch=8, support_count=2 + m = 32 + slice_ch = 8 + support_count = 2 + support_ch = 2 * m + slice_ch * support_count + shared = SharedDictionary(dict_num=8, dictionary_dim=64) + head = build_dictionary_mean_scale_head( + slice_ch=slice_ch, + support_ch=support_ch, + shared_dictionary=shared, + dict_output_ch=m, + cross_attention_kwargs={"head_num": 4, "mlp_rate": 2}, + widths=(16,), + emit_mean_support=emit_mean_support, + ) + return shared, head, m, slice_ch, support_ch + + def test_forward_shape_no_emit(self): + shared, head, m, slice_ch, support_ch = self._build(emit_mean_support=False) + x = torch.randn(2, support_ch, 4, 4) + out = head(x) + # Output: cat([scale, mean]) → 2 * slice_ch + assert out.shape == (2, 2 * slice_ch, 4, 4) + + def test_forward_shape_with_emit_mean_support(self): + shared, head, m, slice_ch, support_ch = self._build(emit_mean_support=True) + x = torch.randn(2, support_ch, 4, 4) + out = head(x) + # Output: cat([scale, mean, support]) where support = cat([x, dict_info(M)]) + expected = 2 * slice_ch + (support_ch + m) + assert out.shape == (2, expected, 4, 4) + + def test_dt_not_duplicated_in_head_state_dict(self): + shared, head, *_ = self._build() + head_keys = list(head.state_dict().keys()) + assert all( + "dt" not in k for k in head_keys + ), f"dt leaked into head.state_dict: {[k for k in head_keys if 'dt' in k]}" + + def test_dt_appears_once_in_container_state_dict(self): + from compressai.models._helpers.dictionary_context import ( + SharedDictionary, + build_dictionary_mean_scale_head, + ) + + m, slice_ch, support_count = 32, 8, 2 + support_ch = 2 * m + slice_ch * support_count + + class _Container(nn.Module): + def __init__(self): + super().__init__() + self.shared_dictionary = SharedDictionary(dict_num=8, dictionary_dim=64) + self.heads = nn.ModuleDict( + { + f"y{k}": build_dictionary_mean_scale_head( + slice_ch=slice_ch, + support_ch=support_ch, + shared_dictionary=self.shared_dictionary, + dict_output_ch=m, + cross_attention_kwargs={"head_num": 4, "mlp_rate": 2}, + widths=(16,), + ) + for k in range(3) + } + ) + + container = _Container() + dt_keys = [k for k in container.state_dict() if k.endswith(".dt")] + assert dt_keys == [ + "shared_dictionary.dt" + ], f"expected single shared_dictionary.dt path, got: {dt_keys}" + + +class TestOLP: + @staticmethod + def test_forward_shape_square(): + from compressai.models._helpers.auxt import OLP + + m = OLP(8, 8) + out = m(torch.randn(2, 8)) + assert out.shape == (2, 8) + + @staticmethod + def test_loss_returns_scalar_for_each_aspect_ratio(): + from compressai.models._helpers.auxt import OLP + + for in_dim, out_dim in [(8, 8), (16, 4), (4, 16)]: + m = OLP(in_dim, out_dim) + loss = m.loss() + assert loss.dim() == 0, f"OLP({in_dim}, {out_dim}).loss() must be scalar" + assert torch.isfinite(loss) + + @staticmethod + def test_state_dict_round_trip(): + from compressai.models._helpers.auxt import OLP + + m = OLP(8, 8) + m2 = OLP(8, 8) + m2.load_state_dict(m.state_dict(), strict=True) + x = torch.randn(2, 8) + assert torch.allclose(m(x), m2(x)) + + +class TestWLSiWLS: + @staticmethod + def test_wls_iwls_shapes_and_round_trip(): + pytest.importorskip("pytorch_wavelets") + from compressai.models._helpers.auxt import WLS, iWLS + + wls = WLS(in_dim=3, out_dim=8) + iwls = iWLS(in_dim=8, out_dim=3) + x = torch.randn(2, 3, 16, 16) + y = wls(x) + # WLS halves spatial size (DWT) and produces out_dim channels. + assert y.shape == (2, 8, 8, 8) + z = iwls(y) + assert z.shape == x.shape + + # state_dict round-trip on WLS. + wls2 = WLS(in_dim=3, out_dim=8) + wls2.load_state_dict(wls.state_dict(), strict=True) + assert torch.allclose(wls(x), wls2(x)) + + @staticmethod + def test_aux_loss_returns_zero_when_no_olp_present(): + import torch.nn as _nn + + from compressai.models._helpers.auxt import aux_loss + + # A toy model with no OLP submodules — aux_loss should return a 0-d + # zero Tensor so callers can unconditionally add it to the objective. + model = _nn.Sequential(_nn.Linear(8, 8)) + loss = aux_loss(model) + assert loss.dim() == 0 + assert loss.item() == 0.0 + + @staticmethod + def test_aux_loss_aggregates_olp_modules(): + import torch.nn as _nn + + from compressai.models._helpers.auxt import OLP, aux_loss + + # Two OLPs at different positions in the tree — aux_loss should + # equal the sum of their individual losses. + class _Container(_nn.Module): + def __init__(self): + super().__init__() + self.a = OLP(8, 8) + self.b = OLP(16, 4) + + c = _Container() + expected = c.a.loss() + c.b.loss() + assert torch.allclose(aux_loss(c), expected) + + +class TestForwardWithAuxt: + @staticmethod + def test_collapses_to_transform_when_aux_layers_none(): + import torch.nn as _nn + + from compressai.models._helpers.auxt import forward_with_auxt + + transform = _nn.Sequential(_nn.Conv2d(3, 4, 1), _nn.Conv2d(4, 5, 1)) + x = torch.randn(2, 3, 4, 4) + with torch.no_grad(): + assert torch.allclose( + forward_with_auxt(transform, None, (), x), transform(x) + ) + + @staticmethod + def test_sums_auxt_at_merge_positions(): + import torch.nn as _nn + + from compressai.models._helpers.auxt import forward_with_auxt + + # transform: 4 layers, all identity Conv2d-style (1x1, weight=I). + def _identity_conv(ch): + conv = _nn.Conv2d(ch, ch, 1, bias=False) + with torch.no_grad(): + conv.weight.copy_(torch.eye(ch).view(ch, ch, 1, 1)) + return conv + + transform = _nn.Sequential(*(_identity_conv(3) for _ in range(4))) + # AuxT branch with 2 layers, also identity. Merge at position 1 and 3. + aux = _nn.ModuleList([_identity_conv(3), _identity_conv(3)]) + x = torch.randn(1, 3, 2, 2) + out = forward_with_auxt(transform, aux, (1, 3), x) + # After identity transform + 2 identity AuxT additions, output = 3 * x + # (x at start + AuxT[0]=x at pos 1 + AuxT[1]=AuxT[0]=x at pos 3). + assert torch.allclose(out, 3 * x) + + @staticmethod + def test_raises_when_merge_positions_underrun_aux_depth(): + import torch.nn as _nn + + from compressai.models._helpers.auxt import forward_with_auxt + + transform = _nn.Sequential(_nn.Conv2d(3, 3, 1), _nn.Conv2d(3, 3, 1)) + aux = _nn.ModuleList([_nn.Conv2d(3, 3, 1), _nn.Conv2d(3, 3, 1)]) + x = torch.randn(1, 3, 2, 2) + # Only 1 merge position for 2 aux layers -> mismatch. + with pytest.raises(RuntimeError, match="merge positions"): + forward_with_auxt(transform, aux, (0,), x) + + +class TestAuxtStateDictHelpers: + @staticmethod + def test_has_auxt_state(): + from compressai.models._helpers.auxt import has_auxt_state + + assert has_auxt_state( + {"g_a.auxiliary_layers.0.olp.linear.weight": torch.zeros(2)} + ) + assert has_auxt_state( + {"g_s.auxiliary_layers.3.scaling_factors": torch.zeros(2)} + ) + assert not has_auxt_state({"g_a.0.weight": torch.zeros(2)}) + + @staticmethod + def test_is_auxt_wavelet_buffer_key(): + from compressai.models._helpers.auxt import is_auxt_wavelet_buffer_key + + assert is_auxt_wavelet_buffer_key("g_a.auxiliary_layers.0.dwt.transform.h0_col") + assert is_auxt_wavelet_buffer_key("g_s.auxiliary_layers.0.idwt.inverse.g0_col") + assert not is_auxt_wavelet_buffer_key( + "g_a.auxiliary_layers.0.olp.linear.weight" + ) + assert not is_auxt_wavelet_buffer_key("g_a.0.weight") + + @staticmethod + def test_is_auxt_upstream_wavelet_buffer_key(): + from compressai.models._helpers.auxt import ( + is_auxt_upstream_wavelet_buffer_key, + ) + + for suffix in ("w_ll", "w_lh", "w_hl", "w_hh"): + assert is_auxt_upstream_wavelet_buffer_key(f"AuxT_enc.0.dwt.{suffix}") + assert is_auxt_upstream_wavelet_buffer_key("AuxT_dec.0.idwt.filters") + assert not is_auxt_upstream_wavelet_buffer_key( + "AuxT_enc.0.dwt.transform.h0_col" + ) + assert not is_auxt_upstream_wavelet_buffer_key("AuxT_enc.0.olp.linear.weight") + + @staticmethod + def test_normalize_upstream_auxt_key_renames_pascal_olp(): + from compressai.models._helpers.auxt import normalize_upstream_auxt_key + + assert ( + normalize_upstream_auxt_key("AuxT_enc.0.OLP.linear.weight") + == "g_a.auxiliary_layers.0.olp.linear.weight" + ) + assert ( + normalize_upstream_auxt_key("AuxT_dec.3.OLP.linear.bias") + == "g_s.auxiliary_layers.3.olp.linear.bias" + ) + # Returns None for non-AuxT keys so callers can use a single check. + assert normalize_upstream_auxt_key("g_a.0.weight") is None diff --git a/uv.lock b/uv.lock index 7f3aa8ea..68898bc9 100644 --- a/uv.lock +++ b/uv.lock @@ -682,6 +682,9 @@ tutorials = [ { name = "ipywidgets" }, { name = "jupyter" }, ] +wavelet = [ + { name = "pytorch-wavelets" }, +] [package.dev-dependencies] dev = [ @@ -737,6 +740,7 @@ requires-dist = [ { name = "pytest-cov", marker = "extra == 'dev'" }, { name = "pytest-cov", marker = "extra == 'test'" }, { name = "pytorch-msssim" }, + { name = "pytorch-wavelets", marker = "extra == 'wavelet'" }, { name = "ruff", marker = "extra == 'dev'", specifier = "==0.8.6" }, { name = "scipy" }, { name = "setuptools", specifier = ">=68" }, @@ -753,7 +757,7 @@ requires-dist = [ { name = "typing-extensions", specifier = ">=4.0.0" }, { name = "wheel", specifier = ">=0.32.0" }, ] -provides-extras = ["test", "dev", "doc", "tutorials", "pointcloud", "attn"] +provides-extras = ["test", "dev", "doc", "tutorials", "pointcloud", "attn", "wavelet"] [package.metadata.requires-dev] dev = [ @@ -3690,6 +3694,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e2/8c/856047f955acc30179e9255fdc488059ca22f0938519523d53494f7cfee8/pytorch_msssim-1.0.0-py3-none-any.whl", hash = "sha256:0b4b7bbf7035fe9dc8084244237aac13b1f104852c45b63a7e9fab4363bede54", size = 7744, upload-time = "2023-05-25T17:15:55.809Z" }, ] +[[package]] +name = "pytorch-wavelets" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "six" }, + { name = "torch", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3f/e3/1553046e97926b859edcab05d74b5b7543327c92b37c58919a3e49ff6151/pytorch_wavelets-1.3.0.tar.gz", hash = "sha256:8b5c63f87c2bb36e6b342a7bb294926bda5cd974614fb4848deab6ec2792f56f", size = 1029927, upload-time = "2023-04-13T06:26:52.086Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/65/b7da80705dc679999ef77c06ded71e060d4ea14ed80111c104223e130cc1/pytorch_wavelets-1.3.0-py3-none-any.whl", hash = "sha256:e4f8635872370d8de640ef7548edef4ef60d5c565553c463210e939c7901ee69", size = 54879, upload-time = "2023-04-13T06:26:45.63Z" }, +] + [[package]] name = "pytz" version = "2025.2"