Skip to content

Commit 5b9f157

Browse files
committed
More linting...
1 parent 631d61c commit 5b9f157

File tree

2 files changed

+103
-102
lines changed

2 files changed

+103
-102
lines changed

examples/PyMPDATA_examples/Jarecka_et_al_2015/formulae.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import numpy as np
88
from scipy.integrate import odeint
99

10+
from PyMPDATA.impl.enumerations import ARG_DATA, ARG_FOCUS, MAX_DIM_NUM
11+
1012

1113
def amplitude(x, y, lx, ly):
1214
A = 1 / lx / ly
@@ -34,3 +36,95 @@ def d2_el_lamb_lamb_t_evol(times, lamb_x0, lamb_y0):
3436
result, info = odeint(deriv, yinit, times, full_output=True)
3537
assert info["message"] == "Integration successful."
3638
return result
39+
40+
41+
def make_rhs_indexers(ats, grid_step, time_step, options):
42+
@numba.njit(**options.jit_flags)
43+
def rhs(m, _0, h, _1, _2, _3):
44+
retval = (
45+
m
46+
- ((ats(*h, +1) - ats(*h, -1)) / 2) / 2 * ats(*h, 0) * time_step / grid_step
47+
)
48+
return retval
49+
50+
return rhs
51+
52+
53+
def make_rhs(grid_step, time_step, axis, options, traversals):
54+
indexers = traversals.indexers[traversals.n_dims]
55+
apply_scalar = traversals.apply_scalar(loop=False)
56+
57+
formulae_rhs = tuple(
58+
(
59+
make_rhs_indexers(
60+
ats=indexers.ats[axis],
61+
grid_step=grid_step[axis],
62+
time_step=time_step,
63+
options=options,
64+
),
65+
None,
66+
None,
67+
)
68+
)
69+
70+
@numba.njit(**options.jit_flags)
71+
def apply(traversals_data, momentum, h):
72+
null_scalarfield, null_scalarfield_bc = traversals_data.null_scalar_field
73+
null_vectorfield, null_vectorfield_bc = traversals_data.null_vector_field
74+
return apply_scalar(
75+
*formulae_rhs,
76+
*momentum.field,
77+
*null_vectorfield,
78+
null_vectorfield_bc,
79+
*h.field,
80+
h.bc,
81+
*null_scalarfield,
82+
null_scalarfield_bc,
83+
*null_scalarfield,
84+
null_scalarfield_bc,
85+
*null_scalarfield,
86+
null_scalarfield_bc,
87+
traversals_data.buffer
88+
)
89+
90+
return apply
91+
92+
93+
def make_interpolate_indexers(ati, options):
94+
@numba.njit(**options.jit_flags)
95+
def interpolate(momentum_x, _, momentum_y):
96+
momenta = (momentum_x[ARG_FOCUS], (momentum_x[ARG_DATA], momentum_y[ARG_DATA]))
97+
return ati(*momenta, 0.5)
98+
99+
return interpolate
100+
101+
102+
def make_interpolate(options, traversals):
103+
indexers = traversals.indexers[traversals.n_dims]
104+
apply_vector = traversals.apply_vector()
105+
106+
formulae_interpolate = tuple(
107+
(
108+
make_interpolate_indexers(ati=indexers.ati[i], options=options)
109+
if indexers.ati[i] is not None
110+
else None
111+
)
112+
for i in range(MAX_DIM_NUM)
113+
)
114+
115+
@numba.njit(**options.jit_flags)
116+
def apply(traversals_data, momentum_x, momentum_y, advector):
117+
null_vectorfield, null_vectorfield_bc = traversals_data.null_vector_field
118+
return apply_vector(
119+
*formulae_interpolate,
120+
*advector.field,
121+
*momentum_x.field,
122+
momentum_x.bc,
123+
*null_vectorfield,
124+
null_vectorfield_bc,
125+
*momentum_y.field,
126+
momentum_y.bc,
127+
traversals_data.buffer
128+
)
129+
130+
return apply

examples/PyMPDATA_examples/Jarecka_et_al_2015/simulation.py

Lines changed: 9 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -4,112 +4,19 @@
44

55
from PyMPDATA import ScalarField, Solver, Stepper, VectorField
66
from PyMPDATA.boundary_conditions import Constant
7-
from PyMPDATA.impl.enumerations import ARG_DATA, ARG_FOCUS, INNER, MAX_DIM_NUM, OUTER
7+
from PyMPDATA.impl.enumerations import INNER, OUTER
88
from PyMPDATA.impl.formulae_divide import make_divide_or_zero
99

1010

11-
def make_rhs_indexers(ats, grid_step, time_step, options):
12-
@numba.njit(**options.jit_flags)
13-
def rhs(m, _0, h, _1, _2, _3):
14-
retval = (
15-
m
16-
- ((ats(*h, +1) - ats(*h, -1)) / 2) / 2 * ats(*h, 0) * time_step / grid_step
17-
)
18-
return retval
19-
20-
return rhs
21-
22-
23-
def make_rhs(grid_step, time_step, axis, options, traversals):
24-
indexers = traversals.indexers[traversals.n_dims]
25-
apply_scalar = traversals.apply_scalar(loop=False)
26-
27-
formulae_rhs = tuple(
28-
(
29-
make_rhs_indexers(
30-
ats=indexers.ats[axis],
31-
grid_step=grid_step[axis],
32-
time_step=time_step,
33-
options=options,
34-
),
35-
None,
36-
None,
37-
)
38-
)
39-
40-
@numba.njit(**options.jit_flags)
41-
def apply(traversals_data, momentum, h):
42-
null_scalarfield, null_scalarfield_bc = traversals_data.null_scalar_field
43-
null_vectorfield, null_vectorfield_bc = traversals_data.null_vector_field
44-
return apply_scalar(
45-
*formulae_rhs,
46-
*momentum.field,
47-
*null_vectorfield,
48-
null_vectorfield_bc,
49-
*h.field,
50-
h.bc,
51-
*null_scalarfield,
52-
null_scalarfield_bc,
53-
*null_scalarfield,
54-
null_scalarfield_bc,
55-
*null_scalarfield,
56-
null_scalarfield_bc,
57-
traversals_data.buffer
58-
)
59-
60-
return apply
61-
62-
63-
def make_interpolate_indexers(ati, options):
64-
@numba.njit(**options.jit_flags)
65-
def interpolate(momentum_x, _, momentum_y):
66-
momenta = (momentum_x[ARG_FOCUS], (momentum_x[ARG_DATA], momentum_y[ARG_DATA]))
67-
return ati(*momenta, 0.5)
68-
69-
return interpolate
70-
71-
72-
def make_interpolate(options, traversals):
73-
indexers = traversals.indexers[traversals.n_dims]
74-
apply_vector = traversals.apply_vector()
75-
76-
formulae_interpolate = tuple(
77-
(
78-
make_interpolate_indexers(ati=indexers.ati[i], options=options)
79-
if indexers.ati[i] is not None
80-
else None
81-
)
82-
for i in range(MAX_DIM_NUM)
83-
)
84-
85-
@numba.njit(**options.jit_flags)
86-
def apply(traversals_data, momentum_x, momentum_y, advector):
87-
null_scalarfield, null_scalarfield_bc = traversals_data.null_scalar_field
88-
null_vectorfield, null_vectorfield_bc = traversals_data.null_vector_field
89-
return apply_vector(
90-
*formulae_interpolate,
91-
*advector.field,
92-
*momentum_x.field,
93-
momentum_x.bc,
94-
*null_vectorfield,
95-
null_vectorfield_bc,
96-
*momentum_y.field,
97-
momentum_y.bc,
98-
traversals_data.buffer
99-
)
100-
101-
return apply
102-
103-
10411
def make_hooks(*, traversals, options, grid_step, time_step):
10512

10613
divide_or_zero = make_divide_or_zero(options, traversals)
107-
interpolate = make_interpolate(options, traversals)
108-
rhs_x = make_rhs(grid_step, time_step, OUTER, options, traversals)
109-
rhs_y = make_rhs(grid_step, time_step, INNER, options, traversals)
14+
interpolate = formulae.make_interpolate(options, traversals)
15+
rhs_x = formulae.make_rhs(grid_step, time_step, OUTER, options, traversals)
16+
rhs_y = formulae.make_rhs(grid_step, time_step, INNER, options, traversals)
11017

11118
@numba.experimental.jitclass([])
112-
class AnteStep:
19+
class AnteStep: # pylint:disable=too-few-public-methods
11320
def __init__(self):
11421
pass
11522

@@ -118,7 +25,7 @@ def call(
11825
traversals_data,
11926
advectees,
12027
advector,
121-
step,
28+
_,
12229
index,
12330
todo_outer,
12431
todo_mid3d,
@@ -143,11 +50,11 @@ def call(
14350
rhs_y(traversals_data, advectees[index], advectees[0])
14451

14552
@numba.experimental.jitclass([])
146-
class PostStep:
53+
class PostStep: # pylint:disable=too-few-public-methods
14754
def __init__(self):
14855
pass
14956

150-
def call(self, traversals_data, advectees, step, index):
57+
def call(self, traversals_data, advectees, _, index):
15158
if index == 0:
15259
pass
15360
if index == 1:
@@ -208,7 +115,7 @@ def run(self):
208115
output.append(
209116
{
210117
k: self.solver.advectee[k].get().copy()
211-
for k in self.advectees.keys()
118+
for k in self.advectees.keys() # pylint:disable=consider-iterating-dictionary
212119
}
213120
)
214121
return output

0 commit comments

Comments
 (0)