44
55from PyMPDATA import ScalarField , Solver , Stepper , VectorField
66from PyMPDATA .boundary_conditions import Constant
7- from PyMPDATA .impl .enumerations import ARG_DATA , ARG_FOCUS , INNER , MAX_DIM_NUM , OUTER
7+ from PyMPDATA .impl .enumerations import INNER , OUTER
88from PyMPDATA .impl .formulae_divide import make_divide_or_zero
99
1010
11- def make_rhs_indexers (ats , grid_step , time_step , options ):
12- @numba .njit (** options .jit_flags )
13- def rhs (m , _0 , h , _1 , _2 , _3 ):
14- retval = (
15- m
16- - ((ats (* h , + 1 ) - ats (* h , - 1 )) / 2 ) / 2 * ats (* h , 0 ) * time_step / grid_step
17- )
18- return retval
19-
20- return rhs
21-
22-
23- def make_rhs (grid_step , time_step , axis , options , traversals ):
24- indexers = traversals .indexers [traversals .n_dims ]
25- apply_scalar = traversals .apply_scalar (loop = False )
26-
27- formulae_rhs = tuple (
28- (
29- make_rhs_indexers (
30- ats = indexers .ats [axis ],
31- grid_step = grid_step [axis ],
32- time_step = time_step ,
33- options = options ,
34- ),
35- None ,
36- None ,
37- )
38- )
39-
40- @numba .njit (** options .jit_flags )
41- def apply (traversals_data , momentum , h ):
42- null_scalarfield , null_scalarfield_bc = traversals_data .null_scalar_field
43- null_vectorfield , null_vectorfield_bc = traversals_data .null_vector_field
44- return apply_scalar (
45- * formulae_rhs ,
46- * momentum .field ,
47- * null_vectorfield ,
48- null_vectorfield_bc ,
49- * h .field ,
50- h .bc ,
51- * null_scalarfield ,
52- null_scalarfield_bc ,
53- * null_scalarfield ,
54- null_scalarfield_bc ,
55- * null_scalarfield ,
56- null_scalarfield_bc ,
57- traversals_data .buffer
58- )
59-
60- return apply
61-
62-
63- def make_interpolate_indexers (ati , options ):
64- @numba .njit (** options .jit_flags )
65- def interpolate (momentum_x , _ , momentum_y ):
66- momenta = (momentum_x [ARG_FOCUS ], (momentum_x [ARG_DATA ], momentum_y [ARG_DATA ]))
67- return ati (* momenta , 0.5 )
68-
69- return interpolate
70-
71-
72- def make_interpolate (options , traversals ):
73- indexers = traversals .indexers [traversals .n_dims ]
74- apply_vector = traversals .apply_vector ()
75-
76- formulae_interpolate = tuple (
77- (
78- make_interpolate_indexers (ati = indexers .ati [i ], options = options )
79- if indexers .ati [i ] is not None
80- else None
81- )
82- for i in range (MAX_DIM_NUM )
83- )
84-
85- @numba .njit (** options .jit_flags )
86- def apply (traversals_data , momentum_x , momentum_y , advector ):
87- null_scalarfield , null_scalarfield_bc = traversals_data .null_scalar_field
88- null_vectorfield , null_vectorfield_bc = traversals_data .null_vector_field
89- return apply_vector (
90- * formulae_interpolate ,
91- * advector .field ,
92- * momentum_x .field ,
93- momentum_x .bc ,
94- * null_vectorfield ,
95- null_vectorfield_bc ,
96- * momentum_y .field ,
97- momentum_y .bc ,
98- traversals_data .buffer
99- )
100-
101- return apply
102-
103-
10411def make_hooks (* , traversals , options , grid_step , time_step ):
10512
10613 divide_or_zero = make_divide_or_zero (options , traversals )
107- interpolate = make_interpolate (options , traversals )
108- rhs_x = make_rhs (grid_step , time_step , OUTER , options , traversals )
109- rhs_y = make_rhs (grid_step , time_step , INNER , options , traversals )
14+ interpolate = formulae . make_interpolate (options , traversals )
15+ rhs_x = formulae . make_rhs (grid_step , time_step , OUTER , options , traversals )
16+ rhs_y = formulae . make_rhs (grid_step , time_step , INNER , options , traversals )
11017
11118 @numba .experimental .jitclass ([])
112- class AnteStep :
19+ class AnteStep : # pylint:disable=too-few-public-methods
11320 def __init__ (self ):
11421 pass
11522
@@ -118,7 +25,7 @@ def call(
11825 traversals_data ,
11926 advectees ,
12027 advector ,
121- step ,
28+ _ ,
12229 index ,
12330 todo_outer ,
12431 todo_mid3d ,
@@ -143,11 +50,11 @@ def call(
14350 rhs_y (traversals_data , advectees [index ], advectees [0 ])
14451
14552 @numba .experimental .jitclass ([])
146- class PostStep :
53+ class PostStep : # pylint:disable=too-few-public-methods
14754 def __init__ (self ):
14855 pass
14956
150- def call (self , traversals_data , advectees , step , index ):
57+ def call (self , traversals_data , advectees , _ , index ):
15158 if index == 0 :
15259 pass
15360 if index == 1 :
@@ -208,7 +115,7 @@ def run(self):
208115 output .append (
209116 {
210117 k : self .solver .advectee [k ].get ().copy ()
211- for k in self .advectees .keys ()
118+ for k in self .advectees .keys () # pylint:disable=consider-iterating-dictionary
212119 }
213120 )
214121 return output
0 commit comments