Skip to content

Commit

Permalink
Merge pull request AI-Hypercomputer#1071 from AI-Hypercomputer:raymon…
Browse files Browse the repository at this point in the history
…dzou-llama

PiperOrigin-RevId: 702160249
  • Loading branch information
maxtext authors committed Dec 3, 2024
2 parents 5cdbabb + e6eb55a commit 6f36556
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
1 change: 1 addition & 0 deletions benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
'mixtral_8x7b_dropless',
'gemma2_9b_8192',
'gemma2_27b_8192',
'llama3_1_70b_129024',
],
default='llama2_70b_4096',
help=(
Expand Down
44 changes: 44 additions & 0 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,49 @@ class MaxTextModel:
),
)

llama3_1_70b_129024 = MaxTextModel(
model_name="llama3_1-70b-129024",
model_type="llama3.1-70b",
tuning_params={
"per_device_batch_size": 0.125,
"ici_fsdp_parallelism": -1,
"ici_sequence_parallelism": 8,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"out_proj": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 129024,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": False,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,
"allow_split_physical_axes": True,
"custom_mesh": "hybrid_ring_32x8",
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_1D_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)

mixtral_8x7b_dropless = MaxTextModel(
model_name="mixtral-8x7b",
model_type="mixtral-8x7b",
Expand Down Expand Up @@ -576,6 +619,7 @@ class MaxTextModel:
llama3_8b_8192, # Not Optimizied yet
llama3_70b_8192, # Not Optimizied yet
llama3_1_405b_8192_fsdp_dcn,
llama3_1_70b_129024,
mixtral_8x7b_dropped,
mixtral_8x7b_dropped_int8,
mixtral_8x7b_dropless,
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/xla_flags_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,15 @@

#Only ready for 1D All-Gather but should support 2D soon, and
# hopefully All-Reduce soon.
ENABLE_SPARECORE_OFFLOADING_FOR_1D_ALL_GATHER = (
ENABLE_SPARSECORE_OFFLOADING_FOR_1D_ALL_GATHER = (
" --xla_sc_disable_megacore_partitioning=true"
" --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false"
" --xla_tpu_enable_all_gather_offload_tracing=true"
" --xla_tpu_use_tc_device_shape_on_sc=true"
" --xla_tpu_enable_sparse_core_collective_offload_all_gather=true"
" --xla_sc_enable_instruction_fusion=false"
" --xla_sc_disjoint_spmem=false"
" --2a886c8_chip_config_name=megachip_tccontrol"
# Interesting flags to try:
# " --xla_tpu_enable_offloading_gather_to_sparsecore=true"
# " --xla_tpu_enable_offloading_reduce_to_sparsecore=true"
Expand Down

0 comments on commit 6f36556

Please sign in to comment.