Skip to content

Commit 78c7d2a

Browse files
committed
chore: fix check code quality
1 parent 88d6dd3 commit 78c7d2a

File tree

2 files changed

+20
-37
lines changed

2 files changed

+20
-37
lines changed

src/datasets/features/midi.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -118,25 +118,17 @@ def _create_midi_from_data(cls, value: dict) -> "pretty_midi.PrettyMIDI":
118118
except ImportError as err:
119119
raise ImportError("To support encoding MIDI data, please install 'pretty_midi'.") from err
120120

121-
122-
123121
# Create piano instrument
124-
piano_program = pretty_midi.instrument_name_to_program('Acoustic Grand Piano')
122+
piano_program = pretty_midi.instrument_name_to_program("Acoustic Grand Piano")
125123
piano = pretty_midi.Instrument(program=piano_program)
126124

127125
notes = value.get("notes", [])
128126
for note_data in notes:
129127
if len(note_data) >= 4:
130128
pitch, velocity, start, end = note_data[:4]
131-
note = pretty_midi.Note(
132-
velocity=int(velocity),
133-
pitch=int(pitch),
134-
start=float(start),
135-
end=float(end)
136-
)
129+
note = pretty_midi.Note(velocity=int(velocity), pitch=int(pitch), start=float(start), end=float(end))
137130
piano.notes.append(note)
138131

139-
140132
if "tempo" in value:
141133
midi = pretty_midi.PrettyMIDI(initial_tempo=value["tempo"])
142134
else:
@@ -208,10 +200,9 @@ def decode_example(
208200
# Extract instrument information
209201
instruments = []
210202
for instrument in midi.instruments:
211-
instruments.append({
212-
"program": instrument.program,
213-
"name": pretty_midi.program_to_instrument_name(instrument.program)
214-
})
203+
instruments.append(
204+
{"program": instrument.program, "name": pretty_midi.program_to_instrument_name(instrument.program)}
205+
)
215206

216207
# Get tempo
217208
tempo = 120.0 # Default tempo

tests/features/test_midi.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
import os
2-
import tempfile
31
from unittest import TestCase
42

5-
import numpy as np
63
import pytest
74

8-
from datasets import Dataset, Features, load_dataset
9-
from datasets.features import Midi, Value
5+
from datasets import Dataset, Features
6+
from datasets.features import Midi
107

118

129
class TestMidiFeature(TestCase):
@@ -17,30 +14,26 @@ def test_audio_feature_type(self):
1714

1815
def test_audio_feature_encode_example(self):
1916
midi = Midi()
20-
17+
2118
# Test with path
2219
encoded = midi.encode_example("path/to/midi.mid")
2320
assert encoded == {"bytes": None, "path": "path/to/midi.mid"}
24-
21+
2522
# Test with bytes
2623
encoded = midi.encode_example(b"fake_midi_bytes")
2724
assert encoded == {"bytes": b"fake_midi_bytes", "path": None}
28-
25+
2926
# Test with dict containing notes
30-
notes_data = {
31-
"notes": [[60, 64, 0.0, 1.0], [62, 64, 1.0, 2.0]],
32-
"tempo": 120.0,
33-
"resolution": 480
34-
}
27+
notes_data = {"notes": [[60, 64, 0.0, 1.0], [62, 64, 1.0, 2.0]], "tempo": 120.0, "resolution": 480}
3528
encoded = midi.encode_example(notes_data)
3629
assert "bytes" in encoded
3730
assert encoded["path"] is None
3831

3932
def test_audio_feature_decode_example(self):
4033
midi = Midi()
41-
34+
4235
# Test decode with bytes
43-
fake_midi_bytes = b'MThd\x00\x00\x00\x06\x00\x01\x00\x02\x00\xdcMTrk\x00\x00\x00\x13\x00\xffQ\x03\x07\xa1 \x00\xffX\x04\x04\x02\x18\x08\x01\xff/\x00MTrk\x00\x00\x00\x16\x00\xc0\x00\x00\x90<@\x838<\x00\x00>@\x838>\x00\x01\xff/\x00'
36+
fake_midi_bytes = b"MThd\x00\x00\x00\x06\x00\x01\x00\x02\x00\xdcMTrk\x00\x00\x00\x13\x00\xffQ\x03\x07\xa1 \x00\xffX\x04\x04\x02\x18\x08\x01\xff/\x00MTrk\x00\x00\x00\x16\x00\xc0\x00\x00\x90<@\x838<\x00\x00>@\x838>\x00\x01\xff/\x00"
4437
decoded = midi.decode_example({"bytes": fake_midi_bytes, "path": None})
4538
assert "notes" in decoded
4639
assert "tempo" in decoded
@@ -50,11 +43,10 @@ def test_audio_feature_decode_example(self):
5043
def test_audio_feature_with_dataset(self):
5144
features = Features({"midi": Midi()})
5245
data = {"midi": ["fake_path1.mid", "fake_path2.mid"]}
53-
54-
with tempfile.TemporaryDirectory() as tmp_dir:
55-
dataset = Dataset.from_dict(data, features=features)
56-
assert "midi" in dataset.column_names
57-
assert dataset.features["midi"].dtype == "dict"
46+
47+
dataset = Dataset.from_dict(data, features=features)
48+
assert "midi" in dataset.column_names
49+
assert dataset.features["midi"].dtype == "dict"
5850

5951
def test_audio_feature_decode_false(self):
6052
midi = Midi(decode=False)
@@ -68,10 +60,10 @@ def test_audio_feature_resolution(self):
6860
def test_audio_feature_flatten(self):
6961
midi = Midi(decode=False)
7062
flattened = midi.flatten()
71-
assert "bytes" in flattened
72-
assert "path" in flattened
63+
assert "bytes" in flattened # type: ignore
64+
assert "path" in flattened # type: ignore
7365

7466
def test_audio_feature_decode_error(self):
7567
midi = Midi(decode=False)
7668
with pytest.raises(RuntimeError):
77-
midi.decode_example({"bytes": b"fake", "path": None})
69+
midi.decode_example({"bytes": b"fake", "path": None})

0 commit comments

Comments
 (0)