1- import os
2- import tempfile
31from unittest import TestCase
42
5- import numpy as np
63import pytest
74
8- from datasets import Dataset , Features , load_dataset
9- from datasets .features import Midi , Value
5+ from datasets import Dataset , Features
6+ from datasets .features import Midi
107
118
129class TestMidiFeature (TestCase ):
@@ -17,30 +14,26 @@ def test_audio_feature_type(self):
1714
1815 def test_audio_feature_encode_example (self ):
1916 midi = Midi ()
20-
17+
2118 # Test with path
2219 encoded = midi .encode_example ("path/to/midi.mid" )
2320 assert encoded == {"bytes" : None , "path" : "path/to/midi.mid" }
24-
21+
2522 # Test with bytes
2623 encoded = midi .encode_example (b"fake_midi_bytes" )
2724 assert encoded == {"bytes" : b"fake_midi_bytes" , "path" : None }
28-
25+
2926 # Test with dict containing notes
30- notes_data = {
31- "notes" : [[60 , 64 , 0.0 , 1.0 ], [62 , 64 , 1.0 , 2.0 ]],
32- "tempo" : 120.0 ,
33- "resolution" : 480
34- }
27+ notes_data = {"notes" : [[60 , 64 , 0.0 , 1.0 ], [62 , 64 , 1.0 , 2.0 ]], "tempo" : 120.0 , "resolution" : 480 }
3528 encoded = midi .encode_example (notes_data )
3629 assert "bytes" in encoded
3730 assert encoded ["path" ] is None
3831
3932 def test_audio_feature_decode_example (self ):
4033 midi = Midi ()
41-
34+
4235 # Test decode with bytes
43- fake_midi_bytes = b' MThd\x00 \x00 \x00 \x06 \x00 \x01 \x00 \x02 \x00 \xdc MTrk\x00 \x00 \x00 \x13 \x00 \xff Q\x03 \x07 \xa1 \x00 \xff X\x04 \x04 \x02 \x18 \x08 \x01 \xff /\x00 MTrk\x00 \x00 \x00 \x16 \x00 \xc0 \x00 \x00 \x90 <@\x83 8<\x00 \x00 >@\x83 8>\x00 \x01 \xff /\x00 '
36+ fake_midi_bytes = b" MThd\x00 \x00 \x00 \x06 \x00 \x01 \x00 \x02 \x00 \xdc MTrk\x00 \x00 \x00 \x13 \x00 \xff Q\x03 \x07 \xa1 \x00 \xff X\x04 \x04 \x02 \x18 \x08 \x01 \xff /\x00 MTrk\x00 \x00 \x00 \x16 \x00 \xc0 \x00 \x00 \x90 <@\x83 8<\x00 \x00 >@\x83 8>\x00 \x01 \xff /\x00 "
4437 decoded = midi .decode_example ({"bytes" : fake_midi_bytes , "path" : None })
4538 assert "notes" in decoded
4639 assert "tempo" in decoded
@@ -50,11 +43,10 @@ def test_audio_feature_decode_example(self):
5043 def test_audio_feature_with_dataset (self ):
5144 features = Features ({"midi" : Midi ()})
5245 data = {"midi" : ["fake_path1.mid" , "fake_path2.mid" ]}
53-
54- with tempfile .TemporaryDirectory () as tmp_dir :
55- dataset = Dataset .from_dict (data , features = features )
56- assert "midi" in dataset .column_names
57- assert dataset .features ["midi" ].dtype == "dict"
46+
47+ dataset = Dataset .from_dict (data , features = features )
48+ assert "midi" in dataset .column_names
49+ assert dataset .features ["midi" ].dtype == "dict"
5850
5951 def test_audio_feature_decode_false (self ):
6052 midi = Midi (decode = False )
@@ -68,10 +60,10 @@ def test_audio_feature_resolution(self):
6860 def test_audio_feature_flatten (self ):
6961 midi = Midi (decode = False )
7062 flattened = midi .flatten ()
71- assert "bytes" in flattened
72- assert "path" in flattened
63+ assert "bytes" in flattened # type: ignore
64+ assert "path" in flattened # type: ignore
7365
7466 def test_audio_feature_decode_error (self ):
7567 midi = Midi (decode = False )
7668 with pytest .raises (RuntimeError ):
77- midi .decode_example ({"bytes" : b"fake" , "path" : None })
69+ midi .decode_example ({"bytes" : b"fake" , "path" : None })
0 commit comments