From 5988b40ac25ee5b8351ce9ae6c2551e1c142ed19 Mon Sep 17 00:00:00 2001 From: tonyjohnchen <127781047+tonyjohnchen@users.noreply.github.com> Date: Thu, 8 Feb 2024 13:21:49 -0800 Subject: [PATCH] Adding MXLA collective benchmark test (#106) * Adding MXLA collective benchmark test. --- dags/multipod/configs/common.py | 7 ++ .../configs/mxla_collective_config.py | 85 ++++++++++++++++ dags/multipod/mxla_collective_nightly.py | 96 +++++++++++++++++++ 3 files changed, 188 insertions(+) create mode 100644 dags/multipod/configs/mxla_collective_config.py create mode 100644 dags/multipod/mxla_collective_nightly.py diff --git a/dags/multipod/configs/common.py b/dags/multipod/configs/common.py index 0b296b61..0f419a27 100644 --- a/dags/multipod/configs/common.py +++ b/dags/multipod/configs/common.py @@ -43,3 +43,10 @@ def setup_maxtext(mode: SetupMode, platform: Platform) -> Tuple[str]: return download_maxtext() + ( f"cd /tmp/maxtext && bash setup.sh MODE={mode.value} && bash preflight.sh PLATFORM={platform.value}", ) + + +def setup_mxla_collective() -> Tuple[str]: + """Common set up for MXLA collective repo.""" + return ( + f"mkdir -p /tmp/mxla_collective && gsutil -m cp gs://mxla_collective_benchmark_script/test_scripts/* /tmp/mxla_collective", + ) diff --git a/dags/multipod/configs/mxla_collective_config.py b/dags/multipod/configs/mxla_collective_config.py new file mode 100644 index 00000000..c6dbd264 --- /dev/null +++ b/dags/multipod/configs/mxla_collective_config.py @@ -0,0 +1,85 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to construct configs for maxtext DAG.""" + +from xlml.apis import gcp_config, metric_config, task, test_config +from dags import test_owner, gcs_bucket +from dags.multipod.configs import common +from dags.vm_resource import TpuVersion, Project, RuntimeVersion +import datetime + +PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value +RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value + + +def get_mxla_collective_config( + tpu_version: TpuVersion, + tpu_cores: int, + tpu_zone: str, + time_out_in_min: int, + test_name: str, + bytes_to_transfer: int, + project_name: str = PROJECT_NAME, + runtime_version: str = RUNTIME_IMAGE, + network: str = "default", + subnetwork: str = "default", + is_tpu_reserved: bool = True, + num_slices: int = 1, +) -> task.TpuQueuedResourceTask: + job_gcp_config = gcp_config.GCPConfig( + project_name=project_name, + zone=tpu_zone, + dataset_name=metric_config.DatasetOption.XLML_DATASET, + dataset_project=project_name, + composer_project=project_name, + ) + + current_time = datetime.datetime.now() + current_date = current_time.strftime("%Y-%m-%d") + current_datetime = current_time.strftime("%Y-%m-%d-%H-%M-%S") + + base_output_directory = f"{gcs_bucket.XLML_OUTPUT_DIR}/multipod/mxla/nightly/automated/{current_date}/{num_slices}slice-V{tpu_version.value}_{tpu_cores}-mxla-collective-{bytes_to_transfer}transferBytes-{current_datetime}" + + test_platform = common.Platform.GCE + set_up_cmds = common.setup_mxla_collective() + run_model_cmds = ( + ( + "cd /tmp/mxla_collective &&" + f" sudo bash run_mxla_collective_benchmark_gcp.sh BYTES_TO_TRANSFER={bytes_to_transfer} NUM_STEPS=100 SLICES={num_slices} ACCELERATOR_TYPE=v{tpu_version.value}_{tpu_cores} GCS_PATH={base_output_directory}" + " exit $(cat /tmp/benchmark_status.txt)" + ), + ) + + job_test_config = test_config.TpuVmTest( + test_config.Tpu( + version=tpu_version, + cores=tpu_cores, + runtime_version=runtime_version, + reserved=is_tpu_reserved, + network=network, + subnetwork=subnetwork, + ), + test_name=test_name, + set_up_cmds=set_up_cmds, + run_model_cmds=run_model_cmds, + time_out_in_min=time_out_in_min, + task_owner=test_owner.TONY_C, + num_slices=num_slices, + ) + + return task.TpuQueuedResourceTask( + task_test_config=job_test_config, + task_gcp_config=job_gcp_config, + ) diff --git a/dags/multipod/mxla_collective_nightly.py b/dags/multipod/mxla_collective_nightly.py new file mode 100644 index 00000000..5a0f21ba --- /dev/null +++ b/dags/multipod/mxla_collective_nightly.py @@ -0,0 +1,96 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A DAG to run MaxText tests with nightly version.""" + +import datetime +from airflow import models +from dags import composer_env +from dags.vm_resource import TpuVersion, Zone +from dags.multipod.configs import mxla_collective_config +from dags.multipod.configs.common import SetupMode, Platform + + +# Run once a day at 8 am UTC (12 pm PST) +SCHEDULED_TIME = "0 8 * * *" if composer_env.is_prod_env() else None + + +with models.DAG( + dag_id="mxla_collective_nightly", + schedule=SCHEDULED_TIME, + tags=["multipod_team", "mxla_collective", "nightly"], + start_date=datetime.datetime(2024, 2, 7), + catchup=False, +) as dag: + mxla_1mb_test_name = "mxla-collective-nightly-1mb" + mxla_256mb_test_name = "mxla-collective-nightly-256mb" + + mxla_collective_1mb_nightly_4slice_v4_8 = ( + mxla_collective_config.get_mxla_collective_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + bytes_to_transfer=1000000, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + is_tpu_reserved=False, + num_slices=4, + test_name=mxla_1mb_test_name, + ).run() + ) + + mxla_collective_1mb_nightly_8slice_v4_8 = ( + mxla_collective_config.get_mxla_collective_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + bytes_to_transfer=1000000, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + is_tpu_reserved=False, + num_slices=8, + test_name=mxla_1mb_test_name, + ).run() + ) + + mxla_collective_256mb_nightly_4slice_v4_8 = ( + mxla_collective_config.get_mxla_collective_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + bytes_to_transfer=256000000, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + is_tpu_reserved=False, + num_slices=4, + test_name=mxla_256mb_test_name, + ).run() + ) + + mxla_collective_256mb_nightly_8slice_v4_8 = ( + mxla_collective_config.get_mxla_collective_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + bytes_to_transfer=256000000, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + is_tpu_reserved=False, + num_slices=8, + test_name=mxla_256mb_test_name, + ).run() + ) + # Test dependencie + ( + mxla_collective_1mb_nightly_4slice_v4_8 + >> mxla_collective_256mb_nightly_4slice_v4_8 + >> mxla_collective_1mb_nightly_8slice_v4_8 + >> mxla_collective_256mb_nightly_8slice_v4_8 + )