Skip to content

2.0.0#9

Merged
josiahwsmith10 merged 14 commits into
mainfrom
2.0.0
May 11, 2026
Merged

2.0.0#9
josiahwsmith10 merged 14 commits into
mainfrom
2.0.0

Conversation

@josiahwsmith10
Copy link
Copy Markdown
Owner

@josiahwsmith10 josiahwsmith10 commented May 11, 2026

complextorch 2.0.0

This branch promotes complextorch from a small drop-in complex-NN
extension into a strict superset of the three sibling libraries in the
ecosystem (Popoff's complexPyTorch, Nazarov's cplxmodule, Levi/Fix
et al.'s torchcvnn), with a hard 100% line-coverage CI gate, modern
docs on GitHub Pages, an Apache 2.0 license, and a clean ruff bill of
health.

It rolls 11 commits — 0f13891..29add19 — into the 2.0.0 release.
Two of those changes are unavoidably breaking; everything else is
additive.

Summary

  • New feature surface matching the audited siblings: complex weight
    init (nn.init), Variational Dropout / ARD (nn.relevance), masked
    layers (nn.masked), RNNs (GRU/LSTM), Transformer, RMSNorm,
    GroupNorm, NaiveBatchNorm*d, MagMaxPool*d, channel Dropout*d,
    Upsample + PolarUpsample, layout-conversion modules
    (InterleavedToComplex / ConcatenatedToComplex / RealToComplex,
    and inverses), PhaseShift, Bilinear, eight new activations
    (CELU/CCELU/CGELU, zAbsReLU/zLeakyReLU, Mod, AdaptiveModReLU,
    learnable modReLU), MSELoss.
  • New subpackages: complextorch.signal (pwelch),
    complextorch.transforms (torchcvnn-style dataloader transforms),
    complextorch.datasets (SAR / MRI surface — SAMPLE and
    SLCDataset are real, the heavier readers are honest
    NotImplementedError stubs pointing at torchcvnn),
    complextorch.models (ViT family + CDS reference models).
  • Co-Domain-Symmetry layers ported from Singhal/Xing/Yu (CVPR 2022)
    plus the missing SurReal/wFM primitives: PhaseDivConv*d,
    PhaseConjConv*d, GTReLU, EquivariantPhaseReLU,
    ComplexScaling, MagBatchNorm*d, PrototypeDistance, wFMReLU,
    wFMDistanceLinear, plus reference models CDSInvariant,
    CDSEquivariant, CDSMSTAR.
  • Latent correctness fixes that an audit surfaced — most notable:
    Gauss-trick bias was off by -b_i*(1+j) in the Slow* family,
    BatchNorm eval-mode broadcast running_mean against the wrong
    axes, attention used QKᵀ instead of the Hermitian QKᴴ,
    PhaseSigmoid was an empty class. None of the fast (native-cfloat)
    forward paths were affected.
  • Gauss-trick layers moved from top-level complextorch.nn.Slow*
    into a dedicated complextorch.nn.gauss subpackage. The Slow*
    prefix was a misleading legacy — since PyTorch 2.1.0's native
    complex kernels these are slower, not faster, than the
    dtype=torch.cfloat wrappers.
  • 100% test coverage, hard-enforced via --cov-fail-under=100 in
    .github/workflows/test.yml. 488 tests organized in a tests/
    tree that mirrors complextorch/ 1:1, plus Hypothesis property
    tests under tests/invariants/ (native↔gauss equivalence, polar
    round-trip, casting round-trip, FFT round-trip, U(1) equivariance).
  • Modern docs stack (PyData Sphinx Theme + MyST + sphinx-autoapi
    • myst-nb + sphinx-multiversion). The per-module .rst stubs are
      gone — autoapi generates the API tree from docstrings, so adding a
      new public class no longer requires touching the docs at all. An
      executable Getting Started notebook fails the build if the public
      API breaks. Deployed to GitHub Pages with a version switcher and
      PR-level sphinx-build -W merge gate.
  • CI and tooling: .github/workflows/{test,docs,lint,pypi}.yml
    plus a .pre-commit-config.yaml. The full ruff ruleset
    (E/W, F, I, B, UP, SIM, RUF, C4, PT, PIE, RET, TCH) passes with
    zero warnings.
  • License switched to Apache 2.0 (was MIT) with a NOTICE file
    acknowledging upstream code from the audited siblings.

Breaking changes

  1. Linear / fast Conv{1,2,3}d / fast ConvTranspose{1,2,3}d now
    default to bias=True to match torch.nn. Pass bias=False
    explicitly if you relied on the old default.
  2. MultiheadAttention / ScaledDotProductAttention use the
    Hermitian inner product QKᴴ instead of QKᵀ. There is no
    opt-out — the prior behaviour was a math bug. A new
    softmax_on='complex'|'real' flag selects between the existing
    CVSoftMax-on-complex semantics (default, preserves behaviour)
    and the real-valued softmax-on-Re[QKᴴ/√d] formulation used by
    torchcvnn.
  3. complextorch.nn.SlowLinear / SlowConv{1,2,3}d /
    SlowConvTranspose{1,2,3}d were renamed to
    complextorch.nn.gauss.Linear / gauss.Conv{1,2,3}d /
    gauss.ConvTranspose{1,2,3}d. The top-level Slow* names are
    gone; mechanical find-and-replace covers all existing call sites.
  4. New runtime dependency on scipy>=1.10.0 (required by the
    Ei-based KL in nn.relevance._expi).

Test plan

  • pytest --cov=complextorch --cov-report=term-missing --cov-fail-under=100 (488 tests, 100% line coverage, mirrors the CI gate)
  • sphinx-build -W -b html docs/source docs/build/html (zero warnings, 111 HTML pages, executable notebook runs cleanly)
  • ruff check . (zero warnings on the full configured ruleset)
  • Hypothesis property tests under tests/invariants/: native↔gauss numerical equivalence, polar round-trip, casting round-trip, FFT round-trip, U(1) equivariance of every applicable CDS module
  • Spot-check that every public class re-exported from complextorch.nn.__init__ has a docstring and resolves under autoapi

Notes for reviewers

  • The single big additive commit is 9c7d447 ("Expand feature surface
    to match complexPyTorch, cplxmodule, and torchcvnn"). The two
    breaking changes are scoped to that commit.
  • de3b8e4 is a pure rename refactor — Gauss-trick content moved out
    of modules/conv.py and modules/linear.py into a new gauss/
    subpackage. Because the files were split rather than moved, git's
    default rename detection won't follow the history; use
    git log --follow -C30% or git blame -C to trace the Gauss-trick
    content's history into its new location.
  • 503a9ae deletes ~40 hand-maintained .rst stubs; sphinx-autoapi
    generates equivalent (and richer) pages from docstrings on every
    build. The Sphinx 7 pin in pyproject.toml is documented — see the
    caveat there; it's tied to sphinx-multiversion 0.2.4 being
    incompatible with Sphinx 8+ and is tracked as a follow-up.
  • 0f13891 is the correctness-fix audit. The descriptions in its
    commit message double as test cases; every fix has a corresponding
    regression test in the suite added by 36f0f18.

…s, and losses

A code audit turned up several issues that were either silently wrong,
crashed in narrow cases (CPU-only installs, eval mode, zero-magnitude
inputs), or contradicted their own docstrings. None of the fast
(native-cfloat) modules were affected; the user-visible default paths
(`Conv1d`, `Linear`, `BatchNorm*d` in training, the polar activations,
attention) all continue to behave the same to first order, but the
hand-rolled `Slow*` variants and the eval-mode of `BatchNorm` produced
wrong numbers under common settings.

Critical correctness fixes:

- Gauss-trick bias in `SlowLinear` / `SlowConv1d/2d/3d` /
  `SlowConvTranspose1d/2d/3d` was off by `-b_i * (1 + j)` whenever
  `bias=True`. The previous implementation folded `b_r` into `t1`, `b_i`
  into `t2`, and `b_r + b_i` into `t3`, then computed
  `(t1 - t2) + j(t3 - t2 - t1)`, which cancels the bias incorrectly.
  Bias is now stored as two real `nn.Parameter`s (`bias_r`, `bias_i`,
  registered via `register_parameter` so they participate in
  `state_dict()` and gradients), and applied to the real and imaginary
  outputs after the Gauss combination. The `bias` property still
  returns `torch.complex(bias_r, bias_i)` so external callers see no
  change. `SlowConv*` default `bias=True`, so this affected
  out-of-the-box usage.

- `_Conv.forward` / `_ConvTranspose.forward` now forward `dilation`
  (and `output_padding`, for the transpose path) into the raw
  `F.conv*d` / `F.conv_transpose*d` call used for `t3`. Without these,
  any non-default `dilation` or `output_padding` made `t3` a different
  shape from `t1, t2` and the final `torch.complex` errored out.

- Removed three unconditional debug prints in `_Conv`. One hard-coded
  `torch.cuda.memory_allocated("cuda:0")` (crashed on CPU-only installs
  and fired on every forward); another printed `self.conv_i.bias.type()`
  from the `bias` property, which crashed with `AttributeError` when
  `bias=False`. Both also leaked into any caller of
  `MaskedChannelAttention1d/2d/3d`, which composes `SlowConv*` under
  the hood.

- `nn.functional.batch_norm` eval-mode (`training=False` with stored
  running stats) used `running_mean` directly at shape `[2, F]` and
  subtracted it from `x` of shape `[2, B, F, ...]`. PyTorch's
  right-aligned broadcasting matched the last two axes instead of
  `[stack, F]`. The else branch now reshapes to
  `[2, 1, F, 1, ..., 1]`. Also made the centering out-of-place
  (`x = x - mean` rather than `x -= mean`) and fixed a duplicate-name
  unpack (`v_rr, v_ir, v_ir, v_ii = ...` -> `..., _, v_ii = ...`); the
  latter was benign because the off-diagonals were stored equal, but
  misleading.

- `mask.PhaseSigmoid` had `pass` as its class body — instantiating it
  worked, calling it raised `NotImplementedError`. Now implements the
  phase-preserving sigmoid the docstring already described.

- `mask.MagMinMaxNorm` subtracted a real scalar `x_min` from a complex
  tensor, which silently changed phase despite the docstring promising
  it would be preserved; its `dim` constructor argument was also stored
  and never used. The forward now operates on `input.abs()` and
  rebuilds via `torch.polar(new_mag, input.angle())`, and honors
  `self.dim`.

- `loss.SSIM.forward` with `data_range=None` first built a 4-D
  `data_range` and then unconditionally ran
  `data_range[:, None, None, None]`, producing a 7-D tensor that
  misbroadcast downstream. The unsqueeze is now gated to the
  user-supplied branch.

High-severity semantic fixes:

- Added `1e-12` clamp on the divisor in the three `/ |z|` sites that
  could produce NaN on zero magnitude: `apply_complex_polar`
  (phase-skip path), `PhaseSoftMax.forward`, and
  `ComplexRatioMask.forward`. Smoke-tested with a zero-containing
  complex tensor; no NaN.

- `_modReLU` / `modReLU` defaulted `bias=0.0` but immediately asserted
  `bias < 0`. Default is now `-0.1` so `modReLU()` with no args works.

- `attention.ScaledDotProductAttention` now uses
  `k.conj().transpose(-2, -1)` so the dot product is the Hermitian
  inner product `Q K^H` rather than the bilinear form `Q K^T`. The
  docstrings on both `ScaledDotProductAttention` and
  `MultiheadAttention` previously claimed `CVSoftMax` applied softmax
  to magnitude; in reality `CVSoftMax` applies softmax to real and
  imaginary parts separately (`PhaseSoftMax` is the magnitude variant).
  Rewrote both to describe what the code actually does.

- `nn.functional.layer_norm` now asserts `weight` and `bias` are both
  set or both `None`, mirroring `batch_norm`.

- `Dropout` docstring now explicitly states that real and imaginary
  parts get **independent** Bernoulli masks, and contrasts with
  Trabelsi 2018's shared-mask formulation, so the choice is deliberate
  rather than implicit.

API and packaging:

- Re-export fast `ConvTranspose1d`, `ConvTranspose2d`,
  `ConvTranspose3d`, and `SSIM` from `complextorch.nn`. The fast
  transposes existed but were not exposed; only the `Slow` variants
  were.

- Fast `ConvTranspose*` default `output_padding` is now `0`, matching
  `torch.nn.ConvTranspose*`. (Previously defaulted to `1`, which
  changed output shape vs. native PyTorch.)

- `CVQuadError`, `CVFourthPowError`, `CVCauchyError`, `CVLogCoshError`,
  and `CVLogError` now accept a `reduction='mean' | 'sum' | 'none'`
  argument, defaulting to `'mean'`. Previously they hard-coded `.sum()`,
  so loss magnitude (and gradient scale) varied with batch and feature
  size and was inconsistent with `SplitMSE` / `SplitL1`. The original
  paper formulas are recoverable with `reduction='sum'`.

- `setup.py`: bumped `python_requires` to `>=3.10` (matches
  `numpy>=2.2.0`) and `torch>=2.1.0` (matches the package's reliance on
  native complex kernels). Dropped `deprecated>=1.2.18` from
  `install_requires` (no longer used after the `CVTensor` removal).
  Same updates applied to `requirements.txt` (which previously pinned
  a CUDA-only `torch>=1.11.0+cu115`, breaking non-CUDA installs) and
  `docs/requirements.txt`.

Minor:

- `manifold.wFMConv2d` now caches `nn.Fold` by input spatial shape
  (`_get_fold(input_shape)`) instead of constructing — and registering
  as a submodule — a fresh `nn.Fold` on every forward.

- `manifold.wFMConv1d.forward` uses `.squeeze(-2)` to pair with its
  `.unsqueeze(-2)`. The previous bare `.squeeze()` could drop a
  singleton batch dim.

- `attention/eca.py`: collapsed a `.transpose(-1, -2).transpose(-1, -2)`
  no-op into the surrounding `.view` calls.

Docs:

- `docs/source/index.rst`: updated release date to `05/10/2026`,
  renamed the release-notes section to "Version 1.2 Release Notes",
  rewrote the bullets to describe the legacy-API removal plus the
  correctness fixes in this change, and corrected the long-standing
  typo `dtype=torch.float` -> `dtype=torch.cfloat`.

- `README.md`: same release-notes rewrite.
…cvnn (2.0.0)

After auditing the three sibling complex-NN libraries checked out alongside
this one (Popoff's `complexPyTorch`, Nazarov's `cplxmodule`, and Levi/Fix
et al.'s `torchcvnn`), this commit lands every feature they ship that
`complextorch` did not, in a single 2.0.0 release. The library is now a
strict superset of those three on the `nn`, `transforms`, `models`,
`datasets`, and signal-processing fronts.

This bumps the version to 2.0.0 because two unavoidable changes are
backwards-incompatible: a `bias=False` -> `bias=True` default change on the
Fast `Conv*d` / `ConvTranspose*d` / `Linear` wrappers (to match `torch.nn`),
and a silent `QK^T` -> `QK^H` math fix in `MultiheadAttention` (the inner
product was using a plain transpose where the Hermitian transpose is the
standard choice for complex inner products).

New subpackages
---------------

- `complextorch.nn.init` — variance-correct complex weight initializers:
  `kaiming_normal_`, `kaiming_uniform_`, `xavier_normal_`, `xavier_uniform_`,
  `trabelsi_standard_` (Rayleigh-uniform polar), `trabelsi_independent_`
  (semi-unitary via SVD). PyTorch's built-in init treats the real and
  imaginary parts as independent real tensors, which gives the wrong
  variance for the complex magnitude (off by a factor of two); these
  produce the intended `Var(|w|^2)` for both He and Glorot targets.

- `complextorch.nn.relevance` — complex Variational Dropout / Automatic
  Relevance Determination. `LinearVD`, `BilinearVD`, `Conv{1,2,3}dVD`,
  plus the corresponding `*ARD` variants. Adapted from
  `cplxmodule.nn.relevance.complex`. Real-valued VD/ARD layers are
  intentionally not ported — only the complex variants, with the `Cplx`
  prefix dropped since the subpackage name already scopes the meaning.
  Adds `scipy` as a runtime dependency for the exact `Ei`-based KL
  (`complextorch.nn.relevance._expi.ExpiFunction`); the empirical-Bayes
  `ARD` variants use a pure-torch softplus penalty and don't need scipy
  at runtime.

- `complextorch.nn.masked` — fixed-mask sparsified layers (`LinearMasked`,
  `BilinearMasked`, `Conv{1,2,3}dMasked`) plus the module-walking helpers
  `deploy_masks`, `binarize_masks`, `is_sparse`, `named_masks`. Designed
  to compose with `nn.relevance`: train with `LinearVD`, extract a
  relevance mask via `compute_ard_masks(model, threshold=...)`, deploy
  it onto a parallel `LinearMasked` inference model with `deploy_masks`.

- `complextorch.nn.utils.sparsity` — `SparsityStats` base + the
  module-walking `named_sparsity` / `sparsity` helpers.

- `complextorch.signal` — `pwelch` (a differentiable torch port of
  `scipy.signal.welch`; matches scipy to ~2e-15 in unit tests).

- `complextorch.transforms` — torchcvnn-style dataloader-stage transforms
  (`LogAmplitude`, `Amplitude`, `ToReal`, `ToImaginary`, `RealImaginary`,
  `Normalize` (per-channel 2x2 whitening), `RandomPhase`, `PadIfNeeded`,
  `CenterCrop`, `SpatialResize`, `FFT2`, `IFFT2`, `FFTResize`, `PolSAR`,
  `ToTensor`, `Unsqueeze`, `HWC2CHW`). Torch-only — the upstream's dual
  numpy/torch dispatch is intentionally dropped for simplicity. Two
  public functional helpers (`polsar_dict_to_array`, `rescale_intensity`).

- `complextorch.datasets` — SAR / MRI / SLC dataset surface. `SAMPLE`
  (in-memory random complex chips) and `SLCDataset` (generic `.npy` /
  `.pt` directory reader) are full implementations; the file-format-heavy
  loaders (`PolSFDataset`, `Bretigny`, `S1SLC`, `MSTARTargets`,
  `ATRNetSTAR`, `MICCAI2023`, `ALOSDataset`+CEOS readers) are present as
  importable classes with `_NotImplementedDataset` stubs that raise at
  instantiation with a clear pointer to the upstream torchcvnn reference
  implementation. This keeps the import surface complete (tab-completion,
  type checking) without dishonestly claiming the loaders work.
  Optional deps gated behind `pip install complextorch[datasets]` (h5py)
  and `complextorch[datasets-alos]` (rasterio).

- `complextorch.models` — first pre-built reference architecture lives
  here. `ViT`, `ViTLayer`, and presets `vit_t/s/b/l/h` matching the
  standard ViT family (5M-630M parameters).

New nn.modules
--------------

- `nn.modules.transformer` — `TransformerEncoderLayer`, `TransformerEncoder`,
  `TransformerDecoderLayer`, `TransformerDecoder`, `Transformer`.
  Composed from the existing `MultiheadAttention` (which already includes
  residual + LayerNorm internally) plus a feed-forward sub-block with
  its own residual + LayerNorm. `softmax_on` flag is plumbed through
  from `MultiheadAttention`.

- `nn.modules.rnn` — `GRUCell`, `GRU`, `LSTMCell`, `LSTM`. Cells are
  built from complex `Linear` projections + `CSigmoid` + `CTanh` (not
  the dual-real-RNN wrapper trick from `complexPyTorch`, which has a
  parameterization quirk where `gru_re` does both `Wr.xr` and `Wr.xi`,
  giving a different model from a proper complex GRU). Multi-layer
  wrappers stack cells along time in Python; no CuDNN fused path
  because CuDNN's complex RNN isn't a thing. Each cell accepts
  `batchnorm=False`; setting it to `True` inserts `BatchNorm1d` after
  every internal linear projection (Recurrent BN, Cooijmans et al. 2017).

- `nn.modules.rmsnorm` — `RMSNorm`. Complex RMS-norm; equivalent to
  `torch.nn.RMSNorm` but with an optional per-feature 2x2 affine instead
  of a scalar gain.

- `nn.modules.groupnorm` — `GroupNorm`. Per-group 2x2 whitening
  (Trabelsi) + per-channel 2x2 affine + 2-vector bias. No running
  statistics. Reuses the existing `inv_sqrtm2x2` from `nn.functional`.

- `nn.modules.upsampling` — `Upsample` (split real/imag; analogous to
  `torchcvnn.nn.Upsample` and `complexPyTorch.complex_upsample`) and
  `PolarUpsample` (polar form; matches the phase-preserving
  `complex_upsample2`). The polar form preserves phase along smooth
  regions but introduces visible discontinuities at the +/-pi phase
  wrap; pick based on the smoothness profile of your data.

- `nn.modules.casting` — layout-conversion modules:
  `InterleavedToComplex`, `ComplexToInterleaved`, `ConcatenatedToComplex`,
  `ComplexToConcatenated`, `RealToComplex` (zero-imag lift). The
  redundant `TensorToCplx` / `CplxToTensor` from `cplxmodule` are
  deliberately omitted — `torch.view_as_complex` / `torch.view_as_real`
  already cover those.

- `nn.modules.phase` — `PhaseShift`, a learnable per-channel /
  per-feature phase rotation: `y = x * exp(j * phi)` with `phi` an
  `nn.Parameter` initialized uniformly in `[-pi, pi]`.

Additions to existing nn modules
--------------------------------

- `nn.modules.linear` — new `Bilinear` (complex bilinear with a
  `conjugate=True/False` flag selecting Hermitian vs plain bilinear
  form). `Linear` / `SlowLinear` default changed to `bias=True`.

- `nn.modules.batchnorm` — new `NaiveBatchNorm{1,2,3}d` (independent
  `torch.nn.BatchNorm*d` on real and imag — the split baseline, kept
  for parity with `complexPyTorch.NaiveComplexBatchNorm*d` and as a
  cheap drop-in when the Trabelsi 2x2 whitening overhead isn't
  warranted).

- `nn.modules.pooling` — new `MagMaxPool{1,2,3}d` (magnitude-argmax max
  pool — `torch.nn.MaxPool*d` doesn't define `>` on complex, so this
  is the canonical complex analogue; gathers the original complex
  sample at the max-magnitude position, preserving phase) and
  `AvgPool{1,2,3}d` (thin wrapper preserving the import-swap promise).
  The `MagMaxPool` name is deliberate: `MaxPool*d` would suggest a
  drop-in replacement for `torch.nn.MaxPool*d`, but the semantics are
  meaningfully different.

- `nn.modules.dropout` — new `Dropout1d` / `Dropout2d` / `Dropout3d`
  with shared real/imag mask (Trabelsi 2018) via a `view_as_real` round
  trip into `F.dropout*d`. Preserves the phase of surviving entries —
  the original `Dropout` (independent masks per part) is unchanged for
  backwards compatibility.

- `nn.modules.loss` — new `MSELoss` matching `torch.nn.MSELoss` exactly
  (no 1/2 factor). The existing `CVQuadError` is unchanged and retains
  its physics-style 1/2 factor; verified numerically that
  `MSELoss(x, y) == 2 * CVQuadError(x, y)` for any inputs.

- `nn.modules.activation` — eight new activations:
  - `CVSplitELU` + alias `CELU` (split Type-A ELU)
  - `CVSplitCELU` + alias `CCELU` (split Type-A `torch.nn.CELU`)
  - `CVSplitGELU` + alias `CGELU` (split Type-A GELU)
  - `zAbsReLU` (magnitude-thresholded with learnable threshold)
  - `zLeakyReLU` (leaky version of `zReLU`)
  - `Mod` (magnitude as a module — complex -> real)
  - `AdaptiveModReLU` (per-channel learnable threshold modReLU)
  - `modReLU` gains a `learnable: bool = False` flag — when `True`,
    the threshold becomes a single trainable scalar.

- `nn.modules.attention` — `MultiheadAttention` and
  `ScaledDotProductAttention` now use `k.conj().transpose(-2, -1)` for
  the Hermitian inner product (was the plain transpose; a math bug).
  New `softmax_on='complex'|'real'` flag: `'complex'` (default) keeps
  the existing `CVSoftMax`-on-`QK^H` behavior; `'real'` applies
  `torch.softmax(Re[QK^H/sqrt(d)])` for real-valued attention weights
  (the formulation used by torchcvnn).

- `nn.modules.conv` — Fast `Conv*d` / `ConvTranspose*d` default changed
  to `bias=True`. Class-level docstring sentence added noting the only
  behavioral difference vs `torch.nn` is the default `dtype=torch.cfloat`.

- `nn.__init__.py` — re-exports `ConvTranspose1d/2d/3d` (the classes
  already existed in `conv.py` but were missing from the public
  surface).

- `nn.functional` — the previously private whitening helpers
  `_whiten2x2_batch_norm` and `_whiten2x2_layer_norm` are now public
  (`whiten2x2_batch_norm`, `whiten2x2_layer_norm`); both are added to
  `__all__` alongside `batch_norm`, `layer_norm`, `inv_sqrtm2x2`.

Breaking changes
----------------

- `Linear` / `SlowLinear` and fast `Conv{1,2,3}d` /
  `ConvTranspose{1,2,3}d` default to `bias=True` (was `False`),
  matching `torch.nn`. Existing code that relied on the silent
  `bias=False` default and intentionally wanted no bias must pass
  `bias=False` explicitly.

- `MultiheadAttention` and `ScaledDotProductAttention` use the
  Hermitian inner product `Q K^H` instead of `Q K^T`. This is the
  standard math for complex inner products and matches all sibling
  libraries; results numerically differ from prior outputs for any
  complex `K`. There is no opt-out flag — the previous behavior was a
  bug, not a design choice.

Documentation
-------------

Every new public class and function appears in the Sphinx docs (GitHub
Pages workflow under `.github/workflows/docs.yml`):

- Seven new module RST files under `docs/source/nn/modules/`:
  `casting.rst`, `phase.rst`, `rmsnorm.rst`, `groupnorm.rst`,
  `upsampling.rst`, `rnn.rst`, `transformer.rst`.
- Four new subpackage RST files under `docs/source/nn/`: `init.rst`,
  `relevance.rst`, `masked.rst`, `utils.rst`.
- Four new top-level RST files under `docs/source/`: `signal.rst`,
  `transforms.rst`, `datasets.rst`, `models.rst`.
- Toctrees updated in `docs/source/index.rst`, `docs/source/nn.rst`,
  and `docs/source/nn/modules.rst`.
- `index.rst` gains a "Version 2.0 Release Notes" section enumerating
  the additions and the two breaking changes.

`sphinx-build -W --keep-going` returns clean (the only warning is a
pre-existing missing-`_static` directory unrelated to this change). All
35+ new public symbols I spot-checked resolve to a built HTML page.

Packaging
---------

- Bumped `__version__` to `2.0.0`.
- Added `scipy>=1.10.0` to runtime dependencies (required by
  `nn.relevance._expi.ExpiFunction`; the `ARD` variants don't need it
  at forward time, but the import-time check is unavoidable).
- Added `[datasets]` extra (`h5py>=3.7`) — for MICCAI MRI.
- Added `[datasets-alos]` extra (`rasterio>=1.3`) — gated separately
  because rasterio wraps GDAL, which is a system-level dep on Linux/Mac.
…ced in CI

Stand up a complete `tests/` mirror of the `complextorch/` package and wire
`pytest --cov-fail-under=100` into the existing GitHub Actions workflow, so
any future PR that drops coverage fails CI automatically.

## Infrastructure

- `pyproject.toml`: expand `[project.optional-dependencies].test` to include
  `pytest>=8`, `pytest-cov>=5`, `pytest-xdist>=3`, and `hypothesis>=6`. Add
  `[tool.pytest.ini_options]` with `-n auto`, `testpaths = ["tests"]`, and
  `--strict-markers --strict-config -ra`. Add `[tool.coverage.run]` (line
  coverage, source = `complextorch`) and `[tool.coverage.report]` with
  `fail_under = 100` plus `exclude_lines` for `raise NotImplementedError`,
  `pragma: no cover`, `if TYPE_CHECKING:`, `@overload`, and bare ellipsis.
- `.github/workflows/test.yml`: add `--cov-fail-under=100` and
  `--cov-report=term-missing` to the pytest invocation; quote `'.[test]'`
  for shell safety.

## Test suite (488 tests, mirroring the package tree)

- `tests/conftest.py`: autouse seeded RNG (`torch`, `random`, `numpy`),
  `cplx(*shape)` factory, `device` fixture.
- Unit tests under `tests/nn/{,modules,modules/activation,modules/attention,
  masked,relevance,utils}/`, `tests/datasets/`, `tests/models/`, and
  `tests/transforms/`, one test file per source file.
- Property tests under `tests/invariants/` using Hypothesis + parametrize:
  Fast/Slow conv & linear numerical equivalence (state-dict-aligned
  weights), polar round-trip, casting round-trip, FFT round-trip.
- One parameterized test covers all 11 `NotImplementedError` dataset stubs
  (`PolSFDataset`, `Bretigny`, `S1SLC`, `MSTARTargets`, `ATRNetSTAR`,
  `MICCAI2023`, `ALOSDataset`, `VolFile`, `LeaderFile`, `TrailerFile`,
  `SARImage`).
- `_expi.py` validated with `scipy.special.expi` parity at five sample
  points plus `torch.autograd.gradcheck` for the analytical backward.
- Full reduction matrix (`'mean'` / `'sum'` / `'none'`) plus invalid-
  reduction `ValueError` checks for every loss in `loss.py`.

## Source fixes surfaced by writing the tests

- `loss.py`: `PerpLossSSIM.forward` was passing the complex `(x, y)` pair
  to the real-only SSIM conv, which would raise at first use. Pass the
  precomputed magnitudes (`mag_input`, `mag_target`) instead, matching the
  cited perpendicular-loss reference.
- `masked/base.py`: drop the unreachable
  `elif mask_in_missing: missing_keys.remove(mask_key)` branch in
  `_load_from_state_dict` — PyTorch's `load_state_dict` hard-codes
  `strict=True` when calling `_load_from_state_dict`, and `super()` only
  appends to `missing_keys` under `strict=True`, so the `strict=False AND
  mask_in_missing=True` precondition is unreachable. Also remove a dead
  `if weight.is_complex():` check in `MaskedWeightMixin.sparsity` whose
  branches returned identical values.
- `transforms/transforms.py`: simplify `_resize_spectrum` by removing the
  real-input fallbacks — the helper is only ever called from `FFTResize`,
  which always feeds it a complex spectrum (`fftshift(fft2(x))`).
- `rnn.py`: remove the unused `_maybe_bn` helper.
- `transforms/Normalize`: delete the no-op `broadcast` line (`x.dim() - x.dim() + 2` is always 2 and the value was never read) and tighten the validation to spell out the supported shapes — `(C, H, W)` or `(B, C, H, W)`, channel at dim `-3`. Matches the existing convention used elsewhere in `transforms.py` (e.g. `RealImaginary`).
- `nn.Dropout{1,2,3}d`: drop the unused `real_view = torch.view_as_real(input)` and the stale comment trail describing a `view_as_real`-based approach that was abandoned in favor of the manual Bernoulli mask. `_dropout_fn` is kept — it handles the non-complex fallback path.
- `nn.MagMaxPool*`: drop the unused `spatial_dims` local in `_gather_max_by_magnitude`.
- `docs/source/conf.py`: drop `sys.path.insert(0, "../../")` and the `os`/`sys` imports it required. Both supported doc-build paths (CI's `pip install .[docs]` and the documented local workflow) install `complextorch` before invoking Sphinx, so the path manipulation was a no-op. Also fixes E402 from the previously mid-file imports.

No behavior changes.
Ports the novel layers from the CVPR 2022 "Co-Domain Symmetry for
Complex-Valued Deep Learning" paper (Singhal, Xing, Yu) and adds the
missing SurReal/wFM (Chakraborty, Xing, Yu — arxiv 1910.11334) primitives
that complement the existing `wFMConv1d/2d`. All ports use native
`torch.cfloat` per the project convention.

New modules (`complextorch.nn`):
- `PhaseDivConv{1,2,3}d` — `x · conj(g(x)) / (|g(x)|² + ε)`; U(1)-invariant.
- `PhaseConjConv{1,2,3}d` — `x · conj(g(x))`; also U(1)-invariant when `g`
  is C-linear (which it is with native cfloat) — see docstring note on
  the divergence from the paper's "phase-mixing" claim.
- `GTReLU` — learnable complex scaling + half-plane phase mask with a
  custom autograd Function whose gradient is the mask itself.
- `EquivariantPhaseReLU` — full U(1)-equivariant ReLU via channel-mean
  reference phase.
- `ComplexScaling` — learnable `(α + jβ)·z`, sibling of `PhaseShift`.
- `MagBatchNorm{1,2,3}d` — magnitude-only BatchNorm (preserves phase →
  U(1)-equivariant), distinct from the existing covariance-whitening
  `BatchNorm{1,2,3}d`.
- `PrototypeDistance` — learnable complex prototypes + temperature.
  Supports an optional reference rotation for U(1)-equivariant networks.
- `wFMReLU`, `wFMDistanceLinear` — manifold-aware ReLU and a real-valued
  distance-to-Fréchet-mean head, completing the SurReal building blocks
  alongside the existing `wFMConv{1,2}d`.

Reference models (`complextorch.models`):
- `CDSInvariant` (paper "I-type"), `CDSEquivariant` ("E-type"), `CDSMSTAR`
  (SAR backbone with a real-valued ResNet-lite tail).

Tests:
- New `tests/invariants/test_equivariance.py` verifies U(1) equivariance
  / invariance of every applicable module via `M(x·e^{jψ}) = M(x)·e^{jψ}`
  (or `= M(x)`).
- New per-module tests: `test_phase_modulation`, `test_prototype`,
  `tests/models/test_cds.py`.
- Extended existing files with tests for `ComplexScaling`, `MagBatchNorm*`,
  `GTReLU`, `EquivariantPhaseReLU`, `wFMReLU`, `wFMDistanceLinear` plus a
  `gradcheck`-style test for the half-plane phase autograd Function.
- Fixed a stale regex in `tests/transforms/test_transforms.py::
  test_normalize_wrong_channel_dim` (matched `"channel dim"`, but the
  error message had been reworded to mention `"with C=3"`).
Renames the hand-rolled real/imag-split conv/linear implementations out of
the top-level `complextorch.nn` surface and into a `complextorch.nn.gauss`
subpackage. The original `Slow*` prefix was a misleading legacy from when
these were *faster* than the naive split — since PyTorch 2.1.0's native
complex kernels they're slower than the `dtype=torch.cfloat` wrappers, and
the prefix described an accident of history rather than what the layers
actually do.

Renames (breaking, `2.0.0`):

- `complextorch.nn.SlowLinear`                → `complextorch.nn.gauss.Linear`
- `complextorch.nn.SlowConv{1,2,3}d`          → `complextorch.nn.gauss.Conv{1,2,3}d`
- `complextorch.nn.SlowConvTranspose{1,2,3}d` → `complextorch.nn.gauss.ConvTranspose{1,2,3}d`

The native-cfloat `Linear` / `Conv{1,2,3}d` / `ConvTranspose{1,2,3}d` keep
their top-level names. `complextorch.nn.gauss` is registered as a
subpackage alongside `init`, `relevance`, `masked`, `utils` — the prior
commit already wired the import into `complextorch/nn/__init__.py`; this
commit adds the subpackage itself plus the rest of the rename.

Internal callers:

- `MaskedChannelAttention{1,2,3}d` (six `conv_down` / `conv_up` sites in
  `attention/mca.py`) switched from `cvnn.SlowConv*d` to the native
  `cvnn.Conv*d`. These are 1×1 convs that gain nothing from the Gauss
  trick — the original choice was a holdover.

Tests:

- `tests/nn/gauss/{test_linear,test_conv}.py` — new files mirroring the
  package layout convention, with the forward / weight-property /
  `bias=False` cases moved from `tests/nn/modules/`.
- `tests/nn/modules/test_{conv,linear}.py` — stripped of `Slow*` blocks;
  now cover only the native cfloat wrappers.
- `tests/invariants/test_fast_slow_equivalence.py` →
  `test_native_gauss_equivalence.py` with imports and helper names
  updated. The property tests continue to assert that the two paths agree
  to floating-point tolerance on shared weights.

Docs:

- `CLAUDE.md` "Fast vs. Slow modules" section retitled to "Native vs.
  Gauss-trick modules" and rewritten to point at the new subpackage; the
  legacy `Slow*` naming is called out so future readers understand the
  history.

Git-rename caveat: because `modules/conv.py` and `modules/linear.py` were
*split* (Gauss-trick content into `gauss/`, fast wrappers stayed put),
git's default rename detection (≥50% similarity) sees them as "modified"
rather than "renamed". Use `git log --follow -C30%` or `git blame -C` to
trace the Gauss-trick content's history into its new location.
Replace the Sphinx + sphinx_rtd_theme + hand-maintained per-module `.rst`
stubs with the stack used by NumPy, SciPy, and PyTorch:

  - PyData Sphinx Theme (3-column layout, version switcher, dark mode)
  - MyST parser so authored pages are Markdown (existing `.rst` keeps working)
  - sphinx-autoapi auto-generates the API tree from docstrings — adding a new
    public class to `complextorch.nn.*` now requires zero doc-tree edits
  - intersphinx mapping to torch / numpy / scipy / python so `:class:torch.nn.X`
    cross-references resolve as real links
  - myst-nb executes `examples/getting_started.md` on every build; if the
    public API breaks, the docs build fails (analogue of the 100% coverage gate)
  - sphinx-multiversion builds `/latest/` from `main` plus any post-migration
    semver tag, with a version switcher in the navbar
  - sphinx-copybutton, sphinx-design (grid cards), sphinx-sitemap, opengraph

Doc-tree restructuring:

  - `docs/source/index.{rst → md}`, `installation.md`, `about.md` — rewritten
    as MyST. Drops the hand-edited release-date string (`05/11/2026`); the
    PyData theme renders the version from `__init__.py` automatically.
  - New `docs/source/concepts/activations.md` and `concepts/native-vs-gauss.md`
    promote the Type-A / Type-B / fully-complex theory and the native-cfloat
    vs. Gauss-trick discussion out of the API tree into a user guide.
  - New `docs/source/examples/getting_started.md` — an executed MyST-NB
    notebook covering the README Conv1d demo, activation comparison, and
    `pwelch` PSD.
  - Deleted ~40 per-module `.rst` stubs under `docs/source/nn/` plus the
    top-level `signal.rst` / `transforms.rst` / `datasets.rst` / `models.rst`
    — sphinx-autoapi covers them.

Release notes:

  - New `CHANGELOG.md` at the repo root (keepachangelog format) holds the
    v2.0 and v1.2 notes that used to live inline in `index.rst`.
  - `docs/source/changelog.md` `include`s it, so PyPI release pages and the
    docs site share a single source of truth.
  - `pyproject.toml` now advertises `[project.urls] Changelog = …` so it
    surfaces on PyPI as well.

CI:

  - `.github/workflows/docs.yml` gets a `pull_request` job that runs
    `sphinx-build -W` (warnings-as-errors) as a merge gate, mirroring the
    `--cov-fail-under=100` gate in `test.yml`. Broken cross-refs, missing
    toctree entries, and notebook-execution failures all block merge.
  - The `main` deploy job switches to `sphinx-multiversion` with
    `fetch-depth: 0` + `fetch-tags`, copies a `redirect.html` so the root
    URL bounces to `/latest/`, and bumps Python to 3.12 to match `test.yml`.

Source-side cleanup (opportunistic, gated by the migration):

  - Convert ad-hoc Markdown links to `torch.nn.*` URLs into Sphinx
    `:class:torch.nn.X` roles across `nn/gauss/conv.py`, `nn/gauss/linear.py`,
    `nn/modules/{batchnorm,mask,pooling,linear}.py` — intersphinx now
    auto-links them.
  - Fix broken `:doc:` refs in `nn/modules/attention/{__init__,eca,mca}.py`
    that pointed at the now-removed per-module RST tree; rewrite as `:mod:`
    refs to the canonical Python modules.
  - Extend the title underlines in `nn/modules/casting.py` and `manifold.py`
    that were too short for their multi-byte (`→`, `é`) titles.
  - Fold the `Attributes: is_sparse` block in `BaseMasked` into prose so the
    `@property` is the single documented site.
  - Expand the one-line docstrings on `ToTensor` / `Unsqueeze` / `HWC2CHW` /
    `ToReal` / `ToImaginary` in `transforms/transforms.py` with Args /
    behavioural notes — autoapi renders them as the canonical reference.

README:

  - Drop the inline 1.2 release notes (now in `CHANGELOG.md`).
  - Point the documentation link at the new GitHub Pages URL and add a
    pointer to the Getting Started notebook.

Caveats:

  - `sphinx-multiversion` 0.2.4 is incompatible with Sphinx ≥ 8
    (`Config.read` signature changed). The docs extra is pinned to
    `sphinx<8` (which pulls myst-parser <5 / myst-nb <2 with it). Migrate
    to `sphinx-polyversion` to unblock newer Sphinx — left as a follow-up
    noted in `pyproject.toml` and `CLAUDE.md`.
  - The `sphinx-multiversion` whitelist is `^2\.[1-9]\d*\.\d+$` — historical
    tags pre-2.1 won't re-render (their `conf.py` predates this migration).
    Those releases stay accessible via PyPI.

Verified locally with `sphinx-build -W -b html docs/source docs/build/html`
— zero warnings, 111 HTML pages, intersphinx hits PyTorch's inventory, the
notebook executes and emits 14 cell input/output spans.
Apply the full `ruff check` ruleset configured in `pyproject.toml` (E/W, F, I,
B, UP, SIM, RUF, C4, PT, PIE, RET, TCH) and bring the tree to zero warnings so
the new `lint.yml` workflow and `pre-commit` hooks can land green.

Mechanical modernizations (autofixed):

- `super(Cls, self).__init__()` → `super().__init__()` (UP008)
- `Tuple[...]` / `Dict[...]` / `Optional[...]` / `Union[X, Y]` → PEP-585/604
  builtins and `X | None` / `X | Y` (UP006, UP007, UP045)
- `from typing import Callable, Sequence, Iterator` → `from collections.abc`
  (UP035)
- Sorted `__all__` and import blocks (RUF022, I001)
- `pytest.mark.parametrize("a, b", ...)` → tuple form (PT006)
- `dict(...)` literals in tests → `{...}` (C408)
- Dropped `pass` after docstring in alias subclasses (PIE790)
- `zip(...)` calls now pass `strict=False` to preserve current behavior (B905)
- Misc: yoda-condition flip, redundant `open(p, "r")`, unused unpacked vars,
  tuple-concat → unpacking, `__setattr__` explicit `return`, etc.

Manual fixes:

- `nn/modules/linear.py`: annotate `__constants__: ClassVar[list[str]]`
  (matches PyTorch's own annotation; silences RUF012 false positive on the
  TorchScript convention).
- `nn/modules/upsampling.py`: rewrite `Optional[int | tuple[int, ...]]` as
  `int | tuple[int, ...] | None` and drop the now-unused `Optional` import
  (UP045 doesn't autofix mixed PEP-604 / `Optional` aliases).
- `nn/masked/__init__.py`: remove leftover `Dict, Tuple` imports.
- `nn/modules/batchnorm.py`: `# noqa: SIM102` on the nested `if self.training
  and self.track_running_stats: / if self.num_batches_tracked is not None:`
  block so the implementation stays line-for-line comparable to
  `torch.nn.modules.batchnorm._BatchNorm.forward`.
- `nn/utils/sparsity.py`: rename unused `mod_name` loop var to `_mod_name`.
- `transforms/functional.py`: collapse `x.abs() if x.is_complex() else
  x.abs()` to `x.abs()` (RUF034 — both branches were identical).

Config:

- `pyproject.toml`: add `allowed-confusables = ["×"]` so intentional math
  notation in docstrings/comments (`2×2 whitening`, `weights × V`) passes
  RUF002/RUF003.
@josiahwsmith10 josiahwsmith10 self-assigned this May 11, 2026
@josiahwsmith10 josiahwsmith10 merged commit 713d9d6 into main May 11, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant