Skip to content

Commit 86c6f71

Browse files
Revert "[Ez][BE]: Remove accidental classvar (pytorch#153540)"
This reverts commit e0dece5. Reverted pytorch#153540 on behalf of https://github.com/jeanschmidt due to Broken internal tests, @albanD may you help the author get his PR merged? D74804063 ([comment](pytorch#153540 (comment)))
1 parent 4d073af commit 86c6f71

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# mypy: allow-untyped-defs
22
from dataclasses import dataclass
33
from typing import cast, Optional, TYPE_CHECKING, Union
4-
from typing_extensions import TypeAlias
54

65
import torch
76
import torch.distributed as dist
@@ -23,8 +22,6 @@
2322
# from run-time to resolve circular dependency.
2423
from torch.distributed._shard.sharded_tensor import ShardedTensor
2524

26-
_ShardingDim: TypeAlias = Union[int, str]
27-
2825

2926
@dataclass
3027
class ChunkShardingSpec(ShardingSpec):
@@ -53,7 +50,9 @@ class ChunkShardingSpec(ShardingSpec):
5350
:class:`torch.distributed._remote_device`
5451
"""
5552

56-
dim: _ShardingDim
53+
ShardingDim = Union[int, str]
54+
55+
dim: ShardingDim
5756
placements: list[Union[torch.distributed._remote_device, str]]
5857

5958
def __post_init__(self):

0 commit comments

Comments
 (0)