From 5ecb3f58579e48f9eaa777666d327adb2307e64a Mon Sep 17 00:00:00 2001 From: mailvijayasingh Date: Thu, 3 Oct 2024 08:56:53 -0700 Subject: [PATCH] Trillium addition - maxtext inference microbenchmarking (#413) * Trillium addition - maxtext inference microbenchmarking * Add Trillium to jetstream-maxtext inference * Switch to working aqtp version * remove uninstall aqtp --------- Co-authored-by: singhvijaya --- .../configs/jetstream_benchmark_serving_gce_config.py | 1 + dags/inference/jetstream_inference_e2e.py | 4 ++-- dags/inference/maxtext_inference.py | 10 +++++----- dags/inference/maxtext_inference_microbenchmark.py | 10 ++++++++-- dags/inference/maxtext_model_config_generator.py | 9 +++++++-- dags/vm_resource.py | 6 ++++++ 6 files changed, 29 insertions(+), 11 deletions(-) diff --git a/dags/inference/configs/jetstream_benchmark_serving_gce_config.py b/dags/inference/configs/jetstream_benchmark_serving_gce_config.py index f0314a93..8736cf32 100644 --- a/dags/inference/configs/jetstream_benchmark_serving_gce_config.py +++ b/dags/inference/configs/jetstream_benchmark_serving_gce_config.py @@ -65,6 +65,7 @@ def get_config( f"cd maxtext && bash setup.sh MODE={test_mode.value} && cd ..", "cd JetStream && pip install -e . && cd benchmarks && pip install -r requirements.in", "pip install torch --index-url https://download.pytorch.org/whl/cpu", + "pip install aqtp==0.7.5", ) additional_metadata_dict = { diff --git a/dags/inference/jetstream_inference_e2e.py b/dags/inference/jetstream_inference_e2e.py index 2c63ba94..749a9316 100644 --- a/dags/inference/jetstream_inference_e2e.py +++ b/dags/inference/jetstream_inference_e2e.py @@ -63,7 +63,7 @@ "jetstream_branch": "", "sleep_time": 360, "time_out_in_min": 60, - "tpu_version_cores": [(TpuVersion.V5E, 8)], + "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": LLAMA2_7B, "tokenizer": "tokenizer.llama2", "weight_dtype": "bfloat16", @@ -91,7 +91,7 @@ "jetstream_branch": "", "sleep_time": 360, "time_out_in_min": 60, - "tpu_version_cores": [(TpuVersion.V5E, 8)], + "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": GEMMA_7B, "tokenizer": "tokenizer.gemma", "weight_dtype": "bfloat16", diff --git a/dags/inference/maxtext_inference.py b/dags/inference/maxtext_inference.py index 416a0142..c98c70d6 100644 --- a/dags/inference/maxtext_inference.py +++ b/dags/inference/maxtext_inference.py @@ -109,7 +109,7 @@ "jetstream_branch": jetstream_branch, "sleep_time": 360, "time_out_in_min": 120, - "tpu_version_cores": [(TpuVersion.V5E, 8)], + "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": LLAMA2_7B, "tokenizer": "tokenizer.llama2", "weight_dtype": "bfloat16", @@ -147,7 +147,7 @@ "jetstream_branch": jetstream_branch, "sleep_time": 360, "time_out_in_min": 120, - "tpu_version_cores": [(TpuVersion.V5E, 8)], + "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": LLAMA2_13B, "tokenizer": "tokenizer.llama2", "weight_dtype": "bfloat16", @@ -185,7 +185,7 @@ "jetstream_branch": jetstream_branch, "sleep_time": 360, "time_out_in_min": 240, - "tpu_version_cores": [(TpuVersion.V5P, 8)], + "tpu_version_cores": [(TpuVersion.V5P, 8), (TpuVersion.TRILLIUM, 8)], "model_name": LLAMA2_70B, "tokenizer": "tokenizer.llama2", "weight_dtype": "bfloat16", @@ -222,7 +222,7 @@ "jetstream_branch": jetstream_branch, "sleep_time": 360, "time_out_in_min": 120, - "tpu_version_cores": [(TpuVersion.V5E, 8)], + "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": GEMMA_7B, "tokenizer": "tokenizer.gemma", "weight_dtype": "bfloat16", @@ -261,7 +261,7 @@ "jetstream_branch": jetstream_branch, "sleep_time": 240, "time_out_in_min": 240, - "tpu_version_cores": [(TpuVersion.V5P, 8)], + "tpu_version_cores": [(TpuVersion.V5P, 8), (TpuVersion.TRILLIUM, 8)], "model_name": MIXTRAL_8_7B, "tokenizer": "gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral", "weight_dtype": "bfloat16", diff --git a/dags/inference/maxtext_inference_microbenchmark.py b/dags/inference/maxtext_inference_microbenchmark.py index 7cb67cdf..519a93fa 100644 --- a/dags/inference/maxtext_inference_microbenchmark.py +++ b/dags/inference/maxtext_inference_microbenchmark.py @@ -19,7 +19,7 @@ import itertools import numpy from airflow import models -from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion +from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK from dags.inference.configs import maxtext_inference_microbenchmark_gce_config from dags.multipod.configs.common import SetupMode @@ -184,6 +184,12 @@ def generate_model_configs( network = V5_NETWORKS subnetwork = V5E_SUBNETWORKS runtime_version = RuntimeVersion.V2_ALPHA_TPUV5_LITE.value + if tpu_version == TpuVersion.TRILLIUM: + project_name = Project.CLOUD_ML_AUTO_SOLUTIONS.value + zone = Zone.EUROPE_WEST4_A.value + network = V6E_GCE_NETWORK + subnetwork = V6E_GCE_SUBNETWORK + runtime_version = RuntimeVersion.V2_ALPHA_TPUV6.value maxtext_kv_cache_layout_optimization = ( maxtext_inference_microbenchmark_gce_config.config( @@ -241,7 +247,7 @@ def generate_model_configs( if not MAXTEXT_BRANCH else f"-b {MAXTEXT_BRANCH}", "sleep_time": 60, - "tpu_version_cores": [(TpuVersion.V5E, 8)], + "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": LLAMA2_7B, "tokenizer": "tokenizer.llama2", "weight_dtype": "bfloat16", diff --git a/dags/inference/maxtext_model_config_generator.py b/dags/inference/maxtext_model_config_generator.py index 9830807d..085f3d2a 100644 --- a/dags/inference/maxtext_model_config_generator.py +++ b/dags/inference/maxtext_model_config_generator.py @@ -14,7 +14,7 @@ """A helper to generate maxtext model configs.""" -from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion +from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK from dags.inference.configs import jetstream_benchmark_serving_gce_config from dags.multipod.configs.common import SetupMode @@ -112,7 +112,12 @@ def generate_model_configs( project_name = Project.TPU_PROD_ENV_AUTOMATED.value network = V5_NETWORKS subnetwork = V5P_SUBNETWORKS - + elif tpu_version == TpuVersion.TRILLIUM: + zone = Zone.EUROPE_WEST4_A.value + runtime_version = RuntimeVersion.V2_ALPHA_TPUV6.value + project_name = Project.CLOUD_ML_AUTO_SOLUTIONS.value + network = V6E_GCE_NETWORK + subnetwork = V6E_GCE_SUBNETWORK jetstream_benchmark_serving = ( jetstream_benchmark_serving_gce_config.get_config( tpu_version=tpu_version, diff --git a/dags/vm_resource.py b/dags/vm_resource.py index 99de432e..8b03670b 100644 --- a/dags/vm_resource.py +++ b/dags/vm_resource.py @@ -26,6 +26,9 @@ V6E_SUBNETWORKS = ( f"{V5_NETWORKS_PREFIX}/regions/us-central2/subnetworks/mas-test" ) +# TODO: Figure V6E_GCE_NETWORK and V6E_GCE_SUBNETWORK +V6E_GCE_NETWORK = "default" +V6E_GCE_SUBNETWORK = "default" BM_NETWORKS_PREFIX_BENCHMARKING = "projects/cloud-ml-benchmarking" BM_NETWORKS = f"{BM_NETWORKS_PREFIX_BENCHMARKING}/global/networks/mas-test" @@ -100,6 +103,8 @@ class Zone(enum.Enum): US_WEST1_C = "us-west1-c" # reserved a3+ cluster in supercomputer-testing AUSTRALIA_SOUTHEAST1_C = "australia-southeast1-c" + # reserved TRILLIUM capacity + EUROPE_WEST4_A = "europe-west4-a" class MachineVersion(enum.Enum): @@ -159,6 +164,7 @@ class RuntimeVersion(enum.Enum): TPU_VM_V4_BASE = "tpu-vm-v4-base" V2_ALPHA_TPUV5_LITE = "v2-alpha-tpuv5-lite" V2_ALPHA_TPUV5 = "v2-alpha-tpuv5" + V2_ALPHA_TPUV6 = "v2-alpha-tpuv6e" class XpkClusters: