|
| 1 | +""" Coordinate Attention and Variants |
| 2 | +
|
| 3 | +Coordinate Attention decomposes channel attention into two 1D feature encoding processes |
| 4 | +to capture long-range dependencies with precise positional information. This module includes |
| 5 | +the original implementation along with simplified and other variants. |
| 6 | +
|
| 7 | +Papers / References: |
| 8 | +- Coordinate Attention: `Coordinate Attention for Efficient Mobile Network Design` - https://arxiv.org/abs/2103.02907 |
| 9 | +- Efficient Local Attention: `Rethinking Local Perception in Lightweight Vision Transformer` - https://arxiv.org/abs/2403.01123 |
| 10 | +
|
| 11 | +Hacked together by / Copyright 2025 Ross Wightman |
| 12 | +""" |
| 13 | +from typing import Optional, Type, Union |
| 14 | + |
| 15 | +import torch |
| 16 | +from torch import nn |
| 17 | + |
| 18 | +from .create_act import create_act_layer |
| 19 | +from .helpers import make_divisible |
| 20 | +from .norm import GroupNorm1 |
| 21 | + |
| 22 | + |
| 23 | +class CoordAttn(nn.Module): |
| 24 | + def __init__( |
| 25 | + self, |
| 26 | + channels: int, |
| 27 | + rd_ratio: float = 1. / 16, |
| 28 | + rd_channels: Optional[int] = None, |
| 29 | + rd_divisor: int = 8, |
| 30 | + se_factor: float = 2/3, |
| 31 | + bias: bool = False, |
| 32 | + act_layer: Type[nn.Module] = nn.Hardswish, |
| 33 | + norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm2d, |
| 34 | + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', |
| 35 | + has_skip: bool = False, |
| 36 | + device=None, |
| 37 | + dtype=None, |
| 38 | + ): |
| 39 | + """Coordinate Attention module for spatial feature recalibration. |
| 40 | +
|
| 41 | + Introduced in "Coordinate Attention for Efficient Mobile Network Design" (CVPR 2021). |
| 42 | + Decomposes channel attention into two 1D feature encoding processes along the height and |
| 43 | + width axes to capture long-range dependencies with precise positional information. |
| 44 | +
|
| 45 | + Args: |
| 46 | + channels: Number of input channels. |
| 47 | + rd_ratio: Reduction ratio for bottleneck channel calculation. |
| 48 | + rd_channels: Explicit number of bottleneck channels, overrides rd_ratio if set. |
| 49 | + rd_divisor: Divisor for making bottleneck channels divisible. |
| 50 | + se_factor: Applied to rd_ratio for final channel count (keeps params similar to SE). |
| 51 | + bias: Whether to use bias in convolution layers. |
| 52 | + act_layer: Activation module class for bottleneck. |
| 53 | + norm_layer: Normalization module class, None for no normalization. |
| 54 | + gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class. |
| 55 | + has_skip: Whether to add residual skip connection to output. |
| 56 | + device: Device to place tensors on. |
| 57 | + dtype: Data type for tensors. |
| 58 | + """ |
| 59 | + |
| 60 | + dd = {'device': device, 'dtype': dtype} |
| 61 | + super().__init__() |
| 62 | + self.has_skip = has_skip |
| 63 | + if not rd_channels: |
| 64 | + rd_channels = make_divisible(channels * rd_ratio * se_factor, rd_divisor, round_limit=0.) |
| 65 | + |
| 66 | + self.conv1 = nn.Conv2d(channels, rd_channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd) |
| 67 | + self.bn1 = norm_layer(rd_channels, **dd) if norm_layer is not None else nn.Identity() |
| 68 | + self.act = act_layer() |
| 69 | + |
| 70 | + self.conv_h = nn.Conv2d(rd_channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd) |
| 71 | + self.conv_w = nn.Conv2d(rd_channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd) |
| 72 | + self.gate = create_act_layer(gate_layer) |
| 73 | + |
| 74 | + def forward(self, x): |
| 75 | + identity = x |
| 76 | + |
| 77 | + N, C, H, W = x.size() |
| 78 | + |
| 79 | + # Strip pooling |
| 80 | + x_h = x.mean(3, keepdim=True) |
| 81 | + x_w = x.mean(2, keepdim=True) |
| 82 | + |
| 83 | + x_w = x_w.transpose(-1, -2) |
| 84 | + y = torch.cat([x_h, x_w], dim=2) |
| 85 | + y = self.conv1(y) |
| 86 | + y = self.bn1(y) |
| 87 | + y = self.act(y) |
| 88 | + x_h, x_w = torch.split(y, [H, W], dim=2) |
| 89 | + x_w = x_w.transpose(-1, -2) |
| 90 | + |
| 91 | + a_h = self.gate(self.conv_h(x_h)) |
| 92 | + a_w = self.gate(self.conv_w(x_w)) |
| 93 | + |
| 94 | + out = identity * a_w * a_h |
| 95 | + if self.has_skip: |
| 96 | + out = out + identity |
| 97 | + |
| 98 | + return out |
| 99 | + |
| 100 | + |
| 101 | +class SimpleCoordAttn(nn.Module): |
| 102 | + """Simplified Coordinate Attention variant. |
| 103 | +
|
| 104 | + Uses |
| 105 | + * linear layers instead of convolutions |
| 106 | + * no norm |
| 107 | + * additive pre-gating re-combination |
| 108 | + for reduced complexity while maintaining the core coordinate attention mechanism |
| 109 | + of separate height and width attention. |
| 110 | + """ |
| 111 | + |
| 112 | + def __init__( |
| 113 | + self, |
| 114 | + channels: int, |
| 115 | + rd_ratio: float = 0.25, |
| 116 | + rd_channels: Optional[int] = None, |
| 117 | + rd_divisor: int = 8, |
| 118 | + se_factor: float = 2 / 3, |
| 119 | + bias: bool = True, |
| 120 | + act_layer: Type[nn.Module] = nn.SiLU, |
| 121 | + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', |
| 122 | + has_skip: bool = False, |
| 123 | + device=None, |
| 124 | + dtype=None, |
| 125 | + ): |
| 126 | + """ |
| 127 | + Args: |
| 128 | + channels: Number of input channels. |
| 129 | + rd_ratio: Reduction ratio for bottleneck channel calculation. |
| 130 | + rd_channels: Explicit number of bottleneck channels, overrides rd_ratio if set. |
| 131 | + rd_divisor: Divisor for making bottleneck channels divisible. |
| 132 | + se_factor: Applied to rd_ratio for final channel count (keeps param similar to SE) |
| 133 | + bias: Whether to use bias in linear layers. |
| 134 | + act_layer: Activation module class for bottleneck. |
| 135 | + gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class. |
| 136 | + has_skip: Whether to add residual skip connection to output. |
| 137 | + device: Device to place tensors on. |
| 138 | + dtype: Data type for tensors. |
| 139 | + """ |
| 140 | + dd = {'device': device, 'dtype': dtype} |
| 141 | + super().__init__() |
| 142 | + self.has_skip = has_skip |
| 143 | + |
| 144 | + if not rd_channels: |
| 145 | + rd_channels = make_divisible(channels * rd_ratio * se_factor, rd_divisor, round_limit=0.) |
| 146 | + |
| 147 | + self.fc1 = nn.Linear(channels, rd_channels, bias=bias, **dd) |
| 148 | + self.act = act_layer() |
| 149 | + self.fc_h = nn.Linear(rd_channels, channels, bias=bias, **dd) |
| 150 | + self.fc_w = nn.Linear(rd_channels, channels, bias=bias, **dd) |
| 151 | + |
| 152 | + self.gate = create_act_layer(gate_layer) |
| 153 | + |
| 154 | + def forward(self, x): |
| 155 | + identity = x |
| 156 | + |
| 157 | + # Strip pooling |
| 158 | + x_h = x.mean(dim=3) # (N, C, H) |
| 159 | + x_w = x.mean(dim=2) # (N, C, W) |
| 160 | + |
| 161 | + # Shared bottleneck projection |
| 162 | + x_h = self.act(self.fc1(x_h.transpose(1, 2))) # (N, H, rd_c) |
| 163 | + x_w = self.act(self.fc1(x_w.transpose(1, 2))) # (N, W, rd_c) |
| 164 | + |
| 165 | + # Separate attention heads |
| 166 | + a_h = self.fc_h(x_h).transpose(1, 2).unsqueeze(-1) # (N, C, H, 1) |
| 167 | + a_w = self.fc_w(x_w).transpose(1, 2).unsqueeze(-2) # (N, C, 1, W) |
| 168 | + |
| 169 | + out = identity * self.gate(a_h + a_w) |
| 170 | + if self.has_skip: |
| 171 | + out = out + identity |
| 172 | + |
| 173 | + return out |
| 174 | + |
| 175 | + |
| 176 | +class EfficientLocalAttn(nn.Module): |
| 177 | + """Efficient Local Attention. |
| 178 | +
|
| 179 | + Lightweight alternative to Coordinate Attention that preserves spatial |
| 180 | + information without channel reduction. Uses 1D depthwise convolutions |
| 181 | + and GroupNorm for better generalization. |
| 182 | +
|
| 183 | + Paper: https://arxiv.org/abs/2403.01123 |
| 184 | + """ |
| 185 | + |
| 186 | + def __init__( |
| 187 | + self, |
| 188 | + channels: int, |
| 189 | + kernel_size: int = 7, |
| 190 | + bias: bool = False, |
| 191 | + act_layer: Type[nn.Module] = nn.SiLU, |
| 192 | + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', |
| 193 | + norm_layer: Optional[Type[nn.Module]] = GroupNorm1, |
| 194 | + has_skip: bool = False, |
| 195 | + device=None, |
| 196 | + dtype=None, |
| 197 | + ): |
| 198 | + """ |
| 199 | + Args: |
| 200 | + channels: Number of input channels. |
| 201 | + kernel_size: Kernel size for 1D depthwise convolutions. |
| 202 | + bias: Whether to use bias in convolution layers. |
| 203 | + act_layer: Activation module class applied after normalization. |
| 204 | + gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class. |
| 205 | + norm_layer: Normalization module class, None for no normalization. |
| 206 | + has_skip: Whether to add residual skip connection to output. |
| 207 | + device: Device to place tensors on. |
| 208 | + dtype: Data type for tensors. |
| 209 | + """ |
| 210 | + dd = {'device': device, 'dtype': dtype} |
| 211 | + super().__init__() |
| 212 | + self.has_skip = has_skip |
| 213 | + |
| 214 | + self.conv_h = nn.Conv2d( |
| 215 | + channels, channels, |
| 216 | + kernel_size=(kernel_size, 1), |
| 217 | + stride=1, |
| 218 | + padding=(kernel_size // 2, 0), |
| 219 | + groups=channels, |
| 220 | + bias=bias, |
| 221 | + **dd |
| 222 | + ) |
| 223 | + self.conv_w = nn.Conv2d( |
| 224 | + channels, channels, |
| 225 | + kernel_size=(1, kernel_size), |
| 226 | + stride=1, |
| 227 | + padding=(0, kernel_size // 2), |
| 228 | + groups=channels, |
| 229 | + bias=bias, |
| 230 | + **dd |
| 231 | + ) |
| 232 | + if norm_layer is not None: |
| 233 | + self.norm_h = norm_layer(channels, **dd) |
| 234 | + self.norm_w = norm_layer(channels, **dd) |
| 235 | + else: |
| 236 | + self.norm_h = nn.Identity() |
| 237 | + self.norm_w = nn.Identity() |
| 238 | + self.act = act_layer() |
| 239 | + self.gate = create_act_layer(gate_layer) |
| 240 | + |
| 241 | + def forward(self, x): |
| 242 | + identity = x |
| 243 | + |
| 244 | + # Strip pooling: (N, C, H, W) -> (N, C, H) and (N, C, W) |
| 245 | + x_h = x.mean(dim=3, keepdim=True) |
| 246 | + x_w = x.mean(dim=2, keepdim=True) |
| 247 | + |
| 248 | + # 1D conv + norm + act |
| 249 | + x_h = self.act(self.norm_h(self.conv_h(x_h))) # (N, C, H, 1) |
| 250 | + x_w = self.act(self.norm_w(self.conv_w(x_w))) # (N, C, 1, W) |
| 251 | + |
| 252 | + # Generate attention maps |
| 253 | + a_h = self.gate(x_h) # (N, C, H, 1) |
| 254 | + a_w = self.gate(x_w) # (N, C, 1, W) |
| 255 | + |
| 256 | + out = identity * a_h * a_w |
| 257 | + if self.has_skip: |
| 258 | + out = out + identity |
| 259 | + |
| 260 | + return out |
| 261 | + |
| 262 | + |
| 263 | +class StripAttn(nn.Module): |
| 264 | + """Minimal Strip Attention. |
| 265 | +
|
| 266 | + Lightweight spatial attention using strip pooling with optional learned refinement. |
| 267 | + """ |
| 268 | + |
| 269 | + def __init__( |
| 270 | + self, |
| 271 | + channels: int, |
| 272 | + use_conv: bool = True, |
| 273 | + kernel_size: int = 3, |
| 274 | + bias: bool = False, |
| 275 | + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', |
| 276 | + has_skip: bool = False, |
| 277 | + device=None, |
| 278 | + dtype=None, |
| 279 | + **_, |
| 280 | + ): |
| 281 | + """ |
| 282 | + Args: |
| 283 | + channels: Number of input channels. |
| 284 | + use_conv: Whether to apply depthwise convolutions for learned spatial refinement. |
| 285 | + kernel_size: Kernel size for 1D depthwise convolutions when use_conv is True. |
| 286 | + bias: Whether to use bias in convolution layers. |
| 287 | + gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class. |
| 288 | + has_skip: Whether to add residual skip connection to output. |
| 289 | + device: Device to place tensors on. |
| 290 | + dtype: Data type for tensors. |
| 291 | + """ |
| 292 | + dd = {'device': device, 'dtype': dtype} |
| 293 | + super().__init__() |
| 294 | + self.has_skip = has_skip |
| 295 | + self.use_conv = use_conv |
| 296 | + |
| 297 | + if use_conv: |
| 298 | + self.conv_h = nn.Conv2d( |
| 299 | + channels, channels, |
| 300 | + kernel_size=(kernel_size, 1), |
| 301 | + stride=1, |
| 302 | + padding=(kernel_size // 2, 0), |
| 303 | + groups=channels, |
| 304 | + bias=bias, |
| 305 | + **dd |
| 306 | + ) |
| 307 | + self.conv_w = nn.Conv2d( |
| 308 | + channels, channels, |
| 309 | + kernel_size=(1, kernel_size), |
| 310 | + stride=1, |
| 311 | + padding=(0, kernel_size // 2), |
| 312 | + groups=channels, |
| 313 | + bias=bias, |
| 314 | + **dd |
| 315 | + ) |
| 316 | + else: |
| 317 | + self.conv_h = nn.Identity() |
| 318 | + self.conv_w = nn.Identity() |
| 319 | + |
| 320 | + self.gate = create_act_layer(gate_layer) |
| 321 | + |
| 322 | + def forward(self, x): |
| 323 | + identity = x |
| 324 | + |
| 325 | + # Strip pooling |
| 326 | + x_h = x.mean(dim=3, keepdim=True) # (N, C, H, 1) |
| 327 | + x_w = x.mean(dim=2, keepdim=True) # (N, C, 1, W) |
| 328 | + |
| 329 | + # Optional learned refinement |
| 330 | + x_h = self.conv_h(x_h) |
| 331 | + x_w = self.conv_w(x_w) |
| 332 | + |
| 333 | + # Combine and gate |
| 334 | + a_hw = self.gate(x_h + x_w) # broadcasts to (N, C, H, W) |
| 335 | + |
| 336 | + out = identity * a_hw |
| 337 | + if self.has_skip: |
| 338 | + out = out + identity |
| 339 | + |
| 340 | + return out |
| 341 | + |
0 commit comments