Skip to content

Commit

Permalink
Migrate AbsTokenizer (#3)
Browse files Browse the repository at this point in the history
* add skeleton

* port midi.py

* update path for maestro metadata json

* add tests and ci

* add space

* update midi tests

* add abstract tokenizer class

* fix mypy and upgrade to pep 585

* rmv import

* fix docstring

* migrate abstokenizer

* fix mypy
  • Loading branch information
loubbrad authored Nov 20, 2024
1 parent da74e2d commit 44fa104
Show file tree
Hide file tree
Showing 7 changed files with 1,207 additions and 74 deletions.
129 changes: 66 additions & 63 deletions ariautils/midi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Utils for data/MIDI processing."""
"""Utils for MIDI processing."""

import re
import os
Expand All @@ -7,25 +7,28 @@
import unicodedata
import mido

from mido.midifiles.units import tick2second
from collections import defaultdict
from pathlib import Path
from typing import (
List,
Dict,
Any,
Tuple,
Final,
Concatenate,
Callable,
TypeAlias,
Literal,
TypedDict,
cast,
)

from mido.midifiles.units import tick2second
from ariautils.utils import load_config, load_maestro_metadata_json


# TODO:
# - Remove unneeded comments
# - Add asserts


class MetaMessage(TypedDict):
"""Meta message type corresponding text or copyright MIDI meta messages."""

Expand Down Expand Up @@ -83,37 +86,37 @@ class NoteMessage(TypedDict):
class MidiDictData(TypedDict):
"""Type for MidiDict attributes in dictionary form."""

meta_msgs: List[MetaMessage]
tempo_msgs: List[TempoMessage]
pedal_msgs: List[PedalMessage]
instrument_msgs: List[InstrumentMessage]
note_msgs: List[NoteMessage]
meta_msgs: list[MetaMessage]
tempo_msgs: list[TempoMessage]
pedal_msgs: list[PedalMessage]
instrument_msgs: list[InstrumentMessage]
note_msgs: list[NoteMessage]
ticks_per_beat: int
metadata: Dict[str, Any]
metadata: dict[str, Any]


class MidiDict:
"""Container for MIDI data in dictionary form.
Args:
meta_msgs (List[MetaMessage]): List of text or copyright MIDI meta messages.
tempo_msgs (List[TempoMessage]): List of tempo change messages.
pedal_msgs (List[PedalMessage]): List of sustain pedal messages.
instrument_msgs (List[InstrumentMessage]): List of program change messages.
note_msgs (List[NoteMessage]): List of note messages from paired note-on/off events.
meta_msgs (list[MetaMessage]): List of text or copyright MIDI meta messages.
tempo_msgs (list[TempoMessage]): List of tempo change messages.
pedal_msgs (list[PedalMessage]): List of sustain pedal messages.
instrument_msgs (list[InstrumentMessage]): List of program change messages.
note_msgs (list[NoteMessage]): List of note messages from paired note-on/off events.
ticks_per_beat (int): MIDI ticks per beat.
metadata (dict): Optional metadata key-value pairs (e.g., {"genre": "classical"}).
"""

def __init__(
self,
meta_msgs: List[MetaMessage],
tempo_msgs: List[TempoMessage],
pedal_msgs: List[PedalMessage],
instrument_msgs: List[InstrumentMessage],
note_msgs: List[NoteMessage],
meta_msgs: list[MetaMessage],
tempo_msgs: list[TempoMessage],
pedal_msgs: list[PedalMessage],
instrument_msgs: list[InstrumentMessage],
note_msgs: list[NoteMessage],
ticks_per_beat: int,
metadata: Dict[str, Any],
metadata: dict[str, Any],
):
self.meta_msgs = meta_msgs
self.tempo_msgs = tempo_msgs
Expand Down Expand Up @@ -147,10 +150,10 @@ def __init__(
self.program_to_instrument = self.get_program_to_instrument()

@classmethod
def get_program_to_instrument(cls) -> Dict[int, str]:
def get_program_to_instrument(cls) -> dict[int, str]:
"""Return a map of MIDI program to instrument name."""

PROGRAM_TO_INSTRUMENT: Final[Dict[int, str]] = (
PROGRAM_TO_INSTRUMENT: Final[dict[int, str]] = (
{i: "piano" for i in range(0, 7 + 1)}
| {i: "chromatic" for i in range(8, 15 + 1)}
| {i: "organ" for i in range(16, 23 + 1)}
Expand Down Expand Up @@ -213,7 +216,7 @@ def from_midi(cls, mid_path: str | Path) -> "MidiDict":
return cls(**midi_to_dict(mid))

def calculate_hash(self) -> str:
msg_dict_to_hash = dict(self.get_msg_dict())
msg_dict_to_hash = cast(dict, self.get_msg_dict())

# Remove metadata before calculating hash
msg_dict_to_hash.pop("meta_msgs")
Expand All @@ -234,12 +237,12 @@ def tick_to_ms(self, tick: int) -> int:
ticks_per_beat=self.ticks_per_beat,
)

def _build_pedal_intervals(self) -> Dict[int, List[List[int]]]:
def _build_pedal_intervals(self) -> dict[int, list[list[int]]]:
"""Returns a mapping of channels to sustain pedal intervals."""

self.pedal_msgs.sort(key=lambda msg: msg["tick"])
channel_to_pedal_intervals = defaultdict(list)
pedal_status: Dict[int, int] = {}
pedal_status: dict[int, int] = {}

for pedal_msg in self.pedal_msgs:
tick = pedal_msg["tick"]
Expand Down Expand Up @@ -276,7 +279,7 @@ def resolve_overlaps(self) -> "MidiDict":
"""

# Organize notes by channel and pitch
note_msgs_c: Dict[int, Dict[int, List[NoteMessage]]] = defaultdict(
note_msgs_c: dict[int, dict[int, list[NoteMessage]]] = defaultdict(
lambda: defaultdict(list)
)
for msg in self.note_msgs:
Expand Down Expand Up @@ -330,7 +333,7 @@ def resolve_pedal(self) -> "MidiDict":

return self

# TODO: Needs to be refactored and tested
# TODO: Needs to be refactored
def remove_redundant_pedals(self) -> "MidiDict":
"""Removes redundant pedal messages from the MIDI data in place.
Expand All @@ -342,7 +345,7 @@ def remove_redundant_pedals(self) -> "MidiDict":
def _is_pedal_useful(
pedal_start_tick: int,
pedal_end_tick: int,
note_msgs: List[NoteMessage],
note_msgs: list[NoteMessage],
) -> bool:
# This logic loops through the note_msgs that could possibly
# be effected by the pedal which starts at pedal_start_tick
Expand Down Expand Up @@ -486,7 +489,7 @@ def remove_instruments(self, config: dict) -> "MidiDict":
channels_to_remove = [i for i in channels_to_remove if i != 9]

# Remove unwanted messages all type by looping over msgs types
_msg_dict: Dict[str, List] = {
_msg_dict: dict[str, list] = {
"meta_msgs": self.meta_msgs,
"tempo_msgs": self.tempo_msgs,
"pedal_msgs": self.pedal_msgs,
Expand All @@ -511,20 +514,20 @@ def remove_instruments(self, config: dict) -> "MidiDict":
# TODO: The sign has been changed. Make sure this function isn't used anywhere else
def _extract_track_data(
track: mido.MidiTrack,
) -> Tuple[
List[MetaMessage],
List[TempoMessage],
List[PedalMessage],
List[InstrumentMessage],
List[NoteMessage],
) -> tuple[
list[MetaMessage],
list[TempoMessage],
list[PedalMessage],
list[InstrumentMessage],
list[NoteMessage],
]:
"""Converts MIDI messages into format used by MidiDict."""

meta_msgs: List[MetaMessage] = []
tempo_msgs: List[TempoMessage] = []
pedal_msgs: List[PedalMessage] = []
instrument_msgs: List[InstrumentMessage] = []
note_msgs: List[NoteMessage] = []
meta_msgs: list[MetaMessage] = []
tempo_msgs: list[TempoMessage] = []
pedal_msgs: list[PedalMessage] = []
instrument_msgs: list[InstrumentMessage] = []
note_msgs: list[NoteMessage] = []

last_note_on = defaultdict(list)
for message in track:
Expand Down Expand Up @@ -684,7 +687,7 @@ def midi_to_dict(mid: mido.MidiFile) -> MidiDictData:
metadata_fn = get_metadata_fn(
metadata_process_name=metadata_process_name
)
fn_args: Dict = metadata_process_config["args"]
fn_args: dict = metadata_process_config["args"]

collected_metadata = metadata_fn(mid, midi_dict_data, **fn_args)
if collected_metadata:
Expand Down Expand Up @@ -788,11 +791,11 @@ def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile:
)

# Magic sorting function
def _sort_fn(msg: mido.Message) -> Tuple[int, int]:
def _sort_fn(msg: mido.Message) -> tuple[int, int]:
if hasattr(msg, "velocity"):
return (msg.time, msg.velocity)
return (msg.time, msg.velocity) # pyright: ignore
else:
return (msg.time, 1000)
return (msg.time, 1000) # pyright: ignore

# Sort and convert from abs_time -> delta_time
track = sorted(track, key=_sort_fn)
Expand All @@ -812,7 +815,7 @@ def _sort_fn(msg: mido.Message) -> Tuple[int, int]:
def get_duration_ms(
start_tick: int,
end_tick: int,
tempo_msgs: List[TempoMessage],
tempo_msgs: list[TempoMessage],
ticks_per_beat: int,
) -> int:
"""Calculates elapsed time (in ms) between start_tick and end_tick."""
Expand Down Expand Up @@ -897,7 +900,7 @@ def to_ascii(s: str) -> str:

def meta_composer_filename(
mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list
) -> Dict[str, str]:
) -> dict[str, str]:
file_name = Path(str(mid.filename)).stem
matched_names_unique = set()
for name in composer_names:
Expand All @@ -914,7 +917,7 @@ def meta_composer_filename(

def meta_form_filename(
mid: mido.MidiFile, msg_data: MidiDictData, form_names: list
) -> Dict[str, str]:
) -> dict[str, str]:
file_name = Path(str(mid.filename)).stem
matched_names_unique = set()
for name in form_names:
Expand All @@ -931,7 +934,7 @@ def meta_form_filename(

def meta_composer_metamsg(
mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list
) -> Dict[str, str]:
) -> dict[str, str]:
matched_names_unique = set()
for msg in msg_data["meta_msgs"]:
for name in composer_names:
Expand All @@ -952,7 +955,7 @@ def meta_maestro_json(
msg_data: MidiDictData,
composer_names: list,
form_names: list,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Loads composer and form metadata from MAESTRO metadata json file.
Expand Down Expand Up @@ -990,16 +993,16 @@ def meta_maestro_json(
return res


def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> Dict[str, str]:
def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> dict[str, str]:
return {"abs_path": str(Path(str(mid.filename)).absolute())}


def get_metadata_fn(
metadata_process_name: str,
) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]]:
name_to_fn: Dict[
) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], dict[str, str]]:
name_to_fn: dict[
str,
Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]],
Callable[Concatenate[mido.MidiFile, MidiDictData, ...], dict[str, str]],
] = {
"composer_filename": meta_composer_filename,
"composer_metamsg": meta_composer_metamsg,
Expand All @@ -1017,7 +1020,7 @@ def get_metadata_fn(
return fn


def test_max_programs(midi_dict: MidiDict, max: int) -> Tuple[bool, int]:
def test_max_programs(midi_dict: MidiDict, max: int) -> tuple[bool, int]:
"""Returns false if midi_dict uses more than {max} programs."""
present_programs = set(
map(
Expand All @@ -1032,7 +1035,7 @@ def test_max_programs(midi_dict: MidiDict, max: int) -> Tuple[bool, int]:
return False, len(present_programs)


def test_max_instruments(midi_dict: MidiDict, max: int) -> Tuple[bool, int]:
def test_max_instruments(midi_dict: MidiDict, max: int) -> tuple[bool, int]:
present_instruments = set(
map(
lambda msg: midi_dict.program_to_instrument[msg["data"]],
Expand All @@ -1048,7 +1051,7 @@ def test_max_instruments(midi_dict: MidiDict, max: int) -> Tuple[bool, int]:

def test_note_frequency(
midi_dict: MidiDict, max_per_second: float, min_per_second: float
) -> Tuple[bool, float]:
) -> tuple[bool, float]:
if not midi_dict.note_msgs:
return False, 0.0

Expand All @@ -1073,7 +1076,7 @@ def test_note_frequency(

def test_note_frequency_per_instrument(
midi_dict: MidiDict, max_per_second: float, min_per_second: float
) -> Tuple[bool, float]:
) -> tuple[bool, float]:
num_instruments = len(
set(
map(
Expand Down Expand Up @@ -1111,7 +1114,7 @@ def test_note_frequency_per_instrument(

def test_min_length(
midi_dict: MidiDict, min_seconds: int
) -> Tuple[bool, float]:
) -> tuple[bool, float]:
if not midi_dict.note_msgs:
return False, 0.0

Expand All @@ -1130,9 +1133,9 @@ def test_min_length(

def get_test_fn(
test_name: str,
) -> Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]]:
name_to_fn: Dict[
str, Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]]
) -> Callable[Concatenate[MidiDict, ...], tuple[bool, Any]]:
name_to_fn: dict[
str, Callable[Concatenate[MidiDict, ...], tuple[bool, Any]]
] = {
"max_programs": test_max_programs,
"max_instruments": test_max_instruments,
Expand Down
5 changes: 5 additions & 0 deletions ariautils/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Includes Tokenizers and pre-processing utilities."""

from ariautils.tokenizer._base import Tokenizer

__all__ = ["Tokenizer"]
Loading

0 comments on commit 44fa104

Please sign in to comment.