diff --git a/docs/source-pytorch/accelerators/tpu_intermediate.rst b/docs/source-pytorch/accelerators/tpu_intermediate.rst index 8b1f26ec96e7e..5e6ed64df33e1 100644 --- a/docs/source-pytorch/accelerators/tpu_intermediate.rst +++ b/docs/source-pytorch/accelerators/tpu_intermediate.rst @@ -22,7 +22,7 @@ for TPU use .. code-block:: python - import torch_xla.core.xla_model as xm + from torch_xla import runtime as xr def train_dataloader(self): @@ -32,7 +32,7 @@ for TPU use sampler = None if use_tpu: sampler = torch.utils.data.distributed.DistributedSampler( - dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True + dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal(), shuffle=True ) loader = DataLoader(dataset, sampler=sampler, batch_size=32)