Skip to content
22 changes: 20 additions & 2 deletions pysr/export_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,28 @@ def __repr__(self):
return f"{type(self).__name__}(expression={self._expression_string})"

def forward(self, X):
if X.dim() != 2:
raise ValueError(
"Expected a 2D input tensor `X` with shape (L, nfeatures)."
)

if self._selection is not None:
X = X[:, self._selection]
symbols = {symbol: X[:, i] for i, symbol in enumerate(self.symbols_in)}
return self._node(symbols)

preserve_2d_output = X.shape[1] == 1
symbols = {
symbol: X[:, i : i + 1] for i, symbol in enumerate(self.symbols_in)
}
output = self._node(symbols)

if (
not preserve_2d_output
and output.dim() == 2
and output.shape[1] == 1
):
output = output.squeeze(-1)

return output

SingleSymPyModule = _SingleSymPyModule

Expand Down
76 changes: 70 additions & 6 deletions pysr/test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,26 +172,90 @@ def test_avoid_simplification(self):
feature_names_in=["x1"],
extra_sympy_mappings={"square": lambda x: x**2},
)
m = pysr.export_torch.sympy2torch(ex, ["x1"])
m = sympy2torch(ex, ["x1"])
rng = np.random.RandomState(0)
X = rng.randn(10, 1)
np.testing.assert_almost_equal(
m(torch.tensor(X)).detach().numpy(),
m(torch.tensor(X)).detach().numpy().flatten(),
np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0],
decimal=3,
)

def test_issue_656(self):
# Should correctly map numeric symbols to floats
E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
m = pysr.export_torch.sympy2torch(E_plus_x1, ["x1"])
m = sympy2torch(E_plus_x1, ["x1"])
X = np.random.randn(10, 1)
np.testing.assert_almost_equal(
m(self.torch.tensor(X)).detach().numpy(),
m(self.torch.tensor(X)).detach().numpy().flatten(),
np.exp(1) + X[:, 0],
decimal=3,
)

def test_issue_571_single_feature_shape(self):
"""Issue #571: 1-feature torch module preserves (L, 1) output shape."""
x = sympy.symbols("x")
m = sympy2torch(x + 1, [x])
X = self.torch.randn(32, 1)
y = m(X)
self.assertEqual(tuple(y.shape), (32, 1))
np.testing.assert_almost_equal(
y.detach().numpy().flatten(),
(X[:, 0] + 1).detach().numpy(),
decimal=6,
)

def test_issue_571_multifeature_output_is_1d(self):
"""Issue #571: multi-feature torch modules keep 1D outputs (L,) by default."""
x, y = sympy.symbols("x y")
m = sympy2torch(x + y, [x, y])
X = self.torch.randn(32, 2)
out = m(X)
self.assertEqual(tuple(out.shape), (32,))
np.testing.assert_almost_equal(
out.detach().numpy(),
(X[:, 0] + X[:, 1]).detach().numpy(),
decimal=6,
)

def test_issue_571_composition(self):
"""Issue #571: composing 1-feature modules into a 2-feature module works."""
x = sympy.symbols("x")
a, b = sympy.symbols("a b")
m1 = sympy2torch(x + 1, [x])
m2 = sympy2torch(2 * x, [x])
m3 = sympy2torch(a + b, [a, b])

X = self.torch.randn(32, 1)
y1 = m1(X)
y2 = m2(X)
self.assertEqual(tuple(y1.shape), (32, 1))
self.assertEqual(tuple(y2.shape), (32, 1))

stacked = self.torch.cat([y1, y2], dim=1)
y3 = m3(stacked)
np.testing.assert_almost_equal(
y3.detach().numpy(),
(3 * X[:, 0] + 1).detach().numpy(),
decimal=6,
)

def test_issue_571_reject_1d_input(self):
"""Issue #571: torch module rejects 1D inputs (expects (L, nfeatures))."""
x = sympy.symbols("x")
m = sympy2torch(x + 1, [x])
X = self.torch.randn(32, 1)
with self.assertRaises(ValueError):
m(X[:, 0])

def test_issue_571_selection_list_keeps_2d(self):
"""Issue #571: selection=[i] keeps (L, 1) shape after feature selection."""
x = sympy.symbols("x")
m = sympy2torch(x + 1, [x], selection=[0])
X = self.torch.randn(32, 2)
out = m(X)
self.assertEqual(tuple(out.shape), (32, 1))

def test_constant_arguments(self):
# Test that functions with constant arguments work correctly
# Regression test for https://github.com/MilesCranmer/PySR/issues/656
Expand All @@ -203,14 +267,14 @@ def test_constant_arguments(self):
]

for expr, expected in test_cases:
m = pysr.export_torch.sympy2torch(expr, [])
m = sympy2torch(expr, [])
result = m(self.torch.randn(10, 1))
np.testing.assert_almost_equal(result.item(), expected, decimal=3)

# Test with variables: sqrt(2) * x
x = sympy.symbols("x")
expr = sympy.sqrt(2) * x
m = pysr.export_torch.sympy2torch(expr, [x])
m = sympy2torch(expr, [x])
X = np.random.randn(10, 1)
np.testing.assert_almost_equal(
m(self.torch.tensor(X)).detach().numpy().flatten(),
Expand Down
Loading