@@ -73,7 +73,7 @@ def __new__(
7373                cat_tensor_shape [1 ] +=  shard .size ()[1 ]
7474
7575        # in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension 
76-         if  len (local_shards ) >  1  and  local_shards [0 ].ndim  ==  1 :  # column -wise sharding 
76+         if  len (local_shards ) >  1  and  local_shards [0 ].ndim  ==  1 :  # row -wise sharding 
7777            for  shard  in  local_shards [1 :]:
7878                cat_tensor_shape [0 ] +=  shard .size ()[0 ]
7979
@@ -119,6 +119,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
119119            aten .copy_ .default : cls .handle_copy_ ,
120120            aten .zeros_like .default : cls .handle_zeros_like ,
121121            aten .empty_like .default : cls .handle_empty_like ,
122+             aten .constant_pad_nd .default : cls .handle_constant_pad_nd ,
122123        }
123124
124125        if  func  in  dispatcher :
@@ -279,6 +280,195 @@ def handle_new_empty(args, kwargs):
279280            self_ls .local_offsets (),
280281        )
281282
283+     @staticmethod  
284+     # pyre-fixme[3]: Return type must be annotated. 
285+     # pyre-fixme[2]: Parameter must be annotated. 
286+     def  handle_constant_pad_nd (args , kwargs ):
287+         """ 
288+         Apply constant padding to LocalShardsWrapper. 
289+ 
290+         The padding is based off of the following ideas: 
291+         - The resulting wrapper represents the padded version of the logical tensor. 
292+         - Each shard is padded based on the sharding type + dimension that is padded. 
293+             - For instance, CW shards padded on the left most col will have only padding on the first CW shard. 
294+             - Padding the top row will apply to all CW shards. 
295+         """ 
296+         self_lsw  =  args [0 ]
297+         pad_spec  =  args [1 ]
298+         pad_value  =  args [2 ] if  len (args ) >  2  else  0.0 
299+ 
300+         if  len (self_lsw .local_shards ()) ==  0 :
301+             raise  NotImplementedError (
302+                 "Padding empty LocalShardsWrapper is not supported." 
303+             )
304+ 
305+         local_shards  =  self_lsw .local_shards ()
306+ 
307+         if  len (local_shards ) ==  1 :
308+             padded_shard  =  torch .nn .functional .pad (
309+                 local_shards [0 ], pad_spec , mode = "constant" , value = pad_value 
310+             )
311+             return  LocalShardsWrapper ([padded_shard ], self_lsw .local_offsets ())
312+ 
313+         padded_shards  =  list (local_shards )
314+ 
315+         if  local_shards [0 ].ndim  ==  2 :
316+             # 2D Column-wise sharding: [pad_left, pad_right, pad_top, pad_bottom] 
317+             pad_left , pad_right , pad_top , pad_bottom  =  (
318+                 pad_spec [0 ],
319+                 pad_spec [1 ],
320+                 pad_spec [2 ],
321+                 pad_spec [3 ],
322+             )
323+ 
324+             if  pad_top  >  0 :
325+                 padded_shards  =  [
326+                     torch .nn .functional .pad (
327+                         shard , [0 , 0 , pad_top , 0 ], mode = "constant" , value = pad_value 
328+                     )
329+                     for  shard  in  padded_shards 
330+                 ]
331+             if  pad_bottom  >  0 :
332+                 padded_shards  =  [
333+                     torch .nn .functional .pad (
334+                         shard , [0 , 0 , 0 , pad_bottom ], mode = "constant" , value = pad_value 
335+                     )
336+                     for  shard  in  padded_shards 
337+                 ]
338+             if  pad_left  >  0 :
339+                 padded_shards [0 ] =  torch .nn .functional .pad (
340+                     padded_shards [0 ],
341+                     [pad_left , 0 , 0 , 0 ],
342+                     mode = "constant" ,
343+                     value = pad_value ,
344+                 )
345+             if  pad_right  >  0 :
346+                 padded_shards [- 1 ] =  torch .nn .functional .pad (
347+                     padded_shards [- 1 ],
348+                     [0 , pad_right , 0 , 0 ],
349+                     mode = "constant" ,
350+                     value = pad_value ,
351+                 )
352+         elif  local_shards [0 ].ndim  ==  1 :
353+             # 1D Row-wise sharding: [pad_top, pad_bottom] 
354+             pad_top , pad_bottom  =  pad_spec [0 ], pad_spec [1 ]
355+ 
356+             if  pad_top  >  0 :
357+                 padded_shards [0 ] =  torch .nn .functional .pad (
358+                     padded_shards [0 ], [pad_top , 0 ], mode = "constant" , value = pad_value 
359+                 )
360+             if  pad_bottom  >  0 :
361+                 padded_shards [- 1 ] =  torch .nn .functional .pad (
362+                     padded_shards [- 1 ], [0 , pad_bottom ], mode = "constant" , value = pad_value 
363+                 )
364+         else :
365+             raise  NotImplementedError (
366+                 f"Padding for { local_shards [0 ].ndim }  D tensors is not supported. " 
367+                 f"Only 1D and 2D tensors are currently supported." 
368+             )
369+ 
370+         # Update offsets and storage metadata 
371+         original_storage  =  self_lsw .storage_metadata ()
372+         updated_offsets , updated_storage  =  LocalShardsWrapper ._compute_updated_metadata (
373+             original_storage ,
374+             self_lsw .local_offsets (),
375+             pad_spec ,
376+             local_shards [0 ].ndim ,
377+             padded_shards ,
378+         )
379+ 
380+         result  =  LocalShardsWrapper (padded_shards , updated_offsets )
381+         result ._storage_meta  =  updated_storage 
382+         return  result 
383+ 
384+     @staticmethod  
385+     def  _compute_updated_metadata (
386+         original_storage : TensorStorageMetadata ,
387+         original_offsets : list [torch .Size ],
388+         pad_spec : list [int ],
389+         ndim : int ,
390+         padded_shards : list [torch .Tensor ],
391+     ) ->  tuple [list [torch .Size ], TensorStorageMetadata ]:
392+         """ 
393+         Compute updated offsets and storage metadata after padding is applied. 
394+ 
395+         Args: 
396+             original_storage: Original storage metadata 
397+             original_offsets: Original shard offsets 
398+             pad_spec: Padding specification 
399+             ndim: Number of dimensions (1=RW or 2=CW) 
400+             padded_shards: Padded shard tensors 
401+ 
402+         Returns: 
403+             Tuple of (updated_offsets, updated_storage_metadata) 
404+         """ 
405+         if  ndim  ==  1 :  # 1D RW 
406+             pad_top , pad_bottom  =  pad_spec [0 ], pad_spec [1 ]
407+ 
408+             updated_offsets  =  []
409+             for  i , offset  in  enumerate (original_offsets ):
410+                 if  i  ==  0 :
411+                     # First shard: offset stays the same (absorbs top padding) 
412+                     updated_offsets .append (offset )
413+                 else :
414+                     # Subsequent shards: shift by top padding amount 
415+                     new_offset  =  (offset [0 ] +  pad_top ,)
416+                     updated_offsets .append (torch .Size (new_offset ))
417+ 
418+             new_global_size  =  torch .Size (
419+                 [original_storage .size [0 ] +  pad_top  +  pad_bottom ]
420+             )
421+ 
422+         elif  ndim  ==  2 :  # 2D CW 
423+             pad_left , pad_right , pad_top , pad_bottom  =  (
424+                 pad_spec [0 ],
425+                 pad_spec [1 ],
426+                 pad_spec [2 ],
427+                 pad_spec [3 ],
428+             )
429+ 
430+             updated_offsets  =  []
431+             for  i , offset  in  enumerate (original_offsets ):
432+                 row_offset  =  offset [0 ]
433+                 col_offset  =  offset [1 ]
434+ 
435+                 # Top/bottom padding doesn't affect offsets 
436+                 # Left padding affects column offsets 
437+                 if  i  ==  0 :
438+                     # First shard: column offset stays the same (absorbs left padding) 
439+                     new_offset  =  (row_offset , col_offset )
440+                 else :
441+                     # Subsequent shards: shift column offset by left padding amount 
442+                     new_offset  =  (row_offset , col_offset  +  pad_left )
443+ 
444+                 updated_offsets .append (torch .Size (new_offset ))
445+ 
446+             new_global_size  =  torch .Size (
447+                 [
448+                     original_storage .size [0 ] +  pad_top  +  pad_bottom ,
449+                     original_storage .size [1 ] +  pad_left  +  pad_right ,
450+                 ]
451+             )
452+ 
453+         else :
454+             raise  NotImplementedError (f"Metadata computation for { ndim }  D not supported" )
455+ 
456+         updated_chunks  =  [
457+             ChunkStorageMetadata (
458+                 offsets = offset ,
459+                 sizes = shard .size (),
460+             )
461+             for  offset , shard  in  zip (updated_offsets , padded_shards )
462+         ]
463+ 
464+         updated_storage  =  TensorStorageMetadata (
465+             properties = original_storage .properties ,
466+             size = new_global_size ,
467+             chunks = updated_chunks ,
468+         )
469+ 
470+         return  updated_offsets , updated_storage 
471+ 
282472    @property  
283473    def  device (self ) ->  torch ._C .device :  # type: ignore[override] 
284474        return  (
0 commit comments