diff --git a/llmfoundry/callbacks/sliding_window_size_warmup_callback.py b/llmfoundry/callbacks/sliding_window_size_warmup_callback.py index a6962c9943..1d2593bbc3 100644 --- a/llmfoundry/callbacks/sliding_window_size_warmup_callback.py +++ b/llmfoundry/callbacks/sliding_window_size_warmup_callback.py @@ -3,9 +3,9 @@ from typing import Any -from composer import Callback, Logger, State, TimeUnit -from git import Optional +from composer import Logger, State, TimeUnit +from llmfoundry.interfaces import CallbackWithConfig from llmfoundry.utils.warnings import experimental_class __all__ = [ @@ -14,11 +14,12 @@ @experimental_class('SlidingWindowSizeWarmerUpper') -class SlidingWindowSizeWarmerUpper(Callback): +class SlidingWindowSizeWarmerUpper(CallbackWithConfig): """Warms up the sliding window size for the model based on a schedule. Args: - t_warmup (str|None): Warmup duration for sliding window size, defaults to scheduler.t_warmup. + train_config (dict[str, Any]): Training configuration. + t_warmup (float): Warmup duration (as a fraction of max duration) for sliding window size. """ def __init__( @@ -32,12 +33,18 @@ def __init__( def before_train_batch(self, state: State, logger: Logger): del logger + if state.max_duration is None: + raise ValueError( + 'SlidingWindowSizeWarmerUpper callback requires max_duration to be set.', + ) if state.max_duration.unit == TimeUnit.TOKEN: current_time_frac = state.timestamp.token.value / state.max_duration.value elif state.max_duration.unit == TimeUnit.BATCH: current_time_frac = state.timestamp.batch.value / state.max_duration.value else: - raise ValueError(f'Unsupported time unit {state.max_duration.unit}') + raise ValueError( + f'Unsupported time unit {state.max_duration.unit=} for SlidingWindowSizeWarmerUpper callback.', + ) if self.orig_sliding_window_size_list is None: self.orig_sliding_window_size_list = [None] * len( state.model.model.transformer.blocks, @@ -50,7 +57,7 @@ def before_train_batch(self, state: State, logger: Logger): if self.orig_sliding_window_size_list[idx] is None: self.orig_sliding_window_size_list[ idx] = attn_block.sliding_window_size - if self.orig_sliding_window_size_list[idx] != -1: + if self.orig_sliding_window_size_list[idx] != -1: # type: ignore attn_block.sliding_window_size = round( max( self.orig_sliding_window_size_list[idx],