Skip to content

Commit 05dc99a

Browse files
committed
Add traversals_data as argument to stepper loop, disable mpi_dim=INNER for ShallowWaterScenario
1 parent 40e1cbb commit 05dc99a

File tree

5 files changed

+42
-18
lines changed

5 files changed

+42
-18
lines changed

PyMPDATA/solver.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,21 @@ class grouping user-supplied stepper, fields and post-step/post-iter hooks,
1515

1616
@numba.experimental.jitclass([])
1717
class AnteStepNull: # pylint: disable=too-few-public-methods
18-
"""do-nothing version of the post-step hook"""
18+
"""do-nothing version of the ante-step hook"""
1919

2020
def __init__(self):
2121
pass
2222

2323
def call(
24-
self, advectee, advector, step, index, todo_outer, todo_mid3d, todo_inner
24+
self,
25+
traversals_data,
26+
advectee,
27+
advector,
28+
step,
29+
index,
30+
todo_outer,
31+
todo_mid3d,
32+
todo_inner,
2533
): # pylint: disable-next=unused-argument,disable=too-many-arguments
2634
"""think of it as a `__call__` method (which Numba does not allow)"""
2735

@@ -33,7 +41,9 @@ class PostStepNull: # pylint: disable=too-few-public-methods
3341
def __init__(self):
3442
pass
3543

36-
def call(self, psi, step, index): # pylint: disable-next=unused-argument
44+
def call(
45+
self, traversals_data, psi, step, index
46+
): # pylint: disable-next=unused-argument
3747
"""think of it as a `__call__` method (which Numba does not allow)"""
3848

3949

@@ -44,7 +54,9 @@ class PostIterNull: # pylint: disable=too-few-public-methods
4454
def __init__(self):
4555
pass
4656

47-
def call(self, flux, g_factor, step, iteration): # pylint: disable=unused-argument
57+
def call(
58+
self, traversals_data, flux, g_factor, step, iteration
59+
): # pylint: disable=unused-argument
4860
"""think of it as a `__call__` method (which Numba does not allow)"""
4961

5062

PyMPDATA/stepper.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,14 @@ def step(
193193
for step in range(n_steps):
194194
for index, advectee in enumerate(advectees):
195195
ante_step.call(
196-
advectees, advector, step, index, todo_outer, todo_mid3d, todo_inner
196+
traversals_data,
197+
advectees,
198+
advector,
199+
step,
200+
index,
201+
todo_outer,
202+
todo_mid3d,
203+
todo_inner,
197204
)
198205
if non_zero_mu_coeff:
199206
advector_orig = advector
@@ -261,11 +268,12 @@ def step(
261268
traversals_data, flux, advectee, advector_nonos
262269
)
263270
upwind(traversals_data, advectee, flux, g_factor)
264-
post_iter.call(flux.field, g_factor.field, step, iteration)
271+
post_iter.call(
272+
traversals_data, flux.field, g_factor.field, step, iteration
273+
)
265274
if non_zero_mu_coeff:
266275
advector = advector_orig
267-
268-
post_step.call(advectees, step, index)
276+
post_step.call(traversals_data, advectees, step, index)
269277
return (clock() - time) / n_steps if n_steps > 0 else np.nan
270278

271279
return step, traversals

examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def apply(traversals_data, momentum_x, momentum_y, advector):
100100

101101
return apply
102102

103+
103104
def make_hooks(*, traversals, options, grid_step, time_step):
104-
traversals_data = traversals.data
105105

106106
divide_or_zero = make_divide_or_zero(options, traversals)
107107
interpolate = make_interpolate(options, traversals)
@@ -115,6 +115,7 @@ def __init__(self):
115115

116116
def call(
117117
self,
118+
traversals_data,
118119
advectees,
119120
advector,
120121
step,
@@ -145,7 +146,8 @@ def call(
145146
class PostStep:
146147
def __init__(self):
147148
pass
148-
def call(self, advectees, step, index):
149+
150+
def call(self, traversals_data, advectees, step, index):
149151
if index == 0:
150152
pass
151153
if index == 1:
@@ -189,7 +191,7 @@ def __init__(self, settings):
189191
traversals=stepper.traversals,
190192
options=settings.options,
191193
grid_step=(s.dx, None, s.dy),
192-
time_step=s.dt
194+
time_step=s.dt,
193195
)
194196

195197
self.solver = Solver(stepper, self.advectees, self.advector)

scenarios_mpi/shallow_water.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def initial_condition(x, y, lx, ly):
3737
"""returns advectee array for a given grid indices"""
3838
# pylint: disable=invalid-name
3939
A = 1 / lx / ly
40-
h = A * (1 - (x / lx) ** 2 - (y / ly) ** 2) * 6.25
40+
h = A * (1 - (x / lx) ** 2 - (y / ly) ** 2)
4141
return np.where(h > 0, h, 0)
4242

4343
# pylint: disable=invalid-name
@@ -124,7 +124,7 @@ def data(self, key):
124124

125125
def _solver_advance(self, n_steps):
126126
for _ in range(n_steps):
127-
self.solvers.advance(1, ante_step=self.ante_step)
127+
self.solvers.advance(1, ante_step=self.ante_step, post_step=self.post_step)
128128
return -1
129129

130130
@staticmethod

tests_mpi/contract_tests/test_single_vs_multi_node.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222

2323
OPTIONS_KWARGS = (
2424
{"n_iters": 1},
25-
# {"n_iters": 2, "third_order_terms": True},
26-
# {"n_iters": 2, "nonoscillatory": True, "infinite_gauge": True},
27-
# {"n_iters": 3},
25+
{"n_iters": 2, "third_order_terms": True},
26+
{"n_iters": 2, "nonoscillatory": True, "infinite_gauge": True},
27+
{"n_iters": 3},
2828
)
2929

3030
COURANT_FIELD_MULTIPLIER = (
@@ -37,7 +37,7 @@
3737

3838
CARTESIAN_OUTPUT_STEPS = range(0, 24, 2)
3939

40-
SHALLOW_WATER_OUTPUT_STEPS = range(0, 2, 1)
40+
SHALLOW_WATER_OUTPUT_STEPS = range(0, 48, 4)
4141

4242
SPHERICAL_OUTPUT_STEPS = range(0, 2000, 100)
4343

@@ -103,7 +103,9 @@ def test_single_vs_multi_node( # pylint: disable=too-many-arguments,too-many-br
103103
pytest.skip("threading requires Numba JIT to be enabled")
104104

105105
if scenario_class is ShallowWaterScenario and (
106-
options_kwargs["n_iters"] == 3 or options_kwargs.get("third_order_terms", False)
106+
options_kwargs["n_iters"] == 3
107+
or options_kwargs.get("third_order_terms", False)
108+
or mpi_dim == INNER
107109
):
108110
pytest.skip("Unsupported method for simulation")
109111
# pylint: disable=too-many-boolean-expressions

0 commit comments

Comments
 (0)