Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
j316chuck committed Oct 17, 2023
1 parent 0cb3a4d commit d1c2517
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ composer train/train.py \
eval_interval=0 \
save_folder=mpt-125m

# To iterate on training on CPU run command (dev only):
# To iterate on training on CPU run command (dev only):
composer train/train.py \
train/yamls/pretrain/mpt-125m-cpu.yaml \
data_local=my-copy-c4 \
Expand Down
16 changes: 10 additions & 6 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from composer import Trainer
from composer.core import Evaluator
from composer.core.callback import Callback
from composer.profiler import Profiler, cyclic_schedule, JSONTraceHandler
from composer.profiler import JSONTraceHandler, Profiler, cyclic_schedule
from composer.utils import dist, get_device, reproducibility
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
Expand Down Expand Up @@ -454,21 +454,25 @@ def main(cfg: DictConfig) -> Trainer:
build_logger(str(name), logger_cfg)
for name, logger_cfg in logger_configs.items()
] if logger_configs else None

# Profiling
profiler: Optional[Profiler] = None
profiler_cfg = cfg.get("profiler", None)
profiler_cfg = cfg.get('profiler', None)
if profiler_cfg:
profiler_schedule: Optional[Callable] = None
profiler_schedule_cfg: Optional[Dict] = profiler_cfg.pop('schedule', None)
profiler_schedule_cfg: Optional[Dict] = profiler_cfg.pop(
'schedule', None)
if profiler_schedule_cfg:
profiler_schedule = cyclic_schedule(**profiler_schedule_cfg)
# Only support json trace handler
profiler_trace_handler: Optional[JSONTraceHandler] = None
profiler_trace_cfg: Optional[Dict] = profiler_cfg.pop('json_trace_handler', None)
profiler_trace_cfg: Optional[Dict] = profiler_cfg.pop(
'json_trace_handler', None)
if profiler_trace_cfg:
profiler_trace_handler = JSONTraceHandler(**profiler_trace_cfg)
profiler = Profiler(**profiler_cfg, trace_handlers=[profiler_trace_handler], schedule=profiler_schedule)
profiler = Profiler(**profiler_cfg,
trace_handlers=[profiler_trace_handler],
schedule=profiler_schedule)

# Callbacks
callbacks: List[Callback] = [
Expand Down

0 comments on commit d1c2517

Please sign in to comment.