Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate RelTokenizer #6

Merged
merged 20 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions ariautils/config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,9 @@
"sfx": 120
},
"drum_velocity": 60,
"velocity_quantization": {
"step": 15
},
"time_quantization": {
"num_steps": 500,
"step": 10
},
"velocity_quantization_step": 10,
"max_time_ms": 5000,
"time_step_ms": 10,
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"],
"form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"],
"genre_names": ["jazz", "classical"]
Expand Down Expand Up @@ -225,18 +221,13 @@
"sfx": 120
},
"drum_velocity": 60,
"velocity_quantization": {
"step": 10
},
"velocity_quantization_step": 10,
"abs_time_step_ms": 5000,
"max_dur_ms": 5000,
"time_step_ms": 10,
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"],
"form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"],
"genre_names": ["jazz", "classical"]
},
"lm": {
"tags": ["happy", "sad"]
}
}
}
3 changes: 2 additions & 1 deletion ariautils/tokenizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

from ._base import Tokenizer
from .absolute import AbsTokenizer
from .relative import RelTokenizer

__all__ = ["Tokenizer", "AbsTokenizer"]
__all__ = ["Tokenizer", "AbsTokenizer", "RelTokenizer"]
144 changes: 75 additions & 69 deletions ariautils/tokenizer/absolute.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# channels, with the same instrument, have overlaping notes
# - There are tons of edge cases here e.g., what if there are two identical
# notes on different channels.
# - Add information about the config, i.e., which instruments removed


class AbsTokenizer(Tokenizer):
Expand All @@ -55,10 +56,13 @@ class AbsTokenizer(Tokenizer):

Notes:
- Notes are ordered according to onset time
- Sustain pedals effects are incorporated directly into note durations
- Various configuration settings effecting instrument processing,
timing resolution, and quantization levels can be adjusted the
config.json at 'tokenizer.abs'.
- Sustain pedal effects are incorporated directly into note durations
- Start (<S>) and end (<E>) tokens wrap the tokenized sequence, and
prefix tokens for instrument, genre, composer, and form are
prepended, i.e., before the <S> token
- Various configuration settings affecting instrument processing,
timing resolution, quantization levels, and prefix tokens can be
adjusted in config.json at 'tokenizer.abs'.
"""

def __init__(self) -> None: # Not sure why this is required by
Expand All @@ -67,20 +71,21 @@ def __init__(self) -> None: # Not sure why this is required by
self.name = "abs"

# Calculate time quantizations (in ms)
self.abs_time_step: int = self.config["abs_time_step_ms"]
self.max_dur: int = self.config["max_dur_ms"]
self.time_step: int = self.config["time_step_ms"]
self.abs_time_step_ms: int = self.config["abs_time_step_ms"]
self.max_dur_ms: int = self.config["max_dur_ms"]
self.time_step_ms: int = self.config["time_step_ms"]

self.dur_time_quantizations = [
self.time_step * i
for i in range((self.max_dur // self.time_step) + 1)
self.time_step_ms * i
for i in range((self.max_dur_ms // self.time_step_ms) + 1)
]
self.onset_time_quantizations = [
self.time_step * i for i in range((self.max_dur // self.time_step))
self.time_step_ms * i
for i in range((self.max_dur_ms // self.time_step_ms))
]

# Calculate velocity quantizations
self.velocity_step: int = self.config["velocity_quantization"]["step"]
self.velocity_step: int = self.config["velocity_quantization_step"]
self.velocity_quantizations = [
i * self.velocity_step
for i in range(int(127 / self.velocity_step) + 1)
Expand Down Expand Up @@ -148,7 +153,7 @@ def _quantize_dur(self, time: int) -> int:
# This function will return values res >= 0 (inc. 0)
dur = self._find_closest_int(time, self.dur_time_quantizations)

return dur if dur != 0 else self.time_step
return dur if dur != 0 else self.time_step_ms

def _quantize_onset(self, time: int) -> int:
# This function will return values res >= 0 (inc. 0)
Expand Down Expand Up @@ -187,8 +192,17 @@ def _format(
return res

def calc_length_ms(self, seq: list[Token], onset: bool = False) -> int:
"""Calculates time (ms) end of sequence to the end of the last note. If
onset=True, then it will return the onset time of the last note instead
"""Calculates sequence time length in milliseconds.

Args:
seq (list[Token]): List of tokens to process.
onset (bool): If True, returns onset time of last note instead of
total duration.

Returns:
int: Time in milliseconds from start to either:
- End of last note if onset=False
- Onset of last note if onset=True
"""

# Find the index of the last onset or dur token
Expand All @@ -200,7 +214,7 @@ def calc_length_ms(self, seq: list[Token], onset: bool = False) -> int:
else:
seq.pop()

time_offset_ms = seq.count(self.time_tok) * self.abs_time_step
time_offset_ms = seq.count(self.time_tok) * self.abs_time_step_ms
idx = len(seq) - 1
for tok in seq[::-1]:
if type(tok) is tuple and tok[0] == "dur":
Expand All @@ -219,7 +233,6 @@ def calc_length_ms(self, seq: list[Token], onset: bool = False) -> int:

idx -= 1

# If it gets to this point, an error has occurred
raise Exception("Invalid sequence format")

def truncate_by_time(
Expand All @@ -229,7 +242,7 @@ def truncate_by_time(
time_offset_ms = 0
for idx, tok in enumerate(tokenized_seq):
if tok == self.time_tok:
time_offset_ms += self.abs_time_step
time_offset_ms += self.abs_time_step_ms
elif type(tok) is tuple and tok[0] == "onset":
if time_offset_ms + tok[1] > trunc_time_ms:
return tokenized_seq[: idx - 1]
Expand Down Expand Up @@ -304,8 +317,8 @@ def _tokenize_midi_dict(

# Add abs time token if necessary
time_toks_to_append = (
curr_time_since_onset // self.abs_time_step
) - (prev_time_since_onset // self.abs_time_step)
curr_time_since_onset // self.abs_time_step_ms
) - (prev_time_since_onset // self.abs_time_step_ms)
if time_toks_to_append > 0:
for _ in range(time_toks_to_append):
tokenized_seq.append(self.time_tok)
Expand All @@ -314,7 +327,7 @@ def _tokenize_midi_dict(
# MIDI channel is 9 when 0 indexing
if _channel == 9:
_note_onset = self._quantize_onset(
curr_time_since_onset % self.abs_time_step
curr_time_since_onset % self.abs_time_step_ms
)
tokenized_seq.append(("drum", _pitch))
tokenized_seq.append(("onset", _note_onset))
Expand All @@ -339,10 +352,9 @@ def _tokenize_midi_dict(
ticks_per_beat=ticks_per_beat,
)

# Quantize
_velocity = self._quantize_velocity(_velocity)
_note_onset = self._quantize_onset(
curr_time_since_onset % self.abs_time_step
curr_time_since_onset % self.abs_time_step_ms
)
_note_duration = self._quantize_dur(_note_duration)

Expand Down Expand Up @@ -384,7 +396,6 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict:
TICKS_PER_BEAT: Final[int] = 500
TEMPO: Final[int] = 500000

# Set message tempos
tempo_msgs: list[TempoMessage] = [
{"type": "tempo", "data": TEMPO, "tick": 0}
]
Expand All @@ -403,7 +414,7 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict:

if tok in self.special_tokens:
if tok == self.time_tok:
curr_tick += self.abs_time_step
curr_tick += self.abs_time_step_ms
continue
elif (
tok[0] == "prefix"
Expand Down Expand Up @@ -463,7 +474,7 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict:
_tok_type_3 = tok_3[0]

if tok_1 == self.time_tok:
curr_tick += self.abs_time_step
curr_tick += self.abs_time_step_ms

elif (
_tok_type_1 == "special"
Expand All @@ -480,36 +491,34 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict:
tok_1[1], int
), f"Expected int for pitch, got {tok_1[1]}"

_start_tick: int = curr_tick + tok_2[1]
_end_tick: int = _start_tick + self.time_step
_pitch: int = tok_1[1]
_channel: int = instrument_to_channel[tok_1[0]]
_channel = instrument_to_channel["drum"]
_velocity: int = self.config["drum_velocity"]

if _channel is None:
logger.warning(
"Tried to decode note message for unexpected instrument"
)
else:
note_msgs.append(
{
"type": "note",
"data": {
"pitch": _pitch,
"start": _start_tick,
"end": _end_tick,
"velocity": _velocity,
},
"tick": _start_tick,
"channel": _channel,
}
)
_start_tick: int = curr_tick + tok_2[1]
_end_tick: int = _start_tick + self.time_step_ms

note_msgs.append(
{
"type": "note",
"data": {
"pitch": _pitch,
"start": _start_tick,
"end": _end_tick,
"velocity": _velocity,
},
"tick": _start_tick,
"channel": _channel,
}
)

elif (
_tok_type_1 in self.instruments_nd
and _tok_type_2 == "onset"
and _tok_type_3 == "dur"
):
assert isinstance(
tok_1[0], str
), f"Expected str for instrument, got {tok_1[0]}"
assert isinstance(
tok_1[1], int
), f"Expected int for pitch, got {tok_1[1]}"
Expand All @@ -523,17 +532,18 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict:
tok_3[1], int
), f"Expected int for duration, got {tok_3[1]}"

_instrument = tok_1[0]
_pitch = tok_1[1]
_channel = instrument_to_channel[tok_1[0]]
_velocity = tok_1[2]
_start_tick = curr_tick + tok_2[1]
_end_tick = _start_tick + tok_3[1]

if _channel is None:
if _instrument not in instrument_to_channel.keys():
logger.warning(
"Tried to decode note message for unexpected instrument"
f"Tried to decode note message for unexpected instrument: {_instrument} "
)
else:
_channel = instrument_to_channel[_instrument]
note_msgs.append(
{
"type": "note",
Expand Down Expand Up @@ -574,8 +584,6 @@ def detokenize(self, tokenized_seq: list[Token], **kwargs: Any) -> MidiDict:

return self._detokenize_midi_dict(tokenized_seq=tokenized_seq)

from typing import Optional

def export_pitch_aug(
self, max_pitch_aug: int
) -> Callable[Concatenate[list[Token], ...], list[Token]]:
Expand All @@ -585,11 +593,11 @@ def export_pitch_aug(
with the unknown token '<U>'.

Args:
aug_range (int): Returned function will randomly augment the pitch
from a value in the range (-aug_range, aug_range).
max_pitch_aug (int): Returned function will randomly augment the pitch
from a value in the range (-max_pitch_aug, max_pitch_aug).

Returns:
Callable[[list[Token]], list[Token]]: Exported function.
Callable[[list[Token], int], list[Token]]: Exported function.
"""

def pitch_aug_seq(
Expand Down Expand Up @@ -637,7 +645,6 @@ def pitch_aug_tok(tok: Token, _pitch_aug: int) -> Token:

return [pitch_aug_tok(x, pitch_aug) for x in src]

# See functools.partial docs
return self.export_aug_fn_concat(
functools.partial(
pitch_aug_seq,
Expand All @@ -655,12 +662,12 @@ def export_velocity_aug(
valid range.

Args:
aug_steps_range (int): Returned function will randomly augment
velocity in the range aug_steps_range * (-self.velocity_step,
self.velocity step).
max_num_aug_steps (int): Returned function will randomly augment
velocity in the range self.velocity_step * (-max_num_aug_steps,
max_num_aug_steps).

Returns:
Callable[[list[Token]], list[Token]]: Exported function.
Callable[[list[Token], int], list[Token]]: Exported function.
"""

def velocity_aug_seq(
Expand Down Expand Up @@ -709,7 +716,6 @@ def velocity_aug_tok(tok: Token, _velocity_aug: int) -> Token:

return [velocity_aug_tok(x, velocity_aug) for x in src]

# See functools.partial docs
return self.export_aug_fn_concat(
functools.partial(
velocity_aug_seq,
Expand All @@ -719,7 +725,7 @@ def velocity_aug_tok(tok: Token, _velocity_aug: int) -> Token:
)
)

# TODO: Refactor the logic
# TODO: Refactor this logic
def export_tempo_aug(
self, max_tempo_aug: float, mixup: bool
) -> Callable[Concatenate[list[Token], ...], list[Token]]:
Expand All @@ -733,12 +739,12 @@ def export_tempo_aug(
export_data_aug.

Args:
tempo_aug_range (int): Returned function will randomly augment
tempo by a factor in the range (1 - tempo_aug_range,
1 + tempo_aug_range).
max_tempo_aug (float): Returned function will randomly augment
tempo by a factor in the range (1 - max_tempo_aug,
1 + max_tempo_aug).

Returns:
Callable[[list[Token]], list[Token]]: Exported function.
Callable[[list[Token], float], list[Token]]: Exported function.
"""

def tempo_aug(
Expand Down Expand Up @@ -876,9 +882,9 @@ def _quantize_time(_n: int | float) -> int:
return self.export_aug_fn_concat(
functools.partial(
tempo_aug,
abs_time_step=self.abs_time_step,
max_dur=self.max_dur,
time_step=self.time_step,
abs_time_step=self.abs_time_step_ms,
max_dur=self.max_dur_ms,
time_step=self.time_step_ms,
unk_tok=self.unk_tok,
time_tok=self.time_tok,
dim_tok=self.dim_tok,
Expand Down
Loading