Skip to content

Commit 4fd761f

Browse files
wconstabpytorchmergebot
authored andcommitted
[DTensor] Wrap sharding prop error with contextual exception (pytorch#161574)
Mainly, this helps tell the user more info about the operator that failed to run if it fails during sharding propagation. Previously, only this exception would be raised: ``` RuntimeError: ('Attempted to flatten sharded dimension 1, ', 'but only the leftmost dim of a Flatten can be sharded.') ``` Now you get both the above exception as well as ``` The above exception was the direct cause of the following exception: RuntimeError: Sharding propagation failed for Op(op=aten.view.default, args_schema=Spec((Replicate(), Shard(dim=0), Shard(dim=1), Shard(dim=2)) on (8, 8, 4)), [64, 4] @ mesh: (1, 2, 2, 2)) ``` <stacktrace omitted> <details><summary>detailed error</summary> ``` ====================================================================== ERROR: test_linear (__main__.TestDTensor) ---------------------------------------------------------------------- Traceback (most recent call last): File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 668, in wrapper self._join_processes(fn) File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 932, in _join_processes self._check_return_codes(fn, elapsed_time) File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 972, in _check_return_codes raise RuntimeError(error) RuntimeError: Process 4 exited with error code 10 and exception: Traceback (most recent call last): File "/data/users/whc/pytorch/torch/distributed/tensor/_dispatch.py", line 150, in dispatch self.sharding_propagator.propagate(op_info) File "/data/users/whc/pytorch/torch/distributed/tensor/_sharding_prop.py", line 309, in propagate OutputSharding, self.propagate_op_sharding(op_info.schema) File "/data/users/whc/pytorch/torch/distributed/tensor/_sharding_prop.py", line 45, in __call__ return self.cache(*args, **kwargs) File "/data/users/whc/pytorch/torch/distributed/tensor/_sharding_prop.py", line 329, in propagate_op_sharding_non_cached op_strategy = self.op_strategy_funcs[op_schema.op](strategy_schema) File "/data/users/whc/pytorch/torch/distributed/tensor/_ops/_view_ops.py", line 673, in reshape_strategy input_tgt_placements, output_placements = propagate_shape_and_sharding( File "/data/users/whc/pytorch/torch/distributed/tensor/_ops/_view_ops.py", line 601, in propagate_shape_and_sharding in_dim = get_in_dim_to_shard(cmd) File "/data/users/whc/pytorch/torch/distributed/tensor/_ops/_view_ops.py", line 537, in get_in_dim_to_shard raise RuntimeError( RuntimeError: ('Attempted to flatten sharded dimension 1, ', 'but only the leftmost dim of a Flatten can be sharded.') The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 816, in run_test getattr(self, test_name)() File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 670, in wrapper fn() File "/data/users/whc/pytorch/torch/testing/_internal/common_utils.py", line 3224, in wrapper method(*args, **kwargs) File "/data/users/whc/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 490, in wrapper raise e File "/data/users/whc/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 487, in wrapper func(self, *args, **kwargs) # type: ignore[misc] File "/data/users/whc/pytorch/test.py", line 60, in test_linear print("results: ", distributed_linear(distributed_input)) File "/data/users/whc/pytorch/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/data/users/whc/pytorch/torch/nn/modules/module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) File "/data/users/whc/pytorch/torch/nn/modules/linear.py", line 134, in forward return F.linear(input, self.weight, self.bias) File "/data/users/whc/pytorch/torch/_compile.py", line 53, in inner return disable_fn(*args, **kwargs) File "/data/users/whc/pytorch/torch/_dynamo/eval_frame.py", line 1005, in _fn return fn(*args, **kwargs) File "/data/users/whc/pytorch/torch/distributed/tensor/_api.py", line 358, in __torch_dispatch__ return DTensor._op_dispatcher.dispatch( File "/data/users/whc/pytorch/torch/distributed/tensor/_dispatch.py", line 163, in dispatch raise RuntimeError( RuntimeError: Sharding propagation failed for Op(op=aten.view.default, args_schema=Spec((Replicate(), Shard(dim=0), Shard(dim=1), Shard(dim=2)) on (8, 8, 4)), [64, 4] @ mesh: (1, 2, 2, 2)) ``` </details> Pull Request resolved: pytorch#161574 Approved by: https://github.com/zpcore, https://github.com/XilunWu
1 parent a8270dd commit 4fd761f

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

test/distributed/tensor/test_math_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ def test_foreach_add_different_mesh(self):
724724
self.assertEqual(out0.device_mesh, mesh_x)
725725
self.assertEqual(out1.device_mesh, mesh_y)
726726

727-
with self.assertRaisesRegex(ValueError, "computation across different mesh"):
727+
with self.assertRaisesRegex(RuntimeError, "Sharding propagation failed"):
728728
torch.ops.aten._foreach_add(
729729
[replica_inp00, replica_inp01], [replica_inp10, replica_inp11]
730730
)

test/distributed/tensor/test_view_ops.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,18 +228,14 @@ def test_illegal_views(self):
228228
shard.view(-1)
229229

230230
shard = dtensor.redistribute(device_mesh=device_mesh, placements=[Shard(dim=1)])
231-
with self.assertRaisesRegex(
232-
RuntimeError, "Attempted to flatten sharded dimension"
233-
):
231+
with self.assertRaisesRegex(RuntimeError, "Sharding propagation failed"):
234232
shard.view(-1)
235233

236234
# 8 is the uneven case since mesh dim is 6
237235
tensor = torch.randn((8, 256))
238236
dtensor = distribute_tensor(tensor, device_mesh, [Replicate()])
239237
shard = dtensor.redistribute(device_mesh=device_mesh, placements=[Shard(dim=0)])
240-
with self.assertRaisesRegex(
241-
RuntimeError, "Attempted to flatten unevenly sharded dimension"
242-
):
238+
with self.assertRaisesRegex(RuntimeError, "Sharding propagation failed"):
243239
shard.view(-1)
244240

245241
@with_comms
@@ -637,9 +633,7 @@ def test_view_redistribution(self):
637633
mesh = init_device_mesh(self.device_type, (self.world_size,))
638634
dtensor_x = distribute_tensor(x, mesh, (Shard(0),))
639635

640-
with self.assertRaisesRegex(
641-
RuntimeError, "Attempted to flatten unevenly sharded dimension"
642-
):
636+
with self.assertRaisesRegex(RuntimeError, "Sharding propagation failed"):
643637
dtensor_x.view(-1, 8)
644638

645639
@with_comms

torch/distributed/tensor/_dispatch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ def dispatch(
159159
return out
160160
else:
161161
raise
162+
except Exception as e:
163+
raise RuntimeError(
164+
f"Sharding propagation failed for {op_info.schema}"
165+
) from e
162166

163167
output_sharding = op_info.output_sharding
164168
logger.debug("output_sharding for %s: %s", op_call, output_sharding)

0 commit comments

Comments
 (0)