diff --git a/llmfoundry/callbacks/sliding_window_size_warmup_callback.py b/llmfoundry/callbacks/sliding_window_size_warmup_callback.py index e2cacedfe1..a6962c9943 100644 --- a/llmfoundry/callbacks/sliding_window_size_warmup_callback.py +++ b/llmfoundry/callbacks/sliding_window_size_warmup_callback.py @@ -3,7 +3,7 @@ from typing import Any -from composer import Callback, Logger, State, Time +from composer import Callback, Logger, State, TimeUnit from git import Optional from llmfoundry.utils.warnings import experimental_class @@ -24,25 +24,20 @@ class SlidingWindowSizeWarmerUpper(Callback): def __init__( self, train_config: dict[str, Any], - t_warmup: Optional[str | Time] = None, + t_warmup: float, ): - if t_warmup is None: - if 'scheduler' in train_config and 't_warmup' in train_config[ - 'scheduler']: - t_warmup = train_config['scheduler']['t_warmup'] - else: - raise ValueError( - 't_warmup must be provided if t_warmup is not in scheduler config', - ) - if isinstance(t_warmup, str): - t_warmup = Time.from_timestring(t_warmup) - self.t_warmup = t_warmup.BATCH + self.t_warmup = t_warmup self.max_seq_len = train_config['max_seq_len'] self.orig_sliding_window_size_list = None def before_train_batch(self, state: State, logger: Logger): del logger - current_batch = state.timestamp.batch + if state.max_duration.unit == TimeUnit.TOKEN: + current_time_frac = state.timestamp.token.value / state.max_duration.value + elif state.max_duration.unit == TimeUnit.BATCH: + current_time_frac = state.timestamp.batch.value / state.max_duration.value + else: + raise ValueError(f'Unsupported time unit {state.max_duration.unit}') if self.orig_sliding_window_size_list is None: self.orig_sliding_window_size_list = [None] * len( state.model.model.transformer.blocks, @@ -56,11 +51,12 @@ def before_train_batch(self, state: State, logger: Logger): self.orig_sliding_window_size_list[ idx] = attn_block.sliding_window_size if self.orig_sliding_window_size_list[idx] != -1: - attn_block.sliding_window_size = max( - self.orig_sliding_window_size_list[idx], - self.max_seq_len - ( - self.max_seq_len - - self.orig_sliding_window_size_list[idx] - ) * current_batch / self.t_warmup, + attn_block.sliding_window_size = round( + max( + self.orig_sliding_window_size_list[idx], + self.max_seq_len - ( + self.max_seq_len - + self.orig_sliding_window_size_list[idx] + ) * current_time_frac / self.t_warmup, + ), ) - block.norm_attn_norm.attn.sliding_window_size