Skip to content

Commit 5b17f1d

Browse files
committed
add support to midi files
1 parent 17f40a3 commit 5b17f1d

File tree

10 files changed

+378
-3
lines changed

10 files changed

+378
-3
lines changed

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@
145145
"Pillow>=9.4.0", # When PIL.Image.ExifTags was introduced
146146
]
147147

148+
MIDI_REQUIRE = [
149+
"pretty-midi>=0.2.0",
150+
]
151+
148152
BENCHMARKS_REQUIRE = [
149153
"tensorflow==2.12.0",
150154
"torch==2.0.1",
@@ -213,6 +217,7 @@
213217
EXTRAS_REQUIRE = {
214218
"audio": AUDIO_REQUIRE,
215219
"vision": VISION_REQUIRE,
220+
"midi": MIDI_REQUIRE,
216221
"tensorflow": [
217222
"tensorflow>=2.6.0",
218223
],

src/datasets/arrow_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
from .arrow_writer import ArrowWriter, OptimizedTypedSequence
7979
from .data_files import sanitize_patterns
8080
from .download.streaming_download_manager import xgetsize
81-
from .features import Audio, ClassLabel, Features, Image, List, Value, Video
81+
from .features import Audio, ClassLabel, Features, Image, List, Midi, Value, Video
8282
from .features.features import (
8383
FeatureType,
8484
_align_features,
@@ -5358,7 +5358,7 @@ def _estimate_nbytes(self) -> int:
53585358

53595359
def extra_nbytes_visitor(array, feature):
53605360
nonlocal extra_nbytes
5361-
if isinstance(feature, (Audio, Image, Video)):
5361+
if isinstance(feature, (Audio, Image, Video, Midi)):
53625362
for x in array.to_pylist():
53635363
if x is not None and x["bytes"] is None and x["path"] is not None:
53645364
size = xgetsize(x["path"])

src/datasets/arrow_writer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from fsspec.core import url_to_fs
2626

2727
from . import config
28-
from .features import Audio, Features, Image, Pdf, Value, Video
28+
from .features import Audio, Features, Image, Midi, Pdf, Value, Video
2929
from .features.features import (
3030
FeatureType,
3131
List,
@@ -78,6 +78,8 @@ def set_batch_size(feature: FeatureType) -> None:
7878
batch_size = min(batch_size, config.ARROW_RECORD_BATCH_SIZE_FOR_AUDIO_DATASETS)
7979
elif isinstance(feature, Video) and config.ARROW_RECORD_BATCH_SIZE_FOR_VIDEO_DATASETS is not None:
8080
batch_size = min(batch_size, config.ARROW_RECORD_BATCH_SIZE_FOR_VIDEO_DATASETS)
81+
elif isinstance(feature, Midi) and config.ARROW_RECORD_BATCH_SIZE_FOR_MIDI_DATASETS is not None:
82+
batch_size = min(batch_size, config.ARROW_RECORD_BATCH_SIZE_FOR_MIDI_DATASETS)
8183
elif (
8284
isinstance(feature, Value)
8385
and feature.dtype == "binary"
@@ -118,6 +120,8 @@ def set_batch_size(feature: FeatureType) -> None:
118120
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS)
119121
elif isinstance(feature, Video) and config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS is not None:
120122
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS)
123+
elif isinstance(feature, Midi) and config.PARQUET_ROW_GROUP_SIZE_FOR_MIDI_DATASETS is not None:
124+
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_MIDI_DATASETS)
121125
elif (
122126
isinstance(feature, Value)
123127
and feature.dtype == "binary"

src/datasets/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,14 @@
198198
PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS = None
199199
PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS = None
200200
PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS = None
201+
PARQUET_ROW_GROUP_SIZE_FOR_MIDI_DATASETS = None
201202

202203
# Arrow configuration
203204
ARROW_RECORD_BATCH_SIZE_FOR_AUDIO_DATASETS = 100
204205
ARROW_RECORD_BATCH_SIZE_FOR_IMAGE_DATASETS = 100
205206
ARROW_RECORD_BATCH_SIZE_FOR_BINARY_DATASETS = 100
206207
ARROW_RECORD_BATCH_SIZE_FOR_VIDEO_DATASETS = 10
208+
ARROW_RECORD_BATCH_SIZE_FOR_MIDI_DATASETS = 100
207209

208210
# Offline mode
209211
_offline = os.environ.get("HF_DATASETS_OFFLINE")

src/datasets/features/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
"Video",
1717
"Pdf",
1818
"Nifti",
19+
"Midi",
1920
]
2021
from .audio import Audio
2122
from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, List, Sequence, Value
2223
from .image import Image
24+
from .midi import Midi
2325
from .nifti import Nifti
2426
from .pdf import Pdf
2527
from .translation import Translation, TranslationVariableLanguages

src/datasets/features/features.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from ..utils.py_utils import asdict, first_non_null_value, zip_dict
4343
from .audio import Audio
4444
from .image import Image, encode_pil_image
45+
from .midi import Midi
4546
from .nifti import Nifti
4647
from .pdf import Pdf, encode_pdfplumber_pdf
4748
from .translation import Translation, TranslationVariableLanguages
@@ -1431,6 +1432,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[dict[str, Uni
14311432
Video.__name__: Video,
14321433
Pdf.__name__: Pdf,
14331434
Nifti.__name__: Nifti,
1435+
Midi.__name__: Midi,
14341436
}
14351437

14361438

0 commit comments

Comments
 (0)