diff --git a/pysr/export_torch.py b/pysr/export_torch.py index ef9aeaeb8..fc27053df 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -179,10 +179,28 @@ def __repr__(self): return f"{type(self).__name__}(expression={self._expression_string})" def forward(self, X): + if X.dim() != 2: + raise ValueError( + "Expected a 2D input tensor `X` with shape (L, nfeatures)." + ) + if self._selection is not None: X = X[:, self._selection] - symbols = {symbol: X[:, i] for i, symbol in enumerate(self.symbols_in)} - return self._node(symbols) + + preserve_2d_output = X.shape[1] == 1 + symbols = { + symbol: X[:, i : i + 1] for i, symbol in enumerate(self.symbols_in) + } + output = self._node(symbols) + + if ( + not preserve_2d_output + and output.dim() == 2 + and output.shape[1] == 1 + ): + output = output.squeeze(-1) + + return output SingleSymPyModule = _SingleSymPyModule diff --git a/pysr/test/test_torch.py b/pysr/test/test_torch.py index e66e7d960..80e5e5f7f 100644 --- a/pysr/test/test_torch.py +++ b/pysr/test/test_torch.py @@ -172,11 +172,11 @@ def test_avoid_simplification(self): feature_names_in=["x1"], extra_sympy_mappings={"square": lambda x: x**2}, ) - m = pysr.export_torch.sympy2torch(ex, ["x1"]) + m = sympy2torch(ex, ["x1"]) rng = np.random.RandomState(0) X = rng.randn(10, 1) np.testing.assert_almost_equal( - m(torch.tensor(X)).detach().numpy(), + m(torch.tensor(X)).detach().numpy().flatten(), np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0], decimal=3, ) @@ -184,14 +184,78 @@ def test_avoid_simplification(self): def test_issue_656(self): # Should correctly map numeric symbols to floats E_plus_x1 = sympy.exp(1) + sympy.symbols("x1") - m = pysr.export_torch.sympy2torch(E_plus_x1, ["x1"]) + m = sympy2torch(E_plus_x1, ["x1"]) X = np.random.randn(10, 1) np.testing.assert_almost_equal( - m(self.torch.tensor(X)).detach().numpy(), + m(self.torch.tensor(X)).detach().numpy().flatten(), np.exp(1) + X[:, 0], decimal=3, ) + def test_issue_571_single_feature_shape(self): + """Issue #571: 1-feature torch module preserves (L, 1) output shape.""" + x = sympy.symbols("x") + m = sympy2torch(x + 1, [x]) + X = self.torch.randn(32, 1) + y = m(X) + self.assertEqual(tuple(y.shape), (32, 1)) + np.testing.assert_almost_equal( + y.detach().numpy().flatten(), + (X[:, 0] + 1).detach().numpy(), + decimal=6, + ) + + def test_issue_571_multifeature_output_is_1d(self): + """Issue #571: multi-feature torch modules keep 1D outputs (L,) by default.""" + x, y = sympy.symbols("x y") + m = sympy2torch(x + y, [x, y]) + X = self.torch.randn(32, 2) + out = m(X) + self.assertEqual(tuple(out.shape), (32,)) + np.testing.assert_almost_equal( + out.detach().numpy(), + (X[:, 0] + X[:, 1]).detach().numpy(), + decimal=6, + ) + + def test_issue_571_composition(self): + """Issue #571: composing 1-feature modules into a 2-feature module works.""" + x = sympy.symbols("x") + a, b = sympy.symbols("a b") + m1 = sympy2torch(x + 1, [x]) + m2 = sympy2torch(2 * x, [x]) + m3 = sympy2torch(a + b, [a, b]) + + X = self.torch.randn(32, 1) + y1 = m1(X) + y2 = m2(X) + self.assertEqual(tuple(y1.shape), (32, 1)) + self.assertEqual(tuple(y2.shape), (32, 1)) + + stacked = self.torch.cat([y1, y2], dim=1) + y3 = m3(stacked) + np.testing.assert_almost_equal( + y3.detach().numpy(), + (3 * X[:, 0] + 1).detach().numpy(), + decimal=6, + ) + + def test_issue_571_reject_1d_input(self): + """Issue #571: torch module rejects 1D inputs (expects (L, nfeatures)).""" + x = sympy.symbols("x") + m = sympy2torch(x + 1, [x]) + X = self.torch.randn(32, 1) + with self.assertRaises(ValueError): + m(X[:, 0]) + + def test_issue_571_selection_list_keeps_2d(self): + """Issue #571: selection=[i] keeps (L, 1) shape after feature selection.""" + x = sympy.symbols("x") + m = sympy2torch(x + 1, [x], selection=[0]) + X = self.torch.randn(32, 2) + out = m(X) + self.assertEqual(tuple(out.shape), (32, 1)) + def test_constant_arguments(self): # Test that functions with constant arguments work correctly # Regression test for https://github.com/MilesCranmer/PySR/issues/656 @@ -203,14 +267,14 @@ def test_constant_arguments(self): ] for expr, expected in test_cases: - m = pysr.export_torch.sympy2torch(expr, []) + m = sympy2torch(expr, []) result = m(self.torch.randn(10, 1)) np.testing.assert_almost_equal(result.item(), expected, decimal=3) # Test with variables: sqrt(2) * x x = sympy.symbols("x") expr = sympy.sqrt(2) * x - m = pysr.export_torch.sympy2torch(expr, [x]) + m = sympy2torch(expr, [x]) X = np.random.randn(10, 1) np.testing.assert_almost_equal( m(self.torch.tensor(X)).detach().numpy().flatten(),