Skip to content

Commit 36978de

Browse files
committed
test: use tempfile.TemporaryDirectory to generate test_folder
1 parent dc37588 commit 36978de

File tree

6 files changed

+766
-773
lines changed

6 files changed

+766
-773
lines changed

src/struphy/models/tests/utils_testing.py

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
import os
33
import shutil
4+
import tempfile
45
from types import ModuleType
56

67
import pytest
@@ -68,45 +69,45 @@ def call_test(model_name: str, module: ModuleType = None, verbose=True):
6869
assert isinstance(model, StruphyModel)
6970

7071
# generate paramater file for testing
71-
test_folder = os.path.join(os.getcwd(), "struphy_model_test")
72-
path = os.path.join(test_folder, f"params_{model_name}.py")
73-
74-
if rank == 0:
75-
model.generate_default_parameter_file(path=path, prompt=False)
76-
del model
77-
MPI.COMM_WORLD.Barrier()
78-
79-
# set environment options
80-
env = EnvironmentOptions(out_folders=test_folder, sim_folder=f"{model_name}")
81-
82-
# read parameters
83-
params_in = import_parameters_py(path)
84-
base_units = params_in.base_units
85-
time_opts = params_in.time_opts
86-
domain = params_in.domain
87-
equil = params_in.equil
88-
grid = params_in.grid
89-
derham_opts = params_in.derham_opts
90-
model = params_in.model
91-
92-
# test
93-
main.run(
94-
model,
95-
params_path=path,
96-
env=env,
97-
base_units=base_units,
98-
time_opts=time_opts,
99-
domain=domain,
100-
equil=equil,
101-
grid=grid,
102-
derham_opts=derham_opts,
103-
verbose=verbose,
104-
)
105-
106-
MPI.COMM_WORLD.Barrier()
107-
if rank == 0:
108-
path_out = os.path.join(test_folder, model_name)
109-
main.pproc(path=path_out)
110-
main.load_data(path=path_out)
111-
shutil.rmtree(test_folder)
112-
MPI.COMM_WORLD.Barrier()
72+
with tempfile.TemporaryDirectory() as test_folder:
73+
path = os.path.join(test_folder, f"params_{model_name}.py")
74+
75+
if rank == 0:
76+
model.generate_default_parameter_file(path=path, prompt=False)
77+
del model
78+
MPI.COMM_WORLD.Barrier()
79+
80+
# set environment options
81+
env = EnvironmentOptions(out_folders=test_folder, sim_folder=f"{model_name}")
82+
83+
# read parameters
84+
params_in = import_parameters_py(path)
85+
base_units = params_in.base_units
86+
time_opts = params_in.time_opts
87+
domain = params_in.domain
88+
equil = params_in.equil
89+
grid = params_in.grid
90+
derham_opts = params_in.derham_opts
91+
model = params_in.model
92+
93+
# test
94+
main.run(
95+
model,
96+
params_path=path,
97+
env=env,
98+
base_units=base_units,
99+
time_opts=time_opts,
100+
domain=domain,
101+
equil=equil,
102+
grid=grid,
103+
derham_opts=derham_opts,
104+
verbose=verbose,
105+
)
106+
107+
MPI.COMM_WORLD.Barrier()
108+
if rank == 0:
109+
path_out = os.path.join(test_folder, model_name)
110+
main.pproc(path=path_out)
111+
main.load_data(path=path_out)
112+
shutil.rmtree(test_folder)
113+
MPI.COMM_WORLD.Barrier()

src/struphy/models/tests/verification/test_verif_EulerSPH.py

Lines changed: 128 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import shutil
3+
import tempfile
34

45
import cunumpy as xp
56
import pytest
@@ -33,135 +34,133 @@ def test_soundwave_1d(nx: int, plot_pts: int, do_plot: bool = False):
3334
from struphy.models.fluid import EulerSPH
3435

3536
# environment options
36-
test_folder = os.path.join(os.getcwd(), "struphy_verification_tests")
37-
out_folders = os.path.join(test_folder, "EulerSPH")
38-
env = EnvironmentOptions(out_folders=out_folders, sim_folder="soundwave_1d")
39-
40-
# units
41-
base_units = BaseUnits(kBT=1.0)
42-
43-
# time stepping
44-
time_opts = Time(dt=0.03125, Tend=2.5, split_algo="Strang")
45-
46-
# geometry
47-
r1 = 2.5
48-
domain = domains.Cuboid(r1=r1)
49-
50-
# fluid equilibrium (can be used as part of initial conditions)
51-
equil = None
52-
53-
# grid
54-
grid = None
55-
56-
# derham options
57-
derham_opts = None
58-
59-
# light-weight model instance
60-
model = EulerSPH(with_B0=False)
61-
62-
# species parameters
63-
model.euler_fluid.set_phys_params()
64-
65-
loading_params = LoadingParameters(ppb=8, loading="tesselation")
66-
weights_params = WeightsParameters()
67-
boundary_params = BoundaryParameters()
68-
model.euler_fluid.set_markers(
69-
loading_params=loading_params,
70-
weights_params=weights_params,
71-
boundary_params=boundary_params,
72-
)
73-
model.euler_fluid.set_sorting_boxes(
74-
boxes_per_dim=(nx, 1, 1),
75-
dims_maks=(True, False, False),
76-
)
77-
78-
bin_plot = BinningPlot(slice="e1", n_bins=(32,), ranges=(0.0, 1.0))
79-
kd_plot = KernelDensityPlot(pts_e1=plot_pts, pts_e2=1)
80-
model.euler_fluid.set_save_data(
81-
binning_plots=(bin_plot,),
82-
kernel_density_plots=(kd_plot,),
83-
)
84-
85-
# propagator options
86-
from struphy.ode.utils import ButcherTableau
87-
88-
butcher = ButcherTableau(algo="forward_euler")
89-
model.propagators.push_eta.options = model.propagators.push_eta.Options(butcher=butcher)
90-
if model.with_B0:
91-
model.propagators.push_vxb.options = model.propagators.push_vxb.Options()
92-
model.propagators.push_sph_p.options = model.propagators.push_sph_p.Options(kernel_type="gaussian_1d")
93-
94-
# background, perturbations and initial conditions
95-
background = equils.ConstantVelocity()
96-
model.euler_fluid.var.add_background(background)
97-
perturbation = perturbations.ModesSin(ls=(1,), amps=(1.0e-2,))
98-
model.euler_fluid.var.add_perturbation(del_n=perturbation)
99-
100-
# start run
101-
main.run(
102-
model,
103-
params_path=None,
104-
env=env,
105-
base_units=base_units,
106-
time_opts=time_opts,
107-
domain=domain,
108-
equil=equil,
109-
grid=grid,
110-
derham_opts=derham_opts,
111-
verbose=True,
112-
)
113-
114-
# post processing
115-
if MPI.COMM_WORLD.Get_rank() == 0:
116-
main.pproc(env.path_out)
117-
118-
# diagnostics
119-
simdata = main.load_data(env.path_out)
120-
121-
ee1, ee2, ee3 = simdata.n_sph["euler_fluid"]["view_0"]["grid_n_sph"]
122-
n_sph = simdata.n_sph["euler_fluid"]["view_0"]["n_sph"]
123-
124-
if do_plot:
125-
ppb = 8
126-
dt = time_opts.dt
127-
end_time = time_opts.Tend
128-
Nt = int(end_time // dt)
129-
x = ee1 * r1
130-
131-
plt.figure(figsize=(10, 8))
132-
interval = Nt / 10
133-
plot_ct = 0
134-
for i in range(0, Nt + 1):
135-
if i % interval == 0:
136-
print(f"{i =}")
137-
plot_ct += 1
138-
ax = plt.gca()
139-
140-
if plot_ct <= 6:
141-
style = "-"
142-
else:
143-
style = "."
144-
plt.plot(x.squeeze(), n_sph[i, :, 0, 0], style, label=f"time={i * dt:4.2f}")
145-
plt.xlim(0, 2.5)
146-
plt.legend()
147-
ax.set_xticks(xp.linspace(0, 2.5, nx + 1))
148-
ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
149-
plt.grid(c="k")
150-
plt.xlabel("x")
151-
plt.ylabel(r"$\rho$")
152-
153-
plt.title(f"standing sound wave ($c_s = 1$) for {nx =} and {ppb =}")
154-
if plot_ct == 11:
155-
break
156-
157-
plt.show()
158-
159-
error = xp.max(xp.abs(n_sph[0] - n_sph[-1]))
160-
print(f"SPH sound wave {error =}.")
161-
assert error < 6e-4
162-
print("Assertion passed.")
163-
164-
shutil.rmtree(test_folder)
37+
with tempfile.TemporaryDirectory() as test_folder:
38+
out_folders = os.path.join(test_folder, "EulerSPH")
39+
env = EnvironmentOptions(out_folders=out_folders, sim_folder="soundwave_1d")
40+
41+
# units
42+
base_units = BaseUnits(kBT=1.0)
43+
44+
# time stepping
45+
time_opts = Time(dt=0.03125, Tend=2.5, split_algo="Strang")
46+
47+
# geometry
48+
r1 = 2.5
49+
domain = domains.Cuboid(r1=r1)
50+
51+
# fluid equilibrium (can be used as part of initial conditions)
52+
equil = None
53+
54+
# grid
55+
grid = None
56+
57+
# derham options
58+
derham_opts = None
59+
60+
# light-weight model instance
61+
model = EulerSPH(with_B0=False)
62+
63+
# species parameters
64+
model.euler_fluid.set_phys_params()
65+
66+
loading_params = LoadingParameters(ppb=8, loading="tesselation")
67+
weights_params = WeightsParameters()
68+
boundary_params = BoundaryParameters()
69+
model.euler_fluid.set_markers(
70+
loading_params=loading_params,
71+
weights_params=weights_params,
72+
boundary_params=boundary_params,
73+
)
74+
model.euler_fluid.set_sorting_boxes(
75+
boxes_per_dim=(nx, 1, 1),
76+
dims_maks=(True, False, False),
77+
)
78+
79+
bin_plot = BinningPlot(slice="e1", n_bins=(32,), ranges=(0.0, 1.0))
80+
kd_plot = KernelDensityPlot(pts_e1=plot_pts, pts_e2=1)
81+
model.euler_fluid.set_save_data(
82+
binning_plots=(bin_plot,),
83+
kernel_density_plots=(kd_plot,),
84+
)
85+
86+
# propagator options
87+
from struphy.ode.utils import ButcherTableau
88+
89+
butcher = ButcherTableau(algo="forward_euler")
90+
model.propagators.push_eta.options = model.propagators.push_eta.Options(butcher=butcher)
91+
if model.with_B0:
92+
model.propagators.push_vxb.options = model.propagators.push_vxb.Options()
93+
model.propagators.push_sph_p.options = model.propagators.push_sph_p.Options(kernel_type="gaussian_1d")
94+
95+
# background, perturbations and initial conditions
96+
background = equils.ConstantVelocity()
97+
model.euler_fluid.var.add_background(background)
98+
perturbation = perturbations.ModesSin(ls=(1,), amps=(1.0e-2,))
99+
model.euler_fluid.var.add_perturbation(del_n=perturbation)
100+
101+
# start run
102+
main.run(
103+
model,
104+
params_path=None,
105+
env=env,
106+
base_units=base_units,
107+
time_opts=time_opts,
108+
domain=domain,
109+
equil=equil,
110+
grid=grid,
111+
derham_opts=derham_opts,
112+
verbose=True,
113+
)
114+
115+
# post processing
116+
if MPI.COMM_WORLD.Get_rank() == 0:
117+
main.pproc(env.path_out)
118+
119+
# diagnostics
120+
simdata = main.load_data(env.path_out)
121+
122+
ee1, ee2, ee3 = simdata.n_sph["euler_fluid"]["view_0"]["grid_n_sph"]
123+
n_sph = simdata.n_sph["euler_fluid"]["view_0"]["n_sph"]
124+
125+
if do_plot:
126+
ppb = 8
127+
dt = time_opts.dt
128+
end_time = time_opts.Tend
129+
Nt = int(end_time // dt)
130+
x = ee1 * r1
131+
132+
plt.figure(figsize=(10, 8))
133+
interval = Nt / 10
134+
plot_ct = 0
135+
for i in range(0, Nt + 1):
136+
if i % interval == 0:
137+
print(f"{i =}")
138+
plot_ct += 1
139+
ax = plt.gca()
140+
141+
if plot_ct <= 6:
142+
style = "-"
143+
else:
144+
style = "."
145+
plt.plot(x.squeeze(), n_sph[i, :, 0, 0], style, label=f"time={i * dt:4.2f}")
146+
plt.xlim(0, 2.5)
147+
plt.legend()
148+
ax.set_xticks(xp.linspace(0, 2.5, nx + 1))
149+
ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
150+
plt.grid(c="k")
151+
plt.xlabel("x")
152+
plt.ylabel(r"$\rho$")
153+
154+
plt.title(f"standing sound wave ($c_s = 1$) for {nx =} and {ppb =}")
155+
if plot_ct == 11:
156+
break
157+
158+
plt.show()
159+
160+
error = xp.max(xp.abs(n_sph[0] - n_sph[-1]))
161+
print(f"SPH sound wave {error =}.")
162+
assert error < 6e-4
163+
print("Assertion passed.")
165164

166165

167166
if __name__ == "__main__":

0 commit comments

Comments
 (0)