-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
38 lines (32 loc) · 1.33 KB
/
dataset.py
File metadata and controls
38 lines (32 loc) · 1.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
IMG_SIZE = 64
def load_transformed_dataset():
data_transforms = [
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), # Scales data into [0,1]
transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
]
data_transform = transforms.Compose(data_transforms)
train = torchvision.datasets.Food101(root="./data/", download=True,
transform=data_transform)
test = torchvision.datasets.Food101(root="./data/", download=True,
transform=data_transform, split='test')
return torch.utils.data.ConcatDataset([train, test])
def show_tensor_image(image, image_name):
reverse_transforms = transforms.Compose([
transforms.Lambda(lambda t: (t + 1) / 2),
transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
transforms.Lambda(lambda t: t * 255.),
transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
transforms.ToPILImage(),
])
# Take first image of batch
if len(image.shape) == 4:
image = image[0, :, :, :]
plt.imshow(reverse_transforms(image))
plt.savefig(f'./{image_name}.png')