99 DataSpace , Guards , Properties , Scope , detect_accesses ,
1010 detect_io , normalize_properties , normalize_syncs ,
1111 sdims_min , sdims_max )
12- from devito .mpi .halo_scheme import HaloTouch
12+ from devito .mpi .halo_scheme import HaloScheme , HaloTouch
1313from devito .symbolics import estimate_cost
1414from devito .tools import as_tuple , flatten , frozendict , infer_dtype
1515
@@ -26,7 +26,7 @@ class Cluster(object):
2626 exprs : expr-like or list of expr-like
2727 An ordered sequence of expressions computing a tensor.
2828 ispace : IterationSpace, optional
29- The cluster iteration space.
29+ The Cluster iteration space.
3030 guards : dict, optional
3131 Mapper from Dimensions to expr-like, representing the conditions under
3232 which the Cluster should be computed.
@@ -37,9 +37,12 @@ class Cluster(object):
3737 Mapper from Dimensions to lists of SyncOps, that is ordered sequences of
3838 synchronization operations that must be performed in order to compute the
3939 Cluster asynchronously.
40+ halo_scheme : HaloScheme, optional
41+ The halo exchanges required by the Cluster.
4042 """
4143
42- def __init__ (self , exprs , ispace = None , guards = None , properties = None , syncs = None ):
44+ def __init__ (self , exprs , ispace = None , guards = None , properties = None , syncs = None ,
45+ halo_scheme = None ):
4346 ispace = ispace or IterationSpace ([])
4447
4548 self ._exprs = tuple (ClusterizedEq (e , ispace = ispace ) for e in as_tuple (exprs ))
@@ -57,6 +60,8 @@ def __init__(self, exprs, ispace=None, guards=None, properties=None, syncs=None)
5760 properties = properties .drop (d )
5861 self ._properties = properties
5962
63+ self ._halo_scheme = halo_scheme
64+
6065 def __repr__ (self ):
6166 return "Cluster([%s])" % ('\n ' + ' ' * 9 ).join ('%s' % i for i in self .exprs )
6267
@@ -91,7 +96,9 @@ def from_clusters(cls, *clusters):
9196 raise ValueError ("Cannot build a Cluster from Clusters with "
9297 "non-compatible synchronization operations" )
9398
94- return Cluster (exprs , ispace , guards , properties , syncs )
99+ halo_scheme = HaloScheme .union ([c .halo_scheme for c in clusters ])
100+
101+ return Cluster (exprs , ispace , guards , properties , syncs , halo_scheme )
95102
96103 def rebuild (self , * args , ** kwargs ):
97104 """
@@ -110,7 +117,8 @@ def rebuild(self, *args, **kwargs):
110117 ispace = kwargs .get ('ispace' , self .ispace ),
111118 guards = kwargs .get ('guards' , self .guards ),
112119 properties = kwargs .get ('properties' , self .properties ),
113- syncs = kwargs .get ('syncs' , self .syncs ))
120+ syncs = kwargs .get ('syncs' , self .syncs ),
121+ halo_scheme = kwargs .get ('halo_scheme' , self .halo_scheme ))
114122
115123 @property
116124 def exprs (self ):
@@ -144,6 +152,10 @@ def properties(self):
144152 def syncs (self ):
145153 return self ._syncs
146154
155+ @property
156+ def halo_scheme (self ):
157+ return self ._halo_scheme
158+
147159 @cached_property
148160 def free_symbols (self ):
149161 return set ().union (* [e .free_symbols for e in self .exprs ])
0 commit comments