@@ -443,10 +443,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
443443 if d in fixed :
444444 continue
445445
446- name = '' .join ('r' if i is d else 'c' for i in distributor .dimensions )
447- rpeer = FieldFromPointer (name , nb )
448- name = '' .join ('l' if i is d else 'c' for i in distributor .dimensions )
449- lpeer = FieldFromPointer (name , nb )
446+ rpeer , lpeer = self ._make_peers (d , distributor , nb )
450447
451448 if (d , LEFT ) in hse .halos :
452449 # Sending to left, receiving from right
@@ -491,6 +488,14 @@ def _make_basic_mapper(self, f, fixed):
491488
492489 return mapper
493490
491+ def _make_peers (self , d , distributor , nb ):
492+ rname = '' .join ('r' if i is d else 'c' for i in distributor .dimensions )
493+ rpeer = FieldFromPointer (rname , nb )
494+ lname = '' .join ('l' if i is d else 'c' for i in distributor .dimensions )
495+ lpeer = FieldFromPointer (lname , nb )
496+
497+ return rpeer , lpeer
498+
494499 def _call_haloupdate (self , name , f , hse , * args ):
495500 comm = f .grid .distributor ._obj_comm
496501 nb = f .grid .distributor ._obj_neighborhood
@@ -537,7 +542,7 @@ def _make_body(self, callcompute, remainder, haloupdates, halowaits):
537542class Basic2HaloExchangeBuilder (BasicHaloExchangeBuilder ):
538543
539544 """
540- A BasicHaloExchangeBuilder making use of pre-allocated buffers for
545+ A BasicHaloExchangeBuilder using pre-allocated buffers for
541546 message size.
542547
543548 Generates:
@@ -616,10 +621,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
616621 if d in fixed :
617622 continue
618623
619- name = '' .join ('r' if i is d else 'c' for i in distributor .dimensions )
620- rpeer = FieldFromPointer (name , nb )
621- name = '' .join ('l' if i is d else 'c' for i in distributor .dimensions )
622- lpeer = FieldFromPointer (name , nb )
624+ rpeer , lpeer = self ._make_peers (d , distributor , nb )
623625
624626 if (d , LEFT ) in hse .halos :
625627 # Sending to left, receiving from right
@@ -1297,6 +1299,7 @@ def _as_number(self, v, args):
12971299 return int (subs_op_args (v , args ))
12981300
12991301 def _allocate_buffers (self , f , shape , entry ):
1302+ # Allocate the send/recv buffers
13001303 entry .sizes = (c_int * len (shape ))(* shape )
13011304 size = reduce (mul , shape )* dtype_len (self .target .dtype )
13021305 ctype = dtype_to_ctype (f .dtype )
@@ -1429,21 +1432,12 @@ def _arg_defaults(self, allocator, alias, args=None):
14291432 if d in fixed :
14301433 continue
14311434
1432- if (d , LEFT ) in self .halos :
1433- entry = self .value [i ]
1434- i = i + 1
1435- # Sending to left, receiving from right
1436- shape = mapper [(d , LEFT , OWNED )]
1437- # Allocate the send/recv buffers
1438- self ._allocate_buffers (f , shape , entry )
1439-
1440- if (d , RIGHT ) in self .halos :
1441- entry = self .value [i ]
1442- i = i + 1
1443- # Sending to right, receiving from left
1444- shape = mapper [(d , RIGHT , OWNED )]
1445- # Allocate the send/recv buffers
1446- self ._allocate_buffers (f , shape , entry )
1435+ for side in (LEFT , RIGHT ):
1436+ if (d , side ) in self .halos :
1437+ entry = self .value [i ]
1438+ i += 1
1439+ shape = mapper [(d , side , OWNED )]
1440+ self ._allocate_buffers (f , shape , entry )
14471441
14481442 return {self .name : self .value }
14491443
0 commit comments