Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add masking_rate_scheduling to examples #406

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions examples/benchmarks/bert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import src.hf_bert as hf_bert_module
import src.mosaic_bert as mosaic_bert_module
import src.text_data as text_data_module
import src.mlm_scheduling as mlm_scheduling_module
from composer import Trainer, algorithms
from composer.callbacks import (HealthChecker, LRMonitor, MemoryMonitor,
OptimizerMonitor, RuntimeEstimator,
Expand Down Expand Up @@ -52,6 +53,26 @@ def update_batch_size_info(cfg: DictConfig):
return cfg


def update_mlm_schedule(cfg: DictConfig):

def convert_constant_rate(dataset_cfg: DictConfig):
mlm_schedule = dataset_cfg.get('mlm_schedule', None)
if mlm_schedule is None:
mlm_probability = dataset_cfg.mlm_probability
mlm_schedule = om.create({
'name': 'constant',
'initial_masking_rate': mlm_probability,
'final_masking_rate': mlm_probability,
})
return mlm_schedule

cfg.train_loader.dataset.mlm_schedule = convert_constant_rate(
cfg.train_loader.dataset)
cfg.eval_loader.dataset.mlm_schedule = convert_constant_rate(
cfg.eval_loader.dataset)
return cfg


def log_config(cfg: DictConfig):
print(om.to_yaml(cfg))
if 'wandb' in cfg.get('loggers', {}):
Expand Down Expand Up @@ -174,15 +195,15 @@ def main(cfg: DictConfig,

# Dataloaders
print('Building train loader...')
train_loader = build_dataloader(
train_loader, distributed_masking_rate = build_dataloader(
cfg.train_loader,
model.tokenizer,
cfg.global_train_batch_size // dist.get_world_size(),
)
print('Building eval loader...')
global_eval_batch_size = cfg.get('global_eval_batch_size',
cfg.global_train_batch_size)
eval_loader = build_dataloader(
eval_loader, _ = build_dataloader(
cfg.eval_loader,
model.tokenizer,
global_eval_batch_size // dist.get_world_size(),
Expand All @@ -205,6 +226,9 @@ def main(cfg: DictConfig,
build_callback(name, callback_cfg)
for name, callback_cfg in cfg.get('callbacks', {}).items()
]
callbacks.append(
mlm_scheduling_module(cfg.train_loader.dataset.mlm_schedule,
distributed_masking_rate))

# Algorithms
algorithms = [
Expand Down Expand Up @@ -265,5 +289,6 @@ def main(cfg: DictConfig,
yaml_cfg = om.load(f)
cli_cfg = om.from_cli(args_list)
cfg = om.merge(yaml_cfg, cli_cfg)
cfg = update_mlm_schedule(cfg)
cfg = cast(DictConfig, cfg) # for type checking
main(cfg)
81 changes: 81 additions & 0 deletions examples/benchmarks/bert/src/mlm_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Union
import multiprocessing

from composer import Callback, State, Logger, Event, Time
from composer.optim.scheduler import (ComposerScheduler, ConstantScheduler,
CosineAnnealingScheduler, _convert_time,
LinearScheduler)
from omegaconf import DictConfig
"""
Definition of schedulers and callbacks for setting the masking rate dynamically
"""


# Define special case of step-wise scheduling where decay is only performed
# once and as such define by start and terminal masking rates
class StepScheduler(ComposerScheduler):
r"""Decays the masking rate by discrete step to new rate.
Args:
alpha_i (float): Multiplier of initial masking rate. Default = ``0.3``.
alpha_f (float): Masking rate to end at. Default = ``0.15``.
t_step (str | Time): The time step to switch masking rate. Default = ``"0.5dur"``.
"""

def __init__(self,
alpha_i: float = 1,
alpha_f: float = 0.5,
t_step: Union[str, Time] = '0.5dur'):
self.alpha_i = alpha_i
self.alpha_f = alpha_f
self.t_step = t_step

def __call__(self, state: State, ssr: float = 1.0):
t_step = _convert_time(self.t_step, state, ssr=ssr)
current_time = state.timestamp.get(t_step.unit)

if t_step.value > current_time.value:
return self.alpha_i

return self.alpha_f


class MaskingRateSetter(Callback):

def __init__(self, scheduler: ComposerScheduler,
initial_masking_rate: float,
dynamic_masking_rate: multiprocessing.Value):
super().__init__()
self.scheduler = scheduler
self.initial_masking_rate = initial_masking_rate
self.dynamic_masking_rate = dynamic_masking_rate

def run_event(self, event: Event, state: State, logger: Logger):
if event == Event.BATCH_END:
masking_rate = self.scheduler(state) * self.initial_masking_rate

self.dynamic_masking_rate.value = masking_rate

logger.log_metrics({'mlm_schedule/masking_rate': masking_rate})


def build_mlm_scheduler_callback(
cfg: DictConfig, distributed_masking_rate: multiprocessing.Value):
initial_masking_rate = cfg.initial_masking_rate
final_masking_rate = cfg.final_masking_rate
alpha_f = final_masking_rate / initial_masking_rate # Multiple to reach final mlm rate

if cfg.name == 'constant':
mlm_schedule = ConstantScheduler()
elif cfg.name == 'cosine':
mlm_schedule = CosineAnnealingScheduler(alpha_f=alpha_f)
elif cfg.name == 'linear':
mlm_schedule = LinearScheduler(alpha_f=alpha_f)
elif cfg.name == 'step':
mlm_schedule = StepScheduler(alpha_f=alpha_f)
else:
raise ValueError(
f'Not sure how to build masking rate scheduler: {cfg.name}')

return MaskingRateSetter(mlm_schedule,
initial_masking_rate=initial_masking_rate,
dynamic_masking_rate=distributed_masking_rate)
34 changes: 28 additions & 6 deletions examples/benchmarks/bert/src/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"""Build a StreamingTextDataset dataset and dataloader for training."""

import os
import multiprocessing
from itertools import islice
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -184,6 +185,23 @@ def __getitem__(self, idx: int) -> Union[Dict[str, Any], torch.Tensor]:
return token_sample


class ScheduledDataCollatorForLanguageModeling(
transformers.DataCollatorForLanguageModeling):

def __init__(self, distributed_mlm_probability: multiprocessing.Value,
*args: Tuple[Any], **kwargs: Dict[str, Any]):
super().__init__(*args, **kwargs)
self.distributed_mlm_probability = distributed_mlm_probability

@property
def mlm_probability(self):
return self.distributed_mlm_probability.value

@mlm_probability.setter
def mlm_probability(self, _):
return


class ConcatenatedSequenceCollatorWrapper:
"""Collator wrapper to add sequence_id to batch."""

Expand Down Expand Up @@ -293,11 +311,15 @@ def build_text_dataloader(
shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
)

mlm_probability = cfg.dataset.get('mlm_probability', None)
collate_fn = transformers.DataCollatorForLanguageModeling(
mlm_schedule = cfg.dataset.get('mlm_schedule', None)
distributed_mlm_probability = None
if mlm_schedule:
distributed_mlm_probability = multiprocessing.Value(
"d", mlm_schedule.initial_masking_rate)
collate_fn = ScheduledDataCollatorForLanguageModeling(
tokenizer=dataset.tokenizer,
mlm=mlm_probability is not None,
mlm_probability=mlm_probability)
mlm=mlm_schedule is not None,
distributed_mlm_probability=distributed_mlm_probability)

eos_token_id = cfg.dataset.get('eos_token_id')
bos_token_id = cfg.dataset.get('bos_token_id')
Expand All @@ -318,7 +340,7 @@ def build_text_dataloader(
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)
), distributed_mlm_probability


# Helpful to test if your dataloader is working locally
Expand Down