Skip to content
This repository was archived by the owner on Nov 23, 2024. It is now read-only.

Commit 97ae47a

Browse files
Marsmaennchen221megalinter-botlars-reimann
authored
feat: added MNIST, Fashion-MNIST and KMNIST datasets (#164)
Closes #161 Closes #162 Closes #163 ### Summary of Changes feat: added MNIST, Fashion-MNIST and KMNIST datasets build: bump safe-ds to ^0.24.0 --------- Co-authored-by: megalinter-bot <[email protected]> Co-authored-by: Lars Reimann <[email protected]>
1 parent 90de957 commit 97ae47a

File tree

8 files changed

+366
-4
lines changed

8 files changed

+366
-4
lines changed

poetry.lock

Lines changed: 1 addition & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ packages = [
1414

1515
[tool.poetry.dependencies]
1616
python = "^3.11,<3.13"
17-
safe-ds = ">=0.17,<0.27"
17+
safe-ds = ">=0.24,<0.27"
1818

1919
[tool.poetry.group.dev.dependencies]
2020
pytest = ">=7.2.1,<9.0.0"

src/safeds_datasets/image/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Image datasets."""
2+
3+
from ._mnist import load_fashion_mnist, load_kmnist, load_mnist
4+
5+
__all__ = ["load_fashion_mnist", "load_kmnist", "load_mnist"]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""MNIST like Datasets."""
2+
3+
from ._mnist import load_fashion_mnist, load_kmnist, load_mnist
4+
5+
__all__ = ["load_fashion_mnist", "load_kmnist", "load_mnist"]
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
import gzip
2+
import os
3+
import struct
4+
import sys
5+
import urllib.request
6+
from array import array
7+
from pathlib import Path
8+
from typing import TYPE_CHECKING
9+
from urllib.error import HTTPError
10+
11+
import torch
12+
from safeds._config import _init_default_device
13+
from safeds.data.image.containers._single_size_image_list import _SingleSizeImageList
14+
from safeds.data.labeled.containers import ImageDataset
15+
from safeds.data.tabular.containers import Column
16+
17+
if TYPE_CHECKING:
18+
from safeds.data.image.containers import ImageList
19+
20+
_mnist_links: list[str] = ["http://yann.lecun.com/exdb/mnist/", "https://ossci-datasets.s3.amazonaws.com/mnist/"]
21+
_mnist_files: dict[str, str] = {
22+
"train-images-idx3": "train-images-idx3-ubyte.gz",
23+
"train-labels-idx1": "train-labels-idx1-ubyte.gz",
24+
"test-images-idx3": "t10k-images-idx3-ubyte.gz",
25+
"test-labels-idx1": "t10k-labels-idx1-ubyte.gz",
26+
}
27+
_mnist_labels: dict[int, str] = {0: "0", 1: "1", 2: "2", 3: "3", 4: "4", 5: "5", 6: "6", 7: "7", 8: "8", 9: "9"}
28+
_mnist_folder: str = "mnist"
29+
30+
_fashion_mnist_links: list[str] = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
31+
_fashion_mnist_files: dict[str, str] = _mnist_files
32+
_fashion_mnist_labels: dict[int, str] = {
33+
0: "T-shirt/top",
34+
1: "Trouser",
35+
2: "Pullover",
36+
3: "Dress",
37+
4: "Coat",
38+
5: "Sandal",
39+
6: "Shirt",
40+
7: "Sneaker",
41+
8: "Bag",
42+
9: "Ankle boot",
43+
}
44+
_fashion_mnist_folder: str = "fashion-mnist"
45+
46+
_kuzushiji_mnist_links: list[str] = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
47+
_kuzushiji_mnist_files: dict[str, str] = _mnist_files
48+
_kuzushiji_mnist_labels: dict[int, str] = {
49+
0: "\u304a",
50+
1: "\u304d",
51+
2: "\u3059",
52+
3: "\u3064",
53+
4: "\u306a",
54+
5: "\u306f",
55+
6: "\u307e",
56+
7: "\u3084",
57+
8: "\u308c",
58+
9: "\u3092",
59+
}
60+
_kuzushiji_mnist_folder: str = "kmnist"
61+
62+
63+
def load_mnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Column], ImageDataset[Column]]:
64+
"""
65+
Load the `MNIST <http://yann.lecun.com/exdb/mnist/>`_ datasets.
66+
67+
Parameters
68+
----------
69+
path:
70+
the path were the files are stored or will be downloaded to
71+
download:
72+
whether the files should be downloaded to the given path
73+
74+
Returns
75+
-------
76+
train_dataset, test_dataset:
77+
The train and test datasets.
78+
79+
Raises
80+
------
81+
FileNotFoundError
82+
if a file of the dataset cannot be found
83+
"""
84+
path = Path(path) / _mnist_folder
85+
path.mkdir(parents=True, exist_ok=True)
86+
path_files = os.listdir(path)
87+
missing_files = []
88+
for file_path in _mnist_files.values():
89+
if file_path not in path_files:
90+
missing_files.append(file_path)
91+
if len(missing_files) > 0:
92+
if download:
93+
_download_mnist_like(
94+
path,
95+
{name: f_path for name, f_path in _mnist_files.items() if f_path in missing_files},
96+
_mnist_links,
97+
)
98+
else:
99+
raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}")
100+
return _load_mnist_like(path, _mnist_files, _mnist_labels)
101+
102+
103+
def load_fashion_mnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Column], ImageDataset[Column]]:
104+
"""
105+
Load the `Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ datasets.
106+
107+
Parameters
108+
----------
109+
path:
110+
the path were the files are stored or will be downloaded to
111+
download:
112+
whether the files should be downloaded to the given path
113+
114+
Returns
115+
-------
116+
train_dataset, test_dataset:
117+
The train and test datasets.
118+
119+
Raises
120+
------
121+
FileNotFoundError
122+
if a file of the dataset cannot be found
123+
"""
124+
path = Path(path) / _fashion_mnist_folder
125+
path.mkdir(parents=True, exist_ok=True)
126+
path_files = os.listdir(path)
127+
missing_files = []
128+
for file_path in _fashion_mnist_files.values():
129+
if file_path not in path_files:
130+
missing_files.append(file_path)
131+
if len(missing_files) > 0:
132+
if download:
133+
_download_mnist_like(
134+
path,
135+
{name: f_path for name, f_path in _fashion_mnist_files.items() if f_path in missing_files},
136+
_fashion_mnist_links,
137+
)
138+
else:
139+
raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}")
140+
return _load_mnist_like(path, _fashion_mnist_files, _fashion_mnist_labels)
141+
142+
143+
def load_kmnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Column], ImageDataset[Column]]:
144+
"""
145+
Load the `Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ datasets.
146+
147+
Parameters
148+
----------
149+
path:
150+
the path were the files are stored or will be downloaded to
151+
download:
152+
whether the files should be downloaded to the given path
153+
154+
Returns
155+
-------
156+
train_dataset, test_dataset:
157+
The train and test datasets.
158+
159+
Raises
160+
------
161+
FileNotFoundError
162+
if a file of the dataset cannot be found
163+
"""
164+
path = Path(path) / _kuzushiji_mnist_folder
165+
path.mkdir(parents=True, exist_ok=True)
166+
path_files = os.listdir(path)
167+
missing_files = []
168+
for file_path in _kuzushiji_mnist_files.values():
169+
if file_path not in path_files:
170+
missing_files.append(file_path)
171+
if len(missing_files) > 0:
172+
if download:
173+
_download_mnist_like(
174+
path,
175+
{name: f_path for name, f_path in _kuzushiji_mnist_files.items() if f_path in missing_files},
176+
_kuzushiji_mnist_links,
177+
)
178+
else:
179+
raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}")
180+
return _load_mnist_like(path, _kuzushiji_mnist_files, _kuzushiji_mnist_labels)
181+
182+
183+
def _load_mnist_like(
184+
path: str | Path,
185+
files: dict[str, str],
186+
labels: dict[int, str],
187+
) -> tuple[ImageDataset[Column], ImageDataset[Column]]:
188+
_init_default_device()
189+
190+
path = Path(path)
191+
test_labels: Column | None = None
192+
train_labels: Column | None = None
193+
test_image_list: ImageList | None = None
194+
train_image_list: ImageList | None = None
195+
for file_name, file_path in files.items():
196+
if "idx1" in file_name:
197+
with gzip.open(path / file_path, mode="rb") as label_file:
198+
magic, size = struct.unpack(">II", label_file.read(8))
199+
if magic != 2049:
200+
raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2049.") # pragma: no cover
201+
if "train" in file_name:
202+
train_labels = Column(
203+
file_name,
204+
[labels[label_index] for label_index in array("B", label_file.read())],
205+
)
206+
else:
207+
test_labels = Column(
208+
file_name,
209+
[labels[label_index] for label_index in array("B", label_file.read())],
210+
)
211+
else:
212+
with gzip.open(path / file_path, mode="rb") as image_file:
213+
magic, size, rows, cols = struct.unpack(">IIII", image_file.read(16))
214+
if magic != 2051:
215+
raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2051.") # pragma: no cover
216+
image_data = array("B", image_file.read())
217+
image_tensor = torch.empty(size, 1, rows, cols, dtype=torch.uint8)
218+
for i in range(size):
219+
image_tensor[i, 0] = torch.frombuffer(
220+
image_data[i * rows * cols : (i + 1) * rows * cols],
221+
dtype=torch.uint8,
222+
).reshape(rows, cols)
223+
image_list = _SingleSizeImageList()
224+
image_list._tensor = image_tensor
225+
image_list._tensor_positions_to_indices = list(range(size))
226+
image_list._indices_to_tensor_positions = image_list._calc_new_indices_to_tensor_positions()
227+
if "train" in file_name:
228+
train_image_list = image_list
229+
else:
230+
test_image_list = image_list
231+
if train_image_list is None or test_image_list is None or train_labels is None or test_labels is None:
232+
raise ValueError # pragma: no cover
233+
return ImageDataset[Column](train_image_list, train_labels, 32, shuffle=True), ImageDataset[Column](
234+
test_image_list,
235+
test_labels,
236+
32,
237+
)
238+
239+
240+
def _download_mnist_like(path: str | Path, files: dict[str, str], links: list[str]) -> None:
241+
path = Path(path)
242+
for file_name, file_path in files.items():
243+
for link in links:
244+
try:
245+
print(f"Trying to download file {file_name} via {link + file_path}") # noqa: T201
246+
urllib.request.urlretrieve(link + file_path, path / file_path, reporthook=_report_download_progress)
247+
print() # noqa: T201
248+
break
249+
except HTTPError as e:
250+
print(f"An error occurred while downloading: {e}") # noqa: T201 # pragma: no cover
251+
252+
253+
def _report_download_progress(current_packages: int, package_size: int, file_size: int) -> None:
254+
percentage = min(((current_packages * package_size) / file_size) * 100, 100)
255+
sys.stdout.write(f"\rDownloading... {percentage:.1f}%")
256+
sys.stdout.flush()

tests/safeds_datasets/image/__init__.py

Whitespace-only changes.

tests/safeds_datasets/image/_mnist/__init__.py

Whitespace-only changes.
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import os
2+
import tempfile
3+
from pathlib import Path
4+
5+
import pytest
6+
import torch
7+
from safeds.data.labeled.containers import ImageDataset
8+
from safeds_datasets.image import _mnist, load_fashion_mnist, load_kmnist, load_mnist
9+
10+
11+
class TestMNIST:
12+
13+
def test_should_download_and_return_mnist(self) -> None:
14+
with tempfile.TemporaryDirectory() as tmpdirname:
15+
train, test = load_mnist(tmpdirname, download=True)
16+
files = os.listdir(Path(tmpdirname) / _mnist._mnist._mnist_folder)
17+
for mnist_file in _mnist._mnist._mnist_files.values():
18+
assert mnist_file in files
19+
assert isinstance(train, ImageDataset)
20+
assert isinstance(test, ImageDataset)
21+
assert len(train) == 60_000
22+
assert len(test) == 10_000
23+
assert (
24+
train.get_input()._as_single_size_image_list()._tensor.dtype
25+
== test.get_input()._as_single_size_image_list()._tensor.dtype
26+
== torch.uint8
27+
)
28+
train_output = train.get_output()
29+
test_output = test.get_output()
30+
assert (
31+
set(train_output.get_distinct_values())
32+
== set(test_output.get_distinct_values())
33+
== set(_mnist._mnist._mnist_labels.values())
34+
)
35+
36+
def test_should_raise_if_file_not_found(self) -> None:
37+
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError):
38+
load_mnist(tmpdirname, download=False)
39+
40+
41+
class TestFashionMNIST:
42+
43+
def test_should_download_and_return_mnist(self) -> None:
44+
with tempfile.TemporaryDirectory() as tmpdirname:
45+
train, test = load_fashion_mnist(tmpdirname, download=True)
46+
files = os.listdir(Path(tmpdirname) / _mnist._mnist._fashion_mnist_folder)
47+
for mnist_file in _mnist._mnist._fashion_mnist_files.values():
48+
assert mnist_file in files
49+
assert isinstance(train, ImageDataset)
50+
assert isinstance(test, ImageDataset)
51+
assert len(train) == 60_000
52+
assert len(test) == 10_000
53+
assert (
54+
train.get_input()._as_single_size_image_list()._tensor.dtype
55+
== test.get_input()._as_single_size_image_list()._tensor.dtype
56+
== torch.uint8
57+
)
58+
train_output = train.get_output()
59+
test_output = test.get_output()
60+
assert (
61+
set(train_output.get_distinct_values())
62+
== set(test_output.get_distinct_values())
63+
== set(_mnist._mnist._fashion_mnist_labels.values())
64+
)
65+
66+
def test_should_raise_if_file_not_found(self) -> None:
67+
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError):
68+
load_fashion_mnist(tmpdirname, download=False)
69+
70+
71+
class TestKMNIST:
72+
73+
def test_should_download_and_return_mnist(self) -> None:
74+
with tempfile.TemporaryDirectory() as tmpdirname:
75+
train, test = load_kmnist(tmpdirname, download=True)
76+
files = os.listdir(Path(tmpdirname) / _mnist._mnist._kuzushiji_mnist_folder)
77+
for mnist_file in _mnist._mnist._kuzushiji_mnist_files.values():
78+
assert mnist_file in files
79+
assert isinstance(train, ImageDataset)
80+
assert isinstance(test, ImageDataset)
81+
assert len(train) == 60_000
82+
assert len(test) == 10_000
83+
assert (
84+
train.get_input()._as_single_size_image_list()._tensor.dtype
85+
== test.get_input()._as_single_size_image_list()._tensor.dtype
86+
== torch.uint8
87+
)
88+
train_output = train.get_output()
89+
test_output = test.get_output()
90+
assert (
91+
set(train_output.get_distinct_values())
92+
== set(test_output.get_distinct_values())
93+
== set(_mnist._mnist._kuzushiji_mnist_labels.values())
94+
)
95+
96+
def test_should_raise_if_file_not_found(self) -> None:
97+
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError):
98+
load_kmnist(tmpdirname, download=False)

0 commit comments

Comments
 (0)