Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Free outputs callback #2598

Merged
merged 10 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from composer.callbacks.checkpoint_saver import CheckpointSaver
from composer.callbacks.early_stopper import EarlyStopper
from composer.callbacks.export_for_inference import ExportForInferenceCallback
from composer.callbacks.free_outputs import FreeOutputs
from composer.callbacks.generate import Generate
from composer.callbacks.health_checker import HealthChecker
from composer.callbacks.image_visualizer import ImageVisualizer
Expand Down Expand Up @@ -38,4 +39,5 @@
'RuntimeEstimator',
'SystemMetricsMonitor',
'Generate',
'FreeOutputs',
]
16 changes: 16 additions & 0 deletions composer/callbacks/free_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Free train metrics ."""
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved

import torch

from composer.core import Callback, State
from composer.loggers import Logger


class FreeOutputs(Callback):
"""Free train metrics on AFTER_LOSS to reduce peak memory usage if not using train metrics."""
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved

def after_loss(self, state: State, logger: Logger) -> None:
state.outputs = torch.Tensor()
1 change: 1 addition & 0 deletions composer/callbacks/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

"""Periodically log generations from a set of prompts."""

mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
from typing import Any, List, Optional, Union, cast

from composer.callbacks.utils import create_interval_scheduler
Expand Down
10 changes: 7 additions & 3 deletions tests/callbacks/callback_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import composer.loggers
import composer.profiler
from composer import Callback
from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, Generate, HealthChecker, ImageVisualizer,
MemoryMonitor, MLPerfCallback, SpeedMonitor, SystemMetricsMonitor, ThresholdStopper)
from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, FreeOutputs, Generate, HealthChecker,
ImageVisualizer, MemoryMonitor, MLPerfCallback, SpeedMonitor, SystemMetricsMonitor,
ThresholdStopper)
from composer.loggers import (CometMLLogger, ConsoleLogger, LoggerDestination, MLFlowLogger, ProgressBarLogger,
RemoteUploaderDownloader, TensorboardLogger, WandBLogger)
from composer.models.base import ComposerModel
Expand Down Expand Up @@ -223,5 +224,8 @@ def get_cb_model_and_datasets(cb: Callback,
)
return (configure_tiny_gpt2_hf_model(), dummy_gpt_lm_dataloader(size=dl_size),
dummy_gpt_lm_dataloader(size=dl_size))
return (SimpleModel(), DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs),
model = SimpleModel()
if isinstance(cb, FreeOutputs):
model.get_metrics = lambda is_train=False: {}
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
return (model, DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs),
DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs))