From 0a9a5439f189ccafb26d642531e37deda5f18b8c Mon Sep 17 00:00:00 2001 From: Chuck Tang Date: Mon, 16 Oct 2023 22:33:42 -0700 Subject: [PATCH] ok --- scripts/train/train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 20a94b3a56..3204959903 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -5,13 +5,14 @@ import os import sys import warnings -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch from composer import Trainer from composer.core import Evaluator from composer.core.callback import Callback -from composer.profiler import TraceHandler, JSONTraceHandler, Profiler, cyclic_schedule +from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler, + cyclic_schedule) from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om @@ -466,7 +467,8 @@ def main(cfg: DictConfig) -> Trainer: profiler_trace_cfg: Optional[Dict] = profiler_cfg.pop( 'json_trace_handler', None) if profiler_trace_cfg: - profiler_trace_handlers.append(JSONTraceHandler(**profiler_trace_cfg)) + profiler_trace_handlers.append( + JSONTraceHandler(**profiler_trace_cfg)) profiler = Profiler(**profiler_cfg, trace_handlers=profiler_trace_handlers, schedule=profiler_schedule)