This repository has been archived by the owner on Jul 3, 2024. It is now read-only.
forked from microsoft/Megatron-DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
/
arguments.py
1543 lines (1395 loc) · 82.7 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron arguments."""
import argparse
import dataclasses
import json
import os
import torch
import deepspeed
import types
from packaging import version
import torch.nn.functional as F
from megatron.global_vars import set_retro_args, get_retro_args
from tools.retro.utils import get_args_path as get_retro_args_path
from megatron.core.transformer import TransformerConfig
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)
# Standard arguments.
parser = _add_network_size_args(parser)
parser = _add_regularization_args(parser)
parser = _add_training_args(parser)
parser = _add_initialization_args(parser)
parser = _add_learning_rate_args(parser)
parser = _add_checkpointing_args(parser)
parser = _add_mixed_precision_args(parser)
parser = _add_distributed_args(parser)
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vision_args(parser)
parser = _add_logging_args(parser)
parser = _add_zero_args(parser)
parser = _add_memoryopt_args(parser)
parser = _add_activation_checkpoint_args(parser)
parser = _add_distillation_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
parser = deepspeed.add_config_arguments(parser)
# Parse.
if ignore_unknown_args:
args, _ = parser.parse_known_args()
else:
args = parser.parse_args()
# helper argument to set deepspeed pipeline parallel or not
args.ds_pipeline_enabled = not args.no_pipeline_parallel
# Args from environment
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
return args
def validate_args(args, defaults={}):
# Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
' ({}) is not divisible by tensor model parallel size ({})'.format(
args.world_size, args.tensor_model_parallel_size)
# Zero bubble pipeline is defined on deepspeed's scheduler
if args.enable_zbh1_pipeline:
assert args.deepspeed, 'Use DeepSpeed to use zero-bubble H1 pipeline'
assert args.sequence_parallel == False, "Sequence Parallel not tested, proceed at own will by removing this line"
if args.enable_zbh1_exact_semantics:
assert args.enable_zbh1_pipeline, 'Exact semantics require ZBH1 pipeline enabled'
# Pipeline model parallel size.
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_size
)
# Checks.
if args.no_pipeline_parallel:
assert args.pipeline_model_parallel_size == 1, \
"pipeline_model_parallel_size must be 1 if pipeline parallel is disabled"
if args.ds_sequence_parallel_size > 1:
assert version.parse(deepspeed.__version__) >= version.parse("0.10.2"), "sequence parallelism requires DeepSpeed version 0.10.2+"
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size * \
args.ds_sequence_parallel_size
assert args.world_size % model_parallel_size == 0, 'world size ({}) is not'\
' divisible by tensor parallel size ({}) times pipeline parallel ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // model_parallel_size
if args.rank == 0:
print('using world size: {}, data-parallel-size: {}, '
'sequence-parallel size: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.ds_sequence_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
args.pipeline_model_parallel_size, 'split rank needs'\
' to be less than pipeline model parallel size ({})'.format(
args.pipeline_model_parallel_size)
# Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \
'valid, use --micro-batch-size instead'
del args.batch_size
assert args.warmup is None, '--warmup argument is no longer valid, use ' \
'--lr-warmup-fraction instead'
del args.warmup
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
# HACK: below is commented because DeepSpeed still relies on the old
# activation checkpointing mechanism.
# if args.checkpoint_activations:
# if args.rank == 0:
# print('--checkpoint-activations is no longer valid, use --recompute-activations, '
# 'or, for more control, --recompute-granularity and --recompute-method.')
# exit()
# del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
del args.recompute_activations
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
if args.rank == 0:
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.pipeline_model_parallel_size > 2, \
'pipeline-model-parallel size should be greater than 2 with ' \
'interleaved schedule'
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage'
args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.transformer_pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage
else:
args.virtual_pipeline_model_parallel_size = None
# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
# bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.', flush=True)
if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
# If we do accumulation and all-reduces in fp32, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is not off.
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
# If we use the distributed optimizer, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is on.
if args.use_distributed_optimizer:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
# For torch DDP, we do not use contiguous buffer
# if args.DDP_impl == 'torch':
if args.DDP_impl != 'local':
args.use_contiguous_buffers_in_local_ddp = False
if args.dataloader_type is None:
args.dataloader_type = 'single'
# Consumed tokens.
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
args.consumed_train_tokens = 0
# Support for variable sequence lengths across batches/microbatches.
# set it if the dataloader supports generation of variable sequence lengths
# across batches/microbatches. Due to additional communication overhead
# during pipeline parallelism, it should not be set if sequence length
# is constant during training.
args.variable_seq_lengths = False
# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
# sample-based options are off.
assert args.train_samples is None, \
'expected iteration-based training'
assert args.lr_decay_samples is None, \
'expected iteration-based learning rate decay'
assert args.lr_warmup_samples == 0, \
'expected iteration-based learning rate warmup'
assert args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training'
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_iters == 0, \
'can only specify one of lr-warmup-fraction and lr-warmup-iters'
# Sample-based training.
if args.train_samples:
# If we use sample-based training, make sure the
# iteration-based options are off.
assert args.train_iters is None, \
'expected sample-based training'
assert args.lr_decay_iters is None, \
'expected sample-based learning rate decay'
assert args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup'
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_samples == 0, \
'can only specify one of lr-warmup-fraction ' \
'and lr-warmup-samples'
if args.num_layers is not None:
assert args.encoder_num_layers is None, \
'cannot have both num-layers and encoder-num-layers specified'
args.encoder_num_layers = args.num_layers
else:
if not args.use_dataset_only:
assert args.encoder_num_layers is not None, \
'either num-layers or encoder-num-layers should be specified'
args.num_layers = args.encoder_num_layers
# Check required arguments.
if not args.use_dataset_only:
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings']
for req_arg in required_args:
_check_arg_is_not_none(args, req_arg)
# Checks.
if not args.use_dataset_only:
if args.ffn_hidden_size is None:
if args.swiglu:
# reduce the dimnesion for MLP since projections happens on
# two linear layers. this keeps the number of paramters in
# the same ballpark as the counterpart with 4*h size
# we keep it a multiple of 64, which means the actual tensor size
# will be a multiple of 64 / tp_size
args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64
else:
args.ffn_hidden_size = 4 * args.hidden_size
if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads
if args.seq_length is not None:
assert args.encoder_seq_length is None
args.encoder_seq_length = args.seq_length
else:
assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length
if not args.use_dataset_only:
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
if args.decoder_seq_length is not None:
assert args.max_position_embeddings >= args.decoder_seq_length
# When rotary position embeddings is used, set add_position_embedding
# to false to turn off absolute position embedding.
if args.use_rotary_position_embeddings:
args.add_position_embedding = False
if args.lr is not None:
assert args.min_lr <= args.lr
if args.save is not None:
assert args.save_interval is not None
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
if not args.use_dataset_only:
if args.weight_decay_incr_style == 'constant':
assert args.start_weight_decay is None
assert args.end_weight_decay is None
args.start_weight_decay = args.weight_decay
args.end_weight_decay = args.weight_decay
else:
assert args.start_weight_decay is not None
assert args.end_weight_decay is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm.
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
args.no_persist_layer_norm = True
if args.rank == 0:
print('Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \
'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations'
# Activation recomputing.
if args.distribute_saved_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'recomputed activations only across tensor model ' \
'parallel groups'
assert args.recompute_granularity == 'full', \
'distributed recompute activations is only '\
'application to full recompute granularity'
assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
# Tranformer-Engine/FP8 related checking
if args.fp8_e4m3 or args.fp8_hybrid:
assert args.transformer_impl == 'transformer_engine', \
'transformer-engine required for fp8 training and inference'
assert not (args.fp8_e4m3 and args.fp8_hybrid), \
'cannot train with both fp8 e4m3 and hybrid formatting'
if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity'
# disable sequence parallelism when tp=1
# to avoid change in numerics when
# sequence_parallelism is enabled.
if args.tensor_model_parallel_size == 1:
args.sequence_parallel = False
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
# TODO: currently DeepSpeed seems to be incompatible with
# async_tensor_model_parallel_allreduce thus temporarily disabling it.
# Need further investigation.
if args.deepspeed:
args.async_tensor_model_parallel_allreduce = False
if not args.use_dataset_only:
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
"Using sequence parallelism requires setting the environment variable "
"CUDA_DEVICE_MAX_CONNECTIONS to 1")
if args.async_tensor_model_parallel_allreduce:
raise RuntimeError(
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
# Disable bias gelu fusion if we are disabling bias altogether
if not args.add_bias_linear:
args.bias_gelu_fusion = False
# Load retro args.
if args.retro_workdir:
retro_args_path = get_retro_args_path(args.retro_workdir)
if os.path.exists(retro_args_path):
with open(retro_args_path) as f:
retro_args = types.SimpleNamespace(**json.load(f))
retro_args.retro_return_doc_ids = args.retro_return_doc_ids
retro_args.retro_gpt_retrieved_length = \
args.retro_num_retrieved_chunks * \
retro_args.retro_gpt_chunk_length
set_retro_args(retro_args)
args.curriculum_learning_legacy = False
args.compression_training = False
# FlashAttention
args.use_flash_attn = args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or args.use_flash_attn_builder
# AML
if args.aml_data_download_path is not None:
data_paths = []
for path in args.data_path:
data_paths.append(f"{args.aml_data_download_path}/{path}")
args.data_path = data_paths
# GQA
if not args.use_dataset_only:
if args.num_key_value_heads is None:
args.num_key_value_heads = args.num_attention_heads
assert args.num_attention_heads % args.num_key_value_heads == 0, \
f"num_attention_heads must be divisible by num_key_value_heads (got `num_attention_heads`: {args.num_attention_heads} " \
f"and `num_key_value_heads`: {args.num_key_value_heads})."
if args.num_key_value_heads != args.num_attention_heads:
# if GQA
assert not args.mos, 'GQA currently does not support args.mos'
assert not args.kd, 'GQA currently does not support args.kd'
# Print arguments.
_print_args("arguments", args)
retro_args = get_retro_args()
if retro_args and args != retro_args:
_print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank))
return args
def _print_args(title, args):
"""Print arguments."""
if args.rank == 0:
print(f'------------------------ {title} ------------------------',
flush=True)
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print(f'-------------------- end of {title} ---------------------',
flush=True)
def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def core_transformer_config_from_args(args):
# Translate args to core transformer configuration
kw_args = {}
for f in dataclasses.fields(TransformerConfig):
if hasattr(args, f.name):
kw_args[f.name] = getattr(args, f.name)
kw_args['persist_layer_norm'] = not args.no_persist_layer_norm
kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p
kw_args['deallocate_pipeline_outputs'] = True
kw_args['pipeline_dtype'] = args.params_dtype
kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
if args.swiglu:
kw_args['activation_func'] = F.silu
kw_args['gated_linear_unit'] = True
kw_args['bias_gelu_fusion'] = False
if args.init_method_xavier_uniform:
kw_args['init_method'] = torch.nn.init.xavier_uniform_
kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
return TransformerConfig(**kw_args)
def _add_transformer_engine_args(parser):
group = parser.add_argument_group(title='Transformer-Engine')
group.add_argument('--fp8-e4m3', action='store_true',
help='E4M3 TransformerLayer', dest='fp8_e4m3')
group.add_argument('--fp8-hybrid', action='store_true',
help='Hybrid FP8 TransformerLayer', dest='fp8_hybrid')
group.add_argument('--no-fp8-wgrad', action='store_false',
help='Execute wgrad in higher precision even for FP8 runs', dest='fp8_wgrad')
group.add_argument('--fp8-margin', type=int, default=0,
help='Scaling margin for fp8', dest='fp8_margin')
group.add_argument('--fp8-interval', type=int, default=1,
help='Scaling update interval for fp8', dest='fp8_interval')
group.add_argument('--transformer-impl', default='local',
choices=['local', 'transformer_engine'],
help='Which Transformer implementation to use.',
dest='transformer_impl')
group.add_argument('--fp8-amax-history-len', type=int, default=1,
help='Number of steps for which amax history is recorded per tensor',
dest='fp8_amax_history_len')
group.add_argument('--fp8-amax-compute-algo', default='most_recent',
choices=['most_recent', 'max'],
help='Algorithm for computing amax from history',
dest='fp8_amax_compute_algo')
return parser
def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')
group.add_argument('--inference-batch-times-seqlen-threshold',
type=int, default=512,
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')
group.add_argument('--max-tokens-to-oom',
type=int, default=12000,
help='Maximum number of tokens during inference'
'tokens here is # in prompt + # to generate'
'Allows us to throw an error before OOM crashes server')
group.add_argument('--output-bert-embeddings', action='store_true',
help='Output Bert embeddings (via mean pooling) from '
'model, rather than its binary head output or entire '
'hidden batch.')
group.add_argument('--bert-embedder-type', default="megatron",
choices=["megatron", "huggingface"],
help='Select either Megatron or Huggingface as the '
'Bert embedder.')
return parser
def _add_retro_args(parser):
group = parser.add_argument_group(title='retro')
group.add_argument('--retro-workdir', default=None,
help='Retro working directory, which contains the '
'preprocessed data for for pretraining. This directory '
'is built during preprocessing (see '
'tools/retro/README.md), and contains subdirectories '
'for the chunk database and pretraining neighbors.')
group.add_argument('--retro-add-retriever',
action='store_true', default=False,
help='Add a retriever to the transformer, for use in '
'pretraining a Retro model.')
group.add_argument('--retro-cyclic-train-iters', type=int, default=None,
help='Set number of training iterations for cyclic '
'Retro training.')
group.add_argument('--retro-encoder-layers', type=int, default=2,
help='Number of layers to use for the retrieval '
'encoder.')
group.add_argument('--retro-encoder-hidden-dropout',
type=float, default=0.1, help='Hidden dropout for '
'retrieval encoder.')
group.add_argument('--retro-encoder-attention-dropout',
type=float, default=0.1, help='Attention dropout for '
'retrieval encoder.')
group.add_argument("--retro-num-neighbors", type=int, default=2,
help='Number of neighbors to retrieve during '
'pretraining.')
group.add_argument("--retro-num-retrieved-chunks", type=int, default=2,
help='Number of chunks to retrieve from the retrieval '
'database.')
group.add_argument("--retro-return-doc-ids", action="store_true",
help="Turn this on when preprocessing retro data.")
# Enforce argument naming convention.
for action in group._group_actions:
prefix = action.dest.split("_")[0]
assert prefix == "retro", \
"Retro args must be prefixed with '--retro-*', for consistent " \
"styling. Please fix '%s'." % ", ".join(action.option_strings)
return parser
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size')
group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.')
group.add_argument('--encoder-num-layers', type=int, default=None,
help='Number of encoder transformer layers.')
group.add_argument('--decoder-num-layers', type=int, default=None,
help='Number of decoder transformer layers.')
group.add_argument('--num-experts', type=int, nargs='+', default=[1,],
help='number of experts list, MoE related.')
group.add_argument('--mlp-type', type=str, default='standard',
help='Only applicable when num-experts > 1, accepts [standard, residual]')
group.add_argument('--topk', type=int, default=1,
help='Sets the k in TopK gating for MoE layers')
group.add_argument('--expert-interval', type=int, default=2,
help='Use experts in every "expert-interval" layers')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--ffn-hidden-size', type=int, default=None,
help='Transformer Feed-Forward Network hidden size. '
'This is set to 4*hidden-size if not provided')
group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.')
group.add_argument('--num-key-value-heads', type=int, default=None,
help='Number of key_value heads that should be used to implement Grouped Query Attention.')
group.add_argument('--kv-channels', type=int, default=None,
help='Projection weights dimension in multi-head '
'attention. This is set to '
' args.hidden_size // args.num_attention_heads '
'if not provided.')
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
group.add_argument('--use-rotary-position-embeddings', action='store_true',
help='Use rotary positional embeddings or not')
group.add_argument('--rotary-position-embeddings-theta', type=int, default=10000,
help='Rotary positional embeddings theta value.',
dest='rope_theta')
group.add_argument('--rotary-percent', type=float, default=1.0,
help='Percent of rotary dimension to use, default 100%')
group.add_argument('--no-position-embedding',
action='store_false',
help='Disable position embedding.',
dest='add_position_embedding')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
group.add_argument('--normalization', type=str, default='layernorm',
choices=['layernorm', 'rmsnorm'],
help='Options for layer normalization type:'
' layernorm'
' rmsnorm')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='Layer norm epsilon.')
group.add_argument('--apply-layernorm-1p', action='store_true',
help='Adjust LayerNorm weights such that they are centered '
'around zero. This improves numerical stability.')
group.add_argument('--disable-mem-efficient-ln', action='store_false',
help='Disable the memory-efficient fused LayerNorm optimization '
'introduced in https://github.com/NVIDIA/apex/pull/1715', dest='mem_efficient_ln')
group.add_argument('--apply-residual-connection-post-layernorm',
action='store_true',
help='If set, use original BERT residula connection '
'ordering.')
group.add_argument('--openai-gelu', action='store_true',
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
'reasons.')
group.add_argument('--squared-relu', action='store_true',
help='Use squared relu activation instead of default gelu')
group.add_argument('--swiglu', action='store_true',
help='Use gated linear units and SiLU activation instead of default gelu')
group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with '
'Torch ONNX exporter')
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
group.add_argument('--num-experts-switch', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'),
group.add_argument('--embedding-weights-in-fp32', action='store_true',
help='Cast word embedding weights to fp32 before embedding fwd.'),
return parser
def _add_logging_args(parser):
group = parser.add_argument_group(title='logging')
group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--timing-log-level', type=int,
default=0, choices=range(0,3),
help='Granularity level to measure and report timing. '
' 0: report only iteration time and make sure timing '
' does not introduce extra overhead.'
' 1: report timing for operations that are executed '
' very limited times (basically once) during '
' each iteration (such as gradient all-reduce) '
' 2: report timing for operations that migh be '
' executed numerous times during each iteration. '
'Note that setting the level to 1 or 2 might '
'cause increase in iteration time.')
group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
help='If not set, use barrier with level 1 time '
'measurements. Note that this is up to the user '
'to make sure calling barrier with their timers '
'will not result in hangs. This can happen if for '
'example the user adds a level 1 timer that is not '
'called by all ranks.',
dest='barrier_with_L1_time')
group.add_argument('--timing-log-option', type=str, default='minmax',
choices=['max', 'minmax', 'all'],
help='Options for logging timing:'
' max: report the max timing across all ranks'
' minmax: report min and max timings across all ranks'
' all: report timings of all ranks.')
group.add_argument('--tensorboard-log-interval', type=int, default=1,
help='Report to tensorboard interval.')
group.add_argument('--tensorboard-queue-size', type=int, default=1000,
help='Size of the tensorboard queue for pending events '
'and summaries before one of the ‘add’ calls forces a '
'flush to disk.')
group.add_argument('--log-timers-to-tensorboard', action='store_true',
help='If set, write timers to tensorboard.')
group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
help='If set, write batch-size to tensorboard.')
group.add_argument('--no-log-learnig-rate-to-tensorboard',
action='store_false',
help='Disable learning rate logging to tensorboard.',
dest='log_learning_rate_to_tensorboard')
group.add_argument('--no-log-loss-scale-to-tensorboard',
action='store_false',
help='Disable loss-scale logging to tensorboard.',
dest='log_loss_scale_to_tensorboard')
group.add_argument('--log-validation-ppl-to-tensorboard',
action='store_true',
help='If set, write validation perplexity to '
'tensorboard.')
group.add_argument('--log-optimizer-states-to-tensorboard',
action='store_true',
help='If set, write various optimizer states to '
'tensorboard. This feature may consume extra GPU memory.')
group.add_argument('--log-memory-to-tensorboard',
action='store_true',
help='Enable memory logging to tensorboard.')
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')
return parser
def _add_regularization_args(parser):
group = parser.add_argument_group(title='regularization')
group.add_argument('--attention-dropout', type=float, default=0.1,
help='Post attention dropout probability.')
group.add_argument('--hidden-dropout', type=float, default=0.1,
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-weight-decay', type=float,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-weight-decay', type=float,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9,
help='First coefficient for computing running averages '
'of gradient and its square')
group.add_argument('--adam-beta2', type=float, default=0.999,
help='Second coefficient for computing running averages '
'of gradient and its square')
group.add_argument('--adam-eps', type=float, default=1e-08,
help='Term added to the denominator to improve'
'numerical stability')
group.add_argument('--sgd-momentum', type=float, default=0.9,
help='Momentum factor for sgd')
return parser
def _add_training_args(parser):
group = parser.add_argument_group(title='training')
group.add_argument('--micro-batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size times number of micro batches.')
group.add_argument('--batch-size', type=int, default=None,
help='Old batch size parameter, do not use. '
'Use --micro-batch-size instead')
group.add_argument('--global-batch-size', type=int, default=None,
help='Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'number of micro-batches.')
group.add_argument('--rampup-batch-size', nargs='*', default=None,
help='Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '
' <batch size incerement> '
' <ramp-up samples> '
'For example:'
' --rampup-batch-size 16 8 300000 \ '
' --global-batch-size 1024'
'will start with global batch size 16 and over '
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--recompute-activations', action='store_true',
help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--recompute-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.')
group.add_argument('--distribute-saved-activations',
action='store_true',
help='If set, distribute recomputed activations '
'across model parallel group.')
group.add_argument('--recompute-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and recompute the input activation of '
'each divided chunk at specified granularity, '
'2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any recomputing at specified granularity'
'default) do not apply activations recompute to any layers')
group.add_argument('--recompute-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to recompute within each pipeline stage.')
group.add_argument('--enable-zbh1-pipeline', action='store_true',
help='Activate zero bubble pipeline parallelism schedule method')
group.add_argument('--enable-zbh1-exact-semantics', action='store_true',
help='Use an exact semantics for zbh1 schedule, might be slower than the default.')
# deprecated
# HACK: added back arguments because DeepSpeed still relies on the old
# activation checkpointing mechanism.
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--distribute-checkpointed-activations',
action='store_true',
help='If set, distribute checkpointed activations '
'across model parallel group.')
group.add_argument('--checkpoint-num-layers', type=int, default=1,
help='chunk size (number of layers) for checkpointing.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-samples', type=int, default=None,
help='Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-tokens', type=int, default=None,
help='Total number of tokens to train over all '
'training runs.')
group.add_argument('--random-ltd',
action='store_true',
help='enable random layer token drop')
group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after the iteration is divisible '
'by this value.')
group.add_argument('--exit-duration-in-mins', type=int, default=None,
help='Exit the program after this many minutes.')
group.add_argument('--exit-signal-handler', action='store_true',
help='Dynamically save the checkpoint and shutdown the '
'training if SIGTERM is received')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--no-masked-softmax-fusion',
action='store_false',
help='Disable fusion of query_key_value scaling, '
'masking, and softmax.',
dest='masked_softmax_fusion')
group.add_argument('--no-bias-gelu-fusion', action='store_false',
help='Disable bias and gelu fusion.',
dest='bias_gelu_fusion')
group.add_argument('--no-bias-dropout-fusion', action='store_false',
help='Disable bias and dropout fusion.',
dest='bias_dropout_fusion')
group.add_argument('--disable-moe-token-dropping', action='store_false',
help='Disable MoE expert token dropping.',
dest='moe_token_dropping')
group.add_argument('--moe-train-capacity-factor', type=float, default=1.0,
help='The capacity of the MoE expert at training time')
group.add_argument('--moe-eval-capacity-factor', type=float, default=1.0,
help='The capacity of the MoE expert at eval time.')
group.add_argument('--moe-min-capacity', type=int, default=4,
help='The minimum capacity per MoE expert regardless of the capacity_factor.')
group.add_argument('--moe-loss-coeff', type=float, default=0.1,
help='Scaling coefficient for adding MoE loss to model loss')
group.add_argument('--create-moe-param-group', action='store_true',
help='Create separate groups for MoE params.'
'This is necessary for techniques like ZeRO.')
group.add_argument('--disable-moe-top2-2nd-expert-sampling', action='store_false',
help='Disable MoE top2 sampling of the 2nd expert. Instead of sampling, use argmax.',
dest='moe_top2_2nd_expert_sampling')
group.add_argument('--use-flash-attn', '--use-flash-attn-v1', dest='use_flash_attn_v1', action='store_true',
help='use first version FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--use-flash-attn-v2', action='store_true',
help='use second version FlashAttention implementation of attention. '
'https://arxiv.org/abs/2307.08691')
group.add_argument('--use-flash-attn-triton', action='store_true',
help='use FlashAttention implementation of attention using Triton.')
group.add_argument('--use-flash-attn-builder', action='store_true',
help='use FlashAttention op builder.')
group.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers',
dest='add_bias_linear')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
group.add_argument('--ds-inference', action='store_true',
help='DeepSpeed inference engine being used')
group.add_argument('--cpu-optimizer', action='store_true',
help='Run optimizer on CPU')
group.add_argument('--cpu_torch_adam', action='store_true',
help='Use Torch Adam as optimizer on CPU.')
group.add_argument('--ds_fused_adam', action='store_true',
help='Use DeepSpeed FusedAdam as optimizer.')
group.add_argument('--no-pipeline-parallel', action='store_true',
help='Disable pipeline parallelism')
group.add_argument('--use-tutel', action='store_true',
help='Use Tutel optimization for MoE')
group.add_argument('--inference', action='store_true',
help='Very basic inference mode: not allocating optim/lr - requires ZERO_STAGE=0')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_false',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.',
dest='async_tensor_model_parallel_allreduce')
group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
group.add_argument('--sequence-parallel', action='store_true',
help='Enable Megatron-LM\'s sequence parallel optimization.')
group.add_argument('--ds-sequence-parallel-size', type=int, default=1,
help='Enable DeepSpeed\'s sequence parallel. Cannot be combined with "--sequence-parallel", which enables Megatron-LM\'s sequence parallel.')
group.add_argument('--force-ds-sequence-parallel', action='store_true',
help='use DeepSpeed sequence parallelism regardless of sequence parallel size.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers',
dest='gradient_accumulation_fusion')
group.add_argument('--use-dataset-only', type=bool, required=False, default=False,
help='If set to True, only use the megatron dataset for external trainer ')
return parser
def _add_initialization_args(parser):
group = parser.add_argument_group(title='initialization')
group.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy, '
'pytorch, and cuda.')
group.add_argument('--data-parallel-random-init', action='store_true',
help='Enable random initialization of params '
'across data parallel ranks')
group.add_argument('--init-method-std', type=float, default=0.02,
help='Standard deviation of the zero mean normal '
'distribution used for weight initialization.')
group.add_argument('--init-method-xavier-uniform', action='store_true',
help='Enable Xavier uniform parameter initialization')
return parser
def _add_learning_rate_args(parser):
group = parser.add_argument_group(title='learning rate')
group.add_argument('--lr', type=float, default=None,
help='Initial learning rate. Depending on decay style '
'and initial warmup, the learing rate at each '
'iteration would be different.')