Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add profiler support in llm foundry #678

Merged
merged 17 commits into from
Oct 18, 2023
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,16 @@ composer train/train.py \
eval_interval=0 \
save_folder=mpt-125m

# To iterate on training on CPU run command (dev only):
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
composer train/train.py \
train/yamls/pretrain/mpt-125m-cpu.yaml \
data_local=my-copy-c4 \
train_loader.dataset.split=train_small \
eval_loader.dataset.split=val_small \
max_duration=10ba \
eval_interval=0 \
save_folder=mpt-125m

# Convert the model to HuggingFace format
python inference/convert_composer_to_hf.py \
--composer_path mpt-125m/ep0-ba10-rank0.pt \
Expand Down
3 changes: 3 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_icl_task_dataloader
from composer.loggers import (InMemoryLogger, LoggerDestination, MLFlowLogger,
TensorboardLogger, WandBLogger)
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.optim import DecoupledAdamW
from composer.optim.scheduler import (ComposerScheduler,
ConstantWithWarmupScheduler,
Expand Down Expand Up @@ -133,6 +134,8 @@ def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination:
return MLFlowLogger(**kwargs)
elif name == 'inmemory':
return InMemoryLogger(**kwargs)
elif name == 's3':
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
return RemoteUploaderDownloader(**kwargs)
else:
raise ValueError(f'Not sure how to build logger: {name}')

Expand Down
20 changes: 20 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from composer import Trainer
from composer.core import Evaluator
from composer.core.callback import Callback
from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler,
cyclic_schedule)
from composer.utils import dist, get_device, reproducibility
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
Expand Down Expand Up @@ -454,6 +456,23 @@ def main(cfg: DictConfig) -> Trainer:
for name, logger_cfg in logger_configs.items()
] if logger_configs else None

# Profiling
profiler: Optional[Profiler] = None
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
profiler_cfg = cfg.get('profiler', None)
if profiler_cfg:
profiler_schedule_cfg: Dict = profiler_cfg.pop('schedule')
profiler_schedule = cyclic_schedule(**profiler_schedule_cfg)
# Only support json trace handler
profiler_trace_handlers: List[TraceHandler] = []
profiler_trace_cfg: Optional[Dict] = profiler_cfg.pop(
'json_trace_handler', None)
if profiler_trace_cfg:
profiler_trace_handlers.append(
JSONTraceHandler(**profiler_trace_cfg))
profiler = Profiler(**profiler_cfg,
trace_handlers=profiler_trace_handlers,
schedule=profiler_schedule)

# Callbacks
callbacks: List[Callback] = [
build_callback(str(name), callback_cfg)
Expand Down Expand Up @@ -571,6 +590,7 @@ def main(cfg: DictConfig) -> Trainer:
autoresume=autoresume,
python_log_level=python_log_level,
dist_timeout=dist_timeout,
profiler=profiler,
)

print('Logging config')
Expand Down
154 changes: 154 additions & 0 deletions scripts/train/yamls/pretrain/mpt-125m-cpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
data_local: ./my-copy-c4
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
data_remote: # If blank, files must be present in data_local
max_seq_len: 128
global_seed: 17

# Run Name
run_name: mpt_causal_lm_cpu # If left blank, will be read from env var $RUN_NAME

# Model
model:
name: mpt_causal_lm
init_device: cpu
d_model: 16
n_heads: 4
n_layers: 4
expansion_ratio: 5
max_seq_len: ${max_seq_len}
vocab_size: 50368
attn_config:
attn_impl: torch
loss_fn: torch_crossentropy

# Tokenizer
tokenizer:
name: EleutherAI/gpt-neox-20b
kwargs:
model_max_length: ${max_seq_len}

# Dataloaders
train_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split: train
shuffle: true
max_seq_len: ${max_seq_len}
shuffle_seed: ${global_seed}
drop_last: true
num_workers: 2

eval_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split: val
shuffle: false
max_seq_len: ${max_seq_len}
shuffle_seed: ${global_seed}
drop_last: false
num_workers: 2

# Optimization
scheduler:
name: cosine_with_warmup
t_warmup: 100ba
alpha_f: 0.1

optimizer:
name: decoupled_adamw
lr: 6.0e-4
betas:
- 0.9
- 0.95
eps: 1.0e-08
weight_decay: 0.0

algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1.0

max_duration: 10ba
eval_interval: 5ba
eval_first: false
eval_subset_num_batches: 5
global_train_batch_size: 256
autoresume: false

# System
seed: ${global_seed}
device_eval_batch_size: 16
device_train_microbatch_size: 16
# device_train_microbatch_size: auto
precision: fp32

# FSDP
fsdp_config:
sharding_strategy: FULL_SHARD
mixed_precision: PURE
activation_checkpointing: false
activation_checkpointing_reentrant: false
activation_cpu_offload: false
limit_all_gathers: true
verbose: false

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
speed_monitor:
window_size: 10
lr_monitor: {}
memory_monitor: {}
runtime_estimator: {}

# loggers:
# wandb: {}

# Profiler
profiler:
sys_prof_cpu: true
sys_prof_memory: true
sys_prof_disk: true
sys_prof_net: true
sys_prof_stats_thread_interval_seconds: 0.5
torch_prof_folder: '{run_name}/torch_traces'
torch_prof_filename: 'rank{rank}.batch{batch}.pt.trace.json'
torch_prof_remote_file_name: '{run_name}/torch_traces/rank{rank}.batch{batch}.pt.trace.json'
torch_prof_overwrite: true
torch_prof_use_gzip: false
torch_prof_record_shapes: true
torch_prof_profile_memory: true
torch_prof_with_stack: true
torch_prof_with_flops: true
torch_prof_num_traces_to_keep: -1 # -1 means keep all traces
schedule:
skip_first: 3
wait: 2
warmup: 2
active: 1
repeat: 1
json_trace_handler:
folder: '{run_name}/composer_traces'
filename: 'ep{epoch}-ba{batch}-rank{rank}.json'
remote_file_name: '{run_name}/traces/ep{epoch}-ba{batch}-rank{rank}.json'
merged_trace_filename: 'merged_trace.json'
merged_trace_remote_file_name: '{run_name}/traces/merged_trace.json'
overwrite: true
num_traces_to_keep: -1

# Checkpoint to local filesystem or remote object store
save_overwrite: true
save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK
save_interval: 5ba
save_folder: ./{run_name}/checkpoints
# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints

# Load from local filesystem or remote object store
# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt
# load_path: s3://my-bucket/my-folder/gpt-125m/checkpoints/latest-rank{rank}.pt
Loading