Skip to content

Commit

Permalink
Trillium addition - maxtext inference microbenchmarking (#413)
Browse files Browse the repository at this point in the history
* Trillium addition - maxtext inference microbenchmarking

* Add Trillium to jetstream-maxtext inference

* Switch to working aqtp version

* remove uninstall aqtp

---------

Co-authored-by: singhvijaya <[email protected]>
  • Loading branch information
mailvijayasingh and singhvijaya authored Oct 3, 2024
1 parent 8ecdce4 commit 5ecb3f5
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def get_config(
f"cd maxtext && bash setup.sh MODE={test_mode.value} && cd ..",
"cd JetStream && pip install -e . && cd benchmarks && pip install -r requirements.in",
"pip install torch --index-url https://download.pytorch.org/whl/cpu",
"pip install aqtp==0.7.5",
)

additional_metadata_dict = {
Expand Down
4 changes: 2 additions & 2 deletions dags/inference/jetstream_inference_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"jetstream_branch": "",
"sleep_time": 360,
"time_out_in_min": 60,
"tpu_version_cores": [(TpuVersion.V5E, 8)],
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)],
"model_name": LLAMA2_7B,
"tokenizer": "tokenizer.llama2",
"weight_dtype": "bfloat16",
Expand Down Expand Up @@ -91,7 +91,7 @@
"jetstream_branch": "",
"sleep_time": 360,
"time_out_in_min": 60,
"tpu_version_cores": [(TpuVersion.V5E, 8)],
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)],
"model_name": GEMMA_7B,
"tokenizer": "tokenizer.gemma",
"weight_dtype": "bfloat16",
Expand Down
10 changes: 5 additions & 5 deletions dags/inference/maxtext_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
"jetstream_branch": jetstream_branch,
"sleep_time": 360,
"time_out_in_min": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8)],
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)],
"model_name": LLAMA2_7B,
"tokenizer": "tokenizer.llama2",
"weight_dtype": "bfloat16",
Expand Down Expand Up @@ -147,7 +147,7 @@
"jetstream_branch": jetstream_branch,
"sleep_time": 360,
"time_out_in_min": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8)],
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)],
"model_name": LLAMA2_13B,
"tokenizer": "tokenizer.llama2",
"weight_dtype": "bfloat16",
Expand Down Expand Up @@ -185,7 +185,7 @@
"jetstream_branch": jetstream_branch,
"sleep_time": 360,
"time_out_in_min": 240,
"tpu_version_cores": [(TpuVersion.V5P, 8)],
"tpu_version_cores": [(TpuVersion.V5P, 8), (TpuVersion.TRILLIUM, 8)],
"model_name": LLAMA2_70B,
"tokenizer": "tokenizer.llama2",
"weight_dtype": "bfloat16",
Expand Down Expand Up @@ -222,7 +222,7 @@
"jetstream_branch": jetstream_branch,
"sleep_time": 360,
"time_out_in_min": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8)],
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)],
"model_name": GEMMA_7B,
"tokenizer": "tokenizer.gemma",
"weight_dtype": "bfloat16",
Expand Down Expand Up @@ -261,7 +261,7 @@
"jetstream_branch": jetstream_branch,
"sleep_time": 240,
"time_out_in_min": 240,
"tpu_version_cores": [(TpuVersion.V5P, 8)],
"tpu_version_cores": [(TpuVersion.V5P, 8), (TpuVersion.TRILLIUM, 8)],
"model_name": MIXTRAL_8_7B,
"tokenizer": "gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral",
"weight_dtype": "bfloat16",
Expand Down
10 changes: 8 additions & 2 deletions dags/inference/maxtext_inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import itertools
import numpy
from airflow import models
from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion
from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK
from dags.inference.configs import maxtext_inference_microbenchmark_gce_config
from dags.multipod.configs.common import SetupMode

Expand Down Expand Up @@ -184,6 +184,12 @@ def generate_model_configs(
network = V5_NETWORKS
subnetwork = V5E_SUBNETWORKS
runtime_version = RuntimeVersion.V2_ALPHA_TPUV5_LITE.value
if tpu_version == TpuVersion.TRILLIUM:
project_name = Project.CLOUD_ML_AUTO_SOLUTIONS.value
zone = Zone.EUROPE_WEST4_A.value
network = V6E_GCE_NETWORK
subnetwork = V6E_GCE_SUBNETWORK
runtime_version = RuntimeVersion.V2_ALPHA_TPUV6.value

maxtext_kv_cache_layout_optimization = (
maxtext_inference_microbenchmark_gce_config.config(
Expand Down Expand Up @@ -241,7 +247,7 @@ def generate_model_configs(
if not MAXTEXT_BRANCH
else f"-b {MAXTEXT_BRANCH}",
"sleep_time": 60,
"tpu_version_cores": [(TpuVersion.V5E, 8)],
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)],
"model_name": LLAMA2_7B,
"tokenizer": "tokenizer.llama2",
"weight_dtype": "bfloat16",
Expand Down
9 changes: 7 additions & 2 deletions dags/inference/maxtext_model_config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""A helper to generate maxtext model configs."""

from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion
from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK
from dags.inference.configs import jetstream_benchmark_serving_gce_config
from dags.multipod.configs.common import SetupMode

Expand Down Expand Up @@ -112,7 +112,12 @@ def generate_model_configs(
project_name = Project.TPU_PROD_ENV_AUTOMATED.value
network = V5_NETWORKS
subnetwork = V5P_SUBNETWORKS

elif tpu_version == TpuVersion.TRILLIUM:
zone = Zone.EUROPE_WEST4_A.value
runtime_version = RuntimeVersion.V2_ALPHA_TPUV6.value
project_name = Project.CLOUD_ML_AUTO_SOLUTIONS.value
network = V6E_GCE_NETWORK
subnetwork = V6E_GCE_SUBNETWORK
jetstream_benchmark_serving = (
jetstream_benchmark_serving_gce_config.get_config(
tpu_version=tpu_version,
Expand Down
6 changes: 6 additions & 0 deletions dags/vm_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
V6E_SUBNETWORKS = (
f"{V5_NETWORKS_PREFIX}/regions/us-central2/subnetworks/mas-test"
)
# TODO: Figure V6E_GCE_NETWORK and V6E_GCE_SUBNETWORK
V6E_GCE_NETWORK = "default"
V6E_GCE_SUBNETWORK = "default"

BM_NETWORKS_PREFIX_BENCHMARKING = "projects/cloud-ml-benchmarking"
BM_NETWORKS = f"{BM_NETWORKS_PREFIX_BENCHMARKING}/global/networks/mas-test"
Expand Down Expand Up @@ -100,6 +103,8 @@ class Zone(enum.Enum):
US_WEST1_C = "us-west1-c"
# reserved a3+ cluster in supercomputer-testing
AUSTRALIA_SOUTHEAST1_C = "australia-southeast1-c"
# reserved TRILLIUM capacity
EUROPE_WEST4_A = "europe-west4-a"


class MachineVersion(enum.Enum):
Expand Down Expand Up @@ -159,6 +164,7 @@ class RuntimeVersion(enum.Enum):
TPU_VM_V4_BASE = "tpu-vm-v4-base"
V2_ALPHA_TPUV5_LITE = "v2-alpha-tpuv5-lite"
V2_ALPHA_TPUV5 = "v2-alpha-tpuv5"
V2_ALPHA_TPUV6 = "v2-alpha-tpuv6e"


class XpkClusters:
Expand Down

0 comments on commit 5ecb3f5

Please sign in to comment.