Skip to content

Commit 7c897db

Browse files
Merge pull request #2076 from devitocodes/hotfix-minmax
compiler: Fix and improve MIN/MAX codegen
2 parents a2518b7 + b1d3670 commit 7c897db

File tree

11 files changed

+157
-78
lines changed

11 files changed

+157
-78
lines changed

devito/finite_differences/elementary.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,13 @@ def root(x):
9393
class Min(sympy.Min, Evaluable):
9494

9595
def _evaluate(self, **kwargs):
96-
args = self._evaluate_args(**kwargs)
97-
assert len(args) == 2
98-
return self.func(args[0], args[1], evaluate=False)
96+
return self.func(*self._evaluate_args(**kwargs), evaluate=False)
9997

10098

10199
class Max(sympy.Max, Evaluable):
102100

103101
def _evaluate(self, **kwargs):
104-
args = self._evaluate_args(**kwargs)
105-
assert len(args) == 2
106-
return self.func(args[0], args[1], evaluate=False)
102+
return self.func(*self._evaluate_args(**kwargs), evaluate=False)
107103

108104

109105
def Id(x):

devito/ir/iet/visitors.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import cgen as c
1313
from sympy import IndexedBase
14+
from sympy.core.function import Application
1415

1516
from devito.exceptions import VisitorException
1617
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
@@ -24,9 +25,9 @@
2425
IndexedData, DeviceMap)
2526

2627

27-
__all__ = ['FindNodes', 'FindSections', 'FindSymbols', 'MapExprStmts', 'MapNodes',
28-
'IsPerfectIteration', 'printAST', 'CGen', 'CInterface', 'Transformer',
29-
'Uxreplace']
28+
__all__ = ['FindApplications', 'FindNodes', 'FindSections', 'FindSymbols',
29+
'MapExprStmts', 'MapNodes', 'IsPerfectIteration', 'printAST', 'CGen',
30+
'CInterface', 'Transformer', 'Uxreplace']
3031

3132

3233
class Visitor(GenericVisitor):
@@ -953,6 +954,55 @@ def visit_Node(self, o, ret=None):
953954
return ret
954955

955956

957+
class FindApplications(Visitor):
958+
959+
"""
960+
Find all SymPy applied functions (aka, `Application`s). The user may refine
961+
the search by supplying a different target class.
962+
"""
963+
964+
def __init__(self, cls=Application):
965+
super().__init__()
966+
self.match = lambda i: isinstance(i, cls) and not isinstance(i, Basic)
967+
968+
@classmethod
969+
def default_retval(cls):
970+
return set()
971+
972+
def visit_object(self, o, **kwargs):
973+
return self.default_retval()
974+
975+
def visit_tuple(self, o, ret=None):
976+
ret = ret or self.default_retval()
977+
for i in o:
978+
ret.update(self._visit(i, ret=ret))
979+
return ret
980+
981+
def visit_Node(self, o, ret=None):
982+
ret = ret or self.default_retval()
983+
for i in o.children:
984+
ret.update(self._visit(i, ret=ret))
985+
return ret
986+
987+
def visit_Expression(self, o, **kwargs):
988+
return o.expr.find(self.match)
989+
990+
def visit_Iteration(self, o, **kwargs):
991+
ret = self._visit(o.children) or self.default_retval()
992+
ret.update(o.symbolic_min.find(self.match))
993+
ret.update(o.symbolic_max.find(self.match))
994+
return ret
995+
996+
def visit_Call(self, o, **kwargs):
997+
ret = self.default_retval()
998+
for i in o.arguments:
999+
try:
1000+
ret.update(i.find(self.match))
1001+
except (AttributeError, TypeError):
1002+
continue
1003+
return ret
1004+
1005+
9561006
class IsPerfectIteration(Visitor):
9571007

9581008
"""

devito/mpi/routines.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,6 @@ def msgs(self):
5858
def regions(self):
5959
return [i for i in self._regions.values() if i is not None]
6060

61-
@property
62-
def headers(self):
63-
"""
64-
No headers needed by default
65-
"""
66-
return {}
67-
6861
def make(self, hs):
6962
"""
7063
Construct Callables and Calls implementing distributed-memory halo
@@ -517,14 +510,6 @@ class OverlapHaloExchangeBuilder(DiagHaloExchangeBuilder):
517510
remainder()
518511
"""
519512

520-
@property
521-
def headers(self):
522-
"""
523-
Overlap Mode uses MIN/MAX that need to be defined
524-
"""
525-
return {'headers': [('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')),
526-
('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))'))]}
527-
528513
def _make_msg(self, f, hse, key):
529514
# Only retain the halos required by the Diag scheme
530515
halos = sorted(i for i in hse.halos if isinstance(i.dim, tuple))

devito/operator/operator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from devito.operator.registry import operator_selector
2020
from devito.mpi import MPI
2121
from devito.parameters import configuration
22-
from devito.passes import Graph, lower_index_derivatives, generate_implicit, instrument
22+
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
23+
generate_macros, instrument)
2324
from devito.symbolics import estimate_cost
2425
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
2526
filter_sorted, frozendict, is_integer, split, timed_pass,
@@ -450,6 +451,9 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
450451
# specialization further Sections may be introduced
451452
instrument(graph, profiler=profiler, sregistry=sregistry)
452453

454+
# Extract the necessary macros from the symbolic objects
455+
generate_macros(graph)
456+
453457
return graph.root, graph
454458

455459
# Read-only properties exposed to the outside world

devito/passes/iet/misc.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
from functools import singledispatch
2+
13
import cgen
4+
import sympy
25

3-
from devito.ir import (Any, Forward, List, Prodder, FindNodes, Transformer,
4-
filter_iterations, retrieve_iteration_tree)
6+
from devito.finite_differences import Max, Min
7+
from devito.ir import (Any, Forward, List, Prodder, FindApplications, FindNodes,
8+
Transformer, filter_iterations, retrieve_iteration_tree)
59
from devito.passes.iet.engine import iet_pass
6-
from devito.symbolics import evalrel
10+
from devito.symbolics import evalrel, has_integer_args
711
from devito.tools import split
812

9-
__all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions']
13+
__all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions',
14+
'generate_macros']
1015

1116

1217
@iet_pass
@@ -124,9 +129,35 @@ def relax_incr_dimensions(iet, options=None, **kwargs):
124129
if mapper:
125130
iet = Transformer(mapper, nested=True).visit(iet)
126131

127-
headers = [('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')),
128-
('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))'))]
129-
else:
130-
headers = []
132+
return iet, {}
133+
134+
135+
@iet_pass
136+
def generate_macros(iet):
137+
applications = FindApplications().visit(iet)
138+
headers = set().union(*[_generate_macros(i) for i in applications])
131139

132140
return iet, {'headers': headers}
141+
142+
143+
@singledispatch
144+
def _generate_macros(expr):
145+
return set()
146+
147+
148+
@_generate_macros.register(Min)
149+
@_generate_macros.register(sympy.Min)
150+
def _(expr):
151+
if has_integer_args(*expr.args) and len(expr.args) == 2:
152+
return {('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))'))}
153+
else:
154+
return set()
155+
156+
157+
@_generate_macros.register(Max)
158+
@_generate_macros.register(sympy.Max)
159+
def _(expr):
160+
if has_integer_args(*expr.args) and len(expr.args) == 2:
161+
return {('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))'))}
162+
else:
163+
return set()

devito/passes/iet/mpi.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,6 @@ def make_mpi(iet, mpimode=None, **kwargs):
292292

293293
efuncs = sync_heb.efuncs + user_heb.efuncs
294294
iet = Transformer(mapper, nested=True).visit(iet)
295-
headers = user_heb.headers
296295

297296
# Must drop the PARALLEL tag from the Iterations within which halo
298297
# exchanges are performed
@@ -307,9 +306,8 @@ def make_mpi(iet, mpimode=None, **kwargs):
307306
for n in tree[:tree.index(i)+1]})
308307
break
309308
iet = Transformer(mapper, nested=True).visit(iet)
310-
headers.update({'includes': ['mpi.h'], 'efuncs': efuncs})
311309

312-
return iet, headers
310+
return iet, {'includes': ['mpi.h'], 'efuncs': efuncs}
313311

314312

315313
def mpiize(graph, **kwargs):

devito/symbolics/inspection.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from devito.symbolics.queries import q_routine
1212
from devito.tools import as_tuple, prod
1313

14-
__all__ = ['compare_ops', 'estimate_cost']
14+
__all__ = ['compare_ops', 'estimate_cost', 'has_integer_args']
1515

1616

1717
def compare_ops(e1, e2):
@@ -235,3 +235,28 @@ def _(expr, estimate):
235235
flops *= prod(i._size for i in expr.dimensions)
236236

237237
return flops, False
238+
239+
240+
def has_integer_args(*args):
241+
"""
242+
True if all `args` are of integer type, False otherwise.
243+
"""
244+
if len(args) == 0:
245+
return False
246+
247+
if len(args) == 1:
248+
try:
249+
return np.issubdtype(args[0].dtype, np.integer)
250+
except AttributeError:
251+
return args[0].is_integer
252+
253+
res = True
254+
for a in args:
255+
try:
256+
if len(a.args) > 0:
257+
res = res and has_integer_args(*a.args)
258+
else:
259+
res = res and has_integer_args(a)
260+
except AttributeError:
261+
res = res and has_integer_args(a)
262+
return res

devito/symbolics/printer.py

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sympy.printing.c import C99CodePrinter
1313

1414
from devito.arch.compiler import AOMPCompiler
15+
from devito.symbolics.inspection import has_integer_args
1516

1617
__all__ = ['ccode']
1718

@@ -102,14 +103,16 @@ def _print_Mod(self, expr):
102103
return '%'.join(args)
103104

104105
def _print_Min(self, expr):
105-
"""Print Min using devito defined header Min"""
106-
func = 'MIN' if has_integer_args(*expr.args) else 'fmin'
107-
return "%s(%s)" % (func, self._print(expr.args)[1:-1])
106+
if has_integer_args(*expr.args) and len(expr.args) == 2:
107+
return "MIN(%s)" % self._print(expr.args)[1:-1]
108+
else:
109+
return super()._print_Min(expr)
108110

109111
def _print_Max(self, expr):
110-
"""Print Max using devito defined header Max"""
111-
func = 'MAX' if has_integer_args(*expr.args) else 'fmax'
112-
return "%s(%s)" % (func, self._print(expr.args)[1:-1])
112+
if has_integer_args(*expr.args) and len(expr.args) == 2:
113+
return "MAX(%s)" % self._print(expr.args)[1:-1]
114+
else:
115+
return super()._print_Max(expr)
113116

114117
def _print_Abs(self, expr):
115118
"""Print an absolute value. Use `abs` if can infer it is an Integer"""
@@ -258,30 +261,3 @@ def ccode(expr, **settings):
258261
# to always use the correct one from our printer
259262
if Version(sympy.__version__) >= Version("1.11"):
260263
setattr(sympy.printing.str.StrPrinter, '_print_Add', CodePrinter._print_Add)
261-
262-
263-
# Check arguements type
264-
def has_integer_args(*args):
265-
"""
266-
Check if expression is Integer.
267-
Used to choose the function printed in the c-code
268-
"""
269-
if len(args) == 0:
270-
return False
271-
272-
if len(args) == 1:
273-
try:
274-
return np.issubdtype(args[0].dtype, np.integer)
275-
except AttributeError:
276-
return args[0].is_integer
277-
278-
res = True
279-
for a in args:
280-
try:
281-
if len(a.args) > 0:
282-
res = res and has_integer_args(*a.args)
283-
else:
284-
res = res and has_integer_args(a)
285-
except AttributeError:
286-
res = res and has_integer_args(a)
287-
return res

examples/performance/00_overview.ipynb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,10 +1184,9 @@
11841184
"output_type": "stream",
11851185
"text": [
11861186
"#define _POSIX_C_SOURCE 200809L\n",
1187-
"#define MIN(a,b) (((a) < (b)) ? (a) : (b))\n",
1188-
"#define MAX(a,b) (((a) > (b)) ? (a) : (b))\n",
11891187
"#define START_TIMER(S) struct timeval start_ ## S , end_ ## S ; gettimeofday(&start_ ## S , NULL);\n",
11901188
"#define STOP_TIMER(S,T) gettimeofday(&end_ ## S, NULL); T->S += (double)(end_ ## S .tv_sec-start_ ## S.tv_sec)+(double)(end_ ## S .tv_usec-start_ ## S .tv_usec)/1000000;\n",
1189+
"#define MIN(a,b) (((a) < (b)) ? (a) : (b))\n",
11911190
"\n",
11921191
"#include \"stdlib.h\"\n",
11931192
"#include \"math.h\"\n",
@@ -1622,10 +1621,9 @@
16221621
"output_type": "stream",
16231622
"text": [
16241623
"#define _POSIX_C_SOURCE 200809L\n",
1625-
"#define MIN(a,b) (((a) < (b)) ? (a) : (b))\n",
1626-
"#define MAX(a,b) (((a) > (b)) ? (a) : (b))\n",
16271624
"#define START_TIMER(S) struct timeval start_ ## S , end_ ## S ; gettimeofday(&start_ ## S , NULL);\n",
16281625
"#define STOP_TIMER(S,T) gettimeofday(&end_ ## S, NULL); T->S += (double)(end_ ## S .tv_sec-start_ ## S.tv_sec)+(double)(end_ ## S .tv_usec-start_ ## S .tv_usec)/1000000;\n",
1626+
"#define MIN(a,b) (((a) < (b)) ? (a) : (b))\n",
16291627
"\n",
16301628
"#include \"stdlib.h\"\n",
16311629
"#include \"math.h\"\n",

tests/test_linearize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def test_codegen_quality0():
169169
assert all('const long' in str(i) for i in exprs[:-2])
170170

171171
# Only four access macros necessary, namely `uL0`, `bufL0`, `bufL1`
172-
# MIN/MAX for the efunc args
172+
# for the efunc args
173173
# (the other three obviously are _POSIX_C_SOURCE, START_TIMER, STOP_TIMER)
174-
assert len(op._headers) == 8
174+
assert len(op._headers) == 6
175175

176176

177177
def test_codegen_quality1():
@@ -193,7 +193,7 @@ def test_codegen_quality1():
193193

194194
# Only two access macros necessary, namely `uL0` and `r1L0` (the other five
195195
# obviously are _POSIX_C_SOURCE, MIN, MAX, START_TIMER, STOP_TIMER)
196-
assert len(op._headers) == 7
196+
assert len(op._headers) == 6
197197

198198

199199
def test_pow():

0 commit comments

Comments
 (0)