Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions docs/tutorials/defining_input_params.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": null,
"id": "0635b285",
"metadata": {},
"outputs": [
Expand All @@ -705,7 +705,7 @@
}
],
"source": [
"inputs_read = p21c.InputParameters.from_template(tomlfile, random_seed=1)\n",
"inputs_read = p21c.InputParameters.from_template(tomlfile)\n",
"\n",
"print(\"Original == New? \", inputs==inputs_read)"
]
Expand Down Expand Up @@ -734,13 +734,17 @@
"metadata": {},
"source": [
"By default, the `write_template` function creates a **full** TOML file with all \n",
"of the possible parameters specified in it. In fact, you can create a TOML file listing\n",
"all of the default parameters very easily:"
"of the possible parameters specified in it, including the `random_seed` and `node_redshifts`. This is why we didn't have to specify the `random_seed` when loading\n",
"`inputs_read` from our template.\n",
"\n",
"In fact, you can create a TOML file listing all of the default parameters very easily,\n",
"and avoid writing the `random_seed` so that you are reminded to be explicit about the \n",
"seed when reading:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": null,
"id": "dba550e0",
"metadata": {},
"outputs": [],
Expand All @@ -749,7 +753,7 @@
"\n",
"# For the sake of the tutorial, write the custom toml into a temporary directory\n",
"tomlfile = Path(mkdtemp()) / 'default_parameters.toml'\n",
"p21c.write_template(inputs, tomlfile)"
"p21c.write_template(inputs, tomlfile, only_structs=True)"
]
},
{
Expand Down
38 changes: 31 additions & 7 deletions src/py21cmfast/_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from collections import defaultdict
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Literal
from typing import Literal

import tomlkit

from .input_serialization import deserialize_inputs, prepare_inputs_for_serialization
from .wrapper.inputs import InputParameters
from .wrapper.inputs import InputParameters, InputStruct

TEMPLATE_PATH = Path(__file__).parent / "templates/"
MANIFEST = TEMPLATE_PATH / "manifest.toml"
Expand Down Expand Up @@ -68,7 +68,7 @@ def load_template_file(template_name: str | Path):

def create_params_from_template(
template_name: str | Path | Sequence[str | Path], **kwargs
) -> dict[str, dict[str, Any]]:
) -> dict[str, InputStruct]:
"""
Construct the required InputStruct instances for a run from a given template.

Expand Down Expand Up @@ -105,15 +105,32 @@ def create_params_from_template(
templates = template_name

full_template = defaultdict(dict)

random_seed = None
node_redshifts = None

for tmpl in templates:
thist = load_template_file(tmpl)
for k, v in thist.items():
full_template[k] |= v
return deserialize_inputs(full_template, **kwargs)
if k == "random_seed":
random_seed = v
elif k == "node_redshifts":
node_redshifts = v
else:
full_template[k] |= v
out = deserialize_inputs(full_template, **kwargs)
if random_seed is not None:
out["random_seed"] = random_seed
if node_redshifts is not None:
out["node_redshifts"] = node_redshifts
return out


def write_template(
inputs: InputParameters, template_file: Path | str, mode: TOMLMode = "full"
inputs: InputParameters,
template_file: Path | str,
mode: TOMLMode = "full",
only_structs: bool | None = None,
):
"""Write a set of input parameters to a template file.

Expand All @@ -124,7 +141,14 @@ def write_template(
template_file
The path of the output.
"""
inputs_dct = prepare_inputs_for_serialization(inputs, mode=mode)
assert mode in ("full", "minimal"), "mode must be 'full' or 'minimal'"

if only_structs is None:
only_structs = mode == "minimal"

inputs_dct = prepare_inputs_for_serialization(
inputs, mode=mode, only_structs=only_structs
)
inputs_dct.pop("CosmoTables", None)

template_file = Path(template_file)
Expand Down
5 changes: 5 additions & 0 deletions src/py21cmfast/input_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def prepare_inputs_for_serialization(
# we can simply leave it out of the written TOML (or HDF5) when its value is None,
# and it will anyway be set to its own default if read back in.
out = {}
if "random_seed" in dct:
out["random_seed"] = dct.pop("random_seed")
if "node_redshifts" in dct:
out["node_redshifts"] = dct.pop("node_redshifts")

for structname, structvals in dct.items():
this = {}
clsname = snake_to_camel(structname)
Expand Down
14 changes: 8 additions & 6 deletions src/py21cmfast/wrapper/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,7 +1799,7 @@ def evolve_input_structs(self, **kwargs):
def from_template(
cls,
name: str | Path | Sequence[str | Path],
random_seed: int,
random_seed: int | None = None,
node_redshifts: tuple[float] | None = None,
**kwargs,
):
Expand Down Expand Up @@ -1827,13 +1827,15 @@ def from_template(
"""
from .._templates import create_params_from_template

cls_kw = {"random_seed": random_seed}
if node_redshifts is not None:
cls_kw["node_redshifts"] = node_redshifts

dct = create_params_from_template(name, **kwargs)
dct.pop("cosmo_tables")
return cls(**dct, **cls_kw)

if random_seed is not None:
dct["random_seed"] = random_seed
if node_redshifts is not None:
dct["node_redshifts"] = node_redshifts

return cls(**dct)

def clone(self, **kwargs):
"""Generate a copy of the InputParameter structure with specified changes."""
Expand Down
15 changes: 15 additions & 0 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,18 @@ def test_roundtrip(self, template, tmp_path: Path, mode: str):

new = InputParameters.from_template(pth, random_seed=1, K_MAX_FOR_CLASS=1.0)
assert new == inputs

def test_writing_non_structs(self, tmp_path):
"""Test that writing with only_structs=False does include non-struct parameters."""
inputs = InputParameters.from_template("simple", random_seed=1)
pth = tmp_path / "tmp.toml"
tmpl.write_template(inputs, pth, mode="full", only_structs=False)

with pth.open("r") as fl:
contents = fl.read()

assert "random_seed" in contents
assert "node_redshifts" in contents

new = InputParameters.from_template(pth)
assert new == inputs
3 changes: 3 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading