Skip to content

Commit 88d6dd3

Browse files
committed
tests: Add test for Midi feature
1 parent 5b17f1d commit 88d6dd3

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

tests/features/test_midi.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import os
2+
import tempfile
3+
from unittest import TestCase
4+
5+
import numpy as np
6+
import pytest
7+
8+
from datasets import Dataset, Features, load_dataset
9+
from datasets.features import Midi, Value
10+
11+
12+
class TestMidiFeature(TestCase):
13+
def test_audio_feature_type(self):
14+
midi = Midi()
15+
assert midi.dtype == "dict"
16+
assert midi.pa_type.names == ["bytes", "path"]
17+
18+
def test_audio_feature_encode_example(self):
19+
midi = Midi()
20+
21+
# Test with path
22+
encoded = midi.encode_example("path/to/midi.mid")
23+
assert encoded == {"bytes": None, "path": "path/to/midi.mid"}
24+
25+
# Test with bytes
26+
encoded = midi.encode_example(b"fake_midi_bytes")
27+
assert encoded == {"bytes": b"fake_midi_bytes", "path": None}
28+
29+
# Test with dict containing notes
30+
notes_data = {
31+
"notes": [[60, 64, 0.0, 1.0], [62, 64, 1.0, 2.0]],
32+
"tempo": 120.0,
33+
"resolution": 480
34+
}
35+
encoded = midi.encode_example(notes_data)
36+
assert "bytes" in encoded
37+
assert encoded["path"] is None
38+
39+
def test_audio_feature_decode_example(self):
40+
midi = Midi()
41+
42+
# Test decode with bytes
43+
fake_midi_bytes = b'MThd\x00\x00\x00\x06\x00\x01\x00\x02\x00\xdcMTrk\x00\x00\x00\x13\x00\xffQ\x03\x07\xa1 \x00\xffX\x04\x04\x02\x18\x08\x01\xff/\x00MTrk\x00\x00\x00\x16\x00\xc0\x00\x00\x90<@\x838<\x00\x00>@\x838>\x00\x01\xff/\x00'
44+
decoded = midi.decode_example({"bytes": fake_midi_bytes, "path": None})
45+
assert "notes" in decoded
46+
assert "tempo" in decoded
47+
assert "resolution" in decoded
48+
assert "instruments" in decoded
49+
50+
def test_audio_feature_with_dataset(self):
51+
features = Features({"midi": Midi()})
52+
data = {"midi": ["fake_path1.mid", "fake_path2.mid"]}
53+
54+
with tempfile.TemporaryDirectory() as tmp_dir:
55+
dataset = Dataset.from_dict(data, features=features)
56+
assert "midi" in dataset.column_names
57+
assert dataset.features["midi"].dtype == "dict"
58+
59+
def test_audio_feature_decode_false(self):
60+
midi = Midi(decode=False)
61+
encoded = midi.encode_example("path/to/midi.mid")
62+
assert encoded == {"bytes": None, "path": "path/to/midi.mid"}
63+
64+
def test_audio_feature_resolution(self):
65+
midi = Midi(resolution=960)
66+
assert midi.resolution == 960
67+
68+
def test_audio_feature_flatten(self):
69+
midi = Midi(decode=False)
70+
flattened = midi.flatten()
71+
assert "bytes" in flattened
72+
assert "path" in flattened
73+
74+
def test_audio_feature_decode_error(self):
75+
midi = Midi(decode=False)
76+
with pytest.raises(RuntimeError):
77+
midi.decode_example({"bytes": b"fake", "path": None})

0 commit comments

Comments
 (0)