diff --git a/ariautils/midi.py b/ariautils/midi.py index 5977b25..a59526e 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -1,4 +1,4 @@ -"""Utils for data/MIDI processing.""" +"""Utils for MIDI processing.""" import re import os @@ -7,22 +7,20 @@ import unicodedata import mido +from mido.midifiles.units import tick2second from collections import defaultdict from pathlib import Path from typing import ( - List, - Dict, Any, - Tuple, Final, Concatenate, Callable, TypeAlias, Literal, TypedDict, + cast, ) -from mido.midifiles.units import tick2second from ariautils.utils import load_config, load_maestro_metadata_json @@ -83,37 +81,37 @@ class NoteMessage(TypedDict): class MidiDictData(TypedDict): """Type for MidiDict attributes in dictionary form.""" - meta_msgs: List[MetaMessage] - tempo_msgs: List[TempoMessage] - pedal_msgs: List[PedalMessage] - instrument_msgs: List[InstrumentMessage] - note_msgs: List[NoteMessage] + meta_msgs: list[MetaMessage] + tempo_msgs: list[TempoMessage] + pedal_msgs: list[PedalMessage] + instrument_msgs: list[InstrumentMessage] + note_msgs: list[NoteMessage] ticks_per_beat: int - metadata: Dict[str, Any] + metadata: dict[str, Any] class MidiDict: """Container for MIDI data in dictionary form. Args: - meta_msgs (List[MetaMessage]): List of text or copyright MIDI meta messages. - tempo_msgs (List[TempoMessage]): List of tempo change messages. - pedal_msgs (List[PedalMessage]): List of sustain pedal messages. - instrument_msgs (List[InstrumentMessage]): List of program change messages. - note_msgs (List[NoteMessage]): List of note messages from paired note-on/off events. + meta_msgs (list[MetaMessage]): List of text or copyright MIDI meta messages. + tempo_msgs (list[TempoMessage]): List of tempo change messages. + pedal_msgs (list[PedalMessage]): List of sustain pedal messages. + instrument_msgs (list[InstrumentMessage]): List of program change messages. + note_msgs (list[NoteMessage]): List of note messages from paired note-on/off events. ticks_per_beat (int): MIDI ticks per beat. metadata (dict): Optional metadata key-value pairs (e.g., {"genre": "classical"}). """ def __init__( self, - meta_msgs: List[MetaMessage], - tempo_msgs: List[TempoMessage], - pedal_msgs: List[PedalMessage], - instrument_msgs: List[InstrumentMessage], - note_msgs: List[NoteMessage], + meta_msgs: list[MetaMessage], + tempo_msgs: list[TempoMessage], + pedal_msgs: list[PedalMessage], + instrument_msgs: list[InstrumentMessage], + note_msgs: list[NoteMessage], ticks_per_beat: int, - metadata: Dict[str, Any], + metadata: dict[str, Any], ): self.meta_msgs = meta_msgs self.tempo_msgs = tempo_msgs @@ -147,10 +145,10 @@ def __init__( self.program_to_instrument = self.get_program_to_instrument() @classmethod - def get_program_to_instrument(cls) -> Dict[int, str]: + def get_program_to_instrument(cls) -> dict[int, str]: """Return a map of MIDI program to instrument name.""" - PROGRAM_TO_INSTRUMENT: Final[Dict[int, str]] = ( + PROGRAM_TO_INSTRUMENT: Final[dict[int, str]] = ( {i: "piano" for i in range(0, 7 + 1)} | {i: "chromatic" for i in range(8, 15 + 1)} | {i: "organ" for i in range(16, 23 + 1)} @@ -213,7 +211,7 @@ def from_midi(cls, mid_path: str | Path) -> "MidiDict": return cls(**midi_to_dict(mid)) def calculate_hash(self) -> str: - msg_dict_to_hash = dict(self.get_msg_dict()) + msg_dict_to_hash = cast(dict, self.get_msg_dict()) # Remove metadata before calculating hash msg_dict_to_hash.pop("meta_msgs") @@ -234,12 +232,12 @@ def tick_to_ms(self, tick: int) -> int: ticks_per_beat=self.ticks_per_beat, ) - def _build_pedal_intervals(self) -> Dict[int, List[List[int]]]: + def _build_pedal_intervals(self) -> dict[int, list[list[int]]]: """Returns a mapping of channels to sustain pedal intervals.""" self.pedal_msgs.sort(key=lambda msg: msg["tick"]) channel_to_pedal_intervals = defaultdict(list) - pedal_status: Dict[int, int] = {} + pedal_status: dict[int, int] = {} for pedal_msg in self.pedal_msgs: tick = pedal_msg["tick"] @@ -276,7 +274,7 @@ def resolve_overlaps(self) -> "MidiDict": """ # Organize notes by channel and pitch - note_msgs_c: Dict[int, Dict[int, List[NoteMessage]]] = defaultdict( + note_msgs_c: dict[int, dict[int, list[NoteMessage]]] = defaultdict( lambda: defaultdict(list) ) for msg in self.note_msgs: @@ -330,7 +328,7 @@ def resolve_pedal(self) -> "MidiDict": return self - # TODO: Needs to be refactored and tested + # TODO: Needs to be refactored def remove_redundant_pedals(self) -> "MidiDict": """Removes redundant pedal messages from the MIDI data in place. @@ -342,7 +340,7 @@ def remove_redundant_pedals(self) -> "MidiDict": def _is_pedal_useful( pedal_start_tick: int, pedal_end_tick: int, - note_msgs: List[NoteMessage], + note_msgs: list[NoteMessage], ) -> bool: # This logic loops through the note_msgs that could possibly # be effected by the pedal which starts at pedal_start_tick @@ -486,7 +484,7 @@ def remove_instruments(self, config: dict) -> "MidiDict": channels_to_remove = [i for i in channels_to_remove if i != 9] # Remove unwanted messages all type by looping over msgs types - _msg_dict: Dict[str, List] = { + _msg_dict: dict[str, list] = { "meta_msgs": self.meta_msgs, "tempo_msgs": self.tempo_msgs, "pedal_msgs": self.pedal_msgs, @@ -511,20 +509,20 @@ def remove_instruments(self, config: dict) -> "MidiDict": # TODO: The sign has been changed. Make sure this function isn't used anywhere else def _extract_track_data( track: mido.MidiTrack, -) -> Tuple[ - List[MetaMessage], - List[TempoMessage], - List[PedalMessage], - List[InstrumentMessage], - List[NoteMessage], +) -> tuple[ + list[MetaMessage], + list[TempoMessage], + list[PedalMessage], + list[InstrumentMessage], + list[NoteMessage], ]: """Converts MIDI messages into format used by MidiDict.""" - meta_msgs: List[MetaMessage] = [] - tempo_msgs: List[TempoMessage] = [] - pedal_msgs: List[PedalMessage] = [] - instrument_msgs: List[InstrumentMessage] = [] - note_msgs: List[NoteMessage] = [] + meta_msgs: list[MetaMessage] = [] + tempo_msgs: list[TempoMessage] = [] + pedal_msgs: list[PedalMessage] = [] + instrument_msgs: list[InstrumentMessage] = [] + note_msgs: list[NoteMessage] = [] last_note_on = defaultdict(list) for message in track: @@ -684,7 +682,7 @@ def midi_to_dict(mid: mido.MidiFile) -> MidiDictData: metadata_fn = get_metadata_fn( metadata_process_name=metadata_process_name ) - fn_args: Dict = metadata_process_config["args"] + fn_args: dict = metadata_process_config["args"] collected_metadata = metadata_fn(mid, midi_dict_data, **fn_args) if collected_metadata: @@ -788,11 +786,11 @@ def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile: ) # Magic sorting function - def _sort_fn(msg: mido.Message) -> Tuple[int, int]: + def _sort_fn(msg: mido.Message) -> tuple[int, int]: if hasattr(msg, "velocity"): - return (msg.time, msg.velocity) + return (msg.time, msg.velocity) # pyright: ignore else: - return (msg.time, 1000) + return (msg.time, 1000) # pyright: ignore # Sort and convert from abs_time -> delta_time track = sorted(track, key=_sort_fn) @@ -812,7 +810,7 @@ def _sort_fn(msg: mido.Message) -> Tuple[int, int]: def get_duration_ms( start_tick: int, end_tick: int, - tempo_msgs: List[TempoMessage], + tempo_msgs: list[TempoMessage], ticks_per_beat: int, ) -> int: """Calculates elapsed time (in ms) between start_tick and end_tick.""" @@ -897,7 +895,7 @@ def to_ascii(s: str) -> str: def meta_composer_filename( mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list -) -> Dict[str, str]: +) -> dict[str, str]: file_name = Path(str(mid.filename)).stem matched_names_unique = set() for name in composer_names: @@ -914,7 +912,7 @@ def meta_composer_filename( def meta_form_filename( mid: mido.MidiFile, msg_data: MidiDictData, form_names: list -) -> Dict[str, str]: +) -> dict[str, str]: file_name = Path(str(mid.filename)).stem matched_names_unique = set() for name in form_names: @@ -931,7 +929,7 @@ def meta_form_filename( def meta_composer_metamsg( mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list -) -> Dict[str, str]: +) -> dict[str, str]: matched_names_unique = set() for msg in msg_data["meta_msgs"]: for name in composer_names: @@ -952,7 +950,7 @@ def meta_maestro_json( msg_data: MidiDictData, composer_names: list, form_names: list, -) -> Dict[str, str]: +) -> dict[str, str]: """Loads composer and form metadata from MAESTRO metadata json file. @@ -990,16 +988,16 @@ def meta_maestro_json( return res -def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> Dict[str, str]: +def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> dict[str, str]: return {"abs_path": str(Path(str(mid.filename)).absolute())} def get_metadata_fn( metadata_process_name: str, -) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]]: - name_to_fn: Dict[ +) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], dict[str, str]]: + name_to_fn: dict[ str, - Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]], + Callable[Concatenate[mido.MidiFile, MidiDictData, ...], dict[str, str]], ] = { "composer_filename": meta_composer_filename, "composer_metamsg": meta_composer_metamsg, @@ -1017,7 +1015,7 @@ def get_metadata_fn( return fn -def test_max_programs(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: +def test_max_programs(midi_dict: MidiDict, max: int) -> tuple[bool, int]: """Returns false if midi_dict uses more than {max} programs.""" present_programs = set( map( @@ -1032,7 +1030,7 @@ def test_max_programs(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: return False, len(present_programs) -def test_max_instruments(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: +def test_max_instruments(midi_dict: MidiDict, max: int) -> tuple[bool, int]: present_instruments = set( map( lambda msg: midi_dict.program_to_instrument[msg["data"]], @@ -1048,7 +1046,7 @@ def test_max_instruments(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: def test_note_frequency( midi_dict: MidiDict, max_per_second: float, min_per_second: float -) -> Tuple[bool, float]: +) -> tuple[bool, float]: if not midi_dict.note_msgs: return False, 0.0 @@ -1073,7 +1071,7 @@ def test_note_frequency( def test_note_frequency_per_instrument( midi_dict: MidiDict, max_per_second: float, min_per_second: float -) -> Tuple[bool, float]: +) -> tuple[bool, float]: num_instruments = len( set( map( @@ -1111,7 +1109,7 @@ def test_note_frequency_per_instrument( def test_min_length( midi_dict: MidiDict, min_seconds: int -) -> Tuple[bool, float]: +) -> tuple[bool, float]: if not midi_dict.note_msgs: return False, 0.0 @@ -1130,9 +1128,9 @@ def test_min_length( def get_test_fn( test_name: str, -) -> Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]]: - name_to_fn: Dict[ - str, Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]] +) -> Callable[Concatenate[MidiDict, ...], tuple[bool, Any]]: + name_to_fn: dict[ + str, Callable[Concatenate[MidiDict, ...], tuple[bool, Any]] ] = { "max_programs": test_max_programs, "max_instruments": test_max_instruments, diff --git a/ariautils/tokenizer/__init__.py b/ariautils/tokenizer/__init__.py new file mode 100644 index 0000000..382adda --- /dev/null +++ b/ariautils/tokenizer/__init__.py @@ -0,0 +1,246 @@ +"""Includes Tokenizers and pre-processing utilities.""" + +import functools + +from typing import ( + Any, + Final, + Callable, + TypeAlias, +) + +from ariautils.midi import MidiDict + + +SpecialToken: TypeAlias = str +Token: TypeAlias = tuple[Any, ...] | str + + +class Tokenizer: + """Abstract Tokenizer class for tokenizing midi_dict objects. + + Args: + return_tensors (bool, optional): If True, encode will return tensors. + Defaults to False. + """ + + def __init__( + self, + return_tensors: bool = False, + ): + self.name: str = "" + self.return_tensors = return_tensors # DELETE + + self.bos_tok: Final[SpecialToken] = "" + self.eos_tok: Final[SpecialToken] = "" + self.pad_tok: Final[SpecialToken] = "

" + self.unk_tok: Final[SpecialToken] = "" + self.dim_tok: Final[SpecialToken] = "" + + self.special_tokens: list[SpecialToken] = [ + self.bos_tok, + self.eos_tok, + self.pad_tok, + self.unk_tok, + self.dim_tok, + ] + + # These must be implemented in child class (abstract params) + self.vocab: tuple[Token, ...] = () + self.instruments_wd: list[str] = [] + self.instruments_nd: list[str] = [] + self.config: dict[str, Any] = {} + self.tok_to_id: dict[Token, int] = {} + self.id_to_tok: dict[int, Token] = {} + self.vocab_size: int = -1 + self.pad_id: int = -1 + + def _tokenize_midi_dict(self, midi_dict: MidiDict) -> list[Token]: + """Abstract method for tokenizing a MidiDict object into a sequence of + tokens. + + Args: + midi_dict (MidiDict): The MidiDict to tokenize. + + Returns: + list[Token]: A sequence of tokens representing the MIDI content. + """ + + raise NotImplementedError + + def tokenize(self, midi_dict: MidiDict, **kwargs: Any) -> list[Token]: + """Tokenizes a MidiDict object. + + This function should be overridden if additional transformations are + required, e.g., adding additional tokens. The default behavior is to + call tokenize_midi_dict. + + Args: + midi_dict (MidiDict): The MidiDict to tokenize. + **kwargs (Any): Additional keyword arguments passed to _tokenize_midi_dict. + + Returns: + list[Token]: A sequence of tokens representing the MIDI content. + """ + + return self._tokenize_midi_dict(midi_dict, **kwargs) + + def _detokenize_midi_dict(self, tokenized_seq: list[int]) -> MidiDict: + """Abstract method for de-tokenizing a sequence of tokens into a + MidiDict Object. + + Args: + tokenized_seq (list[int]): The sequence of tokens to detokenize. + + Returns: + MidiDict: A MidiDict reconstructed from the tokens. + """ + + raise NotImplementedError + + def detokenize(self, tokenized_seq: list[int], **kwargs: Any) -> MidiDict: + """Detokenizes a MidiDict object. + + This function should be overridden if additional are required during + detokenization. The default behavior is to call detokenize_midi_dict. + + Args: + tokenized_seq (list): The sequence of tokens to detokenize. + **kwargs (Any): Additional keyword arguments passed to detokenize_midi_dict. + + Returns: + MidiDict: A MidiDict reconstructed from the tokens. + """ + + return self._detokenize_midi_dict(tokenized_seq, **kwargs) + + def export_data_aug(cls) -> list[Callable[[list[Token]], list[Token]]]: + """Export a list of implemented data augmentation functions.""" + + raise NotImplementedError + + def encode(self, unencoded_seq: list[Token]) -> list[int]: + """Converts tokenized sequence into the corresponding list of ids.""" + + def _enc_fn(tok: Token) -> int: + return self.tok_to_id.get(tok, self.tok_to_id[self.unk_tok]) + + if self.tok_to_id is None: + raise NotImplementedError("tok_to_id") + + encoded_seq = [_enc_fn(tok) for tok in unencoded_seq] + + return encoded_seq + + def decode(self, encoded_seq: list[int]) -> list[Token]: + """Converts list of ids into the corresponding list of tokens.""" + + def _dec_fn(id: int) -> Token: + return self.id_to_tok.get(id, self.unk_tok) + + if self.id_to_tok is None: + raise NotImplementedError("id_to_tok") + + decoded_seq = [_dec_fn(idx) for idx in encoded_seq] + + return decoded_seq + + @classmethod + def _find_closest_int(cls, n: int, sorted_list: list[int]) -> int: + # Selects closest integer to n from sorted_list + # Time ~ Log(n) + + if not sorted_list: + raise ValueError("List is empty") + + left, right = 0, len(sorted_list) - 1 + closest = float("inf") + + while left <= right: + mid = (left + right) // 2 + diff = abs(sorted_list[mid] - n) + + if diff < abs(closest - n): + closest = sorted_list[mid] + + if sorted_list[mid] < n: + left = mid + 1 + else: + right = mid - 1 + + return closest # type: ignore[return-value] + + def add_tokens_to_vocab(self, tokens: list[Token] | tuple[Token]) -> None: + """Utility function for safely adding extra tokens to vocab.""" + + for token in tokens: + assert token not in self.vocab + + self.vocab = self.vocab + tuple(tokens) + self.tok_to_id = {tok: idx for idx, tok in enumerate(self.vocab)} + self.id_to_tok = {v: k for k, v in self.tok_to_id.items()} + self.vocab_size = len(self.vocab) + + def export_aug_fn_concat( + self, aug_fn: Callable[[list[Token]], list[Token]] + ) -> Callable[[list[Token]], list[Token]]: + """Transforms an augmentation function for concatenated sequences. + + This is useful for augmentation functions that are only defined for + sequences which start with and end with . + + Args: + aug_fn (Callable[[list[Token]], list[Token]]): The augmentation + function to transform. + + Returns: + Callable[[list[Token]], list[Token]]: A transformed augmentation + function that can handle concatenated sequences. + """ + + def _aug_fn_concat( + src: list[Token], + _aug_fn: Callable[[list[Token]], list[Token]], + pad_tok: str, + eos_tok: str, + **kwargs: Any, + ) -> list[Token]: + # Split list on '' + initial_seq_len = len(src) + src_sep = [] + prev_idx = 0 + for curr_idx, tok in enumerate(src, start=1): + if tok == eos_tok: + src_sep.append(src[prev_idx:curr_idx]) + prev_idx = curr_idx + + # Last sequence + if prev_idx != curr_idx: + src_sep.append(src[prev_idx:]) + + # Augment + src_sep = [ + _aug_fn( + _src, + **kwargs, + ) + for _src in src_sep + ] + + # Concatenate + src_aug_concat = [tok for src_aug in src_sep for tok in src_aug] + + # Pad or truncate to original sequence length as necessary + src_aug_concat = src_aug_concat[:initial_seq_len] + src_aug_concat += [pad_tok] * ( + initial_seq_len - len(src_aug_concat) + ) + + return src_aug_concat + + return functools.partial( + _aug_fn_concat, + _aug_fn=aug_fn, + pad_tok=self.pad_tok, + eos_tok=self.eos_tok, + ) diff --git a/ariautils/utils/__init__.py b/ariautils/utils/__init__.py index eb1ccad..4c51084 100644 --- a/ariautils/utils/__init__.py +++ b/ariautils/utils/__init__.py @@ -4,7 +4,7 @@ import logging from importlib import resources -from typing import Dict, Any, cast +from typing import Any, cast from .config import load_config @@ -26,14 +26,14 @@ def get_logger(name: str) -> logging.Logger: return logger -def load_maestro_metadata_json() -> Dict[str, Any]: +def load_maestro_metadata_json() -> dict[str, Any]: """Loads MAESTRO metadata json .""" with ( resources.files("ariautils.config") .joinpath("maestro_metadata.json") .open("r") as f ): - return cast(Dict[str, Any], json.load(f)) + return cast(dict[str, Any], json.load(f)) __all__ = ["load_config", "load_maestro_metadata_json", "get_logger"] diff --git a/ariautils/utils/config.py b/ariautils/utils/config.py index a4fd267..4093318 100644 --- a/ariautils/utils/config.py +++ b/ariautils/utils/config.py @@ -1,17 +1,16 @@ """Includes functionality for loading config files.""" -import os import json from importlib import resources -from typing import Dict, Any, cast +from typing import Any, cast -def load_config() -> Dict[str, Any]: +def load_config() -> dict[str, Any]: """Returns a dictionary loaded from the config.json file.""" with ( resources.files("ariautils.config") .joinpath("config.json") .open("r") as f ): - return cast(Dict[str, Any], json.load(f)) + return cast(dict[str, Any], json.load(f)) diff --git a/tests/test_midi.py b/tests/test_midi.py index 74a7402..5d6aa02 100644 --- a/tests/test_midi.py +++ b/tests/test_midi.py @@ -1,4 +1,6 @@ import unittest +import tempfile +import shutil from importlib import resources from pathlib import Path @@ -40,6 +42,33 @@ def test_save(self) -> None: midi_dict = MidiDict.from_midi(mid_path=load_path) midi_dict.to_midi().save(save_path) + def test_tick_to_ms(self) -> None: + CORRECT_LAST_NOTE_ONSET_MS: Final[int] = 220140 + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + midi_dict = MidiDict.from_midi(load_path) + last_note = midi_dict.note_msgs[-1] + last_note_onset_tick = last_note["tick"] + last_note_onset_ms = midi_dict.tick_to_ms(last_note_onset_tick) + self.assertEqual(last_note_onset_ms, CORRECT_LAST_NOTE_ONSET_MS) + + def test_calculate_hash(self) -> None: + # Load two identical files with different filenames and metadata + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + midi_dict_orig = MidiDict.from_midi(load_path) + + with tempfile.NamedTemporaryFile(delete=True) as temp_file: + shutil.copy(load_path, temp_file.name) + midi_dict_temp = MidiDict.from_midi(temp_file.name) + + midi_dict_temp.meta_msgs.append({"type": "text", "data": "test"}) + midi_dict_temp.metadata["composer"] = "test" + midi_dict_temp.metadata["composer"] = "test" + midi_dict_temp.metadata["ticks_per_beat"] = -1 + + self.assertEqual( + midi_dict_orig.calculate_hash(), midi_dict_temp.calculate_hash() + ) + def test_resolve_pedal(self) -> None: load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") save_path = RESULTS_DATA_DIRECTORY.joinpath(