|
23 | 23 |
|
24 | 24 | The FashionMNIST features are in PIL Image format, and the labels are integers. |
25 | 25 | For training, we need the features as normalized tensors, and the labels as one-hot encoded tensors. |
26 | | -To make these transformations, we use ``ToTensor`` and ``Lambda``. |
| 26 | +To make these transformations, we use the ``torchvision.transforms.v2`` API along with ``torch.nn.functional.one_hot``. |
27 | 27 | """ |
28 | 28 |
|
29 | 29 | import torch |
| 30 | +import torch.nn.functional as F |
30 | 31 | from torchvision import datasets |
31 | | -from torchvision.transforms import ToTensor, Lambda |
| 32 | +from torchvision.transforms import v2 |
32 | 33 |
|
33 | 34 | ds = datasets.FashionMNIST( |
34 | 35 | root="data", |
35 | 36 | train=True, |
36 | 37 | download=True, |
37 | | - transform=ToTensor(), |
38 | | - target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) |
| 38 | + transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]), |
| 39 | + target_transform=v2.Lambda( |
| 40 | + lambda y: F.one_hot(torch.tensor(y), num_classes=10).float() |
| 41 | + ), |
39 | 42 | ) |
40 | 43 |
|
41 | 44 | ################################################# |
42 | | -# ToTensor() |
| 45 | +# ToImage() and ToDtype() |
43 | 46 | # ------------------------------- |
44 | 47 | # |
45 | | -# `ToTensor <https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor>`_ |
46 | | -# converts a PIL image or NumPy ``ndarray`` into a ``FloatTensor``. and scales |
47 | | -# the image's pixel intensity values in the range [0., 1.] |
| 48 | +# The ``torchvision.transforms.v2`` API replaces the legacy ``ToTensor`` transform with a two-step pipeline. |
| 49 | +# `v2.ToImage <https://pytorch.org/vision/stable/generated/torchvision.transforms.v2.ToImage.html>`_ |
| 50 | +# converts a PIL image or NumPy ``ndarray`` into a ``torchvision.tv_tensors.Image`` tensor, and |
| 51 | +# `v2.ToDtype <https://pytorch.org/vision/stable/generated/torchvision.transforms.v2.ToDtype.html>`_ |
| 52 | +# with ``scale=True`` casts it to ``float32`` and scales the pixel intensity values to the range [0., 1.]. |
48 | 53 | # |
49 | 54 |
|
50 | 55 | ############################################## |
51 | 56 | # Lambda Transforms |
52 | 57 | # ------------------------------- |
53 | 58 | # |
54 | | -# Lambda transforms apply any user-defined lambda function. Here, we define a function |
55 | | -# to turn the integer into a one-hot encoded tensor. |
56 | | -# It first creates a zero tensor of size 10 (the number of labels in our dataset) and calls |
57 | | -# `scatter_ <https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html>`_ which assigns a |
58 | | -# ``value=1`` on the index as given by the label ``y``. |
| 59 | +# Lambda transforms apply any user-defined lambda function. Here, we use |
| 60 | +# `torch.nn.functional.one_hot <https://pytorch.org/docs/stable/generated/torch.nn.functional.one_hot.html>`_ |
| 61 | +# to turn the integer label into a one-hot encoded tensor of size 10 (the number of labels in our dataset), |
| 62 | +# then cast it to ``float`` to match the expected dtype. |
59 | 63 |
|
60 | | -target_transform = Lambda(lambda y: torch.zeros( |
61 | | - 10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1)) |
| 64 | +target_transform = v2.Lambda( |
| 65 | + lambda y: F.one_hot(torch.tensor(y), num_classes=10).float() |
| 66 | +) |
62 | 67 |
|
63 | 68 | ###################################################################### |
64 | 69 | # -------------- |
|
67 | 72 | ################################################################# |
68 | 73 | # Further Reading |
69 | 74 | # ~~~~~~~~~~~~~~~~~ |
70 | | -# - `torchvision.transforms API <https://pytorch.org/vision/stable/transforms.html>`_ |
| 75 | +# - `Getting started with transforms v2 <https://pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_getting_started.html>`_ |
| 76 | +# - `torchvision.transforms.v2 API <https://pytorch.org/vision/stable/transforms.html#v2-api-reference-recommended>`_ |
0 commit comments