Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate AbsTokenizer #3

Merged
merged 13 commits into from
Nov 20, 2024
129 changes: 66 additions & 63 deletions ariautils/midi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Utils for data/MIDI processing."""
"""Utils for MIDI processing."""

import re
import os
Expand All @@ -7,25 +7,28 @@
import unicodedata
import mido

from mido.midifiles.units import tick2second
from collections import defaultdict
from pathlib import Path
from typing import (
List,
Dict,
Any,
Tuple,
Final,
Concatenate,
Callable,
TypeAlias,
Literal,
TypedDict,
cast,
)

from mido.midifiles.units import tick2second
from ariautils.utils import load_config, load_maestro_metadata_json


# TODO:
# - Remove unneeded comments
# - Add asserts


class MetaMessage(TypedDict):
"""Meta message type corresponding text or copyright MIDI meta messages."""

Expand Down Expand Up @@ -83,37 +86,37 @@ class NoteMessage(TypedDict):
class MidiDictData(TypedDict):
"""Type for MidiDict attributes in dictionary form."""

meta_msgs: List[MetaMessage]
tempo_msgs: List[TempoMessage]
pedal_msgs: List[PedalMessage]
instrument_msgs: List[InstrumentMessage]
note_msgs: List[NoteMessage]
meta_msgs: list[MetaMessage]
tempo_msgs: list[TempoMessage]
pedal_msgs: list[PedalMessage]
instrument_msgs: list[InstrumentMessage]
note_msgs: list[NoteMessage]
ticks_per_beat: int
metadata: Dict[str, Any]
metadata: dict[str, Any]


class MidiDict:
"""Container for MIDI data in dictionary form.
Args:
meta_msgs (List[MetaMessage]): List of text or copyright MIDI meta messages.
tempo_msgs (List[TempoMessage]): List of tempo change messages.
pedal_msgs (List[PedalMessage]): List of sustain pedal messages.
instrument_msgs (List[InstrumentMessage]): List of program change messages.
note_msgs (List[NoteMessage]): List of note messages from paired note-on/off events.
meta_msgs (list[MetaMessage]): List of text or copyright MIDI meta messages.
tempo_msgs (list[TempoMessage]): List of tempo change messages.
pedal_msgs (list[PedalMessage]): List of sustain pedal messages.
instrument_msgs (list[InstrumentMessage]): List of program change messages.
note_msgs (list[NoteMessage]): List of note messages from paired note-on/off events.
ticks_per_beat (int): MIDI ticks per beat.
metadata (dict): Optional metadata key-value pairs (e.g., {"genre": "classical"}).
"""

def __init__(
self,
meta_msgs: List[MetaMessage],
tempo_msgs: List[TempoMessage],
pedal_msgs: List[PedalMessage],
instrument_msgs: List[InstrumentMessage],
note_msgs: List[NoteMessage],
meta_msgs: list[MetaMessage],
tempo_msgs: list[TempoMessage],
pedal_msgs: list[PedalMessage],
instrument_msgs: list[InstrumentMessage],
note_msgs: list[NoteMessage],
ticks_per_beat: int,
metadata: Dict[str, Any],
metadata: dict[str, Any],
):
self.meta_msgs = meta_msgs
self.tempo_msgs = tempo_msgs
Expand Down Expand Up @@ -147,10 +150,10 @@ def __init__(
self.program_to_instrument = self.get_program_to_instrument()

@classmethod
def get_program_to_instrument(cls) -> Dict[int, str]:
def get_program_to_instrument(cls) -> dict[int, str]:
"""Return a map of MIDI program to instrument name."""

PROGRAM_TO_INSTRUMENT: Final[Dict[int, str]] = (
PROGRAM_TO_INSTRUMENT: Final[dict[int, str]] = (
{i: "piano" for i in range(0, 7 + 1)}
| {i: "chromatic" for i in range(8, 15 + 1)}
| {i: "organ" for i in range(16, 23 + 1)}
Expand Down Expand Up @@ -213,7 +216,7 @@ def from_midi(cls, mid_path: str | Path) -> "MidiDict":
return cls(**midi_to_dict(mid))

def calculate_hash(self) -> str:
msg_dict_to_hash = dict(self.get_msg_dict())
msg_dict_to_hash = cast(dict, self.get_msg_dict())

# Remove metadata before calculating hash
msg_dict_to_hash.pop("meta_msgs")
Expand All @@ -234,12 +237,12 @@ def tick_to_ms(self, tick: int) -> int:
ticks_per_beat=self.ticks_per_beat,
)

def _build_pedal_intervals(self) -> Dict[int, List[List[int]]]:
def _build_pedal_intervals(self) -> dict[int, list[list[int]]]:
"""Returns a mapping of channels to sustain pedal intervals."""

self.pedal_msgs.sort(key=lambda msg: msg["tick"])
channel_to_pedal_intervals = defaultdict(list)
pedal_status: Dict[int, int] = {}
pedal_status: dict[int, int] = {}

for pedal_msg in self.pedal_msgs:
tick = pedal_msg["tick"]
Expand Down Expand Up @@ -276,7 +279,7 @@ def resolve_overlaps(self) -> "MidiDict":
"""

# Organize notes by channel and pitch
note_msgs_c: Dict[int, Dict[int, List[NoteMessage]]] = defaultdict(
note_msgs_c: dict[int, dict[int, list[NoteMessage]]] = defaultdict(
lambda: defaultdict(list)
)
for msg in self.note_msgs:
Expand Down Expand Up @@ -330,7 +333,7 @@ def resolve_pedal(self) -> "MidiDict":

return self

# TODO: Needs to be refactored and tested
# TODO: Needs to be refactored
def remove_redundant_pedals(self) -> "MidiDict":
"""Removes redundant pedal messages from the MIDI data in place.
Expand All @@ -342,7 +345,7 @@ def remove_redundant_pedals(self) -> "MidiDict":
def _is_pedal_useful(
pedal_start_tick: int,
pedal_end_tick: int,
note_msgs: List[NoteMessage],
note_msgs: list[NoteMessage],
) -> bool:
# This logic loops through the note_msgs that could possibly
# be effected by the pedal which starts at pedal_start_tick
Expand Down Expand Up @@ -486,7 +489,7 @@ def remove_instruments(self, config: dict) -> "MidiDict":
channels_to_remove = [i for i in channels_to_remove if i != 9]

# Remove unwanted messages all type by looping over msgs types
_msg_dict: Dict[str, List] = {
_msg_dict: dict[str, list] = {
"meta_msgs": self.meta_msgs,
"tempo_msgs": self.tempo_msgs,
"pedal_msgs": self.pedal_msgs,
Expand All @@ -511,20 +514,20 @@ def remove_instruments(self, config: dict) -> "MidiDict":
# TODO: The sign has been changed. Make sure this function isn't used anywhere else
def _extract_track_data(
track: mido.MidiTrack,
) -> Tuple[
List[MetaMessage],
List[TempoMessage],
List[PedalMessage],
List[InstrumentMessage],
List[NoteMessage],
) -> tuple[
list[MetaMessage],
list[TempoMessage],
list[PedalMessage],
list[InstrumentMessage],
list[NoteMessage],
]:
"""Converts MIDI messages into format used by MidiDict."""

meta_msgs: List[MetaMessage] = []
tempo_msgs: List[TempoMessage] = []
pedal_msgs: List[PedalMessage] = []
instrument_msgs: List[InstrumentMessage] = []
note_msgs: List[NoteMessage] = []
meta_msgs: list[MetaMessage] = []
tempo_msgs: list[TempoMessage] = []
pedal_msgs: list[PedalMessage] = []
instrument_msgs: list[InstrumentMessage] = []
note_msgs: list[NoteMessage] = []

last_note_on = defaultdict(list)
for message in track:
Expand Down Expand Up @@ -684,7 +687,7 @@ def midi_to_dict(mid: mido.MidiFile) -> MidiDictData:
metadata_fn = get_metadata_fn(
metadata_process_name=metadata_process_name
)
fn_args: Dict = metadata_process_config["args"]
fn_args: dict = metadata_process_config["args"]

collected_metadata = metadata_fn(mid, midi_dict_data, **fn_args)
if collected_metadata:
Expand Down Expand Up @@ -788,11 +791,11 @@ def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile:
)

# Magic sorting function
def _sort_fn(msg: mido.Message) -> Tuple[int, int]:
def _sort_fn(msg: mido.Message) -> tuple[int, int]:
if hasattr(msg, "velocity"):
return (msg.time, msg.velocity)
return (msg.time, msg.velocity) # pyright: ignore
else:
return (msg.time, 1000)
return (msg.time, 1000) # pyright: ignore

# Sort and convert from abs_time -> delta_time
track = sorted(track, key=_sort_fn)
Expand All @@ -812,7 +815,7 @@ def _sort_fn(msg: mido.Message) -> Tuple[int, int]:
def get_duration_ms(
start_tick: int,
end_tick: int,
tempo_msgs: List[TempoMessage],
tempo_msgs: list[TempoMessage],
ticks_per_beat: int,
) -> int:
"""Calculates elapsed time (in ms) between start_tick and end_tick."""
Expand Down Expand Up @@ -897,7 +900,7 @@ def to_ascii(s: str) -> str:

def meta_composer_filename(
mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list
) -> Dict[str, str]:
) -> dict[str, str]:
file_name = Path(str(mid.filename)).stem
matched_names_unique = set()
for name in composer_names:
Expand All @@ -914,7 +917,7 @@ def meta_composer_filename(

def meta_form_filename(
mid: mido.MidiFile, msg_data: MidiDictData, form_names: list
) -> Dict[str, str]:
) -> dict[str, str]:
file_name = Path(str(mid.filename)).stem
matched_names_unique = set()
for name in form_names:
Expand All @@ -931,7 +934,7 @@ def meta_form_filename(

def meta_composer_metamsg(
mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list
) -> Dict[str, str]:
) -> dict[str, str]:
matched_names_unique = set()
for msg in msg_data["meta_msgs"]:
for name in composer_names:
Expand All @@ -952,7 +955,7 @@ def meta_maestro_json(
msg_data: MidiDictData,
composer_names: list,
form_names: list,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Loads composer and form metadata from MAESTRO metadata json file.
Expand Down Expand Up @@ -990,16 +993,16 @@ def meta_maestro_json(
return res


def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> Dict[str, str]:
def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> dict[str, str]:
return {"abs_path": str(Path(str(mid.filename)).absolute())}


def get_metadata_fn(
metadata_process_name: str,
) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]]:
name_to_fn: Dict[
) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], dict[str, str]]:
name_to_fn: dict[
str,
Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]],
Callable[Concatenate[mido.MidiFile, MidiDictData, ...], dict[str, str]],
] = {
"composer_filename": meta_composer_filename,
"composer_metamsg": meta_composer_metamsg,
Expand All @@ -1017,7 +1020,7 @@ def get_metadata_fn(
return fn


def test_max_programs(midi_dict: MidiDict, max: int) -> Tuple[bool, int]:
def test_max_programs(midi_dict: MidiDict, max: int) -> tuple[bool, int]:
"""Returns false if midi_dict uses more than {max} programs."""
present_programs = set(
map(
Expand All @@ -1032,7 +1035,7 @@ def test_max_programs(midi_dict: MidiDict, max: int) -> Tuple[bool, int]:
return False, len(present_programs)


def test_max_instruments(midi_dict: MidiDict, max: int) -> Tuple[bool, int]:
def test_max_instruments(midi_dict: MidiDict, max: int) -> tuple[bool, int]:
present_instruments = set(
map(
lambda msg: midi_dict.program_to_instrument[msg["data"]],
Expand All @@ -1048,7 +1051,7 @@ def test_max_instruments(midi_dict: MidiDict, max: int) -> Tuple[bool, int]:

def test_note_frequency(
midi_dict: MidiDict, max_per_second: float, min_per_second: float
) -> Tuple[bool, float]:
) -> tuple[bool, float]:
if not midi_dict.note_msgs:
return False, 0.0

Expand All @@ -1073,7 +1076,7 @@ def test_note_frequency(

def test_note_frequency_per_instrument(
midi_dict: MidiDict, max_per_second: float, min_per_second: float
) -> Tuple[bool, float]:
) -> tuple[bool, float]:
num_instruments = len(
set(
map(
Expand Down Expand Up @@ -1111,7 +1114,7 @@ def test_note_frequency_per_instrument(

def test_min_length(
midi_dict: MidiDict, min_seconds: int
) -> Tuple[bool, float]:
) -> tuple[bool, float]:
if not midi_dict.note_msgs:
return False, 0.0

Expand All @@ -1130,9 +1133,9 @@ def test_min_length(

def get_test_fn(
test_name: str,
) -> Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]]:
name_to_fn: Dict[
str, Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]]
) -> Callable[Concatenate[MidiDict, ...], tuple[bool, Any]]:
name_to_fn: dict[
str, Callable[Concatenate[MidiDict, ...], tuple[bool, Any]]
] = {
"max_programs": test_max_programs,
"max_instruments": test_max_instruments,
Expand Down
5 changes: 5 additions & 0 deletions ariautils/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Includes Tokenizers and pre-processing utilities."""

from ariautils.tokenizer._base import Tokenizer

__all__ = ["Tokenizer"]
Loading