From e6eb55a24aff275a8512f6b223af4f6a1584b563 Mon Sep 17 00:00:00 2001 From: Raymond Zou Date: Tue, 5 Nov 2024 08:11:50 +0000 Subject: [PATCH] Add llama 3.1 70b config --- benchmarks/benchmark_runner.py | 1 + benchmarks/maxtext_trillium_model_configs.py | 44 ++++++++++++++++++++ benchmarks/xla_flags_library.py | 3 +- 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index f7b859da0..909c36276 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -93,6 +93,7 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'mixtral_8x7b_dropless', 'gemma2_9b_8192', 'gemma2_27b_8192', + 'llama3_1_70b_129024', ], default='llama2_70b_4096', help=( diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 7a28effde..b3606ec8f 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -412,6 +412,49 @@ class MaxTextModel: ), ) +llama3_1_70b_129024 = MaxTextModel( + model_name="llama3_1-70b-129024", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 0.125, + "ici_fsdp_parallelism": -1, + "ici_sequence_parallelism": 8, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 129024, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + "allow_split_physical_axes": True, + "custom_mesh": "hybrid_ring_32x8", + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_1D_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), +) + mixtral_8x7b_dropless = MaxTextModel( model_name="mixtral-8x7b", model_type="mixtral-8x7b", @@ -576,6 +619,7 @@ class MaxTextModel: llama3_8b_8192, # Not Optimizied yet llama3_70b_8192, # Not Optimizied yet llama3_1_405b_8192_fsdp_dcn, + llama3_1_70b_129024, mixtral_8x7b_dropped, mixtral_8x7b_dropped_int8, mixtral_8x7b_dropless, diff --git a/benchmarks/xla_flags_library.py b/benchmarks/xla_flags_library.py index 35d03bdf1..705e838d1 100644 --- a/benchmarks/xla_flags_library.py +++ b/benchmarks/xla_flags_library.py @@ -51,7 +51,7 @@ #Only ready for 1D All-Gather but should support 2D soon, and # hopefully All-Reduce soon. -ENABLE_SPARECORE_OFFLOADING_FOR_1D_ALL_GATHER = ( +ENABLE_SPARSECORE_OFFLOADING_FOR_1D_ALL_GATHER = ( " --xla_sc_disable_megacore_partitioning=true" " --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false" " --xla_tpu_enable_all_gather_offload_tracing=true" @@ -59,6 +59,7 @@ " --xla_tpu_enable_sparse_core_collective_offload_all_gather=true" " --xla_sc_enable_instruction_fusion=false" " --xla_sc_disjoint_spmem=false" + " --2a886c8_chip_config_name=megachip_tccontrol" # Interesting flags to try: # " --xla_tpu_enable_offloading_gather_to_sparsecore=true" # " --xla_tpu_enable_offloading_reduce_to_sparsecore=true"