Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Oct 12, 2024
1 parent 1877be7 commit 7668896
Showing 1 changed file with 17 additions and 21 deletions.
38 changes: 17 additions & 21 deletions llmfoundry/callbacks/sliding_window_size_warmup_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import Any

from composer import Callback, Logger, State, Time
from composer import Callback, Logger, State, TimeUnit
from git import Optional

from llmfoundry.utils.warnings import experimental_class
Expand All @@ -24,25 +24,20 @@ class SlidingWindowSizeWarmerUpper(Callback):
def __init__(
self,
train_config: dict[str, Any],
t_warmup: Optional[str | Time] = None,
t_warmup: float,
):
if t_warmup is None:
if 'scheduler' in train_config and 't_warmup' in train_config[
'scheduler']:
t_warmup = train_config['scheduler']['t_warmup']
else:
raise ValueError(
't_warmup must be provided if t_warmup is not in scheduler config',
)
if isinstance(t_warmup, str):
t_warmup = Time.from_timestring(t_warmup)
self.t_warmup = t_warmup.BATCH
self.t_warmup = t_warmup
self.max_seq_len = train_config['max_seq_len']
self.orig_sliding_window_size_list = None

def before_train_batch(self, state: State, logger: Logger):
del logger
current_batch = state.timestamp.batch
if state.max_duration.unit == TimeUnit.TOKEN:
current_time_frac = state.timestamp.token.value / state.max_duration.value
elif state.max_duration.unit == TimeUnit.BATCH:
current_time_frac = state.timestamp.batch.value / state.max_duration.value
else:
raise ValueError(f'Unsupported time unit {state.max_duration.unit}')
if self.orig_sliding_window_size_list is None:
self.orig_sliding_window_size_list = [None] * len(
state.model.model.transformer.blocks,
Expand All @@ -56,11 +51,12 @@ def before_train_batch(self, state: State, logger: Logger):
self.orig_sliding_window_size_list[
idx] = attn_block.sliding_window_size
if self.orig_sliding_window_size_list[idx] != -1:
attn_block.sliding_window_size = max(
self.orig_sliding_window_size_list[idx],
self.max_seq_len - (
self.max_seq_len -
self.orig_sliding_window_size_list[idx]
) * current_batch / self.t_warmup,
attn_block.sliding_window_size = round(
max(
self.orig_sliding_window_size_list[idx],
self.max_seq_len - (
self.max_seq_len -
self.orig_sliding_window_size_list[idx]
) * current_time_frac / self.t_warmup,
),
)
block.norm_attn_norm.attn.sliding_window_size

0 comments on commit 7668896

Please sign in to comment.