Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Oct 13, 2024
1 parent 7668896 commit af5dade
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions llmfoundry/callbacks/sliding_window_size_warmup_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from typing import Any

from composer import Callback, Logger, State, TimeUnit
from git import Optional
from composer import Logger, State, TimeUnit

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.utils.warnings import experimental_class

__all__ = [
Expand All @@ -14,11 +14,12 @@


@experimental_class('SlidingWindowSizeWarmerUpper')
class SlidingWindowSizeWarmerUpper(Callback):
class SlidingWindowSizeWarmerUpper(CallbackWithConfig):
"""Warms up the sliding window size for the model based on a schedule.
Args:
t_warmup (str|None): Warmup duration for sliding window size, defaults to scheduler.t_warmup.
train_config (dict[str, Any]): Training configuration.
t_warmup (float): Warmup duration (as a fraction of max duration) for sliding window size.
"""

def __init__(
Expand All @@ -32,12 +33,18 @@ def __init__(

def before_train_batch(self, state: State, logger: Logger):
del logger
if state.max_duration is None:
raise ValueError(
'SlidingWindowSizeWarmerUpper callback requires max_duration to be set.',
)
if state.max_duration.unit == TimeUnit.TOKEN:
current_time_frac = state.timestamp.token.value / state.max_duration.value
elif state.max_duration.unit == TimeUnit.BATCH:
current_time_frac = state.timestamp.batch.value / state.max_duration.value
else:
raise ValueError(f'Unsupported time unit {state.max_duration.unit}')
raise ValueError(
f'Unsupported time unit {state.max_duration.unit=} for SlidingWindowSizeWarmerUpper callback.',
)
if self.orig_sliding_window_size_list is None:
self.orig_sliding_window_size_list = [None] * len(
state.model.model.transformer.blocks,
Expand All @@ -50,7 +57,7 @@ def before_train_batch(self, state: State, logger: Logger):
if self.orig_sliding_window_size_list[idx] is None:
self.orig_sliding_window_size_list[
idx] = attn_block.sliding_window_size
if self.orig_sliding_window_size_list[idx] != -1:
if self.orig_sliding_window_size_list[idx] != -1: # type: ignore
attn_block.sliding_window_size = round(
max(
self.orig_sliding_window_size_list[idx],
Expand Down

0 comments on commit af5dade

Please sign in to comment.