Skip to content

Commit 1187b4b

Browse files
committed
Merge branch 'master' into pass-op-arguments
2 parents 98ebc03 + 5f1ff64 commit 1187b4b

37 files changed

+1122
-609
lines changed

devito/core/cpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def _normalize_kwargs(cls, **kwargs):
6060
o['par-dynamic-work'] = oo.pop('par-dynamic-work', cls.PAR_DYNAMIC_WORK)
6161
o['par-nested'] = oo.pop('par-nested', cls.PAR_NESTED)
6262

63+
# Distributed parallelism
64+
o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN)
65+
6366
# Misc
6467
o['expand'] = oo.pop('expand', cls.EXPAND)
6568
o['optcomms'] = oo.pop('optcomms', True)

devito/core/gpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def _normalize_kwargs(cls, **kwargs):
7575
o['gpu-fit'] = as_tuple(oo.pop('gpu-fit', cls._normalize_gpu_fit(**kwargs)))
7676
o['gpu-create'] = as_tuple(oo.pop('gpu-create', ()))
7777

78+
# Distributed parallelism
79+
o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN)
80+
7881
# Misc
7982
o['expand'] = oo.pop('expand', cls.EXPAND)
8083
o['optcomms'] = oo.pop('optcomms', True)

devito/core/operator.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ class BasicOperator(Operator):
100100
The supported MPI modes.
101101
"""
102102

103+
DIST_DROP_UNWRITTEN = True
104+
"""
105+
Drop halo exchanges for read-only Function, even in presence of
106+
stencil-like data accesses.
107+
"""
108+
103109
INDEX_MODE = "int64"
104110
"""
105111
The type of the expression used to compute array indices. Either `int64`
@@ -281,7 +287,7 @@ def _specialize_iet(cls, graph, **kwargs):
281287
# from HaloSpot optimization)
282288
# Note that if MPI is disabled then this pass will act as a no-op
283289
if 'mpi' not in passes:
284-
passes_mapper['mpi'](graph)
290+
passes_mapper['mpi'](graph, **kwargs)
285291

286292
# Run passes
287293
applied = []
@@ -300,17 +306,17 @@ def _specialize_iet(cls, graph, **kwargs):
300306
if 'init' not in passes:
301307
passes_mapper['init'](graph)
302308

303-
# Enforce pthreads if CPU-GPU orchestration requested
304-
if 'orchestrate' in passes and 'pthreadify' not in passes:
305-
passes_mapper['pthreadify'](graph, sregistry=sregistry)
306-
307309
# Symbol definitions
308310
cls._Target.DataManager(**kwargs).process(graph)
309311

310312
# Linearize n-dimensional Indexeds
311313
if 'linearize' not in passes and options['linearize']:
312314
passes_mapper['linearize'](graph)
313315

316+
# Enforce pthreads if CPU-GPU orchestration requested
317+
if 'orchestrate' in passes and 'pthreadify' not in passes:
318+
passes_mapper['pthreadify'](graph, sregistry=sregistry)
319+
314320
return graph
315321

316322

devito/ir/clusters/algorithms.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def callback(self, clusters, prefix):
281281
mapper[size][si].add(iaf)
282282

283283
# Construct the ModuloDimensions
284-
mds = OrderedDict()
284+
mds = []
285285
for size, v in mapper.items():
286286
for si, iafs in list(v.items()):
287287
# Offsets are sorted so that the semantic order (t0, t1, t2) follows
@@ -290,15 +290,10 @@ def callback(self, clusters, prefix):
290290
# sorting offsets {-1, 0, 1} as {0, -1, 1} assigning -inf to 0
291291
siafs = sorted(iafs, key=lambda i: -np.inf if i - si == 0 else (i - si))
292292

293-
# Create the ModuloDimensions. Note that if `size < len(iafs)` then
294-
# the same ModuloDimension may be used for multiple offsets
295-
for iaf in siafs[:size]:
293+
for iaf in siafs:
296294
name = '%s%d' % (si.name, len(mds))
297295
offset = uxreplace(iaf, {si: d.root})
298-
md = ModuloDimension(name, si, offset, size, origin=iaf)
299-
300-
key = lambda i: i.subs(si, 0) % size
301-
mds[md] = [i for i in siafs if key(i) == key(iaf)]
296+
mds.append(ModuloDimension(name, si, offset, size, origin=iaf))
302297

303298
# Replacement rule for ModuloDimensions
304299
def rule(size, e):
@@ -320,11 +315,8 @@ def rule(size, e):
320315
exprs = c.exprs
321316
groups = as_mapper(mds, lambda d: d.modulo)
322317
for size, v in groups.items():
323-
mapper = {}
324-
for md in v:
325-
mapper.update({i: md for i in mds[md]})
326-
327-
func = partial(xreplace_indices, mapper=mapper, key=partial(rule, size))
318+
subs = {md.origin: md for md in v}
319+
func = partial(xreplace_indices, mapper=subs, key=partial(rule, size))
328320
exprs = [e.apply(func) for e in exprs]
329321

330322
# Augment IterationSpace

devito/ir/clusters/cluster.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
DataSpace, Guards, Properties, Scope, detect_accesses,
1010
detect_io, normalize_properties, normalize_syncs,
1111
sdims_min, sdims_max)
12-
from devito.mpi.halo_scheme import HaloTouch
12+
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
1313
from devito.symbolics import estimate_cost
1414
from devito.tools import as_tuple, flatten, frozendict, infer_dtype
1515

@@ -26,7 +26,7 @@ class Cluster(object):
2626
exprs : expr-like or list of expr-like
2727
An ordered sequence of expressions computing a tensor.
2828
ispace : IterationSpace, optional
29-
The cluster iteration space.
29+
The Cluster iteration space.
3030
guards : dict, optional
3131
Mapper from Dimensions to expr-like, representing the conditions under
3232
which the Cluster should be computed.
@@ -37,9 +37,12 @@ class Cluster(object):
3737
Mapper from Dimensions to lists of SyncOps, that is ordered sequences of
3838
synchronization operations that must be performed in order to compute the
3939
Cluster asynchronously.
40+
halo_scheme : HaloScheme, optional
41+
The halo exchanges required by the Cluster.
4042
"""
4143

42-
def __init__(self, exprs, ispace=None, guards=None, properties=None, syncs=None):
44+
def __init__(self, exprs, ispace=None, guards=None, properties=None, syncs=None,
45+
halo_scheme=None):
4346
ispace = ispace or IterationSpace([])
4447

4548
self._exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs))
@@ -57,6 +60,8 @@ def __init__(self, exprs, ispace=None, guards=None, properties=None, syncs=None)
5760
properties = properties.drop(d)
5861
self._properties = properties
5962

63+
self._halo_scheme = halo_scheme
64+
6065
def __repr__(self):
6166
return "Cluster([%s])" % ('\n' + ' '*9).join('%s' % i for i in self.exprs)
6267

@@ -91,7 +96,9 @@ def from_clusters(cls, *clusters):
9196
raise ValueError("Cannot build a Cluster from Clusters with "
9297
"non-compatible synchronization operations")
9398

94-
return Cluster(exprs, ispace, guards, properties, syncs)
99+
halo_scheme = HaloScheme.union([c.halo_scheme for c in clusters])
100+
101+
return Cluster(exprs, ispace, guards, properties, syncs, halo_scheme)
95102

96103
def rebuild(self, *args, **kwargs):
97104
"""
@@ -110,7 +117,8 @@ def rebuild(self, *args, **kwargs):
110117
ispace=kwargs.get('ispace', self.ispace),
111118
guards=kwargs.get('guards', self.guards),
112119
properties=kwargs.get('properties', self.properties),
113-
syncs=kwargs.get('syncs', self.syncs))
120+
syncs=kwargs.get('syncs', self.syncs),
121+
halo_scheme=kwargs.get('halo_scheme', self.halo_scheme))
114122

115123
@property
116124
def exprs(self):
@@ -144,6 +152,10 @@ def properties(self):
144152
def syncs(self):
145153
return self._syncs
146154

155+
@property
156+
def halo_scheme(self):
157+
return self._halo_scheme
158+
147159
@cached_property
148160
def free_symbols(self):
149161
return set().union(*[e.free_symbols for e in self.exprs])

devito/ir/iet/nodes.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,13 @@ def functions(self):
288288
retval.append(s.function)
289289
except AttributeError:
290290
continue
291+
291292
if self.base is not None:
292293
retval.append(self.base.function)
294+
293295
if self.retobj is not None:
294296
retval.append(self.retobj.function)
297+
295298
return tuple(filter_ordered(retval))
296299

297300
@cached_property
@@ -309,10 +312,15 @@ def expr_symbols(self):
309312
retval.extend(i.free_symbols)
310313
except AttributeError:
311314
pass
315+
312316
if self.base is not None:
313317
retval.append(self.base)
314-
if self.retobj is not None:
318+
319+
if isinstance(self.retobj, Indexed):
320+
retval.extend(self.retobj.free_symbols)
321+
elif self.retobj is not None:
315322
retval.append(self.retobj)
323+
316324
return tuple(filter_ordered(retval))
317325

318326
@property
@@ -744,6 +752,8 @@ class CallableBody(Node):
744752
maps : Transfer or list of Transfer, optional
745753
Data maps for `body` (a data map may e.g. trigger a data transfer from
746754
host to device).
755+
strides : list of Nodes, optional
756+
Statements defining symbols used to access linearized arrays.
747757
objs : list of Definitions, optional
748758
Object definitions for `body`.
749759
unmaps : Transfer or list of Transfer, optional
@@ -756,19 +766,21 @@ class CallableBody(Node):
756766

757767
is_CallableBody = True
758768

759-
_traversable = ['unpacks', 'init', 'allocs', 'casts', 'bundles', 'maps', 'objs',
760-
'body', 'unmaps', 'unbundles', 'frees']
769+
_traversable = ['unpacks', 'init', 'allocs', 'casts', 'bundles', 'maps',
770+
'strides', 'objs', 'body', 'unmaps', 'unbundles', 'frees']
761771

762-
def __init__(self, body, init=(), unpacks=(), allocs=(), casts=(),
772+
def __init__(self, body, init=(), unpacks=(), strides=(), allocs=(), casts=(),
763773
bundles=(), objs=(), maps=(), unmaps=(), unbundles=(), frees=()):
764774
# Sanity check
765775
assert not isinstance(body, CallableBody), "CallableBody's cannot be nested"
766776

767777
self.body = as_tuple(body)
768-
self.init = as_tuple(init)
778+
769779
self.unpacks = as_tuple(unpacks)
780+
self.init = as_tuple(init)
770781
self.allocs = as_tuple(allocs)
771782
self.casts = as_tuple(casts)
783+
self.strides = as_tuple(strides)
772784
self.bundles = as_tuple(bundles)
773785
self.maps = as_tuple(maps)
774786
self.objs = as_tuple(objs)
@@ -1201,6 +1213,10 @@ def __getattr__(self, name):
12011213
def functions(self):
12021214
return as_tuple(self.nthreads)
12031215

1216+
@property
1217+
def expr_symbols(self):
1218+
return as_tuple(self.nthreads)
1219+
12041220
@property
12051221
def root(self):
12061222
return self.body[0]

0 commit comments

Comments
 (0)