From bb0c171e9cbe4e09f553051b96c82a775a409dc6 Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Thu, 30 Nov 2023 03:25:03 +0000 Subject: [PATCH 01/12] Enable docker image feature with xpk --- .github/requirements.txt | 1 + apis/task.py | 84 +++++++++++++- apis/test_config.py | 25 ++++- .../jax/solutionsTeam_jax_npi_config.py | 4 +- .../pytorch/pytorchxla_torchbench_config.py | 4 +- configs/cluster/v5e_cluster_config | 26 +++++ configs/example/gke_example_config.py | 61 ++++++++++ configs/vm_resource.py | 17 +++ ...utionsTeam_flax_latest_supported_config.py | 26 ++--- ...lutionsTeam_pax_latest_supported_config.py | 4 +- ...lutionsTeam_tf_nightly_supported_config.py | 8 +- dags/benchmark/pytorchxla_torchbench.py | 6 +- dags/example/gke_example_dag.py | 45 ++++++++ dags/xlml/pytorchxla_huggingface.py | 8 +- dags/xlml/pytorchxla_llama.py | 4 +- dags/xlml/pytorchxla_torchvision.py | 6 +- .../solutionsTeam_jax_latest_integration.py | 4 +- deployment/cloud_composer_template.tf | 1 + implementations/utils/metric.py | 51 +++++++-- implementations/utils/xpk.py | 106 ++++++++++++++++++ 20 files changed, 440 insertions(+), 51 deletions(-) create mode 100644 configs/cluster/v5e_cluster_config create mode 100644 configs/example/gke_example_config.py create mode 100644 dags/example/gke_example_dag.py create mode 100644 implementations/utils/xpk.py diff --git a/.github/requirements.txt b/.github/requirements.txt index 7bc1563d..c7e0d9cb 100644 --- a/.github/requirements.txt +++ b/.github/requirements.txt @@ -5,3 +5,4 @@ google-cloud-storage google-cloud-tpu>=1.16.0 jsonlines tensorflow-cpu +apache-airflow-providers-cncf-kubernetes \ No newline at end of file diff --git a/apis/task.py b/apis/task.py index d8fbc96d..897bf9e9 100644 --- a/apis/task.py +++ b/apis/task.py @@ -24,7 +24,7 @@ from airflow.utils.task_group import TaskGroup from apis import gcp_config, metric_config, test_config from implementations.utils import metric -from implementations.utils import ssh, tpu +from implementations.utils import ssh, tpu, xpk class BaseTask(abc.ABC): @@ -41,15 +41,15 @@ def run() -> DAGNode: @dataclasses.dataclass -class TpuTask(BaseTask): - """This is a class to set up tasks for TPU. +class TpuGceTask(BaseTask): + """This is a class to set up tasks for TPU in GCE. Attributes: task_test_config: Test configs to run on this TPU. task_gcp_config: Runtime TPU creation parameters. task_metric_config: Metric configs to process metrics. - custom_tpu_name: A custom TPU name. By default the name is - test name + accelerator name. + custom_tpu_name: A custom TPU name. By default the name is test name + + accelerator name. suffix_tpu_name: The flag to define if add auto-generated suffix. all_workers: The flag to define if run commands on all workers or worker 0 only. @@ -189,6 +189,80 @@ def clean_up(self, queued_resource: airflow.XComArg) -> DAGNode: ) +@dataclasses.dataclass +class TpuGkeTask(BaseTask): + """This is a class to set up tasks for TPU in GKE. + + Attributes: + task_test_config: Test configs to run on this TPU. + task_gcp_config: Runtime TPU creation parameters. + task_metric_config: Metric configs to process metrics. + """ + + task_test_config: test_config.TestConfig[test_config.Tpu] + task_gcp_config: gcp_config.GCPConfig + task_metric_config: Optional[metric_config.MetricConfig] = None + + def run(self) -> DAGNode: + """Run a test job within a docker image. + + Returns: + A task group with the following tasks chained: run_model and + post_process. + """ + with TaskGroup(group_id=self.task_test_config.benchmark_id) as group: + self.run_model() >> self.post_process() + + return group + + def run_model(self) -> DAGNode: + """Run the TPU test in `task_test_config` using xpk. + + Returns: + A DAG node that executes the model test. + """ + with TaskGroup(group_id="run_model") as group: + workload_id = xpk.generate_workload_id(self.task_test_config.benchmark_id) + run_workload = xpk.run_workload( + task_id="run_workload", + project_id=self.task_gcp_config.project_name, + zone=self.task_gcp_config.zone, + cluster_name=self.task_test_config.cluster_name, + benchmark_id=self.task_test_config.benchmark_id, + workload_id=workload_id, + docker_image=self.task_test_config.docker_image, + accelerator_type=self.task_test_config.accelerator.name, + run_cmds=self.task_test_config.run_model_cmds, + task_owner=self.task_test_config.task_owner, + ) + wait_for_workload_completion = xpk.wait_for_workload_completion.override( + timeout=self.task_test_config.time_out_in_min * 60 + )( + workload_id=workload_id, + cluster_config=self.task_test_config.cluster_config, + ) + + workload_id >> run_workload >> wait_for_workload_completion + return group + + def post_process(self) -> DAGNode: + """Process metrics and metadata, and insert them into BigQuery tables. + + Returns: + A DAG node that executes the post process. + """ + with TaskGroup(group_id="post_process") as group: + process_id = metric.generate_process_id.override(retries=1)() + metric.process_metrics.override(retries=1)( + process_id, + self.task_test_config, + self.task_metric_config, + self.task_gcp_config, + ) + + return group + + @dataclasses.dataclass class GpuTask(BaseTask): """This is a class to set up tasks for GPU. diff --git a/apis/test_config.py b/apis/test_config.py index a4cdd1fe..558d25b6 100644 --- a/apis/test_config.py +++ b/apis/test_config.py @@ -49,6 +49,7 @@ def __init__(self, accelerator, task_owner=None, test_name): import shlex from typing import Any, Generic, Iterable, List, Optional, TypeVar import attrs +from configs import vm_resource class Accelerator(abc.ABC): @@ -78,7 +79,7 @@ class Tpu(Accelerator): version: str cores: int - runtime_version: str + runtime_version: str = vm_resource.RuntimeVersion.TPU_UBUNTU2204_BASE.value network: str = 'default' subnetwork: str = 'default' reserved: bool = False @@ -154,6 +155,28 @@ def test_script(self) -> str: return '\n'.join(self.run_model_cmds) +@attrs.define +class TpuGkeTest(TestConfig[Tpu]): + test_name: str + cluster_name: str + cluster_config: str + docker_image: str + set_up_cmds: Iterable[str] + run_model_cmds: Iterable[str] + + @property + def benchmark_id(self) -> str: + return f'{self.test_name}-{self.accelerator.name}' + + @property + def setup_script(self) -> Optional[str]: + return ';'.join(self.set_up_cmds) + + @property + def test_script(self) -> str: + return ';'.join(self.run_model_cmds) + + @attrs.define class JSonnetTpuVmTest(TestConfig[Tpu]): """Convert legacy JSonnet test configs into a Task. diff --git a/configs/benchmark/jax/solutionsTeam_jax_npi_config.py b/configs/benchmark/jax/solutionsTeam_jax_npi_config.py index 6c6ce5b2..f8ce91bc 100644 --- a/configs/benchmark/jax/solutionsTeam_jax_npi_config.py +++ b/configs/benchmark/jax/solutionsTeam_jax_npi_config.py @@ -25,7 +25,7 @@ def get_jax_vit_config( tpu_cores: int, tpu_zone: str, time_out_in_min: int, -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=vm_resource.PROJECT_CLOUD_ML_AUTO_SOLUTIONS, zone=tpu_zone, @@ -99,7 +99,7 @@ def get_jax_vit_config( ) ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, task_metric_config=job_metric_config, diff --git a/configs/benchmark/pytorch/pytorchxla_torchbench_config.py b/configs/benchmark/pytorch/pytorchxla_torchbench_config.py index a5d874f4..c2739ed5 100644 --- a/configs/benchmark/pytorch/pytorchxla_torchbench_config.py +++ b/configs/benchmark/pytorch/pytorchxla_torchbench_config.py @@ -70,7 +70,7 @@ def get_torchbench_config( time_out_in_min: int, model_name: str = "", extraFlags: str = "", -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=vm_resource.PROJECT_CLOUD_ML_AUTO_SOLUTIONS, zone=tpu_zone, @@ -114,7 +114,7 @@ def get_torchbench_config( ) ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, task_metric_config=job_metric_config, diff --git a/configs/cluster/v5e_cluster_config b/configs/cluster/v5e_cluster_config new file mode 100644 index 00000000..f273bf28 --- /dev/null +++ b/configs/cluster/v5e_cluster_config @@ -0,0 +1,26 @@ +apiVersion: v1 +clusters: +- cluster: + certificate-authority-data: LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVMVENDQXBXZ0F3SUJBZ0lSQU5sdUp6SkFSS3FvNWFvZXF2M2l0NTh3RFFZSktvWklodmNOQVFFTEJRQXcKTHpFdE1Dc0dBMVVFQXhNa016ZzJZak5qTkRRdE9HWmtOQzAwTjJRNUxXSTBNVFF0WWpkall6RmlNbUV4TmpOaApNQ0FYRFRJek1URXlNREl4TURNek5Wb1lEekl3TlRNeE1URXlNakl3TXpNMVdqQXZNUzB3S3dZRFZRUURFeVF6Ck9EWmlNMk0wTkMwNFptUTBMVFEzWkRrdFlqUXhOQzFpTjJOak1XSXlZVEUyTTJFd2dnR2lNQTBHQ1NxR1NJYjMKRFFFQkFRVUFBNElCandBd2dnR0tBb0lCZ1FDbGoyd2V3UkNFMHBnTjdRMFJOVFlOVlJLQUIrQmVROU1OTG82dwpkZU9EV09uQ3Npd2VvSEpYUlkyM09zdTRDTVRLUVNCbFNrWkZCZFZGR0pYSFdzeEhjdjJ1ZEdwa3NLNnUwSkNlCjdjL2lnc1NLVGwvVlkxQ0RKWnpDQTFmeWpab2hxdGhMaXZFajRFU3JqTTBSSkZUQlVEdXRMQmpsQ0xOWnRKaS8KZm54aldhSlBxVGl1TDR4MUNGMy9EbkxGTjhKQWdROE9JQzFGbkpFdHRlUE9vcXZKOG0rdVZLRHltNUFHR1NvSgpHTjdieTgrcmFJWEFCa1BuNkloVmdqWnllNWJwNzRKZ29GdVAxSTFPeVBoVDVEV2ZQcFl4U3lFTTIzSnI0ZEQrCkhpaGxBdXhQc1M1a29vODJPSVlFcXFtTnM5VDdIV3ByMnJrNDJ2dUtFN0JzMVRrb1dxTFQvUFAzOGdJMzZJL0kKWEZXYzE3NU45L1hRQmlaMGFYNWVxRWdrWWkzRjdGV3FrWEI5N2VWdkJDYTQ1YXdmdFpTQWdwQWtuWnZ3MnpFagp3N2hneUdQZ1g2L0xIM05naFYyelNDdUhJanJIOEtNY0FsaE1nbWRhZlphUzZtdkJvVEEyNmZvdTF5eU1vRHJQCmJSbkxvaFBtQXJEdm9xalJLTnZCbjRrbExUTUNBd0VBQWFOQ01FQXdEZ1lEVlIwUEFRSC9CQVFEQWdJRU1BOEcKQTFVZEV3RUIvd1FGTUFNQkFmOHdIUVlEVlIwT0JCWUVGTXBNcW9XRDMyOUk2K1hJVmNnanVvVXpvZklXTUEwRwpDU3FHU0liM0RRRUJDd1VBQTRJQmdRQVExQ3ZlTDF1a1JBeEQzLzdvZUhReXRESDB4QkEwT3hYaU5nc0VkY3RLCnhsMlpFMjlmdCtEWEpDSkI2ZjgzYkFGUW5leWVGYmRZTFFadWtzdHdWRk9JZ21BWVExb3pTRjM1R2d5UCtML3AKN0IxR2hDWGhoYUZ5Rk1kdGEwRER4aGZkQ2J1V280N1RMcmpGZG5GdlFmVVlMVWNPcDVoZkNZNUxaWUlXaVF0Twp1VDRRUnIrc09YNm9JeUtxQ1dvUjBDVHpJUS8weHhHaE8zYXMwNy9LbmZ6OWNFZE1xb0xmNS9BOFVrc09mNzcrCkdUVE8wa082NzhQYWxadityZm05K216aTIzQVFmUkxwOWVkRUFyQjl4YlQrTVZIcENPK3ZYaHRLL1JPcUR5b1AKM0FxanQ4N3A0UnNZT01nT0V3YVVRRHlmdzFDTTFmamN6NDlveEF2SHRxZC9Tc0dOSW9lWFc5bFlMWmV1TGE4KwpwYnV2UXhnaVdlUnBCSUVPTTB2dERTQmgrbW16eUVrS20xM29RbDlGV0p2NGNmOWM3SUEvWG8zQWVhMC96R08rCmVtakJ3Z3BpSUxDR1JNRDNSaFovbERNRjhTRVhjeHhkYjlaMmc5Y2lSek1aa1IweG94NDcxRENlZ0Q1eFFmS3cKbDBxeHNqakQ3YVpiK1VOTFBDUCthMUk9Ci0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K + server: https://34.125.160.150 + name: gke_tpu-prod-env-one-vm_us-west4_ran-xpk-test-zone +contexts: +- context: + cluster: gke_tpu-prod-env-one-vm_us-west4_ran-xpk-test-zone + user: gke_tpu-prod-env-one-vm_us-west4_ran-xpk-test-zone + name: gke_tpu-prod-env-one-vm_us-west4_ran-xpk-test-zone +current-context: gke_tpu-prod-env-one-vm_us-west4_ran-xpk-test-zone +kind: Config +preferences: {} +users: +- name: gke_tpu-prod-env-one-vm_us-west4_ran-xpk-test-zone + user: + exec: + apiVersion: client.authentication.k8s.io/v1beta1 + args: null + command: gke-gcloud-auth-plugin + env: null + installHint: Install gke-gcloud-auth-plugin for use with kubectl by following + https://cloud.google.com/blog/products/containers-kubernetes/kubectl-auth-changes-in-gke + interactiveMode: IfAvailable + provideClusterInfo: true diff --git a/configs/example/gke_example_config.py b/configs/example/gke_example_config.py new file mode 100644 index 00000000..03ee2350 --- /dev/null +++ b/configs/example/gke_example_config.py @@ -0,0 +1,61 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to construct configs for example_dag.""" + +from apis import gcp_config, metric_config, task, test_config +from configs import test_owner + + +def get_flax_resnet_gke_config( + tpu_version: str, + tpu_cores: int, + tpu_zone: str, + cluster_name: str, + cluster_config: str, + docker_image: str, + time_out_in_min: int, +) -> task.TpuGkeTask: + # TODO(ranran): update the project once quota is approved (b/311073979). + job_gcp_config = gcp_config.GCPConfig( + project_name="tpu-prod-env-one-vm", + zone=tpu_zone, + dataset_name=metric_config.DatasetOption.XLML_DATASET, + ) + + run_model_cmds = ( + "python3 /tmp/flax/examples/imagenet/main.py" + " --config=/tmp/flax/examples/imagenet/configs/tpu.py" + " --workdir=/tmp/imagenet --config.num_epochs=1" + ) + + job_test_config = test_config.TpuGkeTest( + test_config.Tpu( + version=tpu_version, + cores=tpu_cores, + ), + test_name="flax-resnet-gke", + cluster_name=cluster_name, + cluster_config=cluster_config, + docker_image=docker_image, + run_model_cmds=run_model_cmds, + set_up_cmds=None, + time_out_in_min=time_out_in_min, + task_owner=test_owner.RAN_R, + ) + + return task.TpuGkeTask( + task_test_config=job_test_config, + task_gcp_config=job_gcp_config, + ) diff --git a/configs/vm_resource.py b/configs/vm_resource.py index f87be53b..3bfb324e 100644 --- a/configs/vm_resource.py +++ b/configs/vm_resource.py @@ -32,3 +32,20 @@ class RuntimeVersion(enum.Enum): TPU_VM_TF_NIGHTLY_POD = "tpu-vm-tf-nightly-pod" TPU_UBUNTU2204_BASE = "tpu-ubuntu2204-base" TPU_VM_V4_BASE = "tpu-vm-v4-base" + + +# TODO(ranran): update the cluster name once quota is approved (b/311073979). +class ClusterName(enum.Enum): + V4_CLUSTER = "" + V5E_CLUSTER = "ran-xpk-test-zone" + + +# TODO(ranran): update the cluster name once quota is approved (b/311073979). +class ClusterConfig(enum.Enum): + V4_CONFIG = "v4_cluster_config" + V5E_CONFIG = "v5e_cluster_config" + + +# TODO(ranran): update the project once quota is approved (b/311073979). +class DockerImage(enum.Enum): + DEMO_TEST = "gcr.io/tpu-prod-env-one-vm/xpk_jax_test:latest" diff --git a/configs/xlml/jax/solutionsTeam_flax_latest_supported_config.py b/configs/xlml/jax/solutionsTeam_flax_latest_supported_config.py index 8eb0feac..3c42115e 100644 --- a/configs/xlml/jax/solutionsTeam_flax_latest_supported_config.py +++ b/configs/xlml/jax/solutionsTeam_flax_latest_supported_config.py @@ -33,7 +33,7 @@ def get_flax_resnet_config( time_out_in_min: int, data_dir: str = gcs_bucket.TFDS_DATA_DIR, extraFlags: str = "", -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -66,7 +66,7 @@ def get_flax_resnet_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) @@ -134,7 +134,7 @@ def get_flax_vit_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) @@ -204,7 +204,7 @@ def get_flax_gpt2_config( tpu_zone: str, time_out_in_min: int, extraFlags: str = "", -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -252,7 +252,7 @@ def get_flax_gpt2_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) @@ -265,7 +265,7 @@ def get_flax_sd_config( time_out_in_min: int, num_train_epochs: int, extraFlags: str = "", -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -307,7 +307,7 @@ def get_flax_sd_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) @@ -319,7 +319,7 @@ def get_flax_bart_config( tpu_zone: str, time_out_in_min: int, extraFlags: str = "", -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -358,7 +358,7 @@ def get_flax_bart_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) @@ -372,7 +372,7 @@ def get_flax_bert_config( task_name: str, num_train_epochs: int = 1, extraFlags: str = "", -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -408,7 +408,7 @@ def get_flax_bert_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) @@ -422,7 +422,7 @@ def get_flax_wmt_config( num_train_steps: int, data_dir: str = gcs_bucket.TFDS_DATA_DIR, extraFlags: str = "", -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -461,7 +461,7 @@ def get_flax_wmt_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) diff --git a/configs/xlml/jax/solutionsTeam_pax_latest_supported_config.py b/configs/xlml/jax/solutionsTeam_pax_latest_supported_config.py index 3689ae4f..9f5347c1 100644 --- a/configs/xlml/jax/solutionsTeam_pax_latest_supported_config.py +++ b/configs/xlml/jax/solutionsTeam_pax_latest_supported_config.py @@ -30,7 +30,7 @@ def get_pax_lm_config( log_dir: str, ckp_path: str = "", extraFlags: str = "", -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=vm_resource.PROJECT_CLOUD_ML_AUTO_SOLUTIONS, zone=tpu_zone, @@ -63,7 +63,7 @@ def get_pax_lm_config( task_owner=test_owner.GERSON_K, ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) diff --git a/configs/xlml/tensorflow/solutionsTeam_tf_nightly_supported_config.py b/configs/xlml/tensorflow/solutionsTeam_tf_nightly_supported_config.py index fb7a7d79..5ec7d3b1 100644 --- a/configs/xlml/tensorflow/solutionsTeam_tf_nightly_supported_config.py +++ b/configs/xlml/tensorflow/solutionsTeam_tf_nightly_supported_config.py @@ -30,7 +30,7 @@ def get_tf_resnet_config( tfds_data_dir: str = gcs_bucket.TFDS_DATA_DIR, train_steps: int = 320, validation_interval: int = 320, -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=vm_resource.PROJECT_CLOUD_ML_AUTO_SOLUTIONS, zone=tpu_zone, @@ -84,7 +84,7 @@ def get_tf_resnet_config( task_owner=test_owner.CHANDRA_D, ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, custom_tpu_name=tpu_name, @@ -103,7 +103,7 @@ def get_tf_bert_config( tfds_data_dir: str = gcs_bucket.TFDS_DATA_DIR, train_steps: int = 2000, validation_interval: int = 1000, -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=vm_resource.PROJECT_CLOUD_ML_AUTO_SOLUTIONS, zone=tpu_zone, @@ -165,7 +165,7 @@ def get_tf_bert_config( task_owner=test_owner.CHANDRA_D, ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, custom_tpu_name=tpu_name, diff --git a/dags/benchmark/pytorchxla_torchbench.py b/dags/benchmark/pytorchxla_torchbench.py index 7d0a011b..5bbe9ffa 100644 --- a/dags/benchmark/pytorchxla_torchbench.py +++ b/dags/benchmark/pytorchxla_torchbench.py @@ -31,9 +31,9 @@ SCHEDULED_TIME = "0 17 * * *" if composer_env.is_prod_env() else None with models.DAG( - dag_id="pytorchxla-torchbench", - schedule=SCHEDULED_TIME, - tags=["pytorchxla", "nightly", "torchbench"], + dag_id="pytorch_nightly_torchbench", + schedule=None, + tags=["pytorch", "nightly", "torchbench", "benchmark"], start_date=datetime.datetime(2023, 8, 29), catchup=False, ) as dag: diff --git a/dags/example/gke_example_dag.py b/dags/example/gke_example_dag.py new file mode 100644 index 00000000..e9b3e114 --- /dev/null +++ b/dags/example/gke_example_dag.py @@ -0,0 +1,45 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A DAG to run all GKE examples.""" + +import datetime +from airflow import models +from configs import vm_resource +from configs.example import gke_example_config as config + + +# TODO(ranran): add following examples: +# 1) jax_resnet_tpu_example (gce example dag) +# 2) jax_vit_tpu_benchmark_example (gce example dag) +# 3) jax_vit_tpu_benchmark_example (same dag) +with models.DAG( + dag_id="gke_example_dag", + schedule=None, + tags=["example", "gke", "xlml", "benchmark"], + start_date=datetime.datetime(2023, 11, 29), + catchup=False, +) as dag: + jax_resnet_tpu_example = config.get_flax_resnet_gke_config( + tpu_version="5litepod", + tpu_cores=16, + tpu_zone="us-west4-a", + cluster_name=vm_resource.ClusterName.V5E_CLUSTER.value, + cluster_config=vm_resource.ClusterConfig.V5E_CONFIG.value, + docker_image=vm_resource.DockerImage.DEMO_TEST.value, + time_out_in_min=60, + ).run() + + # Test dependencies + jax_resnet_tpu_example diff --git a/dags/xlml/pytorchxla_huggingface.py b/dags/xlml/pytorchxla_huggingface.py index 5b73ad15..cc82d68f 100644 --- a/dags/xlml/pytorchxla_huggingface.py +++ b/dags/xlml/pytorchxla_huggingface.py @@ -39,19 +39,19 @@ start_date=datetime.datetime(2023, 7, 12), catchup=False, ): - accelerate_v2_8 = task.TpuTask( + accelerate_v2_8 = task.TpuGceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-accelerate-smoke-v2-8-1vm" ), US_CENTRAL1_C, ).run() - accelerate_v4_8 = task.TpuTask( + accelerate_v4_8 = task.TpuGceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-accelerate-smoke-v4-8-1vm" ), US_CENTRAL2_B, ).run() - diffusers_v4_8 = task.TpuTask( + diffusers_v4_8 = task.TpuGceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-hf-diffusers-func-v4-8-1vm" ), @@ -61,7 +61,7 @@ accelerate_v4_8 >> accelerate_v2_8 accelerate_v4_8 >> diffusers_v4_8 - task.TpuTask( + task.TpuGceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-hf-fsmt-pjrt-func-v4-8-1vm" ), diff --git a/dags/xlml/pytorchxla_llama.py b/dags/xlml/pytorchxla_llama.py index ace862f9..2101ae6a 100644 --- a/dags/xlml/pytorchxla_llama.py +++ b/dags/xlml/pytorchxla_llama.py @@ -33,13 +33,13 @@ start_date=datetime.datetime(2023, 7, 12), catchup=False, ): - llama_inference_v4_8 = task.TpuTask( + llama_inference_v4_8 = task.TpuGceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-llama2-i-infer-func-v4-8-1vm" ), US_CENTRAL2_B, ).run() - llama_train_v4_8 = task.TpuTask( + llama_train_v4_8 = task.TpuGceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-llama2-t-train-spmd-func-v4-8-1vm" ), diff --git a/dags/xlml/pytorchxla_torchvision.py b/dags/xlml/pytorchxla_torchvision.py index 49ab8037..2f657225 100644 --- a/dags/xlml/pytorchxla_torchvision.py +++ b/dags/xlml/pytorchxla_torchvision.py @@ -39,19 +39,19 @@ start_date=datetime.datetime(2023, 7, 12), catchup=False, ): - mnist_v2_8 = task.TpuTask( + mnist_v2_8 = task.TpuGceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-mnist-pjrt-func-v2-8-1vm" ), US_CENTRAL1_C, ).run() - resnet_v2_8 = task.TpuTask( + resnet_v2_8 = task.TpuGceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-resnet50-pjrt-fake-v2-8-1vm" ), US_CENTRAL1_C, ).run() - resnet_v4_8 = task.TpuTask( + resnet_v4_8 = task.TpuGceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-resnet50-pjrt-fake-v4-8-1vm" ), diff --git a/dags/xlml/solutionsTeam_jax_latest_integration.py b/dags/xlml/solutionsTeam_jax_latest_integration.py index fe68a932..c2a34c69 100644 --- a/dags/xlml/solutionsTeam_jax_latest_integration.py +++ b/dags/xlml/solutionsTeam_jax_latest_integration.py @@ -39,13 +39,13 @@ start_date=datetime.datetime(2023, 7, 12), catchup=False, ): - compilation_cache = task.TpuTask( + compilation_cache = task.TpuGceTask( test_config.JSonnetTpuVmTest.from_jax( "jax-compilation-cache-test-func-v2-8-1vm" ), US_CENTRAL1_C, ).run() - pod = task.TpuTask( + pod = task.TpuGceTask( test_config.JSonnetTpuVmTest.from_jax( "jax-pod-latest-tpu-ubuntu2204-base-func-v2-32-1vm" ), diff --git a/deployment/cloud_composer_template.tf b/deployment/cloud_composer_template.tf index 632136cc..a5f0fd64 100644 --- a/deployment/cloud_composer_template.tf +++ b/deployment/cloud_composer_template.tf @@ -113,6 +113,7 @@ resource "google_composer_environment" "example_environment" { # google-cloud-bigquery = "" # google-cloud-storage = "" # tensorflow-cpu = "" + # apache-airflow-providers-cncf-kubernetes = "" } } diff --git a/implementations/utils/metric.py b/implementations/utils/metric.py index 6fcd58cd..1fced4f5 100644 --- a/implementations/utils/metric.py +++ b/implementations/utils/metric.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - """Utilities to process Benchmark metrics.""" import dataclasses @@ -397,8 +396,38 @@ def is_valid_entry() -> bool: return True -def get_job_status(benchmark_id: str) -> bigquery.JobStatus: - """Get job status for the run. +def get_gke_job_status(benchmark_id: str) -> bigquery.JobStatus: + """Get job status for the GKE run. + + FAILED - if any failure occurs in run_model + SUCCESS - end-to-end model tests are successful in run_model + """ + context = get_current_context() + execution_date = context["dag_run"].logical_date + current_dag = context["dag"] + + workload_completion = current_dag.get_task( + task_id=f"{benchmark_id}.run_model.wait_for_workload_completion" + ) + workload_completion_ti = TaskInstance(workload_completion, execution_date) + workload_completion_state = workload_completion_ti.current_state() + + if workload_completion_state == TaskState.SUCCESS.value: + logging.info( + "The wait_for_workload_completion state is success, and the job status" + " is success." + ) + return bigquery.JobStatus.SUCCESS + + logging.info( + "The wait_for_workload_completion state is not success, and the job" + " status is failed." + ) + return bigquery.JobStatus.FAILED + + +def get_gce_job_status(benchmark_id: str) -> bigquery.JobStatus: + """Get job status for the GCE run. MISSED - if any failure occurs in initialize & create_queued_resource FAILED - if any failure occurs in setup & run_model (including timeout of @@ -414,12 +443,12 @@ def get_job_status(benchmark_id: str) -> bigquery.JobStatus: setup_ti = TaskInstance(setup_task, execution_date) setup_state = setup_ti.current_state() if setup_state == TaskState.SKIPPED.value: - print("The setup state is skipped, and the job status is missed.") + logging.info("The setup state is skipped, and the job status is missed.") return bigquery.JobStatus.MISSED # check setup status to see if setup step is successful if setup_state == TaskState.FAILED.value: - print("The setup state is failed, and the job status is failed.") + logging.info("The setup state is failed, and the job status is failed.") return bigquery.JobStatus.FAILED # check run_model status to see if run_model step is successful @@ -428,10 +457,12 @@ def get_job_status(benchmark_id: str) -> bigquery.JobStatus: run_model_state = run_model_ti.current_state() if run_model_state == TaskState.SUCCESS.value: - print("The run_model state is success, and the job status is success.") + logging.info( + "The run_model state is success, and the job status is success." + ) return bigquery.JobStatus.SUCCESS - print("The run_model state is failed, and the job status is failed.") + logging.info("The run_model state is failed, and the job status is failed.") return bigquery.JobStatus.FAILED @@ -493,7 +524,11 @@ def process_metrics( task_gcp_config.project_name, task_gcp_config.dataset_name.value ) - test_job_status = get_job_status(task_test_config.benchmark_id) + if task_test_config.cluster_name: + test_job_status = get_gke_job_status(task_test_config.benchmark_id) + else: + test_job_status = get_gce_job_status(task_test_config.benchmark_id) + for index in range(len(metadata_history_rows_list)): job_history_row = bigquery.JobHistoryRow( uuid=generate_row_uuid(base_id, index), diff --git a/implementations/utils/xpk.py b/implementations/utils/xpk.py new file mode 100644 index 00000000..eb91c5cc --- /dev/null +++ b/implementations/utils/xpk.py @@ -0,0 +1,106 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to run workloads with xpk (https://github.com/google/xpk).""" + +import uuid +from absl import logging +from airflow.decorators import task +from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator +from kubernetes import client, config + + +@task +def generate_workload_id(benchmark_id: str) -> str: + """Generate a workload ID.""" + short_id = str(uuid.uuid4())[:8] + return f"{benchmark_id}-{short_id}" + + +def run_workload( + task_id: str, + project_id: str, + zone: str, + cluster_name: str, + benchmark_id: str, + workload_id: str, + docker_image: str, + accelerator_type: str, + run_cmds: str, + task_owner: str, + num_slices: int = 1, +) -> KubernetesPodOperator: + """Run workload through xpk tool. + + The reason to use KubernetesPodOperator instead of BashOperator is that + xpk must run with Python 3.10 or greater; however, the latest version in + Composer is Python 3.8, and it's non-trivial to upgrade it as the Composer + uses docker images that bundle Airflow releases with Python and other + libraries. + """ + + cmds = ( + "set -x", + f"gcloud config set project {project_id}", + f"gcloud config set compute/zone {zone}", + "git clone -b xpk-namespace https://github.com/google/xpk.git /tmp/xpk", + "cd /tmp/xpk", + ( + "python3 xpk.py workload create" + f" --cluster={cluster_name} --workload={workload_id} --command='{run_cmds}'" + f" --tpu-type={accelerator_type} --num-slices={num_slices} --docker-image={docker_image} --namespace=default" + ), + ) + + return KubernetesPodOperator( + task_id=task_id, + name=benchmark_id, + cmds=["/bin/bash", "-c"], + arguments=[";".join(cmds)], + namespace="composer-user-workloads", + image=docker_image, + config_file="/home/airflow/composer_kube_config", + kubernetes_conn_id="kubernetes_default", + owner=task_owner, + ) + + +@task.sensor(poke_interval=60, timeout=600, mode="reschedule") +def wait_for_workload_completion(workload_id: str, cluster_config: str) -> bool: + """Check the workload status.""" + + # Load the config for the cluster with TPUs in the pool + config.load_kube_config( + config_file=f"/home/airflow/gcs/dags/configs/cluster/{cluster_config}" + ) + core_api = client.CoreV1Api() + + logging.info(f"workload_id: {workload_id}") + pods = core_api.list_namespaced_pod( + label_selector=f"jobset.sigs.k8s.io/jobset-name={workload_id}", + namespace="default", + ) + + if not pods.items: + RuntimeError(f"No pod is found for workload selector: {pods}.") + + for pod in pods.items: + if pod.status.phase in ["Pending", "Running"]: + logging.info(f"One pod phase is: {pod.status.phase}") + return False + elif pod.status.phase in ["Failed", "Unknown"]: + RuntimeError(f"Bad pod phase: {pod.status.phase}") + + logging.info("All pod(s) phase are succeeded.") + return True From 3cde9532b4a84d847e61aed84dc9701971c42f46 Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Thu, 30 Nov 2023 18:26:56 +0000 Subject: [PATCH 02/12] Fix issue from the merge --- .../xlml/jax/solutionsTeam_flax_latest_supported_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/xlml/jax/solutionsTeam_flax_latest_supported_config.py b/configs/xlml/jax/solutionsTeam_flax_latest_supported_config.py index 3c42115e..8ab8bfa6 100644 --- a/configs/xlml/jax/solutionsTeam_flax_latest_supported_config.py +++ b/configs/xlml/jax/solutionsTeam_flax_latest_supported_config.py @@ -110,7 +110,7 @@ def get_flax_vit_config( time_out_in_min: int, num_train_epochs: int = 3, extraFlags: str = "", -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -147,7 +147,7 @@ def get_flax_vit_conv_config( time_out_in_min: int, num_train_epochs: int = 30, extraFlags: str = "", -) -> task.TpuTask: +) -> task.TpuGceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -191,7 +191,7 @@ def get_flax_vit_conv_config( ) ) - return task.TpuTask( + return task.TpuGceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, task_metric_config=job_metric_config, From 006de9828d774b44213b3e6f9ae020ef43d62fc0 Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Thu, 30 Nov 2023 23:11:39 -0800 Subject: [PATCH 03/12] Add unit tests to checks (#14) * Add unit tests to checks * Update test for auth --- ...e-checklist.yaml => require-checklist.yml} | 0 .github/workflows/unit-test.yml | 32 +++++++++++++++++++ implementations/utils/bigquery_test.py | 16 +++++++--- implementations/utils/metric_test.py | 19 ++++++++--- 4 files changed, 57 insertions(+), 10 deletions(-) rename .github/workflows/{require-checklist.yaml => require-checklist.yml} (100%) create mode 100644 .github/workflows/unit-test.yml diff --git a/.github/workflows/require-checklist.yaml b/.github/workflows/require-checklist.yml similarity index 100% rename from .github/workflows/require-checklist.yaml rename to .github/workflows/require-checklist.yml diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml new file mode 100644 index 00000000..bb86ea5e --- /dev/null +++ b/.github/workflows/unit-test.yml @@ -0,0 +1,32 @@ +# Run all unit tests in files named as *_test.py +name: Unit Test + +on: + pull_request: + branches: [master] + types: [opened, synchronize, edited] + + push: + branches: [master] + + workflow_dispatch: {} + +jobs: + dags: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + submodules: 'true' + + - uses: actions/setup-python@v4 + with: + # Note: this should match Cloud Composer + # https://cloud.google.com/composer/docs/concepts/versioning/composer-versions + python-version: '3.8' + + - name: Install Python dependencies + run: pip install -r .github/requirements.txt + + - name: Run tests + run: python3 -m unittest discover implementations "*_test.py" diff --git a/implementations/utils/bigquery_test.py b/implementations/utils/bigquery_test.py index 7283b371..4826ae53 100644 --- a/implementations/utils/bigquery_test.py +++ b/implementations/utils/bigquery_test.py @@ -19,6 +19,7 @@ from unittest import mock from absl.testing import absltest from absl.testing import parameterized +import google.auth from google.cloud import bigquery from implementations.utils import bigquery as test_bigquery @@ -55,21 +56,26 @@ def setUp(self): ("5.0", 5.0, True), ) def test_is_valid_metric(self, x: float, expected_value: bool): - bq_metric = test_bigquery.BigQueryMetricClient() - actual_value = bq_metric.is_valid_metric(x) - self.assertEqual(actual_value, expected_value) + with mock.patch.object( + google.auth, "default", return_value=["mock", "mock_project"] + ) as mock_object: + bq_metric = test_bigquery.BigQueryMetricClient() + actual_value = bq_metric.is_valid_metric(x) + self.assertEqual(actual_value, expected_value) + @mock.patch.object(google.auth, "default", return_value=["mock", "mock_project"]) @mock.patch.object(bigquery.Client, "get_table", return_value="mock_table") @mock.patch.object( bigquery.Client, "insert_rows", return_value=["there is an error"] ) - def test_insert_failure(self, get_table, insert_rows): + def test_insert_failure(self, default, get_table, insert_rows): bq_metric = test_bigquery.BigQueryMetricClient() self.assertRaises(RuntimeError, bq_metric.insert, self.test_runs) + @mock.patch.object(google.auth, "default", return_value=["mock", "mock_project"]) @mock.patch.object(bigquery.Client, "get_table", return_value="mock_table") @mock.patch.object(bigquery.Client, "insert_rows", return_value=[]) - def test_insert_success(self, get_table, insert_rows): + def test_insert_success(self, default, get_table, insert_rows): bq_metric = test_bigquery.BigQueryMetricClient() bq_metric.insert(self.test_runs) diff --git a/implementations/utils/metric_test.py b/implementations/utils/metric_test.py index 9d0e3ce0..3bddca3b 100644 --- a/implementations/utils/metric_test.py +++ b/implementations/utils/metric_test.py @@ -16,15 +16,17 @@ import hashlib import os +import sys from typing import Iterable, Optional from unittest import mock +from absl import flags from absl.testing import absltest from absl.testing import parameterized from apis import metric_config from configs import composer_env from implementations.utils import bigquery -from implementations.utils import metric from implementations.utils import composer +from implementations.utils import metric import jsonlines import tensorflow as tf @@ -34,8 +36,15 @@ class BenchmarkMetricTest(parameterized.TestCase, absltest.TestCase): + def get_tempdir(self): + try: + flags.FLAGS.test_tmpdir + except flags.UnparsedFlagAccessError: + flags.FLAGS(sys.argv) + return self.create_tempdir().full_path + def generate_tb_file(self): - temp_dir = self.create_tempdir().full_path + temp_dir = self.get_tempdir() summary_writer = tf.summary.create_file_writer(temp_dir) with summary_writer.as_default(): @@ -269,9 +278,9 @@ def test_add_airflow_metadata(self): "COMPOSER_ENVIRONMENT": "test_env", }, ) as mock_variable: - with mock.patch.object(composer, - "get_airflow_url", - return_value="http://airflow") as mock_object: + with mock.patch.object( + composer, "get_airflow_url", return_value="http://airflow" + ) as mock_object: raw_meta = [ [ bigquery.MetadataHistoryRow( From 908246c9599460f8a2a7fbddf261fdfbcd23797a Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Fri, 1 Dec 2023 09:27:42 -0800 Subject: [PATCH 04/12] Update workflow name to unit test (#29) --- .github/workflows/unit-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index bb86ea5e..f43a9f60 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -12,7 +12,7 @@ on: workflow_dispatch: {} jobs: - dags: + unit_test: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 From 6b73c420c275839d941aa747fd192f2b39b7915a Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Fri, 1 Dec 2023 16:14:08 -0800 Subject: [PATCH 05/12] Add auto push after commit merge (#30) * Add auto push after commit merge * Address comment --- pipeline/auto-push.cloudbuild.yml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 pipeline/auto-push.cloudbuild.yml diff --git a/pipeline/auto-push.cloudbuild.yml b/pipeline/auto-push.cloudbuild.yml new file mode 100644 index 00000000..255c7d1a --- /dev/null +++ b/pipeline/auto-push.cloudbuild.yml @@ -0,0 +1,6 @@ +steps: + - name: google/cloud-sdk + args: + - scripts/upload-tests.sh + - gs://us-central1-ml-automation-s-bc954647-bucket/dags + entrypoint: bash From 908a358760d971db1c1b7f266403e13a0211dec1 Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Fri, 1 Dec 2023 17:19:07 -0800 Subject: [PATCH 06/12] Update email address in README file (#31) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d80a5387..c32ef4af 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# ML Automation Solutions +# ML Automation Solutions (MAS) A simplified and automated orchestration workflow to perform ML end-to-end (E2E) model tests and benchmarking on Cloud VMs across different frameworks. @@ -12,7 +12,7 @@ bash scripts/upload-tests.sh gs:///dags ``` 4. After the automatically scheduled tests start running, integrate [Looker Studio](https://cloud.google.com/bigquery/docs/bi-engine-looker-studio) or any other dashboard with BigQuery to monitor metrics. -If you have a use case that ML Automation Solutions does not cover, please email ml-testing-accelerators-users@googlegroups.com. We're here to help! +If you have a use case that MAS does not cover, please email ml-auto-solutions-users@googlegroups.com. We're here to help! ## Contributing From fd1681aafd47cb2bf15dda27fbe872f5593b6867 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 4 Dec 2023 15:51:24 -0800 Subject: [PATCH 07/12] Build JSonnet files as part of the postsubmit (#34) * Build JSonnet files as part of the postsubmit Change-Id: I811493a530d8eed91c975aebe15d714903283678 * remove extra quotes Change-Id: Iba35dacab442d783304f738a2a42bb0bd350fb63 --- pipeline/auto-push.cloudbuild.yml | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/pipeline/auto-push.cloudbuild.yml b/pipeline/auto-push.cloudbuild.yml index 255c7d1a..9d52ac1e 100644 --- a/pipeline/auto-push.cloudbuild.yml +++ b/pipeline/auto-push.cloudbuild.yml @@ -1,6 +1,22 @@ steps: - - name: google/cloud-sdk - args: - - scripts/upload-tests.sh - - gs://us-central1-ml-automation-s-bc954647-bucket/dags - entrypoint: bash +- name: golang:alpine + id: download-jsonnet + entrypoint: 'go' + args: + - install + - github.com/google/go-jsonnet/cmd/jsonnet@latest +- name: golang:alpine + id: build-templates + entrypoint: sh + args: + - scripts/gen-configs.sh +- name: google/cloud-sdk:slim + args: + - scripts/upload-tests.sh + - gs://us-central1-ml-automation-s-bc954647-bucket/dags + entrypoint: bash +options: + machineType: E2_HIGHCPU_32 + volumes: + - name: go-modules + path: /go From 49aa13f1d582fce704f8b1de0c4605ac8f25fba0 Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Mon, 4 Dec 2023 16:10:25 -0800 Subject: [PATCH 08/12] Pull the latest submodule commit (#33) * Pull the latest submodule commit and add configs Change-Id: Ia43089a8224e7f258eb1a9b44876ebd8f0c2b20f * Apply config changes by regenerating Change-Id: I9418e17389e421695d4511a4979f28f318c224df * Remove Jsonnet configs and update test name Change-Id: I4ff436944ed3b5986f7b50f8caa24c673ec15d3b * Remove configs Change-Id: I774bbe8d2d03835bef0cbaededff1f5575a45e83 * Add pjrt in the test name Change-Id: I4847ac74d883def7b51363f5a60f0f99e9665b67 --- dags/xlml/pytorchxla_llama.py | 4 ++-- ml-testing-accelerators | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dags/xlml/pytorchxla_llama.py b/dags/xlml/pytorchxla_llama.py index 2101ae6a..ebffdfae 100644 --- a/dags/xlml/pytorchxla_llama.py +++ b/dags/xlml/pytorchxla_llama.py @@ -35,13 +35,13 @@ ): llama_inference_v4_8 = task.TpuGceTask( test_config.JSonnetTpuVmTest.from_pytorch( - "pt-nightly-llama2-i-infer-func-v4-8-1vm" + "pt-nightly-llama2-pjrt-infer-func-v4-8-1vm-1vm" ), US_CENTRAL2_B, ).run() llama_train_v4_8 = task.TpuGceTask( test_config.JSonnetTpuVmTest.from_pytorch( - "pt-nightly-llama2-t-train-spmd-func-v4-8-1vm" + "pt-nightly-llama2-pjrt-train-spmd-func-v4-8-1vm-1vm" ), US_CENTRAL2_B, ).run() diff --git a/ml-testing-accelerators b/ml-testing-accelerators index c99553c9..52b290a1 160000 --- a/ml-testing-accelerators +++ b/ml-testing-accelerators @@ -1 +1 @@ -Subproject commit c99553c955294d6bb7e00ff0ad4051f5412ff9fe +Subproject commit 52b290a149b760b270085d4f8191188f986047ed From 2c84814178dcdb16738a79353850d70a95867d10 Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Mon, 4 Dec 2023 16:28:07 -0800 Subject: [PATCH 09/12] Update pax test names (#35) Change-Id: I58ee7d2e9237b686004e0a7454d54cf63f2e5e70 --- configs/xlml/jax/solutionsTeam_pax_latest_supported_config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/xlml/jax/solutionsTeam_pax_latest_supported_config.py b/configs/xlml/jax/solutionsTeam_pax_latest_supported_config.py index 9f5347c1..31d48fdd 100644 --- a/configs/xlml/jax/solutionsTeam_pax_latest_supported_config.py +++ b/configs/xlml/jax/solutionsTeam_pax_latest_supported_config.py @@ -28,6 +28,7 @@ def get_pax_lm_config( exp_path: str, model_name: str, log_dir: str, + pax_version: str = "stable", ckp_path: str = "", extraFlags: str = "", ) -> task.TpuGceTask: @@ -56,7 +57,7 @@ def get_pax_lm_config( runtime_version=vm_resource.RuntimeVersion.TPU_VM_V4_BASE.value, reserved=True, ), - test_name=f"pax_{model_name}_c4", + test_name=f"pax_{pax_version}_{model_name}", set_up_cmds=set_up_cmds, run_model_cmds=run_model_cmds, time_out_in_min=time_out_in_min, From 7214504a98c1f1cc37e104e9b0becf488bfbc9bc Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Tue, 5 Dec 2023 17:32:44 -0800 Subject: [PATCH 10/12] Add Pax lmspmd2b and lmtransformeradam nightly tests (#37) Change-Id: I58947fff8617ce27740b566d2aef90bbeda21152 --- configs/xlml/jax/common.py | 8 -- configs/xlml/pax/__init__.py | 0 configs/xlml/pax/common.py | 31 ++++ .../pax/solutionsTeam_pax_supported_config.py | 133 ++++++++++++++++++ configs/xlml/tensorflow/__init__.py | 0 .../solutionsTeam_pax_latest_supported.py | 26 ++-- .../solutionsTeam_pax_nightly_supported.py | 99 +++++++++++++ 7 files changed, 277 insertions(+), 20 deletions(-) create mode 100644 configs/xlml/pax/__init__.py create mode 100644 configs/xlml/pax/common.py create mode 100644 configs/xlml/pax/solutionsTeam_pax_supported_config.py create mode 100644 configs/xlml/tensorflow/__init__.py create mode 100644 dags/xlml/solutionsTeam_pax_nightly_supported.py diff --git a/configs/xlml/jax/common.py b/configs/xlml/jax/common.py index 00e163d1..3e4551bf 100644 --- a/configs/xlml/jax/common.py +++ b/configs/xlml/jax/common.py @@ -34,14 +34,6 @@ def set_up_google_flax() -> Tuple[str]: ) -def set_up_google_pax() -> Tuple[str]: - """Common set up for pax repo.""" - return ( - "pip install paxml", - INSTALL_LATEST_JAX, - ) - - def set_up_hugging_face_transformers() -> Tuple[str]: """Common set up for hugging face transformer repo.""" return ( diff --git a/configs/xlml/pax/__init__.py b/configs/xlml/pax/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/configs/xlml/pax/common.py b/configs/xlml/pax/common.py new file mode 100644 index 00000000..1623da02 --- /dev/null +++ b/configs/xlml/pax/common.py @@ -0,0 +1,31 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to construct common configs.""" + +from typing import Tuple + +UPGRADE_PIP = "pip install --upgrade pip" +INSTALL_LATEST_JAX = ( + "pip install jax[tpu] -f" + " https://storage.googleapis.com/jax-releases/libtpu_releases.html" +) + + +def set_up_google_pax() -> Tuple[str]: + """Common set up for pax repo.""" + return ( + "pip install paxml", + INSTALL_LATEST_JAX, + ) diff --git a/configs/xlml/pax/solutionsTeam_pax_supported_config.py b/configs/xlml/pax/solutionsTeam_pax_supported_config.py new file mode 100644 index 00000000..a1cf65af --- /dev/null +++ b/configs/xlml/pax/solutionsTeam_pax_supported_config.py @@ -0,0 +1,133 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to construct configs for pax DAGs.""" + +from datetime import datetime +import enum +from typing import Tuple +import uuid +from absl import logging +from apis import gcp_config, metric_config, task, test_config +from configs import test_owner, vm_resource +from configs.xlml.pax import common + + +class PaxVersion(enum.Enum): + NIGHTLY = "nightly" + STABLE = "stable" + + +def get_setup_cmds( + pax_version: PaxVersion, + ckp_path: str, + job_log_dir: str, +) -> Tuple[str]: + if pax_version is PaxVersion.STABLE: + logging.info("Running the latest stable Pax version.") + ckp_cmds = ( + f"gsutil -m cp -r {ckp_path} {job_log_dir}" if ckp_path else "echo" + ) + return common.set_up_google_pax() + (ckp_cmds,) + elif pax_version is PaxVersion.NIGHTLY: + logging.info("Running nightly Pax version.") + build_date = datetime.today().strftime("%Y%m%d") + ckp_cmds = ( + f"gsutil -m cp -r {ckp_path} {job_log_dir}" if ckp_path else "echo" + ) + return ( + ckp_cmds, + ( + "set -x; set -e; gsutil cp" + f" gs://pax-on-cloud-tpu-project/wheels/{build_date}/paxml*.whl ." + ), + ( + "set -x; set -e; gsutil cp" + f" gs://pax-on-cloud-tpu-project/wheels/{build_date}/praxis*.whl ." + ), + "pip install praxis*.whl", + "pip install paxml*.whl", + "sudo pip uninstall --yes jax jaxlib libtpu-nightly", + "pip install git+https://github.com/google/jax.git", + ( + "pip install --pre -U jaxlib -f" + " https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html" + ), + ( + "pip install --no-index -U libtpu-nightly -f" + " https://storage.googleapis.com/jax-releases/libtpu_releases.html" + ), + ) + else: + raise RuntimeError(f"Please specify set up cmds for: {pax_version.value}.") + + +def get_runtime_version(pax_version: PaxVersion) -> str: + if pax_version is PaxVersion.STABLE: + return vm_resource.RuntimeVersion.TPU_VM_V4_BASE.value + elif pax_version is PaxVersion.NIGHTLY: + return vm_resource.RuntimeVersion.TPU_UBUNTU2204_BASE.value + else: + raise RuntimeError( + f"Please specify runtime version for: {pax_version.value}." + ) + + +def get_pax_lm_config( + tpu_version: str, + tpu_cores: int, + tpu_zone: str, + time_out_in_min: int, + exp_path: str, + model_name: str, + log_dir: str, + pax_version: PaxVersion = PaxVersion.STABLE, + ckp_path: str = "", + extraFlags: str = "", +) -> task.TpuTask: + job_gcp_config = gcp_config.GCPConfig( + project_name=vm_resource.PROJECT_CLOUD_ML_AUTO_SOLUTIONS, + zone=tpu_zone, + dataset_name=metric_config.DatasetOption.XLML_DATASET, + ) + + short_id = str(uuid.uuid4())[:8] + job_log_dir = f"{log_dir}/{model_name}-{short_id}" + set_up_cmds = get_setup_cmds(pax_version, ckp_path, job_log_dir) + + run_model_cmds = ( + ( + "python3 .local/lib/python3.8/site-packages/paxml/main.py" + f" --exp={exp_path} --job_log_dir={job_log_dir} {extraFlags}" + ), + ) + + job_test_config = test_config.TpuVmTest( + test_config.Tpu( + version=tpu_version, + cores=tpu_cores, + runtime_version=get_runtime_version(pax_version), + reserved=True, + ), + test_name=f"pax_{pax_version.value}_{model_name}", + set_up_cmds=set_up_cmds, + run_model_cmds=run_model_cmds, + time_out_in_min=time_out_in_min, + task_owner=test_owner.GERSON_K, + ) + + return task.TpuTask( + task_test_config=job_test_config, + task_gcp_config=job_gcp_config, + ) diff --git a/configs/xlml/tensorflow/__init__.py b/configs/xlml/tensorflow/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dags/xlml/solutionsTeam_pax_latest_supported.py b/dags/xlml/solutionsTeam_pax_latest_supported.py index c0f73f62..2ce4cdb3 100644 --- a/dags/xlml/solutionsTeam_pax_latest_supported.py +++ b/dags/xlml/solutionsTeam_pax_latest_supported.py @@ -17,7 +17,7 @@ import datetime from airflow import models from configs import composer_env, gcs_bucket, vm_resource -from configs.xlml.jax import solutionsTeam_pax_latest_supported_config as pax_config +from configs.xlml.pax import solutionsTeam_pax_supported_config as pax_config # Run once a day at 10 am UTC (2 am PST) @@ -31,24 +31,27 @@ start_date=datetime.datetime(2023, 11, 8), catchup=False, ) as dag: + log_dir_prefix = f"{gcs_bucket.XLML_OUTPUT_DIR}/pax/stable" + # Language model with SPMD - pax_lmspmd2b_v4_8 = pax_config.get_pax_lm_config( + lmspmd2b_exp_path = "tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps" + pax_stable_lmspmd2b_v4_8 = pax_config.get_pax_lm_config( tpu_version="4", tpu_cores=8, tpu_zone=vm_resource.Zone.US_CENTRAL2_B.value, time_out_in_min=60, - log_dir=f"{gcs_bucket.XLML_OUTPUT_DIR}/pax/lmspmd2b/v4-8", - exp_path="tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps", + log_dir=f"{log_dir_prefix}/lmspmd2b/v4-8", + exp_path=lmspmd2b_exp_path, model_name="lmspmd2b", ).run() - pax_lmspmd2b_ckpt_v4_8 = pax_config.get_pax_lm_config( + pax_stable_lmspmd2b_ckpt_v4_8 = pax_config.get_pax_lm_config( tpu_version="4", tpu_cores=8, tpu_zone=vm_resource.Zone.US_CENTRAL2_B.value, time_out_in_min=60, - log_dir=f"{gcs_bucket.XLML_OUTPUT_DIR}/pax/lmspmd2b_ckpt/v4-8", - exp_path="tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps", + log_dir=f"{log_dir_prefix}/lmspmd2b_ckpt/v4-8", + exp_path=lmspmd2b_exp_path, model_name="lmspmd2b_ckpt", ckp_path=f"{gcs_bucket.PAX_DIR}/lmcloudspmd2B/pax-nightly-lmspmd2b-func-v4-8-1vm-run1/*", ).run() @@ -58,18 +61,17 @@ "--jax_fully_async_checkpoint=False", "--pmap_use_tensorstore=True", ] - pax_lmtransformeradam_v4_8 = pax_config.get_pax_lm_config( + pax_stable_lmtransformeradam_v4_8 = pax_config.get_pax_lm_config( tpu_version="4", tpu_cores=8, tpu_zone=vm_resource.Zone.US_CENTRAL2_B.value, time_out_in_min=60, - log_dir=f"{gcs_bucket.XLML_OUTPUT_DIR}/pax/lmtransformeradam/v4-8", + log_dir=f"{log_dir_prefix}/lmtransformeradam/v4-8", exp_path="tasks.lm.params.lm_cloud.LmCloudTransformerAdamLimitSteps", model_name="lmtransformeradam", extraFlags=" ".join(pax_transformer_adam_extra_flags), ).run() # Test dependencies - pax_lmspmd2b_v4_8 - pax_lmspmd2b_ckpt_v4_8 - pax_lmtransformeradam_v4_8 + pax_stable_lmspmd2b_v4_8 >> pax_stable_lmspmd2b_ckpt_v4_8 + pax_stable_lmtransformeradam_v4_8 diff --git a/dags/xlml/solutionsTeam_pax_nightly_supported.py b/dags/xlml/solutionsTeam_pax_nightly_supported.py new file mode 100644 index 00000000..7085a1e3 --- /dev/null +++ b/dags/xlml/solutionsTeam_pax_nightly_supported.py @@ -0,0 +1,99 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A DAG to run all supported ML models with the nightly PAX version.""" + +import datetime +from airflow import models +from configs import composer_env, gcs_bucket, vm_resource +from configs.xlml.pax import solutionsTeam_pax_supported_config as pax_config + + +# Run once a day at 12 am UTC (4 am PST) +SCHEDULED_TIME = "0 12 * * *" if composer_env.is_prod_env() else None + + +with models.DAG( + dag_id="pax_nightly_supported", + schedule=SCHEDULED_TIME, + tags=["solutions_team", "pax", "nightly", "supported", "xlml"], + start_date=datetime.datetime(2023, 12, 5), + catchup=False, +) as dag: + log_dir_prefix = f"{gcs_bucket.XLML_OUTPUT_DIR}/pax/nightly" + + # Language model with SPMD + pax_lmspmd2b_extra_flags = [ + "--jax_fully_async_checkpoint=False", + ] + lmspmd2b_exp_path = "tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps" + pax_nightly_lmspmd2b_v4_8 = pax_config.get_pax_lm_config( + tpu_version="4", + tpu_cores=8, + tpu_zone=vm_resource.Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + log_dir=f"{log_dir_prefix}/lmspmd2b/v4-8", + pax_version=pax_config.PaxVersion.NIGHTLY, + exp_path=lmspmd2b_exp_path, + model_name="lmspmd2b", + extraFlags=" ".join(pax_lmspmd2b_extra_flags), + ).run() + + pax_nightly_lmspmd2b_ckpt_v4_8 = pax_config.get_pax_lm_config( + tpu_version="4", + tpu_cores=8, + tpu_zone=vm_resource.Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + log_dir=f"{log_dir_prefix}/lmspmd2b_ckpt/v4-8", + pax_version=pax_config.PaxVersion.NIGHTLY, + exp_path=lmspmd2b_exp_path, + model_name="lmspmd2b_ckpt", + ckp_path=f"{gcs_bucket.PAX_DIR}/lmcloudspmd2B/pax-nightly-lmspmd2b-func-v4-8-1vm-run1/*", + ).run() + + # Language model transformer with adam + pax_transformer_adam_extra_flags = [ + "--jax_fully_async_checkpoint=False", + "--pmap_use_tensorstore=True", + ] + lmtransformeradam_exp_path = ( + "tasks.lm.params.lm_cloud.LmCloudTransformerAdamLimitSteps" + ) + pax_nightly_lmtransformeradam_v4_8 = pax_config.get_pax_lm_config( + tpu_version="4", + tpu_cores=8, + tpu_zone=vm_resource.Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + log_dir=f"{log_dir_prefix}/lmtransformeradam/v4-8", + exp_path=lmtransformeradam_exp_path, + pax_version=pax_config.PaxVersion.NIGHTLY, + model_name="lmtransformeradam", + extraFlags=" ".join(pax_transformer_adam_extra_flags), + ).run() + + pax_nightly_lmtransformeradam_v4_16 = pax_config.get_pax_lm_config( + tpu_version="4", + tpu_cores=16, + tpu_zone=vm_resource.Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + log_dir=f"{log_dir_prefix}/lmtransformeradam/v4-16", + exp_path=lmtransformeradam_exp_path, + pax_version=pax_config.PaxVersion.NIGHTLY, + model_name="lmtransformeradam", + extraFlags=" ".join(pax_transformer_adam_extra_flags), + ).run() + + # Test dependencies + pax_nightly_lmspmd2b_v4_8 >> pax_nightly_lmspmd2b_ckpt_v4_8 + pax_nightly_lmtransformeradam_v4_8 >> pax_nightly_lmtransformeradam_v4_16 From 1c096df7fd09762311cfbcaf18b35850de4be766 Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Tue, 5 Dec 2023 18:48:15 -0800 Subject: [PATCH 11/12] Update solutionsTeam to solutionsteam in file name (#38) Change-Id: I575c2afcdbba6883ccc1569bd38d0a60e6f24f74 --- ...nsTeam_jax_npi_config.py => solutionsteam_jax_npi_config.py} | 2 +- ..._config.py => solutionsteam_flax_latest_supported_config.py} | 2 +- ...upported_config.py => solutionsteam_pax_supported_config.py} | 0 ...d_config.py => solutionsteam_tf_nightly_supported_config.py} | 2 +- .../{solutionsTeam_jax_npi.py => solutionsteam_jax_npi.py} | 2 +- ...test_supported.py => solutionsteam_flax_latest_supported.py} | 2 +- ...t_integration.py => solutionsteam_jax_latest_integration.py} | 0 ...atest_supported.py => solutionsteam_pax_latest_supported.py} | 2 +- ...htly_supported.py => solutionsteam_pax_nightly_supported.py} | 2 +- ...ghtly_supported.py => solutionsteam_tf_nightly_supported.py} | 2 +- 10 files changed, 8 insertions(+), 8 deletions(-) rename configs/benchmark/jax/{solutionsTeam_jax_npi_config.py => solutionsteam_jax_npi_config.py} (98%) rename configs/xlml/jax/{solutionsTeam_flax_latest_supported_config.py => solutionsteam_flax_latest_supported_config.py} (99%) rename configs/xlml/pax/{solutionsTeam_pax_supported_config.py => solutionsteam_pax_supported_config.py} (100%) rename configs/xlml/tensorflow/{solutionsTeam_tf_nightly_supported_config.py => solutionsteam_tf_nightly_supported_config.py} (98%) rename dags/benchmark/{solutionsTeam_jax_npi.py => solutionsteam_jax_npi.py} (94%) rename dags/xlml/{solutionsTeam_flax_latest_supported.py => solutionsteam_flax_latest_supported.py} (99%) rename dags/xlml/{solutionsTeam_jax_latest_integration.py => solutionsteam_jax_latest_integration.py} (100%) rename dags/xlml/{solutionsTeam_pax_latest_supported.py => solutionsteam_pax_latest_supported.py} (97%) rename dags/xlml/{solutionsTeam_pax_nightly_supported.py => solutionsteam_pax_nightly_supported.py} (98%) rename dags/xlml/{solutionsTeam_tf_nightly_supported.py => solutionsteam_tf_nightly_supported.py} (97%) diff --git a/configs/benchmark/jax/solutionsTeam_jax_npi_config.py b/configs/benchmark/jax/solutionsteam_jax_npi_config.py similarity index 98% rename from configs/benchmark/jax/solutionsTeam_jax_npi_config.py rename to configs/benchmark/jax/solutionsteam_jax_npi_config.py index f8ce91bc..5142e22e 100644 --- a/configs/benchmark/jax/solutionsTeam_jax_npi_config.py +++ b/configs/benchmark/jax/solutionsteam_jax_npi_config.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities to construct configs for solutionsTeam_jax_npi DAG.""" +"""Utilities to construct configs for solutionsteam_jax_npi DAG.""" from apis import gcp_config, metric_config, task, test_config from configs import gcs_bucket, test_owner, vm_resource diff --git a/configs/xlml/jax/solutionsTeam_flax_latest_supported_config.py b/configs/xlml/jax/solutionsteam_flax_latest_supported_config.py similarity index 99% rename from configs/xlml/jax/solutionsTeam_flax_latest_supported_config.py rename to configs/xlml/jax/solutionsteam_flax_latest_supported_config.py index 8ab8bfa6..d1ec89d1 100644 --- a/configs/xlml/jax/solutionsTeam_flax_latest_supported_config.py +++ b/configs/xlml/jax/solutionsteam_flax_latest_supported_config.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities to construct configs for solutionsTeam_flax_latest_supported DAG.""" +"""Utilities to construct configs for solutionsteam_flax_latest_supported DAG.""" from typing import Tuple import uuid diff --git a/configs/xlml/pax/solutionsTeam_pax_supported_config.py b/configs/xlml/pax/solutionsteam_pax_supported_config.py similarity index 100% rename from configs/xlml/pax/solutionsTeam_pax_supported_config.py rename to configs/xlml/pax/solutionsteam_pax_supported_config.py diff --git a/configs/xlml/tensorflow/solutionsTeam_tf_nightly_supported_config.py b/configs/xlml/tensorflow/solutionsteam_tf_nightly_supported_config.py similarity index 98% rename from configs/xlml/tensorflow/solutionsTeam_tf_nightly_supported_config.py rename to configs/xlml/tensorflow/solutionsteam_tf_nightly_supported_config.py index 5ec7d3b1..7fcfa073 100644 --- a/configs/xlml/tensorflow/solutionsTeam_tf_nightly_supported_config.py +++ b/configs/xlml/tensorflow/solutionsteam_tf_nightly_supported_config.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities to construct configs for solutionsTeam_tf_nightly_supported DAG.""" +"""Utilities to construct configs for solutionsteam_tf_nightly_supported DAG.""" import uuid from apis import gcp_config, metric_config, task, test_config diff --git a/dags/benchmark/solutionsTeam_jax_npi.py b/dags/benchmark/solutionsteam_jax_npi.py similarity index 94% rename from dags/benchmark/solutionsTeam_jax_npi.py rename to dags/benchmark/solutionsteam_jax_npi.py index 6ca1ecbc..d92c57c1 100644 --- a/dags/benchmark/solutionsTeam_jax_npi.py +++ b/dags/benchmark/solutionsteam_jax_npi.py @@ -17,7 +17,7 @@ import datetime from airflow import models from configs import vm_resource -from configs.benchmark.jax import solutionsTeam_jax_npi_config as jax_npi_config +from configs.benchmark.jax import solutionsteam_jax_npi_config as jax_npi_config with models.DAG( diff --git a/dags/xlml/solutionsTeam_flax_latest_supported.py b/dags/xlml/solutionsteam_flax_latest_supported.py similarity index 99% rename from dags/xlml/solutionsTeam_flax_latest_supported.py rename to dags/xlml/solutionsteam_flax_latest_supported.py index fb42371e..5f6223cf 100644 --- a/dags/xlml/solutionsTeam_flax_latest_supported.py +++ b/dags/xlml/solutionsteam_flax_latest_supported.py @@ -17,7 +17,7 @@ import datetime from airflow import models from configs import composer_env, vm_resource -from configs.xlml.jax import solutionsTeam_flax_latest_supported_config as flax_config +from configs.xlml.jax import solutionsteam_flax_latest_supported_config as flax_config # Run once a day at 2 am UTC (6 pm PST) diff --git a/dags/xlml/solutionsTeam_jax_latest_integration.py b/dags/xlml/solutionsteam_jax_latest_integration.py similarity index 100% rename from dags/xlml/solutionsTeam_jax_latest_integration.py rename to dags/xlml/solutionsteam_jax_latest_integration.py diff --git a/dags/xlml/solutionsTeam_pax_latest_supported.py b/dags/xlml/solutionsteam_pax_latest_supported.py similarity index 97% rename from dags/xlml/solutionsTeam_pax_latest_supported.py rename to dags/xlml/solutionsteam_pax_latest_supported.py index 2ce4cdb3..67a511de 100644 --- a/dags/xlml/solutionsTeam_pax_latest_supported.py +++ b/dags/xlml/solutionsteam_pax_latest_supported.py @@ -17,7 +17,7 @@ import datetime from airflow import models from configs import composer_env, gcs_bucket, vm_resource -from configs.xlml.pax import solutionsTeam_pax_supported_config as pax_config +from configs.xlml.pax import solutionsteam_pax_supported_config as pax_config # Run once a day at 10 am UTC (2 am PST) diff --git a/dags/xlml/solutionsTeam_pax_nightly_supported.py b/dags/xlml/solutionsteam_pax_nightly_supported.py similarity index 98% rename from dags/xlml/solutionsTeam_pax_nightly_supported.py rename to dags/xlml/solutionsteam_pax_nightly_supported.py index 7085a1e3..42a5ee9b 100644 --- a/dags/xlml/solutionsTeam_pax_nightly_supported.py +++ b/dags/xlml/solutionsteam_pax_nightly_supported.py @@ -17,7 +17,7 @@ import datetime from airflow import models from configs import composer_env, gcs_bucket, vm_resource -from configs.xlml.pax import solutionsTeam_pax_supported_config as pax_config +from configs.xlml.pax import solutionsteam_pax_supported_config as pax_config # Run once a day at 12 am UTC (4 am PST) diff --git a/dags/xlml/solutionsTeam_tf_nightly_supported.py b/dags/xlml/solutionsteam_tf_nightly_supported.py similarity index 97% rename from dags/xlml/solutionsTeam_tf_nightly_supported.py rename to dags/xlml/solutionsteam_tf_nightly_supported.py index 1de2abe6..fe987ac8 100644 --- a/dags/xlml/solutionsTeam_tf_nightly_supported.py +++ b/dags/xlml/solutionsteam_tf_nightly_supported.py @@ -17,7 +17,7 @@ import datetime from airflow import models from configs import composer_env, vm_resource -from configs.xlml.tensorflow import solutionsTeam_tf_nightly_supported_config as tf_config +from configs.xlml.tensorflow import solutionsteam_tf_nightly_supported_config as tf_config # Run once a day at 6 am UTC (10 pm PST) From d7c3cda34b9d28f91f1a9de0b4c4cc985671bc23 Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Wed, 6 Dec 2023 07:46:31 +0000 Subject: [PATCH 12/12] Update names Change-Id: I39d23e223b6c0c2c2633b4cee3a3d65909dbf63f --- apis/task.py | 8 ++--- apis/test_config.py | 13 +++++++- .../jax/solutionsteam_jax_npi_config.py | 4 +-- .../pytorch/pytorchxla_torchbench_config.py | 4 +-- configs/example/gke_example_config.py | 4 +-- ...utionsteam_flax_latest_supported_config.py | 32 +++++++++---------- ...lutionsteam_tf_nightly_supported_config.py | 8 ++--- dags/benchmark/pytorchxla_torchbench.py | 6 ++-- dags/xlml/pytorchxla_huggingface.py | 8 ++--- dags/xlml/pytorchxla_llama.py | 4 +-- dags/xlml/pytorchxla_torchvision.py | 6 ++-- .../solutionsteam_jax_latest_integration.py | 4 +-- implementations/utils/xpk.py | 6 ++-- 13 files changed, 59 insertions(+), 48 deletions(-) diff --git a/apis/task.py b/apis/task.py index 897bf9e9..7e3e41e6 100644 --- a/apis/task.py +++ b/apis/task.py @@ -41,8 +41,8 @@ def run() -> DAGNode: @dataclasses.dataclass -class TpuGceTask(BaseTask): - """This is a class to set up tasks for TPU in GCE. +class TpuQueuedResourceTask(BaseTask): + """This is a class to set up tasks for TPU provisioned by Queued Resource. Attributes: task_test_config: Test configs to run on this TPU. @@ -190,8 +190,8 @@ def clean_up(self, queued_resource: airflow.XComArg) -> DAGNode: @dataclasses.dataclass -class TpuGkeTask(BaseTask): - """This is a class to set up tasks for TPU in GKE. +class TpuXpkTask(BaseTask): + """This is a class to set up tasks for TPU provisioned by XPK tool. Attributes: task_test_config: Test configs to run on this TPU. diff --git a/apis/test_config.py b/apis/test_config.py index 558d25b6..47606593 100644 --- a/apis/test_config.py +++ b/apis/test_config.py @@ -79,7 +79,7 @@ class Tpu(Accelerator): version: str cores: int - runtime_version: str = vm_resource.RuntimeVersion.TPU_UBUNTU2204_BASE.value + runtime_version: Optional[str] = None network: str = 'default' subnetwork: str = 'default' reserved: bool = False @@ -157,6 +157,17 @@ def test_script(self) -> str: @attrs.define class TpuGkeTest(TestConfig[Tpu]): + """Test config that runs on a single Cloud TPU VM instance. + + Attributes: + test_name: Unique name for this test/model. + cluster_name: Name of the cluster that has provisioned TPUs. + cluster_config: Config of the cluster. + docker_image: Image of the docker to run. + set_up_cmds: List of commands to run once when TPU is created. + run_model_cmds: List of commands to run the model under test. + """ + test_name: str cluster_name: str cluster_config: str diff --git a/configs/benchmark/jax/solutionsteam_jax_npi_config.py b/configs/benchmark/jax/solutionsteam_jax_npi_config.py index 5142e22e..be92a83a 100644 --- a/configs/benchmark/jax/solutionsteam_jax_npi_config.py +++ b/configs/benchmark/jax/solutionsteam_jax_npi_config.py @@ -25,7 +25,7 @@ def get_jax_vit_config( tpu_cores: int, tpu_zone: str, time_out_in_min: int, -) -> task.TpuGceTask: +) -> task.TpuQueuedResourceTask: job_gcp_config = gcp_config.GCPConfig( project_name=vm_resource.PROJECT_CLOUD_ML_AUTO_SOLUTIONS, zone=tpu_zone, @@ -99,7 +99,7 @@ def get_jax_vit_config( ) ) - return task.TpuGceTask( + return task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, task_metric_config=job_metric_config, diff --git a/configs/benchmark/pytorch/pytorchxla_torchbench_config.py b/configs/benchmark/pytorch/pytorchxla_torchbench_config.py index c2739ed5..e5efd0a1 100644 --- a/configs/benchmark/pytorch/pytorchxla_torchbench_config.py +++ b/configs/benchmark/pytorch/pytorchxla_torchbench_config.py @@ -70,7 +70,7 @@ def get_torchbench_config( time_out_in_min: int, model_name: str = "", extraFlags: str = "", -) -> task.TpuGceTask: +) -> task.TpuQueuedResourceTask: job_gcp_config = gcp_config.GCPConfig( project_name=vm_resource.PROJECT_CLOUD_ML_AUTO_SOLUTIONS, zone=tpu_zone, @@ -114,7 +114,7 @@ def get_torchbench_config( ) ) - return task.TpuGceTask( + return task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, task_metric_config=job_metric_config, diff --git a/configs/example/gke_example_config.py b/configs/example/gke_example_config.py index 03ee2350..9c8bddb1 100644 --- a/configs/example/gke_example_config.py +++ b/configs/example/gke_example_config.py @@ -26,7 +26,7 @@ def get_flax_resnet_gke_config( cluster_config: str, docker_image: str, time_out_in_min: int, -) -> task.TpuGkeTask: +) -> task.TpuXpkTask: # TODO(ranran): update the project once quota is approved (b/311073979). job_gcp_config = gcp_config.GCPConfig( project_name="tpu-prod-env-one-vm", @@ -55,7 +55,7 @@ def get_flax_resnet_gke_config( task_owner=test_owner.RAN_R, ) - return task.TpuGkeTask( + return task.TpuXpkTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) diff --git a/configs/xlml/jax/solutionsteam_flax_latest_supported_config.py b/configs/xlml/jax/solutionsteam_flax_latest_supported_config.py index d1ec89d1..f4e1bf08 100644 --- a/configs/xlml/jax/solutionsteam_flax_latest_supported_config.py +++ b/configs/xlml/jax/solutionsteam_flax_latest_supported_config.py @@ -33,7 +33,7 @@ def get_flax_resnet_config( time_out_in_min: int, data_dir: str = gcs_bucket.TFDS_DATA_DIR, extraFlags: str = "", -) -> task.TpuGceTask: +) -> task.TpuQueuedResourceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -66,7 +66,7 @@ def get_flax_resnet_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuGceTask( + return task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) @@ -110,7 +110,7 @@ def get_flax_vit_config( time_out_in_min: int, num_train_epochs: int = 3, extraFlags: str = "", -) -> task.TpuGceTask: +) -> task.TpuQueuedResourceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -134,7 +134,7 @@ def get_flax_vit_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuGceTask( + return task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) @@ -147,7 +147,7 @@ def get_flax_vit_conv_config( time_out_in_min: int, num_train_epochs: int = 30, extraFlags: str = "", -) -> task.TpuGceTask: +) -> task.TpuQueuedResourceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -191,7 +191,7 @@ def get_flax_vit_conv_config( ) ) - return task.TpuGceTask( + return task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, task_metric_config=job_metric_config, @@ -204,7 +204,7 @@ def get_flax_gpt2_config( tpu_zone: str, time_out_in_min: int, extraFlags: str = "", -) -> task.TpuGceTask: +) -> task.TpuQueuedResourceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -252,7 +252,7 @@ def get_flax_gpt2_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuGceTask( + return task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) @@ -265,7 +265,7 @@ def get_flax_sd_config( time_out_in_min: int, num_train_epochs: int, extraFlags: str = "", -) -> task.TpuGceTask: +) -> task.TpuQueuedResourceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -307,7 +307,7 @@ def get_flax_sd_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuGceTask( + return task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) @@ -319,7 +319,7 @@ def get_flax_bart_config( tpu_zone: str, time_out_in_min: int, extraFlags: str = "", -) -> task.TpuGceTask: +) -> task.TpuQueuedResourceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -358,7 +358,7 @@ def get_flax_bart_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuGceTask( + return task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) @@ -372,7 +372,7 @@ def get_flax_bert_config( task_name: str, num_train_epochs: int = 1, extraFlags: str = "", -) -> task.TpuGceTask: +) -> task.TpuQueuedResourceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -408,7 +408,7 @@ def get_flax_bert_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuGceTask( + return task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) @@ -422,7 +422,7 @@ def get_flax_wmt_config( num_train_steps: int, data_dir: str = gcs_bucket.TFDS_DATA_DIR, extraFlags: str = "", -) -> task.TpuGceTask: +) -> task.TpuQueuedResourceTask: job_gcp_config = gcp_config.GCPConfig( project_name=PROJECT_NAME, zone=tpu_zone, @@ -461,7 +461,7 @@ def get_flax_wmt_config( task_owner=test_owner.SHIVA_S, ) - return task.TpuGceTask( + return task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, ) diff --git a/configs/xlml/tensorflow/solutionsteam_tf_nightly_supported_config.py b/configs/xlml/tensorflow/solutionsteam_tf_nightly_supported_config.py index 7fcfa073..dc06e251 100644 --- a/configs/xlml/tensorflow/solutionsteam_tf_nightly_supported_config.py +++ b/configs/xlml/tensorflow/solutionsteam_tf_nightly_supported_config.py @@ -30,7 +30,7 @@ def get_tf_resnet_config( tfds_data_dir: str = gcs_bucket.TFDS_DATA_DIR, train_steps: int = 320, validation_interval: int = 320, -) -> task.TpuGceTask: +) -> task.TpuQueuedResourceTask: job_gcp_config = gcp_config.GCPConfig( project_name=vm_resource.PROJECT_CLOUD_ML_AUTO_SOLUTIONS, zone=tpu_zone, @@ -84,7 +84,7 @@ def get_tf_resnet_config( task_owner=test_owner.CHANDRA_D, ) - return task.TpuGceTask( + return task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, custom_tpu_name=tpu_name, @@ -103,7 +103,7 @@ def get_tf_bert_config( tfds_data_dir: str = gcs_bucket.TFDS_DATA_DIR, train_steps: int = 2000, validation_interval: int = 1000, -) -> task.TpuGceTask: +) -> task.TpuQueuedResourceTask: job_gcp_config = gcp_config.GCPConfig( project_name=vm_resource.PROJECT_CLOUD_ML_AUTO_SOLUTIONS, zone=tpu_zone, @@ -165,7 +165,7 @@ def get_tf_bert_config( task_owner=test_owner.CHANDRA_D, ) - return task.TpuGceTask( + return task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, custom_tpu_name=tpu_name, diff --git a/dags/benchmark/pytorchxla_torchbench.py b/dags/benchmark/pytorchxla_torchbench.py index 5bbe9ffa..7d0a011b 100644 --- a/dags/benchmark/pytorchxla_torchbench.py +++ b/dags/benchmark/pytorchxla_torchbench.py @@ -31,9 +31,9 @@ SCHEDULED_TIME = "0 17 * * *" if composer_env.is_prod_env() else None with models.DAG( - dag_id="pytorch_nightly_torchbench", - schedule=None, - tags=["pytorch", "nightly", "torchbench", "benchmark"], + dag_id="pytorchxla-torchbench", + schedule=SCHEDULED_TIME, + tags=["pytorchxla", "nightly", "torchbench"], start_date=datetime.datetime(2023, 8, 29), catchup=False, ) as dag: diff --git a/dags/xlml/pytorchxla_huggingface.py b/dags/xlml/pytorchxla_huggingface.py index cc82d68f..c602f399 100644 --- a/dags/xlml/pytorchxla_huggingface.py +++ b/dags/xlml/pytorchxla_huggingface.py @@ -39,19 +39,19 @@ start_date=datetime.datetime(2023, 7, 12), catchup=False, ): - accelerate_v2_8 = task.TpuGceTask( + accelerate_v2_8 = task.TpuQueuedResourceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-accelerate-smoke-v2-8-1vm" ), US_CENTRAL1_C, ).run() - accelerate_v4_8 = task.TpuGceTask( + accelerate_v4_8 = task.TpuQueuedResourceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-accelerate-smoke-v4-8-1vm" ), US_CENTRAL2_B, ).run() - diffusers_v4_8 = task.TpuGceTask( + diffusers_v4_8 = task.TpuQueuedResourceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-hf-diffusers-func-v4-8-1vm" ), @@ -61,7 +61,7 @@ accelerate_v4_8 >> accelerate_v2_8 accelerate_v4_8 >> diffusers_v4_8 - task.TpuGceTask( + task.TpuQueuedResourceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-hf-fsmt-pjrt-func-v4-8-1vm" ), diff --git a/dags/xlml/pytorchxla_llama.py b/dags/xlml/pytorchxla_llama.py index ebffdfae..b557f34a 100644 --- a/dags/xlml/pytorchxla_llama.py +++ b/dags/xlml/pytorchxla_llama.py @@ -33,13 +33,13 @@ start_date=datetime.datetime(2023, 7, 12), catchup=False, ): - llama_inference_v4_8 = task.TpuGceTask( + llama_inference_v4_8 = task.TpuQueuedResourceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-llama2-pjrt-infer-func-v4-8-1vm-1vm" ), US_CENTRAL2_B, ).run() - llama_train_v4_8 = task.TpuGceTask( + llama_train_v4_8 = task.TpuQueuedResourceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-llama2-pjrt-train-spmd-func-v4-8-1vm-1vm" ), diff --git a/dags/xlml/pytorchxla_torchvision.py b/dags/xlml/pytorchxla_torchvision.py index 2f657225..dd2bb673 100644 --- a/dags/xlml/pytorchxla_torchvision.py +++ b/dags/xlml/pytorchxla_torchvision.py @@ -39,19 +39,19 @@ start_date=datetime.datetime(2023, 7, 12), catchup=False, ): - mnist_v2_8 = task.TpuGceTask( + mnist_v2_8 = task.TpuQueuedResourceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-mnist-pjrt-func-v2-8-1vm" ), US_CENTRAL1_C, ).run() - resnet_v2_8 = task.TpuGceTask( + resnet_v2_8 = task.TpuQueuedResourceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-resnet50-pjrt-fake-v2-8-1vm" ), US_CENTRAL1_C, ).run() - resnet_v4_8 = task.TpuGceTask( + resnet_v4_8 = task.TpuQueuedResourceTask( test_config.JSonnetTpuVmTest.from_pytorch( "pt-nightly-resnet50-pjrt-fake-v4-8-1vm" ), diff --git a/dags/xlml/solutionsteam_jax_latest_integration.py b/dags/xlml/solutionsteam_jax_latest_integration.py index c2a34c69..0999496a 100644 --- a/dags/xlml/solutionsteam_jax_latest_integration.py +++ b/dags/xlml/solutionsteam_jax_latest_integration.py @@ -39,13 +39,13 @@ start_date=datetime.datetime(2023, 7, 12), catchup=False, ): - compilation_cache = task.TpuGceTask( + compilation_cache = task.TpuQueuedResourceTask( test_config.JSonnetTpuVmTest.from_jax( "jax-compilation-cache-test-func-v2-8-1vm" ), US_CENTRAL1_C, ).run() - pod = task.TpuGceTask( + pod = task.TpuQueuedResourceTask( test_config.JSonnetTpuVmTest.from_jax( "jax-pod-latest-tpu-ubuntu2204-base-func-v2-32-1vm" ), diff --git a/implementations/utils/xpk.py b/implementations/utils/xpk.py index eb91c5cc..24562430 100644 --- a/implementations/utils/xpk.py +++ b/implementations/utils/xpk.py @@ -18,7 +18,7 @@ from absl import logging from airflow.decorators import task from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator -from kubernetes import client, config +from kubernetes import client as kubernetes_client, config as kubernetes_config @task @@ -81,10 +81,10 @@ def wait_for_workload_completion(workload_id: str, cluster_config: str) -> bool: """Check the workload status.""" # Load the config for the cluster with TPUs in the pool - config.load_kube_config( + kubernetes_config.load_kube_config( config_file=f"/home/airflow/gcs/dags/configs/cluster/{cluster_config}" ) - core_api = client.CoreV1Api() + core_api = kubernetes_client.CoreV1Api() logging.info(f"workload_id: {workload_id}") pods = core_api.list_namespaced_pod(