From c81dfe2b5fc87ad40b2d5e9a573a503cd7c95c1b Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 22 Nov 2023 10:05:56 +0200 Subject: [PATCH] More idiomatic py.testing (#41) * Make py.test use more idiomatic (parametrize instead of internal looping) * CI: show tests and test durations --- .github/workflows/pytest.yml | 3 +- pyproject.toml | 2 - tests/__init__.py | 0 tests/test_pianoroll.py | 105 ++++++++++++++++------------------- tests/test_read_dump.py | 47 +++++----------- tests/utils.py | 5 ++ 6 files changed, 70 insertions(+), 92 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/utils.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index b7aa6ad..d619caa 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -30,8 +30,7 @@ jobs: python -m pip install --upgrade pip pip install -e ".[tests]" - name: Test with pytest - run: | - pytest --cov=./ --cov-report=xml -n auto + run: pytest --cov=./ --cov-report=xml -n auto --durations=0 -v - name: Codecov uses: codecov/codecov-action@v3.1.0 build: diff --git a/pyproject.toml b/pyproject.toml index 14de95b..1cc5986 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,8 +39,6 @@ dependencies = [ tests = [ "pytest-cov", "pytest-xdist[psutil]", - "setuptools", - "tqdm", ] [project.urls] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_pianoroll.py b/tests/test_pianoroll.py index 9fc86f2..8e2168b 100644 --- a/tests/test_pianoroll.py +++ b/tests/test_pianoroll.py @@ -1,71 +1,64 @@ #!/usr/bin/python3 python +from operator import attrgetter -"""Testing creating pianorolls of notes. - -""" - -from pathlib import Path - -from tqdm import tqdm +import pytest from miditoolkit import MidiFile from miditoolkit.constants import PITCH_RANGE from miditoolkit.pianoroll import notes2pianoroll, pianoroll2notes +from tests.utils import MIDI_PATHS +test_sets = [ + {"pitch_range": (0, 127)}, + {"pitch_range": (24, 96)}, + {"pitch_range": (24, 116), "pitch_offset": 12}, + {"pitch_range": (6, 96), "pitch_offset": 12}, + {"pitch_range": (24, 96), "pitch_offset": 12, "velocity_threshold": 36}, +] -def test_pianoroll(): - midi_paths = list(Path("tests", "testcases").glob("**/*.mid")) - test_sets = [ - {"pitch_range": (0, 127)}, - {"pitch_range": (24, 96)}, - {"pitch_range": (24, 116), "pitch_offset": 12}, - {"pitch_range": (6, 96), "pitch_offset": 12}, - {"pitch_range": (24, 96), "pitch_offset": 12, "velocity_threshold": 36}, - ] - - for path in tqdm(midi_paths, desc="Checking pianoroll conversion"): - midi = MidiFile(path) - for track in midi.instruments: - # We do a first notes -> pianoroll -> notes conversion before - # This step is required as the pianoroll conversion is lossy with overlapping notes. - # notes2pianoroll has a "last income priority" logic, for which if a notes is occurs - # when another one of the same pitch is already being played, this new note will be - # represented and will end the previous one (if they have different velocities). +@pytest.mark.parametrize("midi_path", MIDI_PATHS, ids=attrgetter("name")) +@pytest.mark.parametrize("test_set", test_sets) +def test_pianoroll(midi_path, test_set): + """Testing creating pianorolls of notes.""" - for test_set in test_sets: - # Set pitch range parameters - pitch_range = test_set.get("pitch_range", PITCH_RANGE) - if "pitch_offset" in test_set: - pitch_range = ( - max(PITCH_RANGE[0], pitch_range[0] - test_set["pitch_offset"]), - min(PITCH_RANGE[1], pitch_range[1] + test_set["pitch_offset"]), - ) + # Set pitch range parameters + pitch_range = test_set.get("pitch_range", PITCH_RANGE) + if "pitch_offset" in test_set: + pitch_range = ( + max(PITCH_RANGE[0], pitch_range[0] - test_set["pitch_offset"]), + min(PITCH_RANGE[1], pitch_range[1] + test_set["pitch_offset"]), + ) - # First pianoroll <--> notes conversion, losing overlapping notes - pianoroll = notes2pianoroll(track.notes, **test_set) - new_notes = pianoroll2notes(pianoroll, pitch_range=pitch_range) + midi = MidiFile(midi_path) - # Second one, notes -> pianoroll -> new notes should be equal - new_pianoroll = notes2pianoroll(new_notes, **test_set) - new_new_notes = pianoroll2notes(new_pianoroll, pitch_range=pitch_range) - if "velocity_threshold" in test_set: - new_notes = [ - note - for note in new_notes - if note.velocity >= test_set["velocity_threshold"] - ] + for track in midi.instruments: + # We do a first notes -> pianoroll -> notes conversion before + # This step is required as the pianoroll conversion is lossy with overlapping notes. + # notes2pianoroll has a "last income priority" logic, for which if a notes is occurs + # when another one of the same pitch is already being played, this new note will be + # represented and will end the previous one (if they have different velocities). - # Assert notes are all retrieved - assert len(new_notes) == len( - new_new_notes - ), "Number of notes changed in pianoroll conversion" - for note1, note2 in zip(new_notes, new_new_notes): - # We don't test the resampling factor as it might later the number of notes - assert ( - note1 == note2 - ), "Notes before and after pianoroll conversion are not the same" + # First pianoroll <--> notes conversion, losing overlapping notes + pianoroll = notes2pianoroll(track.notes, **test_set) + new_notes = pianoroll2notes(pianoroll, pitch_range=pitch_range) + # Second one, notes -> pianoroll -> new notes should be equal + new_pianoroll = notes2pianoroll(new_notes, **test_set) + new_new_notes = pianoroll2notes(new_pianoroll, pitch_range=pitch_range) + if "velocity_threshold" in test_set: + new_notes = [ + note + for note in new_notes + if note.velocity >= test_set["velocity_threshold"] + ] -if __name__ == "__main__": - test_pianoroll() + # Assert notes are all retrieved + assert len(new_notes) == len( + new_new_notes + ), "Number of notes changed in pianoroll conversion" + for note1, note2 in zip(new_notes, new_new_notes): + # We don't test the resampling factor as it might later the number of notes + assert ( + note1 == note2 + ), "Notes before and after pianoroll conversion are not the same" diff --git a/tests/test_read_dump.py b/tests/test_read_dump.py index 0b32bf0..9efdcba 100644 --- a/tests/test_read_dump.py +++ b/tests/test_read_dump.py @@ -1,39 +1,22 @@ -#!/usr/bin/python3 python +from operator import attrgetter -"""Testing that a MIDI loaded and saved unchanged is indeed the save as before. - -""" - -import shutil -from pathlib import Path - -from tqdm import tqdm +import pytest from miditoolkit import MidiFile +from tests.utils import MIDI_PATHS -def test_load_dump(): - midi_paths = list(Path("tests", "testcases").glob("**/*.mid")) - out_path = Path("tests", "tmp", "load_dump") - out_path.mkdir(parents=True, exist_ok=True) - - for path in tqdm(midi_paths, desc="Checking midis load/save"): - midi = MidiFile(path) - # Writing it unchanged - midi.dump(out_path / path.name) - # Loading it back - midi2 = MidiFile(out_path / path.name) - - # Sorting the notes, as after dump the order might have changed - for track1, track2 in zip(midi.instruments, midi2.instruments): - track1.notes.sort(key=lambda x: (x.start, x.pitch, x.end, x.velocity)) - track2.notes.sort(key=lambda x: (x.start, x.pitch, x.end, x.velocity)) - - assert midi == midi2 - - # deletes tmp directory after tests - shutil.rmtree(out_path) +@pytest.mark.parametrize("midi_path", MIDI_PATHS, ids=attrgetter("name")) +def test_load_dump(midi_path, tmp_path): + """Test that a MIDI loaded and saved unchanged is indeed the save as before.""" + midi1 = MidiFile(midi_path) + dump_path = tmp_path / midi_path.name + midi1.dump(dump_path) # Writing it unchanged + midi2 = MidiFile(dump_path) # Loading it back + # Sorting the notes, as after dump the order might have changed + for track1, track2 in zip(midi1.instruments, midi2.instruments): + track1.notes.sort(key=lambda x: (x.start, x.pitch, x.end, x.velocity)) + track2.notes.sort(key=lambda x: (x.start, x.pitch, x.end, x.velocity)) -if __name__ == "__main__": - test_load_dump() + assert midi1 == midi2 diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..e246094 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,5 @@ +from pathlib import Path + +HERE = Path(__file__).parent + +MIDI_PATHS = sorted((HERE / "testcases").rglob("*.mid"))