1+ import os
2+ import tempfile
3+ from unittest import TestCase
4+
5+ import numpy as np
6+ import pytest
7+
8+ from datasets import Dataset , Features , load_dataset
9+ from datasets .features import Midi , Value
10+
11+
12+ class TestMidiFeature (TestCase ):
13+ def test_audio_feature_type (self ):
14+ midi = Midi ()
15+ assert midi .dtype == "dict"
16+ assert midi .pa_type .names == ["bytes" , "path" ]
17+
18+ def test_audio_feature_encode_example (self ):
19+ midi = Midi ()
20+
21+ # Test with path
22+ encoded = midi .encode_example ("path/to/midi.mid" )
23+ assert encoded == {"bytes" : None , "path" : "path/to/midi.mid" }
24+
25+ # Test with bytes
26+ encoded = midi .encode_example (b"fake_midi_bytes" )
27+ assert encoded == {"bytes" : b"fake_midi_bytes" , "path" : None }
28+
29+ # Test with dict containing notes
30+ notes_data = {
31+ "notes" : [[60 , 64 , 0.0 , 1.0 ], [62 , 64 , 1.0 , 2.0 ]],
32+ "tempo" : 120.0 ,
33+ "resolution" : 480
34+ }
35+ encoded = midi .encode_example (notes_data )
36+ assert "bytes" in encoded
37+ assert encoded ["path" ] is None
38+
39+ def test_audio_feature_decode_example (self ):
40+ midi = Midi ()
41+
42+ # Test decode with bytes
43+ fake_midi_bytes = b'MThd\x00 \x00 \x00 \x06 \x00 \x01 \x00 \x02 \x00 \xdc MTrk\x00 \x00 \x00 \x13 \x00 \xff Q\x03 \x07 \xa1 \x00 \xff X\x04 \x04 \x02 \x18 \x08 \x01 \xff /\x00 MTrk\x00 \x00 \x00 \x16 \x00 \xc0 \x00 \x00 \x90 <@\x83 8<\x00 \x00 >@\x83 8>\x00 \x01 \xff /\x00 '
44+ decoded = midi .decode_example ({"bytes" : fake_midi_bytes , "path" : None })
45+ assert "notes" in decoded
46+ assert "tempo" in decoded
47+ assert "resolution" in decoded
48+ assert "instruments" in decoded
49+
50+ def test_audio_feature_with_dataset (self ):
51+ features = Features ({"midi" : Midi ()})
52+ data = {"midi" : ["fake_path1.mid" , "fake_path2.mid" ]}
53+
54+ with tempfile .TemporaryDirectory () as tmp_dir :
55+ dataset = Dataset .from_dict (data , features = features )
56+ assert "midi" in dataset .column_names
57+ assert dataset .features ["midi" ].dtype == "dict"
58+
59+ def test_audio_feature_decode_false (self ):
60+ midi = Midi (decode = False )
61+ encoded = midi .encode_example ("path/to/midi.mid" )
62+ assert encoded == {"bytes" : None , "path" : "path/to/midi.mid" }
63+
64+ def test_audio_feature_resolution (self ):
65+ midi = Midi (resolution = 960 )
66+ assert midi .resolution == 960
67+
68+ def test_audio_feature_flatten (self ):
69+ midi = Midi (decode = False )
70+ flattened = midi .flatten ()
71+ assert "bytes" in flattened
72+ assert "path" in flattened
73+
74+ def test_audio_feature_decode_error (self ):
75+ midi = Midi (decode = False )
76+ with pytest .raises (RuntimeError ):
77+ midi .decode_example ({"bytes" : b"fake" , "path" : None })
0 commit comments