-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_loader.py
More file actions
32 lines (24 loc) · 907 Bytes
/
data_loader.py
File metadata and controls
32 lines (24 loc) · 907 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torchvision import datasets
from skimage.color import rgb2lab
from utils import *
class GrayscaleImageFolder(datasets.ImageFolder):
def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(path)
transposed_dim = (2, 0, 1)
# crop image
img_original = self.transform(img)
img_original = np.asarray(img_original)
# transform to lab
img_lab = rgb2lab(img_original)
# gray channel
img_gray = img_lab[:, :, 0] / 100
# ab channel
img_ab = img_lab[:, :, 1:3]
img_smooth = compute_smoothed(img_ab)
# numpy to torch
img_ab = img_ab.transpose(transposed_dim)
img_ab = torch.from_numpy(img_ab).float()
img_ab = (img_ab + 128) / 255
img_gray = torch.from_numpy(img_gray).unsqueeze(0).float()
return img_gray, img_ab, img_smooth