Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AbsTokenizer augmentation tests #5

Merged
merged 18 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 39 additions & 35 deletions ariautils/tokenizer/absolute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy

from collections import defaultdict
from typing import Final, Callable, Any
from typing import Final, Callable, Any, Concatenate

from ariautils.midi import (
MidiDict,
Expand All @@ -28,8 +28,8 @@
# - Add asserts to the tokenization / detokenization for user error
# - Need to add a tokenization or MidiDict check of how to resolve different
# channels, with the same instrument, have overlaping notes
# - There are tons of edge cases here e.g., what if there are two indetical notes?
# on different channels.
# - There are tons of edge cases here e.g., what if there are two identical
# notes on different channels.


class AbsTokenizer(Tokenizer):
Expand Down Expand Up @@ -139,7 +139,7 @@ def __init__(self) -> None: # Not sure why this is required by

def export_data_aug(self) -> list[Callable[[list[Token]], list[Token]]]:
return [
self.export_tempo_aug(tempo_aug_range=0.2, mixup=True),
self.export_tempo_aug(max_tempo_aug=0.2, mixup=True),
self.export_pitch_aug(5),
self.export_velocity_aug(1),
]
Expand Down Expand Up @@ -574,9 +574,11 @@ def detokenize(self, tokenized_seq: list[Token], **kwargs: Any) -> MidiDict:

return self._detokenize_midi_dict(tokenized_seq=tokenized_seq)

from typing import Optional

def export_pitch_aug(
self, aug_range: int
) -> Callable[[list[Token]], list[Token]]:
self, max_pitch_aug: int
) -> Callable[Concatenate[list[Token], ...], list[Token]]:
"""Exports a function that augments the pitch of all note tokens.

Notes which fall out of the range (0, 127) will be replaced
Expand All @@ -593,7 +595,7 @@ def export_pitch_aug(
def pitch_aug_seq(
src: list[Token],
unk_tok: str,
_aug_range: int,
_max_pitch_aug: int,
pitch_aug: int | None = None,
) -> list[Token]:
def pitch_aug_tok(tok: Token, _pitch_aug: int) -> Token:
Expand Down Expand Up @@ -630,8 +632,8 @@ def pitch_aug_tok(tok: Token, _pitch_aug: int) -> Token:
else:
return unk_tok

if not pitch_aug:
pitch_aug = random.randint(-_aug_range, _aug_range)
if pitch_aug is None:
pitch_aug = random.randint(-_max_pitch_aug, _max_pitch_aug)

return [pitch_aug_tok(x, pitch_aug) for x in src]

Expand All @@ -640,13 +642,13 @@ def pitch_aug_tok(tok: Token, _pitch_aug: int) -> Token:
functools.partial(
pitch_aug_seq,
unk_tok=self.unk_tok,
_aug_range=aug_range,
_max_pitch_aug=max_pitch_aug,
)
)

def export_velocity_aug(
self, aug_steps_range: int
) -> Callable[[list[Token]], list[Token]]:
self, max_num_aug_steps: int
) -> Callable[Concatenate[list[Token], ...], list[Token]]:
"""Exports a function which augments the velocity of all pitch tokens.

Velocity values are clipped so that they don't fall outside of the
Expand All @@ -663,10 +665,10 @@ def export_velocity_aug(

def velocity_aug_seq(
src: list[Token],
velocity_step: int,
min_velocity_step: int,
max_velocity: int,
_aug_steps_range: int,
velocity_aug: int | None = None,
_max_num_aug_steps: int,
aug_step: int | None = None,
) -> list[Token]:
def velocity_aug_tok(tok: Token, _velocity_aug: int) -> Token:
if isinstance(tok, str): # Stand in for SpecialToken
Expand All @@ -693,32 +695,34 @@ def velocity_aug_tok(tok: Token, _velocity_aug: int) -> Token:
# Check it doesn't go out of bounds
if _velocity + _velocity_aug >= max_velocity:
return (_instrument, _pitch, max_velocity)
elif _velocity + _velocity_aug <= velocity_step:
return (_instrument, _pitch, velocity_step)
elif _velocity + _velocity_aug <= min_velocity_step:
return (_instrument, _pitch, min_velocity_step)

return (_instrument, _pitch, _velocity + _velocity_aug)

if not velocity_aug:
velocity_aug = velocity_step * random.randint(
-_aug_steps_range, _aug_steps_range
if aug_step is None:
velocity_aug = min_velocity_step * random.randint(
-_max_num_aug_steps, _max_num_aug_steps
)
else:
velocity_aug = aug_step * min_velocity_step

return [velocity_aug_tok(x, velocity_aug) for x in src]

# See functools.partial docs
return self.export_aug_fn_concat(
functools.partial(
velocity_aug_seq,
velocity_step=self.velocity_step,
min_velocity_step=self.velocity_step,
max_velocity=self.max_velocity,
_aug_steps_range=aug_steps_range,
_max_num_aug_steps=max_num_aug_steps,
)
)

# TODO: Adjust this so it can handle other tokens like <SEP>
# TODO: Refactor the logic
def export_tempo_aug(
self, tempo_aug_range: float, mixup: bool
) -> Callable[[list[Token]], list[Token]]:
self, max_tempo_aug: float, mixup: bool
) -> Callable[Concatenate[list[Token], ...], list[Token]]:
"""Exports a function which augments the tempo of a sequence of tokens.

Additionally this function performs note-mixup: randomly re-ordering
Expand Down Expand Up @@ -749,23 +753,23 @@ def tempo_aug(
end_tok: str,
instruments_wd: list,
tokenizer_name: str,
_tempo_aug_range: float,
_max_tempo_aug: float,
_mixup: bool,
tempo_aug: float | None = None,
) -> list[Token]:
"""This must be used with export_aug_fn_concat in order to work
properly for concatenated sequences."""

def _quantize_time(_n: int) -> int:
def _quantize_time(_n: int | float) -> int:
return round(_n / time_step) * time_step

assert (
tokenizer_name == "abs"
), f"Augmentation function only supports base AbsTokenizer"

if not tempo_aug:
if tempo_aug is None:
tempo_aug = random.uniform(
1 - _tempo_aug_range, 1 + _tempo_aug_range
1 - _max_tempo_aug, 1 + _max_tempo_aug
)

src_time_tok_cnt = 0
Expand All @@ -785,8 +789,8 @@ def _quantize_time(_n: int) -> int:
elif tok_1 == start_tok:
res.append(tok_1)
continue
elif tok_1 == dim_tok and note_buffer:
assert isinstance(note_buffer["onset"], int)
elif tok_1 == dim_tok and note_buffer is not None:
assert isinstance(note_buffer["onset"], tuple)
dim_tok_seen = (src_time_tok_cnt, note_buffer["onset"][1])
continue
elif tok_1[0] == "prefix":
Expand Down Expand Up @@ -822,9 +826,9 @@ def _quantize_time(_n: int) -> int:
for src_time_tok_cnt, interval_notes in sorted(buffer.items()):
for src_onset, notes_by_onset in sorted(interval_notes.items()):
src_time = src_time_tok_cnt * abs_time_step + src_onset
tgt_time = round(src_time * tempo_aug)
tgt_time = _quantize_time(src_time * tempo_aug)
curr_tgt_time_tok_cnt = tgt_time // abs_time_step
curr_tgt_onset = _quantize_time(tgt_time % abs_time_step)
curr_tgt_onset = tgt_time % abs_time_step

if curr_tgt_onset == abs_time_step:
curr_tgt_onset -= time_step
Expand All @@ -846,7 +850,7 @@ def _quantize_time(_n: int) -> int:
if _src_dur_tok is not None:
assert isinstance(_src_dur_tok[1], int)
tgt_dur = _quantize_time(
round(_src_dur_tok[1] * tempo_aug)
_src_dur_tok[1] * tempo_aug
)
tgt_dur = min(tgt_dur, max_dur)
else:
Expand Down Expand Up @@ -882,7 +886,7 @@ def _quantize_time(_n: int) -> int:
start_tok=self.bos_tok,
instruments_wd=self.instruments_wd,
tokenizer_name=self.name,
_tempo_aug_range=tempo_aug_range,
_max_tempo_aug=max_tempo_aug,
_mixup=mixup,
)
)
Loading