Skip to content

Commit 0683907

Browse files
committed
compiler: Pass along cgroup, not exprs
1 parent 174aa6d commit 0683907

File tree

1 file changed

+29
-21
lines changed

1 file changed

+29
-21
lines changed

devito/passes/clusters/aliases.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
from devito.exceptions import CompilationError
99
from devito.finite_differences import EvalDerivative, IndexDerivative, Weights
10-
from devito.ir import (SEQUENTIAL, PARALLEL_IF_PVT, SEPARABLE, Forward,
11-
IterationSpace, Interval, Cluster, ExprGeometry, Queue,
12-
IntervalGroup, LabeledVector, Vector, normalize_properties,
13-
relax_properties, unbounded, minimum, maximum, extrema,
14-
vmax, vmin)
10+
from devito.ir import (
11+
SEQUENTIAL, PARALLEL_IF_PVT, SEPARABLE, Forward, IterationSpace, Interval,
12+
Cluster, ClusterGroup, ExprGeometry, Queue, IntervalGroup, LabeledVector,
13+
Vector, normalize_properties, relax_properties, unbounded, minimum, maximum,
14+
extrema, vmax, vmin
15+
)
1516
from devito.passes.clusters.cse import _cse
1617
from devito.symbolics import (Uxmapper, estimate_cost, search, reuse_if_untouched,
1718
retrieve_functions, uxreplace, sympy_dtype)
@@ -117,20 +118,18 @@ def __init__(self, sregistry, options, platform):
117118
self.opt_min_dtype = options['scalar-min-type']
118119
self.opt_multisubdomain = True
119120

120-
def _aliases_from_clusters(self, clusters, exclude, meta):
121-
exprs = flatten([c.exprs for c in clusters])
122-
121+
def _aliases_from_clusters(self, cgroup, exclude, meta):
123122
# [Clusters]_n -> [Schedule]_m
124123
variants = []
125-
for mapper in self._generate(exprs, exclude):
124+
for mapper in self._generate(cgroup, exclude):
126125
# Clusters -> AliasList
127126
found = collect(mapper.extracted, meta.ispace, self.opt_minstorage)
128-
pexprs, aliases = choose(found, exprs, mapper, self.opt_mingain)
127+
exprs, aliases = choose(found, cgroup, mapper, self.opt_mingain)
129128

130129
# AliasList -> Schedule
131130
schedule = lower_aliases(aliases, meta, self.opt_maxpar)
132131

133-
variants.append(Variant(schedule, pexprs))
132+
variants.append(Variant(schedule, exprs))
134133

135134
if not variants:
136135
return []
@@ -152,7 +151,7 @@ def _aliases_from_clusters(self, clusters, exclude, meta):
152151
processed = optimize_clusters_msds(processed)
153152

154153
# [Clusters]_k -> [Clusters]_{k+n}
155-
for c in clusters:
154+
for c in cgroup:
156155
n = len(c.exprs)
157156
cexprs, exprs = exprs[:n], exprs[n:]
158157

@@ -170,9 +169,9 @@ def _aliases_from_clusters(self, clusters, exclude, meta):
170169
def process(self, clusters):
171170
raise NotImplementedError
172171

173-
def _generate(self, exprs, exclude):
172+
def _generate(self, cgroup, exclude):
174173
"""
175-
Generate one or more extractions from ``exprs``. An extraction is a
174+
Generate one or more extractions from a ClusterGroup. An extraction is a
176175
set of CIRE candidates which may be turned into aliases. Two different
177176
extractions may contain overlapping sub-expressions and, therefore,
178177
should be processed and evaluated indipendently. An extraction won't
@@ -260,7 +259,7 @@ def callback(self, clusters, prefix, xtracted=None):
260259
if not g:
261260
continue
262261

263-
made = self._aliases_from_clusters(g, exclude, ak)
262+
made = self._aliases_from_clusters(ClusterGroup(g), exclude, ak)
264263

265264
if made:
266265
idx = processed.index(g[0])
@@ -285,7 +284,9 @@ def _select(self, variants):
285284

286285
class CireInvariantsElementary(CireInvariants):
287286

288-
def _generate(self, exprs, exclude):
287+
def _generate(self, cgroup, exclude):
288+
exprs = cgroup.exprs
289+
289290
# E.g., extract `sin(x)` and `sqrt(x)` from `a*sin(x)*sqrt(x)`
290291
rule = lambda e: e.is_Function or (e.is_Pow and e.exp.is_Number and 0 < e.exp < 1)
291292
cbk_search = lambda e: search(e, rule, 'all', 'bfs_first_hit')
@@ -308,7 +309,9 @@ def cbk_search(expr):
308309

309310
class CireInvariantsDivs(CireInvariants):
310311

311-
def _generate(self, exprs, exclude):
312+
def _generate(self, cgroup, exclude):
313+
exprs = cgroup.exprs
314+
312315
# E.g., extract `1/h_x`
313316
rule = lambda e: e.is_Pow and (not e.exp.is_Number or e.exp < 0)
314317
cbk_search = lambda e: search(e, rule, 'all', 'bfs_first_hit')
@@ -339,13 +342,17 @@ def process(self, clusters):
339342
# TODO: to process third- and higher-order derivatives, we could
340343
# extend this by calling `_aliases_from_clusters` repeatedly until
341344
# `made` is empty. To be investigated
342-
made = self._aliases_from_clusters([c], exclude, self._lookup_key(c))
345+
made = self._aliases_from_clusters(
346+
ClusterGroup(c), exclude, self._lookup_key(c)
347+
)
343348

344349
processed.extend(flatten(made) or [c])
345350

346351
return processed
347352

348-
def _generate(self, exprs, exclude):
353+
def _generate(self, cgroup, exclude):
354+
exprs = cgroup.exprs
355+
349356
# E.g., extract `u.dx*a*b` and `u.dx*a*c` from
350357
# `[(u.dx*a*b).dy`, `(u.dx*a*c).dy]`
351358
basextr = self._do_generate(exprs, exclude, self._cbk_search,
@@ -600,14 +607,15 @@ def collect(extracted, ispace, minstorage):
600607
return aliases
601608

602609

603-
def choose(aliases, exprs, mapper, mingain):
610+
def choose(aliases, cgroup, mapper, mingain):
604611
"""
605612
Analyze the detected aliases and, after applying a cost model to rule out
606613
the aliases with a bad memory/flops trade-off, inject them into the original
607614
expressions.
608615
"""
609-
aliases = AliasList(aliases)
616+
exprs = cgroup.exprs
610617

618+
aliases = AliasList(aliases)
611619
if not aliases:
612620
return exprs, aliases
613621

0 commit comments

Comments
 (0)