Skip to content

[BUG]: Cannot load a model with custom extra_jax_mappings from .pkl file; yields EOFError: Ran out of input and ValueError: Failed to evaluate the expression #1198

@jc-umana-FI

Description

@jc-umana-FI

What happened?

Hello PySR team!

TL;DR:

Simulations using extra_jax_mappings= with a custom function not already described in the _jnp_func_lookup section of export_jax.py produce an empty pickle file, and cannot be evaluated later with PySRRegressor.from_file(). Minimal reproducible code can be found below.


The Issue

I've been trying to test PySR's ability to recapture known functions and their pre-defined approximations that I know for certain will show up in more complicated expressions corresponding to physical systems (e.g. Bessel functions). Ultimately, I would like to do a warm start from the models that can identify these terms and be able to use the JAX exports of these models later on for auto diff and such. In order to do this, I need to export said models to pickle files and pull from those same files.

The problem that I've observed in many of my test runs is that any custom implementation of a known function or its approximation getting turned into a sympy.Function() or being lambify()'d will lead to an EOFError: Ran out of input upon attempts to load the file in your run_directory with PySRRegressor.from_file(). Upon closer investigation of the file path, the pickle file is totally empty, while hall_of_fame.csv contains the text versions of the expressions.

If you try to use model.predict(), you will get a ValueError because the custom operator that should have been saved to the pickle file is not defined.

The only way to evaluate a chosen expression from the model is shown in the provided example, and only as long as you keep the variable name of the model -- there's seemingly no way of actually saving and reevaluating the model once your local variables are cleared.

Version

1.5.10

Operating System

Linux

Package Manager

pip

Interface

Jupyter Notebook

Relevant log output

--------------------- ERROR from PySRRegressor.from_file() ----------------------------------

Attempting to load model from /mnt/ceph/users/jumana/PySR_models/models_summer_2026/sympfunc_examp_test/checkpoint.pkl...
---------------------------------------------------------------------------
EOFError                                  Traceback (most recent call last)
Cell In[7], line 1
----> 1 model_pull = PySRRegressor.from_file(run_directory=f'/mnt/ceph/users/jumana/PySR_models/models_summer_2026/sympfunc_examp_test')

File ~/code/dft/cluster/qcqed_venv_04162025/lib/python3.10/site-packages/pysr/sr.py:1139, in PySRRegressor.from_file(cls, equation_file, run_directory, binary_operators, unary_operators, n_features_in, feature_names_in, selection_mask, nout, **pysr_kwargs)
   1137 assert n_features_in is None
   1138 with open(pkl_filename, "rb") as f:
-> 1139     model = cast("PySRRegressor", pkl.load(f))
   1141 # Update any parameters if necessary, such as
   1142 # extra_sympy_mappings:
   1143 model.set_params(**pysr_kwargs)

EOFError: Ran out of input


--------------------- ERROR from model.predict() ----------------------------------

NameError                                 Traceback (most recent call last)
File ~/code/dft/cluster/qcqed_venv_04162025/lib/python3.10/site-packages/pysr/sr.py:2437, in PySRRegressor.predict(self, X, index, category)
   2436     else:
-> 2437         return cast(ndarray, best_equation["lambda_format"](X, *args))
   2438 except Exception as error:

File ~/code/dft/cluster/qcqed_venv_04162025/lib/python3.10/site-packages/pysr/export_numpy.py:50, in CallableEquation.__call__(self, X)
     48         X = X[:, self._selection]
---> 50 return self._lambda(*X.T) * np.ones(expected_shape)

File <lambdifygenerated-1>:2, in _lambdifygenerated(x5, x7, x10)
      1 def _lambdifygenerated(x5, x7, x10):
----> 2     return x5*x5 + cos_approx(x10)/0.5768407

NameError: name 'cos_approx' is not defined

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[6], line 1
----> 1 model.predict(X.values, index=5)

File ~/code/dft/cluster/qcqed_venv_04162025/lib/python3.10/site-packages/pysr/sr.py:2439, in PySRRegressor.predict(self, X, index, category)
   2437         return cast(ndarray, best_equation["lambda_format"](X, *args))
   2438 except Exception as error:
-> 2439     raise ValueError(
   2440         "Failed to evaluate the expression. "
   2441         "If you are using a custom operator, make sure to define it in `extra_sympy_mappings`, "
   2442         "e.g., `model.set_params(extra_sympy_mappings={'inv': lambda x: 1/x})`, where "
   2443         "`lambda x: 1/x` is a valid SymPy function defining the operator. "
   2444         "You can then run `model.refresh()` to re-load the expressions."
   2445     ) from error

ValueError: Failed to evaluate the expression. If you are using a custom operator, make sure to define it in `extra_sympy_mappings`, e.g., `model.set_params(extra_sympy_mappings={'inv': lambda x: 1/x})`, where `lambda x: 1/x` is a valid SymPy function defining the operator. You can then run `model.refresh()` to re-load the expressions.

Extra Info


Minimal Reproducible Code (Using a Custom Approximation to cosine -- from the test_jax.py source code)

Please replace the YOUR/WORKING/PATH/HERE with your own path.

The model training:

rstate = np.random.RandomState(0)
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})

def cos_approx(x):
    return 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720

sp_cos_approx = sympy.Function("cos_approx")

y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])

model = PySRRegressor(
    progress=False,
    unary_operators=["cos_approx(x) = 1 - x^2 / 2 + x^4 / 24 + x^6 / 720"],
    select_k_features=3,
    maxsize=10,
    early_stop_condition=1e-5,
    extra_sympy_mappings={"cos_approx": sp_cos_approx},
    extra_jax_mappings={
        sp_cos_approx: "(lambda x: 1 - x**2 / 2 + x**4 / 24 + x**6 / 720)"
    },
    random_state=0,
    deterministic=True,
    parallelism="serial",
    run_id= 'sympfunc_examp_test',
    output_directory= 'YOUR/WORKING/PATH/HERE',
    output_jax_format=True
)

np.random.seed(0)
model.fit(X.values, y.values)

EOFError reproducing code -- model retrieval attempt

model_pull = PySRRegressor.from_file(run_directory='YOUR/WORKING/PATH/HERE/sympfunc_examp_test')

ValueError reproducing code -- simple prediction attempt

model.predict(X.values, index=None)

Currently, the only way to evaluate a select (unsaved) model:

f, parameters = model.jax(index=None).values()
jax_prediction = partial(f, parameters=parameters)
jax_output = jax_prediction(X.values)
jax_output

Desired Method of Model Retrieval Upon Saving:

def get_funct(model, *args, index=None):
        jax_model = model.jax(index)
        jax_func = jax_model['callable'](jnp.asarray([*args])[None,...], jax_model['parameters'])
        return jax_func[0]

funct_val_check_from_model_pull = get_funct(model_pull, X.values.reshape(-1, 1), index=None)

For the sake of completely laying out the problem at hand, here is the Bessel function approximation I mentioned earlier:

def sp_J_1_approx(w):
      func = sympy.sqrt(2 / (sympy.pi * sympy.Abs(w))) * sympy.cos(sympy.Abs(w) - ((3 * sympy.pi) / 4))
      return func
  
J_1_approx = sympy.Function("sp_J_1_approx") 

and set of parameters I initially attempted this with:

unary_operators=["cos", "sqrt", "J_1_approx(x) = Float32((sqrt(2 / (pi * abs(x))) * cos(abs(x) - ((3 * pi) / 4))))"],
extra_sympy_mappings={"J_1_approx": lambda x: sympy.sqrt(2 / (sympy.pi * sympy.Abs(x))) * sympy.cos(sympy.Abs(x) - ((3 * sympy.pi) / 4))},
extra_jax_mappings={J_1_approx: "(jnp.sqrt(2 / (jnp.pi * jnp.abs(w))) * jnp.cos(jnp.abs(w) - ((3 * jnp.pi) / 4)))"},

Please let me know if there are any known workarounds, if I'm doing something wrong, or if this is just a bug that needs fixing. Thank you for taking the time to read this, and I look forward to your response!

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions