Skip to content

Commit a97516f

Browse files
committed
compiler: Add shift to center aliases wrt halo
1 parent e4d5e1b commit a97516f

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

devito/passes/clusters/aliases.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -880,13 +880,14 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
880880
# The user might suggest to go more relaxed about this via `opt_minmem`,
881881
# in which case we extend the halo based on the surrounding
882882
# Functions to minimize support variables such as strides etc
883-
halo = {i.dim: Size(abs(i.lower), abs(i.upper)) for i in writeto}
883+
min_halo = {i.dim: Size(abs(i.lower), abs(i.upper)) for i in writeto}
884884

885885
if opt_minmem:
886886
functions = []
887887
else:
888888
functions = retrieve_functions(pivot)
889889

890+
halo = dict(min_halo)
890891
for f in functions:
891892
for d, h0 in list(halo.items()):
892893
try:
@@ -895,25 +896,26 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
895896
continue
896897
halo[d] = Size(max(h0.left, h1.left), max(h0.right, h1.right))
897898

899+
shift = [halo[d].left - min_halo[d].left for d in writeto.itdims]
898900
halo = tuple(halo.values())
899901

900902
# The indices used to write into the Array
901903
indices = []
902-
for i in writeto:
904+
for i, s in zip(writeto, shift):
903905
try:
904906
# E.g., `xs`
905907
sub_iterators = writeto.sub_iterators[i.dim]
906908
assert len(sub_iterators) <= 1
907909
indices.append(sub_iterators[0])
908910
except (KeyError, IndexError):
909911
# E.g., `z` -- a non-shifted Dimension
910-
indices.append(i.dim - i.lower)
912+
indices.append(i.dim - i.lower + s)
911913

912914
dtype = sympy_dtype(pivot, base=meta.dtype)
913915
obj = make(name=name, dimensions=dimensions, halo=halo, dtype=dtype)
914916
expression = Eq(obj[indices], uxreplace(pivot, subs))
915917

916-
callback = lambda idx: obj[idx]
918+
callback = lambda idx: obj[[i + s for i, s in zip(idx, shift)]]
917919
else:
918920
# Degenerate case: scalar expression
919921
assert writeto.size == 0

0 commit comments

Comments
 (0)