diff --git a/setup.py b/setup.py index 2f626763113..4c31ec7df52 100644 --- a/setup.py +++ b/setup.py @@ -145,6 +145,10 @@ "Pillow>=9.4.0", # When PIL.Image.ExifTags was introduced ] +MIDI_REQUIRE = [ + "pretty-midi>=0.2.0", +] + BENCHMARKS_REQUIRE = [ "tensorflow==2.12.0", "torch==2.0.1", @@ -187,6 +191,7 @@ "Pillow>=9.4.0", # When PIL.Image.ExifTags was introduced "torchcodec>=0.7.0; python_version < '3.14'", # minium version to get windows support, torchcodec doesn't have wheels for 3.14 yet "nibabel>=5.3.1", + "pretty-midi>=0.2.0", ] NUMPY2_INCOMPATIBLE_LIBRARIES = [ @@ -213,6 +218,7 @@ EXTRAS_REQUIRE = { "audio": AUDIO_REQUIRE, "vision": VISION_REQUIRE, + "midi": MIDI_REQUIRE, "tensorflow": [ "tensorflow>=2.6.0", ], diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 36b744a024a..7120652283b 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -78,7 +78,7 @@ from .arrow_writer import ArrowWriter, OptimizedTypedSequence from .data_files import sanitize_patterns from .download.streaming_download_manager import xgetsize -from .features import Audio, ClassLabel, Features, Image, List, Value, Video +from .features import Audio, ClassLabel, Features, Image, List, Midi, Value, Video from .features.features import ( FeatureType, _align_features, @@ -5358,7 +5358,7 @@ def _estimate_nbytes(self) -> int: def extra_nbytes_visitor(array, feature): nonlocal extra_nbytes - if isinstance(feature, (Audio, Image, Video)): + if isinstance(feature, (Audio, Image, Video, Midi)): for x in array.to_pylist(): if x is not None and x["bytes"] is None and x["path"] is not None: size = xgetsize(x["path"]) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 3174f5cf206..946d465d997 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -25,7 +25,7 @@ from fsspec.core import url_to_fs from . import config -from .features import Audio, Features, Image, Pdf, Value, Video +from .features import Audio, Features, Image, Midi, Pdf, Value, Video from .features.features import ( FeatureType, List, @@ -78,6 +78,8 @@ def set_batch_size(feature: FeatureType) -> None: batch_size = min(batch_size, config.ARROW_RECORD_BATCH_SIZE_FOR_AUDIO_DATASETS) elif isinstance(feature, Video) and config.ARROW_RECORD_BATCH_SIZE_FOR_VIDEO_DATASETS is not None: batch_size = min(batch_size, config.ARROW_RECORD_BATCH_SIZE_FOR_VIDEO_DATASETS) + elif isinstance(feature, Midi) and config.ARROW_RECORD_BATCH_SIZE_FOR_MIDI_DATASETS is not None: + batch_size = min(batch_size, config.ARROW_RECORD_BATCH_SIZE_FOR_MIDI_DATASETS) elif ( isinstance(feature, Value) and feature.dtype == "binary" @@ -118,6 +120,8 @@ def set_batch_size(feature: FeatureType) -> None: batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS) elif isinstance(feature, Video) and config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS is not None: batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS) + elif isinstance(feature, Midi) and config.PARQUET_ROW_GROUP_SIZE_FOR_MIDI_DATASETS is not None: + batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_MIDI_DATASETS) elif ( isinstance(feature, Value) and feature.dtype == "binary" diff --git a/src/datasets/config.py b/src/datasets/config.py index b6412682727..8742f71d0ad 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -198,12 +198,14 @@ PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS = None PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS = None PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS = None +PARQUET_ROW_GROUP_SIZE_FOR_MIDI_DATASETS = None # Arrow configuration ARROW_RECORD_BATCH_SIZE_FOR_AUDIO_DATASETS = 100 ARROW_RECORD_BATCH_SIZE_FOR_IMAGE_DATASETS = 100 ARROW_RECORD_BATCH_SIZE_FOR_BINARY_DATASETS = 100 ARROW_RECORD_BATCH_SIZE_FOR_VIDEO_DATASETS = 10 +ARROW_RECORD_BATCH_SIZE_FOR_MIDI_DATASETS = 100 # Offline mode _offline = os.environ.get("HF_DATASETS_OFFLINE") diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py index 40a3568039a..937f80729f2 100644 --- a/src/datasets/features/__init__.py +++ b/src/datasets/features/__init__.py @@ -16,10 +16,12 @@ "Video", "Pdf", "Nifti", + "Midi", ] from .audio import Audio from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, List, Sequence, Value from .image import Image +from .midi import Midi from .nifti import Nifti from .pdf import Pdf from .translation import Translation, TranslationVariableLanguages diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 88259767ae0..bc3e2d6f52d 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -42,6 +42,7 @@ from ..utils.py_utils import asdict, first_non_null_value, zip_dict from .audio import Audio from .image import Image, encode_pil_image +from .midi import Midi from .nifti import Nifti from .pdf import Pdf, encode_pdfplumber_pdf from .translation import Translation, TranslationVariableLanguages @@ -1431,6 +1432,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[dict[str, Uni Video.__name__: Video, Pdf.__name__: Pdf, Nifti.__name__: Nifti, + Midi.__name__: Midi, } diff --git a/src/datasets/features/midi.py b/src/datasets/features/midi.py new file mode 100644 index 00000000000..e5bc38c9d97 --- /dev/null +++ b/src/datasets/features/midi.py @@ -0,0 +1,309 @@ +import os +from dataclasses import dataclass, field +from io import BytesIO +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union + +import numpy as np +import pyarrow as pa + +from .. import config +from ..download.download_config import DownloadConfig +from ..table import array_cast +from ..utils.file_utils import is_local_path, xopen +from ..utils.py_utils import no_op_if_value_is_null, string_to_dict + + +if TYPE_CHECKING: + from .features import FeatureType + + +@dataclass +class Midi: + """Midi [`Feature`] to extract MIDI data from a MIDI file. + + Input: The Midi feature accepts as input: + - A `str`: Absolute path to the MIDI file (i.e. random access is allowed). + - A `pathlib.Path`: path to the MIDI file (i.e. random access is allowed). + - A `dict` with the keys: + + - `path`: String with relative path of the MIDI file to the archive file. + - `bytes`: Bytes content of the MIDI file. + + This is useful for parquet or webdataset files which embed MIDI files. + + - A `dict` with the keys: + + - `notes`: Array containing the MIDI notes data + - `tempo`: Float corresponding to the tempo of the MIDI file + - `resolution`: Integer corresponding to the ticks per quarter note + + Output: The Midi features output data as a dictionary with keys: + + - `notes`: Array containing the MIDI notes data (pitch, velocity, start_time, end_time) + - `tempo`: Float corresponding to the tempo of the MIDI file + - `resolution`: Integer corresponding to the ticks per quarter note + - `instruments`: List of instrument information + + Args: + decode (`bool`, defaults to `True`): + Whether to decode the MIDI data. If `False`, + returns the underlying dictionary in the format `{"path": midi_path, "bytes": midi_bytes}`. + resolution (`int`, *optional*): + Target resolution in ticks per quarter note. If `None`, the native resolution is used. + + Example: + + ```py + >>> from datasets import load_dataset, Midi + >>> ds = load_dataset("example/midi_dataset", split="train") + >>> ds = ds.cast_column("midi", Midi()) + >>> ds[0]["midi"] + {'notes': array([[60, 64, 0.0, 1.0], [62, 64, 1.0, 2.0]]), 'tempo': 120.0, 'resolution': 480, 'instruments': [{'program': 0, 'name': 'Acoustic Grand Piano'}]} + ``` + """ + + decode: bool = True + resolution: Optional[int] = None + id: Optional[str] = field(default=None, repr=False) + # Automatically constructed + dtype: ClassVar[str] = "dict" + pa_type: ClassVar[Any] = pa.struct({"bytes": pa.binary(), "path": pa.string()}) + _type: str = field(default="Midi", init=False, repr=False) + + def __call__(self): + return self.pa_type + + def encode_example(self, value: Union[str, bytes, bytearray, dict]) -> dict: + """Encode example into a format for Arrow. + + Args: + value (`str`, `bytes`,`bytearray`,`dict`): + Data passed as input to Midi feature. + + Returns: + `dict` + """ + + if value is None: + raise ValueError("value must be provided") + + if isinstance(value, str): + return {"bytes": None, "path": value} + elif isinstance(value, Path): + return {"bytes": None, "path": str(value.absolute())} + elif isinstance(value, (bytes, bytearray)): + return {"bytes": value, "path": None} + elif "notes" in value: + # Convert MIDI data to bytes + midi_data = self._create_midi_from_data(value) + buffer = BytesIO() + midi_data.write(buffer) + return {"bytes": buffer.getvalue(), "path": None} + elif value.get("path") is not None and os.path.isfile(value["path"]): + return {"bytes": None, "path": value.get("path")} + elif value.get("bytes") is not None or value.get("path") is not None: + # store the MIDI bytes, and path is used to infer the MIDI format using the file extension + return {"bytes": value.get("bytes"), "path": value.get("path")} + else: + raise ValueError( + f"A MIDI sample should have one of 'path' or 'bytes' but they are missing or None in {value}." + ) + + @classmethod + def _create_midi_from_data(cls, value: dict) -> "pretty_midi.PrettyMIDI": + """Create MIDI file from note data.""" + try: + import pretty_midi + except ImportError as err: + raise ImportError("To support encoding MIDI data, please install 'pretty_midi'.") from err + + # Create piano instrument + piano_program = pretty_midi.instrument_name_to_program("Acoustic Grand Piano") + piano = pretty_midi.Instrument(program=piano_program) + + notes = value.get("notes", []) + for note_data in notes: + if len(note_data) >= 4: + pitch, velocity, start, end = note_data[:4] + note = pretty_midi.Note(velocity=int(velocity), pitch=int(pitch), start=float(start), end=float(end)) + piano.notes.append(note) + + if "tempo" in value: + midi = pretty_midi.PrettyMIDI(initial_tempo=value["tempo"]) + else: + midi = pretty_midi.PrettyMIDI() + + midi.instruments.append(piano) + + return midi + + def decode_example( + self, value: dict, token_per_repo_id: Optional[dict[str, Union[str, bool, None]]] = None + ) -> dict: + """Decode example MIDI file into MIDI data. + + Args: + value (`dict`): + A dictionary with keys: + + - `path`: String with relative MIDI file path. + - `bytes`: Bytes of the MIDI file. + token_per_repo_id (`dict`, *optional*): + To access and decode + MIDI files from private repositories on the Hub, you can pass + a dictionary repo_id (`str`) -> token (`bool` or `str`) + + Returns: + `dict`: Dictionary containing MIDI data with keys: + - `notes`: Array of [pitch, velocity, start_time, end_time] + - `tempo`: Float tempo value + - `resolution`: Integer ticks per quarter note + - `instruments`: List of instrument information + """ + try: + import pretty_midi + except ImportError as err: + raise ImportError("To support decoding MIDI data, please install 'pretty_midi'.") from err + + if not self.decode: + raise RuntimeError("Decoding is disabled for this feature. Please use Midi(decode=True) instead.") + + path, bytes_data = (value["path"], value["bytes"]) if value["bytes"] is not None else (value["path"], None) + if path is None and bytes_data is None: + raise ValueError(f"A MIDI sample should have one of 'path' or 'bytes' but both are None in {value}.") + + midi = None + if bytes_data is not None: + midi = pretty_midi.PrettyMIDI(BytesIO(bytes_data)) + elif is_local_path(path): + midi = pretty_midi.PrettyMIDI(path) + else: + token_per_repo_id = token_per_repo_id or {} + source_url = path.split("::")[-1] + pattern = ( + config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL + ) + source_url_fields = string_to_dict(source_url, pattern) + token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None + + download_config = DownloadConfig(token=token) + with xopen(path, "rb", download_config=download_config) as f: + midi = pretty_midi.PrettyMIDI(BytesIO(f.read())) + + # Extract notes data + notes = [] + for instrument in midi.instruments: + for note in instrument.notes: + notes.append([note.pitch, note.velocity, note.start, note.end]) + + # Extract instrument information + instruments = [] + for instrument in midi.instruments: + instruments.append( + {"program": instrument.program, "name": pretty_midi.program_to_instrument_name(instrument.program)} + ) + + # Get tempo + tempo = 120.0 # Default tempo + if midi.get_tempo_changes() is not None: + tempo = midi.get_tempo_changes()[0] + + return { + "notes": np.array(notes) if notes else np.empty((0, 4)), + "tempo": tempo, + "resolution": midi.resolution, + "instruments": instruments, + "path": path, + } + + def flatten(self) -> Union["FeatureType", dict[str, "FeatureType"]]: + """If in the decodable state, raise an error, otherwise flatten the feature into a dictionary.""" + from .features import Value + + if self.decode: + raise ValueError("Cannot flatten a decoded Midi feature.") + return { + "bytes": Value("binary"), + "path": Value("string"), + } + + def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.StructArray: + """Cast an Arrow array to the Midi arrow storage type. + The Arrow types that can be converted to the Midi pyarrow storage type are: + + - `pa.string()` - it must contain the "path" data + - `pa.binary()` - it must contain the MIDI bytes + - `pa.struct({"bytes": pa.binary()})` + - `pa.struct({"path": pa.string()})` + - `pa.struct({"bytes": pa.binary(), "path": pa.string()})` - order doesn't matter + + Args: + storage (`Union[pa.StringArray, pa.StructArray]`): + PyArrow array to cast. + + Returns: + `pa.StructArray`: Array in the Midi arrow storage type, that is + `pa.struct({"bytes": pa.binary(), "path": pa.string()})` + """ + if pa.types.is_string(storage.type): + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_binary(storage.type): + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_struct(storage.type) and storage.type.get_all_field_indices("notes"): + storage = pa.array( + [Midi().encode_example(x) if x is not None else None for x in storage.to_numpy(zero_copy_only=False)] + ) + elif pa.types.is_struct(storage.type): + if storage.type.get_field_index("bytes") >= 0: + bytes_array = storage.field("bytes") + else: + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + if storage.type.get_field_index("path") >= 0: + path_array = storage.field("path") + else: + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null()) + return array_cast(storage, self.pa_type) + + def embed_storage(self, storage: pa.StructArray, token_per_repo_id=None) -> pa.StructArray: + """Embed MIDI files into the Arrow array. + + Args: + storage (`pa.StructArray`): + PyArrow array to embed. + + Returns: + `pa.StructArray`: Array in the Midi arrow storage type, that is + `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. + """ + if token_per_repo_id is None: + token_per_repo_id = {} + + @no_op_if_value_is_null + def path_to_bytes(path): + source_url = path.split("::")[-1] + pattern = ( + config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL + ) + source_url_fields = string_to_dict(source_url, pattern) + token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None + download_config = DownloadConfig(token=token) + with xopen(path, "rb", download_config=download_config) as f: + return f.read() + + bytes_array = pa.array( + [ + (path_to_bytes(x["path"]) if x["bytes"] is None else x["bytes"]) if x is not None else None + for x in storage.to_pylist() + ], + type=pa.binary(), + ) + path_array = pa.array( + [os.path.basename(path) if path is not None else None for path in storage.field("path").to_pylist()], + type=pa.string(), + ) + storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()) + return array_cast(storage, self.pa_type) diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 9d076df44b7..5b9027c4446 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -11,6 +11,7 @@ from .hdf5 import hdf5 from .imagefolder import imagefolder from .json import json +from .midifolder import midifolder from .niftifolder import niftifolder from .pandas import pandas from .parquet import parquet @@ -51,6 +52,7 @@ def _hash_python_lines(lines: list[str]) -> str: "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), "xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())), "hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())), + "midifolder": (midifolder.__name__, _hash_python_lines(inspect.getsource(midifolder).splitlines())), } # get importable module names and hash for caching @@ -93,6 +95,8 @@ def _hash_python_lines(lines: list[str]) -> str: _EXTENSION_TO_MODULE.update({ext.upper(): ("pdffolder", {}) for ext in pdffolder.PdfFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext: ("niftifolder", {}) for ext in niftifolder.NiftiFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("niftifolder", {}) for ext in niftifolder.NiftiFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext: ("midifolder", {}) for ext in midifolder.MidiFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext.upper(): ("midifolder", {}) for ext in midifolder.MidiFolder.EXTENSIONS}) # Used to filter data files based on extensions given a module name _MODULE_TO_EXTENSIONS: dict[str, list[str]] = {} @@ -111,3 +115,4 @@ def _hash_python_lines(lines: list[str]) -> str: _MODULE_TO_METADATA_FILE_NAMES["videofolder"] = imagefolder.ImageFolder.METADATA_FILENAMES _MODULE_TO_METADATA_FILE_NAMES["pdffolder"] = imagefolder.ImageFolder.METADATA_FILENAMES _MODULE_TO_METADATA_FILE_NAMES["niftifolder"] = imagefolder.ImageFolder.METADATA_FILENAMES +_MODULE_TO_METADATA_FILE_NAMES["midifolder"] = midifolder.MidiFolder.METADATA_FILENAMES diff --git a/src/datasets/packaged_modules/midifolder/__init__.py b/src/datasets/packaged_modules/midifolder/__init__.py new file mode 100644 index 00000000000..c9c93018dd9 --- /dev/null +++ b/src/datasets/packaged_modules/midifolder/__init__.py @@ -0,0 +1,4 @@ +from .midifolder import MidiFolder, MidiFolderConfig + + +__all__ = ["MidiFolder", "MidiFolderConfig"] diff --git a/src/datasets/packaged_modules/midifolder/midifolder.py b/src/datasets/packaged_modules/midifolder/midifolder.py new file mode 100644 index 00000000000..e75871b53af --- /dev/null +++ b/src/datasets/packaged_modules/midifolder/midifolder.py @@ -0,0 +1,33 @@ +import datasets + +from ..folder_based_builder import folder_based_builder + + +logger = datasets.utils.logging.get_logger(__name__) + + +class MidiFolderConfig(folder_based_builder.FolderBasedBuilderConfig): + """Builder Config for MidiFolder.""" + + drop_labels: bool = None + drop_metadata: bool = None + + def __post_init__(self): + super().__post_init__() + + +class MidiFolder(folder_based_builder.FolderBasedBuilder): + BASE_FEATURE = datasets.Midi + BASE_COLUMN_NAME = "midi" + BUILDER_CONFIG_CLASS = MidiFolderConfig + EXTENSIONS: list[str] # definition at the bottom of the script + + +# Common MIDI file extensions +MIDI_EXTENSIONS = [ + ".mid", + ".midi", + ".kar", # Karaoke MIDI files + ".rmi", # RIFF MIDI files +] +MidiFolder.EXTENSIONS = MIDI_EXTENSIONS diff --git a/tests/features/test_midi.py b/tests/features/test_midi.py new file mode 100644 index 00000000000..6f3dc9504d2 --- /dev/null +++ b/tests/features/test_midi.py @@ -0,0 +1,69 @@ +from unittest import TestCase + +import pytest + +from datasets import Dataset, Features +from datasets.features import Midi + + +class TestMidiFeature(TestCase): + def test_audio_feature_type(self): + midi = Midi() + assert midi.dtype == "dict" + assert midi.pa_type.names == ["bytes", "path"] + + def test_audio_feature_encode_example(self): + midi = Midi() + + # Test with path + encoded = midi.encode_example("path/to/midi.mid") + assert encoded == {"bytes": None, "path": "path/to/midi.mid"} + + # Test with bytes + encoded = midi.encode_example(b"fake_midi_bytes") + assert encoded == {"bytes": b"fake_midi_bytes", "path": None} + + # Test with dict containing notes + notes_data = {"notes": [[60, 64, 0.0, 1.0], [62, 64, 1.0, 2.0]], "tempo": 120.0, "resolution": 480} + encoded = midi.encode_example(notes_data) + assert "bytes" in encoded + assert encoded["path"] is None + + def test_audio_feature_decode_example(self): + midi = Midi() + + # Test decode with bytes + fake_midi_bytes = b"MThd\x00\x00\x00\x06\x00\x01\x00\x02\x00\xdcMTrk\x00\x00\x00\x13\x00\xffQ\x03\x07\xa1 \x00\xffX\x04\x04\x02\x18\x08\x01\xff/\x00MTrk\x00\x00\x00\x16\x00\xc0\x00\x00\x90<@\x838<\x00\x00>@\x838>\x00\x01\xff/\x00" + decoded = midi.decode_example({"bytes": fake_midi_bytes, "path": None}) + assert "notes" in decoded + assert "tempo" in decoded + assert "resolution" in decoded + assert "instruments" in decoded + + def test_audio_feature_with_dataset(self): + features = Features({"midi": Midi()}) + data = {"midi": ["fake_path1.mid", "fake_path2.mid"]} + + dataset = Dataset.from_dict(data, features=features) + assert "midi" in dataset.column_names + assert dataset.features["midi"].dtype == "dict" + + def test_audio_feature_decode_false(self): + midi = Midi(decode=False) + encoded = midi.encode_example("path/to/midi.mid") + assert encoded == {"bytes": None, "path": "path/to/midi.mid"} + + def test_audio_feature_resolution(self): + midi = Midi(resolution=960) + assert midi.resolution == 960 + + def test_audio_feature_flatten(self): + midi = Midi(decode=False) + flattened = midi.flatten() + assert "bytes" in flattened # type: ignore + assert "path" in flattened # type: ignore + + def test_audio_feature_decode_error(self): + midi = Midi(decode=False) + with pytest.raises(RuntimeError): + midi.decode_example({"bytes": b"fake", "path": None})