diff --git a/docs/tutorials/defining_input_params.ipynb b/docs/tutorials/defining_input_params.ipynb index 7b3f7f36..e852da88 100644 --- a/docs/tutorials/defining_input_params.ipynb +++ b/docs/tutorials/defining_input_params.ipynb @@ -692,7 +692,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "id": "0635b285", "metadata": {}, "outputs": [ @@ -705,7 +705,7 @@ } ], "source": [ - "inputs_read = p21c.InputParameters.from_template(tomlfile, random_seed=1)\n", + "inputs_read = p21c.InputParameters.from_template(tomlfile)\n", "\n", "print(\"Original == New? \", inputs==inputs_read)" ] @@ -734,13 +734,17 @@ "metadata": {}, "source": [ "By default, the `write_template` function creates a **full** TOML file with all \n", - "of the possible parameters specified in it. In fact, you can create a TOML file listing\n", - "all of the default parameters very easily:" + "of the possible parameters specified in it, including the `random_seed` and `node_redshifts`. This is why we didn't have to specify the `random_seed` when loading\n", + "`inputs_read` from our template.\n", + "\n", + "In fact, you can create a TOML file listing all of the default parameters very easily,\n", + "and avoid writing the `random_seed` so that you are reminded to be explicit about the \n", + "seed when reading:" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "id": "dba550e0", "metadata": {}, "outputs": [], @@ -749,7 +753,7 @@ "\n", "# For the sake of the tutorial, write the custom toml into a temporary directory\n", "tomlfile = Path(mkdtemp()) / 'default_parameters.toml'\n", - "p21c.write_template(inputs, tomlfile)" + "p21c.write_template(inputs, tomlfile, only_structs=True)" ] }, { diff --git a/src/py21cmfast/_templates.py b/src/py21cmfast/_templates.py index ac512abd..9908a614 100644 --- a/src/py21cmfast/_templates.py +++ b/src/py21cmfast/_templates.py @@ -13,12 +13,12 @@ from collections import defaultdict from collections.abc import Sequence from pathlib import Path -from typing import Any, Literal +from typing import Literal import tomlkit from .input_serialization import deserialize_inputs, prepare_inputs_for_serialization -from .wrapper.inputs import InputParameters +from .wrapper.inputs import InputParameters, InputStruct TEMPLATE_PATH = Path(__file__).parent / "templates/" MANIFEST = TEMPLATE_PATH / "manifest.toml" @@ -68,7 +68,7 @@ def load_template_file(template_name: str | Path): def create_params_from_template( template_name: str | Path | Sequence[str | Path], **kwargs -) -> dict[str, dict[str, Any]]: +) -> dict[str, InputStruct]: """ Construct the required InputStruct instances for a run from a given template. @@ -105,15 +105,32 @@ def create_params_from_template( templates = template_name full_template = defaultdict(dict) + + random_seed = None + node_redshifts = None + for tmpl in templates: thist = load_template_file(tmpl) for k, v in thist.items(): - full_template[k] |= v - return deserialize_inputs(full_template, **kwargs) + if k == "random_seed": + random_seed = v + elif k == "node_redshifts": + node_redshifts = v + else: + full_template[k] |= v + out = deserialize_inputs(full_template, **kwargs) + if random_seed is not None: + out["random_seed"] = random_seed + if node_redshifts is not None: + out["node_redshifts"] = node_redshifts + return out def write_template( - inputs: InputParameters, template_file: Path | str, mode: TOMLMode = "full" + inputs: InputParameters, + template_file: Path | str, + mode: TOMLMode = "full", + only_structs: bool | None = None, ): """Write a set of input parameters to a template file. @@ -124,7 +141,14 @@ def write_template( template_file The path of the output. """ - inputs_dct = prepare_inputs_for_serialization(inputs, mode=mode) + assert mode in ("full", "minimal"), "mode must be 'full' or 'minimal'" + + if only_structs is None: + only_structs = mode == "minimal" + + inputs_dct = prepare_inputs_for_serialization( + inputs, mode=mode, only_structs=only_structs + ) inputs_dct.pop("CosmoTables", None) template_file = Path(template_file) diff --git a/src/py21cmfast/input_serialization.py b/src/py21cmfast/input_serialization.py index 06d8762a..26db8a47 100644 --- a/src/py21cmfast/input_serialization.py +++ b/src/py21cmfast/input_serialization.py @@ -107,6 +107,11 @@ def prepare_inputs_for_serialization( # we can simply leave it out of the written TOML (or HDF5) when its value is None, # and it will anyway be set to its own default if read back in. out = {} + if "random_seed" in dct: + out["random_seed"] = dct.pop("random_seed") + if "node_redshifts" in dct: + out["node_redshifts"] = dct.pop("node_redshifts") + for structname, structvals in dct.items(): this = {} clsname = snake_to_camel(structname) diff --git a/src/py21cmfast/wrapper/inputs.py b/src/py21cmfast/wrapper/inputs.py index 69d9073c..3386316e 100644 --- a/src/py21cmfast/wrapper/inputs.py +++ b/src/py21cmfast/wrapper/inputs.py @@ -1799,7 +1799,7 @@ def evolve_input_structs(self, **kwargs): def from_template( cls, name: str | Path | Sequence[str | Path], - random_seed: int, + random_seed: int | None = None, node_redshifts: tuple[float] | None = None, **kwargs, ): @@ -1827,13 +1827,15 @@ def from_template( """ from .._templates import create_params_from_template - cls_kw = {"random_seed": random_seed} - if node_redshifts is not None: - cls_kw["node_redshifts"] = node_redshifts - dct = create_params_from_template(name, **kwargs) dct.pop("cosmo_tables") - return cls(**dct, **cls_kw) + + if random_seed is not None: + dct["random_seed"] = random_seed + if node_redshifts is not None: + dct["node_redshifts"] = node_redshifts + + return cls(**dct) def clone(self, **kwargs): """Generate a copy of the InputParameter structure with specified changes.""" diff --git a/tests/test_templates.py b/tests/test_templates.py index 8fe158cb..e95e6b9b 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -120,3 +120,18 @@ def test_roundtrip(self, template, tmp_path: Path, mode: str): new = InputParameters.from_template(pth, random_seed=1, K_MAX_FOR_CLASS=1.0) assert new == inputs + + def test_writing_non_structs(self, tmp_path): + """Test that writing with only_structs=False does include non-struct parameters.""" + inputs = InputParameters.from_template("simple", random_seed=1) + pth = tmp_path / "tmp.toml" + tmpl.write_template(inputs, pth, mode="full", only_structs=False) + + with pth.open("r") as fl: + contents = fl.read() + + assert "random_seed" in contents + assert "node_redshifts" in contents + + new = InputParameters.from_template(pth) + assert new == inputs diff --git a/uv.lock b/uv.lock new file mode 100644 index 00000000..bda02073 --- /dev/null +++ b/uv.lock @@ -0,0 +1,3 @@ +version = 1 +revision = 3 +requires-python = ">=3.13"