Skip to content

Commit d31c648

Browse files
authored
Add context manager to h5py.File() calls (#135)
This should avoid errors as in #131 (comment) Also added some `if rank == 0:` to print statements in DESC equilibrium.
1 parent 0d71021 commit d31c648

File tree

8 files changed

+469
-510
lines changed

8 files changed

+469
-510
lines changed

src/struphy/fields_background/equils.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from time import time
99

1010
import cunumpy as xp
11+
from psydac.ddm.mpi import MockMPI
12+
from psydac.ddm.mpi import mpi as MPI
1113
from scipy.integrate import odeint, quad
1214
from scipy.interpolate import RectBivariateSpline, UnivariateSpline
1315
from scipy.optimize import fsolve, minimize
@@ -31,6 +33,17 @@
3133
from struphy.fields_background.mhd_equil.eqdsk import readeqdsk
3234
from struphy.utils.utils import read_state, subp_run
3335

36+
if isinstance(MPI, MockMPI):
37+
comm = None
38+
rank = 0
39+
size = 1
40+
Barrier = lambda: None
41+
else:
42+
comm = MPI.COMM_WORLD
43+
rank = comm.Get_rank()
44+
size = comm.Get_size()
45+
Barrier = comm.Barrier
46+
3447

3548
class HomogenSlab(CartesianMHDequilibrium):
3649
r"""
@@ -1748,7 +1761,8 @@ def __init__(
17481761
# default input file
17491762
if file is None:
17501763
file = "AUGNLED_g031213.00830.high"
1751-
print(f"EQDSK: taking default file {file}.")
1764+
if rank == 0:
1765+
print(f"EQDSK: taking default file {file}.")
17521766

17531767
# no rescaling if units are not provided
17541768
if units is None:
@@ -2131,9 +2145,11 @@ def __init__(
21312145
import pytest
21322146

21332147
with pytest.raises(SystemExit) as exc:
2134-
print("Simulation aborted, gvec must be installed (pip install gvec)!")
2148+
if rank == 0:
2149+
print("Simulation aborted, gvec must be installed (pip install gvec)!")
21352150
sys.exit(1)
2136-
print(f"{exc.value.code =}")
2151+
if rank == 0:
2152+
print(f"{exc.value.code =}")
21372153

21382154
import gvec
21392155

@@ -2410,13 +2426,15 @@ def __init__(
24102426
desc_spec = importlib.util.find_spec("desc")
24112427

24122428
if desc_spec is None:
2413-
print("Simulation aborted, desc-opt must be installed!")
2414-
print("Install with:\npip install desc-opt")
2429+
if rank == 0:
2430+
print("Simulation aborted, desc-opt must be installed!")
2431+
print("Install with:\npip install desc-opt")
24152432
sys.exit(1)
24162433

24172434
import desc
24182435

2419-
print(f"DESC import: {time() - t} seconds")
2436+
if rank == 0:
2437+
print(f"DESC import: {time() - t} seconds")
24202438
from struphy.geometry.domains import DESCunit
24212439

24222440
# no rescaling if units are not provided
@@ -2449,7 +2467,8 @@ def __init__(
24492467
else:
24502468
self._eq = desc.io.load(eq_name)
24512469

2452-
print(f"Eq. load: {time() - t} seconds")
2470+
if rank == 0:
2471+
print(f"Eq. load: {time() - t} seconds")
24532472
self._rmin = self.params["rmin"]
24542473
self._use_nfp = self.params["use_nfp"]
24552474

@@ -2878,7 +2897,7 @@ def desc_eval(
28782897
assert xp.all(theta == theta1)
28792898
assert xp.all(zeta == zeta1)
28802899

2881-
if verbose:
2900+
if verbose and rank == 0:
28822901
# import sys
28832902
print(f"\n{nfp =}")
28842903
print(f"{self.eq.axis =}")
@@ -2900,7 +2919,8 @@ def desc_eval(
29002919

29012920
# make c-contiguous
29022921
out = xp.ascontiguousarray(out)
2903-
print(f"desc_eval for {var}: {time() - ttime} seconds")
2922+
if rank == 0:
2923+
print(f"desc_eval for {var}: {time() - ttime} seconds")
29042924
return out
29052925

29062926

src/struphy/initial/utilities.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,15 @@ def __init__(self, derham, name, species, **params):
5454
else:
5555
key = "restart/" + species + "_" + name
5656

57-
if isinstance(data.file[key], h5py.Dataset):
58-
self._vector = data.file[key][-1]
57+
with h5py.File(data.file_path, "a") as file:
58+
if isinstance(file[key], h5py.Dataset):
59+
self._vector = file[key][-1]
5960

60-
else:
61-
self._vector = []
62-
63-
for n in range(3):
64-
self._vector += [data.file[key + "/" + str(n + 1)][-1]]
61+
else:
62+
self._vector = []
6563

66-
data.file.close()
64+
for n in range(3):
65+
self._vector += [file[key + "/" + str(n + 1)][-1]]
6766

6867
@property
6968
def vector(self):

src/struphy/io/output_handling.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,10 @@ def __init__(self, path_out, file_name=None, comm=None):
3838
self._file_name = file_name
3939

4040
# file path
41-
file_path = os.path.join(path_out, "data/", self._file_name)
41+
self._file_path = os.path.join(path_out, "data/", self._file_name)
4242

4343
# check if file already exists
44-
file_exists = os.path.exists(file_path)
45-
46-
# open/create file
47-
self._file = h5py.File(file_path, "a")
44+
file_exists = os.path.exists(self.file_path)
4845

4946
# dictionary with pairs (dataset key : object ID)
5047
self._dset_dict = {}
@@ -53,9 +50,10 @@ def __init__(self, path_out, file_name=None, comm=None):
5350
if file_exists:
5451
dataset_keys = []
5552

56-
self._file.visit(
57-
lambda key: dataset_keys.append(key) if isinstance(self._file[key], h5py.Dataset) else None,
58-
)
53+
with h5py.File(self.file_path, "a") as file:
54+
file.visit(
55+
lambda key: dataset_keys.append(key) if isinstance(self._file[key], h5py.Dataset) else None,
56+
)
5957

6058
for key in dataset_keys:
6159
self._dset_dict[key] = None
@@ -66,9 +64,9 @@ def file_name(self):
6664
return self._file_name
6765

6866
@property
69-
def file(self):
70-
"""The hdf5 file."""
71-
return self._file
67+
def file_path(self):
68+
"""The absolute path to the hdf5 file."""
69+
return self._file_path
7270

7371
@property
7472
def dset_dict(self):
@@ -103,20 +101,21 @@ def add_data(self, data_dict):
103101

104102
# create new dataset otherwise and save array
105103
else:
106-
# scalar values are saved as 1d arrays of size 1
107-
if val.size == 1:
108-
assert val.ndim == 1
109-
self._file.create_dataset(key, (1,), maxshape=(None,), dtype=val.dtype, chunks=True)
110-
self._file[key][0] = val[0]
111-
else:
112-
self._file.create_dataset(
113-
key,
114-
(1,) + val.shape,
115-
maxshape=(None,) + val.shape,
116-
dtype=val.dtype,
117-
chunks=True,
118-
)
119-
self._file[key][0] = val
104+
with h5py.File(self.file_path, "a") as file:
105+
# scalar values are saved as 1d arrays of size 1
106+
if val.size == 1:
107+
assert val.ndim == 1
108+
file.create_dataset(key, (1,), maxshape=(None,), dtype=val.dtype, chunks=True)
109+
file[key][0] = val[0]
110+
else:
111+
file.create_dataset(
112+
key,
113+
(1,) + val.shape,
114+
maxshape=(None,) + val.shape,
115+
dtype=val.dtype,
116+
chunks=True,
117+
)
118+
file[key][0] = val
120119

121120
# set object ID
122121
self._dset_dict[key] = id(val)
@@ -130,25 +129,26 @@ def save_data(self, keys=None):
130129
keys : list
131130
Keys to the data objects specified when using "add_data". Default saves all specified data objects.
132131
"""
133-
134-
# loop over all keys
135-
if keys is None:
136-
for key in self._dset_dict:
137-
self._file[key].resize(self._file[key].shape[0] + 1, axis=0)
138-
self._file[key][-1] = ctypes.cast(self._dset_dict[key], ctypes.py_object).value
139-
140-
# only loop over given keys
141-
else:
142-
for key in keys:
143-
self._file[key].resize(self._file[key].shape[0] + 1, axis=0)
144-
self._file[key][-1] = ctypes.cast(self._dset_dict[key], ctypes.py_object).value
132+
with h5py.File(self.file_path, "a") as file:
133+
# loop over all keys
134+
if keys is None:
135+
for key in self._dset_dict:
136+
file[key].resize(file[key].shape[0] + 1, axis=0)
137+
file[key][-1] = ctypes.cast(self._dset_dict[key], ctypes.py_object).value
138+
139+
# only loop over given keys
140+
else:
141+
for key in keys:
142+
file[key].resize(file[key].shape[0] + 1, axis=0)
143+
file[key][-1] = ctypes.cast(self._dset_dict[key], ctypes.py_object).value
145144

146145
def info(self):
147146
"""Print info of data sets to screen."""
148147

149148
for key in self._dset_dict:
150-
print(f"\nData set name: {key}")
151-
print("Shape:", self._file[key].shape)
152-
print("Attributes:")
153-
for attr, val in self._file[key].attrs.items():
154-
print(attr, val)
149+
with h5py.File(self.file_path, "a") as file:
150+
print(f"\nData set name: {key}")
151+
print("Shape:", file[key].shape)
152+
print("Attributes:")
153+
for attr, val in file[key].attrs.items():
154+
print(attr, val)

src/struphy/main.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,10 @@ def run(
280280
if restart:
281281
model.initialize_from_restart(data)
282282

283-
time_state["value"][0] = data.file["restart/time/value"][-1]
284-
time_state["value_sec"][0] = data.file["restart/time/value_sec"][-1]
285-
time_state["index"][0] = data.file["restart/time/index"][-1]
283+
with h5py.File(data.file_path, "a") as file:
284+
time_state["value"][0] = file["restart/time/value"][-1]
285+
time_state["value_sec"][0] = file["restart/time/value_sec"][-1]
286+
time_state["index"][0] = file["restart/time/index"][-1]
286287

287288
total_steps = str(int(round((Tend - time_state["value"][0]) / dt)))
288289
else:
@@ -317,7 +318,6 @@ def run(
317318
if break_cond_1 or break_cond_2:
318319
# save restart data (other data already saved below)
319320
data.save_data(keys=save_keys_end)
320-
data.file.close()
321321
end_simulation = time.time()
322322
if rank == 0:
323323
print(f"\nTime steps done: {time_state['index'][0]}")
@@ -474,37 +474,34 @@ def pproc(
474474
return
475475

476476
# check for fields and kinetic data in hdf5 file that need post processing
477-
file = h5py.File(os.path.join(path, "data/", "data_proc0.hdf5"), "r")
477+
with h5py.File(os.path.join(path, "data/", "data_proc0.hdf5"), "r") as file:
478+
# save time grid at which post-processing data is created
479+
xp.save(os.path.join(path_pproc, "t_grid.npy"), file["time/value"][::step].copy())
478480

479-
# save time grid at which post-processing data is created
480-
xp.save(os.path.join(path_pproc, "t_grid.npy"), file["time/value"][::step].copy())
481-
482-
if "feec" in file.keys():
483-
exist_fields = True
484-
else:
485-
exist_fields = False
486-
487-
if "kinetic" in file.keys():
488-
exist_kinetic = {"markers": False, "f": False, "n_sph": False}
489-
kinetic_species = []
490-
kinetic_kinds = []
491-
for name in file["kinetic"].keys():
492-
kinetic_species += [name]
493-
kinetic_kinds += [next(iter(model.species[name].variables.values())).space]
494-
495-
# check for saved markers
496-
if "markers" in file["kinetic"][name]:
497-
exist_kinetic["markers"] = True
498-
# check for saved distribution function
499-
if "f" in file["kinetic"][name]:
500-
exist_kinetic["f"] = True
501-
# check for saved sph density
502-
if "n_sph" in file["kinetic"][name]:
503-
exist_kinetic["n_sph"] = True
504-
else:
505-
exist_kinetic = None
506-
507-
file.close()
481+
if "feec" in file.keys():
482+
exist_fields = True
483+
else:
484+
exist_fields = False
485+
486+
if "kinetic" in file.keys():
487+
exist_kinetic = {"markers": False, "f": False, "n_sph": False}
488+
kinetic_species = []
489+
kinetic_kinds = []
490+
for name in file["kinetic"].keys():
491+
kinetic_species += [name]
492+
kinetic_kinds += [next(iter(model.species[name].variables.values())).space]
493+
494+
# check for saved markers
495+
if "markers" in file["kinetic"][name]:
496+
exist_kinetic["markers"] = True
497+
# check for saved distribution function
498+
if "f" in file["kinetic"][name]:
499+
exist_kinetic["f"] = True
500+
# check for saved sph density
501+
if "n_sph" in file["kinetic"][name]:
502+
exist_kinetic["n_sph"] = True
503+
else:
504+
exist_kinetic = None
508505

509506
# field post-processing
510507
if exist_fields:

0 commit comments

Comments
 (0)