Skip to content

Commit 3ce3395

Browse files
authored
Merge pull request #2792 from devitocodes/tweak-cire-halo
compiler: Add cire-minmem optoption
2 parents 59c357b + 532d25c commit 3ce3395

File tree

16 files changed

+216
-74
lines changed

16 files changed

+216
-74
lines changed

devito/core/cpu.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def _normalize_kwargs(cls, **kwargs):
6161
o['cire-maxpar'] = oo.pop('cire-maxpar', False)
6262
o['cire-ftemps'] = oo.pop('cire-ftemps', False)
6363
o['cire-mingain'] = oo.pop('cire-mingain', cls.CIRE_MINGAIN)
64+
o['cire-minmem'] = oo.pop('cire-minmem', cls.CIRE_MINMEM)
6465
o['cire-schedule'] = oo.pop('cire-schedule', cls.CIRE_SCHEDULE)
6566

6667
# Shared-memory parallelism
@@ -75,6 +76,7 @@ def _normalize_kwargs(cls, **kwargs):
7576

7677
# Code generation options for derivatives
7778
o['expand'] = oo.pop('expand', cls.EXPAND)
79+
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
7880
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
7981
o['deriv-unroll'] = oo.pop('deriv-unroll', False)
8082

@@ -150,7 +152,7 @@ class Cpu64AdvOperator(Cpu64OperatorMixin, CoreOperator):
150152
@classmethod
151153
@timed_pass(name='specializing.DSL')
152154
def _specialize_dsl(cls, expressions, **kwargs):
153-
expressions = collect_derivatives(expressions)
155+
expressions = collect_derivatives(expressions, **kwargs)
154156

155157
return expressions
156158

@@ -253,7 +255,7 @@ class Cpu64CustomOperator(Cpu64OperatorMixin, CustomOperator):
253255
@classmethod
254256
def _make_dsl_passes_mapper(cls, **kwargs):
255257
return {
256-
'collect-derivs': collect_derivatives,
258+
'deriv-collect': collect_derivatives,
257259
}
258260

259261
@classmethod
@@ -308,7 +310,7 @@ def _make_iet_passes_mapper(cls, **kwargs):
308310

309311
_known_passes = (
310312
# DSL
311-
'collect-derivs',
313+
'deriv-collect',
312314
# Expressions
313315
'buffering',
314316
# Clusters

devito/core/gpu.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def _normalize_kwargs(cls, **kwargs):
6868
o['cire-maxpar'] = oo.pop('cire-maxpar', True)
6969
o['cire-ftemps'] = oo.pop('cire-ftemps', False)
7070
o['cire-mingain'] = oo.pop('cire-mingain', cls.CIRE_MINGAIN)
71+
o['cire-minmem'] = oo.pop('cire-minmem', cls.CIRE_MINMEM)
7172
o['cire-schedule'] = oo.pop('cire-schedule', cls.CIRE_SCHEDULE)
7273

7374
# GPU parallelism
@@ -88,6 +89,7 @@ def _normalize_kwargs(cls, **kwargs):
8889

8990
# Code generation options for derivatives
9091
o['expand'] = oo.pop('expand', cls.EXPAND)
92+
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
9193
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
9294
o['deriv-unroll'] = oo.pop('deriv-unroll', False)
9395

@@ -188,7 +190,7 @@ class DeviceAdvOperator(DeviceOperatorMixin, CoreOperator):
188190
@classmethod
189191
@timed_pass(name='specializing.DSL')
190192
def _specialize_dsl(cls, expressions, **kwargs):
191-
expressions = collect_derivatives(expressions)
193+
expressions = collect_derivatives(expressions, **kwargs)
192194

193195
return expressions
194196

@@ -280,7 +282,7 @@ class DeviceCustomOperator(DeviceOperatorMixin, CustomOperator):
280282
@classmethod
281283
def _make_dsl_passes_mapper(cls, **kwargs):
282284
return {
283-
'collect-derivs': collect_derivatives,
285+
'deriv-collect': collect_derivatives,
284286
}
285287

286288
@classmethod
@@ -330,7 +332,7 @@ def _make_iet_passes_mapper(cls, **kwargs):
330332

331333
_known_passes = (
332334
# DSL
333-
'collect-derivs',
335+
'deriv-collect',
334336
# Expressions
335337
'buffering',
336338
# Clusters

devito/core/operator.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ class BasicOperator(Operator):
6969
intensity of the generated kernel.
7070
"""
7171

72+
CIRE_MINMEM = True
73+
"""
74+
Minimize memory consumption when allocating temporaries for CIRE-optimized
75+
expressions. This may come at the cost of slighly worse performance due to
76+
the potential need for extra registers to hold a greater number of support
77+
variables (e.g., strides).
78+
"""
79+
7280
SCALAR_MIN_TYPE = np.float16
7381
"""
7482
Minimum datatype for a scalar arising from a common sub-expression or CIRE temp.
@@ -115,6 +123,12 @@ class BasicOperator(Operator):
115123
finite-difference derivatives.
116124
"""
117125

126+
DERIV_COLLECT = True
127+
"""
128+
Factorize finite-difference derivatives exploiting the linearity of the FD
129+
operators.
130+
"""
131+
118132
DERIV_SCHEDULE = 'basic'
119133
"""
120134
The schedule to use for the computation of finite-difference derivatives.
@@ -288,7 +302,7 @@ def _specialize_dsl(cls, expressions, **kwargs):
288302
# Call passes
289303
for i in passes:
290304
try:
291-
expressions = passes_mapper[i](expressions)
305+
expressions = passes_mapper[i](expressions, **kwargs)
292306
except KeyError:
293307
pass
294308

devito/finite_differences/differentiable.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,20 @@ def _evaluate(self, **kwargs):
948948

949949
return EvalDerivative(*expr.args, base=self.base)
950950

951+
def _subs(self, old, new, **hints):
952+
# We have to work around SymPy's weak implementation of `subs` when
953+
# it gets to replacing sub-operations such as `a*b*c` (i.e., potentially
954+
# `self`'s `base`) within say `a*b*c*w[i0]` (i.e., the corresponding
955+
# `self.expr`), because depending on the complexity of `a/b/c`, SymPy
956+
# may fail to identify the sub-expression to be replaced (note: if
957+
# `a/b/c` are atoms or Indexeds, it's generally fine)
958+
959+
if not old.is_Mul or \
960+
old is not self.base:
961+
return super()._subs(old, new, **hints)
962+
963+
return self._rebuild(new * self.weights)
964+
951965

952966
class DiffDerivative(IndexDerivative, DifferentiableOp):
953967
pass

devito/ir/clusters/cluster.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,16 @@ class ClusterGroup(tuple):
470470

471471
def __new__(cls, clusters, ispace=None):
472472
obj = super().__new__(cls, flatten(as_tuple(clusters)))
473-
obj._ispace = ispace
473+
474+
if ispace is not None:
475+
obj._ispace = ispace
476+
else:
477+
# Best effort attempt to infer a common IterationSpace
478+
try:
479+
obj._ispace, = {c.ispace for c in obj}
480+
except ValueError:
481+
obj._ispace = None
482+
474483
return obj
475484

476485
@classmethod

0 commit comments

Comments
 (0)