Skip to content

Commit c2faee4

Browse files
guangyeyalbanDsvekars
authored
[3/N] Refine beginner tutorial by accelerator api (#3170)
* [3/N] Refine beginner tutorial by accelerator api --------- Co-authored-by: albanD <[email protected]> Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 3f302a3 commit c2faee4

File tree

4 files changed

+44
-47
lines changed

4 files changed

+44
-47
lines changed

beginner_source/chatbot_tutorial.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@
108108
import json
109109

110110

111-
USE_CUDA = torch.cuda.is_available()
112-
device = torch.device("cuda" if USE_CUDA else "cpu")
111+
# If the current `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__ is available,
112+
# we will use it. Otherwise, we use the CPU.
113+
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
114+
print(f"Using {device} device")
113115

114116

115117
######################################################################
@@ -1318,16 +1320,16 @@ def evaluateInput(encoder, decoder, searcher, voc):
13181320
encoder_optimizer.load_state_dict(encoder_optimizer_sd)
13191321
decoder_optimizer.load_state_dict(decoder_optimizer_sd)
13201322

1321-
# If you have CUDA, configure CUDA to call
1323+
# If you have an accelerator, configure it to call
13221324
for state in encoder_optimizer.state.values():
13231325
for k, v in state.items():
13241326
if isinstance(v, torch.Tensor):
1325-
state[k] = v.cuda()
1327+
state[k] = v.to(device)
13261328

13271329
for state in decoder_optimizer.state.values():
13281330
for k, v in state.items():
13291331
if isinstance(v, torch.Tensor):
1330-
state[k] = v.cuda()
1332+
state[k] = v.to(device)
13311333

13321334
# Run training iterations
13331335
print("Starting Training!")

beginner_source/introyt/tensors_deeper_tutorial.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -632,34 +632,33 @@
632632
# does this *without* changing ``a`` - you can see that when we print
633633
# ``a`` again at the end, it retains its ``requires_grad=True`` property.
634634
#
635-
# Moving to GPU
635+
# Moving to `Accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
636636
# -------------
637637
#
638-
# One of the major advantages of PyTorch is its robust acceleration on
639-
# CUDA-compatible Nvidia GPUs. (“CUDA” stands for *Compute Unified Device
640-
# Architecture*, which is Nvidia’s platform for parallel computing.) So
641-
# far, everything we’ve done has been on CPU. How do we move to the faster
638+
# One of the major advantages of PyTorch is its robust acceleration on an
639+
# `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
640+
# such as CUDA, MPS, MTIA, or XPU.
641+
# So far, everything we’ve done has been on CPU. How do we move to the faster
642642
# hardware?
643643
#
644-
# First, we should check whether a GPU is available, with the
644+
# First, we should check whether an accelerator is available, with the
645645
# ``is_available()`` method.
646646
#
647647
# .. note::
648-
# If you do not have a CUDA-compatible GPU and CUDA drivers
649-
# installed, the executable cells in this section will not execute any
650-
# GPU-related code.
648+
# If you do not have an accelerator, the executable cells in this section will not execute any
649+
# accelerator-related code.
651650
#
652651

653-
if torch.cuda.is_available():
654-
print('We have a GPU!')
652+
if torch.accelerator.is_available():
653+
print('We have an accelerator!')
655654
else:
656655
print('Sorry, CPU only.')
657656

658657

659658
##########################################################################
660-
# Once we’ve determined that one or more GPUs is available, we need to put
661-
# our data someplace where the GPU can see it. Your CPU does computation
662-
# on data in your computer’s RAM. Your GPU has dedicated memory attached
659+
# Once we’ve determined that one or more accelerators is available, we need to put
660+
# our data someplace where the accelerator can see it. Your CPU does computation
661+
# on data in your computer’s RAM. Your accelerator has dedicated memory attached
663662
# to it. Whenever you want to perform a computation on a device, you must
664663
# move *all* the data needed for that computation to memory accessible by
665664
# that device. (Colloquially, “moving the data to memory accessible by the
@@ -669,34 +668,31 @@
669668
# may do it at creation time:
670669
#
671670

672-
if torch.cuda.is_available():
673-
gpu_rand = torch.rand(2, 2, device='cuda')
671+
if torch.accelerator.is_available():
672+
gpu_rand = torch.rand(2, 2, device=torch.accelerator.current_accelerator())
674673
print(gpu_rand)
675674
else:
676675
print('Sorry, CPU only.')
677676

678677

679678
##########################################################################
680679
# By default, new tensors are created on the CPU, so we have to specify
681-
# when we want to create our tensor on the GPU with the optional
680+
# when we want to create our tensor on the accelerator with the optional
682681
# ``device`` argument. You can see when we print the new tensor, PyTorch
683682
# informs us which device it’s on (if it’s not on CPU).
684683
#
685-
# You can query the number of GPUs with ``torch.cuda.device_count()``. If
686-
# you have more than one GPU, you can specify them by index:
684+
# You can query the number of accelerators with ``torch.accelerator.device_count()``. If
685+
# you have more than one accelerator, you can specify them by index, take CUDA for example:
687686
# ``device='cuda:0'``, ``device='cuda:1'``, etc.
688687
#
689688
# As a coding practice, specifying our devices everywhere with string
690689
# constants is pretty fragile. In an ideal world, your code would perform
691-
# robustly whether you’re on CPU or GPU hardware. You can do this by
690+
# robustly whether you’re on CPU or accelerator hardware. You can do this by
692691
# creating a device handle that can be passed to your tensors instead of a
693692
# string:
694693
#
695694

696-
if torch.cuda.is_available():
697-
my_device = torch.device('cuda')
698-
else:
699-
my_device = torch.device('cpu')
695+
my_device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device('cpu')
700696
print('Device: {}'.format(my_device))
701697

702698
x = torch.rand(2, 2, device=my_device)
@@ -718,12 +714,12 @@
718714
# It is important to know that in order to do computation involving two or
719715
# more tensors, *all of the tensors must be on the same device*. The
720716
# following code will throw a runtime error, regardless of whether you
721-
# have a GPU device available:
717+
# have an accelerator device available, take CUDA for example:
722718
#
723719
# .. code-block:: python
724720
#
725721
# x = torch.rand(2, 2)
726-
# y = torch.rand(2, 2, device='gpu')
722+
# y = torch.rand(2, 2, device='cuda')
727723
# z = x + y # exception will be thrown
728724
#
729725

beginner_source/knowledge_distillation_tutorial.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@
3737
import torchvision.transforms as transforms
3838
import torchvision.datasets as datasets
3939

40-
# Check if GPU is available, and if not, use the CPU
41-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40+
# Check if the current `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
41+
# is available, and if not, use the CPU
42+
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
43+
print(f"Using {device} device")
4244

4345
######################################################################
4446
# Loading CIFAR-10

beginner_source/nn_tutorial.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@
132132
# we'll write `log_softmax` and use it. Remember: although PyTorch
133133
# provides lots of prewritten loss functions, activation functions, and
134134
# so forth, you can easily write your own using plain python. PyTorch will
135-
# even create fast GPU or vectorized CPU code for your function
135+
# even create fast accelerator or vectorized CPU code for your function
136136
# automatically.
137137

138138
def log_softmax(x):
@@ -827,38 +827,35 @@ def __iter__(self):
827827
fit(epochs, model, loss_func, opt, train_dl, valid_dl)
828828

829829
###############################################################################
830-
# Using your GPU
830+
# Using your `Accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
831831
# ---------------
832832
#
833-
# If you're lucky enough to have access to a CUDA-capable GPU (you can
833+
# If you're lucky enough to have access to an accelerator such as CUDA (you can
834834
# rent one for about $0.50/hour from most cloud providers) you can
835-
# use it to speed up your code. First check that your GPU is working in
835+
# use it to speed up your code. First check that your accelerator is working in
836836
# Pytorch:
837837

838-
print(torch.cuda.is_available())
838+
# If the current accelerator is available, we will use it. Otherwise, we use the CPU.
839+
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
840+
print(f"Using {device} device")
839841

840-
###############################################################################
841-
# And then create a device object for it:
842-
843-
dev = torch.device(
844-
"cuda") if torch.cuda.is_available() else torch.device("cpu")
845842

846843
###############################################################################
847-
# Let's update ``preprocess`` to move batches to the GPU:
844+
# Let's update ``preprocess`` to move batches to the accelerator:
848845

849846

850847
def preprocess(x, y):
851-
return x.view(-1, 1, 28, 28).to(dev), y.to(dev)
848+
return x.view(-1, 1, 28, 28).to(device), y.to(device)
852849

853850

854851
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
855852
train_dl = WrappedDataLoader(train_dl, preprocess)
856853
valid_dl = WrappedDataLoader(valid_dl, preprocess)
857854

858855
###############################################################################
859-
# Finally, we can move our model to the GPU.
856+
# Finally, we can move our model to the accelerator.
860857

861-
model.to(dev)
858+
model.to(device)
862859
opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
863860

864861
###############################################################################

0 commit comments

Comments
 (0)