2.0.0#9
Merged
Merged
Conversation
…s, and losses
A code audit turned up several issues that were either silently wrong,
crashed in narrow cases (CPU-only installs, eval mode, zero-magnitude
inputs), or contradicted their own docstrings. None of the fast
(native-cfloat) modules were affected; the user-visible default paths
(`Conv1d`, `Linear`, `BatchNorm*d` in training, the polar activations,
attention) all continue to behave the same to first order, but the
hand-rolled `Slow*` variants and the eval-mode of `BatchNorm` produced
wrong numbers under common settings.
Critical correctness fixes:
- Gauss-trick bias in `SlowLinear` / `SlowConv1d/2d/3d` /
`SlowConvTranspose1d/2d/3d` was off by `-b_i * (1 + j)` whenever
`bias=True`. The previous implementation folded `b_r` into `t1`, `b_i`
into `t2`, and `b_r + b_i` into `t3`, then computed
`(t1 - t2) + j(t3 - t2 - t1)`, which cancels the bias incorrectly.
Bias is now stored as two real `nn.Parameter`s (`bias_r`, `bias_i`,
registered via `register_parameter` so they participate in
`state_dict()` and gradients), and applied to the real and imaginary
outputs after the Gauss combination. The `bias` property still
returns `torch.complex(bias_r, bias_i)` so external callers see no
change. `SlowConv*` default `bias=True`, so this affected
out-of-the-box usage.
- `_Conv.forward` / `_ConvTranspose.forward` now forward `dilation`
(and `output_padding`, for the transpose path) into the raw
`F.conv*d` / `F.conv_transpose*d` call used for `t3`. Without these,
any non-default `dilation` or `output_padding` made `t3` a different
shape from `t1, t2` and the final `torch.complex` errored out.
- Removed three unconditional debug prints in `_Conv`. One hard-coded
`torch.cuda.memory_allocated("cuda:0")` (crashed on CPU-only installs
and fired on every forward); another printed `self.conv_i.bias.type()`
from the `bias` property, which crashed with `AttributeError` when
`bias=False`. Both also leaked into any caller of
`MaskedChannelAttention1d/2d/3d`, which composes `SlowConv*` under
the hood.
- `nn.functional.batch_norm` eval-mode (`training=False` with stored
running stats) used `running_mean` directly at shape `[2, F]` and
subtracted it from `x` of shape `[2, B, F, ...]`. PyTorch's
right-aligned broadcasting matched the last two axes instead of
`[stack, F]`. The else branch now reshapes to
`[2, 1, F, 1, ..., 1]`. Also made the centering out-of-place
(`x = x - mean` rather than `x -= mean`) and fixed a duplicate-name
unpack (`v_rr, v_ir, v_ir, v_ii = ...` -> `..., _, v_ii = ...`); the
latter was benign because the off-diagonals were stored equal, but
misleading.
- `mask.PhaseSigmoid` had `pass` as its class body — instantiating it
worked, calling it raised `NotImplementedError`. Now implements the
phase-preserving sigmoid the docstring already described.
- `mask.MagMinMaxNorm` subtracted a real scalar `x_min` from a complex
tensor, which silently changed phase despite the docstring promising
it would be preserved; its `dim` constructor argument was also stored
and never used. The forward now operates on `input.abs()` and
rebuilds via `torch.polar(new_mag, input.angle())`, and honors
`self.dim`.
- `loss.SSIM.forward` with `data_range=None` first built a 4-D
`data_range` and then unconditionally ran
`data_range[:, None, None, None]`, producing a 7-D tensor that
misbroadcast downstream. The unsqueeze is now gated to the
user-supplied branch.
High-severity semantic fixes:
- Added `1e-12` clamp on the divisor in the three `/ |z|` sites that
could produce NaN on zero magnitude: `apply_complex_polar`
(phase-skip path), `PhaseSoftMax.forward`, and
`ComplexRatioMask.forward`. Smoke-tested with a zero-containing
complex tensor; no NaN.
- `_modReLU` / `modReLU` defaulted `bias=0.0` but immediately asserted
`bias < 0`. Default is now `-0.1` so `modReLU()` with no args works.
- `attention.ScaledDotProductAttention` now uses
`k.conj().transpose(-2, -1)` so the dot product is the Hermitian
inner product `Q K^H` rather than the bilinear form `Q K^T`. The
docstrings on both `ScaledDotProductAttention` and
`MultiheadAttention` previously claimed `CVSoftMax` applied softmax
to magnitude; in reality `CVSoftMax` applies softmax to real and
imaginary parts separately (`PhaseSoftMax` is the magnitude variant).
Rewrote both to describe what the code actually does.
- `nn.functional.layer_norm` now asserts `weight` and `bias` are both
set or both `None`, mirroring `batch_norm`.
- `Dropout` docstring now explicitly states that real and imaginary
parts get **independent** Bernoulli masks, and contrasts with
Trabelsi 2018's shared-mask formulation, so the choice is deliberate
rather than implicit.
API and packaging:
- Re-export fast `ConvTranspose1d`, `ConvTranspose2d`,
`ConvTranspose3d`, and `SSIM` from `complextorch.nn`. The fast
transposes existed but were not exposed; only the `Slow` variants
were.
- Fast `ConvTranspose*` default `output_padding` is now `0`, matching
`torch.nn.ConvTranspose*`. (Previously defaulted to `1`, which
changed output shape vs. native PyTorch.)
- `CVQuadError`, `CVFourthPowError`, `CVCauchyError`, `CVLogCoshError`,
and `CVLogError` now accept a `reduction='mean' | 'sum' | 'none'`
argument, defaulting to `'mean'`. Previously they hard-coded `.sum()`,
so loss magnitude (and gradient scale) varied with batch and feature
size and was inconsistent with `SplitMSE` / `SplitL1`. The original
paper formulas are recoverable with `reduction='sum'`.
- `setup.py`: bumped `python_requires` to `>=3.10` (matches
`numpy>=2.2.0`) and `torch>=2.1.0` (matches the package's reliance on
native complex kernels). Dropped `deprecated>=1.2.18` from
`install_requires` (no longer used after the `CVTensor` removal).
Same updates applied to `requirements.txt` (which previously pinned
a CUDA-only `torch>=1.11.0+cu115`, breaking non-CUDA installs) and
`docs/requirements.txt`.
Minor:
- `manifold.wFMConv2d` now caches `nn.Fold` by input spatial shape
(`_get_fold(input_shape)`) instead of constructing — and registering
as a submodule — a fresh `nn.Fold` on every forward.
- `manifold.wFMConv1d.forward` uses `.squeeze(-2)` to pair with its
`.unsqueeze(-2)`. The previous bare `.squeeze()` could drop a
singleton batch dim.
- `attention/eca.py`: collapsed a `.transpose(-1, -2).transpose(-1, -2)`
no-op into the surrounding `.view` calls.
Docs:
- `docs/source/index.rst`: updated release date to `05/10/2026`,
renamed the release-notes section to "Version 1.2 Release Notes",
rewrote the bullets to describe the legacy-API removal plus the
correctness fixes in this change, and corrected the long-standing
typo `dtype=torch.float` -> `dtype=torch.cfloat`.
- `README.md`: same release-notes rewrite.
…cvnn (2.0.0)
After auditing the three sibling complex-NN libraries checked out alongside
this one (Popoff's `complexPyTorch`, Nazarov's `cplxmodule`, and Levi/Fix
et al.'s `torchcvnn`), this commit lands every feature they ship that
`complextorch` did not, in a single 2.0.0 release. The library is now a
strict superset of those three on the `nn`, `transforms`, `models`,
`datasets`, and signal-processing fronts.
This bumps the version to 2.0.0 because two unavoidable changes are
backwards-incompatible: a `bias=False` -> `bias=True` default change on the
Fast `Conv*d` / `ConvTranspose*d` / `Linear` wrappers (to match `torch.nn`),
and a silent `QK^T` -> `QK^H` math fix in `MultiheadAttention` (the inner
product was using a plain transpose where the Hermitian transpose is the
standard choice for complex inner products).
New subpackages
---------------
- `complextorch.nn.init` — variance-correct complex weight initializers:
`kaiming_normal_`, `kaiming_uniform_`, `xavier_normal_`, `xavier_uniform_`,
`trabelsi_standard_` (Rayleigh-uniform polar), `trabelsi_independent_`
(semi-unitary via SVD). PyTorch's built-in init treats the real and
imaginary parts as independent real tensors, which gives the wrong
variance for the complex magnitude (off by a factor of two); these
produce the intended `Var(|w|^2)` for both He and Glorot targets.
- `complextorch.nn.relevance` — complex Variational Dropout / Automatic
Relevance Determination. `LinearVD`, `BilinearVD`, `Conv{1,2,3}dVD`,
plus the corresponding `*ARD` variants. Adapted from
`cplxmodule.nn.relevance.complex`. Real-valued VD/ARD layers are
intentionally not ported — only the complex variants, with the `Cplx`
prefix dropped since the subpackage name already scopes the meaning.
Adds `scipy` as a runtime dependency for the exact `Ei`-based KL
(`complextorch.nn.relevance._expi.ExpiFunction`); the empirical-Bayes
`ARD` variants use a pure-torch softplus penalty and don't need scipy
at runtime.
- `complextorch.nn.masked` — fixed-mask sparsified layers (`LinearMasked`,
`BilinearMasked`, `Conv{1,2,3}dMasked`) plus the module-walking helpers
`deploy_masks`, `binarize_masks`, `is_sparse`, `named_masks`. Designed
to compose with `nn.relevance`: train with `LinearVD`, extract a
relevance mask via `compute_ard_masks(model, threshold=...)`, deploy
it onto a parallel `LinearMasked` inference model with `deploy_masks`.
- `complextorch.nn.utils.sparsity` — `SparsityStats` base + the
module-walking `named_sparsity` / `sparsity` helpers.
- `complextorch.signal` — `pwelch` (a differentiable torch port of
`scipy.signal.welch`; matches scipy to ~2e-15 in unit tests).
- `complextorch.transforms` — torchcvnn-style dataloader-stage transforms
(`LogAmplitude`, `Amplitude`, `ToReal`, `ToImaginary`, `RealImaginary`,
`Normalize` (per-channel 2x2 whitening), `RandomPhase`, `PadIfNeeded`,
`CenterCrop`, `SpatialResize`, `FFT2`, `IFFT2`, `FFTResize`, `PolSAR`,
`ToTensor`, `Unsqueeze`, `HWC2CHW`). Torch-only — the upstream's dual
numpy/torch dispatch is intentionally dropped for simplicity. Two
public functional helpers (`polsar_dict_to_array`, `rescale_intensity`).
- `complextorch.datasets` — SAR / MRI / SLC dataset surface. `SAMPLE`
(in-memory random complex chips) and `SLCDataset` (generic `.npy` /
`.pt` directory reader) are full implementations; the file-format-heavy
loaders (`PolSFDataset`, `Bretigny`, `S1SLC`, `MSTARTargets`,
`ATRNetSTAR`, `MICCAI2023`, `ALOSDataset`+CEOS readers) are present as
importable classes with `_NotImplementedDataset` stubs that raise at
instantiation with a clear pointer to the upstream torchcvnn reference
implementation. This keeps the import surface complete (tab-completion,
type checking) without dishonestly claiming the loaders work.
Optional deps gated behind `pip install complextorch[datasets]` (h5py)
and `complextorch[datasets-alos]` (rasterio).
- `complextorch.models` — first pre-built reference architecture lives
here. `ViT`, `ViTLayer`, and presets `vit_t/s/b/l/h` matching the
standard ViT family (5M-630M parameters).
New nn.modules
--------------
- `nn.modules.transformer` — `TransformerEncoderLayer`, `TransformerEncoder`,
`TransformerDecoderLayer`, `TransformerDecoder`, `Transformer`.
Composed from the existing `MultiheadAttention` (which already includes
residual + LayerNorm internally) plus a feed-forward sub-block with
its own residual + LayerNorm. `softmax_on` flag is plumbed through
from `MultiheadAttention`.
- `nn.modules.rnn` — `GRUCell`, `GRU`, `LSTMCell`, `LSTM`. Cells are
built from complex `Linear` projections + `CSigmoid` + `CTanh` (not
the dual-real-RNN wrapper trick from `complexPyTorch`, which has a
parameterization quirk where `gru_re` does both `Wr.xr` and `Wr.xi`,
giving a different model from a proper complex GRU). Multi-layer
wrappers stack cells along time in Python; no CuDNN fused path
because CuDNN's complex RNN isn't a thing. Each cell accepts
`batchnorm=False`; setting it to `True` inserts `BatchNorm1d` after
every internal linear projection (Recurrent BN, Cooijmans et al. 2017).
- `nn.modules.rmsnorm` — `RMSNorm`. Complex RMS-norm; equivalent to
`torch.nn.RMSNorm` but with an optional per-feature 2x2 affine instead
of a scalar gain.
- `nn.modules.groupnorm` — `GroupNorm`. Per-group 2x2 whitening
(Trabelsi) + per-channel 2x2 affine + 2-vector bias. No running
statistics. Reuses the existing `inv_sqrtm2x2` from `nn.functional`.
- `nn.modules.upsampling` — `Upsample` (split real/imag; analogous to
`torchcvnn.nn.Upsample` and `complexPyTorch.complex_upsample`) and
`PolarUpsample` (polar form; matches the phase-preserving
`complex_upsample2`). The polar form preserves phase along smooth
regions but introduces visible discontinuities at the +/-pi phase
wrap; pick based on the smoothness profile of your data.
- `nn.modules.casting` — layout-conversion modules:
`InterleavedToComplex`, `ComplexToInterleaved`, `ConcatenatedToComplex`,
`ComplexToConcatenated`, `RealToComplex` (zero-imag lift). The
redundant `TensorToCplx` / `CplxToTensor` from `cplxmodule` are
deliberately omitted — `torch.view_as_complex` / `torch.view_as_real`
already cover those.
- `nn.modules.phase` — `PhaseShift`, a learnable per-channel /
per-feature phase rotation: `y = x * exp(j * phi)` with `phi` an
`nn.Parameter` initialized uniformly in `[-pi, pi]`.
Additions to existing nn modules
--------------------------------
- `nn.modules.linear` — new `Bilinear` (complex bilinear with a
`conjugate=True/False` flag selecting Hermitian vs plain bilinear
form). `Linear` / `SlowLinear` default changed to `bias=True`.
- `nn.modules.batchnorm` — new `NaiveBatchNorm{1,2,3}d` (independent
`torch.nn.BatchNorm*d` on real and imag — the split baseline, kept
for parity with `complexPyTorch.NaiveComplexBatchNorm*d` and as a
cheap drop-in when the Trabelsi 2x2 whitening overhead isn't
warranted).
- `nn.modules.pooling` — new `MagMaxPool{1,2,3}d` (magnitude-argmax max
pool — `torch.nn.MaxPool*d` doesn't define `>` on complex, so this
is the canonical complex analogue; gathers the original complex
sample at the max-magnitude position, preserving phase) and
`AvgPool{1,2,3}d` (thin wrapper preserving the import-swap promise).
The `MagMaxPool` name is deliberate: `MaxPool*d` would suggest a
drop-in replacement for `torch.nn.MaxPool*d`, but the semantics are
meaningfully different.
- `nn.modules.dropout` — new `Dropout1d` / `Dropout2d` / `Dropout3d`
with shared real/imag mask (Trabelsi 2018) via a `view_as_real` round
trip into `F.dropout*d`. Preserves the phase of surviving entries —
the original `Dropout` (independent masks per part) is unchanged for
backwards compatibility.
- `nn.modules.loss` — new `MSELoss` matching `torch.nn.MSELoss` exactly
(no 1/2 factor). The existing `CVQuadError` is unchanged and retains
its physics-style 1/2 factor; verified numerically that
`MSELoss(x, y) == 2 * CVQuadError(x, y)` for any inputs.
- `nn.modules.activation` — eight new activations:
- `CVSplitELU` + alias `CELU` (split Type-A ELU)
- `CVSplitCELU` + alias `CCELU` (split Type-A `torch.nn.CELU`)
- `CVSplitGELU` + alias `CGELU` (split Type-A GELU)
- `zAbsReLU` (magnitude-thresholded with learnable threshold)
- `zLeakyReLU` (leaky version of `zReLU`)
- `Mod` (magnitude as a module — complex -> real)
- `AdaptiveModReLU` (per-channel learnable threshold modReLU)
- `modReLU` gains a `learnable: bool = False` flag — when `True`,
the threshold becomes a single trainable scalar.
- `nn.modules.attention` — `MultiheadAttention` and
`ScaledDotProductAttention` now use `k.conj().transpose(-2, -1)` for
the Hermitian inner product (was the plain transpose; a math bug).
New `softmax_on='complex'|'real'` flag: `'complex'` (default) keeps
the existing `CVSoftMax`-on-`QK^H` behavior; `'real'` applies
`torch.softmax(Re[QK^H/sqrt(d)])` for real-valued attention weights
(the formulation used by torchcvnn).
- `nn.modules.conv` — Fast `Conv*d` / `ConvTranspose*d` default changed
to `bias=True`. Class-level docstring sentence added noting the only
behavioral difference vs `torch.nn` is the default `dtype=torch.cfloat`.
- `nn.__init__.py` — re-exports `ConvTranspose1d/2d/3d` (the classes
already existed in `conv.py` but were missing from the public
surface).
- `nn.functional` — the previously private whitening helpers
`_whiten2x2_batch_norm` and `_whiten2x2_layer_norm` are now public
(`whiten2x2_batch_norm`, `whiten2x2_layer_norm`); both are added to
`__all__` alongside `batch_norm`, `layer_norm`, `inv_sqrtm2x2`.
Breaking changes
----------------
- `Linear` / `SlowLinear` and fast `Conv{1,2,3}d` /
`ConvTranspose{1,2,3}d` default to `bias=True` (was `False`),
matching `torch.nn`. Existing code that relied on the silent
`bias=False` default and intentionally wanted no bias must pass
`bias=False` explicitly.
- `MultiheadAttention` and `ScaledDotProductAttention` use the
Hermitian inner product `Q K^H` instead of `Q K^T`. This is the
standard math for complex inner products and matches all sibling
libraries; results numerically differ from prior outputs for any
complex `K`. There is no opt-out flag — the previous behavior was a
bug, not a design choice.
Documentation
-------------
Every new public class and function appears in the Sphinx docs (GitHub
Pages workflow under `.github/workflows/docs.yml`):
- Seven new module RST files under `docs/source/nn/modules/`:
`casting.rst`, `phase.rst`, `rmsnorm.rst`, `groupnorm.rst`,
`upsampling.rst`, `rnn.rst`, `transformer.rst`.
- Four new subpackage RST files under `docs/source/nn/`: `init.rst`,
`relevance.rst`, `masked.rst`, `utils.rst`.
- Four new top-level RST files under `docs/source/`: `signal.rst`,
`transforms.rst`, `datasets.rst`, `models.rst`.
- Toctrees updated in `docs/source/index.rst`, `docs/source/nn.rst`,
and `docs/source/nn/modules.rst`.
- `index.rst` gains a "Version 2.0 Release Notes" section enumerating
the additions and the two breaking changes.
`sphinx-build -W --keep-going` returns clean (the only warning is a
pre-existing missing-`_static` directory unrelated to this change). All
35+ new public symbols I spot-checked resolve to a built HTML page.
Packaging
---------
- Bumped `__version__` to `2.0.0`.
- Added `scipy>=1.10.0` to runtime dependencies (required by
`nn.relevance._expi.ExpiFunction`; the `ARD` variants don't need it
at forward time, but the import-time check is unavoidable).
- Added `[datasets]` extra (`h5py>=3.7`) — for MICCAI MRI.
- Added `[datasets-alos]` extra (`rasterio>=1.3`) — gated separately
because rasterio wraps GDAL, which is a system-level dep on Linux/Mac.
…ced in CI
Stand up a complete `tests/` mirror of the `complextorch/` package and wire
`pytest --cov-fail-under=100` into the existing GitHub Actions workflow, so
any future PR that drops coverage fails CI automatically.
## Infrastructure
- `pyproject.toml`: expand `[project.optional-dependencies].test` to include
`pytest>=8`, `pytest-cov>=5`, `pytest-xdist>=3`, and `hypothesis>=6`. Add
`[tool.pytest.ini_options]` with `-n auto`, `testpaths = ["tests"]`, and
`--strict-markers --strict-config -ra`. Add `[tool.coverage.run]` (line
coverage, source = `complextorch`) and `[tool.coverage.report]` with
`fail_under = 100` plus `exclude_lines` for `raise NotImplementedError`,
`pragma: no cover`, `if TYPE_CHECKING:`, `@overload`, and bare ellipsis.
- `.github/workflows/test.yml`: add `--cov-fail-under=100` and
`--cov-report=term-missing` to the pytest invocation; quote `'.[test]'`
for shell safety.
## Test suite (488 tests, mirroring the package tree)
- `tests/conftest.py`: autouse seeded RNG (`torch`, `random`, `numpy`),
`cplx(*shape)` factory, `device` fixture.
- Unit tests under `tests/nn/{,modules,modules/activation,modules/attention,
masked,relevance,utils}/`, `tests/datasets/`, `tests/models/`, and
`tests/transforms/`, one test file per source file.
- Property tests under `tests/invariants/` using Hypothesis + parametrize:
Fast/Slow conv & linear numerical equivalence (state-dict-aligned
weights), polar round-trip, casting round-trip, FFT round-trip.
- One parameterized test covers all 11 `NotImplementedError` dataset stubs
(`PolSFDataset`, `Bretigny`, `S1SLC`, `MSTARTargets`, `ATRNetSTAR`,
`MICCAI2023`, `ALOSDataset`, `VolFile`, `LeaderFile`, `TrailerFile`,
`SARImage`).
- `_expi.py` validated with `scipy.special.expi` parity at five sample
points plus `torch.autograd.gradcheck` for the analytical backward.
- Full reduction matrix (`'mean'` / `'sum'` / `'none'`) plus invalid-
reduction `ValueError` checks for every loss in `loss.py`.
## Source fixes surfaced by writing the tests
- `loss.py`: `PerpLossSSIM.forward` was passing the complex `(x, y)` pair
to the real-only SSIM conv, which would raise at first use. Pass the
precomputed magnitudes (`mag_input`, `mag_target`) instead, matching the
cited perpendicular-loss reference.
- `masked/base.py`: drop the unreachable
`elif mask_in_missing: missing_keys.remove(mask_key)` branch in
`_load_from_state_dict` — PyTorch's `load_state_dict` hard-codes
`strict=True` when calling `_load_from_state_dict`, and `super()` only
appends to `missing_keys` under `strict=True`, so the `strict=False AND
mask_in_missing=True` precondition is unreachable. Also remove a dead
`if weight.is_complex():` check in `MaskedWeightMixin.sparsity` whose
branches returned identical values.
- `transforms/transforms.py`: simplify `_resize_spectrum` by removing the
real-input fallbacks — the helper is only ever called from `FFTResize`,
which always feeds it a complex spectrum (`fftshift(fft2(x))`).
- `rnn.py`: remove the unused `_maybe_bn` helper.
- `transforms/Normalize`: delete the no-op `broadcast` line (`x.dim() - x.dim() + 2` is always 2 and the value was never read) and tighten the validation to spell out the supported shapes — `(C, H, W)` or `(B, C, H, W)`, channel at dim `-3`. Matches the existing convention used elsewhere in `transforms.py` (e.g. `RealImaginary`).
- `nn.Dropout{1,2,3}d`: drop the unused `real_view = torch.view_as_real(input)` and the stale comment trail describing a `view_as_real`-based approach that was abandoned in favor of the manual Bernoulli mask. `_dropout_fn` is kept — it handles the non-complex fallback path.
- `nn.MagMaxPool*`: drop the unused `spatial_dims` local in `_gather_max_by_magnitude`.
- `docs/source/conf.py`: drop `sys.path.insert(0, "../../")` and the `os`/`sys` imports it required. Both supported doc-build paths (CI's `pip install .[docs]` and the documented local workflow) install `complextorch` before invoking Sphinx, so the path manipulation was a no-op. Also fixes E402 from the previously mid-file imports.
No behavior changes.
Ports the novel layers from the CVPR 2022 "Co-Domain Symmetry for
Complex-Valued Deep Learning" paper (Singhal, Xing, Yu) and adds the
missing SurReal/wFM (Chakraborty, Xing, Yu — arxiv 1910.11334) primitives
that complement the existing `wFMConv1d/2d`. All ports use native
`torch.cfloat` per the project convention.
New modules (`complextorch.nn`):
- `PhaseDivConv{1,2,3}d` — `x · conj(g(x)) / (|g(x)|² + ε)`; U(1)-invariant.
- `PhaseConjConv{1,2,3}d` — `x · conj(g(x))`; also U(1)-invariant when `g`
is C-linear (which it is with native cfloat) — see docstring note on
the divergence from the paper's "phase-mixing" claim.
- `GTReLU` — learnable complex scaling + half-plane phase mask with a
custom autograd Function whose gradient is the mask itself.
- `EquivariantPhaseReLU` — full U(1)-equivariant ReLU via channel-mean
reference phase.
- `ComplexScaling` — learnable `(α + jβ)·z`, sibling of `PhaseShift`.
- `MagBatchNorm{1,2,3}d` — magnitude-only BatchNorm (preserves phase →
U(1)-equivariant), distinct from the existing covariance-whitening
`BatchNorm{1,2,3}d`.
- `PrototypeDistance` — learnable complex prototypes + temperature.
Supports an optional reference rotation for U(1)-equivariant networks.
- `wFMReLU`, `wFMDistanceLinear` — manifold-aware ReLU and a real-valued
distance-to-Fréchet-mean head, completing the SurReal building blocks
alongside the existing `wFMConv{1,2}d`.
Reference models (`complextorch.models`):
- `CDSInvariant` (paper "I-type"), `CDSEquivariant` ("E-type"), `CDSMSTAR`
(SAR backbone with a real-valued ResNet-lite tail).
Tests:
- New `tests/invariants/test_equivariance.py` verifies U(1) equivariance
/ invariance of every applicable module via `M(x·e^{jψ}) = M(x)·e^{jψ}`
(or `= M(x)`).
- New per-module tests: `test_phase_modulation`, `test_prototype`,
`tests/models/test_cds.py`.
- Extended existing files with tests for `ComplexScaling`, `MagBatchNorm*`,
`GTReLU`, `EquivariantPhaseReLU`, `wFMReLU`, `wFMDistanceLinear` plus a
`gradcheck`-style test for the half-plane phase autograd Function.
- Fixed a stale regex in `tests/transforms/test_transforms.py::
test_normalize_wrong_channel_dim` (matched `"channel dim"`, but the
error message had been reworded to mention `"with C=3"`).
Renames the hand-rolled real/imag-split conv/linear implementations out of
the top-level `complextorch.nn` surface and into a `complextorch.nn.gauss`
subpackage. The original `Slow*` prefix was a misleading legacy from when
these were *faster* than the naive split — since PyTorch 2.1.0's native
complex kernels they're slower than the `dtype=torch.cfloat` wrappers, and
the prefix described an accident of history rather than what the layers
actually do.
Renames (breaking, `2.0.0`):
- `complextorch.nn.SlowLinear` → `complextorch.nn.gauss.Linear`
- `complextorch.nn.SlowConv{1,2,3}d` → `complextorch.nn.gauss.Conv{1,2,3}d`
- `complextorch.nn.SlowConvTranspose{1,2,3}d` → `complextorch.nn.gauss.ConvTranspose{1,2,3}d`
The native-cfloat `Linear` / `Conv{1,2,3}d` / `ConvTranspose{1,2,3}d` keep
their top-level names. `complextorch.nn.gauss` is registered as a
subpackage alongside `init`, `relevance`, `masked`, `utils` — the prior
commit already wired the import into `complextorch/nn/__init__.py`; this
commit adds the subpackage itself plus the rest of the rename.
Internal callers:
- `MaskedChannelAttention{1,2,3}d` (six `conv_down` / `conv_up` sites in
`attention/mca.py`) switched from `cvnn.SlowConv*d` to the native
`cvnn.Conv*d`. These are 1×1 convs that gain nothing from the Gauss
trick — the original choice was a holdover.
Tests:
- `tests/nn/gauss/{test_linear,test_conv}.py` — new files mirroring the
package layout convention, with the forward / weight-property /
`bias=False` cases moved from `tests/nn/modules/`.
- `tests/nn/modules/test_{conv,linear}.py` — stripped of `Slow*` blocks;
now cover only the native cfloat wrappers.
- `tests/invariants/test_fast_slow_equivalence.py` →
`test_native_gauss_equivalence.py` with imports and helper names
updated. The property tests continue to assert that the two paths agree
to floating-point tolerance on shared weights.
Docs:
- `CLAUDE.md` "Fast vs. Slow modules" section retitled to "Native vs.
Gauss-trick modules" and rewritten to point at the new subpackage; the
legacy `Slow*` naming is called out so future readers understand the
history.
Git-rename caveat: because `modules/conv.py` and `modules/linear.py` were
*split* (Gauss-trick content into `gauss/`, fast wrappers stayed put),
git's default rename detection (≥50% similarity) sees them as "modified"
rather than "renamed". Use `git log --follow -C30%` or `git blame -C` to
trace the Gauss-trick content's history into its new location.
Replace the Sphinx + sphinx_rtd_theme + hand-maintained per-module `.rst`
stubs with the stack used by NumPy, SciPy, and PyTorch:
- PyData Sphinx Theme (3-column layout, version switcher, dark mode)
- MyST parser so authored pages are Markdown (existing `.rst` keeps working)
- sphinx-autoapi auto-generates the API tree from docstrings — adding a new
public class to `complextorch.nn.*` now requires zero doc-tree edits
- intersphinx mapping to torch / numpy / scipy / python so `:class:torch.nn.X`
cross-references resolve as real links
- myst-nb executes `examples/getting_started.md` on every build; if the
public API breaks, the docs build fails (analogue of the 100% coverage gate)
- sphinx-multiversion builds `/latest/` from `main` plus any post-migration
semver tag, with a version switcher in the navbar
- sphinx-copybutton, sphinx-design (grid cards), sphinx-sitemap, opengraph
Doc-tree restructuring:
- `docs/source/index.{rst → md}`, `installation.md`, `about.md` — rewritten
as MyST. Drops the hand-edited release-date string (`05/11/2026`); the
PyData theme renders the version from `__init__.py` automatically.
- New `docs/source/concepts/activations.md` and `concepts/native-vs-gauss.md`
promote the Type-A / Type-B / fully-complex theory and the native-cfloat
vs. Gauss-trick discussion out of the API tree into a user guide.
- New `docs/source/examples/getting_started.md` — an executed MyST-NB
notebook covering the README Conv1d demo, activation comparison, and
`pwelch` PSD.
- Deleted ~40 per-module `.rst` stubs under `docs/source/nn/` plus the
top-level `signal.rst` / `transforms.rst` / `datasets.rst` / `models.rst`
— sphinx-autoapi covers them.
Release notes:
- New `CHANGELOG.md` at the repo root (keepachangelog format) holds the
v2.0 and v1.2 notes that used to live inline in `index.rst`.
- `docs/source/changelog.md` `include`s it, so PyPI release pages and the
docs site share a single source of truth.
- `pyproject.toml` now advertises `[project.urls] Changelog = …` so it
surfaces on PyPI as well.
CI:
- `.github/workflows/docs.yml` gets a `pull_request` job that runs
`sphinx-build -W` (warnings-as-errors) as a merge gate, mirroring the
`--cov-fail-under=100` gate in `test.yml`. Broken cross-refs, missing
toctree entries, and notebook-execution failures all block merge.
- The `main` deploy job switches to `sphinx-multiversion` with
`fetch-depth: 0` + `fetch-tags`, copies a `redirect.html` so the root
URL bounces to `/latest/`, and bumps Python to 3.12 to match `test.yml`.
Source-side cleanup (opportunistic, gated by the migration):
- Convert ad-hoc Markdown links to `torch.nn.*` URLs into Sphinx
`:class:torch.nn.X` roles across `nn/gauss/conv.py`, `nn/gauss/linear.py`,
`nn/modules/{batchnorm,mask,pooling,linear}.py` — intersphinx now
auto-links them.
- Fix broken `:doc:` refs in `nn/modules/attention/{__init__,eca,mca}.py`
that pointed at the now-removed per-module RST tree; rewrite as `:mod:`
refs to the canonical Python modules.
- Extend the title underlines in `nn/modules/casting.py` and `manifold.py`
that were too short for their multi-byte (`→`, `é`) titles.
- Fold the `Attributes: is_sparse` block in `BaseMasked` into prose so the
`@property` is the single documented site.
- Expand the one-line docstrings on `ToTensor` / `Unsqueeze` / `HWC2CHW` /
`ToReal` / `ToImaginary` in `transforms/transforms.py` with Args /
behavioural notes — autoapi renders them as the canonical reference.
README:
- Drop the inline 1.2 release notes (now in `CHANGELOG.md`).
- Point the documentation link at the new GitHub Pages URL and add a
pointer to the Getting Started notebook.
Caveats:
- `sphinx-multiversion` 0.2.4 is incompatible with Sphinx ≥ 8
(`Config.read` signature changed). The docs extra is pinned to
`sphinx<8` (which pulls myst-parser <5 / myst-nb <2 with it). Migrate
to `sphinx-polyversion` to unblock newer Sphinx — left as a follow-up
noted in `pyproject.toml` and `CLAUDE.md`.
- The `sphinx-multiversion` whitelist is `^2\.[1-9]\d*\.\d+$` — historical
tags pre-2.1 won't re-render (their `conf.py` predates this migration).
Those releases stay accessible via PyPI.
Verified locally with `sphinx-build -W -b html docs/source docs/build/html`
— zero warnings, 111 HTML pages, intersphinx hits PyTorch's inventory, the
notebook executes and emits 14 cell input/output spans.
Apply the full `ruff check` ruleset configured in `pyproject.toml` (E/W, F, I,
B, UP, SIM, RUF, C4, PT, PIE, RET, TCH) and bring the tree to zero warnings so
the new `lint.yml` workflow and `pre-commit` hooks can land green.
Mechanical modernizations (autofixed):
- `super(Cls, self).__init__()` → `super().__init__()` (UP008)
- `Tuple[...]` / `Dict[...]` / `Optional[...]` / `Union[X, Y]` → PEP-585/604
builtins and `X | None` / `X | Y` (UP006, UP007, UP045)
- `from typing import Callable, Sequence, Iterator` → `from collections.abc`
(UP035)
- Sorted `__all__` and import blocks (RUF022, I001)
- `pytest.mark.parametrize("a, b", ...)` → tuple form (PT006)
- `dict(...)` literals in tests → `{...}` (C408)
- Dropped `pass` after docstring in alias subclasses (PIE790)
- `zip(...)` calls now pass `strict=False` to preserve current behavior (B905)
- Misc: yoda-condition flip, redundant `open(p, "r")`, unused unpacked vars,
tuple-concat → unpacking, `__setattr__` explicit `return`, etc.
Manual fixes:
- `nn/modules/linear.py`: annotate `__constants__: ClassVar[list[str]]`
(matches PyTorch's own annotation; silences RUF012 false positive on the
TorchScript convention).
- `nn/modules/upsampling.py`: rewrite `Optional[int | tuple[int, ...]]` as
`int | tuple[int, ...] | None` and drop the now-unused `Optional` import
(UP045 doesn't autofix mixed PEP-604 / `Optional` aliases).
- `nn/masked/__init__.py`: remove leftover `Dict, Tuple` imports.
- `nn/modules/batchnorm.py`: `# noqa: SIM102` on the nested `if self.training
and self.track_running_stats: / if self.num_batches_tracked is not None:`
block so the implementation stays line-for-line comparable to
`torch.nn.modules.batchnorm._BatchNorm.forward`.
- `nn/utils/sparsity.py`: rename unused `mod_name` loop var to `_mod_name`.
- `transforms/functional.py`: collapse `x.abs() if x.is_complex() else
x.abs()` to `x.abs()` (RUF034 — both branches were identical).
Config:
- `pyproject.toml`: add `allowed-confusables = ["×"]` so intentional math
notation in docstrings/comments (`2×2 whitening`, `weights × V`) passes
RUF002/RUF003.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
complextorch 2.0.0
This branch promotes
complextorchfrom a small drop-in complex-NNextension into a strict superset of the three sibling libraries in the
ecosystem (Popoff's
complexPyTorch, Nazarov'scplxmodule, Levi/Fixet al.'s
torchcvnn), with a hard 100% line-coverage CI gate, moderndocs on GitHub Pages, an Apache 2.0 license, and a clean ruff bill of
health.
It rolls 11 commits —
0f13891..29add19— into the2.0.0release.Two of those changes are unavoidably breaking; everything else is
additive.
Summary
init (
nn.init), Variational Dropout / ARD (nn.relevance), maskedlayers (
nn.masked), RNNs (GRU/LSTM), Transformer,RMSNorm,GroupNorm,NaiveBatchNorm*d,MagMaxPool*d, channelDropout*d,Upsample+PolarUpsample, layout-conversion modules(
InterleavedToComplex/ConcatenatedToComplex/RealToComplex,and inverses),
PhaseShift,Bilinear, eight new activations(
CELU/CCELU/CGELU,zAbsReLU/zLeakyReLU,Mod,AdaptiveModReLU,learnable
modReLU),MSELoss.complextorch.signal(pwelch),complextorch.transforms(torchcvnn-style dataloader transforms),complextorch.datasets(SAR / MRI surface —SAMPLEandSLCDatasetare real, the heavier readers are honestNotImplementedErrorstubs pointing attorchcvnn),complextorch.models(ViTfamily + CDS reference models).plus the missing SurReal/wFM primitives:
PhaseDivConv*d,PhaseConjConv*d,GTReLU,EquivariantPhaseReLU,ComplexScaling,MagBatchNorm*d,PrototypeDistance,wFMReLU,wFMDistanceLinear, plus reference modelsCDSInvariant,CDSEquivariant,CDSMSTAR.Gauss-trick bias was off by
-b_i*(1+j)in theSlow*family,BatchNormeval-mode broadcastrunning_meanagainst the wrongaxes, attention used
QKᵀinstead of the HermitianQKᴴ,PhaseSigmoidwas an empty class. None of the fast (native-cfloat)forward paths were affected.
complextorch.nn.Slow*into a dedicated
complextorch.nn.gausssubpackage. TheSlow*prefix was a misleading legacy — since PyTorch 2.1.0's native
complex kernels these are slower, not faster, than the
dtype=torch.cfloatwrappers.--cov-fail-under=100in.github/workflows/test.yml. 488 tests organized in atests/tree that mirrors
complextorch/1:1, plus Hypothesis propertytests under
tests/invariants/(native↔gauss equivalence, polarround-trip, casting round-trip, FFT round-trip, U(1) equivariance).
.rststubs aregone — autoapi generates the API tree from docstrings, so adding a
new public class no longer requires touching the docs at all. An
executable Getting Started notebook fails the build if the public
API breaks. Deployed to GitHub Pages with a version switcher and
PR-level
sphinx-build -Wmerge gate..github/workflows/{test,docs,lint,pypi}.ymlplus a
.pre-commit-config.yaml. The full ruff ruleset(E/W, F, I, B, UP, SIM, RUF, C4, PT, PIE, RET, TCH) passes with
zero warnings.
NOTICEfileacknowledging upstream code from the audited siblings.
Breaking changes
Linear/ fastConv{1,2,3}d/ fastConvTranspose{1,2,3}dnowdefault to
bias=Trueto matchtorch.nn. Passbias=Falseexplicitly if you relied on the old default.
MultiheadAttention/ScaledDotProductAttentionuse theHermitian inner product
QKᴴinstead ofQKᵀ. There is noopt-out — the prior behaviour was a math bug. A new
softmax_on='complex'|'real'flag selects between the existingCVSoftMax-on-complex semantics (default, preserves behaviour)and the real-valued softmax-on-
Re[QKᴴ/√d]formulation used bytorchcvnn.complextorch.nn.SlowLinear/SlowConv{1,2,3}d/SlowConvTranspose{1,2,3}dwere renamed tocomplextorch.nn.gauss.Linear/gauss.Conv{1,2,3}d/gauss.ConvTranspose{1,2,3}d. The top-levelSlow*names aregone; mechanical find-and-replace covers all existing call sites.
scipy>=1.10.0(required by theEi-based KL innn.relevance._expi).Test plan
pytest --cov=complextorch --cov-report=term-missing --cov-fail-under=100(488 tests, 100% line coverage, mirrors the CI gate)sphinx-build -W -b html docs/source docs/build/html(zero warnings, 111 HTML pages, executable notebook runs cleanly)ruff check .(zero warnings on the full configured ruleset)tests/invariants/: native↔gauss numerical equivalence, polar round-trip, casting round-trip, FFT round-trip, U(1) equivariance of every applicable CDS modulecomplextorch.nn.__init__has a docstring and resolves under autoapiNotes for reviewers
9c7d447("Expand feature surfaceto match complexPyTorch, cplxmodule, and torchcvnn"). The two
breaking changes are scoped to that commit.
de3b8e4is a pure rename refactor — Gauss-trick content moved outof
modules/conv.pyandmodules/linear.pyinto a newgauss/subpackage. Because the files were split rather than moved, git's
default rename detection won't follow the history; use
git log --follow -C30%orgit blame -Cto trace the Gauss-trickcontent's history into its new location.
503a9aedeletes ~40 hand-maintained.rststubs; sphinx-autoapigenerates equivalent (and richer) pages from docstrings on every
build. The Sphinx 7 pin in
pyproject.tomlis documented — see thecaveat there; it's tied to
sphinx-multiversion0.2.4 beingincompatible with Sphinx 8+ and is tracked as a follow-up.
0f13891is the correctness-fix audit. The descriptions in itscommit message double as test cases; every fix has a corresponding
regression test in the suite added by
36f0f18.