diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..f5fc701 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,103 @@ +name: docs + +on: + push: + branches: + - main + pull_request: + paths: + - "docs/**" + - "complextorch/**" + - "pyproject.toml" + - "CHANGELOG.md" + - ".github/workflows/docs.yml" + workflow_dispatch: + +# Allow only one concurrent deployment to GitHub Pages, skipping runs queued +# between the run in-progress and latest queued. Do not cancel in-progress runs. +concurrency: + group: pages + cancel-in-progress: false + +jobs: + # On PRs: build a single version with warnings-as-errors. This is the + # analogue of the --cov-fail-under=100 gate in test.yml — any broken + # cross-reference, missing toctree entry, or notebook execution failure + # fails the PR. + pr-check: + if: github.event_name == 'pull_request' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Cache pip + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-docs-${{ hashFiles('pyproject.toml') }} + + - name: Install complextorch with doc extras + run: | + python -m pip install --upgrade pip + python -m pip install .[docs] + + - name: Build HTML (warnings-as-errors) + run: sphinx-build -W -b html docs/source docs/build/html + + # On main: multi-version build + deploy to GitHub Pages. + build: + if: github.event_name != 'pull_request' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + # sphinx-multiversion needs the full history and all tags. + fetch-depth: 0 + + - name: Fetch all tags + run: git fetch --tags --force + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Cache pip + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-docs-${{ hashFiles('pyproject.toml') }} + + - name: Install complextorch with doc extras + run: | + python -m pip install --upgrade pip + python -m pip install .[docs] + + - name: Build multi-version HTML + run: sphinx-multiversion docs/source docs/build/html + + - name: Add root → latest redirect + run: cp docs/source/_templates/redirect.html docs/build/html/index.html + + - name: Upload Pages artifact + uses: actions/upload-pages-artifact@v3 + with: + path: docs/build/html + + deploy: + needs: build + if: github.event_name != 'pull_request' + runs-on: ubuntu-latest + permissions: + pages: write + id-token: write + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..e7e8804 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,33 @@ +name: lint + +on: + push: + pull_request: + +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Cache pip + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-lint-${{ hashFiles('pyproject.toml') }} + + - name: Install ruff + run: | + python -m pip install --upgrade pip + python -m pip install '.[dev]' + + - name: ruff check + run: ruff check + + - name: ruff format --check + run: ruff format --check diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml new file mode 100644 index 0000000..637d223 --- /dev/null +++ b/.github/workflows/pypi.yml @@ -0,0 +1,89 @@ +name: Publish Python distribution to PyPI + +on: + push: + tags: + - '[0-9]+.[0-9]+.[0-9]+' + +jobs: + build: + name: Build distribution + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Install build + run: python -m pip install --upgrade pip build + + - name: Build sdist and wheel + run: python -m build + + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + publish-to-pypi: + name: Publish to PyPI + needs: build + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/complextorch + permissions: + id-token: write # trusted publishing + + steps: + - name: Download dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + github-release: + name: Sign with Sigstore and publish GitHub Release + needs: publish-to-pypi + runs-on: ubuntu-latest + permissions: + contents: write + id-token: write + + steps: + - name: Download dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + - name: Sign dists with Sigstore + uses: sigstore/gh-action-sigstore-python@v3.0.0 + with: + inputs: >- + ./dist/*.tar.gz + ./dist/*.whl + + - name: Create GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + run: >- + gh release create '${{ github.ref_name }}' + --repo '${{ github.repository }}' + --notes "" + + - name: Upload artifacts and signatures + env: + GITHUB_TOKEN: ${{ github.token }} + run: >- + gh release upload + '${{ github.ref_name }}' dist/** + --repo '${{ github.repository }}' diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..51595ac --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,48 @@ +name: pytest + +on: + push: + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }} + + - name: Install complextorch with test extras + run: | + python -m pip install --upgrade pip + python -m pip install '.[test]' + + - name: Run pytest + run: | + pytest tests \ + --junitxml=junit/test-results-${{ matrix.python-version }}.xml \ + --cov=complextorch \ + --cov-report=xml \ + --cov-report=html \ + --cov-report=term-missing \ + --cov-fail-under=100 + + - name: Upload coverage HTML + if: matrix.python-version == '3.11' + uses: actions/upload-artifact@v4 + with: + name: htmlcov + path: htmlcov diff --git a/.gitignore b/.gitignore index 999a22e..2d641c9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,3 @@ -todo - # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -71,7 +69,7 @@ instance/ .scrapy # Sphinx documentation -docs/_build/ +docs/**/_build/ # PyBuilder .pybuilder/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..9327131 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +# Pre-commit hooks. Install with `pre-commit install`; update pins with +# `pre-commit autoupdate`. See https://pre-commit.com for the docs. +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.12 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/.readthedocs.yaml b/.readthedocs.yaml deleted file mode 100644 index 49ee8f7..0000000 --- a/.readthedocs.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# .readthedocs.yaml -# Read the Docs configuration file -# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details - -# Required -version: 2 - -# Set the OS, Python version and other tools you might need -build: - os: ubuntu-22.04 - tools: - python: "3.11" - -# Build documentation in the "docs/" directory with Sphinx -sphinx: - configuration: docs/source/conf.py - -python: - install: - - requirements: docs/requirements.txt \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..f1aa98e --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,143 @@ +# Changelog + +All notable changes to `complextorch` are documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [2.0.0] + +### Added + +- New top-level subpackages: `complextorch.signal` (`pwelch`), + `complextorch.transforms` (torchcvnn-style dataloader transforms — + `LogAmplitude`, `FFT2`, `IFFT2`, `FFTResize`, `PolSAR`, `Normalize`, + `RandomPhase`, …), `complextorch.datasets` (SAR / MRI dataset surface; + `SAMPLE` and `SLCDataset` are full implementations, the SAR/MRI-specific + readers are present as importable stubs with upstream pointers), and + `complextorch.models` (Vision Transformer with `vit_t/s/b/l/h` presets). +- `complextorch.nn.init`: `kaiming_normal_`, `kaiming_uniform_`, + `xavier_normal_`, `xavier_uniform_`, `trabelsi_standard_`, + `trabelsi_independent_` — variance-correct complex weight initialisers. + (PyTorch's built-ins treat real and imaginary parts independently, which is + wrong for complex magnitude.) +- `complextorch.nn.relevance` (complex Variational Dropout & Automatic + Relevance Determination) and `complextorch.nn.masked` (fixed-mask + sparsified layers) subsystems for learned-sparsity workflows. Adds + `LinearVD`, `LinearARD`, `BilinearVD/ARD`, `Conv{1,2,3}dVD/ARD`, + `LinearMasked`/`Conv*dMasked`, plus the deploy/extract helpers + `named_penalties`, `compute_ard_masks`, `deploy_masks`. Requires `scipy` + (new runtime dependency). +- RNN family: `GRUCell`, `GRU`, `LSTMCell`, `LSTM` (cell-based, with + optional `batchnorm=True` for stable deep stacks). +- Transformer family: `TransformerEncoderLayer`, `TransformerEncoder`, + `TransformerDecoderLayer`, `TransformerDecoder`, `Transformer`. +- Normalisation: `RMSNorm`, `GroupNorm`, `NaiveBatchNorm{1,2,3}d` + (split-form baseline). The functional whitening helpers + (`whiten2x2_batch_norm`, `whiten2x2_layer_norm`, `inv_sqrtm2x2`, + `batch_norm`, `layer_norm`) are now public in + `complextorch.nn.functional`. +- Pooling: `MagMaxPool{1,2,3}d` (magnitude-argmax, the canonical complex + max-pool — `torch.nn.MaxPool*d` doesn't define `>` on complex), + `AvgPool{1,2,3}d`. +- Channel dropout: `Dropout1d`, `Dropout2d`, `Dropout3d` with shared + real/imag mask (Trabelsi 2018). +- Upsampling: `Upsample` (split real/imag) and `PolarUpsample` + (phase-preserving polar form). +- Activations: `CELU`, `CCELU`, `CGELU` (split-type-A ELU/CELU/GELU + + `CVSplit*` aliases), `zAbsReLU`, `zLeakyReLU` (first-quadrant + leaky + variants), `Mod` (magnitude as module), `AdaptiveModReLU` (per-channel + learnable threshold). Existing `modReLU` gains a `learnable=True` flag for + a scalar trainable threshold. +- Layers: `Bilinear` (with `conjugate=True/False`), `InterleavedToComplex` / + `ComplexToInterleaved` / `ConcatenatedToComplex` / + `ComplexToConcatenated` / `RealToComplex` (layout-conversion modules), + `PhaseShift` (learnable per-channel phase rotation). +- Loss: `MSELoss` matching `torch.nn.MSELoss` exactly (no 1/2 factor — + distinct from `CVQuadError`). +- Optional dependencies gated behind extras: `complextorch[datasets]` pulls + in `h5py`; `complextorch[datasets-alos]` pulls in `rasterio`. +- Comprehensive test suite under `tests/`, mirroring the `complextorch/` + tree 1:1 (~490 tests). Covers every public class and helper, including + Fast/Slow numerical equivalence (state-dict-aligned weights), full loss + reduction matrix + invalid-reduction checks, Hypothesis-driven round-trip + invariants (polar, casting, FFT), `scipy.special.expi` parity + + `gradcheck` for `_expi`, and a parameterized sweep over the 11 dataset + stubs. +- `[test]` extras now pull in `pytest-xdist` (parallel runs via `-n auto`) + and `hypothesis` (property tests). + +### Changed + +- **BREAKING:** `MultiheadAttention` / `ScaledDotProductAttention` now use + the Hermitian inner product `QKᴴ` (was `QKᵀ` — a math bug). New + `softmax_on='complex'|'real'` flag selects the attention-weight semantics; + default `'complex'` keeps the existing `CVSoftMax` behaviour. +- **BREAKING:** `Linear` / `SlowLinear` / fast `Conv{1,2,3}d` / fast + `ConvTranspose{1,2,3}d` default `bias=True` to match `torch.nn`. Pass + `bias=False` explicitly if you relied on the old default. +- CI enforces `--cov-fail-under=100` on Python 3.10 / 3.11 / 3.12 — any PR + that drops line coverage fails automatically. Coverage config (omit list, + `exclude_lines` for `raise NotImplementedError` / `pragma: no cover` / + `if TYPE_CHECKING:` / `@overload`) lives in `pyproject.toml`. +- Documentation migrated to PyData Sphinx Theme + MyST + sphinx-autoapi. The + API reference is now auto-generated from docstrings; per-module `.rst` + stubs no longer need to be maintained by hand. +- `docs/` now ships an executable Getting Started notebook (`myst-nb`) which + re-runs on every build, so the public-API examples cannot rot. +- Intersphinx links to PyTorch / NumPy / SciPy so `:class:torch.nn.*` + references resolve. + +### Fixed + +- `PerpLossSSIM.forward` was passing the complex `(x, y)` pair to the + real-only SSIM conv, raising `RuntimeError` on first use. Now passes the + precomputed magnitudes (matching the cited perpendicular-loss reference). +- Removed dead branches surfaced by the coverage push: an unreachable + `elif mask_in_missing:` arm in `BaseMasked._load_from_state_dict` + (PyTorch's `load_state_dict` hard-codes `strict=True` when calling + `_load_from_state_dict`, so the precondition is never met), an `if + weight.is_complex():` check in `MaskedWeightMixin.sparsity` whose two + branches returned identical values, the real-input fallbacks in + `transforms._resize_spectrum` (only called with complex spectra from + `FFTResize`), and the unused `_maybe_bn` helper in `rnn.py`. + +## [1.2.0] + +### Removed + +- The legacy `CVTensor` API and its supporting helpers (`cat`, `roll`, + `from_polar`, `randn`, and the `torch.Tensor.rect` / `torch.Tensor.polar` + monkey-patch) have been removed. The package now operates exclusively on + complex-dtype `torch.Tensor` (typically `torch.cfloat`). Use + `torch.polar(abs, angle)` and `torch.randn(..., dtype=torch.cfloat)` + directly. + +### Fixed + +- Correctness in `SlowLinear` / `SlowConv*` / `SlowConvTranspose*` — the + Gauss-trick bias was previously off by `b_i * (1 + j)` when `bias=True`. + `SlowConv*` and `SlowConvTranspose*` now correctly forward `dilation` and + `output_padding`. The fast (native-cfloat) wrappers were unaffected. +- Complex-valued `BatchNorm*` eval-mode no longer broadcasts `running_mean` + against the wrong axes. +- `PhaseSigmoid` is now implemented (previously was an empty class). + `MagMinMaxNorm` now correctly preserves phase (previously it subtracted a + real scalar from a complex tensor). + +### Added + +- Fast `ConvTranspose1d` / `ConvTranspose2d` / `ConvTranspose3d` are now + exported from `complextorch.nn`. Their `output_padding` default matches + PyTorch's (`0`). +- Complex-valued losses (`CVQuadError`, `CVFourthPowError`, `CVCauchyError`, + `CVLogCoshError`, `CVLogError`) now accept a `reduction` argument + (`'mean'` | `'sum'` | `'none'`), defaulting to `'mean'`. +- `complextorch.nn.Conv1d` (and its 2-D / 3-D / transposed siblings) wrap + `torch.nn.Conv1d` with `dtype=torch.cfloat` for maximum efficiency. The + hand-rolled real/imag-split convolutions remain available under the + `Slow` prefix. + +[Unreleased]: https://github.com/josiahwsmith10/complextorch/compare/2.0.0...HEAD +[2.0.0]: https://github.com/josiahwsmith10/complextorch/releases/tag/2.0.0 +[1.2.0]: https://github.com/josiahwsmith10/complextorch/releases/tag/1.2.0 diff --git a/CLAUDE.md b/CLAUDE.md index d9425ae..4f8141e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,26 +8,42 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## Common commands -PyTorch is **not** declared as a normal install dep in practice — `requirements.txt` pins `torch>=1.11.0+cu115` (a CUDA build), so installing it via `pip install -r requirements.txt` will usually fail on non-CUDA machines. Install torch separately first (see https://pytorch.org/get-started/locally/), then: +```sh +pip install . # install the package +pip install .[test] # with test extras (pytest, pytest-cov, pytest-xdist, hypothesis) +pip install .[docs] # with doc extras (PyData theme, MyST, autoapi, myst-nb, multiversion) +``` + +Tests (mirror layout of `complextorch/` under `tests/`): ```sh -pip install . --use-pep517 # install the package from source +pip install .[test] +pytest # auto-parallel via -n auto from pyproject +pytest --cov=complextorch --cov-report=term-missing --cov-fail-under=100 # mirror CI ``` -Docs (Sphinx, hosted on Read the Docs, configured by `.readthedocs.yaml`): +Docs (Sphinx, deployed to GitHub Pages by `.github/workflows/docs.yml`): ```sh -pip install -r docs/requirements.txt -cd docs && make html # output in docs/build/html +pip install .[docs] +sphinx-build -W -b html docs/source docs/build/html # single-version, warnings-as-errors +sphinx-multiversion docs/source docs/build/html # multi-version (matches CI) ``` -There is **no test suite, linter config, or CI** in the repo. Don't fabricate test/lint commands — if a change needs verification, write an ad-hoc script or ask the user. +The doc stack is **PyData Sphinx Theme + MyST + sphinx-autoapi + myst-nb + sphinx-multiversion**. Notable: + +- `docs/source/conf.py` is the canonical config. The version string is read dynamically from `complextorch/__init__.py:__version__`. +- API reference under `docs/source/api/` is auto-generated by sphinx-autoapi from docstrings — **never hand-edit those `.rst` files**. To add a new public symbol, just export it from the right `__init__.py`; the doc tree updates on next build. +- `docs/source/examples/getting_started.md` is a MyST-NB notebook re-executed on every build. If you break the public API, this notebook will fail and the doc build will fail. +- The doc stack is pinned to Sphinx 7 because `sphinx-multiversion` 0.2.4 is incompatible with Sphinx 8+ (see comment in `pyproject.toml`). Migrate to `sphinx-polyversion` to unblock newer Sphinx. + +CI lives in `.github/workflows/`: `test.yml` (pytest on 3.10/3.11/3.12 with `--cov-fail-under=100` — any coverage drop fails CI), `docs.yml` (per-PR `sphinx-build -W` gate + per-merge `sphinx-multiversion` deploy to GitHub Pages), `ci-cd.yml` (PyPI publish + Sigstore on semver tag). ## Releasing a new version -`complextorch/__init__.py:__version__` is the single source of truth. `setup.py` parses it via regex, `docs/source/conf.py` does the same, and `docs/source/index.rst` displays it via the Sphinx `|release|` substitution. To bump the version, edit `__init__.py` only. +`complextorch/__init__.py:__version__` is the single source of truth. `pyproject.toml` reads it via `[tool.setuptools.dynamic] version = {attr = "complextorch.__version__"}` and `docs/source/conf.py` parses it via regex. To bump the version, edit `__init__.py`, add an entry under `## [Unreleased]` in `CHANGELOG.md`, then `git tag X.Y.Z && git push --tags` to trigger `ci-cd.yml`. -The release-date string in `docs/source/index.rst` (`:Version: |release| of `) and the release-notes section there are still per-release manual edits. +Release notes live in `CHANGELOG.md` (keepachangelog format) and are surfaced in the docs via `docs/source/changelog.md`. The `sphinx-multiversion` whitelist in `conf.py` is `^2\.[1-9]\d*\.\d+$` — bump the regex when cutting a 3.x. ## Architecture @@ -35,12 +51,12 @@ The release-date string in `docs/source/index.rst` (`:Version: |release| of - Everyone is permitted to copy and distribute verbatim copies - of this license document, but changing it is not allowed. - - Preamble - - The GNU General Public License is a free, copyleft license for -software and other kinds of works. - - The licenses for most software and other practical works are designed -to take away your freedom to share and change the works. By contrast, -the GNU General Public License is intended to guarantee your freedom to -share and change all versions of a program--to make sure it remains free -software for all its users. We, the Free Software Foundation, use the -GNU General Public License for most of our software; it applies also to -any other work released this way by its authors. You can apply it to -your programs, too. - - When we speak of free software, we are referring to freedom, not -price. Our General Public Licenses are designed to make sure that you -have the freedom to distribute copies of free software (and charge for -them if you wish), that you receive source code or can get it if you -want it, that you can change the software or use pieces of it in new -free programs, and that you know you can do these things. - - To protect your rights, we need to prevent others from denying you -these rights or asking you to surrender the rights. Therefore, you have -certain responsibilities if you distribute copies of the software, or if -you modify it: responsibilities to respect the freedom of others. - - For example, if you distribute copies of such a program, whether -gratis or for a fee, you must pass on to the recipients the same -freedoms that you received. You must make sure that they, too, receive -or can get the source code. And you must show them these terms so they -know their rights. - - Developers that use the GNU GPL protect your rights with two steps: -(1) assert copyright on the software, and (2) offer you this License -giving you legal permission to copy, distribute and/or modify it. - - For the developers' and authors' protection, the GPL clearly explains -that there is no warranty for this free software. For both users' and -authors' sake, the GPL requires that modified versions be marked as -changed, so that their problems will not be attributed erroneously to -authors of previous versions. - - Some devices are designed to deny users access to install or run -modified versions of the software inside them, although the manufacturer -can do so. This is fundamentally incompatible with the aim of -protecting users' freedom to change the software. The systematic -pattern of such abuse occurs in the area of products for individuals to -use, which is precisely where it is most unacceptable. Therefore, we -have designed this version of the GPL to prohibit the practice for those -products. If such problems arise substantially in other domains, we -stand ready to extend this provision to those domains in future versions -of the GPL, as needed to protect the freedom of users. - - Finally, every program is threatened constantly by software patents. -States should not allow patents to restrict development and use of -software on general-purpose computers, but in those that do, we wish to -avoid the special danger that patents applied to a free program could -make it effectively proprietary. To prevent this, the GPL assures that -patents cannot be used to render the program non-free. - - The precise terms and conditions for copying, distribution and -modification follow. - - TERMS AND CONDITIONS - - 0. Definitions. - - "This License" refers to version 3 of the GNU General Public License. - - "Copyright" also means copyright-like laws that apply to other kinds of -works, such as semiconductor masks. - - "The Program" refers to any copyrightable work licensed under this -License. Each licensee is addressed as "you". "Licensees" and -"recipients" may be individuals or organizations. - - To "modify" a work means to copy from or adapt all or part of the work -in a fashion requiring copyright permission, other than the making of an -exact copy. The resulting work is called a "modified version" of the -earlier work or a work "based on" the earlier work. - - A "covered work" means either the unmodified Program or a work based -on the Program. - - To "propagate" a work means to do anything with it that, without -permission, would make you directly or secondarily liable for -infringement under applicable copyright law, except executing it on a -computer or modifying a private copy. Propagation includes copying, -distribution (with or without modification), making available to the -public, and in some countries other activities as well. - - To "convey" a work means any kind of propagation that enables other -parties to make or receive copies. Mere interaction with a user through -a computer network, with no transfer of a copy, is not conveying. - - An interactive user interface displays "Appropriate Legal Notices" -to the extent that it includes a convenient and prominently visible -feature that (1) displays an appropriate copyright notice, and (2) -tells the user that there is no warranty for the work (except to the -extent that warranties are provided), that licensees may convey the -work under this License, and how to view a copy of this License. If -the interface presents a list of user commands or options, such as a -menu, a prominent item in the list meets this criterion. - - 1. Source Code. - - The "source code" for a work means the preferred form of the work -for making modifications to it. "Object code" means any non-source -form of a work. - - A "Standard Interface" means an interface that either is an official -standard defined by a recognized standards body, or, in the case of -interfaces specified for a particular programming language, one that -is widely used among developers working in that language. - - The "System Libraries" of an executable work include anything, other -than the work as a whole, that (a) is included in the normal form of -packaging a Major Component, but which is not part of that Major -Component, and (b) serves only to enable use of the work with that -Major Component, or to implement a Standard Interface for which an -implementation is available to the public in source code form. A -"Major Component", in this context, means a major essential component -(kernel, window system, and so on) of the specific operating system -(if any) on which the executable work runs, or a compiler used to -produce the work, or an object code interpreter used to run it. - - The "Corresponding Source" for a work in object code form means all -the source code needed to generate, install, and (for an executable -work) run the object code and to modify the work, including scripts to -control those activities. However, it does not include the work's -System Libraries, or general-purpose tools or generally available free -programs which are used unmodified in performing those activities but -which are not part of the work. For example, Corresponding Source -includes interface definition files associated with source files for -the work, and the source code for shared libraries and dynamically -linked subprograms that the work is specifically designed to require, -such as by intimate data communication or control flow between those -subprograms and other parts of the work. - - The Corresponding Source need not include anything that users -can regenerate automatically from other parts of the Corresponding -Source. - - The Corresponding Source for a work in source code form is that -same work. - - 2. Basic Permissions. - - All rights granted under this License are granted for the term of -copyright on the Program, and are irrevocable provided the stated -conditions are met. This License explicitly affirms your unlimited -permission to run the unmodified Program. The output from running a -covered work is covered by this License only if the output, given its -content, constitutes a covered work. This License acknowledges your -rights of fair use or other equivalent, as provided by copyright law. - - You may make, run and propagate covered works that you do not -convey, without conditions so long as your license otherwise remains -in force. You may convey covered works to others for the sole purpose -of having them make modifications exclusively for you, or provide you -with facilities for running those works, provided that you comply with -the terms of this License in conveying all material for which you do -not control copyright. Those thus making or running the covered works -for you must do so exclusively on your behalf, under your direction -and control, on terms that prohibit them from making any copies of -your copyrighted material outside their relationship with you. - - Conveying under any other circumstances is permitted solely under -the conditions stated below. Sublicensing is not allowed; section 10 -makes it unnecessary. - - 3. Protecting Users' Legal Rights From Anti-Circumvention Law. - - No covered work shall be deemed part of an effective technological -measure under any applicable law fulfilling obligations under article -11 of the WIPO copyright treaty adopted on 20 December 1996, or -similar laws prohibiting or restricting circumvention of such -measures. - - When you convey a covered work, you waive any legal power to forbid -circumvention of technological measures to the extent such circumvention -is effected by exercising rights under this License with respect to -the covered work, and you disclaim any intention to limit operation or -modification of the work as a means of enforcing, against the work's -users, your or third parties' legal rights to forbid circumvention of -technological measures. - - 4. Conveying Verbatim Copies. - - You may convey verbatim copies of the Program's source code as you -receive it, in any medium, provided that you conspicuously and -appropriately publish on each copy an appropriate copyright notice; -keep intact all notices stating that this License and any -non-permissive terms added in accord with section 7 apply to the code; -keep intact all notices of the absence of any warranty; and give all -recipients a copy of this License along with the Program. - - You may charge any price or no price for each copy that you convey, -and you may offer support or warranty protection for a fee. - - 5. Conveying Modified Source Versions. - - You may convey a work based on the Program, or the modifications to -produce it from the Program, in the form of source code under the -terms of section 4, provided that you also meet all of these conditions: - - a) The work must carry prominent notices stating that you modified - it, and giving a relevant date. - - b) The work must carry prominent notices stating that it is - released under this License and any conditions added under section - 7. This requirement modifies the requirement in section 4 to - "keep intact all notices". - - c) You must license the entire work, as a whole, under this - License to anyone who comes into possession of a copy. This - License will therefore apply, along with any applicable section 7 - additional terms, to the whole of the work, and all its parts, - regardless of how they are packaged. This License gives no - permission to license the work in any other way, but it does not - invalidate such permission if you have separately received it. - - d) If the work has interactive user interfaces, each must display - Appropriate Legal Notices; however, if the Program has interactive - interfaces that do not display Appropriate Legal Notices, your - work need not make them do so. - - A compilation of a covered work with other separate and independent -works, which are not by their nature extensions of the covered work, -and which are not combined with it such as to form a larger program, -in or on a volume of a storage or distribution medium, is called an -"aggregate" if the compilation and its resulting copyright are not -used to limit the access or legal rights of the compilation's users -beyond what the individual works permit. Inclusion of a covered work -in an aggregate does not cause this License to apply to the other -parts of the aggregate. - - 6. Conveying Non-Source Forms. - - You may convey a covered work in object code form under the terms -of sections 4 and 5, provided that you also convey the -machine-readable Corresponding Source under the terms of this License, -in one of these ways: - - a) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by the - Corresponding Source fixed on a durable physical medium - customarily used for software interchange. - - b) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by a - written offer, valid for at least three years and valid for as - long as you offer spare parts or customer support for that product - model, to give anyone who possesses the object code either (1) a - copy of the Corresponding Source for all the software in the - product that is covered by this License, on a durable physical - medium customarily used for software interchange, for a price no - more than your reasonable cost of physically performing this - conveying of source, or (2) access to copy the - Corresponding Source from a network server at no charge. - - c) Convey individual copies of the object code with a copy of the - written offer to provide the Corresponding Source. This - alternative is allowed only occasionally and noncommercially, and - only if you received the object code with such an offer, in accord - with subsection 6b. - - d) Convey the object code by offering access from a designated - place (gratis or for a charge), and offer equivalent access to the - Corresponding Source in the same way through the same place at no - further charge. You need not require recipients to copy the - Corresponding Source along with the object code. If the place to - copy the object code is a network server, the Corresponding Source - may be on a different server (operated by you or a third party) - that supports equivalent copying facilities, provided you maintain - clear directions next to the object code saying where to find the - Corresponding Source. Regardless of what server hosts the - Corresponding Source, you remain obligated to ensure that it is - available for as long as needed to satisfy these requirements. - - e) Convey the object code using peer-to-peer transmission, provided - you inform other peers where the object code and Corresponding - Source of the work are being offered to the general public at no - charge under subsection 6d. - - A separable portion of the object code, whose source code is excluded -from the Corresponding Source as a System Library, need not be -included in conveying the object code work. - - A "User Product" is either (1) a "consumer product", which means any -tangible personal property which is normally used for personal, family, -or household purposes, or (2) anything designed or sold for incorporation -into a dwelling. In determining whether a product is a consumer product, -doubtful cases shall be resolved in favor of coverage. For a particular -product received by a particular user, "normally used" refers to a -typical or common use of that class of product, regardless of the status -of the particular user or of the way in which the particular user -actually uses, or expects or is expected to use, the product. A product -is a consumer product regardless of whether the product has substantial -commercial, industrial or non-consumer uses, unless such uses represent -the only significant mode of use of the product. - - "Installation Information" for a User Product means any methods, -procedures, authorization keys, or other information required to install -and execute modified versions of a covered work in that User Product from -a modified version of its Corresponding Source. The information must -suffice to ensure that the continued functioning of the modified object -code is in no case prevented or interfered with solely because -modification has been made. - - If you convey an object code work under this section in, or with, or -specifically for use in, a User Product, and the conveying occurs as -part of a transaction in which the right of possession and use of the -User Product is transferred to the recipient in perpetuity or for a -fixed term (regardless of how the transaction is characterized), the -Corresponding Source conveyed under this section must be accompanied -by the Installation Information. But this requirement does not apply -if neither you nor any third party retains the ability to install -modified object code on the User Product (for example, the work has -been installed in ROM). - - The requirement to provide Installation Information does not include a -requirement to continue to provide support service, warranty, or updates -for a work that has been modified or installed by the recipient, or for -the User Product in which it has been modified or installed. Access to a -network may be denied when the modification itself materially and -adversely affects the operation of the network or violates the rules and -protocols for communication across the network. - - Corresponding Source conveyed, and Installation Information provided, -in accord with this section must be in a format that is publicly -documented (and with an implementation available to the public in -source code form), and must require no special password or key for -unpacking, reading or copying. - - 7. Additional Terms. - - "Additional permissions" are terms that supplement the terms of this -License by making exceptions from one or more of its conditions. -Additional permissions that are applicable to the entire Program shall -be treated as though they were included in this License, to the extent -that they are valid under applicable law. If additional permissions -apply only to part of the Program, that part may be used separately -under those permissions, but the entire Program remains governed by -this License without regard to the additional permissions. - - When you convey a copy of a covered work, you may at your option -remove any additional permissions from that copy, or from any part of -it. (Additional permissions may be written to require their own -removal in certain cases when you modify the work.) You may place -additional permissions on material, added by you to a covered work, -for which you have or can give appropriate copyright permission. - - Notwithstanding any other provision of this License, for material you -add to a covered work, you may (if authorized by the copyright holders of -that material) supplement the terms of this License with terms: - - a) Disclaiming warranty or limiting liability differently from the - terms of sections 15 and 16 of this License; or - - b) Requiring preservation of specified reasonable legal notices or - author attributions in that material or in the Appropriate Legal - Notices displayed by works containing it; or - - c) Prohibiting misrepresentation of the origin of that material, or - requiring that modified versions of such material be marked in - reasonable ways as different from the original version; or - - d) Limiting the use for publicity purposes of names of licensors or - authors of the material; or - - e) Declining to grant rights under trademark law for use of some - trade names, trademarks, or service marks; or - - f) Requiring indemnification of licensors and authors of that - material by anyone who conveys the material (or modified versions of - it) with contractual assumptions of liability to the recipient, for - any liability that these contractual assumptions directly impose on - those licensors and authors. - - All other non-permissive additional terms are considered "further -restrictions" within the meaning of section 10. If the Program as you -received it, or any part of it, contains a notice stating that it is -governed by this License along with a term that is a further -restriction, you may remove that term. If a license document contains -a further restriction but permits relicensing or conveying under this -License, you may add to a covered work material governed by the terms -of that license document, provided that the further restriction does -not survive such relicensing or conveying. - - If you add terms to a covered work in accord with this section, you -must place, in the relevant source files, a statement of the -additional terms that apply to those files, or a notice indicating -where to find the applicable terms. - - Additional terms, permissive or non-permissive, may be stated in the -form of a separately written license, or stated as exceptions; -the above requirements apply either way. - - 8. Termination. - - You may not propagate or modify a covered work except as expressly -provided under this License. Any attempt otherwise to propagate or -modify it is void, and will automatically terminate your rights under -this License (including any patent licenses granted under the third -paragraph of section 11). - - However, if you cease all violation of this License, then your -license from a particular copyright holder is reinstated (a) -provisionally, unless and until the copyright holder explicitly and -finally terminates your license, and (b) permanently, if the copyright -holder fails to notify you of the violation by some reasonable means -prior to 60 days after the cessation. - - Moreover, your license from a particular copyright holder is -reinstated permanently if the copyright holder notifies you of the -violation by some reasonable means, this is the first time you have -received notice of violation of this License (for any work) from that -copyright holder, and you cure the violation prior to 30 days after -your receipt of the notice. - - Termination of your rights under this section does not terminate the -licenses of parties who have received copies or rights from you under -this License. If your rights have been terminated and not permanently -reinstated, you do not qualify to receive new licenses for the same -material under section 10. - - 9. Acceptance Not Required for Having Copies. - - You are not required to accept this License in order to receive or -run a copy of the Program. Ancillary propagation of a covered work -occurring solely as a consequence of using peer-to-peer transmission -to receive a copy likewise does not require acceptance. However, -nothing other than this License grants you permission to propagate or -modify any covered work. These actions infringe copyright if you do -not accept this License. Therefore, by modifying or propagating a -covered work, you indicate your acceptance of this License to do so. - - 10. Automatic Licensing of Downstream Recipients. - - Each time you convey a covered work, the recipient automatically -receives a license from the original licensors, to run, modify and -propagate that work, subject to this License. You are not responsible -for enforcing compliance by third parties with this License. - - An "entity transaction" is a transaction transferring control of an -organization, or substantially all assets of one, or subdividing an -organization, or merging organizations. If propagation of a covered -work results from an entity transaction, each party to that -transaction who receives a copy of the work also receives whatever -licenses to the work the party's predecessor in interest had or could -give under the previous paragraph, plus a right to possession of the -Corresponding Source of the work from the predecessor in interest, if -the predecessor has it or can get it with reasonable efforts. - - You may not impose any further restrictions on the exercise of the -rights granted or affirmed under this License. For example, you may -not impose a license fee, royalty, or other charge for exercise of -rights granted under this License, and you may not initiate litigation -(including a cross-claim or counterclaim in a lawsuit) alleging that -any patent claim is infringed by making, using, selling, offering for -sale, or importing the Program or any portion of it. - - 11. Patents. - - A "contributor" is a copyright holder who authorizes use under this -License of the Program or a work on which the Program is based. The -work thus licensed is called the contributor's "contributor version". - - A contributor's "essential patent claims" are all patent claims -owned or controlled by the contributor, whether already acquired or -hereafter acquired, that would be infringed by some manner, permitted -by this License, of making, using, or selling its contributor version, -but do not include claims that would be infringed only as a -consequence of further modification of the contributor version. For -purposes of this definition, "control" includes the right to grant -patent sublicenses in a manner consistent with the requirements of -this License. - - Each contributor grants you a non-exclusive, worldwide, royalty-free -patent license under the contributor's essential patent claims, to -make, use, sell, offer for sale, import and otherwise run, modify and -propagate the contents of its contributor version. - - In the following three paragraphs, a "patent license" is any express -agreement or commitment, however denominated, not to enforce a patent -(such as an express permission to practice a patent or covenant not to -sue for patent infringement). To "grant" such a patent license to a -party means to make such an agreement or commitment not to enforce a -patent against the party. - - If you convey a covered work, knowingly relying on a patent license, -and the Corresponding Source of the work is not available for anyone -to copy, free of charge and under the terms of this License, through a -publicly available network server or other readily accessible means, -then you must either (1) cause the Corresponding Source to be so -available, or (2) arrange to deprive yourself of the benefit of the -patent license for this particular work, or (3) arrange, in a manner -consistent with the requirements of this License, to extend the patent -license to downstream recipients. "Knowingly relying" means you have -actual knowledge that, but for the patent license, your conveying the -covered work in a country, or your recipient's use of the covered work -in a country, would infringe one or more identifiable patents in that -country that you have reason to believe are valid. - - If, pursuant to or in connection with a single transaction or -arrangement, you convey, or propagate by procuring conveyance of, a -covered work, and grant a patent license to some of the parties -receiving the covered work authorizing them to use, propagate, modify -or convey a specific copy of the covered work, then the patent license -you grant is automatically extended to all recipients of the covered -work and works based on it. - - A patent license is "discriminatory" if it does not include within -the scope of its coverage, prohibits the exercise of, or is -conditioned on the non-exercise of one or more of the rights that are -specifically granted under this License. You may not convey a covered -work if you are a party to an arrangement with a third party that is -in the business of distributing software, under which you make payment -to the third party based on the extent of your activity of conveying -the work, and under which the third party grants, to any of the -parties who would receive the covered work from you, a discriminatory -patent license (a) in connection with copies of the covered work -conveyed by you (or copies made from those copies), or (b) primarily -for and in connection with specific products or compilations that -contain the covered work, unless you entered into that arrangement, -or that patent license was granted, prior to 28 March 2007. - - Nothing in this License shall be construed as excluding or limiting -any implied license or other defenses to infringement that may -otherwise be available to you under applicable patent law. - - 12. No Surrender of Others' Freedom. - - If conditions are imposed on you (whether by court order, agreement or -otherwise) that contradict the conditions of this License, they do not -excuse you from the conditions of this License. If you cannot convey a -covered work so as to satisfy simultaneously your obligations under this -License and any other pertinent obligations, then as a consequence you may -not convey it at all. For example, if you agree to terms that obligate you -to collect a royalty for further conveying from those to whom you convey -the Program, the only way you could satisfy both those terms and this -License would be to refrain entirely from conveying the Program. - - 13. Use with the GNU Affero General Public License. - - Notwithstanding any other provision of this License, you have -permission to link or combine any covered work with a work licensed -under version 3 of the GNU Affero General Public License into a single -combined work, and to convey the resulting work. The terms of this -License will continue to apply to the part which is the covered work, -but the special requirements of the GNU Affero General Public License, -section 13, concerning interaction through a network will apply to the -combination as such. - - 14. Revised Versions of this License. - - The Free Software Foundation may publish revised and/or new versions of -the GNU General Public License from time to time. Such new versions will -be similar in spirit to the present version, but may differ in detail to -address new problems or concerns. - - Each version is given a distinguishing version number. If the -Program specifies that a certain numbered version of the GNU General -Public License "or any later version" applies to it, you have the -option of following the terms and conditions either of that numbered -version or of any later version published by the Free Software -Foundation. If the Program does not specify a version number of the -GNU General Public License, you may choose any version ever published -by the Free Software Foundation. - - If the Program specifies that a proxy can decide which future -versions of the GNU General Public License can be used, that proxy's -public statement of acceptance of a version permanently authorizes you -to choose that version for the Program. - - Later license versions may give you additional or different -permissions. However, no additional obligations are imposed on any -author or copyright holder as a result of your choosing to follow a -later version. - - 15. Disclaimer of Warranty. - - THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY -APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT -HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY -OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, -THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM -IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF -ALL NECESSARY SERVICING, REPAIR OR CORRECTION. - - 16. Limitation of Liability. - - IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING -WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS -THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY -GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE -USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF -DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD -PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), -EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF -SUCH DAMAGES. - - 17. Interpretation of Sections 15 and 16. - - If the disclaimer of warranty and limitation of liability provided -above cannot be given local legal effect according to their terms, -reviewing courts shall apply local law that most closely approximates -an absolute waiver of all civil liability in connection with the -Program, unless a warranty or assumption of liability accompanies a -copy of the Program in return for a fee. - - END OF TERMS AND CONDITIONS - - How to Apply These Terms to Your New Programs - - If you develop a new program, and you want it to be of the greatest -possible use to the public, the best way to achieve this is to make it -free software which everyone can redistribute and change under these terms. - - To do so, attach the following notices to the program. It is safest -to attach them to the start of each source file to most effectively -state the exclusion of warranty; and each file should have at least -the "copyright" line and a pointer to where the full notice is found. - - complextorch: A lightweight complex-valued neural network package built on PyTorch - Copyright (C) 2025 Josiah W. Smith - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . - -Also add information on how to contact you by electronic and paper mail. - - If the program does terminal interaction, make it output a short -notice like this when it starts in an interactive mode: - - complextorch Copyright (C) 2025 Josiah W. Smith - This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. - This is free software, and you are welcome to redistribute it - under certain conditions; type `show c' for details. - -The hypothetical commands `show w' and `show c' should show the appropriate -parts of the General Public License. Of course, your program's commands -might be different; for a GUI interface, you would use an "about box". - - You should also get your employer (if you work as a programmer) or school, -if any, to sign a "copyright disclaimer" for the program, if necessary. -For more information on this, and how to apply and follow the GNU GPL, see -. - - The GNU General Public License does not permit incorporating your program -into proprietary programs. If your program is a subroutine library, you -may consider it more useful to permit linking proprietary applications with -the library. If this is what you want to do, use the GNU Lesser General -Public License instead of this License. But first, please read -. \ No newline at end of file + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of tracking or improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for describing the origin of the Work and + reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may accept and charge a + fee for acceptance of support, warranty, indemnity, or other liability + obligations and/or rights consistent with this License. However, in + accepting such obligations, You may act only on Your own behalf + and on Your sole responsibility, not on behalf of any other + Contributor, and only if You agree to indemnify, defend, and hold + each Contributor harmless for any liability incurred by, or claims + asserted against, such Contributor by reason of your accepting any + such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2024-2026 Josiah W. Smith + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 3d3aec5..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1 +0,0 @@ -recursive-include complextorch * \ No newline at end of file diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000..7e17f5f --- /dev/null +++ b/NOTICE @@ -0,0 +1,150 @@ +complextorch +Copyright 2024-2026 Josiah W. Smith + +This product includes software developed by Josiah W. Smith and contributors. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not +use this product except in compliance with the License. You may obtain a copy +of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + +------------------------------------------------------------------------------ +Third-party components +------------------------------------------------------------------------------ + +The following third-party components have been incorporated into this project +under their original licenses. Their copyright notices are reproduced below as +required by their respective licenses. Each component remains licensed under +its original MIT terms within the files listed; the project as a whole is +distributed under Apache License 2.0. + +------------------------------------------------------------------------------ +Component: complexPyTorch +Source: https://github.com/wavefrontshaping/complexPyTorch +License: MIT +Files in this repository derived from or inspired by this component: + - complextorch/nn/functional.py + (apply_complex primitive; whitening batch_norm / layer_norm helpers) + - complextorch/nn/modules/batchnorm.py + (whitening BatchNorm*d; NaiveBatchNorm*d shares the split-real/imag pattern) + - complextorch/nn/modules/dropout.py + (Dropout1d/2d/3d — shared real/imag mask, Trabelsi 2018 pattern) + + Copyright (c) 2019 Sébastien M. P. + + Permission is hereby granted, free of charge, to any person obtaining a + copy of this software and associated documentation files (the "Software"), + to deal in the Software without restriction, including without limitation + the rights to use, copy, modify, merge, publish, distribute, sublicense, + and/or sell copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + +------------------------------------------------------------------------------ +Component: cplxmodule +Source: https://github.com/ivannz/cplxmodule +License: MIT +Files in this repository derived from or inspired by this component: + - complextorch/nn/modules/casting.py + (InterleavedToComplex, ConcatenatedToComplex, RealToComplex and inverses — + replaces cplxmodule's AsTypeCplx / ConcatenatedRealToCplx family) + - complextorch/nn/modules/pooling.py + (MagMaxPool1d/2d/3d — magnitude-argmax pooling via torch.gather) + - complextorch/nn/modules/activation/complex_relu.py + (zAbsReLU, zLeakyReLU) + - complextorch/nn/init.py + (trabelsi_standard_, trabelsi_independent_) + - complextorch/nn/relevance/ (entire subpackage) + (BaseARD, ExpiFunction, LinearVD/ARD, BilinearVD/ARD, Conv*dVD/ARD, + module-walking helpers) + - complextorch/nn/masked/ (entire subpackage) + (BaseMasked, MaskedWeightMixin, LinearMasked, BilinearMasked, + Conv*dMasked, deploy_masks, binarize_masks, named_masks) + - complextorch/nn/utils/sparsity.py + (SparsityStats, named_sparsity, sparsity helpers) + + Copyright (c) 2019-present Ivan Nazarov + + Permission is hereby granted, free of charge, to any person obtaining a + copy of this software and associated documentation files (the "Software"), + to deal in the Software without restriction, including without limitation + the rights to use, copy, modify, merge, publish, distribute, sublicense, + and/or sell copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + +------------------------------------------------------------------------------ +Component: torchcvnn +Source: https://github.com/torchcvnn/torchcvnn +License: MIT +Files in this repository derived from or inspired by this component: + - complextorch/nn/modules/attention/__init__.py + (MultiheadAttention softmax_on='real' formulation; Hermitian QK^H fix + cross-referenced against torchcvnn's implementation) + - complextorch/nn/modules/transformer.py + (TransformerEncoderLayer / Encoder / DecoderLayer / Decoder / Transformer) + - complextorch/models/vit.py + (ViTLayer, ViT, and the vit_t/s/b/l/h presets) + - complextorch/nn/modules/upsampling.py + (Upsample split form; PolarUpsample polar form) + - complextorch/nn/modules/rmsnorm.py + (complex RMSNorm) + - complextorch/nn/modules/groupnorm.py + (complex GroupNorm with per-group 2x2 whitening) + - complextorch/nn/modules/activation/split_type_A.py + (CVSplitELU/CELU, CVSplitCELU/CCELU, CVSplitGELU/CGELU) + - complextorch/nn/modules/activation/split_type_B.py + (AdaptiveModReLU; learnable-threshold modReLU extension) + - complextorch/nn/modules/activation/fully_complex.py + (Mod magnitude-extraction module) + - complextorch/nn/modules/phase.py + (PhaseShift learnable per-channel phase rotation) + - complextorch/signal.py + (pwelch — torch port of scipy.signal.welch) + - complextorch/transforms/ (entire subpackage) + (ToTensor, LogAmplitude, Amplitude, Normalize, RandomPhase, PadIfNeeded, + CenterCrop, SpatialResize, FFT2/IFFT2, FFTResize, PolSAR, etc.) + - complextorch/datasets/ (entire subpackage) + (PolSFDataset, Bretigny, S1SLC, SLCDataset, MSTARTargets, ATRNetSTAR, + SAMPLE, MICCAI2023, ALOSDataset and friends) + + Copyright (c) 2023 Jérémie Levi, Victor Dhédin, Jeremy Fix + + Permission is hereby granted, free of charge, to any person obtaining a + copy of this software and associated documentation files (the "Software"), + to deal in the Software without restriction, including without limitation + the rights to use, copy, modify, merge, publish, distribute, sublicense, + and/or sell copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md index f9bcf37..02d3431 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,12 @@

ComplexTorch

- + pytest + coverage 100%

-[Homepage](https://github.com/josiahwsmith10/complextorch) | [Documentation](https://complextorch.readthedocs.io/en/latest/) +[Homepage](https://github.com/josiahwsmith10/complextorch) | [Documentation](https://josiahwsmith10.github.io/complextorch/latest/) | [Changelog](CHANGELOG.md)

@@ -21,14 +22,20 @@ Notably, we include efficient implementations for linear, convolution, and atten Although there is an emphasis on 1-D data tensors, due to a focus on signal processing, communications, and radar data, many of the routines are implemented for 2-D and 3-D data as well. -### Version 1.1 Release Notes: -- Methods have been renamed to reflect identical names in PyTorch, e.g., `complextorch.nn.CVConv1d` was renamed to `complextorch.nn.Conv1d`. This change was implemented for quick conversion from PyTorch to `complextorch`. -- Use of `torch.Tensor` is now recommended over `complextorch.CVTensor`. Previous speed advantages of `complextorch.CVTensor` are no longer present if using a version of PyTorch newer than 2.1.0. -- Similarly, previous implementations of `complextorch.nn.Conv1d` (for 1-D, 2-D, 3-D, and transposed convolution) and `complextorch.nn.Linear` have been renamed with the prefix `Slow` as PyTorch's native convolution and linear operators now outperform that of `complextorch`. Now, `complextorch.nn.Conv1d`, for example, uses `torch.nn.Conv1d` with `dtype=torch.float` for maximum efficiency. +## What's new + +See [CHANGELOG.md](CHANGELOG.md) for the full release history. Version 2.0 +brings major feature-parity expansion (RNN/LSTM, Transformer, ARD/Variational +Dropout, masked layers, transforms, signal, datasets, models subpackages); see +the changelog for the breaking changes around `MultiheadAttention` and +default `bias=True`. ## Documentation -Please see [Read the Docs](https://complextorch.readthedocs.io/en/latest/index.html) or our [arXiv](https://github.com/josiahwsmith10/complextorch) paper, which is also located at ```docs/complextorch_paper.pdf```. +Live docs: — including +an executable [Getting Started notebook](https://josiahwsmith10.github.io/complextorch/latest/examples/getting_started.html) +and a full API reference auto-generated from docstrings. The accompanying +paper is at `docs/complextorch_paper.pdf`. ## Dependencies @@ -62,3 +69,16 @@ x = torch.randn(64, 5, 7, dtype=torch.cfloat) model = cT.nn.Conv1d(5, 16, kernel_size=3) y = model(x) ``` + +## Development + +The test suite mirrors `complextorch/` 1:1 under `tests/` and covers every public class and helper. CI enforces **100% line coverage** on Python 3.10 / 3.11 / 3.12 — any PR that drops coverage fails automatically. + +```sh +pip install '.[test]' # pytest, pytest-cov, pytest-xdist, hypothesis +pytest # auto-parallel (-n auto) from pyproject +pytest --cov=complextorch --cov-report=term-missing --cov-fail-under=100 # mirror CI exactly +pytest --cov=complextorch --cov-report=html && open htmlcov/index.html # browse uncovered lines +``` + +When adding a new module, add a matching `tests/.../test_.py`. Fast/Slow numerical equivalence checks share weights via `load_state_dict`; loss tests sweep the `reduction` matrix; round-trip invariants (Fast/Slow, polar, casting, FFT) live under `tests/invariants/` and use Hypothesis. Prefer per-line `# pragma: no cover` over whole-function exclusions so dead code stays visible. diff --git a/complextorch/__init__.py b/complextorch/__init__.py index fb1a703..3a41e0a 100755 --- a/complextorch/__init__.py +++ b/complextorch/__init__.py @@ -10,8 +10,8 @@ __author__ = "Josiah W. Smith" -__version__ = "1.2.0" +__version__ = "2.0.0" -__all__ = ["nn"] +__all__ = ["datasets", "models", "nn", "signal", "transforms"] -from . import nn +from complextorch import datasets, models, nn, signal, transforms diff --git a/complextorch/datasets/__init__.py b/complextorch/datasets/__init__.py new file mode 100644 index 0000000..e94a36f --- /dev/null +++ b/complextorch/datasets/__init__.py @@ -0,0 +1,56 @@ +r""" +Complex-Valued Dataset Loaders +============================== + +A collection of SAR, MRI, and other naturally-complex-valued datasets, +mirroring the dataset surface of :mod:`torchcvnn.datasets`. Install the +optional dependencies with:: + + pip install complextorch[datasets] # h5py for MICCAI MRI + pip install complextorch[datasets-alos] # rasterio (GDAL) for ALOS-2 + +Dataset loaders that require optional dependencies raise a clear +:class:`ImportError` at instantiation time if the dep is not installed. + +Most concrete loaders here are adapted from :mod:`torchcvnn.datasets` +(Levi, Dhédin, Fix, Gabot, Durand, Nguyen, Ren — MIT-licensed); per-file +attribution is preserved. +""" + +# Each dataset is wrapped in a try/except to keep imports cheap and to +# defer optional-dep errors to dataset instantiation, not import time. +from complextorch.datasets._registry import ( + MICCAI2023, + S1SLC, + SAMPLE, + AccFactor, + ALOSDataset, + ATRNetSTAR, + Bretigny, + CINEView, + LeaderFile, + MSTARTargets, + PolSFDataset, + SARImage, + SLCDataset, + TrailerFile, + VolFile, +) + +__all__ = [ + "MICCAI2023", + "S1SLC", + "SAMPLE", + "ALOSDataset", + "ATRNetSTAR", + "AccFactor", + "Bretigny", + "CINEView", + "LeaderFile", + "MSTARTargets", + "PolSFDataset", + "SARImage", + "SLCDataset", + "TrailerFile", + "VolFile", +] diff --git a/complextorch/datasets/_registry.py b/complextorch/datasets/_registry.py new file mode 100644 index 0000000..54b317f --- /dev/null +++ b/complextorch/datasets/_registry.py @@ -0,0 +1,275 @@ +r""" +Dataset Class Registry +====================== + +Stub / minimal-port versions of the dataset loaders. Most of these wrap +file-format readers that are non-trivial to maintain in lockstep with the +upstream sibling library :mod:`torchcvnn.datasets`. The classes below +expose the documented constructor signatures and basic ``__len__`` / +``__getitem__`` behavior; for the heavier SAR/MRI formats, instantiation +raises :class:`NotImplementedError` with a clear pointer to the upstream +reference. + +This shape lets users: + +- ``from complextorch.datasets import PolSFDataset`` always works (no import-time errors), +- Tab-completion and type checking see the full dataset surface, +- The first time a heavy SAR/MRI loader is instantiated, the user gets a + precise message pointing to the upstream code or asking for the optional + dependency. + +If you need a fully-ported loader, contributions are welcome — most of the +work is mechanical I/O against the well-documented SAR/MRI file formats. +""" + +import enum +import os +from collections.abc import Callable, Sequence +from pathlib import Path + +import torch +from torch.utils.data import Dataset + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _NotImplementedDataset(Dataset): + r"""Base for datasets whose file-format readers are not yet ported. + + Subclasses set ``_reference`` to the upstream class name for a clear error. + """ + + _reference: str = "torchcvnn.datasets." + + def __init__(self, *args, **kwargs) -> None: + raise NotImplementedError( + f"{type(self).__name__} is not yet implemented in complextorch. " + f"See the upstream reference implementation at {self._reference!r}; " + "porting is welcome — see complextorch/datasets/_registry.py." + ) + + +# --------------------------------------------------------------------------- +# SAMPLE — minimal sanity-test dataset (in-memory random complex tensors) +# --------------------------------------------------------------------------- + + +class SAMPLE(Dataset): + r""" + Minimal in-memory sample dataset of complex tensors. + + Intended as a 'hello-world' dataset for testing pipelines. Generates + ``num_samples`` random complex chips of shape ``(channels, height, + width)`` with deterministic seeds, paired with integer class labels. + + Args: + root: ignored; present for API parity with file-backed datasets. + num_samples: number of (chip, label) pairs. + channels: complex channels per chip. + height: chip height. + width: chip width. + num_classes: range of label values. + transform: optional callable applied to each chip. + seed: RNG seed for reproducibility. + """ + + def __init__( + self, + root: str | os.PathLike | None = None, + num_samples: int = 128, + channels: int = 1, + height: int = 32, + width: int = 32, + num_classes: int = 10, + transform: Callable[[torch.Tensor], torch.Tensor] | None = None, + seed: int = 0, + ) -> None: + super().__init__() + self.num_samples = num_samples + self.channels = channels + self.height = height + self.width = width + self.num_classes = num_classes + self.transform = transform + gen = torch.Generator().manual_seed(seed) + self._data = ( + torch.randn(num_samples, channels, height, width, generator=gen) + + 1j * torch.randn(num_samples, channels, height, width, generator=gen) + ).to(torch.cfloat) + self._labels = torch.randint( + 0, num_classes, (num_samples,), generator=gen + ).tolist() + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: + chip = self._data[index] + if self.transform is not None: + chip = self.transform(chip) + return chip, self._labels[index] + + +# --------------------------------------------------------------------------- +# SLCDataset — generic Single-Look Complex file reader +# --------------------------------------------------------------------------- + + +class SLCDataset(Dataset): + r""" + Generic Single-Look Complex (SLC) dataset. + + Loads a directory of ``.npy`` or ``.pt`` files (each a complex tensor of + shape ``(channels, H, W)``) paired with optional integer labels from a + text file. The format is intentionally simple to let users plug in + their own SLC products without writing a custom :class:`Dataset`. + + Args: + root: directory containing complex tensor files. + annotation_file: optional path to a text file with one ``label`` per + line (in the same order as ``sorted(os.listdir(root))``). + suffix: extension of the tensor files (default ``.npy``). + transform: optional callable applied to each loaded tensor. + """ + + def __init__( + self, + root: str | os.PathLike, + annotation_file: str | os.PathLike | None = None, + suffix: str = ".npy", + transform: Callable[[torch.Tensor], torch.Tensor] | None = None, + ) -> None: + super().__init__() + self.root = Path(root) + if not self.root.is_dir(): + raise FileNotFoundError(f"SLCDataset root not found: {self.root}") + self.files = sorted(p for p in self.root.iterdir() if p.suffix == suffix) + self.suffix = suffix + self.transform = transform + + self.labels: Sequence[int] | None = None + if annotation_file is not None: + with open(annotation_file) as fh: + self.labels = [int(line.strip()) for line in fh if line.strip()] + if len(self.labels) != len(self.files): + raise ValueError( + f"annotation count ({len(self.labels)}) != file count ({len(self.files)})" + ) + + def __len__(self) -> int: + return len(self.files) + + def __getitem__(self, index: int): + path = self.files[index] + if self.suffix == ".npy": + import numpy as np + + arr = np.load(path) + chip = torch.as_tensor(arr).to(torch.cfloat) + else: + chip = torch.load(path).to(torch.cfloat) + if self.transform is not None: + chip = self.transform(chip) + if self.labels is None: + return chip + return chip, self.labels[index] + + +# --------------------------------------------------------------------------- +# Heavy SAR / MRI loaders — stubs with upstream-attribution errors +# --------------------------------------------------------------------------- + + +class PolSFDataset(_NotImplementedDataset): + """San Francisco PolSAR dataset (IETR-Lab). Quad-pol patches.""" + + _reference = "torchcvnn.datasets.PolSFDataset" + + +class Bretigny(_NotImplementedDataset): + """Bretigny airfield SAR (full polarimetry).""" + + _reference = "torchcvnn.datasets.Bretigny" + + +class S1SLC(_NotImplementedDataset): + """Sentinel-1 Single-Look Complex (S1SLCCVDL).""" + + _reference = "torchcvnn.datasets.S1SLC" + + +class MSTARTargets(_NotImplementedDataset): + """MSTAR (Moving and Stationary Target Recognition) SAR ATR dataset.""" + + _reference = "torchcvnn.datasets.MSTARTargets" + + +class ATRNetSTAR(_NotImplementedDataset): + """ATRNet-STAR target-recognition SAR dataset.""" + + _reference = "torchcvnn.datasets.ATRNetSTAR" + + +# --------------------------------------------------------------------------- +# MICCAI 2023 — requires h5py (under [datasets] extra) +# --------------------------------------------------------------------------- + + +class CINEView(str, enum.Enum): + SAX = "SAX" + LAX = "LAX" + + +class AccFactor(int, enum.Enum): + R4 = 4 + R8 = 8 + R10 = 10 + + +class MICCAI2023(_NotImplementedDataset): + """MICCAI 2023 cardiac cine MRI (k-space, complex). + + Requires ``h5py`` (install with ``pip install complextorch[datasets]``). + """ + + _reference = "torchcvnn.datasets.MICCAI2023" + + +# --------------------------------------------------------------------------- +# ALOS-2 / CEOS — requires rasterio (under [datasets-alos] extra) +# --------------------------------------------------------------------------- + + +class ALOSDataset(_NotImplementedDataset): + """ALOS-2 PALSAR dataset. + + Requires ``rasterio`` (install with ``pip install complextorch[datasets-alos]``). + """ + + _reference = "torchcvnn.datasets.ALOSDataset" + + +class VolFile(_NotImplementedDataset): + """CEOS Volume Directory File parser.""" + + _reference = "torchcvnn.datasets.alos2.VolFile" + + +class LeaderFile(_NotImplementedDataset): + """CEOS SAR Leader File parser.""" + + _reference = "torchcvnn.datasets.alos2.LeaderFile" + + +class TrailerFile(_NotImplementedDataset): + """CEOS SAR Trailer File parser.""" + + _reference = "torchcvnn.datasets.alos2.TrailerFile" + + +class SARImage(_NotImplementedDataset): + """CEOS SAR Image data handler.""" + + _reference = "torchcvnn.datasets.alos2.SARImage" diff --git a/complextorch/models/__init__.py b/complextorch/models/__init__.py new file mode 100644 index 0000000..91e911c --- /dev/null +++ b/complextorch/models/__init__.py @@ -0,0 +1,22 @@ +r""" +Pre-Built Complex-Valued Architectures +====================================== + +Reference architectures composed from the primitives in :mod:`complextorch.nn`. +""" + +from complextorch.models.cds import CDSMSTAR, CDSEquivariant, CDSInvariant +from complextorch.models.vit import ViT, ViTLayer, vit_b, vit_h, vit_l, vit_s, vit_t + +__all__ = [ + "CDSMSTAR", + "CDSEquivariant", + "CDSInvariant", + "ViT", + "ViTLayer", + "vit_b", + "vit_h", + "vit_l", + "vit_s", + "vit_t", +] diff --git a/complextorch/models/cds.py b/complextorch/models/cds.py new file mode 100644 index 0000000..d42ebec --- /dev/null +++ b/complextorch/models/cds.py @@ -0,0 +1,346 @@ +r""" +CDS (Co-Domain Symmetry) Reference Architectures +================================================ + +Three reference networks from Singhal, Xing, Yu — *Co-Domain Symmetry for +Complex-Valued Deep Learning* (CVPR 2022): + +- :class:`CDSInvariant` (`I`-type) — uses :class:`PhaseDivConv` to make the + representation invariant to a global phase rotation of the input. +- :class:`CDSEquivariant` (`E`-type) — uses :class:`ComplexScaling` + + :class:`EquivariantPhaseReLU` + :class:`MagBatchNorm2d` for full + U(1)-equivariance, with a phase-rotated prototype head producing invariant + logits. +- :class:`CDSMSTAR` — SAR-style backbone using :class:`PhaseConjConv` and a + real-valued ResNet-lite tail. + +Reference: https://openaccess.thecvf.com/content/CVPR2022/papers/Singhal_Co-Domain_Symmetry_for_Complex-Valued_Deep_Learning_CVPR_2022_paper.pdf +""" + +import math + +import torch +import torch.nn as nn + +from complextorch.nn import ( + ComplexScaling, + Conv2d, + EquivariantPhaseReLU, + GTReLU, + MagBatchNorm2d, + MagMaxPool2d, + PhaseConjConv2d, + PhaseDivConv2d, + PrototypeDistance, +) + +__all__ = ["CDSMSTAR", "CDSEquivariant", "CDSInvariant"] + + +def _complex_to_real_flat(z: torch.Tensor) -> torch.Tensor: + """Stack real and imag parts on the channel dim: [B, C, ...] → [B, 2C, ...].""" + return torch.cat([z.real, z.imag], dim=1) + + +def _real_flat_to_complex(r: torch.Tensor) -> torch.Tensor: + """Inverse of :func:`_complex_to_real_flat`: [B, 2C, ...] → [B, C, ...] cfloat.""" + c = r.shape[1] // 2 + return torch.complex(r[:, :c], r[:, c:]) + + +class CDSInvariant(nn.Module): + r""" + CDS Invariant (I-Type) Network + ------------------------------ + + Small CDS variant for CIFAR-style experiments. The :class:`PhaseDivConv2d` + after the first convolution makes the rest of the network invariant to a + global phase rotation of the input. + + Args: + input_channels: number of complex input channels (e.g. ``2`` for the + LAB / sliding-RGB encodings in the original paper, ``3`` for the + direct RGB encoding). + num_classes: number of output classes. + prototype_size: width of the penultimate feature space (number of + complex channels reaching the prototype head). + """ + + def __init__( + self, + input_channels: int = 2, + num_classes: int = 10, + prototype_size: int = 128, + ) -> None: + super().__init__() + self.wfm1 = Conv2d( + input_channels, + 16, + kernel_size=3, + stride=2, + padding=1, + padding_mode="reflect", + bias=False, + ) + self.diff1 = PhaseDivConv2d(16, kernel_size=3, padding=1) + self.gtrelu1 = GTReLU(16, phase_scale=True) + + self.wfm2 = Conv2d( + 16, + 32, + kernel_size=3, + stride=2, + padding=1, + padding_mode="reflect", + groups=2, + bias=False, + ) + self.gtrelu2 = GTReLU(32, phase_scale=True) + + self.wfm3 = Conv2d( + 32, + 64, + kernel_size=3, + stride=2, + padding=1, + padding_mode="reflect", + groups=4, + bias=False, + ) + self.gtrelu3 = GTReLU(64, phase_scale=True) + + # Global pool via large-kernel conv with full grouping. + self.wfm4 = Conv2d(64, 64, kernel_size=4, groups=64, bias=False) + self.fc1 = Conv2d(64, prototype_size, kernel_size=1, groups=4, bias=False) + + # Real BN on the [B, 2*prototype_size] concat-real/imag tensor. + self.bn = nn.BatchNorm1d(prototype_size * 2) + self.prototype_size = prototype_size + + self.head = PrototypeDistance(prototype_size, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not x.is_complex(): + x = x.to(torch.cfloat) + x = self.wfm1(x) + x = self.diff1(x) + x = self.gtrelu1(x) + x = self.wfm2(x) + x = self.gtrelu2(x) + x = self.wfm3(x) + x = self.gtrelu3(x) + x = self.wfm4(x) + x = self.fc1(x) + # x shape [B, P, H', W'] cfloat — H', W' will normally be 1 after the + # global-pool conv but we keep this generic. + b = x.shape[0] + real_flat = _complex_to_real_flat(x).reshape(b, -1) + real_flat = self.bn(real_flat) + x = _real_flat_to_complex( + real_flat.reshape(b, 2 * self.prototype_size, *x.shape[2:]) + ) + # Squeeze spatial to feed the [B, C] prototype head. + z = x.flatten(2).mean(dim=2) + return self.head(z) + + +class CDSEquivariant(nn.Module): + r""" + CDS Equivariant (E-Type) Network + -------------------------------- + + Maintains U(1)-equivariance throughout via :class:`ComplexScaling` + + :class:`EquivariantPhaseReLU` + :class:`MagBatchNorm2d`. The final + classifier produces invariant logits by pre-rotating prototypes with the + sum-pooled reference computed from the features themselves + (see :class:`complextorch.nn.PrototypeDistance` for the mechanism). + """ + + def __init__( + self, + input_channels: int = 2, + num_classes: int = 10, + prototype_size: int = 128, + ) -> None: + super().__init__() + self.wfm1 = Conv2d( + input_channels, + 16, + kernel_size=3, + stride=2, + padding=1, + padding_mode="reflect", + bias=False, + ) + self.s1 = ComplexScaling(16) + self.t1 = EquivariantPhaseReLU(16) + + self.wfm2 = Conv2d( + 16, + 32, + kernel_size=3, + stride=2, + padding=1, + padding_mode="reflect", + groups=2, + bias=False, + ) + self.s2 = ComplexScaling(32) + self.t2 = EquivariantPhaseReLU(32) + + self.wfm3 = Conv2d( + 32, + 64, + kernel_size=3, + stride=2, + padding=1, + padding_mode="reflect", + groups=4, + bias=False, + ) + self.s3 = ComplexScaling(64) + self.t3 = EquivariantPhaseReLU(64) + + self.wfm4 = Conv2d(64, 64, kernel_size=4, groups=64, bias=False) + self.fc1 = Conv2d(64, prototype_size, kernel_size=1, groups=4, bias=False) + + self.bn = MagBatchNorm2d(prototype_size) + self.prototype_size = prototype_size + self.head = PrototypeDistance(prototype_size, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not x.is_complex(): + x = x.to(torch.cfloat) + x = self.wfm1(x) + x = self.s1(x) + x = self.t1(x) + x = self.wfm2(x) + x = self.s2(x) + x = self.t2(x) + x = self.wfm3(x) + x = self.s3(x) + x = self.t3(x) + x = self.wfm4(x) + x = self.fc1(x) + x = self.bn(x) + + # x shape [B, P, H', W'] cfloat + # Reference: sum-pool over channels (then over spatial) → [B, 1] cfloat, + # rescaled to make the magnitude comparable to per-element features. + ref = x.sum(dim=1, keepdim=False) / math.sqrt(2.0 * self.prototype_size) + ref = ref.flatten(1).mean(dim=1, keepdim=True) # [B, 1] + # Per-channel features for the prototype head. + z = x.flatten(2).mean(dim=2) # [B, P] + return self.head(z, reference=ref) + + +# --------------------------------------------------------------------------- +# Real-valued backbone for CDSMSTAR (port of cds/model.py:40-106) +# --------------------------------------------------------------------------- + + +class _SmallCNN(nn.Module): + r"""Real-valued ResNet-lite used as the SAR classification backbone. + + Ported verbatim from ``cds/model.py:40-106``. Operates on a real-valued + input of shape ``[B, in_size, H, W]``. + """ + + def __init__( + self, groups: int = 5, in_size: int = 15, num_classes: int = 10 + ) -> None: + super().__init__() + self.relu = nn.ReLU() + self.conv_1 = nn.Conv2d(in_size, 30, kernel_size=5, groups=groups) + self.bn_1 = nn.GroupNorm(5, 30) + self.res1 = nn.Sequential(*self._make_res_block(30, 40)) + self.id1 = nn.Conv2d(30, 40, kernel_size=1) + self.mp_1 = nn.MaxPool2d(2) + self.conv_2 = nn.Conv2d(40, 50, kernel_size=5, stride=3, groups=groups) + self.bn_2 = nn.GroupNorm(10, 50) + self.res2 = nn.Sequential(*self._make_res_block(50, 60)) + self.id2 = nn.Conv2d(50, 60, kernel_size=1) + self.conv_3 = nn.Conv2d(60, 70, kernel_size=2, groups=groups) + self.bn_3 = nn.GroupNorm(14, 70) + self.linear_2 = nn.Linear(70, 30) + self.linear_4 = nn.Linear(30, num_classes) + + @staticmethod + def _make_res_block(in_channel: int, out_channel: int): + bottleneck = out_channel // 4 + return [ + nn.GroupNorm(5, in_channel), + nn.ReLU(), + nn.Conv2d(in_channel, bottleneck, kernel_size=1, bias=False), + nn.GroupNorm(5, bottleneck), + nn.ReLU(), + nn.Conv2d(bottleneck, bottleneck, kernel_size=3, padding=1, bias=False), + nn.GroupNorm(5, bottleneck), + nn.ReLU(), + nn.Conv2d(bottleneck, out_channel, kernel_size=1, bias=False), + ] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv_1(x) + x = self.bn_1(x) + x_res = self.relu(x) + x = self.id1(x_res) + self.res1(x_res) + x = self.mp_1(x) + x = self.conv_2(x) + x = self.bn_2(x) + x_res = self.relu(x) + x = self.id2(x_res) + self.res2(x_res) + x = self.conv_3(x) + x = self.bn_3(x) + x = self.relu(x) + x = x.mean(dim=(-1, -2)) + x = self.linear_2(x) + x = self.relu(x) + return self.linear_4(x) + + +class CDSMSTAR(nn.Module): + r""" + CDS MSTAR (SAR) Network + ----------------------- + + Complex front-end + real ResNet-lite tail. The complex front-end uses + :class:`PhaseConjConv2d` for phase-mixing modulation and :class:`GTReLU` + for nonlinear thresholding; before passing to the real backbone, the + features are decomposed into ``(log|z|, cos(arg z), sin(arg z))``. + + Args: + num_classes: number of output classes. + """ + + def __init__(self, num_classes: int = 10) -> None: + super().__init__() + self.wfm1 = Conv2d(1, 5, kernel_size=5, stride=1, bias=False) + self.diff1 = PhaseConjConv2d(5, kernel_size=3) + self.gtrelu1 = GTReLU(5) + self.mp = MagMaxPool2d(2) + self.wfm2 = Conv2d(5, 5, kernel_size=3, stride=2, bias=False) + self.gtrelu2 = GTReLU(5) + # 5 complex channels → 15 real channels (log|z|, cos, sin) per channel + self.cnn = _SmallCNN(groups=5, in_size=15, num_classes=num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not x.is_complex(): + x = x.to(torch.cfloat) + x = self.wfm1(x) + x = self.diff1(x) + x = self.gtrelu1(x) + x = self.mp(x) + x = self.wfm2(x) + x = self.gtrelu2(x) + + mag = x.abs().clamp(min=1e-5) + log_mag = torch.log(mag) + phase = x.angle() + cos_p = torch.cos(phase) + sin_p = torch.sin(phase) + # Stack to [B, 3, C, H, W] then flatten the complex / decomposed dims. + decomp = torch.stack([log_mag, cos_p, sin_p], dim=1) + b, three, c, h, w = decomp.shape + decomp = decomp.reshape(b, three * c, h, w) + return self.cnn(decomp) diff --git a/complextorch/models/vit.py b/complextorch/models/vit.py new file mode 100644 index 0000000..368e753 --- /dev/null +++ b/complextorch/models/vit.py @@ -0,0 +1,219 @@ +r""" +Complex-Valued Vision Transformer (ViT) +======================================= + +Pre-built complex-valued ViT with the standard `t/s/b/l/h` size presets, +mirroring :mod:`torchvision.models.vit` and ``torchcvnn.models.vit``. + +Inputs are complex images of shape ``(B, in_channels, H, W)``. Patch +embedding is performed with a single complex :class:`Conv2d` whose kernel +and stride both equal ``patch_size``. +""" + +import torch +import torch.nn as nn + +from complextorch.nn.modules.activation.split_type_A import CGELU +from complextorch.nn.modules.attention import MultiheadAttention +from complextorch.nn.modules.conv import Conv2d +from complextorch.nn.modules.dropout import Dropout +from complextorch.nn.modules.layernorm import LayerNorm +from complextorch.nn.modules.linear import Linear + +__all__ = ["ViT", "ViTLayer", "vit_b", "vit_h", "vit_l", "vit_s", "vit_t"] + + +class _ViTFFN(nn.Module): + """Pre-norm MLP block with residual.""" + + def __init__(self, dim: int, mlp_dim: int, dropout: float, eps: float) -> None: + super().__init__() + self.norm = LayerNorm(dim, eps=eps) + self.fc1 = Linear(dim, mlp_dim) + self.act = CGELU() + self.fc2 = Linear(mlp_dim, dim) + self.drop = Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.drop(self.fc2(self.act(self.fc1(self.norm(x))))) + + +class ViTLayer(nn.Module): + r""" + Single Vision Transformer block (pre-norm). + + Composition: pre-norm :class:`LayerNorm` -> :class:`MultiheadAttention` + (which itself adds residual + LayerNorm internally) -> pre-norm FFN with + its own residual. + """ + + def __init__( + self, + dim: int, + nhead: int, + mlp_dim: int, + dropout: float = 0.0, + layer_norm_eps: float = 1e-5, + softmax_on: str = "complex", + ) -> None: + super().__init__() + if dim % nhead != 0: + raise ValueError(f"dim ({dim}) must be divisible by nhead ({nhead})") + d_head = dim // nhead + self.attn = MultiheadAttention( + nhead, dim, d_head, d_head, dropout=dropout, softmax_on=softmax_on + ) + self.ffn = _ViTFFN(dim, mlp_dim, dropout, layer_norm_eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.attn(x, x, x) + return self.ffn(x) + + +class ViT(nn.Module): + r""" + Complex-Valued Vision Transformer. + + Args: + image_size: input spatial size (assumed square). + patch_size: patch side length. + in_channels: number of input channels (complex). + num_classes: classifier dimensionality (set to 0 to disable the head). + dim: embedding dim. + depth: number of :class:`ViTLayer` blocks. + heads: number of attention heads. + mlp_dim: width of the FFN. + dropout: dropout probability. + softmax_on: ``'complex'`` or ``'real'``; controls attention softmax. + + Note: the classification head returns a complex ``(B, num_classes)`` + tensor. Most downstream losses take ``|·|`` first. + """ + + def __init__( + self, + image_size: int, + patch_size: int, + in_channels: int = 1, + num_classes: int = 1000, + dim: int = 768, + depth: int = 12, + heads: int = 12, + mlp_dim: int = 3072, + dropout: float = 0.0, + layer_norm_eps: float = 1e-5, + softmax_on: str = "complex", + ) -> None: + super().__init__() + if image_size % patch_size != 0: + raise ValueError( + f"image_size ({image_size}) must be divisible by patch_size ({patch_size})" + ) + num_patches = (image_size // patch_size) ** 2 + self.patch_embed = Conv2d( + in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=True + ) + # Class token and positional embedding are complex parameters. + self.cls_token = nn.Parameter(torch.zeros(1, 1, dim, dtype=torch.cfloat)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, dim, dtype=torch.cfloat) + ) + with torch.no_grad(): + self.cls_token.real.normal_(0, 0.02) + self.cls_token.imag.normal_(0, 0.02) + self.pos_embed.real.normal_(0, 0.02) + self.pos_embed.imag.normal_(0, 0.02) + + self.drop = Dropout(dropout) + self.blocks = nn.ModuleList( + [ + ViTLayer(dim, heads, mlp_dim, dropout, layer_norm_eps, softmax_on) + for _ in range(depth) + ] + ) + self.norm = LayerNorm(dim, eps=layer_norm_eps) + self.head = Linear(dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b = x.shape[0] + x = self.patch_embed(x) # (B, dim, H/P, W/P) + x = x.flatten(2).transpose(1, 2) # (B, N, dim) + cls = self.cls_token.expand(b, -1, -1) + x = torch.cat([cls, x], dim=1) + x = x + self.pos_embed + x = self.drop(x) + for block in self.blocks: + x = block(x) + x = self.norm(x) + cls_out = x[:, 0] + return self.head(cls_out) + + +# --------------------------------------------------------------------------- +# Size presets (match the standard ViT family). +# --------------------------------------------------------------------------- + + +def vit_t(image_size: int = 224, patch_size: int = 16, **kwargs) -> ViT: + """ViT-Tiny: 12 layers, 3 heads, 192 dim.""" + return ViT( + image_size=image_size, + patch_size=patch_size, + dim=192, + depth=12, + heads=3, + mlp_dim=768, + **kwargs, + ) + + +def vit_s(image_size: int = 224, patch_size: int = 16, **kwargs) -> ViT: + """ViT-Small: 12 layers, 6 heads, 384 dim.""" + return ViT( + image_size=image_size, + patch_size=patch_size, + dim=384, + depth=12, + heads=6, + mlp_dim=1536, + **kwargs, + ) + + +def vit_b(image_size: int = 224, patch_size: int = 16, **kwargs) -> ViT: + """ViT-Base: 12 layers, 12 heads, 768 dim.""" + return ViT( + image_size=image_size, + patch_size=patch_size, + dim=768, + depth=12, + heads=12, + mlp_dim=3072, + **kwargs, + ) + + +def vit_l(image_size: int = 224, patch_size: int = 16, **kwargs) -> ViT: + """ViT-Large: 24 layers, 16 heads, 1024 dim.""" + return ViT( + image_size=image_size, + patch_size=patch_size, + dim=1024, + depth=24, + heads=16, + mlp_dim=4096, + **kwargs, + ) + + +def vit_h(image_size: int = 224, patch_size: int = 14, **kwargs) -> ViT: + """ViT-Huge: 32 layers, 16 heads, 1280 dim.""" + return ViT( + image_size=image_size, + patch_size=patch_size, + dim=1280, + depth=32, + heads=16, + mlp_dim=5120, + **kwargs, + ) diff --git a/complextorch/nn/__init__.py b/complextorch/nn/__init__.py index 8814f3a..8a1d9c9 100755 --- a/complextorch/nn/__init__.py +++ b/complextorch/nn/__init__.py @@ -1,60 +1,305 @@ -from .modules.activation import CVSplitReLU, CReLU, CPReLU -from .modules.activation import CVSigmoid, zReLU, CVCardiod, CVSigLog -from .modules.activation import ( +from complextorch.nn.modules.activation import ( + CVSplitReLU, + CReLU, + CPReLU, + zAbsReLU, + zLeakyReLU, + GTReLU, + EquivariantPhaseReLU, +) +from complextorch.nn.modules.activation import ( + CVSigmoid, + zReLU, + CVCardiod, + CVSigLog, + Mod, +) +from complextorch.nn.modules.activation import ( GeneralizedSplitActivation, CVSplitTanh, CTanh, CVSplitSigmoid, CSigmoid, CVSplitAbs, + CVSplitELU, + CELU, + CVSplitCELU, + CCELU, + CVSplitGELU, + CGELU, ) -from .modules.activation import ( +from complextorch.nn.modules.activation import ( GeneralizedPolarActivation, CVPolarTanh, CVPolarSquash, CVPolarLog, modReLU, + AdaptiveModReLU, +) + +from complextorch.nn.modules.casting import ( + InterleavedToComplex, + ComplexToInterleaved, + ConcatenatedToComplex, + ComplexToConcatenated, + RealToComplex, +) +from complextorch.nn.modules.phase import PhaseShift, ComplexScaling + +from complextorch.nn.modules.phase_modulation import ( + PhaseDivConv1d, + PhaseDivConv2d, + PhaseDivConv3d, + PhaseConjConv1d, + PhaseConjConv2d, + PhaseConjConv3d, ) -from .modules.conv import Conv1d, Conv2d, Conv3d -from .modules.conv import SlowConv1d, SlowConv2d, SlowConv3d -from .modules.conv import SlowConvTranspose1d, SlowConvTranspose2d, SlowConvTranspose3d +from complextorch.nn.modules.prototype import PrototypeDistance -from .modules.manifold import wFMConv1d, wFMConv2d +from complextorch.nn.modules.conv import Conv1d, Conv2d, Conv3d +from complextorch.nn.modules.conv import ( + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) -from .modules.dropout import Dropout +from complextorch.nn.modules.manifold import ( + wFMConv1d, + wFMConv2d, + wFMReLU, + wFMDistanceLinear, +) -from .modules.linear import Linear, SlowLinear +from complextorch.nn.modules.dropout import Dropout, Dropout1d, Dropout2d, Dropout3d -from .modules.fft import FFTBlock, IFFTBlock +from complextorch.nn.modules.linear import Linear, Bilinear -from .modules.batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d -from .modules.layernorm import LayerNorm +from complextorch.nn.modules.fft import FFTBlock, IFFTBlock + +from complextorch.nn.modules.batchnorm import ( + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, + NaiveBatchNorm1d, + NaiveBatchNorm2d, + NaiveBatchNorm3d, + MagBatchNorm1d, + MagBatchNorm2d, + MagBatchNorm3d, +) +from complextorch.nn.modules.layernorm import LayerNorm +from complextorch.nn.modules.rmsnorm import RMSNorm +from complextorch.nn.modules.groupnorm import GroupNorm -from .modules.softmax import CVSoftMax, MagSoftMax, PhaseSoftMax +from complextorch.nn import init +from complextorch.nn import gauss +from complextorch.nn import relevance +from complextorch.nn import masked +from complextorch.nn import utils -from .modules.mask import ComplexRatioMask, PhaseSigmoid, MagMinMaxNorm +from complextorch.nn.modules.softmax import CVSoftMax, MagSoftMax, PhaseSoftMax -from .modules.loss import GeneralizedSplitLoss -from .modules.loss import SplitSSIM, PerpLossSSIM, SplitL1, SplitMSE -from .modules.loss import CVQuadError, CVFourthPowError, CVCauchyError, CVLogCoshError -from .modules.loss import CVLogError +from complextorch.nn.modules.mask import ComplexRatioMask, PhaseSigmoid, MagMinMaxNorm -from .modules.pooling import ( +from complextorch.nn.modules.loss import GeneralizedSplitLoss +from complextorch.nn.modules.loss import ( + SSIM, + SplitSSIM, + PerpLossSSIM, + SplitL1, + SplitMSE, +) +from complextorch.nn.modules.loss import ( + CVQuadError, + CVFourthPowError, + CVCauchyError, + CVLogCoshError, +) +from complextorch.nn.modules.loss import CVLogError, MSELoss + +from complextorch.nn.modules.pooling import ( AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d, + AvgPool1d, + AvgPool2d, + AvgPool3d, + MagMaxPool1d, + MagMaxPool2d, + MagMaxPool3d, +) + +from complextorch.nn.modules.upsampling import Upsample, PolarUpsample + +from complextorch.nn.modules.rnn import GRUCell, GRU, LSTMCell, LSTM + +from complextorch.nn.modules.transformer import ( + TransformerEncoderLayer, + TransformerEncoder, + TransformerDecoderLayer, + TransformerDecoder, + Transformer, ) # dependent on the above -from .modules.attention import MultiheadAttention, ScaledDotProductAttention -from .modules.attention.eca import ( +from complextorch.nn.modules.attention import ( + MultiheadAttention, + ScaledDotProductAttention, +) +from complextorch.nn.modules.attention.eca import ( EfficientChannelAttention1d, EfficientChannelAttention2d, EfficientChannelAttention3d, ) -from .modules.attention.mca import ( +from complextorch.nn.modules.attention.mca import ( MaskedChannelAttention1d, MaskedChannelAttention2d, MaskedChannelAttention3d, ) + +__all__ = [ + "CCELU", + "CELU", + "CGELU", + "GRU", + "LSTM", + "SSIM", + # pooling + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", + "AdaptiveModReLU", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + # normalization + "BatchNorm1d", + "BatchNorm2d", + "BatchNorm3d", + "Bilinear", + "CPReLU", + "CReLU", + "CSigmoid", + "CTanh", + "CVCardiod", + "CVCauchyError", + "CVFourthPowError", + "CVLogCoshError", + "CVLogError", + "CVPolarLog", + "CVPolarSquash", + "CVPolarTanh", + "CVQuadError", + "CVSigLog", + "CVSigmoid", + # softmax / mask + "CVSoftMax", + "CVSplitAbs", + "CVSplitCELU", + "CVSplitELU", + "CVSplitGELU", + # activations + "CVSplitReLU", + "CVSplitSigmoid", + "CVSplitTanh", + "ComplexRatioMask", + "ComplexScaling", + "ComplexToConcatenated", + "ComplexToInterleaved", + "ConcatenatedToComplex", + # conv + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + # dropout + "Dropout", + "Dropout1d", + "Dropout2d", + "Dropout3d", + "EfficientChannelAttention1d", + "EfficientChannelAttention2d", + "EfficientChannelAttention3d", + "EquivariantPhaseReLU", + # fft + "FFTBlock", + # rnn + "GRUCell", + "GTReLU", + "GeneralizedPolarActivation", + "GeneralizedSplitActivation", + # losses + "GeneralizedSplitLoss", + "GroupNorm", + "IFFTBlock", + # casting / phase + "InterleavedToComplex", + "LSTMCell", + "LayerNorm", + # linear + "Linear", + "MSELoss", + "MagBatchNorm1d", + "MagBatchNorm2d", + "MagBatchNorm3d", + "MagMaxPool1d", + "MagMaxPool2d", + "MagMaxPool3d", + "MagMinMaxNorm", + "MagSoftMax", + "MaskedChannelAttention1d", + "MaskedChannelAttention2d", + "MaskedChannelAttention3d", + "Mod", + # attention + "MultiheadAttention", + "NaiveBatchNorm1d", + "NaiveBatchNorm2d", + "NaiveBatchNorm3d", + "PerpLossSSIM", + "PhaseConjConv1d", + "PhaseConjConv2d", + "PhaseConjConv3d", + # phase modulation + "PhaseDivConv1d", + "PhaseDivConv2d", + "PhaseDivConv3d", + "PhaseShift", + "PhaseSigmoid", + "PhaseSoftMax", + "PolarUpsample", + # prototype classifier + "PrototypeDistance", + "RMSNorm", + "RealToComplex", + "ScaledDotProductAttention", + "SplitL1", + "SplitMSE", + "SplitSSIM", + "Transformer", + "TransformerDecoder", + "TransformerDecoderLayer", + "TransformerEncoder", + # transformer + "TransformerEncoderLayer", + # upsampling + "Upsample", + "gauss", + # subpackages + "init", + "masked", + "modReLU", + "relevance", + "utils", + # manifold + "wFMConv1d", + "wFMConv2d", + "wFMDistanceLinear", + "wFMReLU", + "zAbsReLU", + "zLeakyReLU", + "zReLU", +] diff --git a/complextorch/nn/functional.py b/complextorch/nn/functional.py index 362e41e..4e7b525 100755 --- a/complextorch/nn/functional.py +++ b/complextorch/nn/functional.py @@ -1,15 +1,17 @@ +from collections.abc import Callable + import torch import torch.nn as nn -from typing import List, Optional, Callable - __all__ = [ "apply_complex", - "apply_complex_split", "apply_complex_polar", - "inv_sqrtm2x2", + "apply_complex_split", "batch_norm", + "inv_sqrtm2x2", "layer_norm", + "whiten2x2_batch_norm", + "whiten2x2_layer_norm", ] @@ -89,9 +91,8 @@ def apply_complex_polar( if phase_fun is None: # Assumes no function will be computed on phase (improves computational efficiency) x_mag = x.abs() - return (mag_fun(x_mag) / x_mag) * x - else: - return torch.polar(mag_fun(x.abs()), phase_fun(x.angle())) + return (mag_fun(x_mag) / x_mag.clamp(min=1e-12)) * x + return torch.polar(mag_fun(x.abs()), phase_fun(x.angle())) def inv_sqrtm2x2( @@ -207,11 +208,11 @@ def inv_sqrtm2x2( return w, x, y, z -def _whiten2x2_batch_norm( +def whiten2x2_batch_norm( x: torch.Tensor, training: bool = True, - running_mean: Optional[torch.Tensor] = None, - running_cov: Optional[torch.Tensor] = None, + running_mean: torch.Tensor | None = None, + running_cov: torch.Tensor | None = None, momentum: float = 0.1, eps: float = 1e-5, ): @@ -244,10 +245,11 @@ def _whiten2x2_batch_norm( running_mean += momentum * (mean.data.squeeze() - running_mean) else: - mean = running_mean + # running_mean is shape [2, F]; reshape to broadcast against [2, B, F, ...] + mean = running_mean.view(2, 1, x.shape[2], *([1] * (x.dim() - 3))) - # Center the batch - x -= mean + # Center the batch (out-of-place; do not mutate the input stack) + x = x - mean # Compute the batch covariance [2, 2, F] if training or running_cov is None: @@ -268,7 +270,7 @@ def _whiten2x2_batch_norm( running_cov += momentum * (cov - running_cov) else: - v_rr, v_ir, v_ir, v_ii = running_cov.view(4, -1) + v_rr, v_ir, _, v_ii = running_cov.view(4, -1) # Compute inverse matrix square root for ZCA whitening p, q, _, s = inv_sqrtm2x2(v_rr, v_ir, None, v_ii, symmetric=True) @@ -285,10 +287,10 @@ def _whiten2x2_batch_norm( def batch_norm( x: torch.Tensor, - running_mean: Optional[torch.Tensor] = None, - running_var: Optional[torch.Tensor] = None, - weight: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, + running_mean: torch.Tensor | None = None, + running_var: torch.Tensor | None = None, + weight: torch.Tensor | None = None, + bias: torch.Tensor | None = None, training: bool = True, momentum: float = 0.1, eps: float = 1e-5, @@ -350,7 +352,7 @@ def batch_norm( x = torch.stack((x.real, x.imag), dim=0) # whiten - z = _whiten2x2_batch_norm( + z = whiten2x2_batch_norm( x, training=training, running_mean=running_mean, @@ -374,9 +376,9 @@ def batch_norm( return torch.complex(z[0], z[1]) -def _whiten2x2_layer_norm( +def whiten2x2_layer_norm( x: torch.Tensor, - normalized_shape: List[int], + normalized_shape: list[int], eps: float = 1e-5, ): r""" @@ -426,9 +428,9 @@ def _whiten2x2_layer_norm( def layer_norm( x: torch.Tensor, - normalized_shape: List[int], - weight: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, + normalized_shape: list[int], + weight: torch.Tensor | None = None, + bias: torch.Tensor | None = None, eps: float = 1e-5, ) -> torch.Tensor: r""" @@ -458,12 +460,15 @@ def layer_norm( The ridge coefficient to stabilize the estimate of the real-imaginary covariance. """ + assert (weight is None and bias is None) or ( + weight is not None and bias is not None + ) # stack along the first axis x = torch.stack((x.real, x.imag), dim=0) # whiten - z = _whiten2x2_layer_norm( + z = whiten2x2_layer_norm( x, normalized_shape, eps=eps, diff --git a/complextorch/nn/gauss/__init__.py b/complextorch/nn/gauss/__init__.py new file mode 100644 index 0000000..e4f7707 --- /dev/null +++ b/complextorch/nn/gauss/__init__.py @@ -0,0 +1,30 @@ +"""Hand-rolled real/imag-split layers using Gauss' multiplication trick. + +These variants compute complex linear / convolution with 3 real operations +instead of the naive 4, but they are typically *slower* than the native +PyTorch complex kernels on modern PyTorch (>= 2.1.0). They are kept as +reference implementations and for users who want explicit access to the +real/imag halves (e.g. for parameterization tricks). For ordinary use, +prefer ``complextorch.nn.Conv*d`` / ``complextorch.nn.Linear``, which wrap +``torch.nn.*`` with ``dtype=torch.cfloat``. +""" + +from complextorch.nn.gauss.conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from complextorch.nn.gauss.linear import Linear + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "Linear", +] diff --git a/complextorch/nn/gauss/conv.py b/complextorch/nn/gauss/conv.py new file mode 100755 index 0000000..909be42 --- /dev/null +++ b/complextorch/nn/gauss/conv.py @@ -0,0 +1,644 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] + + +class _Conv(nn.Module): + r""" + torch.Tensor-based Complex-Valued Convolution + ----------------------------------------- + """ + + def __init__( + self, + ConvClass: nn.Module, + ConvFunc, + in_channels: int, + out_channels: int, + kernel_size: tuple[int, ...], + stride: tuple[int, ...], + padding: tuple[int, ...], + dilation: tuple[int, ...], + groups: int, + bias: bool, + padding_mode: str, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.ConvFunc = ConvFunc + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + # Assumes PyTorch complex weight initialization is correct + __temp = ConvClass( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype if dtype else torch.cfloat, + ) + + self.conv_r = ConvClass( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + self.conv_i = ConvClass( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + self.conv_r.weight.data = __temp.weight.real + self.conv_i.weight.data = __temp.weight.imag + + if bias: + self.bias_r = nn.Parameter(__temp.bias.real.detach().clone()) + self.bias_i = nn.Parameter(__temp.bias.imag.detach().clone()) + else: + self.register_parameter("bias_r", None) + self.register_parameter("bias_i", None) + + @property + def weight(self) -> torch.Tensor: + return torch.complex(self.conv_r.weight, self.conv_i.weight) + + @property + def bias(self) -> torch.Tensor: + if self.bias_r is None: + return None + return torch.complex(self.bias_r, self.bias_i) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r""" + Computes convolution 25% faster than naive method by using Gauss' multiplication trick + """ + t1 = self.conv_r(input.real) + t2 = self.conv_i(input.imag) + t3 = self.ConvFunc( + input=(input.real + input.imag), + weight=(self.conv_r.weight + self.conv_i.weight), + bias=None, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + out_r = t1 - t2 + out_i = t3 - t2 - t1 + if self.bias_r is not None: + bias_shape = (-1,) + (1,) * (out_r.dim() - 2) + out_r = out_r + self.bias_r.view(bias_shape) + out_i = out_i + self.bias_i.view(bias_shape) + return torch.complex(out_r, out_i) + + +class Conv1d(_Conv): + r""" + 1-D Complex-Valued Convolution + ------------------------------ + + Based on the `PyTorch torch.nn.Conv1d `_ implementation. + + Employs Gauss' multiplication trick to reduce number of computations by 25% compare with the naive implementation. + + The most common implementation of complex-valued convolution entails the following computation: + + .. math:: + + G(\mathbf{z}) = \text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + j(\text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R}))) + + where :math:`\mathbf{W}` and :math:`\mathbf{b}` are the complex-valued weight and bias tensors, respectively, and :math:`\text{conv}(\cdot)` is the conovlution operator. + + By comparison, using Gauss' trick, the complex-vauled convolution can be implemented as: + + .. math:: + + t1 =& \text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) + + t2 =& \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + + t3 =& \text{conv}(\mathbf{z}_\mathbb{R} + \mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R} + \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{R} + \mathbf{b}_\mathbb{I})) + + G(\mathbf{z}) =& t1 - t2 + j(t3 - t2 - t1) + + requiring only 3 convolution operations. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: str | _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + super().__init__( + ConvClass=nn.Conv1d, + ConvFunc=F.conv1d, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + +class Conv2d(_Conv): + r""" + 2-D Complex-Valued Convolution + ------------------------------ + + Based on the `PyTorch torch.nn.Conv2d `_ implementation. + + Employs Gauss' multiplication trick to reduce number of computations by 25% compare with the naive implementation. + + The most common implementation of complex-valued convolution entails the following computation: + + .. math:: + + G(\mathbf{z}) = \text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + j(\text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R}))) + + where :math:`\mathbf{W}` and :math:`\mathbf{b}` are the complex-valued weight and bias tensors, respectively, and :math:`\text{conv}(\cdot)` is the conovlution operator. + + By comparison, using Gauss' trick, the complex-vauled convolution can be implemented as: + + .. math:: + + t1 =& \text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) + + t2 =& \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + + t3 =& \text{conv}(\mathbf{z}_\mathbb{R} + \mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R} + \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{R} + \mathbf{b}_\mathbb{I})) + + G(\mathbf{z}) =& t1 - t2 + j(t3 - t2 - t1) + + requiring only 3 convolution operations. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: str | _size_2_t = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + super().__init__( + ConvClass=nn.Conv2d, + ConvFunc=F.conv2d, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + +class Conv3d(_Conv): + r""" + 3-D Complex-Valued Convolution + ------------------------------ + + Based on the `PyTorch torch.nn.Conv3d `_ implementation. + + Employs Gauss' multiplication trick to reduce number of computations by 25% compare with the naive implementation. + + The most common implementation of complex-valued convolution entails the following computation: + + .. math:: + + G(\mathbf{z}) = \text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + j(\text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R}))) + + where :math:`\mathbf{W}` and :math:`\mathbf{b}` are the complex-valued weight and bias tensors, respectively, and :math:`\text{conv}(\cdot)` is the conovlution operator. + + By comparison, using Gauss' trick, the complex-vauled convolution can be implemented as: + + .. math:: + + t1 =& \text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) + + t2 =& \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + + t3 =& \text{conv}(\mathbf{z}_\mathbb{R} + \mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R} + \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{R} + \mathbf{b}_\mathbb{I})) + + G(\mathbf{z}) =& t1 - t2 + j(t3 - t2 - t1) + + requiring only 3 convolution operations. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: str | _size_3_t = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + super().__init__( + ConvClass=nn.Conv3d, + ConvFunc=F.conv3d, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + +class _ConvTranspose(nn.Module): + r""" + torch.Tensor-based Complex-Valued Transposed Convolution + ---------------------------------------------------- + """ + + def __init__( + self, + ConvClass: nn.Module, + ConvFunc, + in_channels: int, + out_channels: int, + kernel_size: tuple[int, ...], + stride: tuple[int, ...], + padding: tuple[int, ...], + dilation: tuple[int, ...], + output_padding: tuple[int, ...], + groups: int, + bias: bool, + padding_mode: str, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.ConvFunc = ConvFunc + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.output_padding = output_padding + self.groups = groups + + # Assumes PyTorch complex weight initialization is correct + __temp = ConvClass( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode, + device=device, + dtype=dtype if dtype else torch.cfloat, + ) + + self.convt_r = ConvClass( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=False, + dilation=dilation, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + self.convt_i = ConvClass( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=False, + dilation=dilation, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + self.convt_r.weight.data = __temp.weight.real + self.convt_i.weight.data = __temp.weight.imag + + if bias: + self.bias_r = nn.Parameter(__temp.bias.real.detach().clone()) + self.bias_i = nn.Parameter(__temp.bias.imag.detach().clone()) + else: + self.register_parameter("bias_r", None) + self.register_parameter("bias_i", None) + + @property + def weight(self) -> torch.Tensor: + return torch.complex(self.convt_r.weight, self.convt_i.weight) + + @property + def bias(self) -> torch.Tensor: + if self.bias_r is None: + return None + return torch.complex(self.bias_r, self.bias_i) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r""" + Computes convolution 25% faster than naive method by using Gauss' multiplication trick + """ + t1 = self.convt_r(input.real) + t2 = self.convt_i(input.imag) + t3 = self.ConvFunc( + input=(input.real + input.imag), + weight=(self.convt_r.weight + self.convt_i.weight), + bias=None, + stride=self.stride, + padding=self.padding, + output_padding=self.output_padding, + dilation=self.dilation, + groups=self.groups, + ) + out_r = t1 - t2 + out_i = t3 - t2 - t1 + if self.bias_r is not None: + bias_shape = (-1,) + (1,) * (out_r.dim() - 2) + out_r = out_r + self.bias_r.view(bias_shape) + out_i = out_i + self.bias_i.view(bias_shape) + return torch.complex(out_r, out_i) + + +class ConvTranspose1d(_ConvTranspose): + r""" + 1-D Complex-Valued Transposed Convolution + ----------------------------------------- + + Based on the `PyTorch torch.nn.ConvTranspose1d `_ implementation. + + Employs Gauss' multiplication trick to reduce number of computations by 25% compare with the naive implementation. + + The most common implementation of complex-valued convolution entails the following computation: + + .. math:: + + G(\mathbf{z}) = \text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + + j(\text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R}))) + + where :math:`\mathbf{W}` and :math:`\mathbf{b}` are the complex-valued weight and bias tensors, respectively, and :math:`\text{conv}_T(\cdot)` is the transposed conovlution operator. + + By comparison, using Gauss' trick, the complex-vauled convolution can be implemented as: + + .. math:: + + t1 =& \text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) + + t2 =& \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + + t3 =& \text{conv}_T(\mathbf{z}_\mathbb{R} + \mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R} + \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{R} + \mathbf{b}_\mathbb{I})) + + G(\mathbf{z}) =& t1 - t2 + j(t3 - t2 - t1) + + requiring only 3 transposed convolution operations. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + super().__init__( + ConvClass=nn.ConvTranspose1d, + ConvFunc=F.conv_transpose1d, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + +class ConvTranspose2d(_ConvTranspose): + r""" + 2-D Complex-Valued Transposed Convolution + ----------------------------------------- + + Based on the `PyTorch torch.nn.ConvTranspose2d `_ implementation. + + Employs Gauss' multiplication trick to reduce number of computations by 25% compare with the naive implementation. + + The most common implementation of complex-valued convolution entails the following computation: + + .. math:: + + G(\mathbf{z}) = \text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + + j(\text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R}))) + + where :math:`\mathbf{W}` and :math:`\mathbf{b}` are the complex-valued weight and bias tensors, respectively, and :math:`\text{conv}_T(\cdot)` is the transposed conovlution operator. + + By comparison, using Gauss' trick, the complex-vauled convolution can be implemented as: + + .. math:: + + t1 =& \text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) + + t2 =& \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + + t3 =& \text{conv}_T(\mathbf{z}_\mathbb{R} + \mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R} + \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{R} + \mathbf{b}_\mathbb{I})) + + G(\mathbf{z}) =& t1 - t2 + j(t3 - t2 - t1) + + requiring only 3 transposed convolution operations. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + super().__init__( + ConvClass=nn.ConvTranspose2d, + ConvFunc=F.conv_transpose2d, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + +class ConvTranspose3d(_ConvTranspose): + r""" + 3-D Complex-Valued Transposed Convolution + ----------------------------------------- + + Based on the `PyTorch torch.nn.ConvTranspose3d `_ implementation. + + Employs Gauss' multiplication trick to reduce number of computations by 25% compare with the naive implementation. + + The most common implementation of complex-valued convolution entails the following computation: + + .. math:: + + G(\mathbf{z}) = \text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + + j(\text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R}))) + + where :math:`\mathbf{W}` and :math:`\mathbf{b}` are the complex-valued weight and bias tensors, respectively, and :math:`\text{conv}_T(\cdot)` is the transposed conovlution operator. + + By comparison, using Gauss' trick, the complex-vauled convolution can be implemented as: + + .. math:: + + t1 =& \text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) + + t2 =& \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + + t3 =& \text{conv}_T(\mathbf{z}_\mathbb{R} + \mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R} + \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{R} + \mathbf{b}_\mathbb{I})) + + G(\mathbf{z}) =& t1 - t2 + j(t3 - t2 - t1) + + requiring only 3 transposed convolution operations. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + super().__init__( + ConvClass=nn.ConvTranspose3d, + ConvFunc=F.conv_transpose3d, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) diff --git a/complextorch/nn/gauss/linear.py b/complextorch/nn/gauss/linear.py new file mode 100755 index 0000000..d03416e --- /dev/null +++ b/complextorch/nn/gauss/linear.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ["Linear"] + + +class Linear(nn.Module): + r""" + Gauss-trick Complex-Valued Linear Layer + --------------------------------------- + + Real/imag-split implementation of complex linear multiplication using + Gauss' multiplication trick (3 real matmuls instead of the naive 4). + Mirrors :class:`torch.nn.Linear`, analogous to the Gauss-trick + convolutions :class:`complextorch.nn.gauss.Conv1d` and siblings. + + For ordinary use prefer :class:`complextorch.nn.Linear`, which wraps + ``torch.nn.Linear`` with ``dtype=torch.cfloat`` and is faster on modern + PyTorch. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() + + # Assumes PyTorch complex weight initialization is correct + __temp = nn.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=torch.cfloat, + ) + + self.linear_r = nn.Linear( + in_features=in_features, + out_features=out_features, + bias=False, + device=device, + dtype=dtype, + ) + self.linear_i = nn.Linear( + in_features=in_features, + out_features=out_features, + bias=False, + device=device, + dtype=dtype, + ) + + self.linear_r.weight.data = __temp.weight.real + self.linear_i.weight.data = __temp.weight.imag + + if bias: + self.bias_r = nn.Parameter(__temp.bias.real.detach().clone()) + self.bias_i = nn.Parameter(__temp.bias.imag.detach().clone()) + else: + self.register_parameter("bias_r", None) + self.register_parameter("bias_i", None) + + @property + def weight(self) -> torch.Tensor: + return torch.complex(self.linear_r.weight, self.linear_i.weight) + + @property + def bias(self) -> torch.Tensor: + if self.bias_r is None: + return None + return torch.complex(self.bias_r, self.bias_i) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r""" + Computes multiplication 25% faster than naive method by using Gauss' multiplication trick + """ + t1 = self.linear_r(input.real) + t2 = self.linear_i(input.imag) + t3 = F.linear( + input=(input.real + input.imag), + weight=(self.linear_r.weight + self.linear_i.weight), + bias=None, + ) + out_r = t1 - t2 + out_i = t3 - t2 - t1 + if self.bias_r is not None: + out_r = out_r + self.bias_r + out_i = out_i + self.bias_i + return torch.complex(out_r, out_i) diff --git a/complextorch/nn/init.py b/complextorch/nn/init.py new file mode 100644 index 0000000..9f0bbd1 --- /dev/null +++ b/complextorch/nn/init.py @@ -0,0 +1,248 @@ +r""" +Weight Initialization for Complex Tensors +========================================= + +Drop-in complex-valued analogues of :mod:`torch.nn.init`. PyTorch's built-in +initializers were designed for real tensors and produce the wrong variance +when applied directly to a complex parameter (each part is treated as +independent of the other, so :math:`\mathrm{Var}(|w|^2)` is too large by a +factor of 2). The functions in this module correct for that. + +All functions mutate ``tensor`` in place and return it, mirroring +:mod:`torch.nn.init`. + +Functions +--------- + +- :func:`kaiming_normal_`, :func:`kaiming_uniform_` — He (Kaiming, 2015). +- :func:`xavier_normal_`, :func:`xavier_uniform_` — Glorot (2010). +- :func:`trabelsi_standard_` — polar Rayleigh-uniform initializer from + Trabelsi et al. (2018) "Deep Complex Networks". +- :func:`trabelsi_independent_` — semi-unitary (orthogonal) complex + initializer from the same paper. +""" + +import math + +import torch + +__all__ = [ + "kaiming_normal_", + "kaiming_uniform_", + "trabelsi_independent_", + "trabelsi_standard_", + "xavier_normal_", + "xavier_uniform_", +] + + +def _get_fans(tensor: torch.Tensor) -> tuple[int, int]: + """Compute fan-in / fan-out for a complex tensor (same as the real form).""" + if tensor.dim() < 2: + # Linear bias or 1-D weight: treat as (fan_in,), fan_out = 1. + fan_in = tensor.numel() + fan_out = tensor.numel() + return fan_in, fan_out + fan_in = ( + tensor.size(1) * tensor[0][0].numel() if tensor.dim() > 2 else tensor.size(1) + ) + fan_out = ( + tensor.size(0) * tensor[0][0].numel() if tensor.dim() > 2 else tensor.size(0) + ) + return fan_in, fan_out + + +def _calculate_gain(nonlinearity: str, a: float = 0.0) -> float: + """Mirror :func:`torch.nn.init.calculate_gain` for a small set of activations.""" + if nonlinearity in ("linear", "conv1d", "conv2d", "conv3d", "sigmoid"): + return 1.0 + if nonlinearity == "tanh": + return 5.0 / 3.0 + if nonlinearity == "relu": + return math.sqrt(2.0) + if nonlinearity == "leaky_relu": + return math.sqrt(2.0 / (1.0 + a * a)) + if nonlinearity == "selu": + return 3.0 / 4.0 + raise ValueError(f"Unsupported nonlinearity {nonlinearity!r}") + + +def _check_complex(tensor: torch.Tensor) -> None: + if not tensor.is_complex(): + raise TypeError( + f"complextorch.nn.init expects a complex tensor, got dtype={tensor.dtype}" + ) + + +# --------------------------------------------------------------------------- +# Kaiming / He +# --------------------------------------------------------------------------- + + +def kaiming_normal_( + tensor: torch.Tensor, + a: float = 0.0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", +) -> torch.Tensor: + r""" + Complex Kaiming Normal Initialization + ------------------------------------- + + Draws ``tensor.real`` and ``tensor.imag`` independently from + :math:`\mathcal{N}(0, \sigma^2)` with + :math:`\sigma = \text{gain} / \sqrt{2 \cdot \text{fan}}` so that + :math:`\mathrm{Var}(|w|^2) = 2 \cdot \sigma^2 = \text{gain}^2 / \text{fan}` — + matching He's target for the complex magnitude. + """ + _check_complex(tensor) + fan_in, fan_out = _get_fans(tensor) + fan = fan_in if mode == "fan_in" else fan_out + gain = _calculate_gain(nonlinearity, a) + std = gain / math.sqrt(2.0 * fan) + with torch.no_grad(): + tensor.real.normal_(0.0, std) + tensor.imag.normal_(0.0, std) + return tensor + + +def kaiming_uniform_( + tensor: torch.Tensor, + a: float = 0.0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", +) -> torch.Tensor: + r"""Complex Kaiming Uniform Initialization. See :func:`kaiming_normal_`.""" + _check_complex(tensor) + fan_in, fan_out = _get_fans(tensor) + fan = fan_in if mode == "fan_in" else fan_out + gain = _calculate_gain(nonlinearity, a) + std = gain / math.sqrt(2.0 * fan) + bound = math.sqrt(3.0) * std # uniform[-bound, bound] has std = bound/sqrt(3) + with torch.no_grad(): + tensor.real.uniform_(-bound, bound) + tensor.imag.uniform_(-bound, bound) + return tensor + + +# --------------------------------------------------------------------------- +# Xavier / Glorot +# --------------------------------------------------------------------------- + + +def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0) -> torch.Tensor: + r""" + Complex Xavier Normal Initialization + ------------------------------------ + + Draws each part from :math:`\mathcal{N}(0, \sigma^2)` with + :math:`\sigma = \text{gain} / \sqrt{\text{fan\_in} + \text{fan\_out}}` + so that :math:`\mathrm{Var}(|w|^2) = 2 \sigma^2 = 2 \cdot \text{gain}^2 / (\text{fan\_in} + \text{fan\_out})`. + """ + _check_complex(tensor) + fan_in, fan_out = _get_fans(tensor) + std = gain / math.sqrt(fan_in + fan_out) + with torch.no_grad(): + tensor.real.normal_(0.0, std) + tensor.imag.normal_(0.0, std) + return tensor + + +def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0) -> torch.Tensor: + r"""Complex Xavier Uniform Initialization. See :func:`xavier_normal_`.""" + _check_complex(tensor) + fan_in, fan_out = _get_fans(tensor) + std = gain / math.sqrt(fan_in + fan_out) + bound = math.sqrt(3.0) * std + with torch.no_grad(): + tensor.real.uniform_(-bound, bound) + tensor.imag.uniform_(-bound, bound) + return tensor + + +# --------------------------------------------------------------------------- +# Trabelsi (Deep Complex Networks, 2018) +# --------------------------------------------------------------------------- + + +def trabelsi_standard_(tensor: torch.Tensor, kind: str = "glorot") -> torch.Tensor: + r""" + Trabelsi Polar (Rayleigh-Uniform) Initializer + --------------------------------------------- + + Polar parameterization from Trabelsi et al. (2018) "Deep Complex Networks". + + The magnitude :math:`|w|` is drawn from a Rayleigh distribution with scale + :math:`\sigma` and the phase :math:`\arg w` is drawn uniformly from + :math:`[-\pi, \pi]`: + + .. math:: + + |w| \sim \mathrm{Rayleigh}(\sigma), \qquad + \arg w \sim \mathcal{U}[-\pi, \pi], \qquad + w = |w| \cdot e^{j \arg w}. + + With ``kind='glorot'``, :math:`\sigma = 1 / \sqrt{\text{fan\_in} + \text{fan\_out}}`; + with ``kind='he'``, :math:`\sigma = 1 / \sqrt{\text{fan\_in}}`. + """ + _check_complex(tensor) + fan_in, fan_out = _get_fans(tensor) + if kind in ("glorot", "xavier"): + sigma = 1.0 / math.sqrt(fan_in + fan_out) + elif kind in ("he", "kaiming"): + sigma = 1.0 / math.sqrt(fan_in) + else: + raise ValueError( + f"Unknown kind {kind!r}; expected 'glorot'/'xavier' or 'he'/'kaiming'" + ) + # Rayleigh(sigma) samples: |w| = sigma * sqrt(-2 ln U) for U ~ Uniform(0, 1]. + with torch.no_grad(): + u = torch.empty_like(tensor.real).uniform_(1e-12, 1.0) + magnitude = sigma * torch.sqrt(-2.0 * torch.log(u)) + phase = torch.empty_like(tensor.real).uniform_(-math.pi, math.pi) + tensor.real.copy_(magnitude * torch.cos(phase)) + tensor.imag.copy_(magnitude * torch.sin(phase)) + return tensor + + +def trabelsi_independent_(tensor: torch.Tensor, kind: str = "glorot") -> torch.Tensor: + r""" + Trabelsi Semi-Unitary (Independent) Initializer + ----------------------------------------------- + + Complex orthogonal init from Trabelsi et al. (2018). + + Generates a random complex matrix of the same flat shape and replaces its + singular values with a constant via SVD, yielding a semi-unitary weight + satisfying :math:`W^* W = c \cdot I` (or :math:`W W^* = c \cdot I` for + wide matrices). The constant :math:`c` is chosen to match either the + Glorot or He variance target. + """ + _check_complex(tensor) + if tensor.dim() < 2: + raise ValueError("trabelsi_independent_ requires a tensor of at least 2 dims") + + fan_in, fan_out = _get_fans(tensor) + if kind in ("glorot", "xavier"): + scale = 1.0 / math.sqrt(fan_in + fan_out) + elif kind in ("he", "kaiming"): + scale = 1.0 / math.sqrt(fan_in) + else: + raise ValueError(f"Unknown kind {kind!r}") + + # Flatten to (out, in_total) for SVD. + out_dim = tensor.size(0) + in_total = tensor.numel() // out_dim + rows, cols = out_dim, in_total + + # Draw a random complex matrix and take its (truncated) SVD. + rng = torch.empty(rows, cols, dtype=tensor.dtype, device=tensor.device) + with torch.no_grad(): + rng.real.normal_(0.0, 1.0) + rng.imag.normal_(0.0, 1.0) + u, _, vh = torch.linalg.svd(rng, full_matrices=False) + # Smallest dim k = min(rows, cols); semi-unitary product u @ vh has shape (rows, cols). + w = u @ vh + w = w * scale + tensor.copy_(w.reshape(tensor.shape)) + return tensor diff --git a/complextorch/nn/masked/__init__.py b/complextorch/nn/masked/__init__.py new file mode 100644 index 0000000..928acaf --- /dev/null +++ b/complextorch/nn/masked/__init__.py @@ -0,0 +1,88 @@ +r""" +Masked / Pruned Layers + Module-Walking Helpers +=============================================== + +Fixed-sparsity-pattern complex layers and helpers for managing their masks +across a whole network. +""" + +from collections.abc import Iterator + +import torch + +from complextorch.nn.masked.base import BaseMasked, MaskedWeightMixin +from complextorch.nn.masked.conv import Conv1dMasked, Conv2dMasked, Conv3dMasked +from complextorch.nn.masked.linear import BilinearMasked, LinearMasked + +__all__ = [ + "BaseMasked", + "BilinearMasked", + "Conv1dMasked", + "Conv2dMasked", + "Conv3dMasked", + "LinearMasked", + "MaskedWeightMixin", + "binarize_masks", + "deploy_masks", + "is_sparse", + "named_masks", +] + + +def deploy_masks( + model: torch.nn.Module, + state_dict: dict[str, torch.Tensor], + *, + strict: bool = True, +) -> torch.nn.Module: + r""" + Load a ``{name: mask}`` dict into the matching :class:`BaseMasked` + submodules of ``model``. + + Keys in ``state_dict`` are interpreted as fully-qualified module names + (e.g. ``"encoder.layer1.linear.mask"`` or just ``"linear.mask"``). Any + key ending in ``".mask"`` is matched to the corresponding submodule's + ``mask`` buffer. + """ + for full_key, value in state_dict.items(): + if not full_key.endswith("mask"): + continue + mod_path = ( + full_key[: -len(".mask")] + if full_key.endswith(".mask") + else full_key[: -len("mask")] + ) + # Walk to the target module. + mod = model + if mod_path: + for part in mod_path.split("."): + mod = getattr(mod, part, None) + if mod is None: + break + if isinstance(mod, BaseMasked): + mod.mask_(value) + elif strict: + raise KeyError(f"deploy_masks: no BaseMasked submodule at {mod_path!r}") + return model + + +def binarize_masks(model: torch.nn.Module) -> torch.nn.Module: + r"""In-place binarize every mask attached to a :class:`BaseMasked` submodule.""" + for mod in model.modules(): + if isinstance(mod, BaseMasked) and mod.is_sparse: + mod.mask_((mod.mask != 0).to(mod.mask.dtype)) + return model + + +def is_sparse(layer: torch.nn.Module) -> bool: + """``True`` if ``layer`` is a :class:`BaseMasked` with a mask set.""" + return isinstance(layer, BaseMasked) and layer.is_sparse + + +def named_masks( + model: torch.nn.Module, +) -> Iterator[tuple[str, torch.Tensor]]: + """Yield ``(qualified_name, mask)`` for each currently-set mask.""" + for name, mod in model.named_modules(): + if isinstance(mod, BaseMasked) and mod.is_sparse: + yield name, mod.mask diff --git a/complextorch/nn/masked/base.py b/complextorch/nn/masked/base.py new file mode 100644 index 0000000..9b7522b --- /dev/null +++ b/complextorch/nn/masked/base.py @@ -0,0 +1,114 @@ +r""" +Masked-Layer Base Classes +========================= + +Building blocks for layers that apply a fixed binary mask to their weight at +forward time. Used for inference-time pruning: train with +:mod:`complextorch.nn.relevance`, extract a relevance mask via +:func:`complextorch.nn.relevance.compute_ard_masks`, then load it into a +masked layer with :func:`complextorch.nn.masked.deploy_masks`. + +Adapted from :mod:`cplxmodule.nn.masked.base`. +""" + +import torch +import torch.nn as nn + +from complextorch.nn.utils.sparsity import SparsityStats + +__all__ = ["BaseMasked", "MaskedWeightMixin"] + + +class BaseMasked(nn.Module): + r""" + Base for layers with a fixed binary mask buffer applied to ``self.weight``. + + The mask buffer (``self.mask``) is a real-valued tensor of the same shape + as the parameter, with ``0`` marking a dropped weight and ``1`` marking a + kept weight. The :attr:`is_sparse` property returns ``True`` when a mask + is currently set. + """ + + def __init__(self) -> None: + super().__init__() + self.register_buffer("mask", None) + + @property + def is_sparse(self) -> bool: + return isinstance(self.mask, torch.Tensor) + + def mask_(self, mask): + if mask is not None and not isinstance(mask, torch.Tensor): + raise TypeError( + f"`mask` must be a Tensor or None, got {type(mask).__name__}" + ) + if mask is not None: + mask = mask.detach().to( + self.weight.device, + ( + self.weight.real.dtype + if self.weight.is_complex() + else self.weight.dtype + ), + ) + mask = mask.expand(self.weight.shape).contiguous() + self.register_buffer("mask", mask) + elif self.is_sparse and mask is None: + del self.mask + self.register_buffer("mask", None) + return self + + def __setattr__(self, name, value): + if name != "mask": + return super().__setattr__(name, value) + self.mask_(value) + return None + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + mask_key = prefix + "mask" + super()._load_from_state_dict( + {k: v for k, v in state_dict.items() if k != mask_key}, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + mask_in_missing = mask_key in missing_keys + if mask_key in state_dict: + if mask_in_missing: + missing_keys.remove(mask_key) + self.mask_(state_dict[mask_key]) + elif strict and not mask_in_missing: + missing_keys.append(mask_key) + + +class MaskedWeightMixin(SparsityStats): + r"""Provides ``weight_masked`` returning ``self.weight * self.mask``.""" + + __sparsity_ignore__ = () + + @property + def weight_masked(self) -> torch.Tensor: + if not getattr(self, "is_sparse", False): + raise RuntimeError( + f"`{type(self).__name__}` has no sparsity mask. " + "Set ``.mask`` or call ``deploy_masks(...)``." + ) + # Complex weight * real mask broadcasts correctly. + return self.weight * self.mask + + def sparsity(self, **kwargs): + weight = self.weight + n_dropped = float((self.mask == 0).sum().item()) if self.is_sparse else 0.0 + return [(id(weight), n_dropped)] diff --git a/complextorch/nn/masked/conv.py b/complextorch/nn/masked/conv.py new file mode 100644 index 0000000..e00a881 --- /dev/null +++ b/complextorch/nn/masked/conv.py @@ -0,0 +1,103 @@ +r""" +Masked Conv (Complex) +===================== +""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from complextorch.nn.masked.base import BaseMasked, MaskedWeightMixin + +__all__ = ["Conv1dMasked", "Conv2dMasked", "Conv3dMasked"] + + +def _to_tuple(x, n: int) -> tuple[int, ...]: + if isinstance(x, int): + return (x,) * n + return tuple(x) + + +class _ConvMaskedNd(MaskedWeightMixin, BaseMasked): + _conv_fn = staticmethod(F.conv1d) + _nd = 1 + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype: torch.dtype = torch.cfloat, + ) -> None: + super().__init__() + if padding_mode != "zeros": + raise ValueError( + f"Only padding_mode='zeros' is supported in masked conv, got {padding_mode!r}" + ) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _to_tuple(kernel_size, self._nd) + self.stride = _to_tuple(stride, self._nd) + self.padding = ( + padding if isinstance(padding, str) else _to_tuple(padding, self._nd) + ) + self.dilation = _to_tuple(dilation, self._nd) + self.groups = groups + + weight_shape = (out_channels, in_channels // groups, *self.kernel_size) + self.weight = nn.Parameter( + torch.empty(*weight_shape, device=device, dtype=dtype) + ) + if bias: + self.bias = nn.Parameter( + torch.empty(out_channels, device=device, dtype=dtype) + ) + else: + self.register_parameter("bias", None) + + fan_in = in_channels // groups + for k in self.kernel_size: + fan_in *= k + bound = 1.0 / math.sqrt(fan_in) + with torch.no_grad(): + self.weight.real.uniform_(-bound, bound) + self.weight.imag.uniform_(-bound, bound) + if self.bias is not None: + self.bias.real.uniform_(-bound, bound) + self.bias.imag.uniform_(-bound, bound) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + w = self.weight_masked if self.is_sparse else self.weight + return self._conv_fn( + input, + w, + self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + + +class Conv1dMasked(_ConvMaskedNd): + _conv_fn = staticmethod(F.conv1d) + _nd = 1 + + +class Conv2dMasked(_ConvMaskedNd): + _conv_fn = staticmethod(F.conv2d) + _nd = 2 + + +class Conv3dMasked(_ConvMaskedNd): + _conv_fn = staticmethod(F.conv3d) + _nd = 3 diff --git a/complextorch/nn/masked/linear.py b/complextorch/nn/masked/linear.py new file mode 100644 index 0000000..a79f15b --- /dev/null +++ b/complextorch/nn/masked/linear.py @@ -0,0 +1,98 @@ +r""" +Masked Linear / Bilinear (Complex) +================================== + +Layers that apply a fixed binary mask to their complex weight at forward +time. Used to deploy a learned-sparsity pattern at inference. +""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from complextorch.nn.masked.base import BaseMasked, MaskedWeightMixin + +__all__ = ["BilinearMasked", "LinearMasked"] + + +def _init_complex_weight(weight: torch.Tensor, fan_in: int) -> None: + bound = 1.0 / math.sqrt(fan_in) + with torch.no_grad(): + weight.real.uniform_(-bound, bound) + weight.imag.uniform_(-bound, bound) + + +class LinearMasked(MaskedWeightMixin, BaseMasked): + r"""Complex linear with a fixed binary weight mask.""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype: torch.dtype = torch.cfloat, + ) -> None: + super().__init__() # BaseMasked.__init__ -> registers `mask` buffer + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.empty(out_features, in_features, device=device, dtype=dtype) + ) + if bias: + self.bias = nn.Parameter( + torch.empty(out_features, device=device, dtype=dtype) + ) + else: + self.register_parameter("bias", None) + _init_complex_weight(self.weight, in_features) + if self.bias is not None: + _init_complex_weight(self.bias, in_features) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + w = self.weight_masked if self.is_sparse else self.weight + return F.linear(input, w, self.bias) + + +class BilinearMasked(MaskedWeightMixin, BaseMasked): + r"""Complex bilinear with a fixed binary weight mask.""" + + def __init__( + self, + in1_features: int, + in2_features: int, + out_features: int, + bias: bool = True, + conjugate: bool = True, + device=None, + dtype: torch.dtype = torch.cfloat, + ) -> None: + super().__init__() + self.in1_features = in1_features + self.in2_features = in2_features + self.out_features = out_features + self.conjugate = conjugate + self.weight = nn.Parameter( + torch.empty( + out_features, in1_features, in2_features, device=device, dtype=dtype + ) + ) + if bias: + self.bias = nn.Parameter( + torch.empty(out_features, device=device, dtype=dtype) + ) + else: + self.register_parameter("bias", None) + _init_complex_weight(self.weight, in1_features) + if self.bias is not None: + _init_complex_weight(self.bias, in1_features) + + def forward(self, input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + x1 = input1.conj() if self.conjugate else input1 + w = self.weight_masked if self.is_sparse else self.weight + out = torch.einsum("...i,kij,...j->...k", x1, w, input2) + if self.bias is not None: + out = out + self.bias + return out diff --git a/complextorch/nn/modules/activation/__init__.py b/complextorch/nn/modules/activation/__init__.py index b4d00eb..15d197f 100755 --- a/complextorch/nn/modules/activation/__init__.py +++ b/complextorch/nn/modules/activation/__init__.py @@ -1,17 +1,75 @@ -from .complex_relu import CVSplitReLU, CReLU, CPReLU -from .fully_complex import CVSigmoid, zReLU, CVCardiod, CVSigLog -from .split_type_A import ( - GeneralizedSplitActivation, - CVSplitTanh, - CTanh, - CVSplitSigmoid, - CSigmoid, - CVSplitAbs, -) -from .split_type_B import ( - GeneralizedPolarActivation, - CVPolarTanh, - CVPolarSquash, - CVPolarLog, - modReLU, -) +from complextorch.nn.modules.activation.complex_relu import ( + CVSplitReLU, + CReLU, + CPReLU, + zAbsReLU, + zLeakyReLU, + GTReLU, + EquivariantPhaseReLU, +) +from complextorch.nn.modules.activation.fully_complex import ( + CVSigmoid, + zReLU, + CVCardiod, + CVSigLog, + Mod, +) +from complextorch.nn.modules.activation.split_type_A import ( + GeneralizedSplitActivation, + CVSplitTanh, + CTanh, + CVSplitSigmoid, + CSigmoid, + CVSplitAbs, + CVSplitELU, + CELU, + CVSplitCELU, + CCELU, + CVSplitGELU, + CGELU, +) +from complextorch.nn.modules.activation.split_type_B import ( + GeneralizedPolarActivation, + CVPolarTanh, + CVPolarSquash, + CVPolarLog, + modReLU, + AdaptiveModReLU, +) + +__all__ = [ + "CCELU", + "CELU", + "CGELU", + "AdaptiveModReLU", + "CPReLU", + "CReLU", + "CSigmoid", + "CTanh", + "CVCardiod", + "CVPolarLog", + "CVPolarSquash", + "CVPolarTanh", + "CVSigLog", + # fully_complex + "CVSigmoid", + "CVSplitAbs", + "CVSplitCELU", + "CVSplitELU", + "CVSplitGELU", + # complex_relu + "CVSplitReLU", + "CVSplitSigmoid", + "CVSplitTanh", + "EquivariantPhaseReLU", + "GTReLU", + # split_type_B + "GeneralizedPolarActivation", + # split_type_A + "GeneralizedSplitActivation", + "Mod", + "modReLU", + "zAbsReLU", + "zLeakyReLU", + "zReLU", +] diff --git a/complextorch/nn/modules/activation/complex_relu.py b/complextorch/nn/modules/activation/complex_relu.py index 8a5e2b1..f8dc289 100755 --- a/complextorch/nn/modules/activation/complex_relu.py +++ b/complextorch/nn/modules/activation/complex_relu.py @@ -1,8 +1,19 @@ +import math + +import torch import torch.nn as nn -from .split_type_A import GeneralizedSplitActivation +from complextorch.nn.modules.activation.split_type_A import GeneralizedSplitActivation -__all__ = ["CVSplitReLU", "CReLU", "CPReLU"] +__all__ = [ + "CPReLU", + "CReLU", + "CVSplitReLU", + "EquivariantPhaseReLU", + "GTReLU", + "zAbsReLU", + "zLeakyReLU", +] class CVSplitReLU(GeneralizedSplitActivation): @@ -26,7 +37,7 @@ class CVSplitReLU(GeneralizedSplitActivation): """ def __init__(self, inplace: bool = True) -> None: - super(CVSplitReLU, self).__init__(nn.ReLU(inplace), nn.ReLU(inplace)) + super().__init__(nn.ReLU(inplace), nn.ReLU(inplace)) class CReLU(CVSplitReLU): @@ -43,8 +54,6 @@ class CReLU(CVSplitReLU): Alias for :class:`CVSplitReLU`. The nomenclature CReLU is used only in certain literature to denote the split complex-valued rectified linear unit. """ - pass - class CPReLU(GeneralizedSplitActivation): r""" @@ -69,4 +78,249 @@ class CPReLU(GeneralizedSplitActivation): """ def __init__(self) -> None: - super(CPReLU, self).__init__(nn.PReLU(), nn.PReLU()) + super().__init__(nn.PReLU(), nn.PReLU()) + + +class zAbsReLU(nn.Module): + r""" + Magnitude-Thresholded ReLU with Learnable Threshold + --------------------------------------------------- + + Zeros out elements whose magnitude is below a learnable threshold + :math:`a`, preserving the phase of passing elements: + + .. math:: + + \texttt{zAbsReLU}(z) = \begin{cases} + z & \text{if } |z| \geq a \\ + 0 & \text{otherwise} + \end{cases} + + Args: + a_init: initial value of the (scalar) threshold parameter. + """ + + def __init__(self, a_init: float = 0.0) -> None: + super().__init__() + self.a = nn.Parameter(torch.tensor(float(a_init))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + mask = (input.abs() >= self.a).to( + input.dtype if input.is_complex() else input.real.dtype + ) + if input.is_complex(): + return input * mask + return input * mask + + +class zLeakyReLU(nn.Module): + r""" + Leaky First-Quadrant Complex ReLU + --------------------------------- + + Soft version of :class:`zReLU`: passes :math:`z` unchanged when both + :math:`\Re z > 0` and :math:`\Im z > 0`; scales by ``negative_slope`` + elsewhere. + + .. math:: + + \texttt{zLeakyReLU}(z) = \begin{cases} + z & \text{if } \Re z > 0 \text{ and } \Im z > 0 \\ + \alpha\, z & \text{otherwise} + \end{cases} + """ + + def __init__(self, negative_slope: float = 0.01) -> None: + super().__init__() + self.negative_slope = negative_slope + + def forward(self, input: torch.Tensor) -> torch.Tensor: + in_q1 = (input.real > 0) & (input.imag > 0) + scale = torch.where( + in_q1, + torch.ones_like(input.real), + torch.full_like(input.real, self.negative_slope), + ) + return input * scale + + def extra_repr(self) -> str: + return f"negative_slope={self.negative_slope}" + + +# --------------------------------------------------------------------------- +# CDS phase-thresholding activations (Singhal, Xing, Yu — CVPR 2022) +# --------------------------------------------------------------------------- + + +class _PhaseHalfPlaneMask(torch.autograd.Function): + r"""Forward: :math:`\theta \mapsto \theta \cdot \mathbf{1}[\theta \bmod 2\pi \in [0, \pi]]`. + Backward: the gradient is the mask itself. + + Matches ``Two_Channel_Nonlinearity`` from ``cds/layers.py:494-520``. + """ + + @staticmethod + def forward(ctx, phase: torch.Tensor) -> torch.Tensor: + wrapped = phase % (2.0 * math.pi) + mask = ((wrapped >= 0.0) & (wrapped <= math.pi)).to(phase.dtype) + ctx.save_for_backward(mask) + return phase * mask + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + (mask,) = ctx.saved_tensors + return grad_output * mask + + +def _broadcast_channelwise(t: torch.Tensor, input_dim: int) -> torch.Tensor: + """Broadcast a 1-D parameter ``(C,)`` to the channel dim (1) of an N-D input.""" + if t.dim() == 1 and input_dim > 1: + shape = [1] * input_dim + shape[1] = t.shape[0] + return t.view(*shape) + return t + + +class GTReLU(nn.Module): + r""" + Gabor-Tangent ReLU (CDS) + ------------------------ + + Phase-thresholding nonlinearity used in the I-type CDS network. Composes a + learnable complex scaling, a half-plane phase mask, and an optional + learnable phase rescaling: + + 1. Scale: :math:`z' = (\alpha + j\beta) \cdot z`, with per-channel + :math:`\alpha, \beta \in \mathbb{R}^C` learnable. + 2. Mask phase: pass only phases in the upper half-plane: + + .. math:: + + \tilde{\theta} = \arg(z') \cdot \mathbf{1}[\arg(z') \bmod 2\pi \in [0, \pi]] + + 3. Recombine: :math:`\operatorname{out} = |z'| \cdot e^{j\tilde{\theta}}`. + 4. (Optional, when ``phase_scale=True``) rescale the masked phase by + :math:`\operatorname{clamp}(\lambda, 0.5, 2)` with :math:`\lambda \in \mathbb{R}^C` + learnable (initialised to 1). + + The mask gradient is implemented via a custom + :class:`torch.autograd.Function` (the mask itself is the gradient). + + Based on work from the following paper: + + **U. Singhal, Y. Xing, S. X. Yu. Co-Domain Symmetry for Complex-Valued Deep Learning.** + + - CVPR 2022 — `GTReLU` in the reference implementation + + - https://openaccess.thecvf.com/content/CVPR2022/papers/Singhal_Co-Domain_Symmetry_for_Complex-Valued_Deep_Learning_CVPR_2022_paper.pdf + + Args: + num_channels: number of complex channels. + global_scaling: if True, share a single scalar :math:`(\alpha, \beta)` + across all channels. + phase_scale: if True, add the per-channel learnable phase rescale. + """ + + def __init__( + self, + num_channels: int, + global_scaling: bool = False, + phase_scale: bool = False, + ) -> None: + super().__init__() + n = 1 if global_scaling else num_channels + self.num_channels = num_channels + self.global_scaling = global_scaling + self.phase_scale = phase_scale + self.alpha = nn.Parameter(torch.empty(n).uniform_(0.0, 1.0)) + self.beta = nn.Parameter(torch.empty(n).uniform_(0.0, 1.0)) + if phase_scale: + self.lambd = nn.Parameter(torch.ones(num_channels)) + else: + self.register_parameter("lambd", None) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not input.is_complex(): + input = input.to(torch.cfloat) + alpha = _broadcast_channelwise(self.alpha, input.dim()) + beta = _broadcast_channelwise(self.beta, input.dim()) + scale = torch.complex(alpha, beta) + z = input * scale + magnitude = z.abs() + phase = z.angle() + masked_phase = _PhaseHalfPlaneMask.apply(phase) + if self.lambd is not None: + lam = _broadcast_channelwise(self.lambd, input.dim()) + masked_phase = masked_phase * lam.clamp(min=0.5, max=2.0) + return torch.polar(magnitude, masked_phase) + + def extra_repr(self) -> str: + return ( + f"num_channels={self.num_channels}, global_scaling={self.global_scaling}, " + f"phase_scale={self.phase_scale}" + ) + + +class EquivariantPhaseReLU(nn.Module): + r""" + Equivariant Phase ReLU (CDS) + ---------------------------- + + The U(1)-equivariant counterpart of :class:`GTReLU`. Thresholds phase + *relative to the channel-mean direction*, so the operator commutes with + any global phase rotation: + + 1. Channel-mean reference direction: + + .. math:: + + \hat{p} = \frac{\mathrm{mean}_c(z)}{|\mathrm{mean}_c(z)| + \varepsilon} + + 2. Relative phase: :math:`\varphi = \arg(z \cdot \overline{\hat{p}})`. + 3. Threshold (half-plane mask, then per-channel scale): + + .. math:: + + \tilde{\varphi} = \varphi \cdot \mathbf{1}[\varphi \bmod 2\pi \in [0, \pi]] \cdot \operatorname{ReLU}(s) + + with :math:`s \in \mathbb{R}^C` learnable (init 1). + 4. Output: :math:`|z| \cdot e^{j\tilde{\varphi}} \cdot \hat{p}`. + + Rotating the input by :math:`e^{j\psi}` rotates :math:`\hat{p}` by the same + angle, leaving :math:`\varphi` invariant and rotating the output by + :math:`\psi` — exact U(1)-equivariance. + + Based on work from the following paper: + + **U. Singhal, Y. Xing, S. X. Yu. Co-Domain Symmetry for Complex-Valued Deep Learning.** + + - CVPR 2022 — `eqnl` in the reference implementation + + - https://openaccess.thecvf.com/content/CVPR2022/papers/Singhal_Co-Domain_Symmetry_for_Complex-Valued_Deep_Learning_CVPR_2022_paper.pdf + + Args: + num_channels: number of complex channels in the input. + eps: numerical floor when normalising the channel mean. + """ + + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.num_channels = num_channels + self.eps = eps + self.phase_gain = nn.Parameter(torch.ones(num_channels)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not input.is_complex(): + input = input.to(torch.cfloat) + # Channel-mean reference direction (keepdim so it broadcasts back). + ref = input.mean(dim=1, keepdim=True) + ref = ref / (ref.abs() + self.eps) + # Relative phase: arg(z * conj(ref)) + relative_phase = (input * ref.conj()).angle() + masked = _PhaseHalfPlaneMask.apply(relative_phase) + gain = _broadcast_channelwise(self.phase_gain, input.dim()) + masked = masked * torch.relu(gain) + return input.abs() * torch.polar(torch.ones_like(masked), masked) * ref + + def extra_repr(self) -> str: + return f"num_channels={self.num_channels}, eps={self.eps}" diff --git a/complextorch/nn/modules/activation/fully_complex.py b/complextorch/nn/modules/activation/fully_complex.py index 6960e3b..e51948f 100755 --- a/complextorch/nn/modules/activation/fully_complex.py +++ b/complextorch/nn/modules/activation/fully_complex.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -__all__ = ["CVSigmoid", "zReLU", "CVCardiod", "CVSigLog"] +__all__ = ["CVCardiod", "CVSigLog", "CVSigmoid", "Mod", "zReLU"] class CVSigmoid(nn.Module): @@ -27,7 +27,7 @@ class CVSigmoid(nn.Module): """ def __init__(self) -> None: - super(CVSigmoid, self).__init__() + super().__init__() def forward(self, input: torch.Tensor) -> torch.Tensor: r"""Computes the complex-valued sigmoid activation function. @@ -75,7 +75,7 @@ class zReLU(nn.Module): """ def __init__(self) -> None: - super(zReLU, self).__init__() + super().__init__() def forward(self, input: torch.Tensor) -> torch.Tensor: r"""Computes the complex-valued Guberman ReLU. @@ -87,7 +87,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: :math:`\begin{cases} \mathbf{z} \quad \text{if} \quad \angle\mathbf{z} \in [0, \pi/2] \\ 0 \quad \text{else} \end{cases}` """ x_angle = input.angle() - mask = (0 <= x_angle) & (x_angle <= torch.pi / 2) + mask = (x_angle >= 0) & (x_angle <= torch.pi / 2) return input * mask @@ -119,7 +119,7 @@ class CVCardiod(nn.Module): """ def __init__(self) -> None: - super(CVCardiod, self).__init__() + super().__init__() def forward(self, input: torch.Tensor) -> torch.Tensor: r"""Computes the complex-valued cardioid activation function. @@ -153,7 +153,7 @@ class CVSigLog(nn.Module): """ def __init__(self, c: float = 1.0, r: float = 1.0) -> None: - super(CVSigLog, self).__init__() + super().__init__() self.c = c self.r = r @@ -168,3 +168,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: :math:`\frac{\mathbf{z}}{(c + 1/r * |\mathbf{z}|)}` """ return input / (self.c + input.abs() / self.r) + + +class Mod(nn.Module): + r""" + Magnitude Module + ---------------- + + Returns the magnitude :math:`|z|` of a complex input, producing a real + tensor. Useful inside :class:`torch.nn.Sequential` where you cannot easily + use ``torch.abs`` as a layer. + + .. math:: + + G(\mathbf{z}) = |\mathbf{z}| + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input.abs() diff --git a/complextorch/nn/modules/activation/split_type_A.py b/complextorch/nn/modules/activation/split_type_A.py index 7a1b71f..ad6fe86 100755 --- a/complextorch/nn/modules/activation/split_type_A.py +++ b/complextorch/nn/modules/activation/split_type_A.py @@ -1,15 +1,21 @@ import torch import torch.nn as nn -from ... import functional as cvF +from complextorch.nn import functional as cvF __all__ = [ - "GeneralizedSplitActivation", - "CVSplitTanh", - "CTanh", - "CVSplitSigmoid", + "CCELU", + "CELU", + "CGELU", "CSigmoid", + "CTanh", "CVSplitAbs", + "CVSplitCELU", + "CVSplitELU", + "CVSplitGELU", + "CVSplitSigmoid", + "CVSplitTanh", + "GeneralizedSplitActivation", ] @@ -38,7 +44,7 @@ class GeneralizedSplitActivation(nn.Module): """ def __init__(self, activation_r: nn.Module, activation_i: nn.Module) -> None: - super(GeneralizedSplitActivation, self).__init__() + super().__init__() self.activation_r = activation_r self.activation_i = activation_i @@ -75,7 +81,7 @@ class CVSplitTanh(GeneralizedSplitActivation): """ def __init__(self) -> None: - super(CVSplitTanh, self).__init__(nn.Tanh(), nn.Tanh()) + super().__init__(nn.Tanh(), nn.Tanh()) class CTanh(CVSplitTanh): @@ -97,8 +103,6 @@ class CTanh(CVSplitTanh): - https://ieeexplore.ieee.org/abstract/document/6138313 """ - pass - class CVSplitSigmoid(GeneralizedSplitActivation): r""" @@ -113,7 +117,7 @@ class CVSplitSigmoid(GeneralizedSplitActivation): """ def __init__(self) -> None: - super(CVSplitSigmoid, self).__init__(nn.Sigmoid(), nn.Sigmoid()) + super().__init__(nn.Sigmoid(), nn.Sigmoid()) class CSigmoid(CVSplitSigmoid): @@ -127,8 +131,6 @@ class CSigmoid(CVSplitSigmoid): G(\mathbf{z}) = \text{sigmoid}(\mathbf{x}) + j \text{sigmoid}(\mathbf{y}) """ - pass - class CVSplitAbs(nn.Module): r""" @@ -150,7 +152,7 @@ class CVSplitAbs(nn.Module): """ def __init__(self) -> None: - super(CVSplitAbs, self).__init__() + super().__init__() def forward(self, input: torch.Tensor) -> torch.Tensor: r"""Computes the Type-A split abs() activation function. @@ -162,3 +164,71 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: :math:`|\mathbf{x}| + j |\mathbf{y}|` """ return torch.complex(input.real.abs(), input.imag.abs()) + + +class CVSplitELU(GeneralizedSplitActivation): + r""" + Split Complex-Valued ELU + ------------------------ + + .. math:: + + G(\mathbf{z}) = \text{ELU}(\mathbf{x}) + j\,\text{ELU}(\mathbf{y}) + + where :math:`\text{ELU}` is :class:`torch.nn.ELU`. + """ + + def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: + super().__init__(nn.ELU(alpha, inplace), nn.ELU(alpha, inplace)) + + +class CELU(CVSplitELU): + r"""Alias for :class:`CVSplitELU`. + + Note: not to be confused with :class:`torch.nn.CELU` — this is the + complex-valued *ELU* (matches the naming used by ``torchcvnn``). For the + complex-valued :class:`torch.nn.CELU` see :class:`CVSplitCELU` / + :class:`CCELU`. + """ + + +class CVSplitCELU(GeneralizedSplitActivation): + r""" + Split Complex-Valued CELU + ------------------------- + + .. math:: + + G(\mathbf{z}) = \text{CELU}(\mathbf{x}) + j\,\text{CELU}(\mathbf{y}) + + where :math:`\text{CELU}` is :class:`torch.nn.CELU`. + """ + + def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: + super().__init__(nn.CELU(alpha, inplace), nn.CELU(alpha, inplace)) + + +class CCELU(CVSplitCELU): + r"""Alias for :class:`CVSplitCELU`.""" + + +class CVSplitGELU(GeneralizedSplitActivation): + r""" + Split Complex-Valued GELU + ------------------------- + + .. math:: + + G(\mathbf{z}) = \text{GELU}(\mathbf{x}) + j\,\text{GELU}(\mathbf{y}) + + where :math:`\text{GELU}` is :class:`torch.nn.GELU`. + """ + + def __init__(self, approximate: str = "none") -> None: + super().__init__( + nn.GELU(approximate=approximate), nn.GELU(approximate=approximate) + ) + + +class CGELU(CVSplitGELU): + r"""Alias for :class:`CVSplitGELU`.""" diff --git a/complextorch/nn/modules/activation/split_type_B.py b/complextorch/nn/modules/activation/split_type_B.py index ebd80cd..554694f 100755 --- a/complextorch/nn/modules/activation/split_type_B.py +++ b/complextorch/nn/modules/activation/split_type_B.py @@ -1,14 +1,15 @@ import torch import torch.nn as nn -from ... import functional as cvF +from complextorch.nn import functional as cvF __all__ = [ - "GeneralizedPolarActivation", - "CVPolarTanh", + "AdaptiveModReLU", + "CVPolarLog", "CVPolarSquash", + "CVPolarTanh", + "GeneralizedPolarActivation", "modReLU", - "CVPolarLog", ] @@ -35,7 +36,7 @@ class GeneralizedPolarActivation(nn.Module): """ def __init__(self, activation_mag: nn.Module, activation_phase: nn.Module) -> None: - super(GeneralizedPolarActivation, self).__init__() + super().__init__() self.activation_mag = activation_mag self.activation_phase = activation_phase @@ -76,7 +77,7 @@ class CVPolarTanh(GeneralizedPolarActivation): """ def __init__(self) -> None: - super(CVPolarTanh, self).__init__(nn.Tanh(), None) + super().__init__(nn.Tanh(), None) class _Squash(nn.Module): @@ -91,7 +92,7 @@ class _Squash(nn.Module): """ def __init__(self) -> None: - super(_Squash, self).__init__() + super().__init__() def forward(self, input: torch.Tensor) -> torch.Tensor: r"""Computes the squash functionality. @@ -126,7 +127,7 @@ class CVPolarSquash(GeneralizedPolarActivation): """ def __init__(self): - super(CVPolarSquash, self).__init__(_Squash(), None) + super().__init__(_Squash(), None) class _LogXPlus1(nn.Module): @@ -141,7 +142,7 @@ class _LogXPlus1(nn.Module): """ def __init__(self) -> None: - super(_LogXPlus1, self).__init__() + super().__init__() def forward(self, input: torch.Tensor) -> torch.Tensor: r"""Computes the :math:`\log(x + 1)` functionality. @@ -176,36 +177,25 @@ class CVPolarLog(GeneralizedPolarActivation): """ def __init__(self) -> None: - super(CVPolarLog, self).__init__(_LogXPlus1(), None) + super().__init__(_LogXPlus1(), None) class _modReLU(nn.Module): r""" Helper class to compute :math:`\text{ReLU}(x + b)` on real-valued magnitude torch.Tensor. - Implements the operation: - - .. math:: - - G(x) = \text{ReLU}(x + b) + If ``learnable=True``, ``b`` is an :class:`torch.nn.Parameter` and is + learned. Otherwise it is a fixed scalar buffer. """ - def __init__(self, bias: float = 0.0) -> None: - super(_modReLU, self).__init__() - - assert bias < 0, "bias must be smaller than 0 to have a non-linearity effect" - - self.bias = bias + def __init__(self, bias: float = -0.1, learnable: bool = False) -> None: + super().__init__() + if learnable: + self.bias = nn.Parameter(torch.tensor(float(bias))) + else: + self.register_buffer("bias", torch.tensor(float(bias))) def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Computes the :math:`\text{ReLU}(x + b)` functionality. - - Args: - input (torch.Tensor): input tensor - - Returns: - torch.Tensor: :math:`\text{ReLU}(x + b)` - """ return torch.relu(input + self.bias) @@ -223,6 +213,10 @@ class modReLU(GeneralizedPolarActivation): Notice that :math:`|\mathbf{z}|` (:math:`\mathbf{z}`.abs()) is always positive, so if :math:`b > 0` then :math:`|\mathbf{z}| + b > = 0` always. In order to have any non-linearity effect, :math:`b` must be smaller than :math:`0` (:math:`b < 0`). + With ``learnable=True``, the bias :math:`b` becomes a single trainable + scalar :class:`torch.nn.Parameter` initialised to the value of ``bias``; + with ``learnable=False`` (default) it remains a fixed constant. + Based on work from the following papers: **Martin Arjovsky, Amar Shah, and Yoshua Bengio. Unitary evolution recurrent neural networks.** @@ -240,7 +234,53 @@ class modReLU(GeneralizedPolarActivation): - https://arxiv.org/abs/2302.08286 """ - def __init__(self, bias: float = 0.0) -> None: - assert bias < 0, "bias must be smaller than 0 to have a non-linearity effect" + def __init__(self, bias: float = -0.1, learnable: bool = False) -> None: + # When learnable, the bias may move above 0 during training; only + # validate the initialisation when it's a fixed constant. + if not learnable: + assert bias < 0, ( + "bias must be smaller than 0 to have a non-linearity effect" + ) + super().__init__(_modReLU(bias, learnable=learnable), None) + + +class _AdaptiveModReLUBias(nn.Module): + r"""Helper module: ``ReLU(x + b)`` where ``b`` has shape ``(num_features,)``. + + The bias is broadcast against the channel dimension (assumed to be dim 1 of + the magnitude tensor of shape ``(B, C, ...)``). + """ + + def __init__(self, num_features: int, init: float = -0.1) -> None: + super().__init__() + self.num_features = num_features + self.bias = nn.Parameter(torch.full((num_features,), float(init))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # input shape: (B, C, ...); broadcast bias along C + shape = [1] * input.dim() + shape[1] = self.num_features + b = self.bias.view(*shape) + return torch.relu(input + b) + + +class AdaptiveModReLU(GeneralizedPolarActivation): + r""" + Adaptive (per-channel) modulus Rectified Linear Unit + ---------------------------------------------------- + + Per-channel learnable-threshold variant of :class:`modReLU`. Expects input + shape ``(B, C, ...)``; learns a separate bias :math:`b_c` per channel. + + .. math:: + + G(\mathbf{z})_c = \texttt{ReLU}(|\mathbf{z}_c| + b_c) \odot \exp(j \angle\mathbf{z}_c) + + Args: + num_features: number of channels ``C``. + init: initial value of every :math:`b_c`. Defaults to ``-0.1`` so the + non-linearity is active at start of training. + """ - super(modReLU, self).__init__(_modReLU(bias), None) + def __init__(self, num_features: int, init: float = -0.1) -> None: + super().__init__(_AdaptiveModReLUBias(num_features, init), None) diff --git a/complextorch/nn/modules/attention/__init__.py b/complextorch/nn/modules/attention/__init__.py index 43adf1e..6eea33a 100755 --- a/complextorch/nn/modules/attention/__init__.py +++ b/complextorch/nn/modules/attention/__init__.py @@ -1,10 +1,9 @@ -import numpy as np import torch import torch.nn as nn -from .... import nn as cvnn +from complextorch import nn as cvnn -__all__ = ["ScaledDotProductAttention", "MultiheadAttention"] +__all__ = ["MultiheadAttention", "ScaledDotProductAttention"] class ScaledDotProductAttention(nn.Module): @@ -23,9 +22,9 @@ class ScaledDotProductAttention(nn.Module): where :math:`Q, K, V` are complex-valued tensors, :math:`t` is known as the temperature typically :math:`t = \sqrt{d_{attn}}`, and :math:`\mathcal{S}` is the softmax function. For complex-values, the `traditional softmax function `_ cannot be applied, and variants must be applied. - Included in this library are several options for :doc:`complex-valued softmax <./softmax>` and similar :doc:`masking <./mask>` functions. + Included in this library are several options for :mod:`complex-valued softmax ` and similar :mod:`masking ` functions. - By default, the :class:`CVScaledDotProductAttention` employs the :class:`complextorch.nn.CVSoftmax`, which applies the traditional softmax to the magnitude of the complex-valued tensor while leaving the phase information unchanged. + By default, the :class:`ScaledDotProductAttention` employs :class:`complextorch.nn.CVSoftMax`, which applies the traditional softmax to the real and imaginary parts of the complex-valued tensor separately. If a phase-preserving alternative is preferred, pass :class:`complextorch.nn.PhaseSoftMax` via the ``SoftMaxClass`` argument. """ def __init__( @@ -33,12 +32,21 @@ def __init__( temperature: float, attn_dropout: float = 0.1, SoftMaxClass: nn.Module = cvnn.CVSoftMax, + softmax_on: str = "complex", ) -> None: - super(ScaledDotProductAttention, self).__init__() + super().__init__() + if softmax_on not in ("complex", "real"): + raise ValueError( + f"softmax_on must be 'complex' or 'real', got {softmax_on!r}" + ) self.temperature = temperature self.dropout = cvnn.Dropout(attn_dropout) - self.softmax = SoftMaxClass(dim=-1) + self.softmax_on = softmax_on + if softmax_on == "complex": + self.softmax = SoftMaxClass(dim=-1) + else: + self.softmax = nn.Softmax(dim=-1) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor @@ -51,9 +59,16 @@ def forward( v (torch.Tensor): complex-valued value tensor Returns: - torch.Tensor: \mathcal{S}(Q K^T / t) V + torch.Tensor: \mathcal{S}(Q K^H / t) V """ - attn = torch.matmul(q / self.temperature, k.transpose(-2, -1)) + # Conjugate transpose so the dot product is the Hermitian inner product Q K^H. + attn = torch.matmul(q / self.temperature, k.conj().transpose(-2, -1)) + + if self.softmax_on == "real": + # Softmax on Re[QK^H]: real-valued attention weights × complex V. + weights = self.softmax(attn.real) + weights = self.dropout(weights.to(v.dtype)) + return torch.matmul(weights, v) attn = self.dropout(self.softmax(attn)) return torch.matmul(attn, v) @@ -66,7 +81,7 @@ class MultiheadAttention(nn.Module): Multihead self attention extended to complex-valued tensors. - By default, the :class:`CVMultiheadAttention` employs the :class:`complextorch.nn.CVSoftmax`, which applies the traditional softmax to the magnitude of the complex-valued tensor while leaving the phase information unchanged. + By default, the :class:`MultiheadAttention` employs :class:`complextorch.nn.CVSoftMax`, which applies the traditional softmax to the real and imaginary parts of the complex-valued tensor separately. Pass :class:`complextorch.nn.PhaseSoftMax` via ``SoftMaxClass`` for the phase-preserving alternative. """ def __init__( @@ -77,8 +92,9 @@ def __init__( d_v: int, dropout: float = 0.1, SoftMaxClass: nn.Module = cvnn.CVSoftMax, + softmax_on: str = "complex", ) -> None: - super(MultiheadAttention, self).__init__() + super().__init__() self.d_k = d_k self.d_v = d_v @@ -90,7 +106,10 @@ def __init__( self.fc = cvnn.Linear(n_heads * d_v, d_model, bias=False) self.attention = ScaledDotProductAttention( - temperature=d_k**0.5, attn_dropout=dropout, SoftMaxClass=SoftMaxClass + temperature=d_k**0.5, + attn_dropout=dropout, + SoftMaxClass=SoftMaxClass, + softmax_on=softmax_on, ) self.dropout = cvnn.Dropout(dropout) diff --git a/complextorch/nn/modules/attention/eca.py b/complextorch/nn/modules/attention/eca.py index 509d3a2..4623cce 100755 --- a/complextorch/nn/modules/attention/eca.py +++ b/complextorch/nn/modules/attention/eca.py @@ -1,8 +1,8 @@ import numpy as np -import torch.nn as nn import torch +import torch.nn as nn -from .... import nn as cvnn +from complextorch import nn as cvnn __all__ = [ "EfficientChannelAttention1d", @@ -22,7 +22,7 @@ class _EfficientChannelAttention(nn.Module): \texttt{CV-ECA}(\mathbf{z}) = \mathcal{M}(\text{conv}(H_\texttt{CVAdaptiveAvgPoolNd}(\mathbf{z}))) \odot \mathbf{z}, - where :math:`\mathcal{M}(\cdot)` is the masking function (by default, ComplexRatioMask is used) and :math:`H_\texttt{CVAdaptiveAvgPoolNd}(\cdot)` is the complex-valued global :doc:`pooling <../pooling>` operator. + where :math:`\mathcal{M}(\cdot)` is the masking function (by default, ComplexRatioMask is used) and :math:`H_\texttt{CVAdaptiveAvgPoolNd}(\cdot)` is the complex-valued global :mod:`pooling ` operator. """ def __init__( @@ -34,7 +34,7 @@ def __init__( gamma: int = 2, dtype=torch.cfloat, ) -> None: - super(_EfficientChannelAttention, self).__init__() + super().__init__() self.channels = channels self.b = b self.gamma = gamma @@ -51,8 +51,7 @@ def __init__( def kernel_size(self) -> int: k = int(abs((np.log2(self.channels) / self.gamma) + self.b / self.gamma)) - out = k if k % 2 else k + 1 - return out + return k if k % 2 else k + 1 def forward(self, input: torch.Tensor) -> torch.Tensor: batch_size, channels, *im_size = input.shape @@ -62,8 +61,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: y = self.avg_pool(input) # Two different branches of ECA module - y = self.conv(y.squeeze(-1).view(batch_size, 1, channels)).transpose(-1, -2) - y = y.transpose(-1, -2).view(batch_size, channels, *one_vec) + y = self.conv(y.squeeze(-1).view(batch_size, 1, channels)) + y = y.view(batch_size, channels, *one_vec) # Multi-scale information fusion y = self.mask(y) @@ -82,7 +81,7 @@ class EfficientChannelAttention1d(_EfficientChannelAttention): \texttt{CV-ECA}(\mathbf{z}) = \mathcal{M}(\text{conv}(H_\texttt{CVAdaptiveAvgPool1d}(\mathbf{z}))) \odot \mathbf{z}, - where :math:`\mathcal{M}(\cdot)` is the masking function (by default, ComplexRatioMask is used) and :math:`H_\texttt{CVAdaptiveAvgPoolNd}(\cdot)` is the complex-valued global :doc:`pooling <../pooling>` operator. + where :math:`\mathcal{M}(\cdot)` is the masking function (by default, ComplexRatioMask is used) and :math:`H_\texttt{CVAdaptiveAvgPoolNd}(\cdot)` is the complex-valued global :mod:`pooling ` operator. Based on work from the following paper: @@ -101,7 +100,7 @@ def __init__( gamma: int = 2, dtype=torch.cfloat, ) -> None: - super(EfficientChannelAttention1d, self).__init__( + super().__init__( channels=channels, MaskingClass=MaskingClass, AvgPoolClass=cvnn.AdaptiveAvgPool1d, @@ -122,7 +121,7 @@ class EfficientChannelAttention2d(_EfficientChannelAttention): \texttt{CV-ECA}(\mathbf{z}) = \mathcal{M}(\text{conv}(H_\texttt{CVAdaptiveAvgPool2d}(\mathbf{z}))) \odot \mathbf{z}, - where :math:`\mathcal{M}(\cdot)` is the masking function (by default, ComplexRatioMask is used) and :math:`H_\texttt{CVAdaptiveAvgPoolNd}(\cdot)` is the complex-valued global :doc:`pooling <../pooling>` operator. + where :math:`\mathcal{M}(\cdot)` is the masking function (by default, ComplexRatioMask is used) and :math:`H_\texttt{CVAdaptiveAvgPoolNd}(\cdot)` is the complex-valued global :mod:`pooling ` operator. Based on work from the following paper: @@ -141,7 +140,7 @@ def __init__( gamma: int = 2, dtype=torch.cfloat, ) -> None: - super(EfficientChannelAttention2d, self).__init__( + super().__init__( channels=channels, MaskingClass=MaskingClass, AvgPoolClass=cvnn.AdaptiveAvgPool2d, @@ -162,7 +161,7 @@ class EfficientChannelAttention3d(_EfficientChannelAttention): \texttt{CV-ECA}(\mathbf{z}) = \mathcal{M}(\text{conv}(H_\texttt{CVAdaptiveAvgPool3d}(\mathbf{z}))) \odot \mathbf{z}, - where :math:`\mathcal{M}(\cdot)` is the masking function (by default, ComplexRatioMask is used) and :math:`H_\texttt{CVAdaptiveAvgPoolNd}(\cdot)` is the complex-valued global :doc:`pooling <../pooling>` operator. + where :math:`\mathcal{M}(\cdot)` is the masking function (by default, ComplexRatioMask is used) and :math:`H_\texttt{CVAdaptiveAvgPoolNd}(\cdot)` is the complex-valued global :mod:`pooling ` operator. Based on work from the following paper: @@ -181,7 +180,7 @@ def __init__( gamma: int = 2, dtype=torch.cfloat, ) -> None: - super(EfficientChannelAttention3d, self).__init__( + super().__init__( channels=channels, MaskingClass=MaskingClass, AvgPoolClass=cvnn.AdaptiveAvgPool3d, diff --git a/complextorch/nn/modules/attention/mca.py b/complextorch/nn/modules/attention/mca.py index df64f52..f0abddd 100755 --- a/complextorch/nn/modules/attention/mca.py +++ b/complextorch/nn/modules/attention/mca.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from .... import nn as cvnn +from complextorch import nn as cvnn __all__ = [ "MaskedChannelAttention1d", @@ -32,15 +32,15 @@ def __init__( MaskingClass: nn.Module = cvnn.ComplexRatioMask, act: nn.Module = cvnn.CReLU, ) -> None: - super(_MaskedChannelAttention, self).__init__() + super().__init__() self.channels = channels self.reduction_factor = reduction_factor self.MaskingClass = MaskingClass() self.act = act() - assert ( - channels % reduction_factor == 0 - ), "Channels / Reduction Factor must yield integer" + assert channels % reduction_factor == 0, ( + "Channels / Reduction Factor must yield integer" + ) self.reduced_channels = int(channels / reduction_factor) # Placeholders @@ -63,7 +63,7 @@ class MaskedChannelAttention1d(_MaskedChannelAttention): 1-D Complex-Valued Masked Channel Attention (MCA) Module -------------------------------------------------------- - Generalized for arbitrary masking function (see :doc:`mask <../mask>` for implemented masking functions) + Generalized for arbitrary masking function (see :mod:`mask ` for implemented masking functions) Implements the operation: @@ -89,7 +89,7 @@ def __init__( MaskingClass: nn.Module = cvnn.ComplexRatioMask, act: nn.Module = cvnn.CReLU, ) -> None: - super(MaskedChannelAttention1d, self).__init__( + super().__init__( channels=channels, reduction_factor=reduction_factor, MaskingClass=MaskingClass, @@ -98,14 +98,14 @@ def __init__( self.avg_pool = cvnn.AdaptiveAvgPool1d(1) - self.conv_down = cvnn.SlowConv1d( + self.conv_down = cvnn.Conv1d( in_channels=channels, out_channels=self.reduced_channels, kernel_size=1, bias=False, ) - self.conv_up = cvnn.SlowConv1d( + self.conv_up = cvnn.Conv1d( in_channels=self.reduced_channels, out_channels=channels, kernel_size=1, @@ -126,7 +126,7 @@ class MaskedChannelAttention2d(_MaskedChannelAttention): where :math:`\mathcal{M}(\cdot)` is the masking function (by default, ComplexRatioMask is used) and :math:`H_\text{ConvUp}(\cdot)` and :math:`H_\text{ConvDown}(\cdot)` are 2-D convolution layers with kernel sizes of 1 that reduce the channel dimension by a factor :math:`r`. - Generalized for arbitrary masking function (see :doc:`mask <../mask>` for implemented masking functions) + Generalized for arbitrary masking function (see :mod:`mask ` for implemented masking functions) Based on work from the following paper: @@ -144,7 +144,7 @@ def __init__( MaskingClass: nn.Module = cvnn.ComplexRatioMask, act: nn.Module = cvnn.CReLU, ) -> None: - super(MaskedChannelAttention2d, self).__init__( + super().__init__( channels=channels, reduction_factor=reduction_factor, MaskingClass=MaskingClass, @@ -153,14 +153,14 @@ def __init__( self.avg_pool = cvnn.AdaptiveAvgPool2d(1) - self.conv_down = cvnn.SlowConv2d( + self.conv_down = cvnn.Conv2d( in_channels=channels, out_channels=self.reduced_channels, kernel_size=1, bias=False, ) - self.conv_up = cvnn.SlowConv2d( + self.conv_up = cvnn.Conv2d( in_channels=self.reduced_channels, out_channels=channels, kernel_size=1, @@ -173,7 +173,7 @@ class MaskedChannelAttention3d(_MaskedChannelAttention): 3-D Complex-Valued Masked Channel Attention (MCA) Module -------------------------------------------------------- - Generalized for arbitrary masking function (see :doc:`mask <../mask>` for implemented masking functions) + Generalized for arbitrary masking function (see :mod:`mask ` for implemented masking functions) Implements the operation: @@ -200,7 +200,7 @@ def __init__( MaskingClass: nn.Module = cvnn.ComplexRatioMask, act: nn.Module = cvnn.CReLU, ) -> None: - super(MaskedChannelAttention3d, self).__init__( + super().__init__( channels=channels, reduction_factor=reduction_factor, MaskingClass=MaskingClass, @@ -209,14 +209,14 @@ def __init__( self.avg_pool = cvnn.AdaptiveAvgPool3d(1) - self.conv_down = cvnn.SlowConv3d( + self.conv_down = cvnn.Conv3d( in_channels=channels, out_channels=self.reduced_channels, kernel_size=1, bias=False, ) - self.conv_up = cvnn.SlowConv3d( + self.conv_up = cvnn.Conv3d( in_channels=self.reduced_channels, out_channels=channels, kernel_size=1, diff --git a/complextorch/nn/modules/batchnorm.py b/complextorch/nn/modules/batchnorm.py index 11a34c3..522d845 100755 --- a/complextorch/nn/modules/batchnorm.py +++ b/complextorch/nn/modules/batchnorm.py @@ -2,9 +2,19 @@ import torch.nn as nn from torch.nn import init -from .. import functional as cvF +from complextorch.nn import functional as cvF -__all__ = ["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"] +__all__ = [ + "BatchNorm1d", + "BatchNorm2d", + "BatchNorm3d", + "MagBatchNorm1d", + "MagBatchNorm2d", + "MagBatchNorm3d", + "NaiveBatchNorm1d", + "NaiveBatchNorm2d", + "NaiveBatchNorm3d", +] class _BatchNorm(nn.Module): @@ -75,12 +85,9 @@ def _check_input_dim(self, input) -> None: def forward(self, input: torch.Tensor) -> torch.Tensor: self._check_input_dim(input) - if self.momentum is None: - exponential_average_factor = 0.0 - else: - exponential_average_factor = self.momentum + exponential_average_factor = 0.0 if self.momentum is None else self.momentum - if self.training and self.track_running_stats: + if self.training and self.track_running_stats: # noqa: SIM102 — kept nested to mirror torch.nn.modules.batchnorm._BatchNorm.forward if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average @@ -112,9 +119,9 @@ class BatchNorm1d(_BatchNorm): -------------------------------------- Complex-valued batch normalization for 2-D and 3-D tensors. - Similar to the `PyTorch BatchNorm1d `_ implementation but performs the proper batch normalization for complex-valued data. + Similar to the PyTorch :class:`torch.nn.BatchNorm1d` implementation but performs the proper batch normalization for complex-valued data. - See `torch.nn.BatchNorm1d `_ for additional details. + See :class:`torch.nn.BatchNorm1d` for additional details. Based on work from the following paper: @@ -136,9 +143,9 @@ class BatchNorm2d(_BatchNorm): -------------------------------------- Complex-valued batch normalization for 4-D tensors. - Similar to the `PyTorch BatchNorm2d `_ implementation but performs the proper batch normalization for complex-valued data. + Similar to the PyTorch :class:`torch.nn.BatchNorm2d` implementation but performs the proper batch normalization for complex-valued data. - See `torch.nn.BatchNorm2d `_ for additional details. + See :class:`torch.nn.BatchNorm2d` for additional details. Based on work from the following paper: @@ -160,9 +167,9 @@ class BatchNorm3d(_BatchNorm): -------------------------------------- Complex-valued batch normalization for 5-D tensors. - Similar to the `PyTorch BatchNorm3d `_ implementation but performs the proper batch normalization for complex-valued data. + Similar to the PyTorch :class:`torch.nn.BatchNorm3d` implementation but performs the proper batch normalization for complex-valued data. - See `torch.nn.BatchNorm3d `_ for additional details. + See :class:`torch.nn.BatchNorm3d` for additional details. Based on work from the following paper: @@ -176,3 +183,148 @@ class BatchNorm3d(_BatchNorm): def _check_input_dim(self, input: torch.Tensor) -> None: if input.dim() != 5: raise ValueError(f"expected 5D input (got {input.dim()}D input)") + + +class _NaiveBatchNorm(nn.Module): + r""" + Naive (split) Complex Batch Normalization Base + ---------------------------------------------- + + Applies an independent :class:`torch.nn.BatchNorm{1,2,3}d` to the real and + imaginary parts of the input. Cheaper than the Trabelsi 2×2-whitening + :class:`_BatchNorm` (about half the cost) but does not decorrelate the + real/imag components. Useful as a baseline. + """ + + _real_bn_class = nn.BatchNorm1d # overridden per dim + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + ) -> None: + super().__init__() + self.num_features = num_features + self.bn_r = self._real_bn_class( + num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + self.bn_i = self._real_bn_class( + num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.complex(self.bn_r(input.real), self.bn_i(input.imag)) + + def extra_repr(self) -> str: + return f"num_features={self.num_features}" + + +class NaiveBatchNorm1d(_NaiveBatchNorm): + r"""1-D split-form complex BatchNorm. See :class:`_NaiveBatchNorm`.""" + + _real_bn_class = nn.BatchNorm1d + + +class NaiveBatchNorm2d(_NaiveBatchNorm): + r"""2-D split-form complex BatchNorm. See :class:`_NaiveBatchNorm`.""" + + _real_bn_class = nn.BatchNorm2d + + +class NaiveBatchNorm3d(_NaiveBatchNorm): + r"""3-D split-form complex BatchNorm. See :class:`_NaiveBatchNorm`.""" + + _real_bn_class = nn.BatchNorm3d + + +class _MagBatchNorm(nn.Module): + r""" + Magnitude-Only Complex Batch Normalization Base + ----------------------------------------------- + + Applies an ordinary real-valued :class:`torch.nn.BatchNorm{1,2,3}d` to the + magnitude :math:`|z|` and rescales :math:`z` to match: + + .. math:: + + y = z \cdot \frac{\operatorname{BN}(|z|)}{|z| + \varepsilon} + + The output's phase is identical to the input's, so the operator is + **U(1)-equivariant**: rotating the input by :math:`e^{j\psi}` rotates the + output by exactly the same angle. This is distinct from the standard + :class:`BatchNorm{1,2,3}d` (Trabelsi 2×2 whitening), which decorrelates the + real/imag covariance but is *not* phase-equivariant. + + Running statistics, affine parameters, and ``eps``/``momentum`` semantics + follow :class:`torch.nn.BatchNorm` directly — the underlying real BN is + stored as ``self.bn`` so its ``state_dict`` is portable. + + Based on work from the following paper: + + **U. Singhal, Y. Xing, S. X. Yu. Co-Domain Symmetry for Complex-Valued Deep Learning.** + + - CVPR 2022 — `VNCBN` ("Vector-Norm Complex Batch Norm") in the reference implementation + + - https://openaccess.thecvf.com/content/CVPR2022/papers/Singhal_Co-Domain_Symmetry_for_Complex-Valued_Deep_Learning_CVPR_2022_paper.pdf + """ + + _real_bn_class = nn.BatchNorm1d # overridden per dim + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + ) -> None: + super().__init__() + self.num_features = num_features + self.eps = eps + self.bn = self._real_bn_class( + num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if input.is_complex(): + magnitude = input.abs() + normalized = self.bn(magnitude) + scale = normalized / (magnitude + self.eps) + return input * scale.to(input.dtype) + return self.bn(input) + + def extra_repr(self) -> str: + return f"num_features={self.num_features}" + + +class MagBatchNorm1d(_MagBatchNorm): + r"""1-D magnitude-only complex BatchNorm. See :class:`_MagBatchNorm`.""" + + _real_bn_class = nn.BatchNorm1d + + +class MagBatchNorm2d(_MagBatchNorm): + r"""2-D magnitude-only complex BatchNorm. See :class:`_MagBatchNorm`.""" + + _real_bn_class = nn.BatchNorm2d + + +class MagBatchNorm3d(_MagBatchNorm): + r"""3-D magnitude-only complex BatchNorm. See :class:`_MagBatchNorm`.""" + + _real_bn_class = nn.BatchNorm3d diff --git a/complextorch/nn/modules/casting.py b/complextorch/nn/modules/casting.py new file mode 100644 index 0000000..2f3f9a0 --- /dev/null +++ b/complextorch/nn/modules/casting.py @@ -0,0 +1,128 @@ +r""" +Layout-Conversion Modules +========================= + +Drop-in :class:`torch.nn.Module` adapters that convert between real-tensor +layouts and complex tensors, so they compose inside :class:`torch.nn.Sequential`. + +These layouts come up in real-to-complex pipelines: + +- **Interleaved**: ``[..., 2*D]`` with real and imaginary parts interleaved: + ``(re_0, im_0, re_1, im_1, ...)``. +- **Concatenated**: ``[..., 2*D]`` with all real parts followed by all + imaginary parts: ``(re_0, ..., re_{D-1}, im_0, ..., im_{D-1})``. + +If your data is in ``(..., 2)`` final-dim layout (one slot for real, one for +imag), use :func:`torch.view_as_complex` / :func:`torch.view_as_real` directly +— no wrapper is needed. +""" + +import torch +import torch.nn as nn + +__all__ = [ + "ComplexToConcatenated", + "ComplexToInterleaved", + "ConcatenatedToComplex", + "InterleavedToComplex", + "RealToComplex", +] + + +class InterleavedToComplex(nn.Module): + r""" + Interleaved Real Layout → Complex + ----------------------------------- + + Maps ``[..., 2D]`` with real/imag interleaved along the last dim to a + complex tensor of shape ``[..., D]``. + + Input: ``(re_0, im_0, re_1, im_1, ...)``. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if input.shape[-1] % 2 != 0: + raise ValueError( + f"InterleavedToComplex requires last dim to be even; got {input.shape[-1]}" + ) + reshaped = input.reshape(*input.shape[:-1], -1, 2).contiguous() + return torch.view_as_complex(reshaped) + + +class ComplexToInterleaved(nn.Module): + r""" + Complex → Interleaved Real Layout + ----------------------------------- + + Inverse of :class:`InterleavedToComplex`. Maps complex ``[..., D]`` to real + ``[..., 2D]`` with real/imag interleaved. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not input.is_complex(): + raise TypeError( + f"ComplexToInterleaved expects a complex input, got {input.dtype}" + ) + as_real = torch.view_as_real(input) # [..., D, 2] + return as_real.reshape(*as_real.shape[:-2], -1) + + +class ConcatenatedToComplex(nn.Module): + r""" + Concatenated Real Layout → Complex + ------------------------------------ + + Maps ``[..., 2D]`` with the first ``D`` slots being real parts and the last + ``D`` slots being imaginary parts to a complex tensor of shape ``[..., D]``. + + Input: ``(re_0, ..., re_{D-1}, im_0, ..., im_{D-1})``. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if input.shape[-1] % 2 != 0: + raise ValueError( + f"ConcatenatedToComplex requires last dim to be even; got {input.shape[-1]}" + ) + real, imag = torch.chunk(input, 2, dim=-1) + return torch.complex(real.contiguous(), imag.contiguous()) + + +class ComplexToConcatenated(nn.Module): + r""" + Complex → Concatenated Real Layout + ------------------------------------ + + Inverse of :class:`ConcatenatedToComplex`. Maps complex ``[..., D]`` to + real ``[..., 2D]`` with all real parts followed by all imaginary parts. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not input.is_complex(): + raise TypeError( + f"ComplexToConcatenated expects a complex input, got {input.dtype}" + ) + return torch.cat((input.real, input.imag), dim=-1) + + +class RealToComplex(nn.Module): + r""" + Real → Complex (Zero Imaginary) + --------------------------------- + + Lifts a real tensor into a complex tensor by setting the imaginary part to + zero. Useful as the first layer of a network whose input is a real signal + but whose internal representations are complex. + """ + + def __init__(self, dtype: torch.dtype = torch.cfloat) -> None: + super().__init__() + self.dtype = dtype + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if input.is_complex(): + return input.to(self.dtype) + zeros = torch.zeros_like(input) + return torch.complex(input, zeros).to(self.dtype) + + def extra_repr(self) -> str: + return f"dtype={self.dtype}" diff --git a/complextorch/nn/modules/conv.py b/complextorch/nn/modules/conv.py old mode 100755 new mode 100644 index e1141ca..f081378 --- a/complextorch/nn/modules/conv.py +++ b/complextorch/nn/modules/conv.py @@ -1,959 +1,341 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t - -from typing import Tuple, Union - - -__all__ = [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvTranspose1d", - "ConvTranspose2d", - "ConvTranspose3d", - "SlowConv1d", - "SlowConv2d", - "SlowConv3d", - "SlowConvTranspose1d", - "SlowConvTranspose2d", - "SlowConvTranspose3d", -] - - -class Conv1d(nn.Module): - r""" - Complex-Valued 1-D Convolution using PyTorch - -------------------------------------------- - - - Implemented using `torch.nn.Conv1d `_ and complex-valued tensors. - - - Used to be slower than `complextorch` version but is now faster after PyTorch 2.1.0 update. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - dilation: int = 1, - groups: int = 1, - bias: bool = False, - padding_mode: str = "zeros", - device=None, - dtype=torch.cfloat, - ) -> None: - super(Conv1d, self).__init__() - - self.conv = nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Computes 1-D complex-valued convolution using PyTorch. - - Args: - input (torch.Tensor): input tensor - - Returns: - torch.Tensor: Conv1d(input) - """ - return self.conv(input) - - -class Conv2d(nn.Module): - r""" - Complex-Valued 2-D Convolution using PyTorch - -------------------------------------------- - - - Implemented using `torch.nn.Conv2d `_ and complex-valued tensors. - - - Used to be slower than `complextorch` version but is now faster after PyTorch 2.1.0 update. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - dilation: int = 1, - groups: int = 1, - bias: bool = False, - padding_mode: str = "zeros", - device=None, - dtype=torch.cfloat, - ) -> None: - super(Conv2d, self).__init__() - - self.conv = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Computes 2-D complex-valued convolution using PyTorch. - - Args: - input (torch.Tensor): input tensor - - Returns: - torch.Tensor: Conv2d(input) - """ - return self.conv(input) - - -class Conv3d(nn.Module): - r""" - Complex-Valued 3-D Convolution using PyTorch - -------------------------------------------- - - - Implemented using `torch.nn.Conv2d `_ and complex-valued tensors. - - - Used to be slower than `complextorch` version but is now faster after PyTorch 2.1.0 update. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - dilation: int = 1, - groups: int = 1, - bias: bool = False, - padding_mode: str = "zeros", - device=None, - dtype=torch.cfloat, - ) -> None: - super(Conv3d, self).__init__() - - self.conv = nn.Conv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Computes 3-D complex-valued convolution using PyTorch. - - Args: - input (torch.Tensor): input tensor - - Returns: - torch.Tensor: Conv3d(input) - """ - return self.conv(input) - - -class ConvTranspose1d(nn.Module): - r""" - Complex-Valued 1-D Transposed Convolution using PyTorch - ------------------------------------------------------- - - - Implemented using `torch.nn.ConvTranspose1d `_ and complex-valued tensors. - - - Used to be slower than `complextorch` version but is now faster after PyTorch 2.1.0 update. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - output_padding: int = 1, - groups: int = 1, - bias: bool = False, - dilation: int = 1, - padding_mode: str = "zeros", - device=None, - dtype=torch.cfloat, - ) -> None: - super(ConvTranspose1d, self).__init__() - - self.conv_transposed = nn.ConvTranspose1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - bias=bias, - dilation=dilation, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Computes 1-D complex-valued transposed convolution using PyTorch. - - Args: - input (torch.Tensor): input tensor - - Returns: - torch.Tensor: ConvTranspose1d(input) - """ - return self.conv_transposed(input) - - -class ConvTranspose2d(nn.Module): - r""" - Complex-Valued 2-D Transposed Convolution using PyTorch - ------------------------------------------------------- - - - Implemented using `torch.nn.ConvTranspose2d `_ and complex-valued tensors. - - - Used to be slower than `complextorch` version but is now faster after PyTorch 2.1.0 update. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - output_padding: int = 1, - groups: int = 1, - bias: bool = False, - dilation: int = 1, - padding_mode: str = "zeros", - device=None, - dtype=torch.cfloat, - ) -> None: - super(ConvTranspose2d, self).__init__() - - self.conv_transposed = nn.ConvTranspose2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - bias=bias, - dilation=dilation, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Computes 2-D complex-valued transposed convolution using PyTorch. - - Args: - input (torch.Tensor): input tensor - - Returns: - torch.Tensor: ConvTranspose2d(input) - """ - return self.conv_transposed(input) - - -class ConvTranspose3d(nn.Module): - r""" - Complex-Valued 3-D Transposed Convolution using PyTorch - ------------------------------------------------------- - - - Implemented using `torch.nn.ConvTranspose3d `_ and complex-valued tensors. - - - Used to be slower than `complextorch` version but is now faster after PyTorch 2.1.0 update. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - output_padding: int = 1, - groups: int = 1, - bias: bool = False, - dilation: int = 1, - padding_mode: str = "zeros", - device=None, - dtype=torch.cfloat, - ) -> None: - super(ConvTranspose3d, self).__init__() - - self.conv_transposed = nn.ConvTranspose3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - bias=bias, - dilation=dilation, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Computes 3-D complex-valued transposed convolution using PyTorch. - - Args: - input (torch.Tensor): input tensor - - Returns: - torch.Tensor: ConvTranspose3d(input) - """ - return self.conv_transposed(input) - - -class _Conv(nn.Module): - r""" - torch.Tensor-based Complex-Valued Convolution - ----------------------------------------- - """ - - def __init__( - self, - ConvClass: nn.Module, - ConvFunc, - in_channels: int, - out_channels: int, - kernel_size: Tuple[int, ...], - stride: Tuple[int, ...], - padding: Tuple[int, ...], - dilation: Tuple[int, ...], - groups: int, - bias: bool, - padding_mode: str, - device=None, - dtype=None, - ) -> None: - super(_Conv, self).__init__() - - self.ConvFunc = ConvFunc - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.groups = groups - - # Assumes PyTorch complex weight initialization is correct - __temp = ConvClass( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype if dtype else torch.cfloat, - ) - - self.conv_r = ConvClass( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - self.conv_i = ConvClass( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - self.conv_r.weight.data = __temp.weight.real - self.conv_i.weight.data = __temp.weight.imag - - if bias: - self.conv_r.bias.data = __temp.bias.real - self.conv_i.bias.data = __temp.bias.imag - - @property - def weight(self) -> torch.Tensor: - # print(self.conv_i.weight.type(), "weight type") - return torch.complex(self.conv_r.weight, self.conv_i.weight) - - @property - def bias(self) -> torch.Tensor: - print(self.conv_i.bias.type(), "bias type") - return torch.complex(self.conv_r.bias, self.conv_i.bias) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - r""" - Computes convolution 25% faster than naive method by using Gauss' multiplication trick - """ - t1 = self.conv_r(input.real) - t2 = self.conv_i(input.imag) - bias = ( - None if self.conv_r.bias is None else (self.conv_r.bias + self.conv_i.bias) - ) - t3 = self.ConvFunc( - input=(input.real + input.imag), - weight=(self.conv_r.weight + self.conv_i.weight), - bias=bias, - stride=self.stride, - padding=self.padding, - groups=self.groups, - ) - print( - "Conv memory allocated", torch.cuda.memory_allocated("cuda:0") / (1024**3) - ) - return torch.complex(t1 - t2, t3 - t2 - t1) - - -class SlowConv1d(_Conv): - r""" - 1-D Complex-Valued Convolution - ------------------------------ - - Based on the `PyTorch torch.nn.Conv1d `_ implementation. - - Employs Gauss' multiplication trick to reduce number of computations by 25% compare with the naive implementation. - - The most common implementation of complex-valued convolution entails the following computation: - - .. math:: - - G(\mathbf{z}) = \text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + j(\text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R}))) - - where :math:`\mathbf{W}` and :math:`\mathbf{b}` are the complex-valued weight and bias tensors, respectively, and :math:`\text{conv}(\cdot)` is the conovlution operator. - - By comparison, using Gauss' trick, the complex-vauled convolution can be implemented as: - - .. math:: - - t1 =& \text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - - t2 =& \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) - - t3 =& \text{conv}(\mathbf{z}_\mathbb{R} + \mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R} + \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{R} + \mathbf{b}_\mathbb{I})) - - G(\mathbf{z}) =& t1 - t2 + j(t3 - t2 - t1) - - requiring only 3 convolution operations. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_1_t, - stride: _size_1_t = 1, - padding: Union[str, _size_1_t] = 0, - dilation: _size_1_t = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", - device=None, - dtype=None, - ) -> None: - super(SlowConv1d, self).__init__( - ConvClass=nn.Conv1d, - ConvFunc=F.conv1d, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - -class SlowConv2d(_Conv): - r""" - 2-D Complex-Valued Convolution - ------------------------------ - - Based on the `PyTorch torch.nn.Conv2d `_ implementation. - - Employs Gauss' multiplication trick to reduce number of computations by 25% compare with the naive implementation. - - The most common implementation of complex-valued convolution entails the following computation: - - .. math:: - - G(\mathbf{z}) = \text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + j(\text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R}))) - - where :math:`\mathbf{W}` and :math:`\mathbf{b}` are the complex-valued weight and bias tensors, respectively, and :math:`\text{conv}(\cdot)` is the conovlution operator. - - By comparison, using Gauss' trick, the complex-vauled convolution can be implemented as: - - .. math:: - - t1 =& \text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - - t2 =& \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) - - t3 =& \text{conv}(\mathbf{z}_\mathbb{R} + \mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R} + \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{R} + \mathbf{b}_\mathbb{I})) - - G(\mathbf{z}) =& t1 - t2 + j(t3 - t2 - t1) - - requiring only 3 convolution operations. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_2_t, - stride: _size_2_t = 1, - padding: Union[str, _size_2_t] = 0, - dilation: _size_2_t = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", - device=None, - dtype=None, - ) -> None: - super(SlowConv2d, self).__init__( - ConvClass=nn.Conv2d, - ConvFunc=F.conv2d, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - -class SlowConv3d(_Conv): - r""" - 3-D Complex-Valued Convolution - ------------------------------ - - Based on the `PyTorch torch.nn.Conv3d `_ implementation. - - Employs Gauss' multiplication trick to reduce number of computations by 25% compare with the naive implementation. - - The most common implementation of complex-valued convolution entails the following computation: - - .. math:: - - G(\mathbf{z}) = \text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + j(\text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R}))) - - where :math:`\mathbf{W}` and :math:`\mathbf{b}` are the complex-valued weight and bias tensors, respectively, and :math:`\text{conv}(\cdot)` is the conovlution operator. - - By comparison, using Gauss' trick, the complex-vauled convolution can be implemented as: - - .. math:: - - t1 =& \text{conv}(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - - t2 =& \text{conv}(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) - - t3 =& \text{conv}(\mathbf{z}_\mathbb{R} + \mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R} + \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{R} + \mathbf{b}_\mathbb{I})) - - G(\mathbf{z}) =& t1 - t2 + j(t3 - t2 - t1) - - requiring only 3 convolution operations. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_3_t, - stride: _size_3_t = 1, - padding: Union[str, _size_3_t] = 0, - dilation: _size_3_t = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", - device=None, - dtype=None, - ) -> None: - super(SlowConv3d, self).__init__( - ConvClass=nn.Conv3d, - ConvFunc=F.conv3d, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - -class _ConvTranspose(nn.Module): - r""" - torch.Tensor-based Complex-Valued Transposed Convolution - ---------------------------------------------------- - """ - - def __init__( - self, - ConvClass: nn.Module, - ConvFunc, - in_channels: int, - out_channels: int, - kernel_size: Tuple[int, ...], - stride: Tuple[int, ...], - padding: Tuple[int, ...], - dilation: Tuple[int, ...], - output_padding: Tuple[int, ...], - groups: int, - bias: bool, - padding_mode: str, - device=None, - dtype=None, - ) -> None: - super(_ConvTranspose, self).__init__() - - self.ConvFunc = ConvFunc - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.output_padding = output_padding - self.groups = groups - - # Assumes PyTorch complex weight initialization is correct - __temp = ConvClass( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - bias=bias, - dilation=dilation, - padding_mode=padding_mode, - device=device, - dtype=dtype if dtype else torch.cfloat, - ) - - self.convt_r = ConvClass( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - bias=bias, - dilation=dilation, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - self.convt_i = ConvClass( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - bias=bias, - dilation=dilation, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - self.convt_r.weight.data = __temp.weight.real - self.convt_i.weight.data = __temp.weight.imag - - if bias: - self.convt_r.bias.data = __temp.bias.real - self.convt_i.bias.data = __temp.bias.imag - - @property - def weight(self) -> torch.Tensor: - return torch.complex(self.convt_r.weight, self.convt_i.weight) - - @property - def bias(self) -> torch.Tensor: - return torch.complex(self.convt_r.bias, self.convt_i.bias) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - r""" - Computes convolution 25% faster than naive method by using Gauss' multiplication trick - """ - t1 = self.convt_r(input.real) - t2 = self.convt_i(input.imag) - bias = ( - None - if self.convt_r.bias is None - else (self.convt_r.bias + self.convt_i.bias) - ) - t3 = self.ConvFunc( - input=(input.real + input.imag), - weight=(self.convt_r.weight + self.convt_i.weight), - bias=bias, - stride=self.stride, - padding=self.padding, - groups=self.groups, - ) - return torch.complex(t1 - t2, t3 - t2 - t1) - - -class SlowConvTranspose1d(_ConvTranspose): - r""" - 1-D Complex-Valued Transposed Convolution - ----------------------------------------- - - Based on the `PyTorch torch.nn.ConvTranspose1d `_ implementation. - - Employs Gauss' multiplication trick to reduce number of computations by 25% compare with the naive implementation. - - The most common implementation of complex-valued convolution entails the following computation: - - .. math:: - - G(\mathbf{z}) = \text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) - + j(\text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R}))) - - where :math:`\mathbf{W}` and :math:`\mathbf{b}` are the complex-valued weight and bias tensors, respectively, and :math:`\text{conv}_T(\cdot)` is the transposed conovlution operator. - - By comparison, using Gauss' trick, the complex-vauled convolution can be implemented as: - - .. math:: - - t1 =& \text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - - t2 =& \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) - - t3 =& \text{conv}_T(\mathbf{z}_\mathbb{R} + \mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R} + \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{R} + \mathbf{b}_\mathbb{I})) - - G(\mathbf{z}) =& t1 - t2 + j(t3 - t2 - t1) - - requiring only 3 transposed convolution operations. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_1_t, - stride: _size_1_t = 1, - padding: _size_1_t = 0, - output_padding: _size_1_t = 0, - groups: int = 1, - bias: bool = True, - dilation: _size_1_t = 1, - padding_mode: str = "zeros", - device=None, - dtype=None, - ) -> None: - super(SlowConvTranspose1d, self).__init__( - ConvClass=nn.ConvTranspose1d, - ConvFunc=F.conv_transpose1d, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - bias=bias, - dilation=dilation, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - -class SlowConvTranspose2d(_ConvTranspose): - r""" - 2-D Complex-Valued Transposed Convolution - ----------------------------------------- - - Based on the `PyTorch torch.nn.ConvTranspose2d `_ implementation. - - Employs Gauss' multiplication trick to reduce number of computations by 25% compare with the naive implementation. - - The most common implementation of complex-valued convolution entails the following computation: - - .. math:: - - G(\mathbf{z}) = \text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) - + j(\text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R}))) - - where :math:`\mathbf{W}` and :math:`\mathbf{b}` are the complex-valued weight and bias tensors, respectively, and :math:`\text{conv}_T(\cdot)` is the transposed conovlution operator. - - By comparison, using Gauss' trick, the complex-vauled convolution can be implemented as: - - .. math:: - - t1 =& \text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - - t2 =& \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) - - t3 =& \text{conv}_T(\mathbf{z}_\mathbb{R} + \mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R} + \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{R} + \mathbf{b}_\mathbb{I})) - - G(\mathbf{z}) =& t1 - t2 + j(t3 - t2 - t1) - - requiring only 3 transposed convolution operations. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_1_t, - stride: _size_1_t = 1, - padding: _size_1_t = 0, - output_padding: _size_1_t = 0, - groups: int = 1, - bias: bool = True, - dilation: _size_1_t = 1, - padding_mode: str = "zeros", - device=None, - dtype=None, - ) -> None: - super(SlowConvTranspose2d, self).__init__( - ConvClass=nn.ConvTranspose2d, - ConvFunc=F.conv_transpose2d, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - bias=bias, - dilation=dilation, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - -class SlowConvTranspose3d(_ConvTranspose): - r""" - 3-D Complex-Valued Transposed Convolution - ----------------------------------------- - - Based on the `PyTorch torch.nn.ConvTranspose3d `_ implementation. - - Employs Gauss' multiplication trick to reduce number of computations by 25% compare with the naive implementation. - - The most common implementation of complex-valued convolution entails the following computation: - - .. math:: - - G(\mathbf{z}) = \text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) - + j(\text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) + \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R}))) - - where :math:`\mathbf{W}` and :math:`\mathbf{b}` are the complex-valued weight and bias tensors, respectively, and :math:`\text{conv}_T(\cdot)` is the transposed conovlution operator. - - By comparison, using Gauss' trick, the complex-vauled convolution can be implemented as: - - .. math:: - - t1 =& \text{conv}_T(\mathbf{z}_\mathbb{R}, \mathbf{W}_\mathbb{R}, \mathbf{b}_\mathbb{R})) - - t2 =& \text{conv}_T(\mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{I})) - - t3 =& \text{conv}_T(\mathbf{z}_\mathbb{R} + \mathbf{z}_\mathbb{I}, \mathbf{W}_\mathbb{R} + \mathbf{W}_\mathbb{I}, \mathbf{b}_\mathbb{R} + \mathbf{b}_\mathbb{I})) - - G(\mathbf{z}) =& t1 - t2 + j(t3 - t2 - t1) - - requiring only 3 transposed convolution operations. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_1_t, - stride: _size_1_t = 1, - padding: _size_1_t = 0, - output_padding: _size_1_t = 0, - groups: int = 1, - bias: bool = True, - dilation: _size_1_t = 1, - padding_mode: str = "zeros", - device=None, - dtype=None, - ) -> None: - super(SlowConvTranspose3d, self).__init__( - ConvClass=nn.ConvTranspose3d, - ConvFunc=F.conv_transpose3d, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - bias=bias, - dilation=dilation, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) +import torch +import torch.nn as nn + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] + + +class Conv1d(nn.Module): + r""" + Complex-Valued 1-D Convolution using PyTorch + -------------------------------------------- + + - Implemented using `torch.nn.Conv1d `_ and complex-valued tensors. + + - Convenience wrapper over ``torch.nn.Conv1d`` whose only behavioural difference is the default ``dtype=torch.cfloat``. + + - See :mod:`complextorch.nn.gauss` for the hand-rolled real/imag-split variant using Gauss' trick. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=torch.cfloat, + ) -> None: + super().__init__() + + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r"""Computes 1-D complex-valued convolution using PyTorch. + + Args: + input (torch.Tensor): input tensor + + Returns: + torch.Tensor: Conv1d(input) + """ + return self.conv(input) + + +class Conv2d(nn.Module): + r""" + Complex-Valued 2-D Convolution using PyTorch + -------------------------------------------- + + - Implemented using `torch.nn.Conv2d `_ and complex-valued tensors. + + - Convenience wrapper over ``torch.nn.Conv2d`` whose only behavioural difference is the default ``dtype=torch.cfloat``. + + - See :mod:`complextorch.nn.gauss` for the hand-rolled real/imag-split variant using Gauss' trick. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=torch.cfloat, + ) -> None: + super().__init__() + + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r"""Computes 2-D complex-valued convolution using PyTorch. + + Args: + input (torch.Tensor): input tensor + + Returns: + torch.Tensor: Conv2d(input) + """ + return self.conv(input) + + +class Conv3d(nn.Module): + r""" + Complex-Valued 3-D Convolution using PyTorch + -------------------------------------------- + + - Implemented using `torch.nn.Conv3d `_ and complex-valued tensors. + + - Convenience wrapper over ``torch.nn.Conv3d`` whose only behavioural difference is the default ``dtype=torch.cfloat``. + + - See :mod:`complextorch.nn.gauss` for the hand-rolled real/imag-split variant using Gauss' trick. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=torch.cfloat, + ) -> None: + super().__init__() + + self.conv = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r"""Computes 3-D complex-valued convolution using PyTorch. + + Args: + input (torch.Tensor): input tensor + + Returns: + torch.Tensor: Conv3d(input) + """ + return self.conv(input) + + +class ConvTranspose1d(nn.Module): + r""" + Complex-Valued 1-D Transposed Convolution using PyTorch + ------------------------------------------------------- + + - Implemented using `torch.nn.ConvTranspose1d `_ and complex-valued tensors. + + - Convenience wrapper over ``torch.nn.ConvTranspose1d`` whose only behavioural difference is the default ``dtype=torch.cfloat``. + + - See :mod:`complextorch.nn.gauss` for the hand-rolled real/imag-split variant using Gauss' trick. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + groups: int = 1, + bias: bool = True, + dilation: int = 1, + padding_mode: str = "zeros", + device=None, + dtype=torch.cfloat, + ) -> None: + super().__init__() + + self.conv_transposed = nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r"""Computes 1-D complex-valued transposed convolution using PyTorch. + + Args: + input (torch.Tensor): input tensor + + Returns: + torch.Tensor: ConvTranspose1d(input) + """ + return self.conv_transposed(input) + + +class ConvTranspose2d(nn.Module): + r""" + Complex-Valued 2-D Transposed Convolution using PyTorch + ------------------------------------------------------- + + - Implemented using `torch.nn.ConvTranspose2d `_ and complex-valued tensors. + + - Convenience wrapper over ``torch.nn.ConvTranspose2d`` whose only behavioural difference is the default ``dtype=torch.cfloat``. + + - See :mod:`complextorch.nn.gauss` for the hand-rolled real/imag-split variant using Gauss' trick. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + groups: int = 1, + bias: bool = True, + dilation: int = 1, + padding_mode: str = "zeros", + device=None, + dtype=torch.cfloat, + ) -> None: + super().__init__() + + self.conv_transposed = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r"""Computes 2-D complex-valued transposed convolution using PyTorch. + + Args: + input (torch.Tensor): input tensor + + Returns: + torch.Tensor: ConvTranspose2d(input) + """ + return self.conv_transposed(input) + + +class ConvTranspose3d(nn.Module): + r""" + Complex-Valued 3-D Transposed Convolution using PyTorch + ------------------------------------------------------- + + - Implemented using `torch.nn.ConvTranspose3d `_ and complex-valued tensors. + + - Convenience wrapper over ``torch.nn.ConvTranspose3d`` whose only behavioural difference is the default ``dtype=torch.cfloat``. + + - See :mod:`complextorch.nn.gauss` for the hand-rolled real/imag-split variant using Gauss' trick. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + groups: int = 1, + bias: bool = True, + dilation: int = 1, + padding_mode: str = "zeros", + device=None, + dtype=torch.cfloat, + ) -> None: + super().__init__() + + self.conv_transposed = nn.ConvTranspose3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r"""Computes 3-D complex-valued transposed convolution using PyTorch. + + Args: + input (torch.Tensor): input tensor + + Returns: + torch.Tensor: ConvTranspose3d(input) + """ + return self.conv_transposed(input) diff --git a/complextorch/nn/modules/dropout.py b/complextorch/nn/modules/dropout.py index 6b179e3..be0bd97 100755 --- a/complextorch/nn/modules/dropout.py +++ b/complextorch/nn/modules/dropout.py @@ -1,9 +1,10 @@ import torch import torch.nn as nn +import torch.nn.functional as F -from .. import functional as cvF +from complextorch.nn import functional as cvF -__all__ = ["Dropout"] +__all__ = ["Dropout", "Dropout1d", "Dropout2d", "Dropout3d"] class Dropout(nn.Module): @@ -11,7 +12,7 @@ class Dropout(nn.Module): Complex-Valued Dropout Layer ---------------------------- - Applies `PyTorch Droput `_ to real and imaginary parts separately. + Applies `PyTorch Dropout `_ to real and imaginary parts separately, with **independent** Bernoulli masks per part. Implements the following operation: @@ -19,11 +20,20 @@ class Dropout(nn.Module): G(\mathbf{z}) = \texttt{Dropout}(\mathbf{x}) + j \texttt{Dropout}(\mathbf{y}), - where :math:`\mathbf{z} = \mathbf{x} + j\mathbf{y}` + where :math:`\mathbf{z} = \mathbf{x} + j\mathbf{y}`. + + .. note:: + + This differs from the dropout used in *Trabelsi et al. (2018) Deep Complex + Networks*, which uses a **shared** Bernoulli mask so that the entire complex + value is dropped together (preserving the phase of the surviving entries). + Because the real and imaginary masks here are sampled independently, the + phase of a non-dropped entry can change when only one of its real/imag + parts is zeroed out. Choose this layer deliberately. """ def __init__(self, p: float = 0.5, inplace: bool = False) -> None: - super(Dropout, self).__init__() + super().__init__() self.dropout_r = nn.Dropout(p, inplace) self.dropout_i = nn.Dropout(p, inplace) @@ -38,3 +48,83 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: :math:`\texttt{Dropout}(\mathbf{x}) + j \texttt{Dropout}(\mathbf{y})` """ return cvF.apply_complex_split(self.dropout_r, self.dropout_i, input) + + +class _ChannelDropoutNd(nn.Module): + r"""Internal base for channel-wise complex dropout with a shared real/imag mask. + + Implements Trabelsi et al. (2018) "Deep Complex Networks" complex dropout: + one Bernoulli mask is drawn per channel and applied to **both** the real and + imaginary parts simultaneously, so an entire complex channel is dropped + together and the phase of surviving entries is preserved. + """ + + _dropout_fn = staticmethod(F.dropout1d) # overridden in subclasses + + def __init__(self, p: float = 0.5, inplace: bool = False) -> None: + super().__init__() + if not 0.0 <= p < 1.0: + raise ValueError(f"dropout probability must be in [0, 1), got {p}") + self.p = p + self.inplace = inplace + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not self.training or self.p == 0.0: + return input + if not input.is_complex(): + return self._dropout_fn(input, self.p, training=True, inplace=self.inplace) + # Trabelsi 2018 channel dropout: one Bernoulli sample per (batch, channel), + # shared across real and imag, broadcast over spatial dims. + b, c = input.shape[0], input.shape[1] + mask_shape = (b, c) + (1,) * (input.dim() - 2) + mask = torch.empty(mask_shape, dtype=input.real.dtype, device=input.device) + mask.bernoulli_(1 - self.p).div_(1 - self.p) + return input * mask + + def extra_repr(self) -> str: + return f"p={self.p}, inplace={self.inplace}" + + +class Dropout1d(_ChannelDropoutNd): + r""" + Complex-Valued 1-D Channel Dropout (Trabelsi 2018 shared mask) + -------------------------------------------------------------- + + Zeros out entire complex channels (matching :class:`torch.nn.Dropout1d`) + using a single Bernoulli mask per channel applied to both real and + imaginary parts. The phase of surviving entries is preserved. + + Input shape: ``(B, C, L)``. + """ + + _dropout_fn = staticmethod(F.dropout1d) + + +class Dropout2d(_ChannelDropoutNd): + r""" + Complex-Valued 2-D Channel Dropout (Trabelsi 2018 shared mask) + -------------------------------------------------------------- + + Zeros out entire complex channels (matching :class:`torch.nn.Dropout2d`) + using a single Bernoulli mask per channel applied to both real and + imaginary parts. + + Input shape: ``(B, C, H, W)``. + """ + + _dropout_fn = staticmethod(F.dropout2d) + + +class Dropout3d(_ChannelDropoutNd): + r""" + Complex-Valued 3-D Channel Dropout (Trabelsi 2018 shared mask) + -------------------------------------------------------------- + + Zeros out entire complex channels (matching :class:`torch.nn.Dropout3d`) + using a single Bernoulli mask per channel applied to both real and + imaginary parts. + + Input shape: ``(B, C, D, H, W)``. + """ + + _dropout_fn = staticmethod(F.dropout3d) diff --git a/complextorch/nn/modules/fft.py b/complextorch/nn/modules/fft.py index afe7beb..ddff692 100755 --- a/complextorch/nn/modules/fft.py +++ b/complextorch/nn/modules/fft.py @@ -15,7 +15,7 @@ class FFTBlock(nn.Module): """ def __init__(self, n=None, dim=-1, norm=None) -> None: - super(FFTBlock, self).__init__() + super().__init__() self.n = n self.dim = dim @@ -44,7 +44,7 @@ class IFFTBlock(nn.Module): """ def __init__(self, n=None, dim=-1, norm=None) -> None: - super(IFFTBlock, self).__init__() + super().__init__() self.n = n self.dim = dim diff --git a/complextorch/nn/modules/groupnorm.py b/complextorch/nn/modules/groupnorm.py new file mode 100644 index 0000000..df21dfe --- /dev/null +++ b/complextorch/nn/modules/groupnorm.py @@ -0,0 +1,129 @@ +r""" +Complex-Valued GroupNorm +======================== + +Group normalization adapted to complex tensors. Splits channels into +``num_groups`` groups, applies 2x2 whitening within each group, then a +per-channel 2x2 affine transform + 2-vector bias. No running statistics +(differs from :class:`BatchNorm2d`). +""" + +import torch +import torch.nn as nn +from torch.nn import init + +from complextorch.nn.functional import inv_sqrtm2x2 + +__all__ = ["GroupNorm"] + + +class GroupNorm(nn.Module): + r""" + Complex-Valued Group Normalization + ---------------------------------- + + Like :class:`torch.nn.GroupNorm`, but applies the Trabelsi 2x2 whitening + transform within each group, then a per-channel 2x2 affine. + + Args: + num_groups: number of groups to divide the channels into; must divide + ``num_channels``. + num_channels: number of channels in the input. + eps: numerical stabilizer. + affine: if ``True``, applies a learnable per-channel 2x2 affine + + 2-vector bias. + """ + + def __init__( + self, + num_groups: int, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + ) -> None: + super().__init__() + if num_channels % num_groups != 0: + raise ValueError( + f"num_channels ({num_channels}) must be divisible by num_groups ({num_groups})" + ) + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.affine = affine + + if affine: + self.weight = nn.Parameter(torch.empty(2, 2, num_channels)) + self.bias = nn.Parameter(torch.empty(2, num_channels)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if not self.affine: + return + self.weight.data.copy_(0.70710678118 * torch.eye(2, 2).unsqueeze(-1)) + init.zeros_(self.bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not input.is_complex(): + raise TypeError( + f"GroupNorm expects a complex input, got dtype={input.dtype}" + ) + b, c = input.shape[:2] + if c != self.num_channels: + raise ValueError(f"Expected {self.num_channels} channels, got {c}") + spatial = input.shape[2:] + g = self.num_groups + c_per_g = c // g + + # Reshape to (B, G, C//G, *spatial). Whiten over (C//G, *spatial) per group. + re = input.real.view(b, g, c_per_g, *spatial) + im = input.imag.view(b, g, c_per_g, *spatial) + + # Axes to reduce over within each group: c_per_g + spatial dims. + reduce_axes = tuple(range(2, 2 + 1 + len(spatial))) # dim 2 onwards + # Mean per (B, G, 1, 1, ...) — center + mean_r = re.mean(dim=reduce_axes, keepdim=True) + mean_i = im.mean(dim=reduce_axes, keepdim=True) + re_c = re - mean_r + im_c = im - mean_i + + # 2x2 covariance per group + v_rr = (re_c * re_c).mean(dim=reduce_axes) + self.eps + v_ii = (im_c * im_c).mean(dim=reduce_axes) + self.eps + v_ir = (re_c * im_c).mean(dim=reduce_axes) + + p, q, _, s = inv_sqrtm2x2(v_rr, v_ir, None, v_ii, symmetric=True) + + # Broadcast p, q, s back over the per-group reduced shape: (B, G, 1, 1, ...) + bcast_shape = (b, g) + (1,) * (1 + len(spatial)) + p = p.view(bcast_shape) + q = q.view(bcast_shape) + s = s.view(bcast_shape) + + out_r = p * re_c + q * im_c + out_i = q * re_c + s * im_c + + # Flatten group dim back: (B, C, *spatial) + out_r = out_r.reshape(b, c, *spatial) + out_i = out_i.reshape(b, c, *spatial) + + if self.affine: + # weight has shape (2, 2, C); broadcast over batch + spatial + chan_shape = (1, c) + (1,) * len(spatial) + w = self.weight + new_r = w[0, 0].view(chan_shape) * out_r + w[0, 1].view(chan_shape) * out_i + new_i = w[1, 0].view(chan_shape) * out_r + w[1, 1].view(chan_shape) * out_i + new_r = new_r + self.bias[0].view(chan_shape) + new_i = new_i + self.bias[1].view(chan_shape) + out_r, out_i = new_r, new_i + + return torch.complex(out_r, out_i) + + def extra_repr(self) -> str: + return ( + f"num_groups={self.num_groups}, num_channels={self.num_channels}, " + f"eps={self.eps}, affine={self.affine}" + ) diff --git a/complextorch/nn/modules/layernorm.py b/complextorch/nn/modules/layernorm.py index a7709d2..9ac4884 100755 --- a/complextorch/nn/modules/layernorm.py +++ b/complextorch/nn/modules/layernorm.py @@ -1,10 +1,8 @@ -from typing import Union, List - import torch import torch.nn as nn from torch.nn import init -from .. import functional as cvF +from complextorch.nn import functional as cvF __all__ = ["LayerNorm"] @@ -28,7 +26,7 @@ class LayerNorm(nn.Module): def __init__( self, - normalized_shape: Union[int, List[int], torch.Size], + normalized_shape: int | list[int] | torch.Size, *, eps: float = 1e-5, elementwise_affine: bool = True, @@ -67,9 +65,9 @@ def reset_parameters(self) -> None: def forward(self, input: torch.Tensor) -> torch.Tensor: # Sanity check to make sure the shapes match - assert ( - self.normalized_shape == input.shape[-len(self.normalized_shape) :] - ), "Expected normalized_shape to match last dimensions of input shape!" + assert self.normalized_shape == input.shape[-len(self.normalized_shape) :], ( + "Expected normalized_shape to match last dimensions of input shape!" + ) # if self.elementwise_affine: # self.weight.data = self.weight.data.to(input.device) diff --git a/complextorch/nn/modules/linear.py b/complextorch/nn/modules/linear.py old mode 100755 new mode 100644 index dba70cb..70c7c9e --- a/complextorch/nn/modules/linear.py +++ b/complextorch/nn/modules/linear.py @@ -1,124 +1,135 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -__all__ = ["Linear", "SlowLinear"] - - -class Linear(nn.Module): - r""" - Complex-Valued Linear using PyTorch - ----------------------------------- - - - Implemented using `torch.nn.Linear `_ and complex-valued tensors. - - - Used to be slower than `complextorch` version but is now faster after PyTorch 2.1.0 update. - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - device=None, - dtype=torch.cfloat, - ) -> None: - super(Linear, self).__init__() - - self.linear = nn.Linear( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Computes complex-valued convolution using PyTorch. - - Args: - input (torch.Tensor): input tensor - - Returns: - torch.Tensor: Linear(input) - """ - return self.linear(input) - - -class SlowLinear(nn.Module): - r""" - Slow Complex-Valued Linear Layer - -------------------------------- - - Follows `PyTorch implementation `_ using Gauss' trick to improve the computation as in :doc:`Complex-Valued Convolution <./conv>`. - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - device=None, - dtype=None, - ) -> None: - super(SlowLinear, self).__init__() - - # Assumes PyTorch complex weight initialization is correct - __temp = nn.Linear( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=torch.cfloat, - ) - - self.linear_r = nn.Linear( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - self.linear_i = nn.Linear( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) - - self.linear_r.weight.data = __temp.weight.real - self.linear_i.weight.data = __temp.weight.imag - - if bias: - self.linear_r.bias.data = __temp.bias.real - self.linear_i.bias.data = __temp.bias.imag - - @property - def weight(self) -> torch.Tensor: - return torch.complex(self.linear_r.weight, self.linear_i.weight) - - @property - def bias(self) -> torch.Tensor: - if self.linear_r.bias is None: - return None - else: - return torch.complex(self.linear_r.bias, self.linear_i.bias) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - r""" - Computes multiplication 25% faster than naive method by using Gauss' multiplication trick - """ - t1 = self.linear_r(input.real) - t2 = self.linear_i(input.imag) - bias = ( - None - if self.linear_r.bias is None - else (self.linear_r.bias + self.linear_i.bias) - ) - t3 = F.linear( - input=(input.real + input.imag), - weight=(self.linear_r.weight + self.linear_i.weight), - bias=bias, - ) - return torch.complex(t1 - t2, t3 - t2 - t1) +import math +from typing import ClassVar + +import torch +import torch.nn as nn + +__all__ = ["Bilinear", "Linear"] + + +class Linear(nn.Module): + r""" + Complex-Valued Linear using PyTorch + ----------------------------------- + + - Implemented as a thin wrapper around :class:`torch.nn.Linear` + with complex-valued tensors. The only behavioural difference is + the default ``dtype=torch.cfloat``. + + - See :mod:`complextorch.nn.gauss` for the hand-rolled real/imag-split + variant using Gauss' trick (kept as a reference implementation). + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=torch.cfloat, + ) -> None: + super().__init__() + + self.linear = nn.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r"""Computes complex-valued convolution using PyTorch. + + Args: + input (torch.Tensor): input tensor + + Returns: + torch.Tensor: Linear(input) + """ + return self.linear(input) + + +class Bilinear(nn.Module): + r""" + Complex-Valued Bilinear Layer + ----------------------------- + + Applies a complex-valued bilinear transformation: + + .. math:: + + y_k = \mathbf{x}_1^\dagger \mathbf{W}_k \mathbf{x}_2 + b_k \qquad (\text{conjugate=True, Hermitian}) + + y_k = \mathbf{x}_1^\top \mathbf{W}_k \mathbf{x}_2 + b_k \qquad (\text{conjugate=False, plain bilinear}) + + With ``conjugate=True`` (default) the input ``x_1`` is conjugated before the + contraction, giving the mathematically standard Hermitian inner-product form. + Setting ``conjugate=False`` uses a plain bilinear product, matching the + real-valued :class:`torch.nn.Bilinear` semantics. + + Args: + in1_features: size of the first input. + in2_features: size of the second input. + out_features: size of the output. + bias: if ``True``, adds a learnable bias. + conjugate: if ``True`` (default), uses Hermitian form. + device, dtype: standard PyTorch factory kwargs; ``dtype`` defaults to + ``torch.cfloat``. + """ + + __constants__: ClassVar[list[str]] = [ + "in1_features", + "in2_features", + "out_features", + "conjugate", + ] + + def __init__( + self, + in1_features: int, + in2_features: int, + out_features: int, + bias: bool = True, + conjugate: bool = True, + device=None, + dtype=torch.cfloat, + ) -> None: + super().__init__() + self.in1_features = in1_features + self.in2_features = in2_features + self.out_features = out_features + self.conjugate = conjugate + + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = nn.Parameter( + torch.empty((out_features, in1_features, in2_features), **factory_kwargs) + ) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + bound = 1.0 / math.sqrt(self.in1_features) + with torch.no_grad(): + self.weight.real.uniform_(-bound, bound) + self.weight.imag.uniform_(-bound, bound) + if self.bias is not None: + self.bias.real.uniform_(-bound, bound) + self.bias.imag.uniform_(-bound, bound) + + def forward(self, input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + x1 = input1.conj() if self.conjugate else input1 + out = torch.einsum("...i,kij,...j->...k", x1, self.weight, input2) + if self.bias is not None: + out = out + self.bias + return out + + def extra_repr(self) -> str: + return ( + f"in1_features={self.in1_features}, in2_features={self.in2_features}, " + f"out_features={self.out_features}, bias={self.bias is not None}, " + f"conjugate={self.conjugate}" + ) diff --git a/complextorch/nn/modules/loss.py b/complextorch/nn/modules/loss.py index 88f8eea..f490156 100755 --- a/complextorch/nn/modules/loss.py +++ b/complextorch/nn/modules/loss.py @@ -1,24 +1,36 @@ -from typing import Optional - import torch import torch.nn as nn import torch.nn.functional as F __all__ = [ - "GeneralizedSplitLoss", - "SplitL1", - "SplitMSE", "SSIM", - "SplitSSIM", - "PerpLossSSIM", - "CVQuadError", - "CVFourthPowError", "CVCauchyError", + "CVFourthPowError", "CVLogCoshError", "CVLogError", + "CVQuadError", + "GeneralizedSplitLoss", + "MSELoss", + "PerpLossSSIM", + "SplitL1", + "SplitMSE", + "SplitSSIM", ] +def _reduce(loss: torch.Tensor, reduction: str) -> torch.Tensor: + r"""Apply a PyTorch-style reduction (``'mean'`` | ``'sum'`` | ``'none'``).""" + if reduction == "mean": + return loss.mean() + if reduction == "sum": + return loss.sum() + if reduction == "none": + return loss + raise ValueError( + f"reduction must be one of 'mean', 'sum', 'none'; got {reduction!r}" + ) + + class GeneralizedSplitLoss(nn.Module): r""" Generalized Split Loss Function @@ -32,7 +44,7 @@ class GeneralizedSplitLoss(nn.Module): """ def __init__(self, loss_r: nn.Module, loss_i: nn.Module) -> None: - super(GeneralizedSplitLoss, self).__init__() + super().__init__() self.loss_r = loss_r self.loss_i = loss_i @@ -70,7 +82,7 @@ def __init__( weight_mag: float = 1.0, weight_phase: float = 1.0, ) -> None: - super(GeneralizedPolarLoss, self).__init__() + super().__init__() self.loss_mag = loss_mag self.loss_phase = loss_phase @@ -146,7 +158,7 @@ def forward( self, x: torch.Tensor, y: torch.Tensor, - data_range: Optional[torch.Tensor] = None, + data_range: torch.Tensor | None = None, full: bool = False, ) -> torch.Tensor: r"""Computes the SSIM metric on the real-valued tensors. @@ -163,10 +175,12 @@ def forward( assert isinstance(self.w, torch.Tensor) if data_range is None: - data_range = torch.ones_like(y) # * Y.max() + data_range = torch.ones_like(y) p = (self.win_size - 1) // 2 - data_range = data_range[:, :, p:-p, p:-p] - data_range = data_range[:, None, None, None] + if p > 0: + data_range = data_range[:, :, p:-p, p:-p] + else: + data_range = data_range[:, None, None, None] C1 = (self.k1 * data_range) ** 2 C2 = (self.k2 * data_range) ** 2 device = x.device @@ -189,8 +203,7 @@ def forward( if full: return S - else: - return S.mean() + return S.mean() class SplitSSIM(GeneralizedSplitLoss): @@ -208,7 +221,7 @@ def forward( self, x: torch.Tensor, y: torch.Tensor, - data_range: Optional[torch.Tensor] = None, + data_range: torch.Tensor | None = None, full: bool = False, ) -> torch.Tensor: return self.loss_r( @@ -264,7 +277,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: mag_target[aligned_mask] - ploss[aligned_mask] ) final_term[~aligned_mask] = ploss[~aligned_mask] - ssim_loss = (1 - self.ssim(x, y)) / mag_input.shape[0] + ssim_loss = (1 - self.ssim(mag_input, mag_target)) / mag_input.shape[0] return ( final_term.mean() * torch.clamp(self.param, 0, 1) @@ -279,7 +292,11 @@ class CVQuadError(nn.Module): .. math:: - \mathcal{L}(\mathbf{x}, \mathbf{y}) = \frac{1}{2}\text{sum}(|\mathbf{x} - \mathbf{y}|^2) + \mathcal{L}(\mathbf{x}, \mathbf{y}) = \frac{1}{2}\text{reduce}(|\mathbf{x} - \mathbf{y}|^2) + + The original paper specified a sum reduction; the default here is ``'mean'`` + so the loss is independent of batch/feature size and consistent with the + rest of the library. Use ``reduction='sum'`` to recover the paper's form. Based on work from the following paper: @@ -290,20 +307,37 @@ class CVQuadError(nn.Module): - https://www.ingentaconnect.com/content/asprs/pers/2010/00000076/00000009/art00008?crawler=true&mimetype=application/pdf """ - def __init__(self) -> None: - super(CVQuadError, self).__init__() + def __init__(self, reduction: str = "mean") -> None: + super().__init__() + self.reduction = reduction def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - r"""Computes the complex-valued quadratic error function. + return _reduce(0.5 * ((x - y).abs() ** 2), self.reduction) - Args: - x (torch.Tensor): estimated labels - y (torch.Tensor): target/ground truth labels - Returns: - torch.Tensor: :math:`\frac{1}{2}\text{sum}(|\mathbf{x} - \mathbf{y}|^2)` - """ - return 0.5 * ((x - y).abs() ** 2).sum() +class MSELoss(nn.Module): + r""" + Complex-Valued Mean Squared Error Loss + -------------------------------------- + + .. math:: + + \mathcal{L}(\mathbf{x}, \mathbf{y}) = \text{reduce}(|\mathbf{x} - \mathbf{y}|^2) + + Drop-in complex analogue of :class:`torch.nn.MSELoss`. Unlike + :class:`CVQuadError` this carries no 1/2 factor, so it matches PyTorch's + real-valued MSE exactly when ``x`` and ``y`` are real. + + Args: + reduction: ``'mean'`` (default), ``'sum'``, or ``'none'``. + """ + + def __init__(self, reduction: str = "mean") -> None: + super().__init__() + self.reduction = reduction + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return _reduce((x - y).abs() ** 2, self.reduction) class CVFourthPowError(nn.Module): @@ -313,7 +347,9 @@ class CVFourthPowError(nn.Module): .. math:: - \mathcal{L}(\mathbf{x}, \mathbf{y}) = \frac{1}{2}\text{sum}(|\mathbf{x} - \mathbf{y}|^4) + \mathcal{L}(\mathbf{x}, \mathbf{y}) = \frac{1}{2}\text{reduce}(|\mathbf{x} - \mathbf{y}|^4) + + See :class:`CVQuadError` for notes on ``reduction``. Based on work from the following paper: @@ -324,20 +360,12 @@ class CVFourthPowError(nn.Module): - https://www.ingentaconnect.com/content/asprs/pers/2010/00000076/00000009/art00008?crawler=true&mimetype=application/pdf """ - def __init__(self) -> None: - super(CVFourthPowError, self).__init__() + def __init__(self, reduction: str = "mean") -> None: + super().__init__() + self.reduction = reduction def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - r"""Computes the complex-valued fourth power error function. - - Args: - x (torch.Tensor): estimated labels - y (torch.Tensor): target/ground truth labels - - Returns: - torch.Tensor: :math:`\frac{1}{2}\text{sum}(|\mathbf{x} - \mathbf{y}|^4)` - """ - return 0.5 * ((x - y).abs() ** 4).sum() + return _reduce(0.5 * ((x - y).abs() ** 4), self.reduction) class CVCauchyError(nn.Module): @@ -346,9 +374,10 @@ class CVCauchyError(nn.Module): .. math:: - \mathcal{L}(\mathbf{x}, \mathbf{y}) = \frac{1}{2}\text{sum}( c^2 / 2 \ln(1 + |\mathbf{x} - \mathbf{y}|^2/c^2) ) + \mathcal{L}(\mathbf{x}, \mathbf{y}) = \text{reduce}( c^2 / 2 \ln(1 + |\mathbf{x} - \mathbf{y}|^2/c^2) ) - where :math:`c` is typically set to unity. + where :math:`c` is typically set to unity. See :class:`CVQuadError` for + notes on ``reduction``. Based on work from the following paper: @@ -359,22 +388,16 @@ class CVCauchyError(nn.Module): - https://www.ingentaconnect.com/content/asprs/pers/2010/00000076/00000009/art00008?crawler=true&mimetype=application/pdf """ - def __init__(self, c: float = 1) -> None: - super(CVCauchyError, self).__init__() - + def __init__(self, c: float = 1, reduction: str = "mean") -> None: + super().__init__() self.c2 = c**2 + self.reduction = reduction def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - r"""Computes the complex-valued Cauchy error function. - - Args: - x (torch.Tensor): estimated labels - y (torch.Tensor): target/ground truth labels - - Returns: - torch.Tensor: :math:`\frac{1}{2}\text{sum}( c^2 / 2 \ln(1 + |\mathbf{x} - \mathbf{y}|^2/c^2) )` - """ - return (self.c2 / 2 * torch.log(1 + ((x - y).abs() ** 2) / self.c2)).sum() + return _reduce( + self.c2 / 2 * torch.log(1 + ((x - y).abs() ** 2) / self.c2), + self.reduction, + ) class CVLogCoshError(nn.Module): @@ -383,7 +406,9 @@ class CVLogCoshError(nn.Module): .. math:: - \mathcal{L}(\mathbf{x}, \mathbf{y}) = \text{sum}(\ln(\cosh(|\mathbf{x} - \mathbf{y}|^2)) + \mathcal{L}(\mathbf{x}, \mathbf{y}) = \text{reduce}(\ln(\cosh(|\mathbf{x} - \mathbf{y}|^2)) + + See :class:`CVQuadError` for notes on ``reduction``. Based on work from the following paper: @@ -394,20 +419,12 @@ class CVLogCoshError(nn.Module): - https://www.ingentaconnect.com/content/asprs/pers/2010/00000076/00000009/art00008?crawler=true&mimetype=application/pdf """ - def __init__(self) -> None: - super(CVLogCoshError, self).__init__() + def __init__(self, reduction: str = "mean") -> None: + super().__init__() + self.reduction = reduction def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - r"""Computes the complex-valued log-cosh error function. - - Args: - x (torch.Tensor): estimated labels - y (torch.Tensor): target/ground truth labels - - Returns: - torch.Tensor: :math:`\text{sum}(\ln(\cosh(|\mathbf{x} - \mathbf{y}|^2))` - """ - return torch.log(torch.cosh((x - y).abs() ** 2)).sum() + return _reduce(torch.log(torch.cosh((x - y).abs() ** 2)), self.reduction) class CVLogError(nn.Module): @@ -416,7 +433,9 @@ class CVLogError(nn.Module): .. math:: - \mathcal{L}(\mathbf{x}, \mathbf{y}) = \text{sum}(|\ln(\mathbf{x}) - \ln(\mathbf{y})|^2) + \mathcal{L}(\mathbf{x}, \mathbf{y}) = \text{reduce}(|\ln(\mathbf{x}) - \ln(\mathbf{y})|^2) + + See :class:`CVQuadError` for notes on ``reduction``. Based on work from the following paper: @@ -427,18 +446,10 @@ class CVLogError(nn.Module): - https://arxiv.org/abs/2101.12249 """ - def __init__(self) -> None: - super(CVLogError, self).__init__() + def __init__(self, reduction: str = "mean") -> None: + super().__init__() + self.reduction = reduction def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - r"""Computes the complex-valued log error function. - - Args: - x (torch.Tensor): estimated labels - y (torch.Tensor): target/ground truth labels - - Returns: - torch.Tensor: :math:`\text{sum}(|\ln(\mathbf{x}) - \ln(\mathbf{y})|^2)` - """ err = torch.log(x) - torch.log(y) - return (err.abs() ** 2).sum() + return _reduce(err.abs() ** 2, self.reduction) diff --git a/complextorch/nn/modules/manifold.py b/complextorch/nn/modules/manifold.py index 3841df8..66b6644 100755 --- a/complextorch/nn/modules/manifold.py +++ b/complextorch/nn/modules/manifold.py @@ -1,403 +1,551 @@ -from typing import Tuple - -import numpy as np -import torch -import torch.nn as nn -from torch.nn.common_types import _size_1_t, _size_2_t - -__all__ = ["wFMConv1d", "wFMConv2d"] - - -def _normalize_weights_squared(weights: torch.Tensor) -> torch.Tensor: - r"""Normalizes the square of input tensor (weights) such that the sum of the output is 1. - Follows the function `weightNormalize1` from https://github.com/xingyifei2016/RotLieNet/blob/master/layers.py. - - Args: - weights (torch.Tensor): input tensor - - Returns: - torch.Tensor: normalized output - """ - return (weights**2) / torch.sum(weights**2) - - -def _normalize_weights(weights: torch.Tensor) -> torch.Tensor: - r"""Normalizes the input tensor by the sum of its square. - Follows the function `weightNormalize2` from https://github.com/xingyifei2016/RotLieNet/blob/master/layers.py. - - Args: - weights (torch.Tensor): input tensor - - Returns: - torch.Tensor: normalized output - """ - return weights / torch.sum(weights**2) - - -def _normalize_rows(weights: torch.Tensor) -> torch.Tensor: - r"""Normalizes the square of input tensor by each row such that the sum of each row of the output is 1. - Follows the function `weightNormalize` from https://github.com/xingyifei2016/RotLieNet/blob/master/layers.py. - - Args: - weights (torch.Tensor): input tensor - - Returns: - torch.Tensor: normalized output - """ - return (weights**2) / torch.sum(weights**2, dim=1, keepdim=True) - - -class _wFMConv2dHelper(nn.Module): - r""" - Helper Class for wFMConv2d - -------------------------- - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_2_t, - stride: _size_2_t = (1, 1), - padding: _size_2_t = (0, 0), - weight_dropout: float = 0.0, - eps: float = 1e-5, - ) -> None: - super(_wFMConv2dHelper, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.eps = eps - - prod_kernel_size = np.prod(kernel_size) - - self.dropout = nn.Dropout(weight_dropout) - - # Weight matrices - self.weight_matrix_ang1 = nn.Parameter( - torch.rand(in_channels, prod_kernel_size), requires_grad=True - ) - - self.weight_matrix_ang2 = nn.Parameter( - torch.rand(out_channels, in_channels), requires_grad=True - ) - - self.unfold = nn.Unfold(kernel_size=kernel_size, stride=stride, padding=padding) - - def compute_output_shape(self, input_shape) -> Tuple[int]: - return tuple( - int(np.floor((in_shape + 2 * padding - (kernel_size - 1) - 1) / stride + 1)) - for in_shape, padding, kernel_size, stride in zip( - input_shape, self.padding, self.kernel_size, self.stride - ) - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - batch_size, mag_ang, in_channels, *input_shape = input.shape - - assert mag_ang == 2, "Input must be complex valued in polar form (mag, ang)" - assert in_channels == self.in_channels, "Input channels must match" - - out_channels = self.out_channels - kernel_size = self.kernel_size - prod_kernel_size = np.prod(kernel_size) - - output_shape = self.compute_output_shape(input_shape) - L = np.prod(output_shape) # Total number of unfolded blocks - - input = input.view(batch_size * 2, in_channels, *input_shape) - - # unfolded shape: (batch_size * 2, in_channels * prod_kernel_size, L) - temporal_buckets = self.unfold(input).view( - batch_size, 2, in_channels, prod_kernel_size, L - ) - - ### Do magnitude processing - tb_mag = torch.log( - temporal_buckets[:, 0] - .permute(0, 3, 1, 2) - .contiguous() - .view(batch_size * L, in_channels, prod_kernel_size) - + self.eps - ) - - # Normalize the weights - wmm1 = _normalize_rows(self.dropout(self.weight_matrix_ang1)) - wmm2 = _normalize_rows(self.dropout(self.weight_matrix_ang2)) - - out_mag = ( - torch.sum(tb_mag * wmm1, dim=2).unsqueeze(1).repeat(1, out_channels, 1) - ) - - out_mag = torch.exp( - torch.sum(out_mag * wmm2, dim=2) - .view(batch_size, 1, *output_shape, out_channels) - .permute(0, 1, 4, 2, 3) - .contiguous() - ) - - ### Do phase processing - tb_ang = ( - temporal_buckets[:, 1] - .permute(0, 3, 1, 2) - .contiguous() - .view(batch_size * L, in_channels, prod_kernel_size) - ) - - # Normalize the weights - wma1 = _normalize_weights_squared(self.weight_matrix_ang1) - wma2 = _normalize_weights_squared(self.weight_matrix_ang2) - - out_ang = ( - torch.sum(tb_ang * wma1, dim=2).unsqueeze(1).repeat(1, out_channels, 1) - ) - - out_ang = ( - torch.sum(out_ang * wma2, dim=2) - .view(batch_size, 1, *output_shape, out_channels) - .permute(0, 1, 4, 2, 3) - .contiguous() - ) - - return torch.cat((out_mag, out_ang), dim=1) - - -class wFMConv2d(nn.Module): - r""" - 2-D Weighted Frechet Mean Convolution Layer - ------------------------------------------- - - In a paper title `Complex-Valued Learning as Principled Transformations on a Scaling and Rotation Manifold`, the authors R Chakraborty, Y Xing, and S Yu introduce a complex-valued convolution operator offering similar equivariance properties to the spatial equivariance of the traditional real-valued convolution operator. - By approach the complex domain as a Riemannian homogeneous space consisting of the product of planar rotation and non-zero scaling, they define a convolution operator equivariant to phase shift and amplitude scaling. - Although their paper shows promising results in reducing the number of parameters of a complex-valued network for several problems, their work has not gained mainstream support. - - As the authors mention in the final bullet point in Section IV-A1, - - If :math:`d` is the manifold distance in (2) for the Euclidean - space that is also Riemannian, then wFM has exactly the - weighted average as its closed-form solution. That is, our - wFM convolution on the Euclidean manifold is reduced - to the standard convolution, although with the additional - convexity constraint on the weights. - - Hence, the implementation closely follows the conventional convolution operator with the exception of the weight normalization. - - Note: the weight normalization, although consistent with the authors' implementation, lacks adequate explanation from the literature and could be improved for further clarity. - - Based on work from the following paper: - - **R Chakraborty, Y Xing, S Yu. SurReal: Complex-Valued Learning as Principled Transformations on a Scaling and Rotation Manifold** - - - Eqs. (14)-(16) - - - https://arxiv.org/abs/1910.11334 - - - Modified from implementation: https://github.com/xingyifei2016/RotLieNet (yields consistent results as this implementation) - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_2_t, - stride: _size_2_t = (1, 1), - padding: _size_2_t = (0, 0), - weight_dropout: float = 0.0, - ) -> None: - super(wFMConv2d, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.weight_dropout = weight_dropout - - prod_kernel_size = np.prod(kernel_size) - - # Weight matrices for magnitude and angle - self.weight_matrix_mag = nn.Parameter( - torch.rand(in_channels, prod_kernel_size), requires_grad=True - ) - - self.weight_matrix_ang = nn.Parameter( - torch.rand(in_channels, prod_kernel_size), requires_grad=True - ) - - self.unfold = nn.Unfold(kernel_size=kernel_size, stride=stride, padding=padding) - - self.wFM_conv = _wFMConv2dHelper( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - weight_dropout=weight_dropout, - ) - - def compute_output_shape(self, input_shape) -> Tuple[int]: - return tuple( - int(np.floor((in_shape + 2 * padding - (kernel_size - 1) - 1) / stride + 1)) - for in_shape, padding, kernel_size, stride in zip( - input_shape, self.padding, self.kernel_size, self.stride - ) - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Computes the 2-D weighted Frechet mean (wFM) convolution. - - Args: - input (torch.Tensor): input tensor - - Returns: - torch.Tensor: output tensor - """ - batch_size, in_channels, *input_shape = input.shape - - assert in_channels == self.in_channels, "Input channels must match" - - kernel_size = self.kernel_size - prod_kernel_size = np.prod(kernel_size) - - output_shape = self.compute_output_shape(input_shape) - L = np.prod(output_shape) # Total number of unfolded blocks - - self.fold = nn.Fold( - output_size=input_shape, - kernel_size=kernel_size, - stride=self.stride, - padding=self.padding, - ) - - # Separate magnitude and angle from torch.Tensor input - x_mag, x_ang = input.abs(), input.angle() - - ### Do magnitude processing - x_mag = self.unfold(x_mag).view(batch_size, in_channels, prod_kernel_size, L) - - x_mag = ( - x_mag.permute(0, 3, 1, 2) - .contiguous() - .view(batch_size * L, in_channels, prod_kernel_size) - ) - - x_mag = x_mag + _normalize_weights_squared(self.weight_matrix_mag) - - x_mag = ( - x_mag.view(batch_size, *output_shape, in_channels * prod_kernel_size) - .permute(0, 3, 1, 2) - .contiguous() - .unsqueeze(1) - ) - - ### Do phase processing - x_ang = self.unfold(x_ang).view(batch_size, in_channels, prod_kernel_size, L) - - x_ang = ( - x_ang.permute(0, 3, 1, 2) - .contiguous() - .view(batch_size * L, in_channels, prod_kernel_size) - ) - - x_ang = x_ang * _normalize_weights(self.weight_matrix_ang) - - x_ang = ( - x_ang.view(batch_size, *output_shape, in_channels * prod_kernel_size) - .permute(0, 3, 1, 2) - .contiguous() - .unsqueeze(1) - ) - - # Stack the magnitude and phase tensors - in_fold = self.fold( - torch.cat((x_mag, x_ang), dim=1).view( - batch_size, 2 * in_channels * prod_kernel_size, L - ) - ).view(batch_size, 2, in_channels, *input_shape) - - x_out = self.wFM_conv(in_fold) - return torch.polar(x_out[:, 0], x_out[:, 1]) - - -class wFMConv1d(nn.Module): - r""" - 1-D Weighted Frechet Mean Convolution Layer - ------------------------------------------- - - In a paper title `Complex-Valued Learning as Principled Transformations on a Scaling and Rotation Manifold`, the authors R Chakraborty, Y Xing, and S Yu introduce a complex-valued convolution operator offering similar equivariance properties to the spatial equivariance of the traditional real-valued convolution operator. - By approach the complex domain as a Riemannian homogeneous space consisting of the product of planar rotation and non-zero scaling, they define a convolution operator equivariant to phase shift and amplitude scaling. - Although their paper shows promising results in reducing the number of parameters of a complex-valued network for several problems, their work has not gained mainstream support. - - As the authors mention in the final bullet point in Section IV-A1, - - If :math:`d` is the manifold distance in (2) for the Euclidean - space that is also Riemannian, then wFM has exactly the - weighted average as its closed-form solution. That is, our - wFM convolution on the Euclidean manifold is reduced - to the standard convolution, although with the additional - convexity constraint on the weights. - - Hence, the implementation closely follows the conventional convolution operator with the exception of the weight normalization. - - Note: the weight normalization, although consistent with the authors' implementation, lacks adequate explanation from the literature and could be improved for further clarity. - - Note: This is a wrapper around wFMConv2d that performs a 1D convolution - - Based on work from the following paper: - - **R Chakraborty, Y Xing, S Yu. SurReal: Complex-Valued Learning as Principled Transformations on a Scaling and Rotation Manifold** - - - Eqs. (14)-(16) - - - https://arxiv.org/abs/1910.11334 - - - Modified from implementation: https://github.com/xingyifei2016/RotLieNet (yields consistent results as this implementation) - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_1_t, - stride: _size_1_t = 1, - padding: _size_1_t = 0, - weight_dropout: float = 0.0, - ) -> None: - super(wFMConv1d, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.weight_dropout = weight_dropout - - self.conv1d = wFMConv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(1, kernel_size), - stride=(1, stride), - padding=(0, padding), - weight_dropout=weight_dropout, - ) - - self.wFM_conv = self.conv1d.wFM_conv - - def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Computes the 1-D weighted Frechet mean (wFM) convolution. See :class:`wFMConv2d` for more implementation details as :class:`wFMConv1d` is a wrapper around :class:`wFMConv2d`. - - Args: - input (torch.Tensor): input tensor - - Returns: - torch.Tensor: output tensor - """ - return self.conv1d(input.unsqueeze(-2)).squeeze() - - @property - def weight_matrix_ang(self) -> torch.Tensor: - return self.conv1d.weight_matrix_ang - - @property - def weight_matrix_mag(self) -> torch.Tensor: - return self.conv1d.weight_matrix_mag +import numpy as np +import torch +import torch.nn as nn +from torch.nn.common_types import _size_1_t, _size_2_t + +__all__ = ["wFMConv1d", "wFMConv2d", "wFMDistanceLinear", "wFMReLU"] + + +def _normalize_weights_squared(weights: torch.Tensor) -> torch.Tensor: + r"""Normalizes the square of input tensor (weights) such that the sum of the output is 1. + Follows the function `weightNormalize1` from https://github.com/xingyifei2016/RotLieNet/blob/master/layers.py. + + Args: + weights (torch.Tensor): input tensor + + Returns: + torch.Tensor: normalized output + """ + return (weights**2) / torch.sum(weights**2) + + +def _normalize_weights(weights: torch.Tensor) -> torch.Tensor: + r"""Normalizes the input tensor by the sum of its square. + Follows the function `weightNormalize2` from https://github.com/xingyifei2016/RotLieNet/blob/master/layers.py. + + Args: + weights (torch.Tensor): input tensor + + Returns: + torch.Tensor: normalized output + """ + return weights / torch.sum(weights**2) + + +def _normalize_rows(weights: torch.Tensor) -> torch.Tensor: + r"""Normalizes the square of input tensor by each row such that the sum of each row of the output is 1. + Follows the function `weightNormalize` from https://github.com/xingyifei2016/RotLieNet/blob/master/layers.py. + + Args: + weights (torch.Tensor): input tensor + + Returns: + torch.Tensor: normalized output + """ + return (weights**2) / torch.sum(weights**2, dim=1, keepdim=True) + + +class _wFMConv2dHelper(nn.Module): + r""" + Helper Class for wFMConv2d + ---------------------------- + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = (1, 1), + padding: _size_2_t = (0, 0), + weight_dropout: float = 0.0, + eps: float = 1e-5, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.eps = eps + + prod_kernel_size = np.prod(kernel_size) + + self.dropout = nn.Dropout(weight_dropout) + + # Weight matrices + self.weight_matrix_ang1 = nn.Parameter( + torch.rand(in_channels, prod_kernel_size), requires_grad=True + ) + + self.weight_matrix_ang2 = nn.Parameter( + torch.rand(out_channels, in_channels), requires_grad=True + ) + + self.unfold = nn.Unfold(kernel_size=kernel_size, stride=stride, padding=padding) + + def compute_output_shape(self, input_shape) -> tuple[int]: + return tuple( + int(np.floor((in_shape + 2 * padding - (kernel_size - 1) - 1) / stride + 1)) + for in_shape, padding, kernel_size, stride in zip( + input_shape, self.padding, self.kernel_size, self.stride, strict=False + ) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + batch_size, mag_ang, in_channels, *input_shape = input.shape + + assert mag_ang == 2, "Input must be complex valued in polar form (mag, ang)" + assert in_channels == self.in_channels, "Input channels must match" + + out_channels = self.out_channels + kernel_size = self.kernel_size + prod_kernel_size = np.prod(kernel_size) + + output_shape = self.compute_output_shape(input_shape) + L = np.prod(output_shape) # Total number of unfolded blocks + + input = input.view(batch_size * 2, in_channels, *input_shape) + + # unfolded shape: (batch_size * 2, in_channels * prod_kernel_size, L) + temporal_buckets = self.unfold(input).view( + batch_size, 2, in_channels, prod_kernel_size, L + ) + + ### Do magnitude processing + tb_mag = torch.log( + temporal_buckets[:, 0] + .permute(0, 3, 1, 2) + .contiguous() + .view(batch_size * L, in_channels, prod_kernel_size) + + self.eps + ) + + # Normalize the weights + wmm1 = _normalize_rows(self.dropout(self.weight_matrix_ang1)) + wmm2 = _normalize_rows(self.dropout(self.weight_matrix_ang2)) + + out_mag = ( + torch.sum(tb_mag * wmm1, dim=2).unsqueeze(1).repeat(1, out_channels, 1) + ) + + out_mag = torch.exp( + torch.sum(out_mag * wmm2, dim=2) + .view(batch_size, 1, *output_shape, out_channels) + .permute(0, 1, 4, 2, 3) + .contiguous() + ) + + ### Do phase processing + tb_ang = ( + temporal_buckets[:, 1] + .permute(0, 3, 1, 2) + .contiguous() + .view(batch_size * L, in_channels, prod_kernel_size) + ) + + # Normalize the weights + wma1 = _normalize_weights_squared(self.weight_matrix_ang1) + wma2 = _normalize_weights_squared(self.weight_matrix_ang2) + + out_ang = ( + torch.sum(tb_ang * wma1, dim=2).unsqueeze(1).repeat(1, out_channels, 1) + ) + + out_ang = ( + torch.sum(out_ang * wma2, dim=2) + .view(batch_size, 1, *output_shape, out_channels) + .permute(0, 1, 4, 2, 3) + .contiguous() + ) + + return torch.cat((out_mag, out_ang), dim=1) + + +class wFMConv2d(nn.Module): + r""" + 2-D Weighted Frechet Mean Convolution Layer + --------------------------------------------- + + In a paper title `Complex-Valued Learning as Principled Transformations on a Scaling and Rotation Manifold`, the authors R Chakraborty, Y Xing, and S Yu introduce a complex-valued convolution operator offering similar equivariance properties to the spatial equivariance of the traditional real-valued convolution operator. + By approach the complex domain as a Riemannian homogeneous space consisting of the product of planar rotation and non-zero scaling, they define a convolution operator equivariant to phase shift and amplitude scaling. + Although their paper shows promising results in reducing the number of parameters of a complex-valued network for several problems, their work has not gained mainstream support. + + As the authors mention in the final bullet point in Section IV-A1, + + If :math:`d` is the manifold distance in (2) for the Euclidean + space that is also Riemannian, then wFM has exactly the + weighted average as its closed-form solution. That is, our + wFM convolution on the Euclidean manifold is reduced + to the standard convolution, although with the additional + convexity constraint on the weights. + + Hence, the implementation closely follows the conventional convolution operator with the exception of the weight normalization. + + Note: the weight normalization, although consistent with the authors' implementation, lacks adequate explanation from the literature and could be improved for further clarity. + + Based on work from the following paper: + + **R Chakraborty, Y Xing, S Yu. SurReal: Complex-Valued Learning as Principled Transformations on a Scaling and Rotation Manifold** + + - Eqs. (14)-(16) + + - https://arxiv.org/abs/1910.11334 + + - Modified from implementation: https://github.com/xingyifei2016/RotLieNet (yields consistent results as this implementation) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = (1, 1), + padding: _size_2_t = (0, 0), + weight_dropout: float = 0.0, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.weight_dropout = weight_dropout + + prod_kernel_size = np.prod(kernel_size) + + # Weight matrices for magnitude and angle + self.weight_matrix_mag = nn.Parameter( + torch.rand(in_channels, prod_kernel_size), requires_grad=True + ) + + self.weight_matrix_ang = nn.Parameter( + torch.rand(in_channels, prod_kernel_size), requires_grad=True + ) + + self.unfold = nn.Unfold(kernel_size=kernel_size, stride=stride, padding=padding) + + self.wFM_conv = _wFMConv2dHelper( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + weight_dropout=weight_dropout, + ) + + # Lazily built and cached by input spatial shape so we don't reallocate + # an ``nn.Fold`` on every forward when shapes are stable. + self._fold_cache: dict[tuple, nn.Fold] = {} + + def _get_fold(self, input_shape) -> nn.Fold: + key = tuple(input_shape) + fold = self._fold_cache.get(key) + if fold is None: + fold = nn.Fold( + output_size=input_shape, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + ) + self._fold_cache[key] = fold + return fold + + def compute_output_shape(self, input_shape) -> tuple[int]: + return tuple( + int(np.floor((in_shape + 2 * padding - (kernel_size - 1) - 1) / stride + 1)) + for in_shape, padding, kernel_size, stride in zip( + input_shape, self.padding, self.kernel_size, self.stride, strict=False + ) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r"""Computes the 2-D weighted Frechet mean (wFM) convolution. + + Args: + input (torch.Tensor): input tensor + + Returns: + torch.Tensor: output tensor + """ + batch_size, in_channels, *input_shape = input.shape + + assert in_channels == self.in_channels, "Input channels must match" + + kernel_size = self.kernel_size + prod_kernel_size = np.prod(kernel_size) + + output_shape = self.compute_output_shape(input_shape) + L = np.prod(output_shape) # Total number of unfolded blocks + + fold = self._get_fold(input_shape) + + # Separate magnitude and angle from torch.Tensor input + x_mag, x_ang = input.abs(), input.angle() + + ### Do magnitude processing + x_mag = self.unfold(x_mag).view(batch_size, in_channels, prod_kernel_size, L) + + x_mag = ( + x_mag.permute(0, 3, 1, 2) + .contiguous() + .view(batch_size * L, in_channels, prod_kernel_size) + ) + + x_mag = x_mag + _normalize_weights_squared(self.weight_matrix_mag) + + x_mag = ( + x_mag.view(batch_size, *output_shape, in_channels * prod_kernel_size) + .permute(0, 3, 1, 2) + .contiguous() + .unsqueeze(1) + ) + + ### Do phase processing + x_ang = self.unfold(x_ang).view(batch_size, in_channels, prod_kernel_size, L) + + x_ang = ( + x_ang.permute(0, 3, 1, 2) + .contiguous() + .view(batch_size * L, in_channels, prod_kernel_size) + ) + + x_ang = x_ang * _normalize_weights(self.weight_matrix_ang) + + x_ang = ( + x_ang.view(batch_size, *output_shape, in_channels * prod_kernel_size) + .permute(0, 3, 1, 2) + .contiguous() + .unsqueeze(1) + ) + + # Stack the magnitude and phase tensors + in_fold = fold( + torch.cat((x_mag, x_ang), dim=1).view( + batch_size, 2 * in_channels * prod_kernel_size, L + ) + ).view(batch_size, 2, in_channels, *input_shape) + + x_out = self.wFM_conv(in_fold) + return torch.polar(x_out[:, 0], x_out[:, 1]) + + +class wFMConv1d(nn.Module): + r""" + 1-D Weighted Frechet Mean Convolution Layer + --------------------------------------------- + + In a paper title `Complex-Valued Learning as Principled Transformations on a Scaling and Rotation Manifold`, the authors R Chakraborty, Y Xing, and S Yu introduce a complex-valued convolution operator offering similar equivariance properties to the spatial equivariance of the traditional real-valued convolution operator. + By approach the complex domain as a Riemannian homogeneous space consisting of the product of planar rotation and non-zero scaling, they define a convolution operator equivariant to phase shift and amplitude scaling. + Although their paper shows promising results in reducing the number of parameters of a complex-valued network for several problems, their work has not gained mainstream support. + + As the authors mention in the final bullet point in Section IV-A1, + + If :math:`d` is the manifold distance in (2) for the Euclidean + space that is also Riemannian, then wFM has exactly the + weighted average as its closed-form solution. That is, our + wFM convolution on the Euclidean manifold is reduced + to the standard convolution, although with the additional + convexity constraint on the weights. + + Hence, the implementation closely follows the conventional convolution operator with the exception of the weight normalization. + + Note: the weight normalization, although consistent with the authors' implementation, lacks adequate explanation from the literature and could be improved for further clarity. + + Note: This is a wrapper around wFMConv2d that performs a 1D convolution + + Based on work from the following paper: + + **R Chakraborty, Y Xing, S Yu. SurReal: Complex-Valued Learning as Principled Transformations on a Scaling and Rotation Manifold** + + - Eqs. (14)-(16) + + - https://arxiv.org/abs/1910.11334 + + - Modified from implementation: https://github.com/xingyifei2016/RotLieNet (yields consistent results as this implementation) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + weight_dropout: float = 0.0, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.weight_dropout = weight_dropout + + self.conv1d = wFMConv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, kernel_size), + stride=(1, stride), + padding=(0, padding), + weight_dropout=weight_dropout, + ) + + self.wFM_conv = self.conv1d.wFM_conv + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r"""Computes the 1-D weighted Frechet mean (wFM) convolution. See :class:`wFMConv2d` for more implementation details as :class:`wFMConv1d` is a wrapper around :class:`wFMConv2d`. + + Args: + input (torch.Tensor): input tensor + + Returns: + torch.Tensor: output tensor + """ + return self.conv1d(input.unsqueeze(-2)).squeeze(-2) + + @property + def weight_matrix_ang(self) -> torch.Tensor: + return self.conv1d.weight_matrix_ang + + @property + def weight_matrix_mag(self) -> torch.Tensor: + return self.conv1d.weight_matrix_mag + + +class wFMReLU(nn.Module): + r""" + Weighted Fréchet Mean ReLU + ---------------------------- + + Manifold-aware nonlinearity that complements :class:`wFMConv1d` / + :class:`wFMConv2d`. Performs an additive shift on the magnitude (in + log-domain semantics) and a multiplicative scaling on the phase, with + weight normalisations chosen so that the operation lies on the + rotation+scaling manifold: + + .. math:: + + \begin{aligned} + \tilde{|z|}_c &= |z|_c + \frac{(w^{m}_c)^2}{\sum_k (w^{m}_k)^2} \\ + \tilde{\theta}_c &= \arg(z_c) \cdot \frac{w^{\theta}_c}{\sum_k (w^{\theta}_k)^2} \\ + y_c &= \tilde{|z|}_c \cdot e^{j \tilde{\theta}_c} + \end{aligned} + + Both :math:`w^m, w^\theta \in \mathbb{R}^C` are learnable. The normalisations + are exactly those used by :class:`wFMConv2d` (``weightNormalize1`` / + ``weightNormalize2`` in the reference implementation). + + Based on work from the following paper: + + **R Chakraborty, Y Xing, S Yu. SurReal: Complex-Valued Learning as Principled Transformations on a Scaling and Rotation Manifold** + + - https://arxiv.org/abs/1910.11334 + + - Reference implementation: `manifoldReLUv2angle` in https://github.com/xingyifei2016/RotLieNet + """ + + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.num_channels = num_channels + self.eps = eps + self.weight_phase = nn.Parameter(torch.empty(num_channels).uniform_(0.0, 1.0)) + self.weight_mag = nn.Parameter(torch.empty(num_channels).uniform_(0.0, 1.0)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not input.is_complex(): + input = input.to(torch.cfloat) + # Broadcast (C,) → (1, C, 1, ...) for an N-D input. + view_shape = [1] * input.dim() + view_shape[1] = self.num_channels + w_phase = _normalize_weights(self.weight_phase + self.eps).view(*view_shape) + w_mag = _normalize_weights_squared(self.weight_mag + self.eps).view(*view_shape) + mag = input.abs() + w_mag + phase = input.angle() * w_phase + return torch.polar(mag, phase) + + def extra_repr(self) -> str: + return f"num_channels={self.num_channels}" + + +class wFMDistanceLinear(nn.Module): + r""" + Weighted Fréchet Mean Distance Linear Head + -------------------------------------------- + + Computes a single weighted Fréchet mean :math:`M` over every element of a + complex input, then returns a **real-valued** distance map combining + rotation distance (on phase) and a log-magnitude ratio. Suitable as a + classification head where the goal is to summarise the input's deviation + from a learned manifold center. + + For a flattened input :math:`z \in \mathbb{C}^N`, + + .. math:: + + \begin{aligned} + \bar{w}_n &= \frac{(w_n)^2}{\sum_k (w_k)^2} \\ + M_\theta &= \tanh(-b_\theta) \cdot \sum_n \arg(z_n) \cdot \bar{w}_n \\ + M_r &= \exp(-b_r^2) + \exp\Bigl(\sum_n \log(|z_n| + \varepsilon) \cdot \bar{w}_n\Bigr) \\ + d_n &= w_\theta^2 \cdot |\arg(z_n) - M_\theta| + + w_r^2 \cdot |\log(|z_n| / (M_r + \varepsilon))| + \end{aligned} + + The output is real-valued and has the same shape as the input (with the + leading complex axis flattened to the same shape). + + .. note:: + Unlike :class:`complextorch.nn.Linear` this layer returns a real-valued + tensor (it produces invariants for classification, not complex + features). The "Distance" suffix is the reminder. + + Based on work from the following paper: + + **R Chakraborty, Y Xing, S Yu. SurReal: Complex-Valued Learning as Principled Transformations on a Scaling and Rotation Manifold** + + - https://arxiv.org/abs/1910.11334 + + - Reference implementation: ``ComplexLinearangle2Dmw_outfield`` in https://github.com/xingyifei2016/RotLieNet + """ + + def __init__(self, input_dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.input_dim = input_dim + self.eps = eps + # Per-element wFM weights and a pair of (phase, magnitude) combination + # weights / biases. + self.weights = nn.Parameter(torch.empty(input_dim).uniform_(0.0, 1.0)) + self.combine_weight = nn.Parameter(torch.empty(2).uniform_(0.0, 1.0)) + self.bias = nn.Parameter(torch.empty(2).uniform_(0.0, 1.0)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not input.is_complex(): + input = input.to(torch.cfloat) + original_shape = input.shape + batch_size = original_shape[0] + flat = input.reshape(batch_size, -1) + if flat.shape[1] != self.input_dim: + raise ValueError( + f"wFMDistanceLinear expects flattened input of size " + f"{self.input_dim}, got {flat.shape[1]}" + ) + phase = flat.angle() + mag = flat.abs() + + w = _normalize_weights_squared(self.weights) # [N], sums to 1 + M_phase = (phase * w).sum(dim=1) * torch.tanh(-self.bias[0]) # [B] + log_mag = torch.log(mag + self.eps) + M_mag = torch.exp((log_mag * w).sum(dim=1)) + torch.exp( + -(self.bias[1] ** 2) + ) # [B] + + dist_phase = (phase - M_phase.unsqueeze(1)).abs() # [B, N] + dist_mag = torch.log(mag / (M_mag.unsqueeze(1) + self.eps)).abs() # [B, N] + dist = ( + self.combine_weight[0] ** 2 * dist_phase + + self.combine_weight[1] ** 2 * dist_mag + ) + return dist.reshape(original_shape) + + def extra_repr(self) -> str: + return f"input_dim={self.input_dim}" diff --git a/complextorch/nn/modules/mask.py b/complextorch/nn/modules/mask.py index bf231f8..05da207 100755 --- a/complextorch/nn/modules/mask.py +++ b/complextorch/nn/modules/mask.py @@ -1,9 +1,9 @@ -from typing import Optional - -import torch.nn as nn import torch +import torch.nn as nn + +__all__ = ["ComplexRatioMask", "MagMinMaxNorm", "PhaseSigmoid"] -__all__ = ["ComplexRatioMask", "PhaseSigmoid", "MagMinMaxNorm"] +_EPS = 1e-12 class ComplexRatioMask(nn.Module): @@ -15,7 +15,7 @@ class ComplexRatioMask(nn.Module): \texttt{ComplexRatioMask}(\mathbf{z}) = \texttt{Sigmoid}(|\mathbf{z}|) \odot \frac{\mathbf{z}}{|\mathbf{z}|} - Retains phase and squeezes magnitude using `sigmoid function `_. + Retains phase and squeezes magnitude using :class:`torch.nn.Sigmoid`. Based on work from the following paper: @@ -27,7 +27,7 @@ class ComplexRatioMask(nn.Module): """ def __init__(self) -> None: - super(ComplexRatioMask, self).__init__() + super().__init__() def forward(self, input: torch.Tensor) -> torch.Tensor: r"""Computes complex ratio mask on complex-valued input tensor. @@ -39,7 +39,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: :math:`\text{sigmoid}(|\mathbf{z}|) * \mathbf{z} / |\mathbf{z}|` """ x_mag = input.abs() - return x_mag.sigmoid() * (input / x_mag) + return x_mag.sigmoid() * (input / x_mag.clamp(min=_EPS)) class PhaseSigmoid(nn.Module): @@ -49,9 +49,9 @@ class PhaseSigmoid(nn.Module): .. math:: - \texttt{ComplexRatioMask}(\mathbf{z}) = \texttt{Sigmoid}(|\mathbf{z}|) \odot \frac{\mathbf{z}}{|\mathbf{z}|} + \texttt{PhaseSigmoid}(\mathbf{z}) = \texttt{Sigmoid}(|\mathbf{z}|) \odot \frac{\mathbf{z}}{|\mathbf{z}|} - Retains phase and squeezes magnitude using `sigmoid function `_. + Retains phase and squeezes magnitude using :class:`torch.nn.Sigmoid`. Based on work from the following paper: @@ -62,7 +62,20 @@ class PhaseSigmoid(nn.Module): - https://ieeexplore.ieee.org/abstract/document/9335579 """ - pass + def __init__(self) -> None: + super().__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r"""Computes phase-preserving sigmoid mask on a complex-valued input tensor. + + Args: + input (torch.Tensor): input tensor + + Returns: + torch.Tensor: :math:`\text{sigmoid}(|\mathbf{z}|) * \mathbf{z} / |\mathbf{z}|` + """ + x_mag = input.abs() + return x_mag.sigmoid() * (input / x_mag.clamp(min=_EPS)) class MagMinMaxNorm(nn.Module): @@ -70,30 +83,38 @@ class MagMinMaxNorm(nn.Module): Magnitude Min-Max Normalization Layer ------------------------------------- - Applies the *min-max norm* to the input tensor yielding an output whose magnitude is normalized between 0 and 1 over the specified dimension while phase information remains unchanged. + Applies the *min-max norm* to the magnitude of the input tensor, yielding an + output whose magnitude is normalized between 0 and 1 (over the specified + dimension, if any) while phase information remains unchanged. Implements the following operation: .. math:: - \texttt{MagMinMaxNorm}(\mathbf{z}) = \frac{\mathbf{z} - \mathbf{z}_{min}}{\mathbf{z}_{max} - \mathbf{z}_{min}} + \texttt{MagMinMaxNorm}(\mathbf{z}) = \frac{|\mathbf{z}| - |\mathbf{z}|_{min}}{|\mathbf{z}|_{max} - |\mathbf{z}|_{min}} \odot \exp(j \angle\mathbf{z}) """ - def __init__(self, dim: Optional[int] = None) -> None: - super(MagMinMaxNorm, self).__init__() + def __init__(self, dim: int | None = None) -> None: + super().__init__() self.dim = dim def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Applies the *min-max norm* to the input tensor yielding an output whose magnitude is normalized between 0 and 1 over the specified dimension while phase information remains unchanged. + r"""Applies the *min-max norm* to the magnitude of the input tensor while + preserving phase. Args: input (torch.Tensor): input tensor Returns: - torch.Tensor: :math:`\frac{\mathbf{z} - \mathbf{z}_{min}}{\mathbf{z}_{max} - \mathbf{z}_{min}}` + torch.Tensor: phase-preserving min-max normalized tensor """ x_mag = input.abs() - x_min = x_mag.min() - x_max = x_mag.max() - return (input - x_min) / (x_max - x_min) + if self.dim is None: + x_min = x_mag.min() + x_max = x_mag.max() + else: + x_min = x_mag.min(dim=self.dim, keepdim=True).values + x_max = x_mag.max(dim=self.dim, keepdim=True).values + new_mag = (x_mag - x_min) / (x_max - x_min).clamp(min=_EPS) + return torch.polar(new_mag, input.angle()) diff --git a/complextorch/nn/modules/phase.py b/complextorch/nn/modules/phase.py new file mode 100644 index 0000000..1ae7f70 --- /dev/null +++ b/complextorch/nn/modules/phase.py @@ -0,0 +1,162 @@ +r""" +Learnable Phase / Complex-Scaling Modules +========================================= + +- :class:`PhaseShift` multiplies its input by :math:`e^{j\phi}` with a learnable + phase :math:`\phi` (magnitude fixed to 1). +- :class:`ComplexScaling` multiplies its input by the general complex scalar + :math:`\alpha + j\beta` with both real and imaginary parts learnable + (magnitude and phase both learnable). +""" + +import math + +import torch +import torch.nn as nn + +__all__ = ["ComplexScaling", "PhaseShift"] + + +class PhaseShift(nn.Module): + r""" + Learnable Phase Shift + --------------------- + + Multiplies a complex input by :math:`e^{j\phi}`, where :math:`\phi` is a + learnable parameter: + + .. math:: + + y = x \cdot e^{j\phi} + + The shape of :math:`\phi` is given by ``num_features``; PyTorch + broadcasting rules apply between the input and ``exp(j*phi)``. Pass a + scalar (``()``) for a single global rotation, ``(C,)`` for a per-channel + rotation, or a higher-rank tuple for finer control. + + Args: + num_features: shape of the learnable phase tensor. ``()`` or ``1`` + gives a scalar rotation. + broadcast_dim: when ``num_features`` is an int, ``phi`` is created with + shape ``(num_features,)``. To use it as the channel dim of a + ``(B, C, ...)`` input, set ``broadcast_dim=1`` (default). For a + ``(B, ..., C)`` layout set ``broadcast_dim=-1``. + """ + + def __init__( + self, + num_features: int | tuple[int, ...] = 1, + broadcast_dim: int = 1, + ) -> None: + super().__init__() + if isinstance(num_features, int): + shape: tuple[int, ...] = (num_features,) if num_features != 1 else () + else: + shape = tuple(num_features) + self.num_features = num_features + self.broadcast_dim = broadcast_dim + # Initialize phases uniformly in [-pi, pi] + phi = torch.empty(shape) + with torch.no_grad(): + phi.uniform_(-math.pi, math.pi) + self.phi = nn.Parameter(phi) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # Reshape phi to broadcast over input. For a 1-D phi of shape (C,) the + # default reshape places it at dim broadcast_dim of the input. + phi = self.phi + if phi.dim() == 1 and input.dim() > 1: + ndim = input.dim() + dim = ( + self.broadcast_dim + if self.broadcast_dim >= 0 + else ndim + self.broadcast_dim + ) + shape = [1] * ndim + shape[dim] = phi.shape[0] + phi = phi.view(*shape) + rotor = torch.polar(torch.ones_like(phi), phi) + if not input.is_complex(): + input = input.to(torch.cfloat) + return input * rotor + + def extra_repr(self) -> str: + return f"num_features={self.num_features}, broadcast_dim={self.broadcast_dim}" + + +class ComplexScaling(nn.Module): + r""" + Learnable Complex Scaling + ------------------------- + + Multiplies a complex input by the learnable complex scalar + :math:`\alpha + j\beta`: + + .. math:: + + y = (\alpha + j\beta) \cdot z + = (\alpha \Re z - \beta \Im z) + j(\beta \Re z + \alpha \Im z) + + Unlike :class:`PhaseShift` (which restricts the multiplier to unit + magnitude), :class:`ComplexScaling` learns both magnitude and phase. + + Broadcasting matches :class:`PhaseShift`: pass an int / tuple of ints for + ``num_features`` and the parameter shape is identical to it. When + ``num_features`` is a single int, the parameter has shape ``(num_features,)`` + and is broadcast at ``broadcast_dim`` (default 1, i.e. the channel axis of a + ``(B, C, ...)`` input). + + Based on work from the following paper: + + **U. Singhal, Y. Xing, S. X. Yu. Co-Domain Symmetry for Complex-Valued Deep Learning.** + + - CVPR 2022 — `scaling_layer` in the reference implementation + + - https://openaccess.thecvf.com/content/CVPR2022/papers/Singhal_Co-Domain_Symmetry_for_Complex-Valued_Deep_Learning_CVPR_2022_paper.pdf + + Args: + num_features: shape of the learnable scale parameters. ``()`` or ``1`` + gives a single scalar scale. + broadcast_dim: when ``num_features`` is an int, the parameter is + broadcast to the input at this dim. Use ``1`` (default) for + ``(B, C, ...)`` and ``-1`` for ``(B, ..., C)``. + """ + + def __init__( + self, + num_features: int | tuple[int, ...] = 1, + broadcast_dim: int = 1, + ) -> None: + super().__init__() + if isinstance(num_features, int): + shape: tuple[int, ...] = (num_features,) if num_features != 1 else () + else: + shape = tuple(num_features) + self.num_features = num_features + self.broadcast_dim = broadcast_dim + # Matches the reference (cds/layers.py:474-475): uniform [0, 1). + self.alpha = nn.Parameter(torch.empty(shape).uniform_(0.0, 1.0)) + self.beta = nn.Parameter(torch.empty(shape).uniform_(0.0, 1.0)) + + def _broadcast(self, t: torch.Tensor, input_dim: int) -> torch.Tensor: + if t.dim() == 1 and input_dim > 1: + dim = ( + self.broadcast_dim + if self.broadcast_dim >= 0 + else input_dim + self.broadcast_dim + ) + shape = [1] * input_dim + shape[dim] = t.shape[0] + return t.view(*shape) + return t + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not input.is_complex(): + input = input.to(torch.cfloat) + alpha = self._broadcast(self.alpha, input.dim()) + beta = self._broadcast(self.beta, input.dim()) + scale = torch.complex(alpha, beta) + return input * scale + + def extra_repr(self) -> str: + return f"num_features={self.num_features}, broadcast_dim={self.broadcast_dim}" diff --git a/complextorch/nn/modules/phase_modulation.py b/complextorch/nn/modules/phase_modulation.py new file mode 100644 index 0000000..42c237e --- /dev/null +++ b/complextorch/nn/modules/phase_modulation.py @@ -0,0 +1,238 @@ +r""" +Phase-Modulation Layers (CDS) +============================= + +Two related layers that modulate a complex input :math:`x` by a learned +complex function :math:`g(x)` of itself: + +- :class:`PhaseDivConv{1,2,3}d` (``y = x / g(x)``) — U(1)-invariant: a global + phase rotation :math:`e^{j\psi}` of :math:`x` cancels in numerator and + denominator, so the output is unchanged by global phase. Output magnitude + is :math:`|x| / |g(x)|`. + +- :class:`PhaseConjConv{1,2,3}d` (``y = x \cdot \overline{g(x)}``) — also + U(1)-invariant when the inner :math:`g` is complex-linear (the cfloat-native + case here): both factors rotate by :math:`e^{\pm j\psi}` and cancel. Output + magnitude is :math:`|x| \cdot |g(x)|`. Use this when you want phase + invariance plus a learned magnitude scaling that grows with :math:`|g(x)|` + rather than shrinks with :math:`1/|g(x)|`. + +Both share an inner complex convolution :math:`g`. When ``use_one_filter=True`` +(the default in the reference implementation), :math:`g` has a single output +channel that is broadcast across input channels. + +.. note:: + The CDS paper described ``ConjugateLayer`` as a "phase-mixing" operator. + That characterisation was specific to the paper's two-real-conv + decomposition of complex convolution; a fully C-linear inner conv (used + here for compatibility with ``complextorch``'s native cfloat convention) + yields strict U(1)-invariance instead. + +Based on work from the following paper: + + **U. Singhal, Y. Xing, S. X. Yu. Co-Domain Symmetry for Complex-Valued Deep Learning.** + + - CVPR 2022 — ``DivLayer`` and ``ConjugateLayer`` in the reference implementation + + - https://openaccess.thecvf.com/content/CVPR2022/papers/Singhal_Co-Domain_Symmetry_for_Complex-Valued_Deep_Learning_CVPR_2022_paper.pdf +""" + +import torch +import torch.nn as nn + +from complextorch.nn.modules.conv import Conv1d, Conv2d, Conv3d + +__all__ = [ + "PhaseConjConv1d", + "PhaseConjConv2d", + "PhaseConjConv3d", + "PhaseDivConv1d", + "PhaseDivConv2d", + "PhaseDivConv3d", +] + + +def _center_crop(x: torch.Tensor, target_spatial: tuple[int, ...]) -> torch.Tensor: + """Center-crop the trailing spatial dims of ``x`` to ``target_spatial``. + + Used when the inner conv ``g`` has a kernel larger than 1 and no padding, + leaving ``g(x)`` smaller than ``x``. Returns ``x`` unchanged if shapes + already match. + """ + spatial_in = x.shape[-len(target_spatial) :] + if tuple(spatial_in) == tuple(target_spatial): + return x + slices = [slice(None), slice(None)] + for in_size, out_size in zip(spatial_in, target_spatial, strict=False): + start = (in_size - out_size) // 2 + slices.append(slice(start, start + out_size)) + return x[tuple(slices)] + + +class _PhaseModulationNd(nn.Module): + r"""Shared base for :class:`PhaseDivConv{1,2,3}d` and + :class:`PhaseConjConv{1,2,3}d`. Subclasses override :meth:`_combine`.""" + + _conv_classes = (None, Conv1d, Conv2d, Conv3d) # indexed by nd + + def __init__( + self, + nd: int, + in_channels: int, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups: int = 1, + use_one_filter: bool = True, + eps: float = 1e-7, + ) -> None: + super().__init__() + if nd not in (1, 2, 3): + raise ValueError(f"nd must be 1, 2, or 3, got {nd}") + self.nd = nd + self.in_channels = in_channels + self.use_one_filter = use_one_filter + self.eps = eps + + out_channels = 1 if use_one_filter else in_channels + conv_cls = self._conv_classes[nd] + self.conv = conv_cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + ) + + def _combine(self, x: torch.Tensor, g_x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not input.is_complex(): + input = input.to(torch.cfloat) + g_x = self.conv(input) + if self.use_one_filter: + # g_x has 1 output channel; expand back to in_channels for elementwise ops. + g_x = g_x.expand(g_x.shape[0], self.in_channels, *g_x.shape[2:]) + # Center-crop input to g_x's spatial size when the inner conv shrank it. + target_spatial = g_x.shape[-self.nd :] + x = _center_crop(input, target_spatial) + return self._combine(x, g_x) + + def extra_repr(self) -> str: + return ( + f"in_channels={self.in_channels}, use_one_filter={self.use_one_filter}, " + f"eps={self.eps}" + ) + + +class _PhaseDivConvNd(_PhaseModulationNd): + r"""``y = x · conj(g(x)) / (|g(x)|² + ε)`` — U(1)-invariant.""" + + def _combine(self, x: torch.Tensor, g_x: torch.Tensor) -> torch.Tensor: + denom = (g_x.real * g_x.real + g_x.imag * g_x.imag + self.eps).to(x.dtype) + return x * g_x.conj() / denom + + +class _PhaseConjConvNd(_PhaseModulationNd): + r"""``y = x · conj(g(x))`` — phase-mixing modulator.""" + + def _combine(self, x: torch.Tensor, g_x: torch.Tensor) -> torch.Tensor: + return x * g_x.conj() + + +def _make_doc(name: str, op: str, invariance: str) -> str: + return f""" +{name} +{"-" * len(name)} + +Phase modulation by a learned complex convolution: + +.. math:: + + y = {op} + +{invariance} + +See :mod:`complextorch.nn.modules.phase_modulation` for the shared +construction conventions (``use_one_filter`` default, inner conv reuse, etc.). + +Args: + in_channels: number of complex input channels. + kernel_size, stride, padding, dilation, groups: forwarded to the inner + :class:`complextorch.nn.Conv{{1,2,3}}d`. + use_one_filter: if ``True`` (default), the inner conv produces a single + complex channel that is broadcast across input channels (this matches + the CDS reference for the ``I``-type network). + eps: numerical floor on the denominator (only used by the Div variant). +""" + + +class PhaseDivConv1d(_PhaseDivConvNd): + __doc__ = _make_doc( + "1-D Phase-Division Modulation", + r"x \cdot \overline{g(x)} / (|g(x)|^2 + \varepsilon)", + "**U(1)-invariant** under a global phase rotation of ``x``.", + ) + + def __init__(self, in_channels: int, kernel_size, **kwargs) -> None: + super().__init__(1, in_channels, kernel_size, **kwargs) + + +class PhaseDivConv2d(_PhaseDivConvNd): + __doc__ = _make_doc( + "2-D Phase-Division Modulation", + r"x \cdot \overline{g(x)} / (|g(x)|^2 + \varepsilon)", + "**U(1)-invariant** under a global phase rotation of ``x``.", + ) + + def __init__(self, in_channels: int, kernel_size, **kwargs) -> None: + super().__init__(2, in_channels, kernel_size, **kwargs) + + +class PhaseDivConv3d(_PhaseDivConvNd): + __doc__ = _make_doc( + "3-D Phase-Division Modulation", + r"x \cdot \overline{g(x)} / (|g(x)|^2 + \varepsilon)", + "**U(1)-invariant** under a global phase rotation of ``x``.", + ) + + def __init__(self, in_channels: int, kernel_size, **kwargs) -> None: + super().__init__(3, in_channels, kernel_size, **kwargs) + + +class PhaseConjConv1d(_PhaseConjConvNd): + __doc__ = _make_doc( + "1-D Phase-Conjugate Modulation", + r"x \cdot \overline{g(x)}", + "Phase-mixing modulator; magnitude is scaled by :math:`|g(x)|`.", + ) + + def __init__(self, in_channels: int, kernel_size, **kwargs) -> None: + super().__init__(1, in_channels, kernel_size, **kwargs) + + +class PhaseConjConv2d(_PhaseConjConvNd): + __doc__ = _make_doc( + "2-D Phase-Conjugate Modulation", + r"x \cdot \overline{g(x)}", + "Phase-mixing modulator; magnitude is scaled by :math:`|g(x)|`.", + ) + + def __init__(self, in_channels: int, kernel_size, **kwargs) -> None: + super().__init__(2, in_channels, kernel_size, **kwargs) + + +class PhaseConjConv3d(_PhaseConjConvNd): + __doc__ = _make_doc( + "3-D Phase-Conjugate Modulation", + r"x \cdot \overline{g(x)}", + "Phase-mixing modulator; magnitude is scaled by :math:`|g(x)|`.", + ) + + def __init__(self, in_channels: int, kernel_size, **kwargs) -> None: + super().__init__(3, in_channels, kernel_size, **kwargs) diff --git a/complextorch/nn/modules/pooling.py b/complextorch/nn/modules/pooling.py index 1f08bba..a7d049a 100755 --- a/complextorch/nn/modules/pooling.py +++ b/complextorch/nn/modules/pooling.py @@ -1,11 +1,20 @@ -from typing import Union, Tuple - -import torch.nn as nn import torch +import torch.nn as nn +import torch.nn.functional as F -from .. import functional as cvF +from complextorch.nn import functional as cvF -__all__ = ["AdaptiveAvgPool1d", "AdaptiveAvgPool2d", "AdaptiveAvgPool3d"] +__all__ = [ + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "MagMaxPool1d", + "MagMaxPool2d", + "MagMaxPool3d", +] class AdaptiveAvgPool1d(nn.AdaptiveAvgPool1d): @@ -13,7 +22,7 @@ class AdaptiveAvgPool1d(nn.AdaptiveAvgPool1d): 1-D Complex-Valued Adaptive Average Pooling ------------------------------------------- - Applies adaptive average pooling using `torch.nn.AdaptiveAvgPool1d `_ to the real and imaginary parts of the input tensor separately. + Applies adaptive average pooling using :class:`torch.nn.AdaptiveAvgPool1d` to the real and imaginary parts of the input tensor separately. Implements the following operation: @@ -24,11 +33,11 @@ class AdaptiveAvgPool1d(nn.AdaptiveAvgPool1d): where :math:`\mathbf{z} = \mathbf{x} + j\mathbf{y}` """ - def __init__(self, output_size: Union[int, Tuple[int]]) -> None: + def __init__(self, output_size: int | tuple[int]) -> None: super().__init__(output_size) def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Applies adaptive average pooling using `torch.nn.AdaptiveAvgPool1d `_ to the real and imaginary parts of the input tensor separately. + r"""Applies adaptive average pooling using :class:`torch.nn.AdaptiveAvgPool1d` to the real and imaginary parts of the input tensor separately. Args: input (torch.Tensor): input tensor @@ -44,7 +53,7 @@ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): 2-D Complex-Valued Adaptive Average Pooling ------------------------------------------- - Applies adaptive average pooling using `torch.nn.AdaptiveAvgPool2d `_ to the real and imaginary parts of the input tensor separately. + Applies adaptive average pooling using :class:`torch.nn.AdaptiveAvgPool2d` to the real and imaginary parts of the input tensor separately. Implements the following operation: @@ -59,7 +68,7 @@ def __init__(self, output_size) -> None: super().__init__(output_size) def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Applies adaptive average pooling using `torch.nn.AdaptiveAvgPool2d `_ to the real and imaginary parts of the input tensor separately. + r"""Applies adaptive average pooling using :class:`torch.nn.AdaptiveAvgPool2d` to the real and imaginary parts of the input tensor separately. Args: input (torch.Tensor): input tensor @@ -75,7 +84,7 @@ class AdaptiveAvgPool3d(nn.AdaptiveAvgPool3d): 3-D Complex-Valued Adaptive Average Pooling ------------------------------------------- - Applies adaptive average pooling using `torch.nn.AdaptiveAvgPool3d `_ to the real and imaginary parts of the input tensor separately. + Applies adaptive average pooling using :class:`torch.nn.AdaptiveAvgPool3d` to the real and imaginary parts of the input tensor separately. Implements the following operation: @@ -90,7 +99,7 @@ def __init__(self, output_size) -> None: super().__init__(output_size) def forward(self, input: torch.Tensor) -> torch.Tensor: - r"""Applies adaptive average pooling using `torch.nn.AdaptiveAvgPool3d `_ to the real and imaginary parts of the input tensor separately. + r"""Applies adaptive average pooling using :class:`torch.nn.AdaptiveAvgPool3d` to the real and imaginary parts of the input tensor separately. Args: input (torch.Tensor): input tensor @@ -99,3 +108,154 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: :math:`\texttt{AdaptiveAvgPool3d}(\mathbf{x}) + j \texttt{AdaptiveAvgPool3d}(\mathbf{y})` """ return cvF.apply_complex_split(super().forward, super().forward, input) + + +class AvgPool1d(nn.AvgPool1d): + r""" + 1-D Complex-Valued Average Pooling + ---------------------------------- + + Convenience wrapper over :class:`torch.nn.AvgPool1d` for complex inputs. + Average pooling is linear, so applying ``torch.nn.AvgPool1d`` to a + ``torch.cfloat`` tensor is mathematically equivalent to pooling real and + imaginary parts independently and recombining. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if input.is_complex(): + return cvF.apply_complex_split(super().forward, super().forward, input) + return super().forward(input) + + +class AvgPool2d(nn.AvgPool2d): + r""" + 2-D Complex-Valued Average Pooling + ---------------------------------- + + Convenience wrapper over :class:`torch.nn.AvgPool2d` for complex inputs. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if input.is_complex(): + return cvF.apply_complex_split(super().forward, super().forward, input) + return super().forward(input) + + +class AvgPool3d(nn.AvgPool3d): + r""" + 3-D Complex-Valued Average Pooling + ---------------------------------- + + Convenience wrapper over :class:`torch.nn.AvgPool3d` for complex inputs. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if input.is_complex(): + return cvF.apply_complex_split(super().forward, super().forward, input) + return super().forward(input) + + +def _gather_max_by_magnitude( + input: torch.Tensor, indices: torch.Tensor +) -> torch.Tensor: + r"""Gather the *original* complex samples at ``indices`` (argmax-of-magnitude positions). + + ``indices`` come from ``F.max_poolNd_with_indices(|input|)``: they are linear + indices into the spatial dimensions (last N dims of the input). For each + output position, we select the corresponding complex sample from + ``input`` and return it, preserving the original phase. + """ + # Flatten spatial dimensions of input, then gather along that flat dim. + n_spatial = indices.dim() - 2 # batch, channel, then spatial + flat_input = input.reshape(*input.shape[:-n_spatial], -1) + flat_indices = indices.reshape(*indices.shape[:-n_spatial], -1) + gathered = torch.gather(flat_input, dim=-1, index=flat_indices) + return gathered.reshape(*indices.shape) + + +class _MagMaxPoolNd(nn.Module): + r"""Internal base for magnitude-argmax complex max pooling.""" + + _max_pool_with_indices = staticmethod(F.max_pool1d_with_indices) + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + dilation=1, + return_indices: bool = False, + ceil_mode: bool = False, + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride if stride is not None else kernel_size + self.padding = padding + self.dilation = dilation + self.return_indices = return_indices + self.ceil_mode = ceil_mode + + def forward(self, input: torch.Tensor): + magnitude = input.abs() if input.is_complex() else input + _, indices = self._max_pool_with_indices( + magnitude, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ceil_mode=self.ceil_mode, + ) + if input.is_complex(): + out = _gather_max_by_magnitude(input, indices) + else: + # Real input: gather is equivalent to the max value itself + out = _gather_max_by_magnitude(input, indices) + if self.return_indices: + return out, indices + return out + + def extra_repr(self) -> str: + return ( + f"kernel_size={self.kernel_size}, stride={self.stride}, " + f"padding={self.padding}, dilation={self.dilation}, " + f"ceil_mode={self.ceil_mode}" + ) + + +class MagMaxPool1d(_MagMaxPoolNd): + r""" + 1-D Complex-Valued Max Pooling by Magnitude + ------------------------------------------- + + Pools by selecting the input position with the **largest magnitude** + :math:`|z|` within each window, and returns the **original complex sample** + at that position (phase preserved). + + Because ``torch.nn.MaxPool1d`` is not defined for complex tensors (no total + ordering on :math:`\mathbb{C}`), this layer is the canonical complex + analogue. Signature matches :class:`torch.nn.MaxPool1d`. + """ + + _max_pool_with_indices = staticmethod(F.max_pool1d_with_indices) + + +class MagMaxPool2d(_MagMaxPoolNd): + r""" + 2-D Complex-Valued Max Pooling by Magnitude + ------------------------------------------- + + See :class:`MagMaxPool1d`. Signature matches :class:`torch.nn.MaxPool2d`. + """ + + _max_pool_with_indices = staticmethod(F.max_pool2d_with_indices) + + +class MagMaxPool3d(_MagMaxPoolNd): + r""" + 3-D Complex-Valued Max Pooling by Magnitude + ------------------------------------------- + + See :class:`MagMaxPool1d`. Signature matches :class:`torch.nn.MaxPool3d`. + """ + + _max_pool_with_indices = staticmethod(F.max_pool3d_with_indices) diff --git a/complextorch/nn/modules/prototype.py b/complextorch/nn/modules/prototype.py new file mode 100644 index 0000000..e08b45a --- /dev/null +++ b/complextorch/nn/modules/prototype.py @@ -0,0 +1,115 @@ +r""" +Prototype-Distance Classifier Head +================================== + +:class:`PrototypeDistance` stores a learnable bank of complex prototypes and +returns negative-distance logits for each prototype. Used as the classifier +head in the CDS reference models (`I`- and `E`-type). + +Based on work from the following paper: + + **U. Singhal, Y. Xing, S. X. Yu. Co-Domain Symmetry for Complex-Valued Deep Learning.** + + - CVPR 2022 — ``DistFeatures`` in the reference implementation + + - https://openaccess.thecvf.com/content/CVPR2022/papers/Singhal_Co-Domain_Symmetry_for_Complex-Valued_Deep_Learning_CVPR_2022_paper.pdf +""" + +import math + +import torch +import torch.nn as nn + +__all__ = ["PrototypeDistance"] + + +class PrototypeDistance(nn.Module): + r""" + Complex Prototype Distance Classifier + ------------------------------------- + + Holds :math:`K` learnable complex prototypes :math:`p_{c,k}`. For each + sample ``z`` of shape ``[B, C]``, the logit for class :math:`k` is the + negative root-mean-squared complex distance to prototype :math:`k`: + + .. math:: + + \text{logits}_{b,k} + = -\tau \cdot \sqrt{\tfrac{1}{C} \sum_{c} |z_{b,c} - p_{c,k}|^2} + + where :math:`\tau \in \mathbb{R}` is a learnable temperature. + + Equivariant ("E-type") use + ~~~~~~~~~~~~~~~~~~~~~~~~~~ + To form a U(1)-equivariant *network*, the prototypes can be pre-rotated by + a reference complex vector ``y`` (one per sample, per channel) before + distance: + + .. math:: + + p_{b,c,k}^{\prime} = y_{b,c} \cdot p_{c,k} + + Pass ``y`` via the ``reference=`` argument of :meth:`forward`. When the + network rotates ``z`` by a global :math:`e^{j\psi}`, both ``z`` and ``y`` + rotate identically and the distance is unchanged — so the logits are + invariant in the I-type call and equivariant-then-invariant in the E-type + call. This matches ``cds/model.py:225-228``. + + Args: + in_features: number of complex channels :math:`C`. + num_prototypes: number of prototypes / output classes :math:`K`. + temperature_init: initial value of the learnable temperature ``τ``. + """ + + def __init__( + self, + in_features: int, + num_prototypes: int, + temperature_init: float = 1.0, + ) -> None: + super().__init__() + self.in_features = in_features + self.num_prototypes = num_prototypes + scale = 1.0 / math.sqrt(in_features) + proto = torch.empty(in_features, num_prototypes, dtype=torch.cfloat) + proto.real.normal_(0.0, scale) + proto.imag.normal_(0.0, scale) + self.prototypes = nn.Parameter(proto) + self.temperature = nn.Parameter(torch.tensor(float(temperature_init))) + + def forward( + self, input: torch.Tensor, reference: torch.Tensor | None = None + ) -> torch.Tensor: + if input.dim() != 2: + raise ValueError( + f"PrototypeDistance expects input of shape [B, C], got " + f"{tuple(input.shape)}" + ) + if not input.is_complex(): + input = input.to(torch.cfloat) + + # prototypes broadcast over batch: [1, C, K] + proto = self.prototypes.unsqueeze(0) + if reference is not None: + if not reference.is_complex(): + reference = reference.to(torch.cfloat) + if reference.dim() == 1: + reference = reference.unsqueeze(-1) # [B] → [B, 1] + if reference.dim() != 2 or reference.shape[0] != input.shape[0]: + raise ValueError( + f"reference must broadcast against input shape " + f"{tuple(input.shape)} on the channel dim; got " + f"{tuple(reference.shape)}" + ) + # Apply rotation: [B, C, 1] * [1, C, K] (broadcasts a scalar reference). + proto = reference.unsqueeze(-1) * proto + + # diff: [B, C, K] + diff = input.unsqueeze(-1) - proto + # Mean over channels of squared absolute difference, then sqrt. + dist_sq = (diff.real * diff.real + diff.imag * diff.imag).mean(dim=1) + dist = torch.sqrt(dist_sq.clamp(min=0.0)) + return -dist * self.temperature + + def extra_repr(self) -> str: + return f"in_features={self.in_features}, num_prototypes={self.num_prototypes}" diff --git a/complextorch/nn/modules/rmsnorm.py b/complextorch/nn/modules/rmsnorm.py new file mode 100644 index 0000000..49aec3f --- /dev/null +++ b/complextorch/nn/modules/rmsnorm.py @@ -0,0 +1,91 @@ +r""" +Complex-Valued RMSNorm +====================== + +Root-Mean-Square layer normalization adapted to complex tensors. Equivalent to +:class:`torch.nn.RMSNorm` but operates on complex inputs; the affine parameter +is a 2x2 real matrix applied to ``(Re, Im)`` of each channel. +""" + +import torch +import torch.nn as nn + +__all__ = ["RMSNorm"] + + +class RMSNorm(nn.Module): + r""" + Complex-Valued RMS Normalization + -------------------------------- + + .. math:: + + y = \frac{x}{\sqrt{\text{mean}(|x|^2) + \epsilon}} + + Followed by an optional per-feature affine transform: the real/imag pair + of each feature is multiplied by a learnable 2x2 matrix. No bias. + + Args: + normalized_shape: shape of the trailing dims to normalize over (same + semantics as :class:`torch.nn.RMSNorm`). + eps: numerical stabilizer. + elementwise_affine: if ``True``, applies a learnable 2x2 affine. + """ + + def __init__( + self, + normalized_shape: int | list[int] | tuple[int, ...] | torch.Size, + *, + eps: float = 1e-5, + elementwise_affine: bool = True, + ) -> None: + super().__init__() + + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + normalized_shape = tuple(normalized_shape) + self.normalized_shape = normalized_shape + self.eps = eps + self.elementwise_affine = elementwise_affine + + if elementwise_affine: + # 2x2 affine on (Re, Im); initialized to identity / sqrt(2) so the + # zero-mean unit-variance assumption maps to a proper complex + # standardisation (same convention as the LayerNorm in this lib). + self.weight = nn.Parameter(torch.empty(2, 2, *normalized_shape)) + else: + self.register_parameter("weight", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if not self.elementwise_affine: + return + eye = (0.70710678118 * torch.eye(2)).view( + 2, 2, *([1] * len(self.normalized_shape)) + ) + self.weight.data.copy_(eye) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not input.is_complex(): + raise TypeError(f"RMSNorm expects a complex input, got dtype={input.dtype}") + # Compute RMS of |x|^2 over the normalized_shape (last len(...) dims). + ns = len(self.normalized_shape) + dims = tuple(range(input.dim() - ns, input.dim())) + rms = torch.sqrt(input.abs().pow(2).mean(dim=dims, keepdim=True) + self.eps) + normed = input / rms + if not self.elementwise_affine: + return normed + + # Apply 2x2 affine: weight has shape (2, 2, *normalized_shape). + re, im = normed.real, normed.imag + w = self.weight + out_r = w[0, 0] * re + w[0, 1] * im + out_i = w[1, 0] * re + w[1, 1] * im + return torch.complex(out_r, out_i) + + def extra_repr(self) -> str: + return ( + f"normalized_shape={self.normalized_shape}, eps={self.eps}, " + f"elementwise_affine={self.elementwise_affine}" + ) diff --git a/complextorch/nn/modules/rnn.py b/complextorch/nn/modules/rnn.py new file mode 100644 index 0000000..ac9198a --- /dev/null +++ b/complextorch/nn/modules/rnn.py @@ -0,0 +1,383 @@ +r""" +Complex-Valued Recurrent Neural Networks +======================================== + +Drop-in complex analogues of :class:`torch.nn.GRUCell`, :class:`torch.nn.GRU`, +:class:`torch.nn.LSTMCell`, and :class:`torch.nn.LSTM`. + +Cells (:class:`GRUCell`, :class:`LSTMCell`) are built from complex +:class:`Linear` layers and complex activations and are mathematically the +standard cell equations applied to complex inputs and states. Multi-layer +sequence wrappers (:class:`GRU`, :class:`LSTM`) stack cells along the time +axis — they do not use PyTorch's CuDNN-fused real RNN under the hood (which +would need a parameterization trick that subtly differs from the cell math). + +Each cell accepts ``batchnorm=False``; setting it to ``True`` inserts a +:class:`BatchNorm1d` after every internal linear projection (analogous to +*Recurrent Batch Normalization* (Cooijmans et al., 2017) for the complex case). +""" + +import torch +import torch.nn as nn + +from complextorch.nn.modules.activation.split_type_A import CSigmoid, CTanh +from complextorch.nn.modules.batchnorm import BatchNorm1d +from complextorch.nn.modules.dropout import Dropout +from complextorch.nn.modules.linear import Linear + +__all__ = ["GRU", "LSTM", "GRUCell", "LSTMCell"] + + +class GRUCell(nn.Module): + r""" + Complex-Valued GRU Cell + ----------------------- + + Standard GRU equations applied to a complex input and hidden state: + + .. math:: + + r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\ + z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\ + n_t &= \tanh(W_{in} x_t + r_t \odot (W_{hn} h_{t-1}) + b_n) \\ + h_t &= (1 - z_t) \odot n_t + z_t \odot h_{t-1} + + where :math:`\sigma` and :math:`\tanh` are the split (Type-A) complex + activations :class:`CSigmoid` / :class:`CTanh`, and all weights/biases + are complex. + + Args: + input_size: feature size of ``x``. + hidden_size: feature size of ``h``. + bias: if ``True``, adds biases to the projections. + batchnorm: if ``True``, wraps each linear projection in + :class:`BatchNorm1d`. Useful for stabilizing deep recurrent stacks. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + batchnorm: bool = False, + ) -> None: + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.batchnorm = batchnorm + + # 6 linear projections: 3 input (r, z, n) + 3 hidden (r, z, n). + self.w_ir = Linear(input_size, hidden_size, bias=bias) + self.w_iz = Linear(input_size, hidden_size, bias=bias) + self.w_in = Linear(input_size, hidden_size, bias=bias) + self.w_hr = Linear(hidden_size, hidden_size, bias=bias) + self.w_hz = Linear(hidden_size, hidden_size, bias=bias) + self.w_hn = Linear(hidden_size, hidden_size, bias=bias) + + if batchnorm: + self.bn_ir = BatchNorm1d(hidden_size) + self.bn_iz = BatchNorm1d(hidden_size) + self.bn_in = BatchNorm1d(hidden_size) + self.bn_hr = BatchNorm1d(hidden_size) + self.bn_hz = BatchNorm1d(hidden_size) + self.bn_hn = BatchNorm1d(hidden_size) + self.sigmoid = CSigmoid() + self.tanh = CTanh() + + def _bn(self, name: str, x: torch.Tensor) -> torch.Tensor: + if not self.batchnorm: + return x + # BatchNorm1d expects (B, C); use directly. + return getattr(self, name)(x) + + def forward( + self, input: torch.Tensor, hx: torch.Tensor | None = None + ) -> torch.Tensor: + if hx is None: + hx = torch.zeros( + input.shape[0], self.hidden_size, dtype=input.dtype, device=input.device + ) + r = self.sigmoid( + self._bn("bn_ir", self.w_ir(input)) + self._bn("bn_hr", self.w_hr(hx)) + ) + z = self.sigmoid( + self._bn("bn_iz", self.w_iz(input)) + self._bn("bn_hz", self.w_hz(hx)) + ) + n = self.tanh( + self._bn("bn_in", self.w_in(input)) + r * self._bn("bn_hn", self.w_hn(hx)) + ) + return (1 - z) * n + z * hx + + +class LSTMCell(nn.Module): + r""" + Complex-Valued LSTM Cell + ------------------------ + + Standard LSTM equations applied to complex inputs and states: + + .. math:: + + i_t &= \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ + f_t &= \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ + g_t &= \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ + o_t &= \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ + c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\ + h_t &= o_t \odot \tanh(c_t) + + Args: + input_size: feature size of ``x``. + hidden_size: feature size of ``h`` and ``c``. + bias: if ``True``, adds biases to the projections. + batchnorm: if ``True``, wraps each linear projection in + :class:`BatchNorm1d`. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + batchnorm: bool = False, + ) -> None: + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.batchnorm = batchnorm + + self.w_ii = Linear(input_size, hidden_size, bias=bias) + self.w_if = Linear(input_size, hidden_size, bias=bias) + self.w_ig = Linear(input_size, hidden_size, bias=bias) + self.w_io = Linear(input_size, hidden_size, bias=bias) + self.w_hi = Linear(hidden_size, hidden_size, bias=bias) + self.w_hf = Linear(hidden_size, hidden_size, bias=bias) + self.w_hg = Linear(hidden_size, hidden_size, bias=bias) + self.w_ho = Linear(hidden_size, hidden_size, bias=bias) + + if batchnorm: + for gate in ("ii", "if_", "ig", "io", "hi", "hf", "hg", "ho"): + setattr(self, f"bn_{gate}", BatchNorm1d(hidden_size)) + self.sigmoid = CSigmoid() + self.tanh = CTanh() + + def _bn(self, name: str, x: torch.Tensor) -> torch.Tensor: + if not self.batchnorm: + return x + return getattr(self, name)(x) + + def forward( + self, + input: torch.Tensor, + hx: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if hx is None: + zero_h = torch.zeros( + input.shape[0], self.hidden_size, dtype=input.dtype, device=input.device + ) + zero_c = torch.zeros_like(zero_h) + h_prev, c_prev = zero_h, zero_c + else: + h_prev, c_prev = hx + i = self.sigmoid( + self._bn("bn_ii", self.w_ii(input)) + self._bn("bn_hi", self.w_hi(h_prev)) + ) + # ``if`` is a Python keyword, so the attribute is named ``w_if`` and + # the batch-norm is ``bn_if_``. + f = self.sigmoid( + self._bn("bn_if_", self.w_if(input)) + self._bn("bn_hf", self.w_hf(h_prev)) + ) + g = self.tanh( + self._bn("bn_ig", self.w_ig(input)) + self._bn("bn_hg", self.w_hg(h_prev)) + ) + o = self.sigmoid( + self._bn("bn_io", self.w_io(input)) + self._bn("bn_ho", self.w_ho(h_prev)) + ) + c = f * c_prev + i * g + h = o * self.tanh(c) + return h, c + + +# --------------------------------------------------------------------------- +# Multi-layer sequence wrappers +# --------------------------------------------------------------------------- + + +class _RNNBase(nn.Module): + """Internal base for multi-layer cell-stacking RNNs.""" + + _cell_class = GRUCell + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + batchnorm: bool = False, + ) -> None: + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.batch_first = batch_first + self.dropout = dropout + self.bidirectional = bidirectional + self.num_directions = 2 if bidirectional else 1 + + cells_fwd: list[nn.Module] = [] + cells_bwd: list[nn.Module] = [] + for layer in range(num_layers): + in_size = input_size if layer == 0 else hidden_size * self.num_directions + cells_fwd.append( + self._cell_class(in_size, hidden_size, bias=bias, batchnorm=batchnorm) + ) + if bidirectional: + cells_bwd.append( + self._cell_class( + in_size, hidden_size, bias=bias, batchnorm=batchnorm + ) + ) + self.cells_fwd = nn.ModuleList(cells_fwd) + self.cells_bwd = nn.ModuleList(cells_bwd) if bidirectional else None + + self.drop = Dropout(dropout) if dropout > 0 and num_layers > 1 else None + + +class GRU(_RNNBase): + r""" + Multi-Layer Complex-Valued GRU + ------------------------------ + + Stacks :class:`GRUCell` along the time axis. Compatible API with + :class:`torch.nn.GRU` (``num_layers``, ``batch_first``, ``dropout``, + ``bidirectional``). + + Note: this implementation rolls along time in Python rather than using + the (real-only) CuDNN fused kernel, so expect a per-step Python overhead + on long sequences. + """ + + _cell_class = GRUCell + + def forward( + self, + input: torch.Tensor, + hx: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.batch_first: + input = input.transpose(0, 1) # -> (T, B, F) + seq_len, batch, _ = input.shape + + if hx is None: + hx = torch.zeros( + self.num_layers * self.num_directions, + batch, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + + outputs = input + layer_states: list[torch.Tensor] = [] + + for layer in range(self.num_layers): + fwd_cell = self.cells_fwd[layer] + h_f = hx[layer * self.num_directions] + outs_f: list[torch.Tensor] = [] + for t in range(seq_len): + h_f = fwd_cell(outputs[t], h_f) + outs_f.append(h_f) + out_f = torch.stack(outs_f, dim=0) + + if self.bidirectional: + bwd_cell = self.cells_bwd[layer] + h_b = hx[layer * self.num_directions + 1] + outs_b: list[torch.Tensor] = [] + for t in range(seq_len - 1, -1, -1): + h_b = bwd_cell(outputs[t], h_b) + outs_b.append(h_b) + outs_b.reverse() + out_b = torch.stack(outs_b, dim=0) + outputs = torch.cat([out_f, out_b], dim=-1) + layer_states.append(h_f) + layer_states.append(h_b) + else: + outputs = out_f + layer_states.append(h_f) + + if self.drop is not None and layer < self.num_layers - 1: + outputs = self.drop(outputs) + + if self.batch_first: + outputs = outputs.transpose(0, 1) + new_hx = torch.stack(layer_states, dim=0) + return outputs, new_hx + + +class LSTM(_RNNBase): + r""" + Multi-Layer Complex-Valued LSTM + ------------------------------- + + Stacks :class:`LSTMCell` along the time axis. Compatible API with + :class:`torch.nn.LSTM`. + """ + + _cell_class = LSTMCell + + def forward( + self, + input: torch.Tensor, + hx: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if self.batch_first: + input = input.transpose(0, 1) + seq_len, batch, _ = input.shape + + n_dir = self.num_directions + if hx is None: + shape = (self.num_layers * n_dir, batch, self.hidden_size) + h0 = torch.zeros(shape, dtype=input.dtype, device=input.device) + c0 = torch.zeros_like(h0) + else: + h0, c0 = hx + + outputs = input + new_h: list[torch.Tensor] = [] + new_c: list[torch.Tensor] = [] + + for layer in range(self.num_layers): + fwd_cell = self.cells_fwd[layer] + h_f, c_f = h0[layer * n_dir], c0[layer * n_dir] + outs_f: list[torch.Tensor] = [] + for t in range(seq_len): + h_f, c_f = fwd_cell(outputs[t], (h_f, c_f)) + outs_f.append(h_f) + out_f = torch.stack(outs_f, dim=0) + + if self.bidirectional: + bwd_cell = self.cells_bwd[layer] + h_b, c_b = h0[layer * n_dir + 1], c0[layer * n_dir + 1] + outs_b: list[torch.Tensor] = [] + for t in range(seq_len - 1, -1, -1): + h_b, c_b = bwd_cell(outputs[t], (h_b, c_b)) + outs_b.append(h_b) + outs_b.reverse() + out_b = torch.stack(outs_b, dim=0) + outputs = torch.cat([out_f, out_b], dim=-1) + new_h.extend([h_f, h_b]) + new_c.extend([c_f, c_b]) + else: + outputs = out_f + new_h.append(h_f) + new_c.append(c_f) + + if self.drop is not None and layer < self.num_layers - 1: + outputs = self.drop(outputs) + + if self.batch_first: + outputs = outputs.transpose(0, 1) + return outputs, (torch.stack(new_h, dim=0), torch.stack(new_c, dim=0)) diff --git a/complextorch/nn/modules/softmax.py b/complextorch/nn/modules/softmax.py index 394ac48..2a9cc58 100755 --- a/complextorch/nn/modules/softmax.py +++ b/complextorch/nn/modules/softmax.py @@ -1,9 +1,7 @@ -from typing import Optional - import torch import torch.nn as nn -from .. import functional as cvF +from complextorch.nn import functional as cvF __all__ = ["CVSoftMax", "MagSoftMax", "PhaseSoftMax"] @@ -27,8 +25,8 @@ class CVSoftMax(nn.Module): where :math:`\mathbf{z} = \mathbf{x} + j\mathbf{y}` """ - def __init__(self, dim: Optional[int] = None) -> None: - super(CVSoftMax, self).__init__() + def __init__(self, dim: int | None = None) -> None: + super().__init__() self.softmax = nn.Softmax(dim) @@ -58,8 +56,8 @@ class PhaseSoftMax(nn.Module): G(\mathbf{z}) = \texttt{SoftMax}(|\mathbf{z}|) \odot \mathbf{z} / |\mathbf{z}| """ - def __init__(self, dim: Optional[int] = None) -> None: - super(PhaseSoftMax, self).__init__() + def __init__(self, dim: int | None = None) -> None: + super().__init__() self.softmax = nn.Softmax(dim) @@ -73,7 +71,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: :math:`\texttt{SoftMax}(|\mathbf{z}|) \odot \mathbf{z} / |\mathbf{z}|` """ x_mag = input.abs() - return self.softmax(x_mag) * (input / x_mag) + return self.softmax(x_mag) * (input / x_mag.clamp(min=1e-12)) class MagSoftMax(nn.Module): @@ -90,8 +88,8 @@ class MagSoftMax(nn.Module): G(\mathbf{z}) = \texttt{SoftMax}(|\mathbf{z}|) """ - def __init__(self, dim: Optional[int] = None) -> None: - super(MagSoftMax, self).__init__() + def __init__(self, dim: int | None = None) -> None: + super().__init__() self.softmax = nn.Softmax(dim) diff --git a/complextorch/nn/modules/transformer.py b/complextorch/nn/modules/transformer.py new file mode 100644 index 0000000..4734a29 --- /dev/null +++ b/complextorch/nn/modules/transformer.py @@ -0,0 +1,253 @@ +r""" +Complex-Valued Transformer +========================== + +Encoder / decoder layers and full encoder-decoder transformer for complex +inputs. Composed from existing complex primitives. + +Note on building blocks: this library's :class:`MultiheadAttention` is +implemented as a *complete* attention sub-block — it applies QKV projections, +the attention mechanism, a residual connection, and a final +:class:`LayerNorm` internally. The encoder/decoder layers below therefore use +``MultiheadAttention`` directly as the "attention" sub-layer, and only add +the feed-forward sub-block on top with its own residual + LayerNorm. +""" + +import copy + +import torch +import torch.nn as nn + +from complextorch.nn.modules.activation.complex_relu import CReLU +from complextorch.nn.modules.activation.split_type_A import CGELU +from complextorch.nn.modules.attention import MultiheadAttention +from complextorch.nn.modules.dropout import Dropout +from complextorch.nn.modules.layernorm import LayerNorm +from complextorch.nn.modules.linear import Linear + +__all__ = [ + "Transformer", + "TransformerDecoder", + "TransformerDecoderLayer", + "TransformerEncoder", + "TransformerEncoderLayer", +] + + +def _get_activation(name: str) -> nn.Module: + if name == "gelu": + return CGELU() + if name == "relu": + return CReLU() + raise ValueError(f"Unknown activation {name!r}; expected 'gelu' or 'relu'") + + +def _clones(module: nn.Module, n: int) -> nn.ModuleList: + return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) + + +class _FFNBlock(nn.Module): + """Feed-forward sub-block: Linear -> activation -> Linear, with residual + LayerNorm.""" + + def __init__( + self, + d_model: int, + dim_feedforward: int, + activation: str, + dropout: float, + layer_norm_eps: float, + ) -> None: + super().__init__() + self.linear1 = Linear(d_model, dim_feedforward) + self.linear2 = Linear(dim_feedforward, d_model) + self.activation = _get_activation(activation) + self.norm = LayerNorm(d_model, eps=layer_norm_eps) + self.dropout = Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.linear2(self.activation(self.linear1(x))) + return self.norm(self.dropout(out) + x) + + +class TransformerEncoderLayer(nn.Module): + r""" + Single complex-valued transformer encoder layer. + + Self-attention + feed-forward, each with internal residual + LayerNorm. + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "gelu", + layer_norm_eps: float = 1e-5, + batch_first: bool = True, + softmax_on: str = "complex", + ) -> None: + super().__init__() + if d_model % nhead != 0: + raise ValueError( + f"d_model ({d_model}) must be divisible by nhead ({nhead})" + ) + self.batch_first = batch_first + d_head = d_model // nhead + + self.self_attn = MultiheadAttention( + nhead, d_model, d_head, d_head, dropout=dropout, softmax_on=softmax_on + ) + self.ffn = _FFNBlock( + d_model, dim_feedforward, activation, dropout, layer_norm_eps + ) + self.d_model = d_model + self.nhead = nhead + + def forward(self, src: torch.Tensor) -> torch.Tensor: + x = src if self.batch_first else src.transpose(0, 1) + x = self.self_attn(x, x, x) + x = self.ffn(x) + return x if self.batch_first else x.transpose(0, 1) + + +class TransformerEncoder(nn.Module): + r"""Stack of :class:`TransformerEncoderLayer` blocks.""" + + def __init__( + self, + encoder_layer: TransformerEncoderLayer, + num_layers: int, + norm: nn.Module | None = None, + ) -> None: + super().__init__() + self.layers = _clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src: torch.Tensor) -> torch.Tensor: + x = src + for layer in self.layers: + x = layer(x) + if self.norm is not None: + x = self.norm(x) + return x + + +class TransformerDecoderLayer(nn.Module): + r""" + Single complex-valued transformer decoder layer. + + Self-attention + cross-attention + feed-forward, each with internal + residual + LayerNorm. + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "gelu", + layer_norm_eps: float = 1e-5, + batch_first: bool = True, + softmax_on: str = "complex", + ) -> None: + super().__init__() + self.batch_first = batch_first + d_head = d_model // nhead + + self.self_attn = MultiheadAttention( + nhead, d_model, d_head, d_head, dropout=dropout, softmax_on=softmax_on + ) + self.cross_attn = MultiheadAttention( + nhead, d_model, d_head, d_head, dropout=dropout, softmax_on=softmax_on + ) + self.ffn = _FFNBlock( + d_model, dim_feedforward, activation, dropout, layer_norm_eps + ) + + def forward(self, tgt: torch.Tensor, memory: torch.Tensor) -> torch.Tensor: + x = tgt if self.batch_first else tgt.transpose(0, 1) + m = memory if self.batch_first else memory.transpose(0, 1) + x = self.self_attn(x, x, x) + x = self.cross_attn(x, m, m) + x = self.ffn(x) + return x if self.batch_first else x.transpose(0, 1) + + +class TransformerDecoder(nn.Module): + r"""Stack of :class:`TransformerDecoderLayer` blocks.""" + + def __init__( + self, + decoder_layer: TransformerDecoderLayer, + num_layers: int, + norm: nn.Module | None = None, + ) -> None: + super().__init__() + self.layers = _clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, tgt: torch.Tensor, memory: torch.Tensor) -> torch.Tensor: + x = tgt + for layer in self.layers: + x = layer(x, memory) + if self.norm is not None: + x = self.norm(x) + return x + + +class Transformer(nn.Module): + r""" + Complex-Valued Transformer (encoder-decoder). + + Mirrors :class:`torch.nn.Transformer`. ``forward(src, tgt)`` runs the + encoder on ``src`` and the decoder on ``tgt`` with the encoder output as + cross-attention memory. + """ + + def __init__( + self, + d_model: int = 512, + nhead: int = 8, + num_encoder_layers: int = 6, + num_decoder_layers: int = 6, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "gelu", + layer_norm_eps: float = 1e-5, + batch_first: bool = True, + softmax_on: str = "complex", + ) -> None: + super().__init__() + enc_layer = TransformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + batch_first, + softmax_on, + ) + self.encoder = TransformerEncoder(enc_layer, num_encoder_layers) + dec_layer = TransformerDecoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + batch_first, + softmax_on, + ) + self.decoder = TransformerDecoder(dec_layer, num_decoder_layers) + self.d_model = d_model + self.nhead = nhead + self.batch_first = batch_first + + def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: + memory = self.encoder(src) + return self.decoder(tgt, memory) diff --git a/complextorch/nn/modules/upsampling.py b/complextorch/nn/modules/upsampling.py new file mode 100644 index 0000000..26f5c7c --- /dev/null +++ b/complextorch/nn/modules/upsampling.py @@ -0,0 +1,179 @@ +r""" +Complex-Valued Upsampling Modules +================================= + +Two flavors of complex-valued upsampling / interpolation: + +- :class:`Upsample` — *split form*. Interpolates the real and imaginary parts + independently. Matches the behavior of :class:`torchcvnn.nn.Upsample` and + ``complexPyTorch.complex_upsample``. + +- :class:`PolarUpsample` — *polar form*. Interpolates the magnitude + :math:`|z|` and the phase :math:`\arg z` independently, then recombines via + :math:`|z| \cdot \exp(j\,\arg z)`. Phase-preserving along smooth phase + regions; useful for coherent signal models (radar, SAR). The cost is a + visible discontinuity wherever the phase wraps from :math:`-\pi` to + :math:`+\pi` — neither form is universally correct. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ["PolarUpsample", "Upsample"] + + +_SizeT = int | tuple[int, ...] | None + + +def _interpolate( + x: torch.Tensor, + size: _SizeT, + scale_factor: float | tuple[float, ...] | None, + mode: str, + align_corners: bool | None, + recompute_scale_factor: bool | None, +) -> torch.Tensor: + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + ) + + +class Upsample(nn.Module): + r""" + Complex-Valued Upsample (split form) + ------------------------------------ + + Applies :func:`torch.nn.functional.interpolate` independently to the real + and imaginary parts of a complex input, then recombines. + + All keyword arguments mirror :class:`torch.nn.Upsample`. + """ + + def __init__( + self, + size: _SizeT = None, + scale_factor: float | tuple[float, ...] | None = None, + mode: str = "nearest", + align_corners: bool | None = None, + recompute_scale_factor: bool | None = None, + ) -> None: + super().__init__() + self.size = size + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not input.is_complex(): + return _interpolate( + input, + self.size, + self.scale_factor, + self.mode, + self.align_corners, + self.recompute_scale_factor, + ) + real = _interpolate( + input.real, + self.size, + self.scale_factor, + self.mode, + self.align_corners, + self.recompute_scale_factor, + ) + imag = _interpolate( + input.imag, + self.size, + self.scale_factor, + self.mode, + self.align_corners, + self.recompute_scale_factor, + ) + return torch.complex(real, imag) + + def extra_repr(self) -> str: + return ( + f"size={self.size}, scale_factor={self.scale_factor}, " + f"mode={self.mode!r}, align_corners={self.align_corners}" + ) + + +class PolarUpsample(nn.Module): + r""" + Complex-Valued Upsample (polar form) + ------------------------------------ + + Interpolates magnitude and phase independently. For a complex input + :math:`z = |z|\, e^{j\arg z}`: + + .. math:: + + |z|' = \text{interp}(|z|), \quad + \arg z' = \text{interp}(\arg z), \quad + z' = |z|' \cdot e^{j\,\arg z'} + + Phase-preserving along smooth phase regions but introduces discontinuities + at phase wraps (:math:`\pm\pi`). Choose between :class:`Upsample` (split) + and :class:`PolarUpsample` (polar) based on whether your data has smooth + phase (favor polar) or smooth real/imag parts (favor split). + + All keyword arguments mirror :class:`torch.nn.Upsample`. + """ + + def __init__( + self, + size: _SizeT = None, + scale_factor: float | tuple[float, ...] | None = None, + mode: str = "nearest", + align_corners: bool | None = None, + recompute_scale_factor: bool | None = None, + ) -> None: + super().__init__() + self.size = size + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not input.is_complex(): + return _interpolate( + input, + self.size, + self.scale_factor, + self.mode, + self.align_corners, + self.recompute_scale_factor, + ) + mag = input.abs() + phase = input.angle() + mag_up = _interpolate( + mag, + self.size, + self.scale_factor, + self.mode, + self.align_corners, + self.recompute_scale_factor, + ) + phase_up = _interpolate( + phase, + self.size, + self.scale_factor, + self.mode, + self.align_corners, + self.recompute_scale_factor, + ) + return torch.polar(mag_up, phase_up) + + def extra_repr(self) -> str: + return ( + f"size={self.size}, scale_factor={self.scale_factor}, " + f"mode={self.mode!r}, align_corners={self.align_corners}" + ) diff --git a/complextorch/nn/relevance/__init__.py b/complextorch/nn/relevance/__init__.py new file mode 100644 index 0000000..4e2fcdf --- /dev/null +++ b/complextorch/nn/relevance/__init__.py @@ -0,0 +1,52 @@ +r""" +Variational Dropout / Automatic Relevance Determination (Complex) +================================================================= + +Complex-valued layers that learn a per-weight relevance score during training +via the local reparameterization trick. Adds a per-module +:attr:`BaseARD.penalty` (KL divergence) that the user adds to the negative +log-likelihood, and a :meth:`BaseARD.relevance` method that returns the +post-training binary keep/drop mask. + +Adapted from :mod:`cplxmodule.nn.relevance.complex` for native ``torch.cfloat``. +""" + +from complextorch.nn.relevance.base import ( + BaseARD, + compute_ard_masks, + named_penalties, + named_relevance, + penalties, +) +from complextorch.nn.relevance.conv import ( + Conv1dARD, + Conv1dVD, + Conv2dARD, + Conv2dVD, + Conv3dARD, + Conv3dVD, +) +from complextorch.nn.relevance.linear import ( + BilinearARD, + BilinearVD, + LinearARD, + LinearVD, +) + +__all__ = [ + "BaseARD", + "BilinearARD", + "BilinearVD", + "Conv1dARD", + "Conv1dVD", + "Conv2dARD", + "Conv2dVD", + "Conv3dARD", + "Conv3dVD", + "LinearARD", + "LinearVD", + "compute_ard_masks", + "named_penalties", + "named_relevance", + "penalties", +] diff --git a/complextorch/nn/relevance/_expi.py b/complextorch/nn/relevance/_expi.py new file mode 100644 index 0000000..00cb407 --- /dev/null +++ b/complextorch/nn/relevance/_expi.py @@ -0,0 +1,43 @@ +r""" +Differentiable Exponential Integral ``Ei`` +========================================== + +Custom :class:`torch.autograd.Function` whose forward calls +:func:`scipy.special.expi` on CPU and whose backward uses the closed-form +derivative :math:`\frac{d}{dx} \mathrm{Ei}(x) = e^x / x`. + +This is used by complex Variational Dropout to compute an exact KL divergence +between a complex-Gaussian posterior and the log-uniform prior — no closed-form +``Ei`` is currently available in pure torch. +""" + +import torch + +__all__ = ["ExpiFunction", "torch_expi"] + + +class ExpiFunction(torch.autograd.Function): + r"""Differentiable port of :func:`scipy.special.expi`. + + Forward goes to CPU + numpy + scipy; backward is analytical + (:math:`\frac{d}{dx} \mathrm{Ei}(x) = e^x / x`) and stays on the original + device. Memory transfer overhead is amortised by the fact that ``Ei`` is + only evaluated on parameter-shaped tensors, not on data-shaped ones. + """ + + @staticmethod + def forward(ctx, x: torch.Tensor) -> torch.Tensor: + import scipy.special + + ctx.save_for_backward(x) + x_cpu = x.detach().cpu().numpy() + output = scipy.special.expi(x_cpu) + return torch.from_numpy(output).to(x.device, dtype=x.dtype) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + (x,) = ctx.saved_tensors + return grad_output * torch.exp(x) / x + + +torch_expi = ExpiFunction.apply diff --git a/complextorch/nn/relevance/base.py b/complextorch/nn/relevance/base.py new file mode 100644 index 0000000..fc059b1 --- /dev/null +++ b/complextorch/nn/relevance/base.py @@ -0,0 +1,84 @@ +r""" +BaseARD + Module-Walking Helpers +================================ + +Adapted from :mod:`cplxmodule.nn.relevance.base`. +""" + +from collections.abc import Iterator + +import torch + +__all__ = [ + "BaseARD", + "compute_ard_masks", + "named_penalties", + "named_relevance", + "penalties", +] + + +class BaseARD(torch.nn.Module): + r""" + Abstract base for variational-dropout / automatic-relevance-determination + layers. Subclasses provide: + + - ``.penalty`` (property): differentiable KL divergence to add to the loss. + - ``.relevance(threshold=...)`` (method): a binary mask of relevant weights. + """ + + @property + def penalty(self) -> torch.Tensor: + raise NotImplementedError("Subclasses must compute their own penalty.") + + def relevance(self, **kwargs) -> torch.Tensor: + raise NotImplementedError("Subclasses must implement `.relevance`.") + + +def named_penalties( + module: torch.nn.Module, + reduction: str = "sum", + prefix: str = "", +) -> Iterator[tuple[str, torch.Tensor]]: + """Yield ``(name, penalty)`` for every :class:`BaseARD` submodule.""" + if reduction is not None and reduction not in ("mean", "sum"): + raise ValueError(f"reduction must be 'mean', 'sum', or None; got {reduction!r}") + for name, mod in module.named_modules(prefix=prefix): + if isinstance(mod, BaseARD): + p = mod.penalty + if reduction == "sum": + p = p.sum() + elif reduction == "mean": + p = p.mean() + yield name, p + + +def penalties( + module: torch.nn.Module, reduction: str = "sum" +) -> Iterator[torch.Tensor]: + """Yield just the penalty tensors. See :func:`named_penalties`.""" + for _, p in named_penalties(module, reduction=reduction): + yield p + + +def named_relevance( + module: torch.nn.Module, *, prefix: str = "", **kwargs +) -> Iterator[tuple[str, torch.Tensor]]: + """Yield ``(name, mask)`` for every :class:`BaseARD` submodule.""" + for name, mod in module.named_modules(prefix=prefix): + if isinstance(mod, BaseARD): + yield name, mod.relevance(**kwargs).detach() + + +def compute_ard_masks(module: torch.nn.Module, *, prefix: str = "", **kwargs) -> dict: + r""" + Build a ``{name + '.mask': mask}`` dict suitable for + :func:`complextorch.nn.masked.deploy_masks`. + """ + if not isinstance(module, torch.nn.Module): + return {} + out = {} + for name, mask in named_relevance(module, prefix=prefix, **kwargs): + key = (name + "." if name else "") + "mask" + out[key] = mask + return out diff --git a/complextorch/nn/relevance/conv.py b/complextorch/nn/relevance/conv.py new file mode 100644 index 0000000..621ef5f --- /dev/null +++ b/complextorch/nn/relevance/conv.py @@ -0,0 +1,157 @@ +r""" +Complex-Valued VD / ARD Conv1d/2d/3d Layers +=========================================== + +Adapted from :mod:`cplxmodule.nn.relevance.complex` for native ``torch.cfloat``. +""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from complextorch.nn.relevance.base import BaseARD +from complextorch.nn.relevance.linear import ( + _CplxARDMixin, + _CplxVDMixin, + _GaussianMixin, + _RelevanceMixin, +) + +__all__ = [ + "Conv1dARD", + "Conv1dVD", + "Conv2dARD", + "Conv2dVD", + "Conv3dARD", + "Conv3dVD", +] + + +def _to_tuple(x, n: int) -> tuple[int, ...]: + if isinstance(x, int): + return (x,) * n + return tuple(x) + + +class _ConvNdGaussian(_GaussianMixin, nn.Module): + """Internal base — subclasses set ``_conv_fn`` and ``_nd``.""" + + _conv_fn = staticmethod(F.conv1d) + _nd = 1 + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + if padding_mode != "zeros": + raise ValueError( + f"Only padding_mode='zeros' is supported, got {padding_mode!r}" + ) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _to_tuple(kernel_size, self._nd) + self.stride = _to_tuple(stride, self._nd) + self.padding = ( + padding if isinstance(padding, str) else _to_tuple(padding, self._nd) + ) + self.dilation = _to_tuple(dilation, self._nd) + self.groups = groups + + weight_shape = (out_channels, in_channels // groups, *self.kernel_size) + self.weight = nn.Parameter(torch.empty(*weight_shape, dtype=torch.cfloat)) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels, dtype=torch.cfloat)) + else: + self.register_parameter("bias", None) + self.log_sigma2 = nn.Parameter(torch.empty(*weight_shape)) + self.reset_parameters() + + def reset_parameters(self) -> None: + fan_in = self.in_channels // self.groups + for k in self.kernel_size: + fan_in *= k + bound = 1.0 / math.sqrt(fan_in) + with torch.no_grad(): + self.weight.real.uniform_(-bound, bound) + self.weight.imag.uniform_(-bound, bound) + if self.bias is not None: + self.bias.real.uniform_(-bound, bound) + self.bias.imag.uniform_(-bound, bound) + self.reset_variational_parameters() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + mu = self._conv_fn( + input, + self.weight, + self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + if not self.training: + return mu + s2 = self._conv_fn( + input.real * input.real + input.imag * input.imag, + torch.exp(self.log_sigma2), + None, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + s2 = s2.clamp_min(1e-8) + eps_r = torch.randn_like(s2) + eps_i = torch.randn_like(s2) + noise = torch.complex(eps_r, eps_i) * (torch.sqrt(s2 / 2.0)) + return mu + noise + + +class _Conv1dGaussian(_ConvNdGaussian): + _conv_fn = staticmethod(F.conv1d) + _nd = 1 + + +class _Conv2dGaussian(_ConvNdGaussian): + _conv_fn = staticmethod(F.conv2d) + _nd = 2 + + +class _Conv3dGaussian(_ConvNdGaussian): + _conv_fn = staticmethod(F.conv3d) + _nd = 3 + + +class Conv1dVD(_CplxVDMixin, _RelevanceMixin, _Conv1dGaussian, BaseARD): + pass + + +class Conv2dVD(_CplxVDMixin, _RelevanceMixin, _Conv2dGaussian, BaseARD): + pass + + +class Conv3dVD(_CplxVDMixin, _RelevanceMixin, _Conv3dGaussian, BaseARD): + pass + + +class Conv1dARD(_CplxARDMixin, _RelevanceMixin, _Conv1dGaussian, BaseARD): + pass + + +class Conv2dARD(_CplxARDMixin, _RelevanceMixin, _Conv2dGaussian, BaseARD): + pass + + +class Conv3dARD(_CplxARDMixin, _RelevanceMixin, _Conv3dGaussian, BaseARD): + pass diff --git a/complextorch/nn/relevance/linear.py b/complextorch/nn/relevance/linear.py new file mode 100644 index 0000000..090e761 --- /dev/null +++ b/complextorch/nn/relevance/linear.py @@ -0,0 +1,201 @@ +r""" +Complex-Valued VD / ARD Linear & Bilinear Layers +================================================ + +Adapted from :mod:`cplxmodule.nn.relevance.complex.{base,vd,ard}` for native +``torch.cfloat``. See :class:`complextorch.nn.relevance.BaseARD` for the +shared interface. +""" + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from complextorch.nn.relevance._expi import torch_expi +from complextorch.nn.relevance.base import BaseARD +from complextorch.nn.utils.sparsity import SparsityStats + +__all__ = ["BilinearARD", "BilinearVD", "LinearARD", "LinearVD"] + + +def _init_complex_weight(weight: torch.Tensor, in_features: int) -> None: + bound = 1.0 / math.sqrt(in_features) + with torch.no_grad(): + weight.real.uniform_(-bound, bound) + weight.imag.uniform_(-bound, bound) + + +def _init_complex_bias(bias: torch.Tensor, in_features: int) -> None: + bound = 1.0 / math.sqrt(in_features) + with torch.no_grad(): + bias.real.uniform_(-bound, bound) + bias.imag.uniform_(-bound, bound) + + +class _GaussianMixin: + r"""Provides ``log_alpha`` and a small ``reset_variational_parameters``.""" + + def reset_variational_parameters(self) -> None: + with torch.no_grad(): + self.log_sigma2.fill_(-10.0) + + @property + def log_alpha(self) -> torch.Tensor: + return self.log_sigma2 - 2.0 * torch.log(self.weight.abs() + 1e-12) + + +class _RelevanceMixin(SparsityStats): + __sparsity_ignore__ = ("log_sigma2",) + + def relevance(self, *, threshold: float, **kwargs) -> torch.Tensor: + with torch.no_grad(): + return (self.log_alpha <= threshold).to(self.log_alpha.dtype) + + def sparsity(self, *, threshold: float, **kwargs): + relevance = self.relevance(threshold=threshold) + n_dropped = float(self.weight.numel()) - float(relevance.sum().item()) + return [(id(self.weight), n_dropped)] + + +class _CplxVDMixin: + r"""KL of complex Gaussian posterior vs scale-free log-uniform prior.""" + + @property + def penalty(self) -> torch.Tensor: + n_log_alpha = -self.log_alpha + # Euler-Mascheroni constant ensures non-negativity. + return float(np.euler_gamma) + n_log_alpha - torch_expi(-torch.exp(n_log_alpha)) + + +class _CplxARDMixin: + r"""Empirical-Bayes (softplus) penalty.""" + + @property + def penalty(self) -> torch.Tensor: + return F.softplus(-self.log_alpha) + + +# --------------------------------------------------------------------------- +# Linear +# --------------------------------------------------------------------------- + + +class _LinearGaussian(_GaussianMixin, nn.Module): + r"""Complex linear with multiplicative Gaussian noise on the weight.""" + + def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.empty(out_features, in_features, dtype=torch.cfloat) + ) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, dtype=torch.cfloat)) + else: + self.register_parameter("bias", None) + self.log_sigma2 = nn.Parameter(torch.empty(out_features, in_features)) + self.reset_parameters() + + def reset_parameters(self) -> None: + _init_complex_weight(self.weight, self.in_features) + if self.bias is not None: + _init_complex_bias(self.bias, self.in_features) + self.reset_variational_parameters() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + mu = F.linear(input, self.weight, self.bias) + if not self.training: + return mu + # Variance of the additive noise: linear(|x|^2, exp(log_sigma^2)). + s2 = F.linear( + input.real * input.real + input.imag * input.imag, + torch.exp(self.log_sigma2), + None, + ) + s2 = s2.clamp_min(1e-8) + # Circular complex N(0, s2) noise. + eps_r = torch.randn_like(s2) + eps_i = torch.randn_like(s2) + noise = torch.complex(eps_r, eps_i) * (torch.sqrt(s2 / 2.0)) + return mu + noise + + +class LinearVD(_CplxVDMixin, _RelevanceMixin, _LinearGaussian, BaseARD): + r"""Complex Linear with Variational Dropout (log-uniform prior, exact KL via Ei).""" + + +class LinearARD(_CplxARDMixin, _RelevanceMixin, _LinearGaussian, BaseARD): + r"""Complex Linear with Automatic Relevance Determination (softplus penalty).""" + + +# --------------------------------------------------------------------------- +# Bilinear +# --------------------------------------------------------------------------- + + +class _BilinearGaussian(_GaussianMixin, nn.Module): + r"""Complex bilinear with multiplicative Gaussian noise on the weight.""" + + def __init__( + self, + in1_features: int, + in2_features: int, + out_features: int, + bias: bool = True, + conjugate: bool = True, + ) -> None: + super().__init__() + self.in1_features = in1_features + self.in2_features = in2_features + self.out_features = out_features + self.conjugate = conjugate + self.weight = nn.Parameter( + torch.empty(out_features, in1_features, in2_features, dtype=torch.cfloat) + ) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, dtype=torch.cfloat)) + else: + self.register_parameter("bias", None) + self.log_sigma2 = nn.Parameter( + torch.empty(out_features, in1_features, in2_features) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + bound = 1.0 / math.sqrt(self.in1_features) + with torch.no_grad(): + self.weight.real.uniform_(-bound, bound) + self.weight.imag.uniform_(-bound, bound) + if self.bias is not None: + self.bias.real.uniform_(-bound, bound) + self.bias.imag.uniform_(-bound, bound) + self.reset_variational_parameters() + + def forward(self, input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + x1 = input1.conj() if self.conjugate else input1 + mu = torch.einsum("...i,kij,...j->...k", x1, self.weight, input2) + if self.bias is not None: + mu = mu + self.bias + if not self.training: + return mu + # Variance per output: einsum(|x1|^2, exp(log_sigma2), |x2|^2) + m1 = input1.real * input1.real + input1.imag * input1.imag + m2 = input2.real * input2.real + input2.imag * input2.imag + s2 = torch.einsum("...i,kij,...j->...k", m1, torch.exp(self.log_sigma2), m2) + s2 = s2.clamp_min(1e-8) + eps_r = torch.randn_like(s2) + eps_i = torch.randn_like(s2) + noise = torch.complex(eps_r, eps_i) * (torch.sqrt(s2 / 2.0)) + return mu + noise + + +class BilinearVD(_CplxVDMixin, _RelevanceMixin, _BilinearGaussian, BaseARD): + pass + + +class BilinearARD(_CplxARDMixin, _RelevanceMixin, _BilinearGaussian, BaseARD): + pass diff --git a/complextorch/nn/utils/__init__.py b/complextorch/nn/utils/__init__.py new file mode 100644 index 0000000..78afaa9 --- /dev/null +++ b/complextorch/nn/utils/__init__.py @@ -0,0 +1,5 @@ +"""Utility helpers.""" + +from complextorch.nn.utils.sparsity import SparsityStats, named_sparsity, sparsity + +__all__ = ["SparsityStats", "named_sparsity", "sparsity"] diff --git a/complextorch/nn/utils/sparsity.py b/complextorch/nn/utils/sparsity.py new file mode 100644 index 0000000..d219a99 --- /dev/null +++ b/complextorch/nn/utils/sparsity.py @@ -0,0 +1,71 @@ +r""" +Sparsity Statistics Helpers +=========================== + +Walk a module tree and report sparsity stats for layers that subclass +:class:`SparsityStats` (typically a :mod:`complextorch.nn.masked` or +:mod:`complextorch.nn.relevance` layer). +""" + +from collections.abc import Iterator + +import torch + +__all__ = ["SparsityStats", "named_sparsity", "sparsity"] + + +class SparsityStats(torch.nn.Module): + r""" + Mixin for modules that can report ``n_zeros`` per parameter. + + Subclasses must implement :meth:`sparsity(self, **kwargs)`, returning a + list of ``(param_id, n_dropped)`` tuples where ``param_id`` uniquely + identifies a parameter (we use ``id(p)``). + """ + + __sparsity_ignore__: tuple = () + + def sparsity(self, **kwargs): + raise NotImplementedError( + "Subclasses of SparsityStats must implement `.sparsity(self, **kwargs)`" + ) + + +def named_sparsity( + module: torch.nn.Module, *, prefix: str = "", **kwargs +) -> Iterator[tuple[str, tuple[int, int]]]: + r""" + Yield ``(param_name, (n_zeros, n_total))`` for each parameter of every + :class:`SparsityStats` submodule. + + ``kwargs`` are forwarded to :meth:`SparsityStats.sparsity` (typically + ``threshold=...``). + """ + # Build a mapping id(param) -> ("module_name.param_name", n_total). + pid_to_name = {} + pid_to_total = {} + for mod_name, mod in module.named_modules(prefix=prefix): + for p_name, p in mod.named_parameters(recurse=False): + full = f"{mod_name}.{p_name}" if mod_name else p_name + pid_to_name[id(p)] = full + pid_to_total[id(p)] = p.numel() + + seen = set() + for _mod_name, mod in module.named_modules(prefix=prefix): + if not isinstance(mod, SparsityStats): + continue + for pid, n_dropped in mod.sparsity(**kwargs): + if pid in seen or pid not in pid_to_name: + continue + seen.add(pid) + yield pid_to_name[pid], (int(n_dropped), int(pid_to_total[pid])) + + +def sparsity(module: torch.nn.Module, **kwargs) -> float: + """Return the overall sparsity ratio (``n_zeros / n_total``).""" + total = 0 + zeros = 0 + for _, (n_z, n_t) in named_sparsity(module, **kwargs): + zeros += n_z + total += n_t + return zeros / total if total else 0.0 diff --git a/complextorch/signal.py b/complextorch/signal.py new file mode 100644 index 0000000..a5d9ecc --- /dev/null +++ b/complextorch/signal.py @@ -0,0 +1,124 @@ +r""" +Complex-Aware Signal-Processing Utilities +========================================= + +A small set of complex-aware signal helpers that don't fit naturally in +:mod:`complextorch.nn`. Currently: + +- :func:`pwelch` — Welch power spectral density (torch port of + :func:`scipy.signal.welch`). Works on both real and complex inputs and is + differentiable end-to-end. +""" + +import torch + +__all__ = ["pwelch"] + + +def _window_view(x: torch.Tensor, dim: int, size: int, stride: int = 1) -> torch.Tensor: + """Sliding-window view of ``x`` along ``dim``. Returns a view of shape + ``(..., n_windows, size, ...)`` (the new window axis replaces ``dim``, + with the per-window-time axis appended immediately after).""" + return x.unfold(dim, size, stride) + + +def pwelch( + x: torch.Tensor, + window: int | torch.Tensor = 256, + fs: float = 1.0, + scaling: str = "density", + n_overlap: int | None = None, + detrend: str = "constant", + return_onesided: bool | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Welch Power Spectral Density (torch). + + Mirrors the headline arguments of :func:`scipy.signal.welch`. Works on + real or complex ``x``; for a complex input the returned spectrum is + two-sided by default (matching scipy). + + Args: + x: signal tensor; PSD is computed over the **last** dim. + window: integer window length (uses Hann), or a 1-D tensor with the + window samples. Defaults to 256 (or the signal length if shorter). + fs: sampling frequency (in Hz). Defaults to ``1.0`` so that the + returned ``frequencies`` lie in ``[-0.5, 0.5)`` for two-sided + or ``[0, 0.5]`` for one-sided. + scaling: ``'density'`` (units of power / Hz) or ``'spectrum'`` (units + of power). + n_overlap: number of samples to overlap between segments. Defaults to + ``len(window) // 2``. + detrend: ``'constant'`` (subtract per-segment mean, the scipy default) + or ``'none'``. + return_onesided: ``True`` -> drop negative frequencies (real signals + only). ``None`` -> auto: ``True`` for real ``x``, ``False`` for + complex ``x``. + + Returns: + ``(frequencies, psd)`` — ``frequencies`` of shape ``(F,)``, ``psd`` of + shape ``x.shape[:-1] + (F,)``. + """ + n = x.shape[-1] + if isinstance(window, int): + win_len = min(window, n) + win = torch.hann_window(win_len, dtype=x.real.dtype, device=x.device) + else: + win = window.to(dtype=x.real.dtype, device=x.device) + win_len = win.shape[0] + if n_overlap is None: + n_overlap = win_len // 2 + stride = win_len - n_overlap + if stride <= 0: + raise ValueError( + f"n_overlap ({n_overlap}) must be smaller than window length ({win_len})" + ) + + if return_onesided is None: + return_onesided = not x.is_complex() + if return_onesided and x.is_complex(): + raise ValueError("return_onesided=True is invalid for complex signals") + + # Sliding-window view: (..., n_windows, win_len) + segs = _window_view(x, dim=-1, size=win_len, stride=stride) + + # Optional detrending per segment + if detrend == "constant": + segs = segs - segs.mean(dim=-1, keepdim=True) + elif detrend != "none": + raise ValueError(f"detrend must be 'constant' or 'none', got {detrend!r}") + + # Apply window + segs = segs * win + + # FFT each segment + if return_onesided: + spec = torch.fft.rfft(segs, n=win_len, dim=-1) + freqs = torch.fft.rfftfreq(win_len, d=1.0 / fs).to(x.device) + else: + spec = torch.fft.fft(segs, n=win_len, dim=-1) + freqs = torch.fft.fftfreq(win_len, d=1.0 / fs).to(x.device) + + # Average power across segments + psd = (spec.real**2 + spec.imag**2).mean(dim=-2) + + # Scaling factor: matches scipy convention + if scaling == "density": + scale = 1.0 / (fs * (win * win).sum()) + elif scaling == "spectrum": + scale = 1.0 / (win.sum() ** 2) + else: + raise ValueError(f"scaling must be 'density' or 'spectrum', got {scaling!r}") + psd = psd * scale + + # One-sided correction: double interior bins (DC and Nyquist unchanged) + if return_onesided: + # Indices 1..-2 (interior of rfft output) get a factor of 2 + if psd.shape[-1] > 2: + psd_mid = psd[..., 1:-1] * 2.0 + psd = torch.cat([psd[..., :1], psd_mid, psd[..., -1:]], dim=-1) + elif psd.shape[-1] == 2: + # Only DC + Nyquist + pass + + return freqs, psd diff --git a/complextorch/transforms/__init__.py b/complextorch/transforms/__init__.py new file mode 100644 index 0000000..375d2a0 --- /dev/null +++ b/complextorch/transforms/__init__.py @@ -0,0 +1,51 @@ +r""" +Dataloader-Stage Transforms (Torch-Only) +======================================== + +A complex-aware analogue of :mod:`torchvision.transforms`. All transforms are +:class:`torch.nn.Module` subclasses and operate on torch tensors. Numpy paths +are intentionally not provided — pre-convert via :class:`ToTensor`. +""" + +from complextorch.transforms.functional import polsar_dict_to_array, rescale_intensity +from complextorch.transforms.transforms import ( + FFT2, + HWC2CHW, + IFFT2, + Amplitude, + CenterCrop, + FFTResize, + LogAmplitude, + Normalize, + PadIfNeeded, + PolSAR, + RandomPhase, + RealImaginary, + SpatialResize, + ToImaginary, + ToReal, + ToTensor, + Unsqueeze, +) + +__all__ = [ + "FFT2", + "HWC2CHW", + "IFFT2", + "Amplitude", + "CenterCrop", + "FFTResize", + "LogAmplitude", + "Normalize", + "PadIfNeeded", + "PolSAR", + "RandomPhase", + "RealImaginary", + "SpatialResize", + "ToImaginary", + "ToReal", + "ToTensor", + "Unsqueeze", + "polsar_dict_to_array", + "rescale_intensity", +] diff --git a/complextorch/transforms/functional.py b/complextorch/transforms/functional.py new file mode 100644 index 0000000..e522b11 --- /dev/null +++ b/complextorch/transforms/functional.py @@ -0,0 +1,148 @@ +r""" +Functional Transform Helpers +============================ + +Lower-level helpers backing the class transforms in +:mod:`complextorch.transforms.transforms`. Most are intentionally private +(``_``-prefixed); the public ones are listed in ``__all__``. +""" + +import torch +import torch.nn.functional as F + +__all__ = ["polsar_dict_to_array", "rescale_intensity"] + + +# --------------------------------------------------------------------------- +# Public helpers +# --------------------------------------------------------------------------- + + +def polsar_dict_to_array( + d: dict[str, torch.Tensor], order: tuple[str, ...] = ("HH", "HV", "VH", "VV") +) -> torch.Tensor: + r""" + Stack a polarimetric SAR channel dictionary into a tensor. + + Args: + d: mapping from polarization name (e.g. ``'HH'``) to a 2-D complex + tensor of shape ``(H, W)`` (or higher-rank; all entries must agree). + order: which channels to keep, in the desired output order. Defaults + to the standard quad-pol order ``(HH, HV, VH, VV)``. + + Returns: + Complex tensor of shape ``(len(order), H, W)``. + """ + chans = [d[k] for k in order if k in d] + if not chans: + raise ValueError( + f"none of the requested channels {order} are present in dict keys {tuple(d.keys())}" + ) + return torch.stack(chans, dim=0) + + +def rescale_intensity( + x: torch.Tensor, + in_range: tuple[float, float] | None = None, + out_range: tuple[float, float] = (0.0, 1.0), +) -> torch.Tensor: + r""" + Linearly remap intensity values from ``in_range`` to ``out_range``. + + Mirrors :func:`skimage.exposure.rescale_intensity` for torch tensors. + Values outside ``in_range`` are clamped to the corresponding output + bound. ``in_range=None`` uses ``(x.min(), x.max())``. + """ + if x.is_complex(): + raise TypeError("rescale_intensity expects a real tensor; pass abs(x) first") + if in_range is None: + in_lo, in_hi = float(x.min()), float(x.max()) + else: + in_lo, in_hi = in_range + out_lo, out_hi = out_range + if in_hi == in_lo: + return torch.full_like(x, out_lo) + x = x.clamp(in_lo, in_hi) + return (x - in_lo) / (in_hi - in_lo) * (out_hi - out_lo) + out_lo + + +# --------------------------------------------------------------------------- +# Private helpers backing the class transforms +# --------------------------------------------------------------------------- + + +def _check_chw(x: torch.Tensor) -> None: + if x.dim() not in (3, 4): + raise ValueError( + f"expected a (C, H, W) or (B, C, H, W) tensor, got {tuple(x.shape)}" + ) + + +def _log_normalize_amplitude( + x: torch.Tensor, scale: float = 1.0, preserve_phase: bool = True +) -> torch.Tensor: + """``log1p(|x| / scale) * exp(j * arg x)`` when ``preserve_phase``; else real magnitude.""" + out_mag = torch.log1p(x.abs() / scale) + if preserve_phase and x.is_complex(): + return torch.polar(out_mag, x.angle()) + return out_mag + + +def _applyfft2(x: torch.Tensor) -> torch.Tensor: + return torch.fft.fftshift(torch.fft.fft2(x), dim=(-2, -1)) + + +def _applyifft2(x: torch.Tensor) -> torch.Tensor: + return torch.fft.ifft2(torch.fft.ifftshift(x, dim=(-2, -1))) + + +def _get_padding( + shape: tuple[int, int], target: tuple[int, int] +) -> tuple[int, int, int, int]: + """Symmetric (left, right, top, bottom) pad amounts to bring (H, W) up to target.""" + h, w = shape + th, tw = target + pad_h = max(th - h, 0) + pad_w = max(tw - w, 0) + top = pad_h // 2 + bot = pad_h - top + left = pad_w // 2 + right = pad_w - left + return left, right, top, bot + + +def _padifneeded( + x: torch.Tensor, min_h: int, min_w: int, mode: str = "constant" +) -> torch.Tensor: + h, w = x.shape[-2], x.shape[-1] + if h >= min_h and w >= min_w: + return x + left, right, top, bot = _get_padding((h, w), (min_h, min_w)) + if x.is_complex(): + re = F.pad(x.real, (left, right, top, bot), mode=mode) + im = F.pad(x.imag, (left, right, top, bot), mode=mode) + return torch.complex(re, im) + return F.pad(x, (left, right, top, bot), mode=mode) + + +def _center_crop(x: torch.Tensor, h: int, w: int) -> torch.Tensor: + H, W = x.shape[-2], x.shape[-1] + if h > H or w > W: + raise ValueError(f"center_crop: target {(h, w)} larger than input {(H, W)}") + top = (H - h) // 2 + left = (W - w) // 2 + return x[..., top : top + h, left : left + w] + + +def _spatial_resize_bicubic(x: torch.Tensor, h: int, w: int) -> torch.Tensor: + # F.interpolate(mode='bicubic') doesn't support complex; do split. + needs_unsqueeze = x.dim() == 3 + if needs_unsqueeze: + x = x.unsqueeze(0) + if x.is_complex(): + re = F.interpolate(x.real, size=(h, w), mode="bicubic", align_corners=False) + im = F.interpolate(x.imag, size=(h, w), mode="bicubic", align_corners=False) + out = torch.complex(re, im) + else: + out = F.interpolate(x, size=(h, w), mode="bicubic", align_corners=False) + return out.squeeze(0) if needs_unsqueeze else out diff --git a/complextorch/transforms/transforms.py b/complextorch/transforms/transforms.py new file mode 100644 index 0000000..5cbc19c --- /dev/null +++ b/complextorch/transforms/transforms.py @@ -0,0 +1,427 @@ +r""" +Torch-Native Complex Transforms +=============================== + +Class-based transforms operating on torch tensors. All transforms are +:class:`torch.nn.Module` subclasses so they compose inside +:class:`torch.nn.Sequential` (or :class:`torchvision.transforms.Compose`). +""" + +import math + +import torch +import torch.nn as nn + +from complextorch.transforms.functional import ( + _applyfft2, + _applyifft2, + _center_crop, + _check_chw, + _log_normalize_amplitude, + _padifneeded, + _spatial_resize_bicubic, +) + +__all__ = [ + "FFT2", + "HWC2CHW", + "IFFT2", + "Amplitude", + "CenterCrop", + "FFTResize", + "LogAmplitude", + "Normalize", + "PadIfNeeded", + "PolSAR", + "RandomPhase", + "RealImaginary", + "SpatialResize", + "ToImaginary", + "ToReal", + "ToTensor", + "Unsqueeze", +] + + +# --------------------------------------------------------------------------- +# I/O and shape transforms +# --------------------------------------------------------------------------- + + +class ToTensor(nn.Module): + r"""Cast input to a :class:`torch.Tensor` of the requested ``dtype``. + + Args: + dtype: target dtype. Defaults to ``torch.cfloat`` since most + :mod:`complextorch` workflows expect complex inputs. + """ + + def __init__(self, dtype: torch.dtype = torch.cfloat) -> None: + super().__init__() + self.dtype = dtype + + def forward(self, x) -> torch.Tensor: + t = torch.as_tensor(x) + return t.to(self.dtype) + + def extra_repr(self) -> str: + return f"dtype={self.dtype}" + + +class Unsqueeze(nn.Module): + r"""Insert a size-1 dimension at ``dim``. + + Thin :class:`torch.nn.Module` wrapper around :meth:`torch.Tensor.unsqueeze`, + intended for use inside a :class:`torchvision.transforms.Compose` pipeline. + + Args: + dim: position at which the new axis is inserted. + """ + + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.unsqueeze(self.dim) + + def extra_repr(self) -> str: + return f"dim={self.dim}" + + +class HWC2CHW(nn.Module): + r"""Permute ``(H, W, C)`` to ``(C, H, W)``. + + PIL / NumPy image conventions store channels-last; PyTorch expects + channels-first. Raises :class:`ValueError` on inputs that are not 3-D. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() != 3: + raise ValueError(f"HWC2CHW expects a 3-D tensor, got {tuple(x.shape)}") + return x.permute(2, 0, 1).contiguous() + + +# --------------------------------------------------------------------------- +# Magnitude / component extraction +# --------------------------------------------------------------------------- + + +class LogAmplitude(nn.Module): + r""" + ``log1p(|x| / scale) * exp(j*arg x)`` (or real magnitude if ``preserve_phase=False``). + + Standard SAR preprocessing: raw SAR magnitudes span many orders of + magnitude, so the log-scaling makes them tractable for a network. + """ + + def __init__(self, scale: float = 1.0, preserve_phase: bool = True) -> None: + super().__init__() + self.scale = scale + self.preserve_phase = preserve_phase + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _log_normalize_amplitude(x, self.scale, self.preserve_phase) + + +class Amplitude(nn.Module): + r"""Returns ``|x|`` (complex -> real).""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.abs() + + +class ToReal(nn.Module): + r"""Return :math:`\Re(x)`, i.e. the real part of a complex tensor. + + No-op for inputs that are already real (returned unchanged), so the + transform is safe to use unconditionally in a preprocessing pipeline. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.real if x.is_complex() else x + + +class ToImaginary(nn.Module): + r"""Return :math:`\Im(x)`, i.e. the imaginary part of a complex tensor. + + For real-valued inputs, returns a tensor of zeros with the same shape + and dtype so the transform composes cleanly with mixed pipelines. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.imag if x.is_complex() else torch.zeros_like(x) + + +class RealImaginary(nn.Module): + r""" + Stack real and imaginary parts along the channel dim. + + Complex ``(C, H, W)`` -> real ``(2C, H, W)``; ``(B, C, H, W)`` -> + ``(B, 2C, H, W)``. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not x.is_complex(): + return x + return torch.cat([x.real, x.imag], dim=-3) + + +# --------------------------------------------------------------------------- +# Statistics / randomization +# --------------------------------------------------------------------------- + + +class Normalize(nn.Module): + r""" + Per-channel 2x2 whitening with precomputed statistics. + + Given per-channel ``mean`` (complex, shape ``(C,)``) and ``covariance`` + (real, shape ``(C, 2, 2)``), applies + ``(x - mean) @ cov^{-1/2}`` per channel. The 2x2 matrix square root is + computed via :func:`complextorch.nn.functional.inv_sqrtm2x2`. + """ + + def __init__(self, mean: torch.Tensor, covariance: torch.Tensor) -> None: + super().__init__() + if covariance.shape[-2:] != (2, 2): + raise ValueError( + f"covariance must have shape (..., 2, 2), got {tuple(covariance.shape)}" + ) + self.register_buffer("mean", mean.to(torch.cfloat)) + self.register_buffer("covariance", covariance.to(torch.float32)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + from complextorch.nn.functional import inv_sqrtm2x2 + + # Contract: (C, H, W) or (B, C, H, W) — channel at dim -3, two spatial dims. + if x.dim() not in (3, 4) or x.shape[-3] != self.mean.shape[0]: + raise ValueError( + f"expected input of shape (C, H, W) or (B, C, H, W) with C={self.mean.shape[0]}, " + f"got shape {tuple(x.shape)}" + ) + m = self.mean.view(self.mean.shape[0], 1, 1) + x = x - m + a = self.covariance[..., 0, 0].view(-1, 1, 1) + b = self.covariance[..., 0, 1].view(-1, 1, 1) + c = self.covariance[..., 1, 0].view(-1, 1, 1) + d = self.covariance[..., 1, 1].view(-1, 1, 1) + w, xc, yc, z = inv_sqrtm2x2(a, b, c, d) + re = w * x.real + xc * x.imag + im = yc * x.real + z * x.imag + return torch.complex(re, im) + + +class RandomPhase(nn.Module): + r"""Multiply by ``exp(j * phi)`` with ``phi ~ Uniform(0, 2*pi)`` (or ``[-pi, pi]`` if ``centered``). + + Phase-invariance data augmentation for coherent signals. + """ + + def __init__(self, centered: bool = False) -> None: + super().__init__() + self.centered = centered + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.centered: + phi = torch.empty((), device=x.device).uniform_(-math.pi, math.pi) + else: + phi = torch.empty((), device=x.device).uniform_(0.0, 2.0 * math.pi) + rotor = torch.polar(torch.tensor(1.0, device=x.device), phi) + if x.is_complex(): + return x * rotor + return x.to(torch.cfloat) * rotor + + +# --------------------------------------------------------------------------- +# Spatial +# --------------------------------------------------------------------------- + + +class PadIfNeeded(nn.Module): + r"""Symmetric padding to bring ``(H, W)`` up to at least ``(min_h, min_w)``. + + ``mode`` matches :func:`torch.nn.functional.pad`: ``'constant'``, + ``'reflect'``, ``'replicate'``, ``'circular'``. + """ + + def __init__(self, min_h: int, min_w: int, mode: str = "constant") -> None: + super().__init__() + self.min_h = min_h + self.min_w = min_w + self.mode = mode + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _check_chw(x) + return _padifneeded(x, self.min_h, self.min_w, mode=self.mode) + + +class CenterCrop(nn.Module): + r"""Center-crop to ``(h, w)``.""" + + def __init__(self, h: int, w: int) -> None: + super().__init__() + self.h = h + self.w = w + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _center_crop(x, self.h, self.w) + + +class SpatialResize(nn.Module): + r"""Bicubic resize to ``(h, w)`` (split real/imag for complex inputs).""" + + def __init__(self, h: int, w: int) -> None: + super().__init__() + self.h = h + self.w = w + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _spatial_resize_bicubic(x, self.h, self.w) + + +# --------------------------------------------------------------------------- +# Spectral +# --------------------------------------------------------------------------- + + +class FFT2(nn.Module): + r"""2-D FFT with zero-frequency centering (``fftshift(fft2(x))``).""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _applyfft2(x) + + +class IFFT2(nn.Module): + r"""Inverse of :class:`FFT2`: ``ifft2(ifftshift(x))``.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _applyifft2(x) + + +class FFTResize(nn.Module): + r""" + Spectral-domain resize. + + FFT -> centre-crop or zero-pad the spectrum to ``(h, w)`` -> inverse FFT. + Preserves spectral characteristics of coherent signals (useful for SAR); + with ``energy_preserving=True``, scales the spectrum so the total power + matches the spatial-resize convention. + """ + + def __init__(self, h: int, w: int, energy_preserving: bool = True) -> None: + super().__init__() + self.h = h + self.w = w + self.energy_preserving = energy_preserving + + def forward(self, x: torch.Tensor) -> torch.Tensor: + spec = _applyfft2(x) + H, W = spec.shape[-2], spec.shape[-1] + # Crop or zero-pad to (h, w). + spec = _resize_spectrum(spec, self.h, self.w) + if self.energy_preserving: + scale = math.sqrt((self.h * self.w) / (H * W)) + spec = spec * scale + return _applyifft2(spec) + + +def _resize_spectrum(spec: torch.Tensor, h: int, w: int) -> torch.Tensor: + """Centre-crop or zero-pad the trailing (H, W) of the complex ``spec`` to ``(h, w)``.""" + H, W = spec.shape[-2], spec.shape[-1] + # Vertical + if h <= H: + top = (H - h) // 2 + spec = spec[..., top : top + h, :] + else: + pad_top = (h - H) // 2 + pad_bot = h - H - pad_top + pad_r = torch.nn.functional.pad(spec.real, (0, 0, pad_top, pad_bot)) + pad_i = torch.nn.functional.pad(spec.imag, (0, 0, pad_top, pad_bot)) + spec = torch.complex(pad_r, pad_i) + # Horizontal + if w <= W: + left = (W - w) // 2 + spec = spec[..., :, left : left + w] + else: + pad_l = (w - W) // 2 + pad_r = w - W - pad_l + pad_re = torch.nn.functional.pad(spec.real, (pad_l, pad_r, 0, 0)) + pad_im = torch.nn.functional.pad(spec.imag, (pad_l, pad_r, 0, 0)) + spec = torch.complex(pad_re, pad_im) + return spec + + +# --------------------------------------------------------------------------- +# Polarimetric SAR +# --------------------------------------------------------------------------- + + +class PolSAR(nn.Module): + r""" + PolSAR channel selection. + + Input is a complex tensor with ``C`` channels following the standard + quad-pol order (HH, HV, VH, VV) — at minimum the first ``C`` of those. + Reduces to ``out_channels`` channels per: + + - C=1: identity (any out_channels=1). + - C=2 (HH, VV by convention here, matching torchcvnn): out=1 -> [HH]; + out=2 -> [HH, VV]. + - C=3 (HH, VV, HV by convention): out=1 -> [HH]; out=2 -> [HH, VV]; + out=3 -> all. + - C=4: out=1 -> [HH]; out=2 -> [HH, VV]; out=3 -> [HH, VV, (HV+VH)/2]; + out=4 -> all. + """ + + def __init__(self, out_channels: int) -> None: + super().__init__() + if out_channels < 1 or out_channels > 4: + raise ValueError(f"out_channels must be in [1, 4], got {out_channels}") + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() < 3: + raise ValueError(f"expected at least 3 dims, got {tuple(x.shape)}") + c = x.shape[-3] + out = self.out_channels + if c == 1: + if out != 1: + raise ValueError( + f"single-channel input requires out_channels=1, got {out}" + ) + return x + if c == 2: + if out == 1: + return x[..., 0:1, :, :] + if out == 2: + return x + raise ValueError( + f"2-channel input supports out_channels in (1, 2), got {out}" + ) + if c == 3: + if out == 1: + return x[..., 0:1, :, :] + if out == 2: + return x[..., :2, :, :] + if out == 3: + return x + raise ValueError( + f"3-channel input supports out_channels in (1, 2, 3), got {out}" + ) + if c == 4: + hh, hv, vh, vv = ( + x[..., 0, :, :], + x[..., 1, :, :], + x[..., 2, :, :], + x[..., 3, :, :], + ) + if out == 1: + return hh.unsqueeze(-3) + if out == 2: + return torch.stack([hh, vv], dim=-3) + if out == 3: + return torch.stack([hh, vv, 0.5 * (hv + vh)], dim=-3) + if out == 4: + return x + raise ValueError(f"unsupported input channel count {c}") diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100755 index 7c2fbc3..0000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -sphinx_rtd_theme -numpy>=2.2.0 -torch>=1.11.0+cu115 -deprecated>=1.2.18 \ No newline at end of file diff --git a/docs/source/_static/switcher.json b/docs/source/_static/switcher.json new file mode 100644 index 0000000..1a0b79a --- /dev/null +++ b/docs/source/_static/switcher.json @@ -0,0 +1,8 @@ +[ + { + "name": "latest (main)", + "version": "latest", + "url": "https://josiahwsmith10.github.io/complextorch/latest/", + "preferred": true + } +] diff --git a/docs/source/_templates/redirect.html b/docs/source/_templates/redirect.html new file mode 100644 index 0000000..c3b1e1e --- /dev/null +++ b/docs/source/_templates/redirect.html @@ -0,0 +1,13 @@ + + + + + complextorch documentation + + + + +

Redirecting to complextorch latest documentation

+ + + diff --git a/docs/source/about.md b/docs/source/about.md new file mode 100644 index 0000000..7c7e771 --- /dev/null +++ b/docs/source/about.md @@ -0,0 +1,102 @@ +# About + +## Author + +**Josiah W. Smith, Ph.D.** — · +[GitHub](https://github.com/josiahwsmith10) · +[LinkedIn](https://www.linkedin.com/in/josiahwsmith/) + +```{note} +The author no longer has access to the university email addresses +`josiah.smith@utdallas.edu` / `jws160130@utdallas.edu`. Please use +. +``` + +Currently working on research publications in deep learning, computer vision, +radar signal processing, and imaging. Completed PhD at **The University of +Texas at Dallas** in 2022. + +## PhD dissertation + +- **Novel Hybrid-Learning Algorithms for Improved Millimeter-Wave Imaging + Systems** — [arXiv:2306.15341](https://arxiv.org/abs/2306.15341) + +## Selected papers + +- Deep Learning-Based Multiband Signal Fusion for 3-D SAR Super-Resolution — + [arXiv](https://arxiv.org/abs/2305.02017) · + [DOI](https://doi.org/10.1109/TAES.2023.3270111) +- Efficient CNN-based Super Resolution Algorithms for mmWave Mobile Radar + Imaging — [arXiv](https://arxiv.org/abs/2305.02092) · + [DOI](https://doi.org/10.1109/ICIP46576.2022.9897190) +- A Vision Transformer Approach for Efficient Near-Field Irregular SAR + Super-Resolution — [arXiv](https://arxiv.org/abs/2305.02074) · + [DOI](https://doi.org/10.1109/WMCS55582.2022.9866326) +- Efficient 3-D Near-Field MIMO-SAR Imaging for Irregular Scanning Geometries + — [arXiv](https://arxiv.org/abs/2305.02064) · + [DOI](https://doi.org/10.1109/ACCESS.2022.3145370) +- An FCNN-Based Super-Resolution mmWave Radar Framework for Contactless Musical + Instrument Interface — [arXiv](https://arxiv.org/abs/2305.01995) · + [DOI](https://doi.org/10.1109/TMM.2021.3079695) +- Improved Static Hand Gesture Classification on Deep CNNs using Novel Sterile + Training Technique — [arXiv](https://arxiv.org/abs/2305.02039) · + [DOI](https://doi.org/10.1109/ACCESS.2021.3051454) +- Near-Field MIMO-ISAR Millimeter-Wave Imaging — + [arXiv](https://arxiv.org/abs/2305.02030) · + [DOI](https://doi.org/10.1109/RadarConf2043947.2020.9266412) + +## Citing this package + +If `complextorch` helps your research, please cite the dissertation: + +```bibtex +@phdthesis{smith2022novel, + title = {Novel Hybrid-Learning Algorithms for Improved Millimeter-Wave Imaging Systems}, + author = {Smith, J. W.}, + year = 2022, + month = apr, + address= {Richardson, Texas, USA}, + note = {Available at \url{https://arxiv.org/abs/2306.15341}}, + school = {Dept. Elect. Comput. Eng., Univ. Texas Dallas}, + type = {PhD dissertation} +} +``` + +## Related GitHub repositories + +Notable repositories are highlighted in *italics*. + +**Simulating radar and SAR data** +- (V1) [SAR Simulation](https://github.com/josiahwsmith10/sar-simulation-jws) +- (V2) [FMCW MIMO-SAR Simulation and Image Reconstruction Toolbox](https://github.com/josiahwsmith10/FMCW-MIMO-SAR-Simulation-and-Image-Reconstruction-Toolbox) +- *(V3) [THz and Sub-THz Imaging Toolbox](https://github.com/josiahwsmith10/THz-and-Sub-THz-Imaging-Toolbox)* + +**Data-driven radar image enhancement** +- (V1) [SAR CNN Enhancement](https://github.com/josiahwsmith10/sar-cnn-enhancement) +- (V2) [Improved SAR CNN Enhancement](https://github.com/josiahwsmith10/improved-sar-cnn-enhancement) +- (V2.1) [Improved SAR CNN Enhancement-v2](https://github.com/josiahwsmith10/improved-sar-cnn-enhancement-v2) + +**MATLAB user interfaces for SAR scanners** +- *[Dual Radar GUI](https://github.com/josiahwsmith10/dual-radar-gui)* +- [RSAR GUI](https://github.com/josiahwsmith10/RSAR-GUI) +- [MATLAB SAR Scanner API](https://github.com/josiahwsmith10/SAR-Scanner-Toolbox) +- [TI mmWave Studio MATLAB GUI](https://github.com/josiahwsmith10/mmWave-Studio-MATLAB-GUI-jws) + +**Embedded synchronization software** +- (V1) [TI Radar HW Trigger using ESP32](https://github.com/josiahwsmith10/single-TI-radar-HW-trigger-esp32) +- *(V2) [Dual Radar Synchronizer](https://github.com/josiahwsmith10/dual-radar-synchronizer)* + +**Documentation & tutorials** +- [Introduction to Near-Field SAR in MATLAB](https://github.com/josiahwsmith10/SAR-Intro) +- *[Introduction to MIMO-FMCW Radar](https://github.com/josiahwsmith10/Introduction-to-MIMO-FMCW-Radar)* + +**Other projects** +- [FCNN Audio Denoising](https://github.com/josiahwsmith10/FCNN-audio-denoising) +- [Deep Learning-Enhanced BLE Ranging](https://github.com/josiahwsmith10/deep-learning-BLE-ranging) +- [SVM-MUSIC Algorithm for DoA Estimation](https://github.com/josiahwsmith10/svm-music-algorithm) +- [FIR and IIR Filtering for Audio Denoising](https://github.com/josiahwsmith10/FIR-and-IIR-Filtering-for-Audio-Denoising) +- [Nyquist Sampling and Sunspots](https://github.com/josiahwsmith10/sampling-and-sunspots) + +**Code accompanying papers** +- *[Efficient 3-D Near-Field MIMO-SAR Imaging for Irregular Scanning Geometries](https://github.com/josiahwsmith10/Efficient-3-D-Near-Field-MIMO-SAR-Imaging-for-Irregular-Scanning-Geometries)* +- *[FCNN-Based Super-Resolution mmWave Radar Framework for Contactless Musical Instrument Interface](https://github.com/josiahwsmith10/Radar-Musical-Instrument)* diff --git a/docs/source/about.rst b/docs/source/about.rst deleted file mode 100755 index 6f0cfb0..0000000 --- a/docs/source/about.rst +++ /dev/null @@ -1,187 +0,0 @@ -Author Information -================== - -Josah W. Smith, Ph.D. ---------------------- - -josiah.radar@gmail.com, `GitHub `_, `LinkedIn `_ - -*PSA:* I no longer have access to my university email address (josiah.smith@utdallas.edu or jws160130@utdallas.edu) so please contact me at josiah.radar@gmail.com. - -- 🔭 I am currently working on several research publications in deep learning, computer vision, radar signal processing, and imaging. - -- 📫 How to reach me **josiah.radar@gmail.com** - -- 📄 Know about my experiences at my `LinkedIn `_ - -- ⚡ Fun fact **I love music, hiking, camping, and have lived around the world.** - -- I completed my PhD from **The University of Texas at Dallas** in 2022 - -PhD Dissertation ----------------- -- Novel Hybrid-Learning Algorithms for Improved Millimeter-Wave Imaging Systems (`arXiv `_) - -Papers ------- -- Deep Learning-Based Multiband Signal Fusion for 3-D SAR Super-Resolution (`arXiv `_, `DOI `_) -- Efficient CNN-based Super Resolution Algorithms for mmWave Mobile Radar Imaging (`arXiv `_, `DOI `_) -- A Vision Transformer Approach for Efficient Near-Field Irregular SAR Super-Resolution (`arXiv `_, `DOI `_) -- Efficient 3-D Near-Field MIMO-SAR Imaging for Irregular Scanning Geometries (`arXiv `_, `DOI `_) -- An FCNN-Based Super-Resolution mmWave Radar Framework for Contactless Musical Instrument Interface (`arXiv `_, `DOI `_) -- Improved Static Hand Gesture Classification on Deep Convolutional Neural Networks using Novel Sterile Training Technique (`arXiv `_, `DOI `_) -- Near-Field MIMO-ISAR Millimeter-Wave Imaging (`arXiv `_, `DOI `_) - -Organization of GitHub Repositories ------------------------------------ - -(Notable repositories are highlighted in *italics*) - -- Simulating Radar and SAR Data - - (V1) SAR Simulation (`code `_) - - (V2) FMCW MIMO-SAR Simulation and Image Reconstruction Toolbox (`code `_) - - *(V3) THz and Sub-THz Imaging Toolbox* (`code `_) -- Data-Driven Radar Image Enhancement (Computer Vision for Radar) - - (V1) SAR CNN Enhancement (`code `_) - - (V2) Improved SAR CNN Enhancement (`code `_) - - (V2.1) Improved SAR CNN Enhancement-v2 (`code `_) -- MATLAB User Interfaces for Controlling SAR Scanners - - Dual Radar (Also Applicable for Single Radar) - - *Dual Radar GUI* (`code `_) - - SAR User Interface - - RSAR GUI (`code `_) - - MATLAB SAR Scanner API (`code `_) - - Texas Instruments mmWave Studio - - TI mmWave Studio MATLAB GUI (`code `_) -- Embedded Synchronization Software for SAR Scanner Controllers - - (V1) TI Radar HW Trigger using ESP32 (`code `_) - - *(V2) Dual Radar Synchronizer* (`code `_) -- Documentation and Introduction to Radar and SAR Principles - - Introduction to Near-Field SAR in MATLAB (`code `_) - - *Introduction to MIMO-FMCW Radar* (`doc `_) -- Projects - - FCNN Audio Denoising (`code `_) - - Deep Learning-Enhanced BLE Ranging (`code `_) - - SVM-MUSIC Algorithm for DoA Estimation (`code `_) - - FIR and IIR Filtering for Audio Denoising (`code `_) - - Nyquist Sampling and Sunspots (`code `_) -- WISLAB-Specific - - `WISLAB-Helps `_ -- *Code for Papers* - - *Efficient 3-D Near-Field MIMO-SAR Imaging for Irregular Scanning Geometries* (`arXiv `_, `DOI `_, `code `_) - - *An FCNN-Based Super-Resolution mmWave Radar Framework for Contactless Musical Instrument Interface* (`arXiv `_, `DOI `_, `code `_) - -BibTeX Citations ----------------- - -- Deep Learning-Based Multiband Signal Fusion for 3-D SAR Super-Resolution (`arXiv `_, `DOI `_) - - .. code-block:: latex - - @article{smith2023deep, - title = {Deep Learning-Based Multiband Signal Fusion for {3-D} {SAR} Super-Resolution}, - author = {Smith, J. W. and Torlak, M.}, - year = 2023, - month = apr, - journal = {IEEE Trans. Aerosp. Electron. Syst.}, - pages = {1--17} - } - -- Novel Hybrid-Learning Algorithms for Improved Millimeter-Wave Imaging Systems (`arXiv `_) - - .. code-block:: latex - - @phdthesis{smith2022novel, - title = {Novel Hybrid-Learning Algorithms for Improved Millimeter-Wave Imaging Systems}, - author = {Smith, J. W.}, - year = 2022, - month = apr, - address = {Richardson, Texas, USA}, - note = {Available at \url{https://arxiv.org/abs/2306.15341}}, - school = {Dept. Elect. Comput. Eng., Univ. Texas Dallas}, - type = {PhD dissertation} - } - -- Efficient CNN-based Super Resolution Algorithms for mmWave Mobile Radar Imaging (`arXiv `_, `DOI `_) - - .. code-block:: latex - - @article{smith2022efficient, - title = {Efficient {3-D} Near-Field {MIMO-SAR} Imaging for Irregular Scanning Geometries}, - author = {Smith, J. W. and Torlak, M.}, - year = 2022, - month = jan, - journal = {IEEE Access}, - volume = 10, - pages = {10283--10294} - } - -- A Vision Transformer Approach for Efficient Near-Field Irregular SAR Super-Resolution (`arXiv `_, `DOI `_) - - .. code-block:: latex - - @inproceedings{smith2022vision, - title = {A Vision Transformer Approach for Efficient Near-Field {SAR} Super-Resolution under Array Perturbation}, - author = {Smith, J. W. and Alimam, Y. and Vedula, G. and Torlak, M.}, - year = 2022, - month = apr, - booktitle = {Proc. IEEE Tex. Symp. Wirel. Microw. Circuits Syst. (WMCS)}, - address = {Waco, TX, USA}, - pages = {1--6} - } - -- Efficient 3-D Near-Field MIMO-SAR Imaging for Irregular Scanning Geometries (`arXiv `_, `DOI `_) - - .. code-block:: latex - - @inproceedings{vasileiou2022efficient, - title = {Efficient {CNN}-Based Super Resolution Algorithms for {mmWave} Mobile Radar Imaging}, - author = {Vasileiou, C. and Smith, J. W. and Thiagarajan, S. and Nigh, M. and Makris, Y. and Torlak, M.}, - year = 2022, - month = oct, - booktitle = {Proc. IEEE Int. Conf. Image Process. (ICIP)}, - address = {Bourdeaux, France}, - pages = {3803--3807} - } - -- Improved Static Hand Gesture Classification on Deep Convolutional Neural Networks using Novel Sterile Training Technique (`arXiv `_, `DOI `_) - - .. code-block:: latex - - @article{smith2021improved, - title = {Improved Static Hand Gesture Classification on Deep Convolutional Neural Networks Using Novel Sterile Training Technique}, - author = {Smith, J. W. and Thiagarajan, S. and Willis, R. and Makris, Y. and Torlak, M.}, - year = 2021, - month = jan, - journal = {IEEE Access}, - volume = 9, - pages = {10893--10902} - } - -- An FCNN-Based Super-Resolution mmWave Radar Framework for Contactless Musical Instrument Interface (`arXiv `_, `DOI `_) - - .. code-block:: latex - - @article{smith2021fcnn, - title = {An {FCNN}-Based Super-Resolution {mmWave} Radar Framework for Contactless Musical Instrument Interface}, - author = {Smith, J. W. and Furxhi, O. and Torlak, M.}, - year = 2021, - month = may, - journal = {IEEE Trans. Multimedia}, - volume = 24, - pages = {2315--2328} - } - -- Near-Field MIMO-ISAR Millimeter-Wave Imaging (`arXiv `_, `DOI `_) - - .. code-block:: latex - - @inproceedings{smith2020near, - title = {Near-Field {MIMO-ISAR} Millimeter-Wave Imaging}, - author = {Smith, J. W. and Yanik, M. E. and Torlak, M.}, - year = 2020, - month = sep, - booktitle = {Proc. IEEE Radar Conf. (RadarConf)}, - address = {Florance, Italy}, - pages = {1--6} - } diff --git a/docs/source/changelog.md b/docs/source/changelog.md new file mode 100644 index 0000000..3139cd4 --- /dev/null +++ b/docs/source/changelog.md @@ -0,0 +1,2 @@ +```{include} ../../CHANGELOG.md +``` diff --git a/docs/source/concepts/activations.md b/docs/source/concepts/activations.md new file mode 100644 index 0000000..f7b73ae --- /dev/null +++ b/docs/source/concepts/activations.md @@ -0,0 +1,61 @@ +# Complex-valued activations + +Complex-valued activation functions must take into account the two +degrees-of-freedom inherent to complex-valued data, typically represented as +real / imaginary parts or magnitude / phase. Two generalised classes of +activation operate on those respective representations and are defined as +*Type-A* and *Type-B* functions. + +## Type-A — split on real / imaginary + +Type-A activations consist of two real-valued functions, $G_\mathbb{R}(\cdot)$ +and $G_\mathbb{I}(\cdot)$, applied to the real and imaginary parts of the input +tensor independently: + +$$ +G(\mathbf{z}) = G_\mathbb{R}(\mathbf{x}) + j\, G_\mathbb{I}(\mathbf{y}) +$$ + +where $\mathbf{z} = \mathbf{x} + j\mathbf{y}$. + +Under the hood, Type-A activations call +{func}`complextorch.nn.functional.apply_complex_split`. Examples in the +package: `CVSplitReLU`, `CVSplitTanh`, `CVSplitSigmoid`, `CELU`, `CCELU`, +`CGELU`. See the [activation reference](../api/complextorch/nn/modules/activation/index) +for the full list. + +## Type-B — split on magnitude / phase + +Type-B activations consist of two real-valued functions, $G_{||}(\cdot)$ and +$G_\angle(\cdot)$, applied to the magnitude (modulus) and phase (argument) of +the input tensor: + +$$ +G(\mathbf{z}) = G_{||}\!\left(|\mathbf{z}|\right) \,\exp\!\left(j\, G_\angle\!\left(\arg \mathbf{z}\right)\right). +$$ + +Type-B activations call +{func}`complextorch.nn.functional.apply_complex_polar`. Passing `phase_fun=None` +is an optimisation that skips the polar round-trip when the activation only +modifies magnitude. Examples: `modReLU`, `AdaptiveModReLU`, `CVPolarTanh`. + +## Fully complex + +Fully-complex activations fit neither the Type-A nor the Type-B +designation — they operate on the complex tensor directly. Use them when an +activation has a natural complex form (e.g., a learnable phase rotation). + +## ReLU variants + +A separate family — `zReLU`, `CReLU`, `zAbsReLU`, `zLeakyReLU` — generalises +the rectified linear unit to the complex plane. These are documented alongside +their classes in the API reference. + +## When to use which + +| Need | Reach for | +| --- | --- | +| Drop-in replacement for `nn.ReLU` / `nn.Tanh` | `CVSplitReLU` / `CVSplitTanh` (Type-A) | +| Preserve phase, modulate magnitude only | `modReLU`, `AdaptiveModReLU` (Type-B, `phase_fun=None`) | +| Phase-aware operation | Type-B with both `mag_fun` and `phase_fun` set | +| Learnable scalar phase shift | {class}`complextorch.nn.PhaseShift` | diff --git a/docs/source/concepts/native-vs-gauss.md b/docs/source/concepts/native-vs-gauss.md new file mode 100644 index 0000000..182b3e2 --- /dev/null +++ b/docs/source/concepts/native-vs-gauss.md @@ -0,0 +1,83 @@ +# Native vs. Gauss-trick modules + +Convolution and linear layers in `complextorch.nn` exist in two variants: + +| Native cfloat (recommended) | Gauss-trick (reference) | +| --- | --- | +| {class}`complextorch.nn.Linear` | {class}`complextorch.nn.gauss.Linear` | +| {class}`complextorch.nn.Conv1d` / `Conv2d` / `Conv3d` | `complextorch.nn.gauss.Conv1d` / `Conv2d` / `Conv3d` | +| `ConvTranspose1d` / `2d` / `3d` | `complextorch.nn.gauss.ConvTranspose1d` / `2d` / `3d` | + +```{note} +Up to `complextorch < 2.0` the Gauss-trick variants lived at the top level as +`SlowConv*` / `SlowLinear`. The prefix was a misleading legacy from when they +were *faster* than the naive split; they have since been moved to the +{mod}`complextorch.nn.gauss` subpackage and the `Slow` names removed. +``` + +## What's the difference? + +**Native cfloat modules** are thin wrappers around the corresponding `torch.nn` +module constructed with `dtype=torch.cfloat`. They rely on PyTorch's native +complex kernels (available since PyTorch 2.1) and are the recommended path for +all new code. + +```python +import torch +import complextorch as ctorch + +x = torch.randn(8, 5, 7, dtype=torch.cfloat) +y = ctorch.nn.Conv1d(5, 16, kernel_size=3)(x) # native cfloat kernel +``` + +**Gauss-trick modules** are the original hand-rolled implementations that +split each complex tensor into real and imaginary parts and apply Gauss' +multiplication trick: + +$$ +(R + jI)(x + jy) = (Rx - Iy) + j(Ry + Ix) +$$ + +with a three-multiply real-valued formulation under the hood. They predate +PyTorch's native complex support and are kept for two reasons: + +1. **Reference math** — the Gauss path is the easiest place to read the + real/imag split when you're learning the package internals or implementing + a new layer. +2. **Explicit split parameters** — `conv_r` / `conv_i` (or `linear_r` / + `linear_i`) are exposed as separate `nn.Module` children, which is useful + if you want to apply different parameterizations or constraints to each + half. + +## Which should I use? + +Use the **native cfloat** variant. The Gauss-trick path no longer offers a +speed advantage since PyTorch 2.1, so its only remaining role is as a +numerically-equivalent reference. The test suite under `tests/invariants/` +checks the two paths agree to floating-point tolerance on the same weights. + +If you're adding a new layer that has a native PyTorch complex equivalent, +follow the native pattern (wrap `torch.nn.X` with `dtype=torch.cfloat`) +rather than reimplementing the real/imag split. + +## The three composition primitives + +Most non-convolutional layers in `complextorch` are built on three helpers in +{mod}`complextorch.nn.functional`: + +- {func}`~complextorch.nn.functional.apply_complex` — the "naive" complex + linear lift: $(R(x_r) - I(x_i)) + j(R(x_i) + I(x_r))$. +- {func}`~complextorch.nn.functional.apply_complex_split` — **Type-A** split: + apply two separate functions to real and imaginary parts independently. Used + by `CVSplit*` activations, `Dropout`, `CVSoftMax`, `AdaptiveAvgPool*d`. +- {func}`~complextorch.nn.functional.apply_complex_polar` — **Type-B** polar + split: apply functions to magnitude and phase separately, recombine via + `torch.polar`. Used by `CVPolar*` / `modReLU` activations. + +See [Activations](activations.md) for the math behind Type-A / Type-B. + +```{tip} +Construct magnitude/phase tensors with `torch.polar(abs, angle)` — it's been +a PyTorch builtin since 1.8 and is the idiomatic call. `complextorch` does +not provide a `from_polar` helper. +``` diff --git a/docs/source/conf.py b/docs/source/conf.py index a918199..f45e359 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,14 +1,12 @@ # Configuration file for the Sphinx documentation builder. # -# For the full list of built-in configuration values, see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html -# -- Project information ----------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information - import re from pathlib import Path +# -- Project information ----------------------------------------------------- + project = "complextorch" copyright = "2025, Josiah W. Smith" author = "Josiah W. Smith" @@ -28,28 +26,141 @@ def _read_version() -> str: version = release # -- General configuration --------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [ - "sphinx.ext.autodoc", - "sphinx.ext.autosummary", - "sphinx.ext.coverage", + "myst_nb", + "autoapi.extension", + "sphinx.ext.intersphinx", "sphinx.ext.napoleon", "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinx_copybutton", + "sphinx_design", + "sphinx_sitemap", + "sphinxext.opengraph", ] templates_path = ["_templates"] -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] + +# -- Source files ------------------------------------------------------------ + +# myst-nb owns .md and .ipynb (it subsumes myst-parser for Markdown). +source_suffix = { + ".rst": "restructuredtext", + ".md": "myst-nb", + ".ipynb": "myst-nb", +} + +# -- MyST / MyST-NB ---------------------------------------------------------- + +myst_enable_extensions = [ + "amsmath", + "colon_fence", + "deflist", + "dollarmath", + "substitution", + "tasklist", +] +myst_heading_anchors = 3 + +nb_execution_mode = "auto" +nb_execution_timeout = 120 +nb_execution_allow_errors = False +nb_execution_excludepatterns = [] -language = "python" +# -- Autoapi ----------------------------------------------------------------- -# -- Options for HTML output ------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output +autoapi_type = "python" +autoapi_dirs = [str(Path(__file__).resolve().parents[2] / "complextorch")] +autoapi_root = "api" +autoapi_keep_files = False +# We wire the API tree into our own toctree in index.md, so don't let +# autoapi inject a second top-level entry. +autoapi_add_toctree_entry = False +autoapi_python_class_content = "both" # merge class + __init__ docstrings +autoapi_member_order = "groupwise" +# ``imported-members`` would cause every re-exported symbol to appear twice +# (once in the source module, once in the re-exporting __init__). Leaving +# it off relies on users following the canonical source module link. +autoapi_options = [ + "members", + "undoc-members", + "show-inheritance", + "show-module-summary", + "special-members", +] +autoapi_ignore = ["*/_build/*", "*/tests/*"] + +# -- Napoleon (docstring parsing) -------------------------------------------- + +napoleon_google_docstring = True +napoleon_numpy_docstring = True +napoleon_include_init_with_doc = True +napoleon_use_admonition_for_examples = False +napoleon_use_rtype = False + +# -- Intersphinx ------------------------------------------------------------- + +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "torch": ("https://pytorch.org/docs/stable", None), + "numpy": ("https://numpy.org/doc/stable", None), + "scipy": ("https://docs.scipy.org/doc/scipy", None), +} -html_theme = "sphinx_rtd_theme" +# -- HTML output ------------------------------------------------------------- + +html_theme = "pydata_sphinx_theme" html_static_path = ["_static"] +html_baseurl = "https://josiahwsmith10.github.io/complextorch/" +html_title = f"complextorch {release}" + +html_context = { + "github_user": "josiahwsmith10", + "github_repo": "complextorch", + "github_version": "main", + "doc_path": "docs/source", +} + +html_theme_options = { + "github_url": "https://github.com/josiahwsmith10/complextorch", + "use_edit_page_button": True, + "navigation_with_keys": False, + "show_toc_level": 2, + "show_nav_level": 2, + "navbar_align": "left", + "header_links_before_dropdown": 6, + "icon_links": [ + { + "name": "PyPI", + "url": "https://pypi.org/project/complextorch/", + "icon": "fa-brands fa-python", + }, + ], + "switcher": { + "json_url": "https://josiahwsmith10.github.io/complextorch/_static/switcher.json", + "version_match": release, + }, + "navbar_end": ["version-switcher", "theme-switcher", "navbar-icon-links"], + "footer_start": ["copyright", "sphinx-version"], + "footer_end": ["theme-version"], +} + +# -- OpenGraph & sitemap ----------------------------------------------------- + +ogp_site_url = html_baseurl +ogp_image = None # add a social card later if a logo is created +sitemap_url_scheme = "{link}" -import os -import sys +# -- sphinx-multiversion ----------------------------------------------------- -sys.path.insert(0, os.path.abspath("../../")) +# Whitelist only post-migration releases. Older tags (pre-2.1) used a different +# conf.py and dependency set; they remain accessible via PyPI but are not +# re-rendered here. +smv_tag_whitelist = r"^2\.[1-9]\d*\.\d+$" +smv_branch_whitelist = r"^main$" +smv_remote_whitelist = None +smv_released_pattern = r"^refs/tags/.*$" +smv_outputdir_format = "{ref.name}" +smv_prefer_remote_refs = False diff --git a/docs/source/examples/getting_started.md b/docs/source/examples/getting_started.md new file mode 100644 index 0000000..5ad01d5 --- /dev/null +++ b/docs/source/examples/getting_started.md @@ -0,0 +1,149 @@ +--- +file_format: mystnb +kernelspec: + name: python3 + display_name: Python 3 +--- + +# Getting started + +This notebook is **executed on every docs build** — if it stops running +against the latest `complextorch`, CI fails. Treat it as a smoke-test of the +public API as well as a tutorial. + +## 1 · Imports & version check + +```{code-cell} +import torch +import complextorch as ctorch + +print(f"torch {torch.__version__}") +print(f"complextorch {ctorch.__version__}") +``` + +## 2 · Building a complex tensor + +`complextorch` operates on complex-dtype `torch.Tensor` (typically +`torch.cfloat`). There is no special wrapper type — use PyTorch's built-ins +directly: + +```{code-cell} +torch.manual_seed(0) + +x = torch.randn(8, 5, 16, dtype=torch.cfloat) # (batch, channels, length) +print(x.shape, x.dtype) +print(x[0, 0, :3]) +``` + +You can construct from magnitude / phase via `torch.polar`: + +```{code-cell} +mag = torch.rand(8, 5, 16) +phase = torch.rand(8, 5, 16) * (2 * torch.pi) - torch.pi +z = torch.polar(mag, phase) +print(z.dtype, z[0, 0, 0]) +``` + +## 3 · Conv1d + Linear (the README example) + +The native cfloat modules (`Conv1d`, `Linear`, ...) are thin wrappers around +`torch.nn` with `dtype=torch.cfloat`. See +[Native vs. Gauss-trick modules](../concepts/native-vs-gauss.md) for the design rationale. + +```{code-cell} +conv = ctorch.nn.Conv1d(in_channels=5, out_channels=16, kernel_size=3) +fc = ctorch.nn.Linear(in_features=16 * 14, out_features=4) + +h = conv(x) # (8, 16, 14) +h_flat = h.reshape(h.size(0), -1) # (8, 16*14) +y = fc(h_flat) # (8, 4) + +print("conv output:", h.shape, h.dtype) +print("fc output: ", y.shape, y.dtype) +``` + +Both modules accept and emit complex tensors — and gradients flow through +them just like any real-valued `torch.nn` module: + +```{code-cell} +loss = y.abs().pow(2).mean() +loss.backward() + +total_grad_norm = sum(p.grad.abs().pow(2).sum() for p in conv.parameters()).sqrt() +print(f"loss = {loss.item():.4f}, conv grad norm = {total_grad_norm:.4f}") +``` + +## 4 · Type-A vs. Type-B activations + +The package implements two paradigms for complex activations (see +[Activations](../concepts/activations.md) for the math). Let's compare a +Type-A `CVSplitReLU` (independent real/imag) against a Type-B `modReLU` +(magnitude-only) on the same input. + +```{code-cell} +import matplotlib.pyplot as plt + +z = torch.complex( + real=torch.linspace(-2, 2, 200).repeat(200, 1), + imag=torch.linspace(-2, 2, 200).repeat(200, 1).T, +) + +split_relu = ctorch.nn.CVSplitReLU() +mod_relu = ctorch.nn.modReLU(bias=-0.5) + +with torch.no_grad(): + a = split_relu(z) + b = mod_relu(z) + +fig, axes = plt.subplots(2, 2, figsize=(8, 7), sharex=True, sharey=True) +for ax, data, title in zip( + axes.flat, + [a.abs(), a.angle(), b.abs(), b.angle()], + ["CVSplitReLU |·|", "CVSplitReLU ∠", "modReLU |·|", "modReLU ∠"], +): + im = ax.imshow(data, extent=[-2, 2, -2, 2], origin="lower", + cmap="twilight" if "∠" in title else "viridis") + ax.set_title(title) + fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) +axes[1, 0].set_xlabel("Re(z)"); axes[1, 1].set_xlabel("Re(z)") +axes[0, 0].set_ylabel("Im(z)"); axes[1, 0].set_ylabel("Im(z)") +plt.tight_layout(); +``` + +`CVSplitReLU` zeros the real/imag components independently — it doesn't +preserve phase. `modReLU` only modulates magnitude (`|z| - b`)+ and leaves the +phase untouched. + +## 5 · Welch's PSD on a complex signal + +{func}`complextorch.signal.pwelch` is a torch port of `scipy.signal.welch` +that's differentiable end-to-end — so it can sit inside a loss function. + +```{code-cell} +from complextorch.signal import pwelch + +t = torch.linspace(0, 1, 4096) +sig = torch.exp(1j * 2 * torch.pi * 50 * t).to(torch.cfloat) \ + + 0.5 * torch.exp(1j * 2 * torch.pi * 120 * t).to(torch.cfloat) \ + + 0.1 * torch.randn(4096, dtype=torch.cfloat) + +f, psd = pwelch(sig, fs=4096.0, window=256, n_overlap=128) + +plt.figure(figsize=(7, 3)) +plt.semilogy(f.numpy(), psd.numpy()) +plt.xlabel("Frequency (Hz)"); plt.ylabel("PSD"); plt.title("pwelch demo") +plt.tight_layout(); +``` + +The two tones at 50 Hz and 120 Hz should be clearly visible. Because `pwelch` +is autograd-friendly, you can use the PSD as a spectral loss for training a +complex-valued generator network. + +## Where next? + +- Browse the [API reference](../api/complextorch/index) for the full module + surface (`nn`, `signal`, `transforms`, `datasets`, `models`). +- Read the [Activations](../concepts/activations.md) deep-dive for Type-A / + Type-B / fully-complex / ReLU-variant theory. +- Check the [changelog](../changelog.md) for what landed in the current + release. diff --git a/docs/source/examples/index.md b/docs/source/examples/index.md new file mode 100644 index 0000000..c998683 --- /dev/null +++ b/docs/source/examples/index.md @@ -0,0 +1,18 @@ +# Examples + +Runnable, re-executed-on-every-build examples that demonstrate `complextorch` +on small but realistic problems. If an example breaks against the latest +`complextorch` source, the docs build fails — so these are guaranteed not to +rot. + +```{toctree} +:maxdepth: 1 + +getting_started +``` + +## Suggesting an example + +Examples live as MyST-NB notebooks under `docs/source/examples/`. Open a PR +with a new `.md` file in that directory and add it to the toctree above; the +docs build will execute it on every commit. diff --git a/docs/source/index.md b/docs/source/index.md new file mode 100644 index 0000000..75ad0f4 --- /dev/null +++ b/docs/source/index.md @@ -0,0 +1,105 @@ +--- +sd_hide_title: true +--- + +# complextorch + +```{rst-class} lead +A lightweight complex-valued neural network package built on PyTorch. +``` + +`complextorch` provides drop-in `complextorch.nn.*` modules whose names mirror +`torch.nn.*` ({class}`torch.nn.Conv1d`, {class}`torch.nn.Linear`, ...), so a +real-valued PyTorch model can be ported to complex-valued by changing the +import. The library emphasises 1-D signal-processing / radar / comms workloads, +but most layers are also provided for 2-D and 3-D. + +::::{grid} 1 2 2 3 +:gutter: 3 +:margin: 4 4 0 0 + +:::{grid-item-card} {fas}`rocket` Getting started +:link: examples/getting_started +:link-type: doc + +A runnable notebook covering the README example, activation comparisons, +and an end-to-end Conv1d demo. +::: + +:::{grid-item-card} {fas}`book` Concepts +:link: concepts/activations +:link-type: doc + +Type-A / Type-B / fully-complex activations, and when to reach for the +native cfloat vs. Gauss-trick (real/imag split) modules. +::: + +:::{grid-item-card} {fas}`code` API reference +:link: api/complextorch/index +:link-type: doc + +Auto-generated reference for every public class and function in +`complextorch`, with cross-links into PyTorch's docs. +::: + +:::: + +## Install + +```sh +pip install complextorch +``` + +PyTorch is **not** installed automatically — install the wheel matching your +CUDA/CPU target from first. See +[Installation](installation.md) for source-install and development setup. + +## Why `complextorch`? + +- **Native cfloat wrappers.** {class}`complextorch.nn.Conv1d`, + {class}`complextorch.nn.Linear`, and friends are thin wrappers around + `torch.nn` modules with `dtype=torch.cfloat`. PyTorch ≥ 2.1 has fast complex + kernels — these are the recommended path. +- **Reference implementations on hand.** The {mod}`complextorch.nn.gauss` + subpackage keeps the original real/imag-split Gauss-trick implementations + ({class}`complextorch.nn.gauss.Conv1d`, etc.) around as reference math. +- **Three composition primitives.** Activations, pooling, losses, dropout, and + softmax are built on `apply_complex`, `apply_complex_split`, and + `apply_complex_polar` in {mod}`complextorch.nn.functional` — see + [Activations](concepts/activations.md) for the math. +- **Beyond layers.** Includes {mod}`complextorch.signal` (a torch port of + Welch's PSD), {mod}`complextorch.transforms` (torchcvnn-style transforms), + {mod}`complextorch.nn.init` (variance-correct complex initializers), + {mod}`complextorch.nn.relevance` (Variational Dropout & ARD), and + {mod}`complextorch.nn.masked` (fixed-mask sparsified layers). + +## Citation + +If `complextorch` helps your research, please cite the package and consider +citing the author's [PhD thesis](https://arxiv.org/abs/2306.15341) and related +papers — see [About](about.md) for the full list. + +```{toctree} +:hidden: +:caption: User guide + +installation +examples/index +concepts/activations +concepts/native-vs-gauss +``` + +```{toctree} +:hidden: +:caption: Reference + +api/complextorch/index +changelog +``` + +```{toctree} +:hidden: +:caption: About + +about +``` diff --git a/docs/source/index.rst b/docs/source/index.rst deleted file mode 100755 index 7ef4d2a..0000000 --- a/docs/source/index.rst +++ /dev/null @@ -1,36 +0,0 @@ -Welcome to complextorch's documentation! -======================================== - -:Author: Josiah W. Smith -:Version: |release| of 04/13/2025 - -A lightweight complex-valued neural network package built on PyTorch. - -This is a package built on `PyTorch `_ with the intention of implementing light-weight interfaces for common complex-valued neural network operations and architectures. -Notably, we include efficient implementations for linear, convolution, and attention modules in addition to activation functions and normalization layers such as batchnorm and layernorm. - -Although there is an emphasis on 1-D data tensors, due to a focus on signal processing, communications, and radar data, many of the routines are implemented for 2-D and 3-D data as well. - -Version 1.1 Release Notes: - -* Methods have been renamed to reflect identical names in PyTorch, e.g., `complextorch.nn.CVConv1d` was renamed to `complextorch.nn.Conv1d`. This change was implemented for quick conversion from PyTorch to `complextorch`. -* Use of `torch.Tensor` is now recommended over `complextorch.CVTensor`. Previous speed advantages of `complextorch.CVTensor` are no longer present if using a version of PyTorch newer than 2.1.0. -* Similarly, previous implementations of `complextorch.nn.Conv1d` (for 1-D, 2-D, 3-D, and transposed convolution) and `complextorch.nn.Linear` have been renamed with the prefix `Slow` as PyTorch's native convolution and linear operators now outperform that of `complextorch`. Now, `complextorch.nn.Conv1d`, for example, uses `torch.nn.Conv1d` with `dtype=torch.float` for maximum efficiency. - -.. toctree:: - :maxdepth: 3 - - installation - nn - -.. toctree:: - :maxdepth: 1 - - about - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` diff --git a/docs/source/installation.md b/docs/source/installation.md new file mode 100644 index 0000000..bc12f06 --- /dev/null +++ b/docs/source/installation.md @@ -0,0 +1,56 @@ +# Installation + +```{important} +**Install PyTorch first.** PyTorch is **not** installed automatically with +`complextorch` — you must install the wheel matching your CUDA / CPU target +from before installing this package. +``` + +## From PyPI + +```sh +pip install complextorch +``` + +[PyPI project page →](https://pypi.org/project/complextorch/) + +## From source + +For local development or to track `main`: + +```sh +git clone https://github.com/josiahwsmith10/complextorch.git +cd complextorch +pip install -e . +``` + +## Optional extras + +| Extra | Adds | When to use it | +| --- | --- | --- | +| `complextorch[test]` | `pytest`, `pytest-cov`, `pytest-xdist`, `hypothesis` | Running the test suite locally. | +| `complextorch[docs]` | Sphinx + PyData theme + MyST + autoapi + myst-nb + multiversion | Building these docs locally. | +| `complextorch[datasets]` | `h5py` | Future SAR / MRI dataset readers (`PolSFDataset`, `MICCAI2023`, ...). | +| `complextorch[datasets-alos]` | `rasterio` | ALOS-2 CEOS-format reader (needs system-level GDAL). | + +Install combinations with the usual pip syntax, e.g.: + +```sh +pip install -e .[test,docs] +``` + +## Verifying the install + +```python +import torch +import complextorch as ctorch + +print(ctorch.__version__) + +x = torch.randn(8, 5, 7, dtype=torch.cfloat) +y = ctorch.nn.Conv1d(5, 16, kernel_size=3)(x) +print(y.shape, y.dtype) # torch.Size([8, 16, 5]) torch.complex64 +``` + +If that runs without error, you're set. Continue to the +[Getting Started notebook](examples/getting_started.md). diff --git a/docs/source/installation.rst b/docs/source/installation.rst deleted file mode 100755 index d87913d..0000000 --- a/docs/source/installation.rst +++ /dev/null @@ -1,23 +0,0 @@ -Installation -============ - -IMPORTANT ---------- -Prior to installation, `install PyTorch `_ to your environment using your preferred method using the compute platform (CPU/GPU) settings for your machine. -PyTorch will not be automatically installed with the installation of complextorch and MUST be installed manually by the user. - - -Using `pip `_ ------------------------------------------------------ - -:: - - pip install complextorch - - -Using GitHub ------------- - -Useful if you want to modify the source code:: - - git clone https://github.com/josiahwsmith10/complextorch.git diff --git a/docs/source/nn.rst b/docs/source/nn.rst deleted file mode 100755 index d10f1cc..0000000 --- a/docs/source/nn.rst +++ /dev/null @@ -1,8 +0,0 @@ -NN -== - -.. toctree:: - :maxdepth: 2 - - nn/modules - nn/functional \ No newline at end of file diff --git a/docs/source/nn/functional.rst b/docs/source/nn/functional.rst deleted file mode 100755 index 926e4c1..0000000 --- a/docs/source/nn/functional.rst +++ /dev/null @@ -1,5 +0,0 @@ -Functional -========== - -.. automodule:: complextorch.nn.functional - :members: diff --git a/docs/source/nn/modules.rst b/docs/source/nn/modules.rst deleted file mode 100755 index 152fb3a..0000000 --- a/docs/source/nn/modules.rst +++ /dev/null @@ -1,19 +0,0 @@ -Modules -======= - -.. toctree:: - :maxdepth: 2 - - modules/activation - modules/attention - modules/batchnorm - modules/conv - modules/dropout - modules/fft - modules/layernorm - modules/linear - modules/loss - modules/manifold - modules/mask - modules/pooling - modules/softmax diff --git a/docs/source/nn/modules/activation.rst b/docs/source/nn/modules/activation.rst deleted file mode 100755 index f4423d1..0000000 --- a/docs/source/nn/modules/activation.rst +++ /dev/null @@ -1,31 +0,0 @@ -Activation -========== - -Complex-valued activation functions must take into account the 2 degrees-of-freedom inherent to complex-valued data, typically represented as real and imaginary parts or magnitude and phase. -Two common generalized classes of complex-valued activation functions operate on these respective representations and are defined as *Type-A* and *Type-B* functions. - -:doc:`Type-A <./activation/split_type_A>` activation functions consist of two real-valued functions, :math:`G_\mathbb{R}(\cdot)` and :math:`G_\mathbb{I}(\cdot)`, which are applied to the real and imaginary parts of the input tensor, respectively, as - -.. math:: - - G(\mathbf{z}) = G_\mathbb{R}(\mathbf{x}) + j G_\mathbb{I}(\mathbf{y}) - -where :math:`\mathbf{z} = \mathbf{x} + j\mathbf{y}`. - -:doc:`Type-B <./activation/split_type_B>` activation functions consist of two real-valued functions, :math:`G_{||}(\cdot)` and :math:`G_\angle(\cdot)`, which are applied to the magnitude (modulus) and phase (angle, argument) of the input tensor, respectively, as - -.. math:: - - G(\mathbf{z}) = G_{||}(|\mathbf{z}|) * \exp(j G_\angle(\text{angle}(\mathbf{z}))). - -In contrast, :doc:`fully complex activation functions <./activation/fully_complex>` fit neither the :doc:`Split Type-A <./activation/split_type_A>` or :doc:`Split Type-B <./activation/split_type_B>` designation. - -The final designation of complex-valued activation functions detailed in this work are extensions of the :doc:`Rectified Linear Unit (ReLU) to the complex plane<./activation/complex_relu>`. - -.. toctree:: - :maxdepth: 2 - - activation/split_type_A - activation/split_type_B - activation/fully_complex - activation/complex_relu diff --git a/docs/source/nn/modules/activation/complex_relu.rst b/docs/source/nn/modules/activation/complex_relu.rst deleted file mode 100755 index f230c93..0000000 --- a/docs/source/nn/modules/activation/complex_relu.rst +++ /dev/null @@ -1,11 +0,0 @@ -Complex-Valued Rectified Linear Units -===================================== - -The *Rectified Linear Unit (ReLU)* is the most common activation function in modern data-driven algorithms. -Hence, it is oft-extended to the complex domain. -However, whereas its nonlinearity lends itself naturally to the real domain, its application to the complex domain, specifically its activation in different quadrants of the complex plane, has led to further investigation. - -These variants of the complex-valued ReLU are all *Type-A* split activation functions, meaning they apply a function separately to the real and imaginary parts of the input tensor, as detailed in :doc:`Split Type-A <./split_type_A>`. - -.. automodule:: complextorch.nn.modules.activation.complex_relu - :members: diff --git a/docs/source/nn/modules/activation/fully_complex.rst b/docs/source/nn/modules/activation/fully_complex.rst deleted file mode 100755 index 450c4b4..0000000 --- a/docs/source/nn/modules/activation/fully_complex.rst +++ /dev/null @@ -1,7 +0,0 @@ -Fully Complex Activation Functions -================================== - -These activation functions are fully-complex, meaning they fit neither the :doc:`Split Type-A <./split_type_A>` or :doc:`Split Type-B <./split_type_B>` designation. - -.. automodule:: complextorch.nn.modules.activation.fully_complex - :members: diff --git a/docs/source/nn/modules/activation/split_type_A.rst b/docs/source/nn/modules/activation/split_type_A.rst deleted file mode 100755 index 12ca75d..0000000 --- a/docs/source/nn/modules/activation/split_type_A.rst +++ /dev/null @@ -1,16 +0,0 @@ -Split Type-A Activation Functions -================================= - -*Type-A* activation functions consist of two real-valued functions, :math:`G_\mathbb{R}(\cdot)` and :math:`G_\mathbb{I}(\cdot)`, which are applied to the real and imaginary parts of the input tensor, respectively, as - -.. math:: - - G(\mathbf{z}) = G_\mathbb{R}(\mathbf{x}) + j G_\mathbb{I}(\mathbf{y}), - -where :math:`\mathbf{z} = \mathbf{x} + j\mathbf{y}`. - -In most cases, :math:`G_\mathbb{R}(\cdot) = G_\mathbb{I}(\cdot)`; however, :math:`G_\mathbb{R}(\cdot)` and :math:`G_\mathbb{I}(\cdot)` can also be distinct functions. -A generalized Type-A split activation function is defined in :class:`GeneralizedSplitActivation`, which accepts two real-valued torch.nn.Module objects for :math:`G_\mathbb{R}(\cdot)` and :math:`G_\mathbb{I}(\cdot)`, respectively. - -.. automodule:: complextorch.nn.modules.activation.split_type_A - :members: diff --git a/docs/source/nn/modules/activation/split_type_B.rst b/docs/source/nn/modules/activation/split_type_B.rst deleted file mode 100755 index 57d00d8..0000000 --- a/docs/source/nn/modules/activation/split_type_B.rst +++ /dev/null @@ -1,14 +0,0 @@ -Polar Type-B Activation Functions -================================= - -*Type-B* activation functions consist of two real-valued functions, :math:`G_{||}(\cdot)` and :math:`G_\angle(\cdot)`, which are applied to the magnitude (modulus) and phase (angle, argument) of the input tensor, respectively, as - -.. math:: - - G(\mathbf{z}) = G_{||}(|\mathbf{z}|) * \exp(j G_\angle(\text{angle}(\mathbf{z}))) - -A generalized Type-B split activation function is defined in :class:`GeneralizedPolarActivation`, which accepts two real-valued torch.nn.Module objects for :math:`G_{||}(\cdot)` and :math:`G_\angle(\cdot)`, respectively. - - -.. automodule:: complextorch.nn.modules.activation.split_type_B - :members: diff --git a/docs/source/nn/modules/attention.rst b/docs/source/nn/modules/attention.rst deleted file mode 100755 index 9e3333b..0000000 --- a/docs/source/nn/modules/attention.rst +++ /dev/null @@ -1,14 +0,0 @@ -Attention -========= - -Whereas attention-based models, such as transformers, have gained significant attention for natural language processing (NLP) and image processing, their potential for implementation in complex-valued problems such as signal processing remains relatively untapped. -Here, we include complex-valued variants of several attention-based techniques. - -.. automodule:: complextorch.nn.modules.attention - :members: - -.. toctree:: - :maxdepth: 2 - - attention/eca - attention/mca \ No newline at end of file diff --git a/docs/source/nn/modules/attention/eca.rst b/docs/source/nn/modules/attention/eca.rst deleted file mode 100755 index 74abb64..0000000 --- a/docs/source/nn/modules/attention/eca.rst +++ /dev/null @@ -1,5 +0,0 @@ -Efficient Channel Attention -=========================== - -.. automodule:: complextorch.nn.modules.attention.eca - :members: diff --git a/docs/source/nn/modules/attention/mca.rst b/docs/source/nn/modules/attention/mca.rst deleted file mode 100755 index 428be9a..0000000 --- a/docs/source/nn/modules/attention/mca.rst +++ /dev/null @@ -1,5 +0,0 @@ -Masked Attention -================ - -.. automodule:: complextorch.nn.modules.attention.mca - :members: diff --git a/docs/source/nn/modules/batchnorm.rst b/docs/source/nn/modules/batchnorm.rst deleted file mode 100755 index f205ff0..0000000 --- a/docs/source/nn/modules/batchnorm.rst +++ /dev/null @@ -1,19 +0,0 @@ -Batch Normalization -=================== - -Batch normalization is a crucial element to the convergence and robustness of many deep learning applications; however, its implementation must be carefully address for complex-valued data. -The complex-valued corollary to zero-mean unit variance normalization is known as whitening. - -Additional details can be found in the following paper: - - **J. A. Barrachina, C. Ren, G. Vieillard, C. Morisseau, and J.-P. Ovarlez. Theory and Implementation of Complex-Valued Neural Networks.** - - - Section 6 - - - https://arxiv.org/abs/2302.08286 - - -For other complex-valued normalization methods, see :doc:`layernorm <./layernorm>`. - -.. automodule:: complextorch.nn.modules.batchnorm - :members: diff --git a/docs/source/nn/modules/conv.rst b/docs/source/nn/modules/conv.rst deleted file mode 100755 index 95142ae..0000000 --- a/docs/source/nn/modules/conv.rst +++ /dev/null @@ -1,7 +0,0 @@ -Convolution Layers -================== - -Whereas `complextorch` 1.0.0 recommended our implementation of convolution, we now recommend using the PyTorch version as it outperforms ours post PyTorch version 2.1.0. - -.. automodule:: complextorch.nn.modules.conv - :members: diff --git a/docs/source/nn/modules/dropout.rst b/docs/source/nn/modules/dropout.rst deleted file mode 100755 index ec9338a..0000000 --- a/docs/source/nn/modules/dropout.rst +++ /dev/null @@ -1,5 +0,0 @@ -Dropout -======= - -.. automodule:: complextorch.nn.modules.dropout - :members: \ No newline at end of file diff --git a/docs/source/nn/modules/fft.rst b/docs/source/nn/modules/fft.rst deleted file mode 100755 index bc98628..0000000 --- a/docs/source/nn/modules/fft.rst +++ /dev/null @@ -1,5 +0,0 @@ -FFT -=== - -.. automodule:: complextorch.nn.modules.fft - :members: \ No newline at end of file diff --git a/docs/source/nn/modules/layernorm.rst b/docs/source/nn/modules/layernorm.rst deleted file mode 100755 index 931e600..0000000 --- a/docs/source/nn/modules/layernorm.rst +++ /dev/null @@ -1,16 +0,0 @@ -Layer Normalization -=================== - -Similar to :doc:`batch normalization <./batchnorm>`, layer normalization is a crucial element to the convergence and robustness of many deep learning applications; however, its implementation must be carefully address for complex-valued data. -The complex-valued corollary to zero-mean unit variance normalization is known as whitening. - -Additional details can be found in the following paper: - - **J. A. Barrachina, C. Ren, G. Vieillard, C. Morisseau, and J.-P. Ovarlez. Theory and Implementation of Complex-Valued Neural Networks.** - - - Section 6 - - - https://arxiv.org/abs/2302.08286 - -.. automodule:: complextorch.nn.modules.layernorm - :members: diff --git a/docs/source/nn/modules/linear.rst b/docs/source/nn/modules/linear.rst deleted file mode 100755 index 93d3fc8..0000000 --- a/docs/source/nn/modules/linear.rst +++ /dev/null @@ -1,5 +0,0 @@ -Linear -====== - -.. automodule:: complextorch.nn.modules.linear - :members: diff --git a/docs/source/nn/modules/loss.rst b/docs/source/nn/modules/loss.rst deleted file mode 100755 index 811d7c8..0000000 --- a/docs/source/nn/modules/loss.rst +++ /dev/null @@ -1,7 +0,0 @@ -Loss Functions -============== - -Similar to activation functions, two general types of loss functions have similar forms to Type-A and Type-B activations, operating on the real and imaginary or magnitude and phase, respectively. - -.. automodule:: complextorch.nn.modules.loss - :members: diff --git a/docs/source/nn/modules/manifold.rst b/docs/source/nn/modules/manifold.rst deleted file mode 100755 index e80d871..0000000 --- a/docs/source/nn/modules/manifold.rst +++ /dev/null @@ -1,32 +0,0 @@ -Manifold-Based Layers -===================== - -In a paper titled `SurReal: Complex-Valued Learning as Principled Transformations on a Scaling and Rotation Manifold`, the authors R Chakraborty, Y Xing, and S Yu introduce a complex-valued convolution operator offering similar equivariance properties to the spatial equivariance of the traditional real-valued convolution operator. -By approaching the complex domain as a Riemannian homogeneous space consisting of the product of planar rotation and non-zero scaling, they define a convolution operator equivariant to phase shift and amplitude scaling. -Although their paper shows promising results in reducing the number of parameters of a complex-valued network for several problems, their work has not gained mainstream support. - -As the authors mention in the final bullet point in Section IV-A1, - - If :math:`d` is the manifold distance in (2) for the Euclidean - space that is also Riemannian, then wFM has exactly the - weighted average as its closed-form solution. That is, our - wFM convolution on the Euclidean manifold is reduced - to the standard convolution, although with the additional - convexity constraint on the weights. - -Hence, the implementation closely follows the conventional convolution operator with the exception of the weight normalization. - -Note: the weight normalization, although consistent with the authors' implementation, lacks adequate explanation from the literature and could be improved for further clarity. - -Based on work from the following paper: - - **R Chakraborty, Y Xing, S Yu. SurReal: Complex-Valued Learning as Principled Transformations on a Scaling and Rotation Manifold** - - - Eqs. (14)-(16) - - - https://arxiv.org/abs/1910.11334 - - - Modified from implementation: https://github.com/xingyifei2016/RotLieNet (yields consistent results as this implementation) - -.. automodule:: complextorch.nn.modules.manifold - :members: diff --git a/docs/source/nn/modules/mask.rst b/docs/source/nn/modules/mask.rst deleted file mode 100755 index 9cbea4c..0000000 --- a/docs/source/nn/modules/mask.rst +++ /dev/null @@ -1,5 +0,0 @@ -Magnitude Masking Layers -======================== - -.. automodule:: complextorch.nn.modules.mask - :members: diff --git a/docs/source/nn/modules/pooling.rst b/docs/source/nn/modules/pooling.rst deleted file mode 100755 index 91941a8..0000000 --- a/docs/source/nn/modules/pooling.rst +++ /dev/null @@ -1,5 +0,0 @@ -Pooling -======= - -.. automodule:: complextorch.nn.modules.pooling - :members: diff --git a/docs/source/nn/modules/softmax.rst b/docs/source/nn/modules/softmax.rst deleted file mode 100755 index 394e06b..0000000 --- a/docs/source/nn/modules/softmax.rst +++ /dev/null @@ -1,5 +0,0 @@ -Softmax Functions -================= - -.. automodule:: complextorch.nn.modules.softmax - :members: diff --git a/pyproject.toml b/pyproject.toml index 1b68d94..08913b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,139 @@ [build-system] -requires = ["setuptools>=42", "wheel"] -build-backend = "setuptools.build_meta" \ No newline at end of file +requires = ["setuptools>=68.2.2", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "complextorch" +dynamic = ["version"] +description = "A lightweight complex-valued neural network package built on PyTorch" +readme = "README.md" +license = {file = "LICENSE"} +requires-python = ">=3.10" +authors = [ + {name = "Josiah W. Smith", email = "josiah.radar@gmail.com"}, +] +keywords = ["deep learning", "pytorch", "complex valued neural networks"] +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] +dependencies = [ + "numpy>=2.2.0", + "torch>=2.1.0", + "scipy>=1.10.0", +] + +[project.urls] +Homepage = "https://github.com/josiahwsmith10/complextorch" +Documentation = "https://josiahwsmith10.github.io/complextorch/latest/" +Repository = "https://github.com/josiahwsmith10/complextorch" +"Bug Tracker" = "https://github.com/josiahwsmith10/complextorch/issues" +Changelog = "https://github.com/josiahwsmith10/complextorch/blob/main/CHANGELOG.md" + +[project.optional-dependencies] +docs = [ + # The whole stack is pinned to Sphinx 7 because sphinx-multiversion 0.2.4 + # is incompatible with Sphinx >= 8 (Config.read signature changed). + # myst-parser >= 5 / myst-nb >= 2 require Sphinx >= 8, so they are also + # pinned. Migrate to sphinx-polyversion to unblock these floors. + "sphinx>=7.3,<8", + "pydata-sphinx-theme>=0.15,<0.16", + "myst-parser>=3.0,<5", + "myst-nb>=1.1,<2", + "sphinx-autoapi>=3.3,<4", + "sphinx-copybutton>=0.5", + "sphinx-design>=0.6,<0.7", + "sphinx-multiversion>=0.2", + "sphinx-sitemap>=2.6", + "sphinxext-opengraph>=0.9", + "matplotlib>=3.7", +] +test = [ + "pytest>=8", + "pytest-cov>=5", + "pytest-xdist>=3", + "hypothesis>=6", +] +# Tooling for contributors: linter/formatter + pre-commit driver. +dev = [ + "ruff>=0.15", + "pre-commit>=4", +] +# SAR + MRI dataset readers (PolSAR, MICCAI cardiac cine MRI, etc.) +datasets = [ + "h5py>=3.7", +] +# ALOS-2 CEOS-format reader (needs system-level GDAL via rasterio) +datasets-alos = [ + "rasterio>=1.3", +] + +[tool.setuptools.dynamic] +version = {attr = "complextorch.__version__"} + +[tool.setuptools.packages.find] +include = ["complextorch", "complextorch.*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "-n auto --strict-markers --strict-config -ra" +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::PendingDeprecationWarning", + "ignore::UserWarning", +] + +[tool.coverage.run] +branch = false +source = ["complextorch"] +omit = ["*/__pycache__/*"] + +[tool.coverage.report] +fail_under = 100 +show_missing = true +skip_covered = false +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError", + "if TYPE_CHECKING:", + "@overload", + "if __name__ == .__main__.:", + "\\.\\.\\.", +] + +[tool.ruff] +target-version = "py310" +line-length = 88 +extend-exclude = ["docs/source/_build"] + +[tool.ruff.lint] +select = [ + "E", "W", # pycodestyle errors + warnings + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear (likely bugs) + "UP", # pyupgrade (modern Python syntax) + "SIM", # flake8-simplify + "RUF", # ruff-specific + "C4", # flake8-comprehensions + "PT", # pytest-style + "PIE", # flake8-pie + "RET", # flake8-return + "TCH", # flake8-type-checking +] +ignore = [ + "E501", # line-too-long — handled by the formatter +] +# `×` (U+00D7) is intentional math notation in docstrings/comments (e.g. "2×2 whitening"). +allowed-confusables = ["×"] + +[tool.ruff.lint.per-file-ignores] +# Public API surface — manual import grouping is intentional; __all__ is the contract. +"complextorch/nn/__init__.py" = ["I001"] +"complextorch/nn/modules/activation/__init__.py" = ["I001"] +# Sphinx conf reads version before sys.path tweaks; keep authorial layout. +"docs/source/conf.py" = ["I001"] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 5a4d30f..0000000 --- a/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -numpy>=2.2.0 -setuptools>=68.2.2 -torch>=1.11.0+cu115 -deprecated>=1.2.18 diff --git a/setup.py b/setup.py deleted file mode 100755 index 72a88af..0000000 --- a/setup.py +++ /dev/null @@ -1,46 +0,0 @@ -import re -from pathlib import Path - -from setuptools import setup, find_packages - - -def _read_version() -> str: - init_py = Path(__file__).parent / "complextorch" / "__init__.py" - match = re.search( - r'^__version__\s*=\s*"([^"]+)"', init_py.read_text(), re.MULTILINE - ) - if not match: - raise RuntimeError("Cannot find __version__ in complextorch/__init__.py") - return match.group(1) - - -setup( - name="complextorch", - version=_read_version(), - author="Josiah W. Smith", - author_email="josiah.radar@gmail.com", - description="A lightweight complex-valued neural network package built on PyTorch", - long_description=open("README.md").read(), - long_description_content_type="text/markdown", - url="https://www.github.com/josiahwsmith10/complextorch", - packages=find_packages(include=["complextorch", "complextorch.*"]), - include_package_data=True, - project_urls={ - "Bug Tracker": "https://github.com/josiahwsmith10/complextorch/issues", - "Documentation": "https://complextorch.readthedocs.io/en/latest/index.html", - "GitHub": "https://github.com/josiahwsmith10/complextorch", - }, - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)", - "Operating System :: OS Independent", - ], - package_dir={"": "."}, - python_requires=">=3.6", - install_requires=[ - "numpy>=2.2.0", - "setuptools>=68.2.2", - "torch>=1.11.0", - "deprecated>=1.2.18", - ], -) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..b0cd56c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,35 @@ +"""Shared pytest fixtures for the complextorch test suite.""" + +from __future__ import annotations + +import random + +import numpy as np +import pytest +import torch + + +@pytest.fixture(autouse=True) +def _seed_everything(): + torch.manual_seed(0) + random.seed(0) + np.random.seed(0) + return + + +@pytest.fixture +def device() -> torch.device: + return torch.device("cpu") + + +@pytest.fixture +def cplx(): + def _make( + *shape: int, dtype: torch.dtype = torch.cfloat, requires_grad: bool = False + ) -> torch.Tensor: + t = torch.randn(*shape, dtype=dtype) + if requires_grad: + t.requires_grad_(True) + return t + + return _make diff --git a/tests/datasets/__init__.py b/tests/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/datasets/test_enums.py b/tests/datasets/test_enums.py new file mode 100644 index 0000000..6b4405a --- /dev/null +++ b/tests/datasets/test_enums.py @@ -0,0 +1,25 @@ +"""Tests for the CINEView / AccFactor enums in the datasets registry.""" + +from __future__ import annotations + +from complextorch.datasets import AccFactor, CINEView + + +def test_cineview_members(): + assert CINEView.SAX.value == "SAX" + assert CINEView.LAX.value == "LAX" + assert list(CINEView) == [CINEView.SAX, CINEView.LAX] + + +def test_accfactor_members(): + assert AccFactor.R4.value == 4 + assert AccFactor.R8.value == 8 + assert AccFactor.R10.value == 10 + + +def test_cineview_str_compat(): + assert CINEView.SAX == "SAX" + + +def test_accfactor_int_compat(): + assert AccFactor.R4 == 4 diff --git a/tests/datasets/test_sample.py b/tests/datasets/test_sample.py new file mode 100644 index 0000000..990310c --- /dev/null +++ b/tests/datasets/test_sample.py @@ -0,0 +1,36 @@ +"""Tests for the SAMPLE in-memory synthetic dataset.""" + +from __future__ import annotations + +import torch + +from complextorch.datasets import SAMPLE + + +def test_sample_basic(): + ds = SAMPLE(num_samples=8, channels=2, height=4, width=4, num_classes=5, seed=42) + assert len(ds) == 8 + chip, label = ds[0] + assert chip.shape == (2, 4, 4) + assert chip.is_complex() + assert isinstance(label, int) + assert 0 <= label < 5 + + +def test_sample_reproducibility(): + a = SAMPLE(num_samples=4, height=8, width=8, seed=7) + b = SAMPLE(num_samples=4, height=8, width=8, seed=7) + torch.testing.assert_close(a[2][0], b[2][0]) + + +def test_sample_transform_applied(): + sentinel = torch.zeros(1, 4, 4, dtype=torch.cfloat) + ds = SAMPLE(num_samples=2, height=4, width=4, transform=lambda _x: sentinel) + chip, _ = ds[1] + torch.testing.assert_close(chip, sentinel) + + +def test_sample_root_ignored(): + """root is accepted for API parity but not required to exist.""" + ds = SAMPLE(root="/nonexistent/path", num_samples=2) + assert len(ds) == 2 diff --git a/tests/datasets/test_slc.py b/tests/datasets/test_slc.py new file mode 100644 index 0000000..dd262f8 --- /dev/null +++ b/tests/datasets/test_slc.py @@ -0,0 +1,75 @@ +"""Tests for the generic SLCDataset reader.""" + +from __future__ import annotations + +import numpy as np +import pytest +import torch + +from complextorch.datasets import SLCDataset + + +def _seed_npy_dir(root, n_files=3, shape=(1, 4, 4)): + for i in range(n_files): + arr = (np.random.randn(*shape) + 1j * np.random.randn(*shape)).astype( + np.complex64 + ) + np.save(root / f"chip_{i:03d}.npy", arr) + + +def test_slc_dataset_npy(tmp_path): + _seed_npy_dir(tmp_path, n_files=4) + ds = SLCDataset(tmp_path) + assert len(ds) == 4 + chip = ds[0] + assert chip.shape == (1, 4, 4) + assert chip.is_complex() + + +def test_slc_dataset_pt(tmp_path): + for i in range(3): + t = torch.randn(1, 4, 4, dtype=torch.cfloat) + torch.save(t, tmp_path / f"chip_{i:03d}.pt") + ds = SLCDataset(tmp_path, suffix=".pt") + assert len(ds) == 3 + chip = ds[1] + assert chip.is_complex() + + +def test_slc_dataset_with_annotations(tmp_path): + _seed_npy_dir(tmp_path, n_files=3) + ann = tmp_path / "labels.txt" + ann.write_text("0\n1\n2\n") + ds = SLCDataset(tmp_path, annotation_file=ann) + chip, label = ds[2] + assert label == 2 + assert chip.is_complex() + + +def test_slc_dataset_with_blank_lines_in_annotations(tmp_path): + _seed_npy_dir(tmp_path, n_files=2) + ann = tmp_path / "labels.txt" + ann.write_text("0\n\n1\n\n") # blanks filtered + ds = SLCDataset(tmp_path, annotation_file=ann) + assert len(ds.labels) == 2 + + +def test_slc_dataset_missing_root_raises(tmp_path): + with pytest.raises(FileNotFoundError, match="root not found"): + SLCDataset(tmp_path / "does_not_exist") + + +def test_slc_dataset_annotation_mismatch_raises(tmp_path): + _seed_npy_dir(tmp_path, n_files=2) + ann = tmp_path / "labels.txt" + ann.write_text("0\n1\n2\n") # 3 labels, 2 files + with pytest.raises(ValueError, match="annotation count"): + SLCDataset(tmp_path, annotation_file=ann) + + +def test_slc_dataset_transform_applied(tmp_path): + _seed_npy_dir(tmp_path, n_files=2) + sentinel = torch.ones(1, 4, 4, dtype=torch.cfloat) * 5.0 + ds = SLCDataset(tmp_path, transform=lambda _x: sentinel) + chip = ds[0] + torch.testing.assert_close(chip, sentinel) diff --git a/tests/datasets/test_stubs.py b/tests/datasets/test_stubs.py new file mode 100644 index 0000000..fc3212b --- /dev/null +++ b/tests/datasets/test_stubs.py @@ -0,0 +1,46 @@ +"""Tests asserting NotImplementedError stubs for the heavy SAR/MRI loaders.""" + +from __future__ import annotations + +import pytest + +from complextorch.datasets import ( + MICCAI2023, + S1SLC, + ALOSDataset, + ATRNetSTAR, + Bretigny, + LeaderFile, + MSTARTargets, + PolSFDataset, + SARImage, + TrailerFile, + VolFile, +) + + +@pytest.mark.parametrize( + "cls", + [ + PolSFDataset, + Bretigny, + S1SLC, + MSTARTargets, + ATRNetSTAR, + MICCAI2023, + ALOSDataset, + VolFile, + LeaderFile, + TrailerFile, + SARImage, + ], +) +def test_stub_raises_not_implemented(cls): + with pytest.raises(NotImplementedError, match="not yet implemented"): + cls() + + +def test_stub_reference_points_upstream(): + """The error message must reference the upstream torchcvnn class.""" + with pytest.raises(NotImplementedError, match="torchcvnn"): + PolSFDataset() diff --git a/tests/invariants/__init__.py b/tests/invariants/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/invariants/test_casting_roundtrip.py b/tests/invariants/test_casting_roundtrip.py new file mode 100644 index 0000000..46c3f0d --- /dev/null +++ b/tests/invariants/test_casting_roundtrip.py @@ -0,0 +1,40 @@ +"""Property tests: casting modules are inverses of each other.""" + +from __future__ import annotations + +import torch +from hypothesis import given, settings +from hypothesis import strategies as st + +from complextorch.nn.modules.casting import ( + ComplexToConcatenated, + ComplexToInterleaved, + ConcatenatedToComplex, + InterleavedToComplex, + RealToComplex, +) + + +@given(d=st.integers(1, 8)) +@settings(max_examples=5, deadline=None) +def test_interleaved_roundtrip(d): + z = torch.randn(3, d, dtype=torch.cfloat) + back = InterleavedToComplex()(ComplexToInterleaved()(z)) + torch.testing.assert_close(back, z) + + +@given(d=st.integers(1, 8)) +@settings(max_examples=5, deadline=None) +def test_concatenated_roundtrip(d): + z = torch.randn(3, d, dtype=torch.cfloat) + back = ConcatenatedToComplex()(ComplexToConcatenated()(z)) + torch.testing.assert_close(back, z) + + +@given(d=st.integers(1, 8)) +@settings(max_examples=5, deadline=None) +def test_real_to_complex_zeros_imag(d): + x = torch.randn(3, d) + z = RealToComplex()(x) + torch.testing.assert_close(z.real, x) + torch.testing.assert_close(z.imag, torch.zeros_like(x)) diff --git a/tests/invariants/test_equivariance.py b/tests/invariants/test_equivariance.py new file mode 100644 index 0000000..047c3a7 --- /dev/null +++ b/tests/invariants/test_equivariance.py @@ -0,0 +1,108 @@ +"""U(1) equivariance / invariance verification across CDS + SurReal modules. + +Each test rotates the input by a global complex phase :math:`e^{j\\psi}` and +checks the documented commutation property of the module. + +.. note:: + :class:`wFMConv2d` and :class:`wFMReLU` (SurReal) operate on the + rotation+scaling manifold; their equivariance is defined relative to + that manifold's group action, not to the strict global phase rotation + used in this file. They are intentionally not covered here. +""" + +from __future__ import annotations + +import torch + +from complextorch.nn import ( + ComplexScaling, + EquivariantPhaseReLU, + MagBatchNorm2d, + PhaseConjConv2d, + PhaseDivConv2d, + PhaseShift, +) + + +def _rotor(psi: float = 1.3) -> torch.Tensor: + return torch.polar(torch.tensor(1.0), torch.tensor(psi)) + + +# --------------------------------------------------------------------------- +# Equivariant modules: M(x · e^{jψ}) = M(x) · e^{jψ} +# --------------------------------------------------------------------------- + + +def test_phase_shift_is_u1_equivariant(): + layer = PhaseShift(num_features=4) + x = torch.randn(2, 4, 6, 6, dtype=torch.cfloat) + rotor = _rotor() + torch.testing.assert_close(layer(x * rotor), layer(x) * rotor, atol=1e-5, rtol=1e-5) + + +def test_complex_scaling_is_u1_equivariant(): + layer = ComplexScaling(num_features=4) + x = torch.randn(2, 4, 6, 6, dtype=torch.cfloat) + rotor = _rotor() + torch.testing.assert_close(layer(x * rotor), layer(x) * rotor, atol=1e-5, rtol=1e-5) + + +def test_equivariant_phase_relu_is_u1_equivariant(): + layer = EquivariantPhaseReLU(num_channels=4) + x = torch.randn(2, 4, 6, 6, dtype=torch.cfloat) + 0.1 + rotor = _rotor() + torch.testing.assert_close(layer(x * rotor), layer(x) * rotor, atol=1e-4, rtol=1e-4) + + +def test_mag_batchnorm_is_u1_equivariant(): + layer = MagBatchNorm2d(num_features=4) + layer.train() + # Warm running stats with one forward, then verify in eval mode where + # BN does not depend on the batch (so rotation acts purely on phase). + _ = layer(torch.randn(8, 4, 6, 6, dtype=torch.cfloat)) + layer.eval() + x = torch.randn(2, 4, 6, 6, dtype=torch.cfloat) + 0.1 + rotor = _rotor() + torch.testing.assert_close(layer(x * rotor), layer(x) * rotor, atol=1e-4, rtol=1e-4) + + +# --------------------------------------------------------------------------- +# Invariant modules: M(x · e^{jψ}) = M(x) +# --------------------------------------------------------------------------- + + +def test_phase_div_conv2d_is_u1_invariant(): + layer = PhaseDivConv2d(in_channels=3, kernel_size=3, padding=1) + x = torch.randn(2, 3, 5, 5, dtype=torch.cfloat) + 0.1 + rotor = _rotor() + torch.testing.assert_close(layer(x * rotor), layer(x), atol=1e-4, rtol=1e-4) + + +def test_phase_conj_conv2d_is_u1_invariant(): + """With a C-linear inner conv, ``x · conj(g(x))`` is invariant too.""" + layer = PhaseConjConv2d(in_channels=3, kernel_size=3, padding=1) + x = torch.randn(2, 3, 5, 5, dtype=torch.cfloat) + 0.1 + rotor = _rotor() + torch.testing.assert_close(layer(x * rotor), layer(x), atol=1e-4, rtol=1e-4) + + +# --------------------------------------------------------------------------- +# Composition: chain of equivariant ops should remain equivariant. +# --------------------------------------------------------------------------- + + +def test_composition_is_u1_equivariant(): + """Conv → ComplexScaling → EquivariantPhaseReLU is equivariant. + + Note: complextorch's stock ``Conv2d`` is *not* U(1)-equivariant in general + (a generic complex linear map mixes phase and magnitude). For a + composition that's *guaranteed* equivariant we use only modules that + preserve phase: ComplexScaling, EquivariantPhaseReLU. + """ + block = torch.nn.Sequential( + ComplexScaling(num_features=4), + EquivariantPhaseReLU(num_channels=4), + ) + x = torch.randn(2, 4, 6, 6, dtype=torch.cfloat) + 0.1 + rotor = _rotor() + torch.testing.assert_close(block(x * rotor), block(x) * rotor, atol=1e-4, rtol=1e-4) diff --git a/tests/invariants/test_fft_roundtrip.py b/tests/invariants/test_fft_roundtrip.py new file mode 100644 index 0000000..d099dd2 --- /dev/null +++ b/tests/invariants/test_fft_roundtrip.py @@ -0,0 +1,18 @@ +"""Property tests: IFFT(FFT(x)) ≡ x.""" + +from __future__ import annotations + +import torch +from hypothesis import given, settings +from hypothesis import strategies as st + +from complextorch.nn.modules.fft import FFTBlock, IFFTBlock + + +@given(n=st.integers(4, 32)) +@settings(max_examples=5, deadline=None) +def test_fft_ifft_roundtrip(n): + fwd = FFTBlock(dim=-1, norm="ortho") + inv = IFFTBlock(dim=-1, norm="ortho") + x = torch.randn(2, n, dtype=torch.cfloat) + torch.testing.assert_close(inv(fwd(x)), x, atol=1e-4, rtol=1e-4) diff --git a/tests/invariants/test_native_gauss_equivalence.py b/tests/invariants/test_native_gauss_equivalence.py new file mode 100644 index 0000000..604f1e7 --- /dev/null +++ b/tests/invariants/test_native_gauss_equivalence.py @@ -0,0 +1,78 @@ +"""Property tests: native cfloat and Gauss-trick conv/linear agree given shared weights.""" + +from __future__ import annotations + +import pytest +import torch +from hypothesis import given, settings +from hypothesis import strategies as st + +from complextorch.nn.gauss.conv import ( + Conv1d as GaussConv1d, +) +from complextorch.nn.gauss.conv import ( + Conv2d as GaussConv2d, +) +from complextorch.nn.gauss.conv import ( + Conv3d as GaussConv3d, +) +from complextorch.nn.gauss.linear import Linear as GaussLinear +from complextorch.nn.modules.conv import Conv1d, Conv2d, Conv3d +from complextorch.nn.modules.linear import Linear + + +def _align_gauss_conv_with_native(gauss, native): + with torch.no_grad(): + gauss.conv_r.weight.copy_(native.conv.weight.real) + gauss.conv_i.weight.copy_(native.conv.weight.imag) + if native.conv.bias is not None: + gauss.bias_r.copy_(native.conv.bias.real) + gauss.bias_i.copy_(native.conv.bias.imag) + + +@pytest.mark.parametrize(("in_ch", "out_ch", "k"), [(2, 4, 3), (1, 1, 1), (3, 2, 5)]) +@given(batch=st.integers(1, 3)) +@settings(max_examples=3, deadline=None) +def test_conv1d_native_gauss_equivalence(in_ch, out_ch, k, batch): + native = Conv1d(in_ch, out_ch, kernel_size=k, bias=True) + gauss = GaussConv1d(in_ch, out_ch, kernel_size=k, bias=True) + _align_gauss_conv_with_native(gauss, native) + x = torch.randn(batch, in_ch, 16, dtype=torch.cfloat) + torch.testing.assert_close(native(x), gauss(x), rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize(("in_ch", "out_ch", "k"), [(2, 4, 3), (1, 1, 1)]) +@given(batch=st.integers(1, 2)) +@settings(max_examples=3, deadline=None) +def test_conv2d_native_gauss_equivalence(in_ch, out_ch, k, batch): + native = Conv2d(in_ch, out_ch, kernel_size=k, bias=True) + gauss = GaussConv2d(in_ch, out_ch, kernel_size=k, bias=True) + _align_gauss_conv_with_native(gauss, native) + x = torch.randn(batch, in_ch, 8, 8, dtype=torch.cfloat) + torch.testing.assert_close(native(x), gauss(x), rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize(("in_ch", "out_ch", "k"), [(2, 2, 3)]) +@settings(max_examples=2, deadline=None) +@given(batch=st.integers(1, 2)) +def test_conv3d_native_gauss_equivalence(in_ch, out_ch, k, batch): + native = Conv3d(in_ch, out_ch, kernel_size=k, bias=True) + gauss = GaussConv3d(in_ch, out_ch, kernel_size=k, bias=True) + _align_gauss_conv_with_native(gauss, native) + x = torch.randn(batch, in_ch, 4, 4, 4, dtype=torch.cfloat) + torch.testing.assert_close(native(x), gauss(x), rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize(("in_f", "out_f"), [(6, 8), (1, 1), (12, 5)]) +@given(batch=st.integers(1, 4)) +@settings(max_examples=3, deadline=None) +def test_linear_native_gauss_equivalence(in_f, out_f, batch): + native = Linear(in_f, out_f, bias=True) + gauss = GaussLinear(in_f, out_f, bias=True) + with torch.no_grad(): + gauss.linear_r.weight.copy_(native.linear.weight.real) + gauss.linear_i.weight.copy_(native.linear.weight.imag) + gauss.bias_r.copy_(native.linear.bias.real) + gauss.bias_i.copy_(native.linear.bias.imag) + x = torch.randn(batch, in_f, dtype=torch.cfloat) + torch.testing.assert_close(native(x), gauss(x), rtol=1e-4, atol=1e-4) diff --git a/tests/invariants/test_polar_roundtrip.py b/tests/invariants/test_polar_roundtrip.py new file mode 100644 index 0000000..4f26d1e --- /dev/null +++ b/tests/invariants/test_polar_roundtrip.py @@ -0,0 +1,18 @@ +"""Property tests: torch.polar(|z|, ∠z) ≡ z (away from |z|=0).""" + +from __future__ import annotations + +import pytest +import torch +from hypothesis import given, settings +from hypothesis import strategies as st + + +@pytest.mark.parametrize("shape", [(4,), (2, 3), (1, 4, 5)]) +@given(seed=st.integers(0, 10_000)) +@settings(max_examples=5, deadline=None) +def test_polar_roundtrip(shape, seed): + g = torch.Generator().manual_seed(seed) + z = torch.randn(*shape, generator=g, dtype=torch.cfloat) + 0.5 + back = torch.polar(z.abs(), z.angle()) + torch.testing.assert_close(back, z, atol=1e-5, rtol=1e-5) diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/test_cds.py b/tests/models/test_cds.py new file mode 100644 index 0000000..0d5faee --- /dev/null +++ b/tests/models/test_cds.py @@ -0,0 +1,76 @@ +"""Smoke tests for the CDS reference models.""" + +from __future__ import annotations + +import torch + +from complextorch.models.cds import CDSMSTAR, CDSEquivariant, CDSInvariant + + +def test_cds_invariant_forward_backward(): + model = CDSInvariant(input_channels=2, num_classes=10, prototype_size=32) + x = torch.randn(2, 2, 32, 32, dtype=torch.cfloat) + logits = model(x) + assert logits.shape == (2, 10) + logits.sum().backward() + # At least one param should have a finite, non-zero gradient. + grads = [p.grad for p in model.parameters() if p.grad is not None] + assert any(g.abs().sum().item() > 0 for g in grads) + + +def test_cds_equivariant_forward_backward(): + model = CDSEquivariant(input_channels=2, num_classes=10, prototype_size=32) + x = torch.randn(2, 2, 32, 32, dtype=torch.cfloat) + logits = model(x) + assert logits.shape == (2, 10) + logits.sum().backward() + grads = [p.grad for p in model.parameters() if p.grad is not None] + assert any(g.abs().sum().item() > 0 for g in grads) + + +def test_cds_mstar_forward_backward(): + # MSTAR input: 1 complex channel; size 88x88 is the original SAR chip size. + model = CDSMSTAR(num_classes=10) + x = torch.randn(2, 1, 88, 88, dtype=torch.cfloat) + 0.1 + logits = model(x) + assert logits.shape == (2, 10) + logits.sum().backward() + + +def test_cds_models_accept_real_input(): + """All three CDS models auto-cast a real input to cfloat in forward.""" + for cls, x in [ + ( + CDSInvariant(input_channels=2, num_classes=4, prototype_size=8), + torch.randn(2, 2, 32, 32), + ), + ( + CDSEquivariant(input_channels=2, num_classes=4, prototype_size=8), + torch.randn(2, 2, 32, 32), + ), + (CDSMSTAR(num_classes=4), torch.randn(2, 1, 88, 88) + 0.1), + ]: + out = cls(x) + assert out.shape[0] == 2 + + +def test_cds_invariant_is_phase_invariant_in_eval(): + """The DivConv after wfm1 makes CDSInvariant invariant to global phase. + + Use eval() because the trailing real BatchNorm1d (which mixes magnitudes + of real/imag parts in a way that breaks pure invariance) needs running + stats fixed, AND in training mode batch-norm depends on the rotation. + """ + model = CDSInvariant(input_channels=2, num_classes=4, prototype_size=16) + x = torch.randn(2, 2, 32, 32, dtype=torch.cfloat) + # Warm up running stats. + model.train() + _ = model(x) + model.eval() + rotor = torch.polar(torch.tensor(1.0), torch.tensor(0.8)) + logits1 = model(x) + logits2 = model(x * rotor) + # Should match closely (the path before BN is genuinely invariant; BN + # operates on real+imag concatenation so post-BN there is slight rotation + # dependence). Loosen tolerance accordingly. + torch.testing.assert_close(logits1, logits2, atol=5e-2, rtol=5e-2) diff --git a/tests/models/test_vit.py b/tests/models/test_vit.py new file mode 100644 index 0000000..988c47e --- /dev/null +++ b/tests/models/test_vit.py @@ -0,0 +1,81 @@ +"""Tests for the complex Vision Transformer (ViT) and preset factories.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.models import ViT, ViTLayer, vit_b, vit_h, vit_l, vit_s, vit_t + +# ---------- ViTLayer ---------- + + +def test_vit_layer_forward(): + layer = ViTLayer(dim=8, nhead=2, mlp_dim=16) + x = torch.randn(2, 5, 8, dtype=torch.cfloat) + out = layer(x) + assert out.shape == x.shape + assert out.is_complex() + + +def test_vit_layer_invalid_dim(): + with pytest.raises(ValueError, match="must be divisible"): + ViTLayer(dim=8, nhead=3, mlp_dim=16) + + +# ---------- ViT (full) ---------- + + +def test_vit_forward_tiny(): + vit = ViT( + image_size=16, + patch_size=4, + in_channels=1, + num_classes=5, + dim=16, + depth=1, + heads=2, + mlp_dim=32, + ) + x = torch.randn(1, 1, 16, 16, dtype=torch.cfloat) + out = vit(x) + assert out.shape == (1, 5) + assert out.is_complex() + + +def test_vit_no_head(): + vit = ViT( + image_size=16, patch_size=4, num_classes=0, dim=16, depth=1, heads=2, mlp_dim=32 + ) + x = torch.randn(1, 1, 16, 16, dtype=torch.cfloat) + out = vit(x) + # nn.Identity head -> dim still complex + assert out.shape == (1, 16) + + +def test_vit_invalid_image_patch(): + with pytest.raises(ValueError, match="must be divisible"): + ViT( + image_size=17, + patch_size=4, + num_classes=5, + dim=8, + depth=1, + heads=2, + mlp_dim=16, + ) + + +# ---------- Presets (smoke: just instantiate; don't run forward on full sizes) ---------- + + +@pytest.mark.parametrize("factory", [vit_t, vit_s, vit_b, vit_l, vit_h]) +def test_vit_factories_instantiate(factory): + """Use small image_size to keep memory reasonable.""" + model = factory( + image_size=14 if factory is vit_h else 16, + patch_size=14 if factory is vit_h else 16, + num_classes=0, + ) + # Just confirm it's a ViT. + assert isinstance(model, ViT) diff --git a/tests/nn/__init__.py b/tests/nn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/nn/gauss/__init__.py b/tests/nn/gauss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/nn/gauss/test_conv.py b/tests/nn/gauss/test_conv.py new file mode 100644 index 0000000..11c482d --- /dev/null +++ b/tests/nn/gauss/test_conv.py @@ -0,0 +1,131 @@ +"""Tests for the Gauss-trick complex Conv* and ConvTranspose* layers.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.gauss.conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) + + +@pytest.mark.parametrize( + ("cls", "x_shape", "kwargs"), + [ + ( + Conv1d, + (2, 2, 8), + {"in_channels": 2, "out_channels": 4, "kernel_size": 3, "padding": 1}, + ), + ( + Conv2d, + (2, 2, 6, 6), + {"in_channels": 2, "out_channels": 3, "kernel_size": 3, "padding": 1}, + ), + ( + Conv3d, + (1, 2, 4, 4, 4), + {"in_channels": 2, "out_channels": 2, "kernel_size": 3, "padding": 1}, + ), + ], +) +def test_gauss_conv_forward(cls, x_shape, kwargs): + layer = cls(**kwargs) + x = torch.randn(*x_shape, dtype=torch.cfloat) + out = layer(x) + assert out.is_complex() + assert layer.weight.is_complex() + assert layer.bias.is_complex() + + +@pytest.mark.parametrize( + ("cls", "x_shape", "kwargs"), + [ + ( + Conv1d, + (2, 2, 8), + { + "in_channels": 2, + "out_channels": 4, + "kernel_size": 3, + "padding": 1, + "bias": False, + }, + ), + ( + Conv2d, + (1, 2, 6, 6), + { + "in_channels": 2, + "out_channels": 2, + "kernel_size": 3, + "padding": 1, + "bias": False, + }, + ), + ], +) +def test_gauss_conv_no_bias(cls, x_shape, kwargs): + layer = cls(**kwargs) + x = torch.randn(*x_shape, dtype=torch.cfloat) + out = layer(x) + assert layer.bias is None + assert out.is_complex() + + +@pytest.mark.parametrize( + ("cls", "x_shape", "kwargs"), + [ + ( + ConvTranspose1d, + (2, 2, 4), + {"in_channels": 2, "out_channels": 3, "kernel_size": 3, "stride": 2}, + ), + ( + ConvTranspose2d, + (1, 2, 4, 4), + {"in_channels": 2, "out_channels": 2, "kernel_size": 3, "stride": 2}, + ), + ( + ConvTranspose3d, + (1, 2, 2, 2, 2), + {"in_channels": 2, "out_channels": 2, "kernel_size": 3, "stride": 2}, + ), + ], +) +def test_gauss_convtranspose_forward(cls, x_shape, kwargs): + layer = cls(**kwargs) + x = torch.randn(*x_shape, dtype=torch.cfloat) + out = layer(x) + assert out.is_complex() + assert layer.weight.is_complex() + assert layer.bias.is_complex() + + +@pytest.mark.parametrize( + ("cls", "x_shape", "kwargs"), + [ + ( + ConvTranspose1d, + (2, 2, 4), + {"in_channels": 2, "out_channels": 3, "kernel_size": 3, "bias": False}, + ), + ( + ConvTranspose2d, + (1, 2, 4, 4), + {"in_channels": 2, "out_channels": 2, "kernel_size": 3, "bias": False}, + ), + ], +) +def test_gauss_convtranspose_no_bias(cls, x_shape, kwargs): + layer = cls(**kwargs) + x = torch.randn(*x_shape, dtype=torch.cfloat) + out = layer(x) + assert layer.bias is None + assert out.is_complex() diff --git a/tests/nn/gauss/test_linear.py b/tests/nn/gauss/test_linear.py new file mode 100644 index 0000000..0417120 --- /dev/null +++ b/tests/nn/gauss/test_linear.py @@ -0,0 +1,30 @@ +"""Tests for the Gauss-trick complex Linear layer.""" + +from __future__ import annotations + +import torch + +from complextorch.nn.gauss.linear import Linear + + +def test_gauss_linear_forward(): + layer = Linear(8, 4, bias=True) + x = torch.randn(3, 8, dtype=torch.cfloat) + out = layer(x) + assert out.shape == (3, 4) + assert out.is_complex() + # weight property + w = layer.weight + assert w.shape == (4, 8) + assert w.is_complex() + # bias property + b = layer.bias + assert b.shape == (4,) + + +def test_gauss_linear_no_bias(): + layer = Linear(8, 4, bias=False) + x = torch.randn(3, 8, dtype=torch.cfloat) + out = layer(x) + assert out.shape == (3, 4) + assert layer.bias is None diff --git a/tests/nn/masked/__init__.py b/tests/nn/masked/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/nn/masked/test_base.py b/tests/nn/masked/test_base.py new file mode 100644 index 0000000..32226b5 --- /dev/null +++ b/tests/nn/masked/test_base.py @@ -0,0 +1,152 @@ +"""Tests for BaseMasked, deploy_masks, binarize_masks, is_sparse, named_masks.""" + +from __future__ import annotations + +import pytest +import torch +import torch.nn as nn + +from complextorch.nn.masked import ( + LinearMasked, + binarize_masks, + deploy_masks, + is_sparse, + named_masks, +) + + +def test_baselayer_is_sparse_when_masked(): + layer = LinearMasked(4, 6) + assert not layer.is_sparse + layer.mask = torch.ones(6, 4) + assert layer.is_sparse + + +def test_mask_setter_with_non_tensor_raises(): + layer = LinearMasked(4, 6) + with pytest.raises(TypeError, match="Tensor or None"): + layer.mask = [1, 2, 3] + + +def test_setattr_passthrough_for_non_mask(): + layer = LinearMasked(4, 6) + layer.in_features = 42 # should not go through mask_ + assert layer.in_features == 42 + + +def test_mask_clear_via_none(): + layer = LinearMasked(4, 6) + layer.mask = torch.ones(6, 4) + assert layer.is_sparse + layer.mask = None + assert not layer.is_sparse + + +def test_deploy_masks_with_qualified_names(): + model = nn.Module() + model.linear = LinearMasked(4, 6) + mask = torch.zeros(6, 4) + mask[0:2] = 1.0 + deploy_masks(model, {"linear.mask": mask}) + assert model.linear.is_sparse + torch.testing.assert_close(model.linear.mask, mask) + + +def test_deploy_masks_strict_missing_module_raises(): + model = nn.Module() + with pytest.raises(KeyError, match="no BaseMasked"): + deploy_masks(model, {"nope.mask": torch.zeros(4)}, strict=True) + + +def test_deploy_masks_non_strict_silently_skips(): + model = nn.Module() + deploy_masks(model, {"nope.mask": torch.zeros(4)}, strict=False) + + +def test_deploy_masks_skips_non_mask_keys(): + model = nn.Module() + model.linear = LinearMasked(4, 6) + deploy_masks(model, {"linear.weight": torch.zeros(6, 4)}) + assert not model.linear.is_sparse + + +def test_binarize_masks_makes_binary(): + layer = LinearMasked(4, 6) + layer.mask = torch.full((6, 4), 0.3) + binarize_masks(layer) + assert ((layer.mask == 0) | (layer.mask == 1)).all() + + +def test_named_masks_iterates(): + model = nn.Module() + model.l1 = LinearMasked(4, 6) + model.l2 = LinearMasked(4, 6) + model.l1.mask = torch.ones(6, 4) + names = [n for n, _ in named_masks(model)] + assert "l1" in names + assert "l2" not in names # l2 doesn't have a mask + + +def test_is_sparse_function(): + layer = LinearMasked(4, 6) + assert not is_sparse(layer) + layer.mask = torch.ones(6, 4) + assert is_sparse(layer) + assert not is_sparse(nn.Linear(4, 6)) # not a BaseMasked + + +def test_state_dict_roundtrip_with_mask(): + layer = LinearMasked(4, 6) + mask = torch.eye(6, 4) + layer.mask = mask + state = layer.state_dict() + assert "mask" in state + other = LinearMasked(4, 6) + other.load_state_dict(state) + assert other.is_sparse + torch.testing.assert_close(other.mask, mask) + + +def test_state_dict_load_missing_mask_strict_false(): + layer = LinearMasked(4, 6) + state = { + "weight": layer.weight.detach(), + "bias": layer.bias.detach(), + } + layer.load_state_dict(state, strict=False) + + +def test_state_dict_load_missing_mask_strict_true_records_missing(): + layer = LinearMasked(4, 6) + state = { + "weight": layer.weight.detach(), + "bias": layer.bias.detach(), + } + missing, _ = layer.load_state_dict(state, strict=False) + assert "mask" in missing + + +def test_weight_masked_property_raises_without_mask(): + layer = LinearMasked(4, 6) + with pytest.raises(RuntimeError, match="no sparsity mask"): + _ = layer.weight_masked + + +def test_state_dict_load_when_layer_has_mask_and_state_has_mask(): + """Hits the 'mask_in_missing + state has mask' path: remove from missing_keys.""" + src = LinearMasked(4, 6) + src.mask = torch.eye(6, 4) + state = src.state_dict() + tgt = LinearMasked(4, 6) + tgt.mask = torch.ones(6, 4) # non-None so super sees it during load + tgt.load_state_dict(state) + torch.testing.assert_close(tgt.mask, src.mask) + + +def test_state_dict_load_strict_mask_missing_records_in_missing_keys(): + """Hits strict-mode branch where mask is absent from state_dict.""" + layer = LinearMasked(4, 6) + layer.mask = torch.ones(6, 4) # non-None so super sees it + state = {"weight": layer.weight.detach(), "bias": layer.bias.detach()} + missing, _ = layer.load_state_dict(state, strict=False) + assert "mask" in missing diff --git a/tests/nn/masked/test_layers.py b/tests/nn/masked/test_layers.py new file mode 100644 index 0000000..40388e5 --- /dev/null +++ b/tests/nn/masked/test_layers.py @@ -0,0 +1,124 @@ +"""Tests for the masked complex layers.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.masked import ( + BilinearMasked, + Conv1dMasked, + Conv2dMasked, + Conv3dMasked, + LinearMasked, +) + +# ---------- LinearMasked / BilinearMasked ---------- + + +def test_linear_masked_dense_forward(): + layer = LinearMasked(4, 6, bias=True) + x = torch.randn(2, 4, dtype=torch.cfloat) + out = layer(x) + assert out.shape == (2, 6) + assert out.is_complex() + + +def test_linear_masked_sparse_forward(): + layer = LinearMasked(4, 6, bias=True) + layer.mask = torch.zeros(6, 4) # zero out everything + x = torch.randn(2, 4, dtype=torch.cfloat) + out = layer(x) + # With zero mask, output equals bias broadcast. + torch.testing.assert_close(out, layer.bias.expand_as(out)) + + +def test_linear_masked_no_bias(): + layer = LinearMasked(4, 6, bias=False) + assert layer.bias is None + x = torch.randn(2, 4, dtype=torch.cfloat) + out = layer(x) + assert out.is_complex() + + +def test_bilinear_masked_forward(): + layer = BilinearMasked(4, 5, 6, bias=True) + x1 = torch.randn(2, 4, dtype=torch.cfloat) + x2 = torch.randn(2, 5, dtype=torch.cfloat) + out = layer(x1, x2) + assert out.shape == (2, 6) + assert out.is_complex() + + +def test_bilinear_masked_no_conjugate(): + layer = BilinearMasked(4, 5, 6, conjugate=False) + x1 = torch.randn(2, 4, dtype=torch.cfloat) + x2 = torch.randn(2, 5, dtype=torch.cfloat) + out = layer(x1, x2) + assert out.shape == (2, 6) + + +def test_bilinear_masked_sparse_forward(): + layer = BilinearMasked(4, 5, 6, bias=True) + layer.mask = torch.zeros(6, 4, 5) + x1 = torch.randn(2, 4, dtype=torch.cfloat) + x2 = torch.randn(2, 5, dtype=torch.cfloat) + out = layer(x1, x2) + torch.testing.assert_close(out, layer.bias.expand_as(out)) + + +def test_bilinear_masked_no_bias(): + layer = BilinearMasked(4, 5, 6, bias=False) + assert layer.bias is None + + +# ---------- ConvMasked ---------- + + +@pytest.mark.parametrize( + ("cls", "shape", "k"), + [ + (Conv1dMasked, (1, 2, 8), 3), + (Conv2dMasked, (1, 2, 6, 6), 3), + (Conv3dMasked, (1, 2, 4, 4, 4), 3), + ], +) +def test_conv_masked_dense_forward(cls, shape, k): + layer = cls(in_channels=2, out_channels=4, kernel_size=k, padding=1) + x = torch.randn(*shape, dtype=torch.cfloat) + out = layer(x) + assert out.is_complex() + + +def test_conv1d_masked_sparse(): + layer = Conv1dMasked(2, 4, kernel_size=3, padding=1) + layer.mask = torch.zeros(4, 2, 3) + x = torch.randn(1, 2, 8, dtype=torch.cfloat) + out = layer(x) + # Zero mask -> output is bias broadcast over output positions + expected = layer.bias.view(1, 4, 1).expand_as(out) + torch.testing.assert_close(out, expected) + + +def test_conv_masked_no_bias(): + layer = Conv1dMasked(2, 4, kernel_size=3, bias=False) + assert layer.bias is None + + +def test_conv_masked_invalid_padding_mode(): + with pytest.raises(ValueError, match="padding_mode"): + Conv1dMasked(2, 4, kernel_size=3, padding_mode="reflect") + + +def test_conv_masked_tuple_kernel_size(): + layer = Conv2dMasked(2, 4, kernel_size=(3, 5), padding=(1, 2)) + x = torch.randn(1, 2, 8, 10, dtype=torch.cfloat) + out = layer(x) + assert out.is_complex() + + +def test_conv_masked_str_padding(): + layer = Conv2dMasked(2, 4, kernel_size=3, padding="same") + x = torch.randn(1, 2, 8, 8, dtype=torch.cfloat) + out = layer(x) + assert out.shape == (1, 4, 8, 8) diff --git a/tests/nn/modules/__init__.py b/tests/nn/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/nn/modules/activation/__init__.py b/tests/nn/modules/activation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/nn/modules/activation/test_complex_relu.py b/tests/nn/modules/activation/test_complex_relu.py new file mode 100644 index 0000000..248edaa --- /dev/null +++ b/tests/nn/modules/activation/test_complex_relu.py @@ -0,0 +1,202 @@ +"""Tests for complex ReLU variants.""" + +from __future__ import annotations + +import math + +import pytest +import torch + +from complextorch.nn.modules.activation.complex_relu import ( + CPReLU, + CReLU, + CVSplitReLU, + EquivariantPhaseReLU, + GTReLU, + _PhaseHalfPlaneMask, + zAbsReLU, + zLeakyReLU, +) + + +@pytest.fixture +def z(): + return torch.randn(4, 8, dtype=torch.cfloat) + + +@pytest.mark.parametrize("cls", [CVSplitReLU, CReLU]) +def test_split_relu_zeros_negatives(cls, z): + out = cls(inplace=False)(z.clone()) + assert (out.real >= 0).all() + assert (out.imag >= 0).all() + + +def test_cprelu_learnable(z): + act = CPReLU() + out = act(z) + assert out.shape == z.shape + # Two PReLU modules, each with a learnable weight + params = list(act.parameters()) + assert len(params) == 2 + + +def test_zabsrelu_threshold_zero_passes_all(): + z = torch.randn(8, dtype=torch.cfloat) + out = zAbsReLU(a_init=0.0)(z) + torch.testing.assert_close(out, z) + + +def test_zabsrelu_high_threshold_zeros_small(): + z = torch.tensor([0.1 + 0j, 5.0 + 0j], dtype=torch.cfloat) + out = zAbsReLU(a_init=1.0)(z) + assert out[0].abs().item() == 0.0 + assert out[1].abs().item() == 5.0 + + +def test_zabsrelu_real_input_path(): + """Test the non-complex branch in forward.""" + x = torch.tensor([0.1, 5.0]) # real, not complex + out = zAbsReLU(a_init=1.0)(x) + assert out[0].abs().item() == 0.0 + + +def test_zleakyrelu_first_quadrant_passes(): + z = torch.tensor([1 + 1j, -1 + 1j, 1 - 1j, -1 - 1j], dtype=torch.cfloat) + out = zLeakyReLU(negative_slope=0.5)(z) + torch.testing.assert_close(out[0], z[0]) + torch.testing.assert_close(out[1], 0.5 * z[1]) + torch.testing.assert_close(out[2], 0.5 * z[2]) + torch.testing.assert_close(out[3], 0.5 * z[3]) + + +def test_zleakyrelu_extra_repr(): + s = zLeakyReLU(negative_slope=0.2).extra_repr() + assert "0.2" in s + + +# ---------- _PhaseHalfPlaneMask (custom autograd) ---------- + + +def test_phase_halfplane_mask_zeros_lower_halfplane(): + phase = torch.tensor([0.0, math.pi / 2, math.pi, 3 * math.pi / 2, -math.pi / 4]) + out = _PhaseHalfPlaneMask.apply(phase) + # 0, pi/2, pi → in [0, pi] (kept). 3pi/2 → mod 2pi = 3pi/2 (not kept). + # -pi/4 → mod 2pi = 7pi/4 (not kept). + assert out[0].item() == pytest.approx(0.0) + assert out[1].item() == pytest.approx(math.pi / 2) + assert out[2].item() == pytest.approx(math.pi) + assert out[3].item() == pytest.approx(0.0) + assert out[4].item() == pytest.approx(0.0) + + +def test_phase_halfplane_mask_gradient_is_mask(): + phase = torch.tensor([math.pi / 4, 3 * math.pi / 2], requires_grad=True) + out = _PhaseHalfPlaneMask.apply(phase) + out.sum().backward() + # Gradient is 1 where the mask is 1, 0 elsewhere + assert phase.grad[0].item() == pytest.approx(1.0) + assert phase.grad[1].item() == pytest.approx(0.0) + + +# ---------- GTReLU ---------- + + +def test_gtrelu_forward_shape(): + layer = GTReLU(num_channels=4) + x = torch.randn(2, 4, 6, 6, dtype=torch.cfloat) + out = layer(x) + assert out.shape == x.shape + assert out.is_complex() + + +def test_gtrelu_phase_scale_adds_parameter(): + layer = GTReLU(num_channels=4, phase_scale=True) + assert layer.lambd is not None + assert layer.lambd.shape == (4,) + + +def test_gtrelu_no_phase_scale_no_lambda(): + layer = GTReLU(num_channels=4, phase_scale=False) + assert layer.lambd is None + + +def test_gtrelu_global_scaling(): + layer = GTReLU(num_channels=4, global_scaling=True) + assert layer.alpha.shape == (1,) + assert layer.beta.shape == (1,) + + +def test_gtrelu_grad_flows(): + layer = GTReLU(num_channels=3, phase_scale=True) + x = torch.randn(2, 3, 4, dtype=torch.cfloat, requires_grad=True) + out = layer(x) + out.abs().sum().backward() + assert torch.isfinite(layer.alpha.grad).all() + assert torch.isfinite(layer.lambd.grad).all() + + +# ---------- EquivariantPhaseReLU ---------- + + +def test_equivariant_phase_relu_forward_shape(): + layer = EquivariantPhaseReLU(num_channels=4) + x = torch.randn(2, 4, 6, 6, dtype=torch.cfloat) + out = layer(x) + assert out.shape == x.shape + assert out.is_complex() + + +def test_equivariant_phase_relu_is_u1_equivariant(): + """Rotating the input by a global phase rotates the output by the same.""" + layer = EquivariantPhaseReLU(num_channels=4) + x = torch.randn(2, 4, 5, 5, dtype=torch.cfloat) + 0.1 + rotor = torch.polar(torch.tensor(1.0), torch.tensor(1.3)) + y1 = layer(x * rotor) + y2 = layer(x) * rotor + torch.testing.assert_close(y1, y2, atol=1e-4, rtol=1e-4) + + +def test_equivariant_phase_relu_grad_flows(): + layer = EquivariantPhaseReLU(num_channels=3) + x = torch.randn(2, 3, 4, dtype=torch.cfloat, requires_grad=True) + out = layer(x) + out.abs().sum().backward() + assert torch.isfinite(layer.phase_gain.grad).all() + + +def test_broadcast_channelwise_passthrough_for_1d_input(): + """1-D input -> the else branch returns t unchanged.""" + from complextorch.nn.modules.activation.complex_relu import _broadcast_channelwise + + t = torch.zeros(4) + out = _broadcast_channelwise(t, input_dim=1) + assert out is t + + +def test_gtrelu_real_input_auto_casts(): + from complextorch.nn.modules.activation.complex_relu import GTReLU + + layer = GTReLU(num_channels=4) + x = torch.randn(2, 4, 6) + out = layer(x) + assert out.is_complex() + + +def test_gtrelu_extra_repr(): + from complextorch.nn.modules.activation.complex_relu import GTReLU + + s = GTReLU(num_channels=4, global_scaling=True, phase_scale=True).extra_repr() + assert "num_channels=4" in s + assert "phase_scale=True" in s + + +def test_equivariant_phase_relu_real_input_auto_casts(): + layer = EquivariantPhaseReLU(num_channels=3) + x = torch.randn(2, 3, 4) + out = layer(x) + assert out.is_complex() + + +def test_equivariant_phase_relu_extra_repr(): + s = EquivariantPhaseReLU(num_channels=4).extra_repr() + assert "num_channels=4" in s diff --git a/tests/nn/modules/activation/test_fully_complex.py b/tests/nn/modules/activation/test_fully_complex.py new file mode 100644 index 0000000..ef9f478 --- /dev/null +++ b/tests/nn/modules/activation/test_fully_complex.py @@ -0,0 +1,70 @@ +"""Tests for fully-complex activation functions.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.activation.fully_complex import ( + CVCardiod, + CVSigLog, + CVSigmoid, + Mod, + zReLU, +) + + +@pytest.fixture +def z(): + return torch.randn(8, dtype=torch.cfloat) + + +def test_cvsigmoid_forward(z): + out = CVSigmoid()(z) + assert out.shape == z.shape + assert out.is_complex() + + +def test_zrelu_zeros_outside_first_quadrant(): + # In Q1 (angle in [0, pi/2]) -> passes; elsewhere -> 0 + z = torch.tensor([1 + 1j, -1 + 1j, -1 - 1j, 1 - 1j], dtype=torch.cfloat) + out = zReLU()(z) + assert out[0].abs().item() > 0 # Q1 -> passed + assert out[1].abs().item() == 0 # Q2 -> zero + assert out[2].abs().item() == 0 # Q3 -> zero + assert out[3].abs().item() == 0 # Q4 -> zero + + +def test_cvcardiod_real_positive_axis_unchanged(): + z = torch.tensor([1.0 + 0j, 2.0 + 0j]) + out = CVCardiod()(z) + # angle=0 -> 0.5*(1+cos(0))=1 -> output == z + torch.testing.assert_close(out, z) + + +def test_cvcardiod_zeros_negative_real(): + z = torch.tensor([-1.0 + 0j, -3.0 + 0j]) + out = CVCardiod()(z) + # angle=pi -> 0.5*(1+cos(pi))=0 -> output == 0 + torch.testing.assert_close(out, torch.zeros_like(z)) + + +def test_cvsiglog_default_params(z): + act = CVSigLog() + out = act(z) + assert out.shape == z.shape + expected = z / (1.0 + z.abs() / 1.0) + torch.testing.assert_close(out, expected) + + +def test_cvsiglog_custom_c_r(z): + act = CVSigLog(c=2.0, r=3.0) + out = act(z) + expected = z / (2.0 + z.abs() / 3.0) + torch.testing.assert_close(out, expected) + + +def test_mod_returns_magnitude(z): + out = Mod()(z) + torch.testing.assert_close(out, z.abs()) + assert not out.is_complex() diff --git a/tests/nn/modules/activation/test_split_type_a.py b/tests/nn/modules/activation/test_split_type_a.py new file mode 100644 index 0000000..b686642 --- /dev/null +++ b/tests/nn/modules/activation/test_split_type_a.py @@ -0,0 +1,81 @@ +"""Tests for split Type-A activation functions.""" + +from __future__ import annotations + +import pytest +import torch +import torch.nn as nn + +from complextorch.nn.modules.activation.split_type_A import ( + CCELU, + CELU, + CGELU, + CSigmoid, + CTanh, + CVSplitAbs, + CVSplitCELU, + CVSplitELU, + CVSplitGELU, + CVSplitSigmoid, + CVSplitTanh, + GeneralizedSplitActivation, +) + + +@pytest.fixture +def z(): + return torch.randn(4, 8, dtype=torch.cfloat) + + +@pytest.mark.parametrize( + "cls", + [ + CVSplitTanh, + CTanh, + CVSplitSigmoid, + CSigmoid, + CVSplitAbs, + CVSplitELU, + CELU, + CVSplitCELU, + CCELU, + CVSplitGELU, + CGELU, + ], +) +def test_split_type_a_forward_preserves_shape_and_complex_dtype(cls, z): + act = cls() + out = act(z) + assert out.shape == z.shape + assert out.is_complex() + + +def test_generalized_split_activation_user_provided_modules(z): + act = GeneralizedSplitActivation(nn.ReLU(), nn.Tanh()) + out = act(z) + expected = torch.complex(torch.relu(z.real), torch.tanh(z.imag)) + torch.testing.assert_close(out, expected) + + +def test_cvsplit_abs_outputs_absolute_values(z): + out = CVSplitAbs()(z) + torch.testing.assert_close(out.real, z.real.abs()) + torch.testing.assert_close(out.imag, z.imag.abs()) + + +def test_cvsplit_elu_with_alpha_inplace(z): + act = CVSplitELU(alpha=0.5, inplace=False) + out = act(z) + assert out.shape == z.shape + + +def test_cvsplit_celu_with_alpha(z): + act = CVSplitCELU(alpha=2.0, inplace=False) + out = act(z) + assert out.shape == z.shape + + +def test_cvsplit_gelu_tanh_approximate(z): + act = CVSplitGELU(approximate="tanh") + out = act(z) + assert out.shape == z.shape diff --git a/tests/nn/modules/activation/test_split_type_b.py b/tests/nn/modules/activation/test_split_type_b.py new file mode 100644 index 0000000..5ee6f23 --- /dev/null +++ b/tests/nn/modules/activation/test_split_type_b.py @@ -0,0 +1,87 @@ +"""Tests for split Type-B (polar) activation functions.""" + +from __future__ import annotations + +import pytest +import torch +import torch.nn as nn + +from complextorch.nn.modules.activation.split_type_B import ( + AdaptiveModReLU, + CVPolarLog, + CVPolarSquash, + CVPolarTanh, + GeneralizedPolarActivation, + modReLU, +) + + +@pytest.fixture +def z_channel(): + # (B, C, L) for AdaptiveModReLU + return torch.randn(2, 3, 6, dtype=torch.cfloat) + + +@pytest.mark.parametrize("cls", [CVPolarTanh, CVPolarSquash, CVPolarLog]) +def test_polar_activations_preserve_phase(cls): + z = torch.randn(8, dtype=torch.cfloat) + 1e-3 # avoid |z|=0 + out = cls()(z) + # Phase should be unchanged (these all use phase_fun=None) + torch.testing.assert_close(out.angle(), z.angle(), atol=1e-5, rtol=1e-5) + + +def test_cvpolar_squash_magnitude_bounded(): + z = torch.randn(16, dtype=torch.cfloat) + out = CVPolarSquash()(z) + # squash: x^2 / (1 + x^2) -> magnitude < 1 + assert out.abs().max().item() < 1.0 + + +def test_cvpolar_log_monotonic_in_magnitude(): + z = torch.tensor([1.0 + 0j, 2.0 + 0j, 5.0 + 0j]) + out = CVPolarLog()(z) + # log(|z|+1) is monotonic increasing + mags = out.abs() + assert mags[0] < mags[1] < mags[2] + + +def test_modrelu_static_bias_must_be_negative(): + with pytest.raises(AssertionError, match="smaller than 0"): + modReLU(bias=0.5) + + +def test_modrelu_static_bias_zeros_small_magnitudes(): + act = modReLU(bias=-0.5) + z = torch.tensor([0.1 + 0j, 1.0 + 0j], dtype=torch.cfloat) + out = act(z) + assert out[0].abs().item() == 0.0 # |0.1| - 0.5 < 0 -> zeroed + assert out[1].abs().item() > 0.0 + + +def test_modrelu_learnable_bias_can_be_positive(): + act = modReLU(bias=0.1, learnable=True) + assert isinstance(act.activation_mag.bias, nn.Parameter) + + +def test_modrelu_static_bias_is_buffer(): + act = modReLU(bias=-0.1, learnable=False) + assert not isinstance(act.activation_mag.bias, nn.Parameter) + assert "bias" in dict(act.activation_mag.named_buffers()) + + +def test_adaptive_modrelu_per_channel(z_channel): + act = AdaptiveModReLU(num_features=3, init=-0.1) + out = act(z_channel) + assert out.shape == z_channel.shape + assert out.is_complex() + assert act.activation_mag.bias.shape == (3,) + + +def test_generalized_polar_with_phase_fun(): + z = torch.randn(4, dtype=torch.cfloat) + 1e-3 + # Use a non-trivial phase function + act = GeneralizedPolarActivation(nn.Identity(), nn.Identity()) + out = act(z) + # phase passed through identity -> phase same; magnitude unchanged + torch.testing.assert_close(out.abs(), z.abs(), atol=1e-5, rtol=1e-5) + torch.testing.assert_close(out.angle(), z.angle(), atol=1e-5, rtol=1e-5) diff --git a/tests/nn/modules/attention/__init__.py b/tests/nn/modules/attention/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/nn/modules/attention/test_attention.py b/tests/nn/modules/attention/test_attention.py new file mode 100644 index 0000000..cbb3cec --- /dev/null +++ b/tests/nn/modules/attention/test_attention.py @@ -0,0 +1,53 @@ +"""Tests for ScaledDotProductAttention and MultiheadAttention.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.attention import ( + MultiheadAttention, + ScaledDotProductAttention, +) + + +def test_scaled_dot_product_complex_softmax(): + attn = ScaledDotProductAttention(temperature=2.0, softmax_on="complex") + q = torch.randn(2, 3, 5, 4, dtype=torch.cfloat) + k = torch.randn(2, 3, 7, 4, dtype=torch.cfloat) + v = torch.randn(2, 3, 7, 4, dtype=torch.cfloat) + out = attn(q, k, v) + assert out.shape == (2, 3, 5, 4) + assert out.is_complex() + + +def test_scaled_dot_product_real_softmax(): + attn = ScaledDotProductAttention(temperature=2.0, softmax_on="real") + q = torch.randn(2, 3, 5, 4, dtype=torch.cfloat) + k = torch.randn(2, 3, 7, 4, dtype=torch.cfloat) + v = torch.randn(2, 3, 7, 4, dtype=torch.cfloat) + out = attn(q, k, v) + assert out.shape == (2, 3, 5, 4) + assert out.is_complex() + + +def test_scaled_dot_product_invalid_softmax_on(): + with pytest.raises(ValueError, match="softmax_on must be"): + ScaledDotProductAttention(temperature=1.0, softmax_on="bogus") + + +def test_multihead_attention_forward(): + mha = MultiheadAttention(n_heads=2, d_model=8, d_k=4, d_v=4) + q = torch.randn(2, 5, 8, dtype=torch.cfloat) + out = mha(q, q, q) + assert out.shape == q.shape + assert out.is_complex() + + +def test_multihead_attention_cross(): + mha = MultiheadAttention(n_heads=2, d_model=8, d_k=4, d_v=4) + q = torch.randn(2, 5, 8, dtype=torch.cfloat) + k = torch.randn(2, 7, 8, dtype=torch.cfloat) + v = torch.randn(2, 7, 8, dtype=torch.cfloat) + out = mha(q, k, v) + assert out.shape == q.shape diff --git a/tests/nn/modules/attention/test_eca.py b/tests/nn/modules/attention/test_eca.py new file mode 100644 index 0000000..b6e98fa --- /dev/null +++ b/tests/nn/modules/attention/test_eca.py @@ -0,0 +1,28 @@ +"""Tests for EfficientChannelAttention{1,2,3}d.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.attention.eca import ( + EfficientChannelAttention1d, + EfficientChannelAttention2d, + EfficientChannelAttention3d, +) + + +@pytest.mark.parametrize( + ("cls", "shape"), + [ + (EfficientChannelAttention1d, (1, 8, 16)), + (EfficientChannelAttention2d, (1, 8, 8, 8)), + (EfficientChannelAttention3d, (1, 8, 4, 4, 4)), + ], +) +def test_eca_forward(cls, shape): + eca = cls(channels=8) + x = torch.randn(*shape, dtype=torch.cfloat) + out = eca(x) + assert out.shape == x.shape + assert out.is_complex() diff --git a/tests/nn/modules/attention/test_mca.py b/tests/nn/modules/attention/test_mca.py new file mode 100644 index 0000000..e9815ff --- /dev/null +++ b/tests/nn/modules/attention/test_mca.py @@ -0,0 +1,33 @@ +"""Tests for MaskedChannelAttention{1,2,3}d.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.attention.mca import ( + MaskedChannelAttention1d, + MaskedChannelAttention2d, + MaskedChannelAttention3d, +) + + +@pytest.mark.parametrize( + ("cls", "shape"), + [ + (MaskedChannelAttention1d, (1, 8, 16)), + (MaskedChannelAttention2d, (1, 8, 8, 8)), + (MaskedChannelAttention3d, (1, 8, 4, 4, 4)), + ], +) +def test_mca_forward(cls, shape): + mca = cls(channels=8, reduction_factor=2) + x = torch.randn(*shape, dtype=torch.cfloat) + out = mca(x) + assert out.shape == x.shape + assert out.is_complex() + + +def test_mca_invalid_reduction_factor(): + with pytest.raises(AssertionError, match="yield integer"): + MaskedChannelAttention1d(channels=8, reduction_factor=3) diff --git a/tests/nn/modules/test_batchnorm.py b/tests/nn/modules/test_batchnorm.py new file mode 100644 index 0000000..4fe2a12 --- /dev/null +++ b/tests/nn/modules/test_batchnorm.py @@ -0,0 +1,181 @@ +"""Tests for complex BatchNorm{1,2,3}d, NaiveBatchNorm{1,2,3}d, MagBatchNorm{1,2,3}d.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.batchnorm import ( + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, + MagBatchNorm1d, + MagBatchNorm2d, + MagBatchNorm3d, + NaiveBatchNorm1d, + NaiveBatchNorm2d, + NaiveBatchNorm3d, +) + +# ---------- Trabelsi BN (2x2 whitening) ---------- + + +@pytest.mark.parametrize( + ("cls", "shape"), + [ + (BatchNorm1d, (16, 4, 10)), + (BatchNorm2d, (16, 4, 8, 8)), + (BatchNorm3d, (4, 4, 4, 4, 4)), + ], +) +def test_batchnorm_forward_train_and_eval(cls, shape): + bn = cls(num_features=4) + x = torch.randn(*shape, dtype=torch.cfloat) + bn.train() + out = bn(x) + assert out.shape == x.shape + assert out.is_complex() + # running stats should now be non-trivial + assert (bn.running_mean.abs().sum() > 0) or (bn.num_batches_tracked.item() == 1) + bn.eval() + out2 = bn(x) + assert out2.shape == x.shape + + +def test_batchnorm1d_2d_input(): + """BatchNorm1d accepts both 2D and 3D inputs.""" + bn = BatchNorm1d(num_features=4) + x = torch.randn(8, 4, dtype=torch.cfloat) + out = bn(x) + assert out.shape == (8, 4) + + +def test_batchnorm1d_invalid_dim(): + bn = BatchNorm1d(4) + with pytest.raises(ValueError, match="expected 2D or 3D"): + bn(torch.randn(2, 4, 4, 4, dtype=torch.cfloat)) + + +def test_batchnorm2d_invalid_dim(): + bn = BatchNorm2d(4) + with pytest.raises(ValueError, match="expected 4D"): + bn(torch.randn(2, 4, 4, dtype=torch.cfloat)) + + +def test_batchnorm3d_invalid_dim(): + bn = BatchNorm3d(4) + with pytest.raises(ValueError, match="expected 5D"): + bn(torch.randn(2, 4, 4, 4, dtype=torch.cfloat)) + + +def test_batchnorm_no_affine_no_track(): + bn = BatchNorm2d(4, affine=False, track_running_stats=False) + assert bn.weight is None + assert bn.bias is None + assert bn.running_mean is None + x = torch.randn(8, 4, 6, 6, dtype=torch.cfloat) + out = bn(x) + assert out.shape == x.shape + + +def test_batchnorm_momentum_none_uses_cumulative(): + bn = BatchNorm2d(4, momentum=None) + bn.train() + x = torch.randn(8, 4, 4, 4, dtype=torch.cfloat) + bn(x) + bn(x) + assert bn.num_batches_tracked.item() == 2 + + +def test_batchnorm_extra_repr(): + s = BatchNorm2d(4).extra_repr() + assert "4" in s # the format string starts with {num_features} + assert "affine=True" in s + + +def test_batchnorm_reset_running_stats_idempotent_when_untracked(): + bn = BatchNorm1d(4, track_running_stats=False) + bn.reset_running_stats() # no-op when not tracking + + +def test_batchnorm_reset_parameters_no_op_no_affine(): + bn = BatchNorm1d(4, affine=False) + bn.reset_parameters() # no-op + + +# ---------- Naive (split) BN ---------- + + +@pytest.mark.parametrize( + ("cls", "shape"), + [ + (NaiveBatchNorm1d, (16, 4, 10)), + (NaiveBatchNorm2d, (16, 4, 8, 8)), + (NaiveBatchNorm3d, (4, 4, 4, 4, 4)), + ], +) +def test_naive_batchnorm_forward(cls, shape): + bn = cls(num_features=4) + x = torch.randn(*shape, dtype=torch.cfloat) + out = bn(x) + assert out.shape == x.shape + assert out.is_complex() + + +def test_naive_batchnorm_extra_repr(): + s = NaiveBatchNorm2d(4).extra_repr() + assert "num_features=4" in s + + +# ---------- MagBatchNorm (magnitude-only, equivariant) ---------- + + +@pytest.mark.parametrize( + ("cls", "shape"), + [ + (MagBatchNorm1d, (16, 4, 10)), + (MagBatchNorm2d, (16, 4, 8, 8)), + (MagBatchNorm3d, (4, 4, 4, 4, 4)), + ], +) +def test_mag_batchnorm_forward(cls, shape): + bn = cls(num_features=4) + x = torch.randn(*shape, dtype=torch.cfloat) + out = bn(x) + assert out.shape == x.shape + assert out.is_complex() + + +def test_mag_batchnorm_equivariant_under_phase_rotation(): + bn = MagBatchNorm2d(num_features=4) + bn.eval() + bn.train() + _ = bn(torch.randn(8, 4, 6, 6, dtype=torch.cfloat)) + bn.eval() + x = torch.randn(2, 4, 6, 6, dtype=torch.cfloat) + 0.1 + rotor = torch.polar(torch.tensor(1.0), torch.tensor(0.9)) + y1 = bn(x * rotor) + y2 = bn(x) * rotor + torch.testing.assert_close(y1, y2, atol=1e-4, rtol=1e-4) + + +def test_mag_batchnorm_state_dict_uses_underlying_bn(): + bn = MagBatchNorm2d(num_features=4) + sd_keys = list(bn.state_dict().keys()) + # The underlying real BN's params should be exposed as bn.weight, bn.bias, etc. + assert any(k.startswith("bn.") for k in sd_keys) + + +def test_mag_batchnorm_extra_repr(): + s = MagBatchNorm2d(4).extra_repr() + assert "num_features=4" in s + + +def test_mag_batchnorm_real_input_passthrough(): + """Real input goes straight through the underlying real BN (else branch).""" + bn = MagBatchNorm2d(num_features=4) + bn.train() + x = torch.randn(8, 4, 6, 6) + out = bn(x) + assert not out.is_complex() + assert out.shape == x.shape diff --git a/tests/nn/modules/test_casting.py b/tests/nn/modules/test_casting.py new file mode 100644 index 0000000..0fa873d --- /dev/null +++ b/tests/nn/modules/test_casting.py @@ -0,0 +1,76 @@ +"""Tests for the layout-casting modules.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.casting import ( + ComplexToConcatenated, + ComplexToInterleaved, + ConcatenatedToComplex, + InterleavedToComplex, + RealToComplex, +) + + +def test_interleaved_roundtrip(): + to_complex = InterleavedToComplex() + to_real = ComplexToInterleaved() + x = torch.randn(2, 6) # last dim 6 = 2 * 3 + z = to_complex(x) + assert z.shape == (2, 3) + assert z.is_complex() + x_back = to_real(z) + torch.testing.assert_close(x_back, x) + + +def test_interleaved_to_complex_odd_dim_raises(): + with pytest.raises(ValueError, match="even"): + InterleavedToComplex()(torch.randn(2, 5)) + + +def test_complex_to_interleaved_real_input_raises(): + with pytest.raises(TypeError, match="expects a complex input"): + ComplexToInterleaved()(torch.randn(2, 4)) + + +def test_concatenated_roundtrip(): + to_complex = ConcatenatedToComplex() + to_real = ComplexToConcatenated() + z = torch.randn(2, 3, dtype=torch.cfloat) + x = to_real(z) + assert x.shape == (2, 6) + z_back = to_complex(x) + torch.testing.assert_close(z_back, z) + + +def test_concatenated_to_complex_odd_raises(): + with pytest.raises(ValueError, match="even"): + ConcatenatedToComplex()(torch.randn(2, 5)) + + +def test_complex_to_concatenated_real_input_raises(): + with pytest.raises(TypeError, match="expects a complex input"): + ComplexToConcatenated()(torch.randn(2, 4)) + + +def test_real_to_complex_from_real(): + lift = RealToComplex() + x = torch.randn(2, 4) + z = lift(x) + assert z.is_complex() + torch.testing.assert_close(z.real, x) + torch.testing.assert_close(z.imag, torch.zeros_like(x)) + + +def test_real_to_complex_complex_input_just_casts(): + lift = RealToComplex(dtype=torch.cdouble) + x = torch.randn(2, 4, dtype=torch.cfloat) + z = lift(x) + assert z.dtype == torch.cdouble + + +def test_real_to_complex_extra_repr(): + s = RealToComplex().extra_repr() + assert "dtype" in s diff --git a/tests/nn/modules/test_conv.py b/tests/nn/modules/test_conv.py new file mode 100644 index 0000000..9252d22 --- /dev/null +++ b/tests/nn/modules/test_conv.py @@ -0,0 +1,59 @@ +"""Tests for complex Conv* and ConvTranspose* layers (native cfloat wrappers).""" + +from __future__ import annotations + +import torch + +from complextorch.nn.modules.conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) + +# ---------- Fast variants ---------- + + +def test_conv1d_forward_shape(): + layer = Conv1d(2, 4, kernel_size=3, padding=1) + x = torch.randn(2, 2, 8, dtype=torch.cfloat) + out = layer(x) + assert out.shape == (2, 4, 8) + assert out.is_complex() + + +def test_conv2d_forward_shape(): + layer = Conv2d(2, 4, kernel_size=3, padding=1) + x = torch.randn(2, 2, 8, 8, dtype=torch.cfloat) + out = layer(x) + assert out.shape == (2, 4, 8, 8) + + +def test_conv3d_forward_shape(): + layer = Conv3d(2, 4, kernel_size=3, padding=1) + x = torch.randn(1, 2, 4, 4, 4, dtype=torch.cfloat) + out = layer(x) + assert out.shape == (1, 4, 4, 4, 4) + + +def test_convtranspose1d_forward_shape(): + layer = ConvTranspose1d(2, 4, kernel_size=3, stride=2) + x = torch.randn(2, 2, 8, dtype=torch.cfloat) + out = layer(x) + assert out.is_complex() + + +def test_convtranspose2d_forward_shape(): + layer = ConvTranspose2d(2, 4, kernel_size=3, stride=2) + x = torch.randn(1, 2, 4, 4, dtype=torch.cfloat) + out = layer(x) + assert out.is_complex() + + +def test_convtranspose3d_forward_shape(): + layer = ConvTranspose3d(2, 4, kernel_size=3, stride=2) + x = torch.randn(1, 2, 2, 2, 2, dtype=torch.cfloat) + out = layer(x) + assert out.is_complex() diff --git a/tests/nn/modules/test_dropout.py b/tests/nn/modules/test_dropout.py new file mode 100644 index 0000000..64deabc --- /dev/null +++ b/tests/nn/modules/test_dropout.py @@ -0,0 +1,76 @@ +"""Tests for complex Dropout / Dropout{1,2,3}d.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.dropout import Dropout, Dropout1d, Dropout2d, Dropout3d + + +def test_dropout_eval_passthrough(): + drop = Dropout(p=0.9) + drop.eval() + x = torch.randn(4, 8, dtype=torch.cfloat) + out = drop(x) + torch.testing.assert_close(out, x) + + +def test_dropout_train_changes_inputs(): + drop = Dropout(p=0.5) + drop.train() + x = torch.ones(1000, dtype=torch.cfloat) + out = drop(x) + # Many entries should now be zero + assert (out == 0).any() + + +def test_dropout1d_eval_is_identity(): + drop = Dropout1d(p=0.5) + drop.eval() + x = torch.randn(2, 4, 8, dtype=torch.cfloat) + out = drop(x) + torch.testing.assert_close(out, x) + + +@pytest.mark.parametrize( + ("cls", "shape"), + [ + (Dropout1d, (2, 4, 8)), + (Dropout2d, (2, 4, 4, 4)), + (Dropout3d, (1, 4, 3, 3, 3)), + ], +) +def test_channel_dropout_complex_train(cls, shape): + drop = cls(p=0.5) + drop.train() + x = torch.randn(*shape, dtype=torch.cfloat) + out = drop(x) + assert out.shape == x.shape + assert out.is_complex() + + +def test_channel_dropout_p_zero_is_identity(): + drop = Dropout1d(p=0.0) + drop.train() + x = torch.randn(2, 4, 8, dtype=torch.cfloat) + out = drop(x) + torch.testing.assert_close(out, x) + + +def test_channel_dropout_invalid_p(): + with pytest.raises(ValueError, match="must be in"): + Dropout1d(p=1.0) + + +def test_channel_dropout_real_input_path(): + drop = Dropout1d(p=0.5) + drop.train() + x = torch.randn(2, 4, 8) + out = drop(x) + assert out.shape == x.shape + + +def test_channel_dropout_extra_repr(): + s = Dropout2d(p=0.3).extra_repr() + assert "p=0.3" in s diff --git a/tests/nn/modules/test_fft.py b/tests/nn/modules/test_fft.py new file mode 100644 index 0000000..d5a2b50 --- /dev/null +++ b/tests/nn/modules/test_fft.py @@ -0,0 +1,28 @@ +"""Tests for FFTBlock / IFFTBlock.""" + +from __future__ import annotations + +import torch + +from complextorch.nn.modules.fft import FFTBlock, IFFTBlock + + +def test_fft_round_trip(): + fwd = FFTBlock(dim=-1, norm="ortho") + inv = IFFTBlock(dim=-1, norm="ortho") + x = torch.randn(2, 8, dtype=torch.cfloat) + torch.testing.assert_close(inv(fwd(x)), x, atol=1e-5, rtol=1e-5) + + +def test_fft_with_n_param(): + fwd = FFTBlock(n=16, dim=-1) + x = torch.randn(2, 8, dtype=torch.cfloat) + out = fwd(x) + assert out.shape == (2, 16) + + +def test_ifft_with_n_param(): + inv = IFFTBlock(n=8, dim=-1) + x = torch.randn(2, 16, dtype=torch.cfloat) + out = inv(x) + assert out.shape == (2, 8) diff --git a/tests/nn/modules/test_groupnorm.py b/tests/nn/modules/test_groupnorm.py new file mode 100644 index 0000000..34eebab --- /dev/null +++ b/tests/nn/modules/test_groupnorm.py @@ -0,0 +1,61 @@ +"""Tests for complex GroupNorm.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.groupnorm import GroupNorm + + +def test_groupnorm_forward(): + gn = GroupNorm(num_groups=2, num_channels=8) + x = torch.randn(4, 8, 6, 6, dtype=torch.cfloat) + out = gn(x) + assert out.shape == x.shape + assert out.is_complex() + + +def test_groupnorm_no_affine(): + gn = GroupNorm(num_groups=2, num_channels=8, affine=False) + x = torch.randn(4, 8, 6, 6, dtype=torch.cfloat) + out = gn(x) + assert gn.weight is None + assert gn.bias is None + assert out.shape == x.shape + + +def test_groupnorm_indivisible_raises(): + with pytest.raises(ValueError, match="must be divisible"): + GroupNorm(num_groups=3, num_channels=8) + + +def test_groupnorm_real_input_raises(): + gn = GroupNorm(2, 4) + with pytest.raises(TypeError, match="expects a complex input"): + gn(torch.randn(2, 4, 4, 4)) + + +def test_groupnorm_wrong_channels_raises(): + gn = GroupNorm(2, 4) + with pytest.raises(ValueError, match="Expected 4 channels"): + gn(torch.randn(2, 8, 4, 4, dtype=torch.cfloat)) + + +def test_groupnorm_extra_repr(): + s = GroupNorm(2, 8).extra_repr() + assert "num_groups=2" in s + assert "num_channels=8" in s + + +def test_groupnorm_reset_parameters_no_op_no_affine(): + gn = GroupNorm(2, 4, affine=False) + gn.reset_parameters() + + +def test_groupnorm_1d_spatial(): + """GroupNorm with no spatial dims (B, C) works too.""" + gn = GroupNorm(num_groups=2, num_channels=8) + x = torch.randn(4, 8, 6, dtype=torch.cfloat) + out = gn(x) + assert out.shape == x.shape diff --git a/tests/nn/modules/test_layernorm.py b/tests/nn/modules/test_layernorm.py new file mode 100644 index 0000000..705b041 --- /dev/null +++ b/tests/nn/modules/test_layernorm.py @@ -0,0 +1,43 @@ +"""Tests for complex LayerNorm.""" + +from __future__ import annotations + +import torch + +from complextorch.nn.modules.layernorm import LayerNorm + + +def test_layernorm_int_normalized_shape(): + ln = LayerNorm(8) + x = torch.randn(4, 8, dtype=torch.cfloat) + out = ln(x) + assert out.shape == x.shape + assert out.is_complex() + + +def test_layernorm_list_normalized_shape(): + ln = LayerNorm([4, 8]) + x = torch.randn(2, 3, 4, 8, dtype=torch.cfloat) + out = ln(x) + assert out.shape == x.shape + + +def test_layernorm_torch_size_normalized_shape(): + ln = LayerNorm(torch.Size([8])) + x = torch.randn(2, 8, dtype=torch.cfloat) + out = ln(x) + assert out.shape == x.shape + + +def test_layernorm_no_affine(): + ln = LayerNorm(8, elementwise_affine=False) + x = torch.randn(4, 8, dtype=torch.cfloat) + out = ln(x) + assert out.shape == x.shape + assert ln.weight is None + assert ln.bias is None + + +def test_layernorm_reset_parameters_no_op_no_affine(): + ln = LayerNorm(4, elementwise_affine=False) + ln.reset_parameters() diff --git a/tests/nn/modules/test_linear.py b/tests/nn/modules/test_linear.py new file mode 100644 index 0000000..fa23b1f --- /dev/null +++ b/tests/nn/modules/test_linear.py @@ -0,0 +1,58 @@ +"""Tests for complex Linear and Bilinear.""" + +from __future__ import annotations + +import torch + +from complextorch.nn.modules.linear import Bilinear, Linear + + +def test_linear_forward_shape_and_dtype(): + layer = Linear(8, 4, bias=True) + x = torch.randn(3, 8, dtype=torch.cfloat) + out = layer(x) + assert out.shape == (3, 4) + assert out.is_complex() + + +def test_bilinear_hermitian_default(): + layer = Bilinear( + in1_features=4, in2_features=6, out_features=3, bias=True, conjugate=True + ) + x1 = torch.randn(2, 4, dtype=torch.cfloat) + x2 = torch.randn(2, 6, dtype=torch.cfloat) + out = layer(x1, x2) + assert out.shape == (2, 3) + assert out.is_complex() + # Hermitian form: when x1 is real, conjugate is the identity + x1r = torch.randn(2, 4, dtype=torch.float).to(torch.cfloat) + out_real = layer(x1r, x2) + expected = torch.einsum("...i,kij,...j->...k", x1r, layer.weight, x2) + layer.bias + torch.testing.assert_close(out_real, expected) + + +def test_bilinear_plain_no_conjugate(): + layer = Bilinear(4, 6, 3, conjugate=False, bias=True) + x1 = torch.randn(2, 4, dtype=torch.cfloat) + x2 = torch.randn(2, 6, dtype=torch.cfloat) + out = layer(x1, x2) + expected = torch.einsum("...i,kij,...j->...k", x1, layer.weight, x2) + layer.bias + torch.testing.assert_close(out, expected) + + +def test_bilinear_no_bias(): + layer = Bilinear(4, 6, 3, bias=False) + x1 = torch.randn(2, 4, dtype=torch.cfloat) + x2 = torch.randn(2, 6, dtype=torch.cfloat) + out = layer(x1, x2) + assert layer.bias is None + expected = torch.einsum("...i,kij,...j->...k", x1.conj(), layer.weight, x2) + torch.testing.assert_close(out, expected) + + +def test_bilinear_extra_repr(): + repr_str = Bilinear(4, 6, 3, conjugate=False).extra_repr() + assert "in1_features=4" in repr_str + assert "in2_features=6" in repr_str + assert "out_features=3" in repr_str + assert "conjugate=False" in repr_str diff --git a/tests/nn/modules/test_loss.py b/tests/nn/modules/test_loss.py new file mode 100644 index 0000000..8573a47 --- /dev/null +++ b/tests/nn/modules/test_loss.py @@ -0,0 +1,169 @@ +"""Tests for complex loss functions.""" + +from __future__ import annotations + +import pytest +import torch +import torch.nn as nn + +from complextorch.nn.modules.loss import ( + SSIM, + CVCauchyError, + CVFourthPowError, + CVLogCoshError, + CVLogError, + CVQuadError, + GeneralizedPolarLoss, + GeneralizedSplitLoss, + MSELoss, + PerpLossSSIM, + SplitL1, + SplitMSE, + SplitSSIM, +) + +# -------- _reduce branches via parameterized losses -------- + +REDUCTION_LOSSES = [ + CVQuadError, + CVFourthPowError, + CVCauchyError, + CVLogCoshError, + MSELoss, +] + + +@pytest.mark.parametrize("cls", REDUCTION_LOSSES) +@pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) +def test_reduction_branches(cls, reduction): + x = torch.randn(4, 8, dtype=torch.cfloat) + y = torch.randn(4, 8, dtype=torch.cfloat) + loss = cls(reduction=reduction) + out = loss(x, y) + if reduction == "none": + assert out.shape == x.shape + else: + assert out.dim() == 0 + + +@pytest.mark.parametrize("cls", REDUCTION_LOSSES) +def test_invalid_reduction(cls): + x = torch.randn(2, 4, dtype=torch.cfloat) + y = torch.randn(2, 4, dtype=torch.cfloat) + loss = cls(reduction="bogus") + with pytest.raises(ValueError, match="reduction must be"): + loss(x, y) + + +# -------- CVLogError (separate because log() requires non-zero input) -------- + + +@pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) +def test_cvlogerror_reduction(reduction): + x = torch.randn(4, dtype=torch.cfloat) + 1.0 + y = torch.randn(4, dtype=torch.cfloat) + 1.0 + out = CVLogError(reduction=reduction)(x, y) + assert torch.isfinite(out).all() or reduction == "none" + + +def test_cvlogerror_invalid_reduction(): + x = torch.randn(2, dtype=torch.cfloat) + 1.0 + y = torch.randn(2, dtype=torch.cfloat) + 1.0 + with pytest.raises(ValueError, match="reduction must be"): + CVLogError(reduction="bogus")(x, y) + + +# -------- Split losses -------- + + +def test_split_l1(): + x = torch.randn(2, 8, dtype=torch.cfloat) + y = torch.randn(2, 8, dtype=torch.cfloat) + out = SplitL1()(x, y) + expected = nn.L1Loss()(x.real, y.real) + nn.L1Loss()(x.imag, y.imag) + torch.testing.assert_close(out, expected) + + +def test_split_mse(): + x = torch.randn(2, 8, dtype=torch.cfloat) + y = torch.randn(2, 8, dtype=torch.cfloat) + out = SplitMSE()(x, y) + expected = nn.MSELoss()(x.real, y.real) + nn.MSELoss()(x.imag, y.imag) + torch.testing.assert_close(out, expected) + + +def test_generalized_split_loss_user_provided(): + loss = GeneralizedSplitLoss(nn.L1Loss(), nn.MSELoss()) + x = torch.randn(2, 8, dtype=torch.cfloat) + y = torch.randn(2, 8, dtype=torch.cfloat) + out = loss(x, y) + expected = nn.L1Loss()(x.real, y.real) + nn.MSELoss()(x.imag, y.imag) + torch.testing.assert_close(out, expected) + + +# -------- Polar loss -------- + + +def test_generalized_polar_loss(): + loss = GeneralizedPolarLoss( + nn.MSELoss(), nn.MSELoss(), weight_mag=2.0, weight_phase=0.5 + ) + x = torch.randn(2, 8, dtype=torch.cfloat) + y = torch.randn(2, 8, dtype=torch.cfloat) + out = loss(x, y) + expected = 2.0 * nn.MSELoss()(x.abs(), y.abs()) + 0.5 * nn.MSELoss()( + x.angle(), y.angle() + ) + torch.testing.assert_close(out, expected) + + +# -------- SSIM family -------- + + +def test_ssim_default_reduction(): + s = SSIM() + x = torch.randn(1, 1, 32, 32) + y = torch.randn(1, 1, 32, 32) + out = s(x, y) + assert out.dim() == 0 + + +def test_ssim_full(): + s = SSIM() + x = torch.randn(1, 1, 32, 32) + y = torch.randn(1, 1, 32, 32) + out = s(x, y, full=True) + assert out.dim() == 4 + + +def test_ssim_with_data_range(): + s = SSIM() + x = torch.randn(1, 1, 32, 32) + y = torch.randn(1, 1, 32, 32) + data_range = torch.ones(1) + out = s(x, y, data_range=data_range) + assert out.dim() == 0 + + +def test_split_ssim(): + s = SplitSSIM() + x = torch.randn(1, 1, 32, 32, dtype=torch.cfloat) + y = torch.randn(1, 1, 32, 32, dtype=torch.cfloat) + out = s(x, y) + assert out.dim() == 0 + + +def test_perp_loss_ssim(): + loss = PerpLossSSIM() + x = torch.randn(1, 1, 32, 32, dtype=torch.cfloat) + y = torch.randn(1, 1, 32, 32, dtype=torch.cfloat) + out = loss(x, y) + assert torch.isfinite(out).all() + + +def test_cvcauchy_with_custom_c(): + loss = CVCauchyError(c=2.0, reduction="mean") + x = torch.randn(4, dtype=torch.cfloat) + y = torch.randn(4, dtype=torch.cfloat) + out = loss(x, y) + assert out.dim() == 0 diff --git a/tests/nn/modules/test_manifold.py b/tests/nn/modules/test_manifold.py new file mode 100644 index 0000000..c85a4b2 --- /dev/null +++ b/tests/nn/modules/test_manifold.py @@ -0,0 +1,121 @@ +"""Tests for wFM modules: convolutions, ReLU, distance linear.""" + +from __future__ import annotations + +import torch + +from complextorch.nn.modules.manifold import ( + wFMConv1d, + wFMConv2d, + wFMDistanceLinear, + wFMReLU, +) + + +def test_wfm_conv2d_forward(): + conv = wFMConv2d(in_channels=3, out_channels=5, kernel_size=(3, 3), padding=(1, 1)) + x = torch.randn(2, 3, 8, 8, dtype=torch.cfloat) + 0.1 + out = conv(x) + assert out.is_complex() + assert out.shape[0] == 2 + + +def test_wfm_conv2d_fold_cache_reused(): + conv = wFMConv2d(in_channels=3, out_channels=5, kernel_size=(3, 3), padding=(1, 1)) + x = torch.randn(2, 3, 8, 8, dtype=torch.cfloat) + 0.1 + conv(x) + n_before = len(conv._fold_cache) + conv(x) + n_after = len(conv._fold_cache) + assert n_before == n_after + + +def test_wfm_conv1d_forward(): + conv = wFMConv1d(in_channels=3, out_channels=5, kernel_size=3, padding=1) + x = torch.randn(2, 3, 8, dtype=torch.cfloat) + 0.1 + out = conv(x) + assert out.is_complex() + + +def test_wfm_conv1d_weight_properties(): + conv = wFMConv1d(in_channels=3, out_channels=5, kernel_size=3) + assert conv.weight_matrix_ang is conv.conv1d.weight_matrix_ang + assert conv.weight_matrix_mag is conv.conv1d.weight_matrix_mag + + +# ---------- wFMReLU ---------- + + +def test_wfm_relu_forward_shape(): + layer = wFMReLU(num_channels=4) + x = torch.randn(2, 4, 6, 6, dtype=torch.cfloat) + 0.1 + out = layer(x) + assert out.shape == x.shape + assert out.is_complex() + + +def test_wfm_relu_grad_flows(): + layer = wFMReLU(num_channels=3) + x = torch.randn(2, 3, 4, dtype=torch.cfloat) + 0.1 + out = layer(x) + out.abs().sum().backward() + assert layer.weight_phase.grad is not None + assert layer.weight_mag.grad is not None + assert torch.isfinite(layer.weight_phase.grad).all() + + +def test_wfm_relu_extra_repr(): + s = wFMReLU(num_channels=4).extra_repr() + assert "num_channels=4" in s + + +# ---------- wFMDistanceLinear ---------- + + +def test_wfm_distance_linear_returns_real(): + layer = wFMDistanceLinear(input_dim=4 * 6) + x = torch.randn(2, 4, 6, dtype=torch.cfloat) + 0.1 + out = layer(x) + assert out.shape == x.shape # preserves shape of complex input + assert not out.is_complex() + assert out.dtype == torch.float32 or out.dtype == torch.float64 + + +def test_wfm_distance_linear_grad_flows(): + layer = wFMDistanceLinear(input_dim=12) + x = torch.randn(2, 3, 4, dtype=torch.cfloat) + 0.1 + out = layer(x) + out.sum().backward() + assert layer.weights.grad is not None + assert torch.isfinite(layer.weights.grad).all() + + +def test_wfm_distance_linear_extra_repr(): + s = wFMDistanceLinear(input_dim=10).extra_repr() + assert "input_dim=10" in s + + +def test_wfm_relu_real_input_auto_casts(): + """Real input is auto-cast to cfloat (forward's else branch).""" + layer = wFMReLU(num_channels=3) + x = torch.randn(2, 3, 5) + 0.1 + out = layer(x) + assert out.is_complex() + + +def test_wfm_distance_linear_real_input_auto_casts(): + """wFMDistanceLinear also auto-casts real input.""" + layer = wFMDistanceLinear(input_dim=12) + x = torch.randn(2, 3, 4) + 0.1 + out = layer(x) + assert torch.isfinite(out).all() + + +def test_wfm_distance_linear_wrong_input_dim_raises(): + """flat.shape[1] != input_dim triggers a ValueError.""" + import pytest + + layer = wFMDistanceLinear(input_dim=10) + x = torch.randn(2, 3, 4, dtype=torch.cfloat) + 0.1 # 12 != 10 + with pytest.raises(ValueError, match="expects flattened input of size"): + layer(x) diff --git a/tests/nn/modules/test_mask.py b/tests/nn/modules/test_mask.py new file mode 100644 index 0000000..1a0f982 --- /dev/null +++ b/tests/nn/modules/test_mask.py @@ -0,0 +1,38 @@ +"""Tests for ComplexRatioMask / PhaseSigmoid / MagMinMaxNorm.""" + +from __future__ import annotations + +import torch + +from complextorch.nn.modules.mask import ComplexRatioMask, MagMinMaxNorm, PhaseSigmoid + + +def test_complex_ratio_mask_bounded_magnitude(): + m = ComplexRatioMask() + x = torch.randn(8, dtype=torch.cfloat) * 10 + out = m(x) + # sigmoid(|z|) <= 1 + assert out.abs().max().item() <= 1.0 + + +def test_phase_sigmoid_bounded_magnitude(): + m = PhaseSigmoid() + x = torch.randn(8, dtype=torch.cfloat) * 10 + out = m(x) + assert out.abs().max().item() <= 1.0 + + +def test_mag_min_max_norm_global(): + m = MagMinMaxNorm(dim=None) + x = torch.tensor([1 + 0j, 2 + 0j, 5 + 0j], dtype=torch.cfloat) + out = m(x) + assert out.abs().min().item() == 0.0 + assert out.abs().max().item() == 1.0 + + +def test_mag_min_max_norm_dim(): + m = MagMinMaxNorm(dim=-1) + x = torch.tensor([[1 + 0j, 2 + 0j, 5 + 0j]], dtype=torch.cfloat) + out = m(x) + assert out.abs().min().item() == 0.0 + assert out.abs().max().item() == 1.0 diff --git a/tests/nn/modules/test_phase.py b/tests/nn/modules/test_phase.py new file mode 100644 index 0000000..1f67d2a --- /dev/null +++ b/tests/nn/modules/test_phase.py @@ -0,0 +1,101 @@ +"""Tests for PhaseShift and ComplexScaling.""" + +from __future__ import annotations + +import torch + +from complextorch.nn.modules.phase import ComplexScaling, PhaseShift + + +def test_phase_shift_scalar(): + layer = PhaseShift(num_features=1) + x = torch.randn(4, 5, dtype=torch.cfloat) + out = layer(x) + assert out.shape == x.shape + # |out| == |x|; phase rotated + torch.testing.assert_close(out.abs(), x.abs(), atol=1e-5, rtol=1e-5) + + +def test_phase_shift_per_channel(): + layer = PhaseShift(num_features=4) + x = torch.randn(2, 4, 6, dtype=torch.cfloat) + out = layer(x) + assert out.shape == x.shape + torch.testing.assert_close(out.abs(), x.abs(), atol=1e-5, rtol=1e-5) + + +def test_phase_shift_higher_rank_phi(): + layer = PhaseShift(num_features=(4, 6)) + x = torch.randn(2, 4, 6, dtype=torch.cfloat) + out = layer(x) + assert out.shape == x.shape + + +def test_phase_shift_broadcast_dim_neg(): + layer = PhaseShift(num_features=3, broadcast_dim=-1) + x = torch.randn(2, 4, 3, dtype=torch.cfloat) + out = layer(x) + assert out.shape == x.shape + + +def test_phase_shift_real_input_auto_casts(): + layer = PhaseShift(num_features=1) + x = torch.randn(4, 5) + out = layer(x) + assert out.is_complex() + + +def test_phase_shift_extra_repr(): + s = PhaseShift(num_features=4).extra_repr() + assert "num_features=4" in s + + +# ---------- ComplexScaling ---------- + + +def test_complex_scaling_shape_per_channel(): + layer = ComplexScaling(num_features=4) + x = torch.randn(2, 4, 6, dtype=torch.cfloat) + out = layer(x) + assert out.shape == x.shape + assert out.is_complex() + + +def test_complex_scaling_scalar_broadcasts(): + layer = ComplexScaling(num_features=1) + x = torch.randn(3, 5, dtype=torch.cfloat) + out = layer(x) + assert out.shape == x.shape + + +def test_complex_scaling_real_input_auto_casts(): + layer = ComplexScaling(num_features=4) + x = torch.randn(2, 4, 6) + out = layer(x) + assert out.is_complex() + + +def test_complex_scaling_equivariant_under_phase_rotation(): + """ComplexScaling commutes with a global phase rotation.""" + layer = ComplexScaling(num_features=4) + x = torch.randn(2, 4, 6, dtype=torch.cfloat) + psi = torch.tensor(0.7) + rotor = torch.polar(torch.tensor(1.0), psi) + y1 = layer(x * rotor) + y2 = layer(x) * rotor + torch.testing.assert_close(y1, y2, atol=1e-5, rtol=1e-5) + + +def test_complex_scaling_grad_flows(): + layer = ComplexScaling(num_features=3) + x = torch.randn(2, 3, dtype=torch.cfloat, requires_grad=True) + out = layer(x) + out.abs().sum().backward() + assert layer.alpha.grad is not None + assert layer.beta.grad is not None + assert torch.isfinite(layer.alpha.grad).all() + + +def test_complex_scaling_extra_repr(): + s = ComplexScaling(num_features=4).extra_repr() + assert "num_features=4" in s diff --git a/tests/nn/modules/test_phase_modulation.py b/tests/nn/modules/test_phase_modulation.py new file mode 100644 index 0000000..9b5b1d4 --- /dev/null +++ b/tests/nn/modules/test_phase_modulation.py @@ -0,0 +1,116 @@ +"""Tests for PhaseDivConv / PhaseConjConv phase-modulation layers.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.phase_modulation import ( + PhaseConjConv1d, + PhaseConjConv2d, + PhaseConjConv3d, + PhaseDivConv1d, + PhaseDivConv2d, + PhaseDivConv3d, + _center_crop, +) + + +@pytest.mark.parametrize( + ("cls", "shape", "kernel"), + [ + (PhaseDivConv1d, (2, 4, 16), 3), + (PhaseDivConv2d, (2, 4, 8, 8), 3), + (PhaseDivConv3d, (2, 4, 4, 4, 4), 3), + (PhaseConjConv1d, (2, 4, 16), 3), + (PhaseConjConv2d, (2, 4, 8, 8), 3), + (PhaseConjConv3d, (2, 4, 4, 4, 4), 3), + ], +) +def test_phase_modulation_forward(cls, shape, kernel): + layer = cls(in_channels=4, kernel_size=kernel, padding=kernel // 2) + x = torch.randn(*shape, dtype=torch.cfloat) + 0.1 + out = layer(x) + assert out.shape == x.shape + assert out.is_complex() + + +def test_phase_div_conv_is_u1_invariant(): + """Global phase rotation cancels in numerator and denominator.""" + layer = PhaseDivConv2d(in_channels=3, kernel_size=3, padding=1) + x = torch.randn(2, 3, 5, 5, dtype=torch.cfloat) + 0.1 + rotor = torch.polar(torch.tensor(1.0), torch.tensor(0.7)) + y1 = layer(x * rotor) + y2 = layer(x) + torch.testing.assert_close(y1, y2, atol=1e-4, rtol=1e-4) + + +def test_phase_conj_conv_is_u1_invariant(): + """PhaseConjConv with a C-linear inner conv is also U(1)-invariant. + + For complex-linear ``g``, ``g(e^{jψ} x) = e^{jψ} g(x)``, so + ``(e^{jψ} x) · conj(e^{jψ} g(x)) = x · conj(g(x))``. + """ + layer = PhaseConjConv2d(in_channels=3, kernel_size=3, padding=1) + x = torch.randn(2, 3, 5, 5, dtype=torch.cfloat) + 0.1 + rotor = torch.polar(torch.tensor(1.0), torch.tensor(0.7)) + y1 = layer(x * rotor) + y2 = layer(x) + torch.testing.assert_close(y1, y2, atol=1e-4, rtol=1e-4) + + +def test_phase_modulation_use_one_filter_false(): + """When use_one_filter=False, inner conv has out_channels=in_channels.""" + layer = PhaseDivConv2d( + in_channels=4, kernel_size=3, padding=1, use_one_filter=False + ) + x = torch.randn(2, 4, 6, 6, dtype=torch.cfloat) + 0.1 + out = layer(x) + assert out.shape == x.shape + + +def test_phase_modulation_center_crop_no_padding(): + """When inner conv shrinks spatial dims, input is center-cropped to match.""" + layer = PhaseDivConv2d(in_channels=3, kernel_size=3, padding=0) + x = torch.randn(2, 3, 8, 8, dtype=torch.cfloat) + 0.1 + out = layer(x) + # No padding + k=3 → shrink by 2 on each spatial dim → (8-2, 8-2) + assert out.shape == (2, 3, 6, 6) + + +def test_center_crop_passthrough_when_shape_matches(): + x = torch.randn(2, 3, 8, 8) + out = _center_crop(x, (8, 8)) + assert out is x or torch.equal(out, x) + + +def test_phase_modulation_grad_flows(): + layer = PhaseDivConv2d(in_channels=3, kernel_size=3, padding=1) + x = torch.randn(2, 3, 5, 5, dtype=torch.cfloat, requires_grad=True) + 0.1 + out = layer(x) + out.abs().sum().backward() + # Inner conv weight should have grad + conv_weight = layer.conv.conv.weight + assert conv_weight.grad is not None + assert torch.isfinite(conv_weight.grad).all() + + +def test_phase_modulation_extra_repr(): + s = PhaseDivConv2d(in_channels=4, kernel_size=3, padding=1).extra_repr() + assert "in_channels=4" in s + + +def test_phase_modulation_invalid_nd_raises(): + """Direct construction of the internal base with nd not in (1, 2, 3) is rejected.""" + from complextorch.nn.modules.phase_modulation import _PhaseDivConvNd + + with pytest.raises(ValueError, match="nd must be 1, 2, or 3"): + _PhaseDivConvNd(nd=4, in_channels=2, kernel_size=3) + + +def test_phase_modulation_real_input_auto_casts(): + """Real input is upcast to cfloat in forward.""" + layer = PhaseDivConv1d(in_channels=3, kernel_size=3, padding=1) + x = torch.randn(2, 3, 8) + 0.1 + out = layer(x) + assert out.is_complex() diff --git a/tests/nn/modules/test_pooling.py b/tests/nn/modules/test_pooling.py new file mode 100644 index 0000000..616aa6d --- /dev/null +++ b/tests/nn/modules/test_pooling.py @@ -0,0 +1,99 @@ +"""Tests for complex pooling layers.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.pooling import ( + AdaptiveAvgPool1d, + AdaptiveAvgPool2d, + AdaptiveAvgPool3d, + AvgPool1d, + AvgPool2d, + AvgPool3d, + MagMaxPool1d, + MagMaxPool2d, + MagMaxPool3d, +) + + +@pytest.mark.parametrize( + ("cls", "shape", "out_size"), + [ + (AdaptiveAvgPool1d, (2, 4, 8), 4), + (AdaptiveAvgPool2d, (2, 4, 8, 8), (4, 4)), + (AdaptiveAvgPool3d, (1, 4, 4, 4, 4), (2, 2, 2)), + ], +) +def test_adaptive_avg_pool(cls, shape, out_size): + pool = cls(out_size) + x = torch.randn(*shape, dtype=torch.cfloat) + out = pool(x) + assert out.is_complex() + + +@pytest.mark.parametrize( + ("cls", "shape"), + [ + (AvgPool1d, (2, 4, 8)), + (AvgPool2d, (2, 4, 8, 8)), + (AvgPool3d, (1, 4, 4, 4, 4)), + ], +) +def test_avg_pool_complex(cls, shape): + pool = cls(kernel_size=2) + x = torch.randn(*shape, dtype=torch.cfloat) + out = pool(x) + assert out.is_complex() + + +@pytest.mark.parametrize( + ("cls", "shape"), + [ + (AvgPool1d, (2, 4, 8)), + (AvgPool2d, (2, 4, 8, 8)), + (AvgPool3d, (1, 4, 4, 4, 4)), + ], +) +def test_avg_pool_real_passthrough(cls, shape): + pool = cls(kernel_size=2) + x = torch.randn(*shape) + out = pool(x) + assert not out.is_complex() + + +@pytest.mark.parametrize( + ("cls", "shape"), + [ + (MagMaxPool1d, (2, 4, 8)), + (MagMaxPool2d, (2, 4, 8, 8)), + (MagMaxPool3d, (1, 4, 4, 4, 4)), + ], +) +def test_magmax_pool_complex(cls, shape): + pool = cls(kernel_size=2) + x = torch.randn(*shape, dtype=torch.cfloat) + out = pool(x) + assert out.is_complex() + + +def test_magmaxpool_real_input(): + pool = MagMaxPool1d(kernel_size=2) + x = torch.randn(2, 4, 8) + out = pool(x) + assert not out.is_complex() + + +def test_magmaxpool_return_indices(): + pool = MagMaxPool2d(kernel_size=2, return_indices=True) + x = torch.randn(1, 2, 4, 4, dtype=torch.cfloat) + out, indices = pool(x) + assert out.is_complex() + assert indices.dtype == torch.int64 + + +def test_magmaxpool_extra_repr(): + s = MagMaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1).extra_repr() + assert "kernel_size=3" in s + assert "stride=2" in s diff --git a/tests/nn/modules/test_prototype.py b/tests/nn/modules/test_prototype.py new file mode 100644 index 0000000..fedc12e --- /dev/null +++ b/tests/nn/modules/test_prototype.py @@ -0,0 +1,111 @@ +"""Tests for PrototypeDistance.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.prototype import PrototypeDistance + + +def test_prototype_distance_forward_shape(): + head = PrototypeDistance(in_features=8, num_prototypes=10) + x = torch.randn(4, 8, dtype=torch.cfloat) + logits = head(x) + assert logits.shape == (4, 10) + assert not logits.is_complex() + + +def test_prototype_distance_closest_prototype_has_largest_logit(): + """If an input matches prototype k exactly, logit k should be largest.""" + head = PrototypeDistance(in_features=4, num_prototypes=3) + # Match prototype 1 exactly. + target_proto = head.prototypes[:, 1].detach().clone() + x = target_proto.unsqueeze(0) # [1, 4] + logits = head(x) + assert logits.argmax(dim=1).item() == 1 + + +def test_prototype_distance_rejects_non_2d_input(): + head = PrototypeDistance(in_features=4, num_prototypes=3) + with pytest.raises(ValueError, match="expects input of shape"): + head(torch.randn(2, 4, 6, dtype=torch.cfloat)) + + +def test_prototype_distance_with_reference_scalar_broadcasts(): + """E-type call: a single complex scalar per batch broadcasts over channels.""" + head = PrototypeDistance(in_features=4, num_prototypes=3) + x = torch.randn(2, 4, dtype=torch.cfloat) + ref = torch.randn(2, 1, dtype=torch.cfloat) + logits = head(x, reference=ref) + assert logits.shape == (2, 3) + + +def test_prototype_distance_with_reference_per_channel(): + head = PrototypeDistance(in_features=4, num_prototypes=3) + x = torch.randn(2, 4, dtype=torch.cfloat) + ref = torch.randn(2, 4, dtype=torch.cfloat) + logits = head(x, reference=ref) + assert logits.shape == (2, 3) + + +def test_prototype_distance_etype_invariant_under_phase_rotation(): + """When both input and reference rotate by e^{j psi}, logits are unchanged.""" + head = PrototypeDistance(in_features=4, num_prototypes=3) + x = torch.randn(2, 4, dtype=torch.cfloat) + ref = torch.randn(2, 1, dtype=torch.cfloat) + rotor = torch.polar(torch.tensor(1.0), torch.tensor(1.1)) + logits1 = head(x, reference=ref) + logits2 = head(x * rotor, reference=ref * rotor) + torch.testing.assert_close(logits1, logits2, atol=1e-5, rtol=1e-5) + + +def test_prototype_distance_grad_flows(): + head = PrototypeDistance(in_features=4, num_prototypes=3) + x = torch.randn(2, 4, dtype=torch.cfloat, requires_grad=True) + logits = head(x) + logits.sum().backward() + assert head.prototypes.grad is not None + assert head.temperature.grad is not None + assert torch.isfinite(head.prototypes.grad).all() + + +def test_prototype_distance_extra_repr(): + s = PrototypeDistance(in_features=4, num_prototypes=10).extra_repr() + assert "in_features=4" in s + assert "num_prototypes=10" in s + + +def test_prototype_distance_real_input_auto_casts(): + """A real input tensor is upcast to cfloat in forward.""" + head = PrototypeDistance(in_features=4, num_prototypes=3) + x = torch.randn(2, 4) + logits = head(x) + assert logits.shape == (2, 3) + + +def test_prototype_distance_real_reference_auto_casts(): + """A real reference tensor is upcast to cfloat in forward.""" + head = PrototypeDistance(in_features=4, num_prototypes=3) + x = torch.randn(2, 4, dtype=torch.cfloat) + ref = torch.randn(2, 1) # real + logits = head(x, reference=ref) + assert logits.shape == (2, 3) + + +def test_prototype_distance_reference_1d_promoted_to_2d(): + """A 1-D reference (one complex value per sample) is unsqueezed to 2-D.""" + head = PrototypeDistance(in_features=4, num_prototypes=3) + x = torch.randn(2, 4, dtype=torch.cfloat) + ref = torch.randn(2, dtype=torch.cfloat) # [B] + logits = head(x, reference=ref) + assert logits.shape == (2, 3) + + +def test_prototype_distance_reference_shape_mismatch_raises(): + """A reference whose batch dim doesn't match the input is rejected.""" + head = PrototypeDistance(in_features=4, num_prototypes=3) + x = torch.randn(2, 4, dtype=torch.cfloat) + ref = torch.randn(5, 4, dtype=torch.cfloat) # wrong batch + with pytest.raises(ValueError, match="reference must broadcast"): + head(x, reference=ref) diff --git a/tests/nn/modules/test_rmsnorm.py b/tests/nn/modules/test_rmsnorm.py new file mode 100644 index 0000000..cab20f0 --- /dev/null +++ b/tests/nn/modules/test_rmsnorm.py @@ -0,0 +1,62 @@ +"""Tests for complex RMSNorm.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.rmsnorm import RMSNorm + + +def test_rmsnorm_int_shape(): + rms = RMSNorm(8) + x = torch.randn(4, 8, dtype=torch.cfloat) + out = rms(x) + assert out.shape == x.shape + + +def test_rmsnorm_tuple_shape(): + rms = RMSNorm((4, 8)) + x = torch.randn(2, 4, 8, dtype=torch.cfloat) + out = rms(x) + assert out.shape == x.shape + + +def test_rmsnorm_list_shape(): + rms = RMSNorm([4, 8]) + x = torch.randn(2, 4, 8, dtype=torch.cfloat) + out = rms(x) + assert out.shape == x.shape + + +def test_rmsnorm_no_affine(): + rms = RMSNorm(8, elementwise_affine=False) + x = torch.randn(4, 8, dtype=torch.cfloat) + out = rms(x) + assert rms.weight is None + assert out.shape == x.shape + + +def test_rmsnorm_real_input_raises(): + rms = RMSNorm(8) + with pytest.raises(TypeError, match="expects a complex input"): + rms(torch.randn(4, 8)) + + +def test_rmsnorm_extra_repr(): + s = RMSNorm(8).extra_repr() + assert "normalized_shape=(8,)" in s + + +def test_rmsnorm_reset_parameters_no_op_no_affine(): + rms = RMSNorm(8, elementwise_affine=False) + rms.reset_parameters() + + +def test_rmsnorm_unit_rms_when_no_affine(): + """Without affine, the RMS of |x|^2 over normalized dims should be ~1.""" + rms = RMSNorm(8, elementwise_affine=False) + x = torch.randn(16, 8, dtype=torch.cfloat) * 5 + out = rms(x) + rms_val = out.abs().pow(2).mean(dim=-1) + torch.testing.assert_close(rms_val, torch.ones_like(rms_val), atol=0.01, rtol=0.01) diff --git a/tests/nn/modules/test_rnn.py b/tests/nn/modules/test_rnn.py new file mode 100644 index 0000000..a8c339b --- /dev/null +++ b/tests/nn/modules/test_rnn.py @@ -0,0 +1,99 @@ +"""Tests for complex GRU/LSTM cells and multi-layer wrappers.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.rnn import GRU, LSTM, GRUCell, LSTMCell + +# ---------- Cells ---------- + + +@pytest.mark.parametrize("batchnorm", [False, True]) +def test_grucell(batchnorm): + cell = GRUCell(input_size=4, hidden_size=6, batchnorm=batchnorm) + x = torch.randn(8, 4, dtype=torch.cfloat) + h = cell(x) + assert h.shape == (8, 6) + assert h.is_complex() + # With provided hx + h2 = cell(x, hx=h) + assert h2.shape == (8, 6) + + +@pytest.mark.parametrize("batchnorm", [False, True]) +def test_lstmcell(batchnorm): + cell = LSTMCell(input_size=4, hidden_size=6, batchnorm=batchnorm) + x = torch.randn(8, 4, dtype=torch.cfloat) + h, c = cell(x) + assert h.shape == (8, 6) + assert c.shape == (8, 6) + assert h.is_complex() + assert c.is_complex() + h2, c2 = cell(x, hx=(h, c)) + assert h2.shape == (8, 6) + assert c2.shape == (8, 6) + + +# ---------- Multi-layer ---------- + + +@pytest.mark.parametrize("batch_first", [False, True]) +@pytest.mark.parametrize("bidirectional", [False, True]) +def test_gru_multilayer(batch_first, bidirectional): + n_dir = 2 if bidirectional else 1 + gru = GRU( + input_size=4, + hidden_size=6, + num_layers=2, + batch_first=batch_first, + dropout=0.1, + bidirectional=bidirectional, + ) + if batch_first: + x = torch.randn(3, 5, 4, dtype=torch.cfloat) # (B, T, F) + else: + x = torch.randn(5, 3, 4, dtype=torch.cfloat) # (T, B, F) + out, h = gru(x) + assert out.is_complex() + assert h.shape == (2 * n_dir, 3, 6) + + +def test_gru_with_provided_hx(): + gru = GRU(input_size=4, hidden_size=6, num_layers=1) + x = torch.randn(5, 3, 4, dtype=torch.cfloat) + hx = torch.randn(1, 3, 6, dtype=torch.cfloat) + out, _h = gru(x, hx) + assert out.is_complex() + + +@pytest.mark.parametrize("batch_first", [False, True]) +@pytest.mark.parametrize("bidirectional", [False, True]) +def test_lstm_multilayer(batch_first, bidirectional): + n_dir = 2 if bidirectional else 1 + lstm = LSTM( + input_size=4, + hidden_size=6, + num_layers=2, + batch_first=batch_first, + dropout=0.1, + bidirectional=bidirectional, + ) + if batch_first: + x = torch.randn(3, 5, 4, dtype=torch.cfloat) + else: + x = torch.randn(5, 3, 4, dtype=torch.cfloat) + out, (h, c) = lstm(x) + assert out.is_complex() + assert h.shape == (2 * n_dir, 3, 6) + assert c.shape == (2 * n_dir, 3, 6) + + +def test_lstm_with_provided_hx(): + lstm = LSTM(input_size=4, hidden_size=6, num_layers=1) + x = torch.randn(5, 3, 4, dtype=torch.cfloat) + h0 = torch.randn(1, 3, 6, dtype=torch.cfloat) + c0 = torch.randn(1, 3, 6, dtype=torch.cfloat) + out, (_h, _c) = lstm(x, (h0, c0)) + assert out.is_complex() diff --git a/tests/nn/modules/test_softmax.py b/tests/nn/modules/test_softmax.py new file mode 100644 index 0000000..59321c4 --- /dev/null +++ b/tests/nn/modules/test_softmax.py @@ -0,0 +1,37 @@ +"""Tests for complex softmax variants.""" + +from __future__ import annotations + +import torch + +from complextorch.nn.modules.softmax import CVSoftMax, MagSoftMax, PhaseSoftMax + + +def test_cvsoftmax(): + sm = CVSoftMax(dim=-1) + x = torch.randn(3, 5, dtype=torch.cfloat) + out = sm(x) + assert out.shape == x.shape + # Real part sums to 1 along dim + torch.testing.assert_close( + out.real.sum(dim=-1), torch.ones(3), atol=1e-5, rtol=1e-5 + ) + torch.testing.assert_close( + out.imag.sum(dim=-1), torch.ones(3), atol=1e-5, rtol=1e-5 + ) + + +def test_magsoftmax_returns_real(): + sm = MagSoftMax(dim=-1) + x = torch.randn(3, 5, dtype=torch.cfloat) + out = sm(x) + assert not out.is_complex() + torch.testing.assert_close(out.sum(dim=-1), torch.ones(3), atol=1e-5, rtol=1e-5) + + +def test_phasesoftmax_preserves_phase(): + sm = PhaseSoftMax(dim=-1) + x = torch.randn(3, 5, dtype=torch.cfloat) + 0.5 # avoid |z|=0 + out = sm(x) + assert out.is_complex() + torch.testing.assert_close(out.angle(), x.angle(), atol=1e-5, rtol=1e-5) diff --git a/tests/nn/modules/test_transformer.py b/tests/nn/modules/test_transformer.py new file mode 100644 index 0000000..8a16e6c --- /dev/null +++ b/tests/nn/modules/test_transformer.py @@ -0,0 +1,113 @@ +"""Tests for complex Transformer encoder / decoder / full.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.modules.transformer import ( + Transformer, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, +) + + +@pytest.mark.parametrize("activation", ["gelu", "relu"]) +def test_encoder_layer_forward(activation): + layer = TransformerEncoderLayer( + d_model=8, nhead=2, dim_feedforward=16, activation=activation + ) + src = torch.randn(2, 5, 8, dtype=torch.cfloat) + out = layer(src) + assert out.shape == src.shape + + +def test_encoder_layer_batch_first_false(): + layer = TransformerEncoderLayer( + d_model=8, nhead=2, dim_feedforward=16, batch_first=False + ) + src = torch.randn(5, 2, 8, dtype=torch.cfloat) + out = layer(src) + assert out.shape == src.shape + + +def test_encoder_layer_invalid_d_model_raises(): + with pytest.raises(ValueError, match="must be divisible"): + TransformerEncoderLayer(d_model=8, nhead=3) + + +def test_encoder_layer_invalid_activation(): + with pytest.raises(ValueError, match="Unknown activation"): + TransformerEncoderLayer(d_model=8, nhead=2, activation="bogus") + + +def test_encoder_stack(): + layer = TransformerEncoderLayer(d_model=8, nhead=2, dim_feedforward=16) + enc = TransformerEncoder(layer, num_layers=2) + src = torch.randn(2, 5, 8, dtype=torch.cfloat) + out = enc(src) + assert out.shape == src.shape + + +def test_encoder_stack_with_norm(): + layer = TransformerEncoderLayer(d_model=8, nhead=2, dim_feedforward=16) + from complextorch.nn.modules.layernorm import LayerNorm + + enc = TransformerEncoder(layer, num_layers=1, norm=LayerNorm(8)) + src = torch.randn(2, 5, 8, dtype=torch.cfloat) + out = enc(src) + assert out.shape == src.shape + + +def test_decoder_layer_forward(): + layer = TransformerDecoderLayer(d_model=8, nhead=2, dim_feedforward=16) + tgt = torch.randn(2, 5, 8, dtype=torch.cfloat) + mem = torch.randn(2, 7, 8, dtype=torch.cfloat) + out = layer(tgt, mem) + assert out.shape == tgt.shape + + +def test_decoder_layer_batch_first_false(): + layer = TransformerDecoderLayer( + d_model=8, nhead=2, dim_feedforward=16, batch_first=False + ) + tgt = torch.randn(5, 2, 8, dtype=torch.cfloat) + mem = torch.randn(7, 2, 8, dtype=torch.cfloat) + out = layer(tgt, mem) + assert out.shape == tgt.shape + + +def test_decoder_stack(): + layer = TransformerDecoderLayer(d_model=8, nhead=2, dim_feedforward=16) + dec = TransformerDecoder(layer, num_layers=2) + tgt = torch.randn(2, 5, 8, dtype=torch.cfloat) + mem = torch.randn(2, 7, 8, dtype=torch.cfloat) + out = dec(tgt, mem) + assert out.shape == tgt.shape + + +def test_decoder_stack_with_norm(): + layer = TransformerDecoderLayer(d_model=8, nhead=2, dim_feedforward=16) + from complextorch.nn.modules.layernorm import LayerNorm + + dec = TransformerDecoder(layer, num_layers=1, norm=LayerNorm(8)) + tgt = torch.randn(2, 5, 8, dtype=torch.cfloat) + mem = torch.randn(2, 7, 8, dtype=torch.cfloat) + out = dec(tgt, mem) + assert out.shape == tgt.shape + + +def test_full_transformer(): + model = Transformer( + d_model=8, + nhead=2, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=16, + ) + src = torch.randn(2, 5, 8, dtype=torch.cfloat) + tgt = torch.randn(2, 4, 8, dtype=torch.cfloat) + out = model(src, tgt) + assert out.shape == tgt.shape diff --git a/tests/nn/modules/test_upsampling.py b/tests/nn/modules/test_upsampling.py new file mode 100644 index 0000000..d2c296a --- /dev/null +++ b/tests/nn/modules/test_upsampling.py @@ -0,0 +1,55 @@ +"""Tests for Upsample / PolarUpsample.""" + +from __future__ import annotations + +import torch + +from complextorch.nn.modules.upsampling import PolarUpsample, Upsample + + +def test_upsample_split_complex(): + up = Upsample(scale_factor=2.0, mode="nearest") + x = torch.randn(1, 2, 4, 4, dtype=torch.cfloat) + out = up(x) + assert out.shape == (1, 2, 8, 8) + assert out.is_complex() + + +def test_upsample_split_real_input(): + up = Upsample(scale_factor=2.0, mode="nearest") + x = torch.randn(1, 2, 4, 4) + out = up(x) + assert out.shape == (1, 2, 8, 8) + assert not out.is_complex() + + +def test_upsample_with_size(): + up = Upsample(size=(8, 8), mode="bilinear", align_corners=False) + x = torch.randn(1, 2, 4, 4, dtype=torch.cfloat) + out = up(x) + assert out.shape == (1, 2, 8, 8) + + +def test_upsample_extra_repr(): + s = Upsample(scale_factor=2.0).extra_repr() + assert "scale_factor" in s + + +def test_polar_upsample_complex(): + up = PolarUpsample(scale_factor=2.0, mode="nearest") + x = torch.randn(1, 2, 4, 4, dtype=torch.cfloat) + out = up(x) + assert out.shape == (1, 2, 8, 8) + assert out.is_complex() + + +def test_polar_upsample_real_input(): + up = PolarUpsample(scale_factor=2.0, mode="nearest") + x = torch.randn(1, 2, 4, 4) + out = up(x) + assert not out.is_complex() + + +def test_polar_upsample_extra_repr(): + s = PolarUpsample(scale_factor=2.0).extra_repr() + assert "scale_factor" in s diff --git a/tests/nn/relevance/__init__.py b/tests/nn/relevance/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/nn/relevance/test_base.py b/tests/nn/relevance/test_base.py new file mode 100644 index 0000000..f8495c8 --- /dev/null +++ b/tests/nn/relevance/test_base.py @@ -0,0 +1,91 @@ +"""Tests for BaseARD + named_penalties / named_relevance / compute_ard_masks.""" + +from __future__ import annotations + +import pytest +import torch.nn as nn + +from complextorch.nn.relevance import ( + BaseARD, + LinearVD, + compute_ard_masks, + named_penalties, + named_relevance, + penalties, +) + + +def test_base_ard_penalty_default_raises(): + base = BaseARD() + with pytest.raises(NotImplementedError): + _ = base.penalty + + +def test_base_ard_relevance_default_raises(): + base = BaseARD() + with pytest.raises(NotImplementedError): + base.relevance(threshold=0.0) + + +def test_named_penalties_sum(): + model = nn.Module() + model.l1 = LinearVD(4, 6) + pairs = list(named_penalties(model, reduction="sum")) + assert pairs[0][0] == "l1" + assert pairs[0][1].dim() == 0 + + +def test_named_penalties_mean(): + model = nn.Module() + model.l1 = LinearVD(4, 6) + pairs = list(named_penalties(model, reduction="mean")) + assert pairs[0][1].dim() == 0 + + +def test_named_penalties_no_reduction(): + model = nn.Module() + model.l1 = LinearVD(4, 6) + pairs = list(named_penalties(model, reduction=None)) + assert pairs[0][1].shape == (6, 4) + + +def test_named_penalties_invalid_reduction(): + model = nn.Module() + model.l1 = LinearVD(4, 6) + with pytest.raises(ValueError, match="reduction must be"): + list(named_penalties(model, reduction="bogus")) + + +def test_penalties_generator(): + model = nn.Module() + model.l1 = LinearVD(4, 6) + ps = list(penalties(model)) + assert len(ps) == 1 + + +def test_named_relevance(): + model = nn.Module() + model.l1 = LinearVD(4, 6) + pairs = list(named_relevance(model, threshold=0.0)) + assert pairs[0][0] == "l1" + assert pairs[0][1].shape == (6, 4) + + +def test_compute_ard_masks(): + model = nn.Module() + model.l1 = LinearVD(4, 6) + masks = compute_ard_masks(model, threshold=0.0) + assert "l1.mask" in masks + assert masks["l1.mask"].shape == (6, 4) + + +def test_compute_ard_masks_non_module_returns_empty(): + """Guard branch: non-Module input returns empty dict.""" + assert compute_ard_masks("not a module", threshold=0.0) == {} + + +def test_compute_ard_masks_root_layer(): + """Mask key for a root-level ARD layer is just 'mask'.""" + layer = LinearVD(4, 6) + masks = compute_ard_masks(layer, threshold=0.0) + assert "mask" in masks diff --git a/tests/nn/relevance/test_expi.py b/tests/nn/relevance/test_expi.py new file mode 100644 index 0000000..e00c215 --- /dev/null +++ b/tests/nn/relevance/test_expi.py @@ -0,0 +1,20 @@ +"""Tests for the differentiable Ei (exponential integral) wrapper.""" + +from __future__ import annotations + +import scipy.special +import torch + +from complextorch.nn.relevance._expi import torch_expi + + +def test_torch_expi_forward_matches_scipy(): + x = torch.tensor([-5.0, -1.0, 0.1, 1.0, 5.0]) + out = torch_expi(x).numpy() + expected = scipy.special.expi(x.numpy()) + assert (abs(out - expected) < 1e-5).all() + + +def test_torch_expi_gradcheck(): + x = torch.tensor([0.5, 1.5, -2.0], dtype=torch.double, requires_grad=True) + assert torch.autograd.gradcheck(torch_expi, (x,), eps=1e-6, atol=1e-4) diff --git a/tests/nn/relevance/test_layers.py b/tests/nn/relevance/test_layers.py new file mode 100644 index 0000000..d7d53a2 --- /dev/null +++ b/tests/nn/relevance/test_layers.py @@ -0,0 +1,144 @@ +"""Tests for the VD/ARD complex layers.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.nn.relevance import ( + BilinearARD, + BilinearVD, + Conv1dARD, + Conv1dVD, + Conv2dARD, + Conv2dVD, + Conv3dARD, + Conv3dVD, + LinearARD, + LinearVD, +) + +# ---------- LinearVD/ARD ---------- + + +@pytest.mark.parametrize("cls", [LinearVD, LinearARD]) +@pytest.mark.parametrize("training", [True, False]) +def test_linear_vd_ard_forward(cls, training): + layer = cls(4, 6, bias=True) + layer.train(training) + x = torch.randn(2, 4, dtype=torch.cfloat) + out = layer(x) + assert out.shape == (2, 6) + assert out.is_complex() + + +def test_linear_vd_penalty_finite(): + layer = LinearVD(4, 6) + p = layer.penalty + assert torch.isfinite(p).all() + assert (p >= -1e-3).all() # KL is non-negative up to floating noise + + +def test_linear_ard_penalty_finite(): + layer = LinearARD(4, 6) + p = layer.penalty + assert torch.isfinite(p).all() + + +def test_linear_vd_relevance_all_dropped(): + layer = LinearVD(4, 6) + mask = layer.relevance(threshold=-1e6) # nothing passes + assert mask.sum().item() == 0 + + +def test_linear_vd_relevance_all_kept(): + layer = LinearVD(4, 6) + mask = layer.relevance(threshold=1e6) # everything passes + assert mask.sum().item() == mask.numel() + + +def test_linear_vd_sparsity_call(): + layer = LinearVD(4, 6) + pairs = layer.sparsity(threshold=0.0) + assert len(pairs) == 1 + + +def test_linear_vd_no_bias(): + layer = LinearVD(4, 6, bias=False) + assert layer.bias is None + out = layer(torch.randn(2, 4, dtype=torch.cfloat)) + assert out.is_complex() + + +# ---------- BilinearVD/ARD ---------- + + +@pytest.mark.parametrize("cls", [BilinearVD, BilinearARD]) +@pytest.mark.parametrize("training", [True, False]) +def test_bilinear_vd_ard_forward(cls, training): + layer = cls(4, 5, 6, bias=True) + layer.train(training) + x1 = torch.randn(2, 4, dtype=torch.cfloat) + x2 = torch.randn(2, 5, dtype=torch.cfloat) + out = layer(x1, x2) + assert out.shape == (2, 6) + + +def test_bilinear_vd_no_bias(): + layer = BilinearVD(4, 5, 6, bias=False) + assert layer.bias is None + + +def test_bilinear_vd_no_conjugate(): + layer = BilinearVD(4, 5, 6, conjugate=False) + x1 = torch.randn(2, 4, dtype=torch.cfloat) + x2 = torch.randn(2, 5, dtype=torch.cfloat) + out = layer(x1, x2) + assert out.shape == (2, 6) + + +# ---------- ConvVD/ARD ---------- + + +@pytest.mark.parametrize( + ("cls", "shape"), + [ + (Conv1dVD, (1, 2, 8)), + (Conv2dVD, (1, 2, 6, 6)), + (Conv3dVD, (1, 2, 4, 4, 4)), + (Conv1dARD, (1, 2, 8)), + (Conv2dARD, (1, 2, 6, 6)), + (Conv3dARD, (1, 2, 4, 4, 4)), + ], +) +@pytest.mark.parametrize("training", [True, False]) +def test_conv_vd_ard_forward(cls, shape, training): + layer = cls(2, 4, kernel_size=3, padding=1) + layer.train(training) + x = torch.randn(*shape, dtype=torch.cfloat) + out = layer(x) + assert out.is_complex() + + +def test_conv_vd_no_bias(): + layer = Conv1dVD(2, 4, kernel_size=3, bias=False) + assert layer.bias is None + + +def test_conv_vd_invalid_padding_mode(): + with pytest.raises(ValueError, match="padding_mode"): + Conv1dVD(2, 4, kernel_size=3, padding_mode="reflect") + + +def test_conv_vd_tuple_kernel_size(): + layer = Conv2dVD(2, 4, kernel_size=(3, 5), padding=(1, 2)) + x = torch.randn(1, 2, 8, 10, dtype=torch.cfloat) + out = layer(x) + assert out.is_complex() + + +def test_conv_vd_str_padding(): + layer = Conv2dVD(2, 4, kernel_size=3, padding="same") + x = torch.randn(1, 2, 8, 8, dtype=torch.cfloat) + out = layer(x) + assert out.shape == (1, 4, 8, 8) diff --git a/tests/nn/test_functional.py b/tests/nn/test_functional.py new file mode 100644 index 0000000..d1e9ad1 --- /dev/null +++ b/tests/nn/test_functional.py @@ -0,0 +1,201 @@ +"""Tests for complextorch.nn.functional primitives and norm helpers.""" + +from __future__ import annotations + +import torch +import torch.nn as nn + +from complextorch.nn import functional as F + +# ---------- apply_complex (Gauss-trick lift) ---------- + + +def test_apply_complex_matches_manual_real_imag(): + real_m = nn.Linear(4, 3, bias=False) + imag_m = nn.Linear(4, 3, bias=False) + x = torch.randn(2, 4, dtype=torch.cfloat) + out = F.apply_complex(real_m, imag_m, x) + expected = torch.complex( + real_m(x.real) - imag_m(x.imag), + real_m(x.imag) + imag_m(x.real), + ) + torch.testing.assert_close(out, expected) + assert out.is_complex() + + +# ---------- apply_complex_split ---------- + + +def test_apply_complex_split_with_identity_returns_input(): + x = torch.randn(2, 5, dtype=torch.cfloat) + out = F.apply_complex_split(lambda t: t, lambda t: t, x) + torch.testing.assert_close(out, x) + + +def test_apply_complex_split_independent_functions(): + x = torch.randn(3, dtype=torch.cfloat) + out = F.apply_complex_split(torch.relu, torch.tanh, x) + expected = torch.complex(torch.relu(x.real), torch.tanh(x.imag)) + torch.testing.assert_close(out, expected) + + +# ---------- apply_complex_polar ---------- + + +def test_apply_complex_polar_phase_none_preserves_phase(): + x = torch.randn(4, dtype=torch.cfloat) + out = F.apply_complex_polar(torch.abs, None, x) + torch.testing.assert_close(out.angle(), x.angle(), atol=1e-5, rtol=1e-5) + torch.testing.assert_close(out.abs(), x.abs().abs()) # mag_fun=abs is idempotent + + +def test_apply_complex_polar_phase_none_zero_magnitude_safe(): + """The clamp(min=1e-12) prevents division by zero at z=0.""" + x = torch.zeros(3, dtype=torch.cfloat) + out = F.apply_complex_polar(torch.abs, None, x) + assert torch.isfinite(out.real).all() + assert torch.isfinite(out.imag).all() + + +def test_apply_complex_polar_with_phase_fun(): + x = torch.randn(4, dtype=torch.cfloat) + out = F.apply_complex_polar(torch.abs, lambda p: p * 0, x) + torch.testing.assert_close( + out.imag, torch.zeros_like(out.imag), atol=1e-6, rtol=1e-6 + ) + + +# ---------- inv_sqrtm2x2 ---------- + + +def test_inv_sqrtm2x2_symmetric_gives_inverse_squareroot(): + a = torch.tensor([2.0]) + d = torch.tensor([3.0]) + b = torch.tensor([0.5]) + w, x, y, z = F.inv_sqrtm2x2(a, b, None, d, symmetric=True) + assert y is None + A = torch.tensor([[2.0, 0.5], [0.5, 3.0]]) + B = torch.tensor([[w.item(), x.item()], [x.item(), z.item()]]) + recovered = B @ B @ A + torch.testing.assert_close(recovered, torch.eye(2), atol=1e-5, rtol=1e-5) + + +def test_inv_sqrtm2x2_non_symmetric_branch(): + a, b, c, d = (torch.tensor([v]) for v in (2.5, 0.1, 0.2, 3.0)) + w, x, y, z = F.inv_sqrtm2x2(a, b, c, d, symmetric=False) + A = torch.tensor([[2.5, 0.1], [0.2, 3.0]]) + B = torch.tensor([[w.item(), x.item()], [y.item(), z.item()]]) + recovered = B @ B @ A + torch.testing.assert_close(recovered, torch.eye(2), atol=1e-5, rtol=1e-5) + + +# ---------- whiten2x2_batch_norm ---------- + + +def _stack_re_im(z: torch.Tensor) -> torch.Tensor: + return torch.stack([z.real, z.imag], dim=0) + + +def test_whiten2x2_batch_norm_training_no_running_stats(): + # Need many samples for the whitening to converge on identity covariance. + z = torch.randn(256, 4, 16, dtype=torch.cfloat) * 3.0 + 0.5 + x = _stack_re_im(z) + out = F.whiten2x2_batch_norm(x, training=True) + assert out.shape == x.shape + var = out.var(dim=(1, 2), unbiased=False) + torch.testing.assert_close(var, torch.ones_like(var), atol=0.05, rtol=0.05) + # Cross-covariance between real/imag should be near zero per feature. + cov_ri = (out[0] * out[1]).mean(dim=(0, 2)) + assert cov_ri.abs().max().item() < 0.1 + + +def test_whiten2x2_batch_norm_updates_running_stats(): + z = torch.randn(8, 4, 6, dtype=torch.cfloat) + x = _stack_re_im(z) + running_mean = torch.zeros(2, 4) + running_cov = torch.eye(2).unsqueeze(-1).repeat(1, 1, 4) * 0.5 + rm_before = running_mean.clone() + rc_before = running_cov.clone() + F.whiten2x2_batch_norm( + x, training=True, running_mean=running_mean, running_cov=running_cov + ) + assert not torch.allclose(running_mean, rm_before) + assert not torch.allclose(running_cov, rc_before) + + +def test_whiten2x2_batch_norm_eval_uses_running_stats(): + z_train = torch.randn(8, 4, 6, dtype=torch.cfloat) + x_train = _stack_re_im(z_train) + running_mean = torch.zeros(2, 4) + running_cov = torch.eye(2).unsqueeze(-1).repeat(1, 1, 4).clone() + F.whiten2x2_batch_norm( + x_train, training=True, running_mean=running_mean, running_cov=running_cov + ) + rm_after_train = running_mean.clone() + rc_after_train = running_cov.clone() + # Eval pass should NOT update the running stats + z_eval = torch.randn(8, 4, 6, dtype=torch.cfloat) + x_eval = _stack_re_im(z_eval) + F.whiten2x2_batch_norm( + x_eval, training=False, running_mean=running_mean, running_cov=running_cov + ) + torch.testing.assert_close(running_mean, rm_after_train) + torch.testing.assert_close(running_cov, rc_after_train) + + +# ---------- batch_norm wrapper ---------- + + +def test_batch_norm_without_affine(): + z = torch.randn(8, 4, 10, dtype=torch.cfloat) + out = F.batch_norm(z, training=True) + assert out.is_complex() + assert out.shape == z.shape + + +def test_batch_norm_with_affine(): + z = torch.randn(8, 4, 10, dtype=torch.cfloat) + weight = torch.eye(2).unsqueeze(-1).repeat(1, 1, 4) + bias = torch.zeros(2, 4) + out = F.batch_norm(z, weight=weight, bias=bias, training=True) + assert out.shape == z.shape + + +def test_batch_norm_with_running_stats_and_affine(): + z = torch.randn(8, 4, 10, dtype=torch.cfloat) + running_mean = torch.zeros(2, 4) + running_var = torch.eye(2).unsqueeze(-1).repeat(1, 1, 4) + weight = torch.eye(2).unsqueeze(-1).repeat(1, 1, 4) + bias = torch.zeros(2, 4) + out = F.batch_norm( + z, + running_mean=running_mean, + running_var=running_var, + weight=weight, + bias=bias, + training=False, + ) + assert out.shape == z.shape + + +# ---------- whiten2x2_layer_norm + layer_norm ---------- + + +def test_layer_norm_without_affine(): + z = torch.randn(8, 16, dtype=torch.cfloat) + out = F.layer_norm(z, normalized_shape=[16]) + assert out.shape == z.shape + + +def test_layer_norm_with_affine(): + z = torch.randn(8, 16, dtype=torch.cfloat) + weight = torch.eye(2).unsqueeze(-1).repeat(1, 1, 16) + bias = torch.zeros(2, 16) + out = F.layer_norm(z, normalized_shape=[16], weight=weight, bias=bias) + assert out.shape == z.shape + + +def test_layer_norm_multi_dim_normalized_shape(): + z = torch.randn(4, 8, 8, dtype=torch.cfloat) + out = F.layer_norm(z, normalized_shape=[8, 8]) + assert out.shape == z.shape diff --git a/tests/nn/test_init.py b/tests/nn/test_init.py new file mode 100644 index 0000000..c0a1255 --- /dev/null +++ b/tests/nn/test_init.py @@ -0,0 +1,141 @@ +"""Tests for complextorch.nn.init complex initializers.""" + +from __future__ import annotations + +import math + +import pytest +import torch + +from complextorch.nn import init + +# ---------- _get_fans + _check_complex via public functions ---------- + + +def test_kaiming_rejects_real_tensor(): + with pytest.raises(TypeError, match="expects a complex tensor"): + init.kaiming_normal_(torch.zeros(4, 4)) + + +def test_kaiming_normal_fan_in(): + w = torch.empty(8, 16, dtype=torch.cfloat) + init.kaiming_normal_(w, mode="fan_in", nonlinearity="relu") + # Expected per-part std: sqrt(2) / sqrt(2 * 16) = 0.25 + assert w.is_complex() + assert abs(w.real.std().item() - 0.25) < 0.1 + + +def test_kaiming_normal_fan_out(): + w = torch.empty(8, 16, dtype=torch.cfloat) + init.kaiming_normal_(w, mode="fan_out", nonlinearity="linear") + assert torch.isfinite(w).all() + + +def test_kaiming_normal_3d_tensor_fan_calc(): + """tensor.dim() > 2 branch in _get_fans.""" + w = torch.empty(8, 4, 3, dtype=torch.cfloat) + init.kaiming_normal_(w) + assert w.shape == (8, 4, 3) + + +def test_kaiming_normal_1d_tensor_fan_calc(): + """tensor.dim() < 2 branch in _get_fans.""" + w = torch.empty(10, dtype=torch.cfloat) + init.kaiming_normal_(w) + assert w.shape == (10,) + + +def test_kaiming_uniform(): + w = torch.empty(8, 16, dtype=torch.cfloat) + init.kaiming_uniform_(w) + assert torch.isfinite(w).all() + + +def test_kaiming_invalid_nonlinearity_raises(): + w = torch.empty(4, 4, dtype=torch.cfloat) + with pytest.raises(ValueError, match="Unsupported nonlinearity"): + init.kaiming_normal_(w, nonlinearity="bogus") + + +@pytest.mark.parametrize( + "nl", + [ + "linear", + "tanh", + "relu", + "leaky_relu", + "selu", + "sigmoid", + "conv1d", + "conv2d", + "conv3d", + ], +) +def test_gain_branches_via_kaiming(nl): + w = torch.empty(4, 4, dtype=torch.cfloat) + init.kaiming_normal_(w, nonlinearity=nl) + assert torch.isfinite(w).all() + + +# ---------- Xavier ---------- + + +def test_xavier_normal(): + w = torch.empty(8, 16, dtype=torch.cfloat) + init.xavier_normal_(w, gain=1.0) + expected_std = 1.0 / math.sqrt(8 + 16) + assert abs(w.real.std().item() - expected_std) < 0.1 + + +def test_xavier_uniform(): + w = torch.empty(8, 16, dtype=torch.cfloat) + init.xavier_uniform_(w) + assert torch.isfinite(w).all() + + +# ---------- Trabelsi standard (polar Rayleigh) ---------- + + +@pytest.mark.parametrize("kind", ["glorot", "xavier", "he", "kaiming"]) +def test_trabelsi_standard_kinds(kind): + w = torch.empty(16, 32, dtype=torch.cfloat) + init.trabelsi_standard_(w, kind=kind) + assert torch.isfinite(w).all() + # Phases should span (-pi, pi) -> at least non-trivial spread + assert w.angle().std().item() > 0.5 + + +def test_trabelsi_standard_invalid_kind(): + w = torch.empty(4, 4, dtype=torch.cfloat) + with pytest.raises(ValueError, match="Unknown kind"): + init.trabelsi_standard_(w, kind="bogus") + + +# ---------- Trabelsi independent (semi-unitary) ---------- + + +@pytest.mark.parametrize("kind", ["glorot", "he"]) +def test_trabelsi_independent_kinds(kind): + w = torch.empty(8, 16, dtype=torch.cfloat) + init.trabelsi_independent_(w, kind=kind) + assert torch.isfinite(w).all() + # Approximately semi-unitary scaled by `scale`: w @ w.conj().T = scale^2 * I + prod = w @ w.conj().T + diag = torch.diag(prod).real + off = prod - torch.diag(torch.diag(prod)) + # diagonal entries should be ~equal + assert (diag.std() / diag.mean()).item() < 0.1 + # off-diagonal should be small relative to diagonal + assert off.abs().mean().item() < 0.1 * diag.mean().item() + + +def test_trabelsi_independent_invalid_kind(): + w = torch.empty(4, 4, dtype=torch.cfloat) + with pytest.raises(ValueError, match="Unknown kind"): + init.trabelsi_independent_(w, kind="weird") + + +def test_trabelsi_independent_requires_2d(): + w = torch.empty(10, dtype=torch.cfloat) + with pytest.raises(ValueError, match="at least 2 dims"): + init.trabelsi_independent_(w) diff --git a/tests/nn/utils/__init__.py b/tests/nn/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/nn/utils/test_sparsity.py b/tests/nn/utils/test_sparsity.py new file mode 100644 index 0000000..c080c09 --- /dev/null +++ b/tests/nn/utils/test_sparsity.py @@ -0,0 +1,86 @@ +"""Tests for SparsityStats / named_sparsity / sparsity.""" + +from __future__ import annotations + +import pytest +import torch +import torch.nn as nn + +from complextorch.nn.masked import LinearMasked +from complextorch.nn.utils import SparsityStats, named_sparsity, sparsity + + +def test_sparsity_stats_default_raises(): + class M(SparsityStats): + pass + + m = M() + with pytest.raises(NotImplementedError): + m.sparsity() + + +def test_named_sparsity_no_subscriber_yields_nothing(): + model = nn.Linear(4, 6) + pairs = list(named_sparsity(model)) + assert pairs == [] + + +def test_named_sparsity_on_masked_layer(): + model = nn.Module() + model.l1 = LinearMasked(4, 6) + model.l1.mask = torch.zeros(6, 4) + pairs = list(named_sparsity(model)) + assert len(pairs) == 1 + name, (n_zeros, n_total) = pairs[0] + assert "weight" in name + assert n_zeros == 24 # 6*4 weights, all zero + assert n_total == 24 + + +def test_sparsity_global(): + model = nn.Module() + model.l1 = LinearMasked(4, 6) + model.l1.mask = torch.zeros(6, 4) + s = sparsity(model) + assert s == 1.0 + + +def test_sparsity_no_params_returns_zero(): + """Empty module -> 0/0 -> 0.0 via the `if total` guard.""" + model = nn.Module() + s = sparsity(model) + assert s == 0.0 + + +def test_named_sparsity_dense_masked_layer_reports_zero(): + """LinearMasked without a mask set -> n_dropped == 0.""" + model = nn.Module() + model.l1 = LinearMasked(4, 6) + pairs = list(named_sparsity(model)) + assert len(pairs) == 1 + _, (n_zeros, n_total) = pairs[0] + assert n_zeros == 0 + assert n_total == 24 + + +def test_named_sparsity_skips_duplicate_pids_and_unknown_pids(): + """Hit the 'pid in seen or pid not in pid_to_name: continue' branch.""" + + class DupAndForeign(SparsityStats): + def __init__(self) -> None: + super().__init__() + self.weight = nn.Parameter(torch.zeros(3)) + + def sparsity(self, **kwargs): + # Real pid twice + a foreign pid. + return [ + (id(self.weight), 1), + (id(self.weight), 1), # duplicate -> 'pid in seen' branch + (99999999, 7), # foreign -> 'pid not in pid_to_name' branch + ] + + model = nn.Module() + model.s = DupAndForeign() + pairs = list(named_sparsity(model)) + # Only the unique, known pid yields a result. + assert len(pairs) == 1 diff --git a/tests/test_package.py b/tests/test_package.py new file mode 100644 index 0000000..309b154 --- /dev/null +++ b/tests/test_package.py @@ -0,0 +1,24 @@ +"""Top-level package smoke tests.""" + +from __future__ import annotations + +import re + +import complextorch + + +def test_version_is_semver(): + assert re.match(r"^\d+\.\d+\.\d+", complextorch.__version__) + + +def test_public_subpackages_importable(): + assert complextorch.nn is not None + assert complextorch.signal is not None + assert complextorch.transforms is not None + assert complextorch.datasets is not None + assert complextorch.models is not None + + +def test_author_metadata(): + assert isinstance(complextorch.__author__, str) + assert complextorch.__author__ diff --git a/tests/test_signal.py b/tests/test_signal.py new file mode 100644 index 0000000..950d94e --- /dev/null +++ b/tests/test_signal.py @@ -0,0 +1,96 @@ +"""Tests for complextorch.signal.pwelch.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.signal import pwelch + + +def test_pwelch_real_default_onesided(): + x = torch.randn(2048) + freqs, psd = pwelch(x, window=256, fs=1.0) + assert freqs.shape == (129,) + assert psd.shape == (129,) + assert torch.all(psd >= 0) + + +def test_pwelch_complex_two_sided(): + x = torch.randn(2048, dtype=torch.cfloat) + freqs, psd = pwelch(x, window=256, fs=2.0) + assert freqs.shape == (256,) + assert psd.shape == (256,) + + +def test_pwelch_user_window_tensor(): + x = torch.randn(512) + win = torch.hann_window(64) + freqs, psd = pwelch(x, window=win, fs=1.0) + assert freqs.shape[0] == 33 + assert psd.shape[0] == 33 + + +def test_pwelch_scaling_spectrum(): + x = torch.randn(1024) + _, psd_density = pwelch(x, window=128, scaling="density") + _, psd_spectrum = pwelch(x, window=128, scaling="spectrum") + assert psd_density.shape == psd_spectrum.shape + assert not torch.allclose(psd_density, psd_spectrum) + + +def test_pwelch_detrend_none(): + x = torch.randn(512) + 10.0 # large DC offset + _, psd_none = pwelch(x, window=64, detrend="none") + _, psd_const = pwelch(x, window=64, detrend="constant") + assert psd_none[0] > psd_const[0] # detrending removes DC spike + + +def test_pwelch_overlap_eq_window_raises(): + with pytest.raises(ValueError, match="must be smaller than window length"): + pwelch(torch.randn(64), window=32, n_overlap=32) + + +def test_pwelch_invalid_detrend_raises(): + with pytest.raises(ValueError, match="detrend must be"): + pwelch(torch.randn(64), window=16, detrend="linear") + + +def test_pwelch_invalid_scaling_raises(): + with pytest.raises(ValueError, match="scaling must be"): + pwelch(torch.randn(64), window=16, scaling="bogus") + + +def test_pwelch_onesided_with_complex_raises(): + x = torch.randn(128, dtype=torch.cfloat) + with pytest.raises(ValueError, match="invalid for complex signals"): + pwelch(x, window=16, return_onesided=True) + + +def test_pwelch_window_longer_than_signal(): + """Window length is clamped to signal length.""" + x = torch.randn(8) + freqs, _psd = pwelch(x, window=64) + assert freqs.shape[0] == 5 # rfft(8) → 5 + + +def test_pwelch_batched_input(): + x = torch.randn(4, 3, 512) + freqs, psd = pwelch(x, window=64, fs=1.0) + assert psd.shape == (4, 3, 33) + assert freqs.shape == (33,) + + +def test_pwelch_two_freq_bins_path(): + """Hit the elif psd.shape[-1] == 2 branch (win_len=2 → rfft returns 2 bins).""" + x = torch.randn(16) + freqs, psd = pwelch(x, window=2, n_overlap=0) + assert freqs.shape[0] == 2 + assert psd.shape[0] == 2 + + +def test_pwelch_differentiable(): + x = torch.randn(256, requires_grad=True) + _, psd = pwelch(x, window=64) + psd.sum().backward() + assert x.grad is not None diff --git a/tests/transforms/__init__.py b/tests/transforms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/transforms/test_functional.py b/tests/transforms/test_functional.py new file mode 100644 index 0000000..e6fe1da --- /dev/null +++ b/tests/transforms/test_functional.py @@ -0,0 +1,66 @@ +"""Tests for transforms.functional helpers (polsar_dict_to_array, rescale_intensity).""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.transforms.functional import polsar_dict_to_array, rescale_intensity + + +def test_polsar_dict_to_array_default_order(): + d = { + "HH": torch.randn(4, 4, dtype=torch.cfloat), + "HV": torch.randn(4, 4, dtype=torch.cfloat), + "VH": torch.randn(4, 4, dtype=torch.cfloat), + "VV": torch.randn(4, 4, dtype=torch.cfloat), + } + arr = polsar_dict_to_array(d) + assert arr.shape == (4, 4, 4) + torch.testing.assert_close(arr[0], d["HH"]) + + +def test_polsar_dict_to_array_partial_order(): + d = { + "HH": torch.randn(4, 4, dtype=torch.cfloat), + "VV": torch.randn(4, 4, dtype=torch.cfloat), + } + arr = polsar_dict_to_array(d, order=("HH", "VV")) + assert arr.shape == (2, 4, 4) + + +def test_polsar_dict_to_array_no_match(): + with pytest.raises(ValueError, match="none of the requested"): + polsar_dict_to_array({"foo": torch.zeros(4, 4)}, order=("HH",)) + + +def test_rescale_intensity_default_range(): + x = torch.tensor([0.0, 5.0, 10.0]) + out = rescale_intensity(x) + torch.testing.assert_close(out, torch.tensor([0.0, 0.5, 1.0])) + + +def test_rescale_intensity_explicit_range(): + x = torch.tensor([2.0, 4.0, 6.0]) + out = rescale_intensity(x, in_range=(0.0, 10.0), out_range=(-1.0, 1.0)) + expected = torch.tensor([-0.6, -0.2, 0.2]) + torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5) + + +def test_rescale_intensity_complex_raises(): + with pytest.raises(TypeError, match="real tensor"): + rescale_intensity(torch.zeros(3, dtype=torch.cfloat)) + + +def test_rescale_intensity_degenerate_range(): + x = torch.ones(4) + out = rescale_intensity(x) # in_range = (1, 1) + # All values clamp to out_lo (0.0 by default) + torch.testing.assert_close(out, torch.zeros(4)) + + +def test_rescale_intensity_clamps_values(): + x = torch.tensor([-1.0, 5.0, 11.0]) + out = rescale_intensity(x, in_range=(0.0, 10.0)) + # -1 -> clamp to 0 -> 0.0; 11 -> clamp to 10 -> 1.0 + torch.testing.assert_close(out, torch.tensor([0.0, 0.5, 1.0])) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py new file mode 100644 index 0000000..352ad7f --- /dev/null +++ b/tests/transforms/test_transforms.py @@ -0,0 +1,296 @@ +"""Tests for the class-based transforms.""" + +from __future__ import annotations + +import pytest +import torch + +from complextorch.transforms import ( + FFT2, + HWC2CHW, + IFFT2, + Amplitude, + CenterCrop, + FFTResize, + LogAmplitude, + Normalize, + PadIfNeeded, + PolSAR, + RandomPhase, + RealImaginary, + SpatialResize, + ToImaginary, + ToReal, + ToTensor, + Unsqueeze, +) + +# ---------- Casting / shape ---------- + + +def test_to_tensor_from_list(): + out = ToTensor()([[1.0, 2.0], [3.0, 4.0]]) + assert out.is_complex() + assert out.shape == (2, 2) + + +def test_to_tensor_extra_repr(): + # torch.cfloat is an alias for torch.complex64 + assert "complex64" in ToTensor(dtype=torch.cfloat).extra_repr() + + +def test_unsqueeze(): + out = Unsqueeze(dim=0)(torch.zeros(3)) + assert out.shape == (1, 3) + assert "dim=0" in Unsqueeze(0).extra_repr() + + +def test_hwc2chw(): + x = torch.randn(8, 8, 3) + out = HWC2CHW()(x) + assert out.shape == (3, 8, 8) + + +def test_hwc2chw_invalid_dim(): + with pytest.raises(ValueError, match="3-D tensor"): + HWC2CHW()(torch.zeros(2, 3, 8, 8)) + + +# ---------- Magnitude / component extraction ---------- + + +def test_log_amplitude_preserve_phase(): + x = torch.randn(3, 4, dtype=torch.cfloat) + 0.1 + out = LogAmplitude(scale=2.0, preserve_phase=True)(x) + assert out.is_complex() + # Phase preserved + torch.testing.assert_close(out.angle(), x.angle(), atol=1e-5, rtol=1e-5) + + +def test_log_amplitude_magnitude_only(): + x = torch.randn(3, 4, dtype=torch.cfloat) + out = LogAmplitude(preserve_phase=False)(x) + assert not out.is_complex() + + +def test_log_amplitude_real_input(): + x = torch.randn(3, 4) + out = LogAmplitude()(x) + # Real input -> still real (preserve_phase requires complex) + assert not out.is_complex() + + +def test_amplitude(): + x = torch.randn(3, 4, dtype=torch.cfloat) + torch.testing.assert_close(Amplitude()(x), x.abs()) + + +def test_to_real_complex_and_real(): + x = torch.randn(3, 4, dtype=torch.cfloat) + torch.testing.assert_close(ToReal()(x), x.real) + x_r = torch.randn(3, 4) + torch.testing.assert_close(ToReal()(x_r), x_r) + + +def test_to_imaginary_complex_and_real(): + x = torch.randn(3, 4, dtype=torch.cfloat) + torch.testing.assert_close(ToImaginary()(x), x.imag) + x_r = torch.randn(3, 4) + torch.testing.assert_close(ToImaginary()(x_r), torch.zeros_like(x_r)) + + +def test_real_imaginary_stack(): + x = torch.randn(2, 8, 8, dtype=torch.cfloat) + out = RealImaginary()(x) + assert out.shape == (4, 8, 8) + + +def test_real_imaginary_real_passthrough(): + x = torch.randn(2, 8, 8) + out = RealImaginary()(x) + torch.testing.assert_close(out, x) + + +# ---------- Normalize ---------- + + +def test_normalize_forward(): + mean = torch.zeros(3, dtype=torch.cfloat) + cov = torch.eye(2).unsqueeze(0).expand(3, 2, 2).clone() + norm = Normalize(mean=mean, covariance=cov) + x = torch.randn(3, 4, 4, dtype=torch.cfloat) + out = norm(x) + assert out.shape == x.shape + + +def test_normalize_invalid_covariance_shape(): + with pytest.raises(ValueError, match="covariance must have shape"): + Normalize( + mean=torch.zeros(3, dtype=torch.cfloat), covariance=torch.zeros(3, 2, 3) + ) + + +def test_normalize_wrong_channel_dim(): + mean = torch.zeros(3, dtype=torch.cfloat) + cov = torch.eye(2).unsqueeze(0).expand(3, 2, 2).clone() + norm = Normalize(mean=mean, covariance=cov) + with pytest.raises(ValueError, match=r"with C=3"): + norm(torch.randn(5, 4, 4, dtype=torch.cfloat)) + + +# ---------- RandomPhase ---------- + + +def test_random_phase_preserves_magnitude(): + x = torch.randn(3, 4, dtype=torch.cfloat) + out = RandomPhase()(x) + torch.testing.assert_close(out.abs(), x.abs(), atol=1e-5, rtol=1e-5) + + +def test_random_phase_centered(): + x = torch.randn(3, 4, dtype=torch.cfloat) + out = RandomPhase(centered=True)(x) + torch.testing.assert_close(out.abs(), x.abs(), atol=1e-5, rtol=1e-5) + + +def test_random_phase_real_input_casts(): + x = torch.randn(3, 4) + out = RandomPhase()(x) + assert out.is_complex() + + +# ---------- Spatial ---------- + + +def test_pad_if_needed_smaller(): + x = torch.randn(3, 4, 4, dtype=torch.cfloat) + out = PadIfNeeded(min_h=8, min_w=8)(x) + assert out.shape == (3, 8, 8) + + +def test_pad_if_needed_already_large_returns_input(): + x = torch.randn(3, 8, 8, dtype=torch.cfloat) + out = PadIfNeeded(min_h=4, min_w=4)(x) + assert out.shape == (3, 8, 8) + + +def test_pad_if_needed_real(): + x = torch.randn(3, 4, 4) + out = PadIfNeeded(min_h=8, min_w=8)(x) + assert out.shape == (3, 8, 8) + + +def test_pad_if_needed_invalid_dim_raises(): + """_check_chw rejects 1-D or 2-D inputs.""" + with pytest.raises(ValueError, match="C, H, W"): + PadIfNeeded(min_h=4, min_w=4)(torch.zeros(8)) + + +def test_center_crop(): + x = torch.randn(3, 8, 8, dtype=torch.cfloat) + out = CenterCrop(4, 4)(x) + assert out.shape == (3, 4, 4) + + +def test_center_crop_too_small_raises(): + with pytest.raises(ValueError, match="larger than input"): + CenterCrop(16, 16)(torch.randn(3, 4, 4)) + + +def test_spatial_resize_complex_3d(): + x = torch.randn(2, 4, 4, dtype=torch.cfloat) + out = SpatialResize(8, 8)(x) + assert out.shape == (2, 8, 8) + + +def test_spatial_resize_complex_4d(): + x = torch.randn(1, 2, 4, 4, dtype=torch.cfloat) + out = SpatialResize(8, 8)(x) + assert out.shape == (1, 2, 8, 8) + + +def test_spatial_resize_real(): + x = torch.randn(2, 4, 4) + out = SpatialResize(8, 8)(x) + assert out.shape == (2, 8, 8) + + +# ---------- Spectral ---------- + + +def test_fft_ifft_round_trip(): + x = torch.randn(1, 2, 8, 8, dtype=torch.cfloat) + out = IFFT2()(FFT2()(x)) + torch.testing.assert_close(out, x, atol=1e-5, rtol=1e-5) + + +def test_fft_resize_downsize(): + x = torch.randn(1, 2, 16, 16, dtype=torch.cfloat) + out = FFTResize(8, 8)(x) + assert out.shape == (1, 2, 8, 8) + + +def test_fft_resize_upsize(): + x = torch.randn(1, 2, 8, 8, dtype=torch.cfloat) + out = FFTResize(16, 16)(x) + assert out.shape == (1, 2, 16, 16) + + +def test_fft_resize_no_energy_preserve(): + x = torch.randn(1, 2, 8, 8, dtype=torch.cfloat) + out = FFTResize(16, 16, energy_preserving=False)(x) + assert out.shape == (1, 2, 16, 16) + + +def test_fft_resize_real_input(): + x = torch.randn(1, 2, 8, 8) + out = FFTResize(4, 4)(x) + assert out.shape == (1, 2, 4, 4) + + +# ---------- PolSAR ---------- + + +@pytest.mark.parametrize( + ("in_c", "out_c"), + [ + (1, 1), + (2, 1), + (2, 2), + (3, 1), + (3, 2), + (3, 3), + (4, 1), + (4, 2), + (4, 3), + (4, 4), + ], +) +def test_polsar_combinations(in_c, out_c): + x = torch.randn(in_c, 8, 8, dtype=torch.cfloat) + out = PolSAR(out_channels=out_c)(x) + assert out.shape[0] == out_c + + +def test_polsar_out_channels_invalid(): + with pytest.raises(ValueError, match="must be in"): + PolSAR(out_channels=5) + + +def test_polsar_too_few_dims_raises(): + with pytest.raises(ValueError, match="at least 3 dims"): + PolSAR(out_channels=1)(torch.randn(8)) + + +def test_polsar_invalid_combos(): + with pytest.raises(ValueError, match="single-channel"): + PolSAR(out_channels=2)(torch.randn(1, 8, 8, dtype=torch.cfloat)) + with pytest.raises(ValueError, match="2-channel"): + PolSAR(out_channels=3)(torch.randn(2, 8, 8, dtype=torch.cfloat)) + with pytest.raises(ValueError, match="3-channel"): + PolSAR(out_channels=4)(torch.randn(3, 8, 8, dtype=torch.cfloat)) + + +def test_polsar_unsupported_channel_count(): + with pytest.raises(ValueError, match="unsupported input channel"): + PolSAR(out_channels=1)(torch.randn(5, 8, 8, dtype=torch.cfloat))