77
88from devito .exceptions import CompilationError
99from devito .finite_differences import EvalDerivative , IndexDerivative , Weights
10- from devito .ir import (SEQUENTIAL , PARALLEL_IF_PVT , SEPARABLE , Forward ,
11- IterationSpace , Interval , Cluster , ExprGeometry , Queue ,
12- IntervalGroup , LabeledVector , Vector , normalize_properties ,
13- relax_properties , unbounded , minimum , maximum , extrema ,
14- vmax , vmin )
10+ from devito .ir import (
11+ SEQUENTIAL , PARALLEL_IF_PVT , SEPARABLE , Forward , IterationSpace , Interval ,
12+ Cluster , ClusterGroup , ExprGeometry , Queue , IntervalGroup , LabeledVector ,
13+ Vector , normalize_properties , relax_properties , unbounded , minimum , maximum ,
14+ extrema , vmax , vmin
15+ )
1516from devito .passes .clusters .cse import _cse
1617from devito .symbolics import (Uxmapper , estimate_cost , search , reuse_if_untouched ,
1718 retrieve_functions , uxreplace , sympy_dtype )
@@ -117,20 +118,18 @@ def __init__(self, sregistry, options, platform):
117118 self .opt_min_dtype = options ['scalar-min-type' ]
118119 self .opt_multisubdomain = True
119120
120- def _aliases_from_clusters (self , clusters , exclude , meta ):
121- exprs = flatten ([c .exprs for c in clusters ])
122-
121+ def _aliases_from_clusters (self , cgroup , exclude , meta ):
123122 # [Clusters]_n -> [Schedule]_m
124123 variants = []
125- for mapper in self ._generate (exprs , exclude ):
124+ for mapper in self ._generate (cgroup , exclude ):
126125 # Clusters -> AliasList
127126 found = collect (mapper .extracted , meta .ispace , self .opt_minstorage )
128- pexprs , aliases = choose (found , exprs , mapper , self .opt_mingain )
127+ exprs , aliases = choose (found , cgroup , mapper , self .opt_mingain )
129128
130129 # AliasList -> Schedule
131130 schedule = lower_aliases (aliases , meta , self .opt_maxpar )
132131
133- variants .append (Variant (schedule , pexprs ))
132+ variants .append (Variant (schedule , exprs ))
134133
135134 if not variants :
136135 return []
@@ -152,7 +151,7 @@ def _aliases_from_clusters(self, clusters, exclude, meta):
152151 processed = optimize_clusters_msds (processed )
153152
154153 # [Clusters]_k -> [Clusters]_{k+n}
155- for c in clusters :
154+ for c in cgroup :
156155 n = len (c .exprs )
157156 cexprs , exprs = exprs [:n ], exprs [n :]
158157
@@ -170,9 +169,9 @@ def _aliases_from_clusters(self, clusters, exclude, meta):
170169 def process (self , clusters ):
171170 raise NotImplementedError
172171
173- def _generate (self , exprs , exclude ):
172+ def _generate (self , cgroup , exclude ):
174173 """
175- Generate one or more extractions from ``exprs`` . An extraction is a
174+ Generate one or more extractions from a ClusterGroup . An extraction is a
176175 set of CIRE candidates which may be turned into aliases. Two different
177176 extractions may contain overlapping sub-expressions and, therefore,
178177 should be processed and evaluated indipendently. An extraction won't
@@ -260,7 +259,7 @@ def callback(self, clusters, prefix, xtracted=None):
260259 if not g :
261260 continue
262261
263- made = self ._aliases_from_clusters (g , exclude , ak )
262+ made = self ._aliases_from_clusters (ClusterGroup ( g ) , exclude , ak )
264263
265264 if made :
266265 idx = processed .index (g [0 ])
@@ -285,7 +284,9 @@ def _select(self, variants):
285284
286285class CireInvariantsElementary (CireInvariants ):
287286
288- def _generate (self , exprs , exclude ):
287+ def _generate (self , cgroup , exclude ):
288+ exprs = cgroup .exprs
289+
289290 # E.g., extract `sin(x)` and `sqrt(x)` from `a*sin(x)*sqrt(x)`
290291 rule = lambda e : e .is_Function or (e .is_Pow and e .exp .is_Number and 0 < e .exp < 1 )
291292 cbk_search = lambda e : search (e , rule , 'all' , 'bfs_first_hit' )
@@ -308,7 +309,9 @@ def cbk_search(expr):
308309
309310class CireInvariantsDivs (CireInvariants ):
310311
311- def _generate (self , exprs , exclude ):
312+ def _generate (self , cgroup , exclude ):
313+ exprs = cgroup .exprs
314+
312315 # E.g., extract `1/h_x`
313316 rule = lambda e : e .is_Pow and (not e .exp .is_Number or e .exp < 0 )
314317 cbk_search = lambda e : search (e , rule , 'all' , 'bfs_first_hit' )
@@ -339,13 +342,17 @@ def process(self, clusters):
339342 # TODO: to process third- and higher-order derivatives, we could
340343 # extend this by calling `_aliases_from_clusters` repeatedly until
341344 # `made` is empty. To be investigated
342- made = self ._aliases_from_clusters ([c ], exclude , self ._lookup_key (c ))
345+ made = self ._aliases_from_clusters (
346+ ClusterGroup (c ), exclude , self ._lookup_key (c )
347+ )
343348
344349 processed .extend (flatten (made ) or [c ])
345350
346351 return processed
347352
348- def _generate (self , exprs , exclude ):
353+ def _generate (self , cgroup , exclude ):
354+ exprs = cgroup .exprs
355+
349356 # E.g., extract `u.dx*a*b` and `u.dx*a*c` from
350357 # `[(u.dx*a*b).dy`, `(u.dx*a*c).dy]`
351358 basextr = self ._do_generate (exprs , exclude , self ._cbk_search ,
@@ -600,14 +607,15 @@ def collect(extracted, ispace, minstorage):
600607 return aliases
601608
602609
603- def choose (aliases , exprs , mapper , mingain ):
610+ def choose (aliases , cgroup , mapper , mingain ):
604611 """
605612 Analyze the detected aliases and, after applying a cost model to rule out
606613 the aliases with a bad memory/flops trade-off, inject them into the original
607614 expressions.
608615 """
609- aliases = AliasList ( aliases )
616+ exprs = cgroup . exprs
610617
618+ aliases = AliasList (aliases )
611619 if not aliases :
612620 return exprs , aliases
613621
0 commit comments