From 97cdf4e4c2a07fe7499658e0d4fc970d9ac66ded Mon Sep 17 00:00:00 2001
From: Margus Niitsoo <velochy@gmail.com>
Date: Mon, 21 Apr 2025 17:39:03 +0300
Subject: [PATCH] Generalize ordered transform

---
 pymc/distributions/transforms.py      | 38 +++++++++++++++++++++------
 tests/distributions/test_transform.py | 29 ++++++++++++++++++--
 2 files changed, 57 insertions(+), 10 deletions(-)

diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py
index be0df56541..ebdaf3c3e1 100644
--- a/pymc/distributions/transforms.py
+++ b/pymc/distributions/transforms.py
@@ -89,26 +89,48 @@ def log_jac_det(self, value, *inputs):
 
 
 class Ordered(Transform):
+    """
+    Transforms a vector of values into a vector of ordered values.
+
+    Parameters
+    ----------
+    positive: If True, all values are positive. This has better geometry than just chaining with a log transform.
+    ascending: If True, the values are in ascending order (default). If False, the values are in descending order.
+    """
+
     name = "ordered"
 
-    def __init__(self, ndim_supp=None):
+    def __init__(self, ndim_supp=None, positive=False, ascending=True):
         if ndim_supp is not None:
             warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning)
+        self.positive = positive
+        self.ascending = ascending
 
     def backward(self, value, *inputs):
-        x = pt.zeros(value.shape)
-        x = pt.set_subtensor(x[..., 0], value[..., 0])
-        x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:]))
-        return pt.cumsum(x, axis=-1)
+        if self.positive:  # Transform both initial value and deltas to be positive
+            x = pt.exp(value)
+        else:  # Transform only deltas to be positive
+            x = pt.empty(value.shape)
+            x = pt.set_subtensor(x[..., 0], value[..., 0])
+            x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:]))
+        x = pt.cumsum(x, axis=-1)  # Add deltas cumulatively to initial value
+        if not self.ascending:
+            x = x[..., ::-1]
+        return x
 
     def forward(self, value, *inputs):
-        y = pt.zeros(value.shape)
-        y = pt.set_subtensor(y[..., 0], value[..., 0])
+        if not self.ascending:
+            value = value[..., ::-1]
+        y = pt.empty(value.shape)
+        y = pt.set_subtensor(y[..., 0], pt.log(value[..., 0]) if self.positive else value[..., 0])
         y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1]))
         return y
 
     def log_jac_det(self, value, *inputs):
-        return pt.sum(value[..., 1:], axis=-1)
+        if self.positive:
+            return pt.sum(value, axis=-1)
+        else:
+            return pt.sum(value[..., 1:], axis=-1)
 
 
 class SumTo1(Transform):
diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py
index 12d9b438b5..e28052bab9 100644
--- a/tests/distributions/test_transform.py
+++ b/tests/distributions/test_transform.py
@@ -103,7 +103,7 @@ def check_jacobian_det(
         x = make_comparable(x)
 
     if not elemwise:
-        jac = pt.log(pt.nlinalg.det(jacobian(x, [y])))
+        jac = pt.log(pt.abs(pt.nlinalg.det(jacobian(x, [y]))))
     else:
         jac = pt.log(pt.abs(pt.diag(jacobian(x, [y]))))
 
@@ -115,7 +115,7 @@ def check_jacobian_det(
     )
 
     for yval in domain.vals:
-        assert_allclose(actual_ljd(yval), computed_ljd(yval), rtol=tol)
+        assert_allclose(actual_ljd(yval), computed_ljd(yval), rtol=tol, atol=tol)
 
 
 def test_simplex():
@@ -281,6 +281,31 @@ def test_ordered():
     vals = get_values(tr.ordered, Vector(R, 3), pt.vector, floatX(np.zeros(3)))
     assert_array_equal(np.diff(vals) >= 0, True)
 
+    # Check that positive=True creates positive and still ordered values
+    vals = get_values(tr.Ordered(positive=True), Vector(R, 3), pt.vector, floatX(np.zeros(3)))
+    assert_array_equal(vals > 0, True)
+    assert_array_equal(np.diff(vals) >= 0, True)
+
+    # Check that positive=True and ascending=False creates descending values
+    vals = get_values(
+        tr.Ordered(positive=True, ascending=False), Vector(R, 3), pt.vector, floatX(np.zeros(3))
+    )
+    assert_array_equal(vals > 0, True)
+    assert_array_equal(np.diff(vals) <= 0, True)
+
+    # Check that forward and backward are still inverses
+    ord, vals = tr.Ordered(positive=True, ascending=False), np.array([0.3, 0.2, 0.1])
+    assert_allclose(vals, ord.backward(ord.forward(vals)).eval())
+
+    # Check the jacobian for positive=True and ascending=False
+    check_jacobian_det(
+        tr.Ordered(positive=True, ascending=False),
+        Vector(R, 2),
+        pt.vector,
+        floatX(np.array([1, 1])),
+        elemwise=False,
+    )
+
 
 def test_chain_values():
     chain_tranf = tr.Chain([tr.logodds, tr.ordered])