Skip to content

Commit

Permalink
small updates
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Jan 2, 2024
1 parent 91cd0f1 commit 6ccadec
Showing 1 changed file with 5 additions and 37 deletions.
42 changes: 5 additions & 37 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
import os
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -186,27 +187,6 @@ def validate_interval(interval: Union[str, int, Time],
return result


def validate_check_interval(check_interval: Union[str, int, Time],
interval: Time) -> Time:
if isinstance(check_interval, str):
result: Time = Time.from_timestring(check_interval)
elif isinstance(check_interval, int):
result: Time = Time(check_interval, TimeUnit.EPOCH)
else:
result: Time = check_interval

if result.unit not in SUPPORTED_UNITS:
raise ValueError(
f'Async eval check interval must be in units {SUPPORTED_UNITS}')
if interval.unit != result.unit:
raise ValueError('Check interval and interval must be in the same unit')
if result > interval:
raise ValueError(
'Async eval interval must be equal or greater (less frequent) than async check interval'
)
return result


class AsyncEval(Callback):
"""Run the eval loop asynchronously as part of a MosaicML platform run.
Expand All @@ -217,9 +197,6 @@ class AsyncEval(Callback):
interval: Union[str, int, Time]: The interval describing how often eval runs should be
launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH` or :attr:`.TimeUnit.BATCH`.
check_interval: Optional[Union[str, int, Time]]: The interval describing how often
to check if an eval run should be launched. If not provided, it will be set to
check 5 times each :attr:`interval` (floored to the nearest integer)
compute: Optional[Union[ComputeConfig, Dict[str, Any]]]: The compute configuration to
use for the eval run. If not provided, the same cluster as the current run and a
single, full GPU node will be used.
Expand All @@ -229,7 +206,6 @@ def __init__(
self,
training_config: Dict[str, Any],
interval: Union[str, int, Time],
check_interval: Optional[Union[str, int, Time]] = None,
compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None,
):

Expand All @@ -253,17 +229,9 @@ def __init__(
self.interval = validate_interval(interval,
self.training_config['save_interval'])

# Validate and compute the check interval (how often to check for new checkpoints)
if check_interval is None:
unit = self.interval.value // 5
if unit == 0:
unit = 1
check_interval = Time(unit, self.interval.unit)
log.info(
f'No check interval provided, defaulting to {check_interval}')

self.check_interval = validate_check_interval(check_interval,
self.interval)
# Configures how often to check for new checkpoints
self.check_interval = Time(max(self.interval.value // 5, 1),
self.interval.unit)

# Keep track of checkpoints by interval that have already been evaled
# Format: {interval: (checkpoint, run_name)}
Expand Down Expand Up @@ -297,7 +265,7 @@ def _get_checkpoints_and_launch_runs(self, state: State):
break

if not checkpointer:
log.warning('No checkpoint saver callback found. Skipping eval')
warnings.warn('No checkpoint saver callback found. Skipping eval')
return

if not checkpointer.saved_checkpoints:
Expand Down

0 comments on commit 6ccadec

Please sign in to comment.