@@ -880,13 +880,14 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
880880 # The user might suggest to go more relaxed about this via `opt_minmem`,
881881 # in which case we extend the halo based on the surrounding
882882 # Functions to minimize support variables such as strides etc
883- halo = {i .dim : Size (abs (i .lower ), abs (i .upper )) for i in writeto }
883+ min_halo = {i .dim : Size (abs (i .lower ), abs (i .upper )) for i in writeto }
884884
885885 if opt_minmem :
886886 functions = []
887887 else :
888888 functions = retrieve_functions (pivot )
889889
890+ halo = dict (min_halo )
890891 for f in functions :
891892 for d , h0 in list (halo .items ()):
892893 try :
@@ -895,25 +896,26 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
895896 continue
896897 halo [d ] = Size (max (h0 .left , h1 .left ), max (h0 .right , h1 .right ))
897898
899+ shift = [halo [d ].left - min_halo [d ].left for d in writeto .itdims ]
898900 halo = tuple (halo .values ())
899901
900902 # The indices used to write into the Array
901903 indices = []
902- for i in writeto :
904+ for i , s in zip ( writeto , shift ) :
903905 try :
904906 # E.g., `xs`
905907 sub_iterators = writeto .sub_iterators [i .dim ]
906908 assert len (sub_iterators ) <= 1
907909 indices .append (sub_iterators [0 ])
908910 except (KeyError , IndexError ):
909911 # E.g., `z` -- a non-shifted Dimension
910- indices .append (i .dim - i .lower )
912+ indices .append (i .dim - i .lower + s )
911913
912914 dtype = sympy_dtype (pivot , base = meta .dtype )
913915 obj = make (name = name , dimensions = dimensions , halo = halo , dtype = dtype )
914916 expression = Eq (obj [indices ], uxreplace (pivot , subs ))
915917
916- callback = lambda idx : obj [idx ]
918+ callback = lambda idx : obj [[ i + s for i , s in zip ( idx , shift )] ]
917919 else :
918920 # Degenerate case: scalar expression
919921 assert writeto .size == 0
0 commit comments