From fbc92284642a003ad8bdcd66e616070e1143abb0 Mon Sep 17 00:00:00 2001
From: Oscar Esteban <code@oscaresteban.es>
Date: Wed, 20 Jul 2022 17:04:59 +0200
Subject: [PATCH] ENH: Collapse linear and nonlinear transforms chains

Very undertested, but currently there is a test that uses a "collapsed"
transform on an ITK's .h5 file with one affine and one nonlinear.

BSpline transforms not currently supported.

Resolves #89.
---
 nitransforms/linear.py            | 16 +++++++---------
 nitransforms/manip.py             | 16 +++++++---------
 nitransforms/tests/test_linear.py |  6 +++---
 nitransforms/tests/test_manip.py  |  8 +++++++-
 4 files changed, 24 insertions(+), 22 deletions(-)

diff --git a/nitransforms/linear.py b/nitransforms/linear.py
index 9c430d3b..239f0ebc 100644
--- a/nitransforms/linear.py
+++ b/nitransforms/linear.py
@@ -123,19 +123,17 @@ def __matmul__(self, b):
         True
 
         >>> xfm1 = Affine([[1, 0, 0, 4], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
-        >>> xfm1 @ np.eye(4) == xfm1
+        >>> xfm1 @ Affine() == xfm1
         True
 
         """
-        if not isinstance(b, self.__class__):
-            _b = self.__class__(b)
-        else:
-            _b = b
+        if isinstance(b, self.__class__):
+            return self.__class__(
+                b.matrix @ self.matrix,
+                reference=b.reference,
+            )
 
-        retval = self.__class__(self.matrix.dot(_b.matrix))
-        if _b.reference:
-            retval.reference = _b.reference
-        return retval
+        return b @ self
 
     @property
     def matrix(self):
diff --git a/nitransforms/manip.py b/nitransforms/manip.py
index 233f5adf..58d15058 100644
--- a/nitransforms/manip.py
+++ b/nitransforms/manip.py
@@ -8,7 +8,6 @@
 ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
 """Common interface for transforms."""
 from collections.abc import Iterable
-import numpy as np
 
 from .base import (
     TransformBase,
@@ -140,9 +139,9 @@ def map(self, x, inverse=False):
 
         return x
 
-    def asaffine(self, indices=None):
+    def collapse(self):
         """
-        Combine a succession of linear transforms into one.
+        Combine a succession of transforms into one.
 
         Example
         ------
@@ -150,7 +149,7 @@ def asaffine(self, indices=None):
         ...     Affine.from_matvec(vec=(2, -10, 3)),
         ...     Affine.from_matvec(vec=(-2, 10, -3)),
         ... ])
-        >>> chain.asaffine()
+        >>> chain.collapse()
         array([[1., 0., 0., 0.],
                [0., 1., 0., 0.],
                [0., 0., 1., 0.],
@@ -160,7 +159,7 @@ def asaffine(self, indices=None):
         ...     Affine.from_matvec(vec=(1, 2, 3)),
         ...     Affine.from_matvec(mat=[[0, 1, 0], [0, 0, 1], [1, 0, 0]]),
         ... ])
-        >>> chain.asaffine()
+        >>> chain.collapse()
         array([[0., 1., 0., 2.],
                [0., 0., 1., 3.],
                [1., 0., 0., 1.],
@@ -168,7 +167,7 @@ def asaffine(self, indices=None):
 
         >>> np.allclose(
         ...     chain.map((4, -2, 1)),
-        ...     chain.asaffine().map((4, -2, 1)),
+        ...     chain.collapse().map((4, -2, 1)),
         ... )
         True
 
@@ -178,9 +177,8 @@ def asaffine(self, indices=None):
             The indices of the values to extract.
 
         """
-        affines = self.transforms if indices is None else np.take(self.transforms, indices)
-        retval = affines[0]
-        for xfm in affines[1:]:
+        retval = self.transforms[-1]
+        for xfm in reversed(self.transforms[:-1]):
             retval = xfm @ retval
         return retval
 
diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py
index eea77b7f..f3f83b38 100644
--- a/nitransforms/tests/test_linear.py
+++ b/nitransforms/tests/test_linear.py
@@ -372,10 +372,10 @@ def test_mulmat_operator(testdata_path):
     mat2 = from_matvec(np.eye(3), (4, 2, -1))
     aff = nitl.Affine(mat1, reference=ref)
 
-    composed = aff @ mat2
+    composed = aff @ nitl.Affine(mat2)
     assert composed.reference is None
-    assert composed == nitl.Affine(mat1.dot(mat2))
+    assert composed == nitl.Affine(mat2 @ mat1)
 
     composed = nitl.Affine(mat2) @ aff
     assert composed.reference == aff.reference
-    assert composed == nitl.Affine(mat2.dot(mat1), reference=ref)
+    assert composed == nitl.Affine(mat1 @ mat2, reference=ref)
diff --git a/nitransforms/tests/test_manip.py b/nitransforms/tests/test_manip.py
index 6dee540e..59f7f3b7 100644
--- a/nitransforms/tests/test_manip.py
+++ b/nitransforms/tests/test_manip.py
@@ -60,6 +60,12 @@ def test_itk_h5(tmp_path, testdata_path):
     # A certain tolerance is necessary because of resampling at borders
     assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL
 
+    col_moved = xfm.collapse().apply(img_fname, order=0)
+    col_moved.to_filename("nt_collapse_resampled.nii.gz")
+    diff = sw_moved.get_fdata() - col_moved.get_fdata()
+    # A certain tolerance is necessary because of resampling at borders
+    assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL
+
 
 @pytest.mark.parametrize("ext0", ["lta", "tfm"])
 @pytest.mark.parametrize("ext1", ["lta", "tfm"])
@@ -81,7 +87,7 @@ def test_collapse_affines(tmp_path, data_path, ext0, ext1, ext2):
         ]
     )
     assert np.allclose(
-        chain.asaffine().matrix,
+        chain.collapse().matrix,
         Affine.from_filename(
             data_path / "regressions" / f"from-fsnative_to-bold_mode-image.{ext2}",
             fmt=f"{FMT[ext2]}",