Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable docker image feature with xpk #28

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions .github/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ google-cloud-storage
google-cloud-tpu>=1.16.0
jsonlines
tensorflow-cpu
apache-airflow-providers-cncf-kubernetes
84 changes: 79 additions & 5 deletions apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from airflow.utils.task_group import TaskGroup
from apis import gcp_config, metric_config, test_config
from implementations.utils import metric
from implementations.utils import ssh, tpu
from implementations.utils import ssh, tpu, xpk


class BaseTask(abc.ABC):
Expand All @@ -41,15 +41,15 @@ def run() -> DAGNode:


@dataclasses.dataclass
class TpuTask(BaseTask):
"""This is a class to set up tasks for TPU.
class TpuGceTask(BaseTask):
RissyRan marked this conversation as resolved.
Show resolved Hide resolved
"""This is a class to set up tasks for TPU in GCE.

Attributes:
task_test_config: Test configs to run on this TPU.
task_gcp_config: Runtime TPU creation parameters.
task_metric_config: Metric configs to process metrics.
custom_tpu_name: A custom TPU name. By default the name is
test name + accelerator name.
custom_tpu_name: A custom TPU name. By default the name is test name +
accelerator name.
suffix_tpu_name: The flag to define if add auto-generated suffix.
all_workers: The flag to define if run commands on all workers or worker 0
only.
Expand Down Expand Up @@ -189,6 +189,80 @@ def clean_up(self, queued_resource: airflow.XComArg) -> DAGNode:
)


@dataclasses.dataclass
class TpuGkeTask(BaseTask):
"""This is a class to set up tasks for TPU in GKE.

Attributes:
task_test_config: Test configs to run on this TPU.
task_gcp_config: Runtime TPU creation parameters.
task_metric_config: Metric configs to process metrics.
"""

task_test_config: test_config.TestConfig[test_config.Tpu]
task_gcp_config: gcp_config.GCPConfig
task_metric_config: Optional[metric_config.MetricConfig] = None

def run(self) -> DAGNode:
"""Run a test job within a docker image.

Returns:
A task group with the following tasks chained: run_model and
post_process.
"""
with TaskGroup(group_id=self.task_test_config.benchmark_id) as group:
self.run_model() >> self.post_process()

return group

def run_model(self) -> DAGNode:
"""Run the TPU test in `task_test_config` using xpk.

Returns:
A DAG node that executes the model test.
"""
with TaskGroup(group_id="run_model") as group:
workload_id = xpk.generate_workload_id(self.task_test_config.benchmark_id)
run_workload = xpk.run_workload(
task_id="run_workload",
project_id=self.task_gcp_config.project_name,
zone=self.task_gcp_config.zone,
cluster_name=self.task_test_config.cluster_name,
benchmark_id=self.task_test_config.benchmark_id,
workload_id=workload_id,
docker_image=self.task_test_config.docker_image,
accelerator_type=self.task_test_config.accelerator.name,
run_cmds=self.task_test_config.run_model_cmds,
task_owner=self.task_test_config.task_owner,
)
wait_for_workload_completion = xpk.wait_for_workload_completion.override(
timeout=self.task_test_config.time_out_in_min * 60
)(
workload_id=workload_id,
cluster_config=self.task_test_config.cluster_config,
)

workload_id >> run_workload >> wait_for_workload_completion
return group

def post_process(self) -> DAGNode:
"""Process metrics and metadata, and insert them into BigQuery tables.

Returns:
A DAG node that executes the post process.
"""
with TaskGroup(group_id="post_process") as group:
process_id = metric.generate_process_id.override(retries=1)()
metric.process_metrics.override(retries=1)(
process_id,
self.task_test_config,
self.task_metric_config,
self.task_gcp_config,
)

return group


@dataclasses.dataclass
class GpuTask(BaseTask):
"""This is a class to set up tasks for GPU.
Expand Down
25 changes: 24 additions & 1 deletion apis/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, accelerator, task_owner=None, test_name):
import shlex
from typing import Any, Generic, Iterable, List, Optional, TypeVar
import attrs
from configs import vm_resource


class Accelerator(abc.ABC):
Expand Down Expand Up @@ -78,7 +79,7 @@ class Tpu(Accelerator):

version: str
cores: int
runtime_version: str
runtime_version: str = vm_resource.RuntimeVersion.TPU_UBUNTU2204_BASE.value
RissyRan marked this conversation as resolved.
Show resolved Hide resolved
network: str = 'default'
subnetwork: str = 'default'
reserved: bool = False
Expand Down Expand Up @@ -154,6 +155,28 @@ def test_script(self) -> str:
return '\n'.join(self.run_model_cmds)


@attrs.define
class TpuGkeTest(TestConfig[Tpu]):
test_name: str
cluster_name: str
cluster_config: str
docker_image: str
set_up_cmds: Iterable[str]
run_model_cmds: Iterable[str]

@property
def benchmark_id(self) -> str:
return f'{self.test_name}-{self.accelerator.name}'

@property
def setup_script(self) -> Optional[str]:
return ';'.join(self.set_up_cmds)

@property
def test_script(self) -> str:
return ';'.join(self.run_model_cmds)


@attrs.define
class JSonnetTpuVmTest(TestConfig[Tpu]):
"""Convert legacy JSonnet test configs into a Task.
Expand Down
4 changes: 2 additions & 2 deletions configs/benchmark/jax/solutionsTeam_jax_npi_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_jax_vit_config(
tpu_cores: int,
tpu_zone: str,
time_out_in_min: int,
) -> task.TpuTask:
) -> task.TpuGceTask:
job_gcp_config = gcp_config.GCPConfig(
project_name=vm_resource.PROJECT_CLOUD_ML_AUTO_SOLUTIONS,
zone=tpu_zone,
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_jax_vit_config(
)
)

return task.TpuTask(
return task.TpuGceTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
task_metric_config=job_metric_config,
Expand Down
4 changes: 2 additions & 2 deletions configs/benchmark/pytorch/pytorchxla_torchbench_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_torchbench_config(
time_out_in_min: int,
model_name: str = "",
extraFlags: str = "",
) -> task.TpuTask:
) -> task.TpuGceTask:
job_gcp_config = gcp_config.GCPConfig(
project_name=vm_resource.PROJECT_CLOUD_ML_AUTO_SOLUTIONS,
zone=tpu_zone,
Expand Down Expand Up @@ -114,7 +114,7 @@ def get_torchbench_config(
)
)

return task.TpuTask(
return task.TpuGceTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
task_metric_config=job_metric_config,
Expand Down
26 changes: 26 additions & 0 deletions configs/cluster/v5e_cluster_config
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
apiVersion: v1
RissyRan marked this conversation as resolved.
Show resolved Hide resolved
clusters:
- cluster:
certificate-authority-data: LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVMVENDQXBXZ0F3SUJBZ0lSQU5sdUp6SkFSS3FvNWFvZXF2M2l0NTh3RFFZSktvWklodmNOQVFFTEJRQXcKTHpFdE1Dc0dBMVVFQXhNa016ZzJZak5qTkRRdE9HWmtOQzAwTjJRNUxXSTBNVFF0WWpkall6RmlNbUV4TmpOaApNQ0FYRFRJek1URXlNREl4TURNek5Wb1lEekl3TlRNeE1URXlNakl3TXpNMVdqQXZNUzB3S3dZRFZRUURFeVF6Ck9EWmlNMk0wTkMwNFptUTBMVFEzWkRrdFlqUXhOQzFpTjJOak1XSXlZVEUyTTJFd2dnR2lNQTBHQ1NxR1NJYjMKRFFFQkFRVUFBNElCandBd2dnR0tBb0lCZ1FDbGoyd2V3UkNFMHBnTjdRMFJOVFlOVlJLQUIrQmVROU1OTG82dwpkZU9EV09uQ3Npd2VvSEpYUlkyM09zdTRDTVRLUVNCbFNrWkZCZFZGR0pYSFdzeEhjdjJ1ZEdwa3NLNnUwSkNlCjdjL2lnc1NLVGwvVlkxQ0RKWnpDQTFmeWpab2hxdGhMaXZFajRFU3JqTTBSSkZUQlVEdXRMQmpsQ0xOWnRKaS8KZm54aldhSlBxVGl1TDR4MUNGMy9EbkxGTjhKQWdROE9JQzFGbkpFdHRlUE9vcXZKOG0rdVZLRHltNUFHR1NvSgpHTjdieTgrcmFJWEFCa1BuNkloVmdqWnllNWJwNzRKZ29GdVAxSTFPeVBoVDVEV2ZQcFl4U3lFTTIzSnI0ZEQrCkhpaGxBdXhQc1M1a29vODJPSVlFcXFtTnM5VDdIV3ByMnJrNDJ2dUtFN0JzMVRrb1dxTFQvUFAzOGdJMzZJL0kKWEZXYzE3NU45L1hRQmlaMGFYNWVxRWdrWWkzRjdGV3FrWEI5N2VWdkJDYTQ1YXdmdFpTQWdwQWtuWnZ3MnpFagp3N2hneUdQZ1g2L0xIM05naFYyelNDdUhJanJIOEtNY0FsaE1nbWRhZlphUzZtdkJvVEEyNmZvdTF5eU1vRHJQCmJSbkxvaFBtQXJEdm9xalJLTnZCbjRrbExUTUNBd0VBQWFOQ01FQXdEZ1lEVlIwUEFRSC9CQVFEQWdJRU1BOEcKQTFVZEV3RUIvd1FGTUFNQkFmOHdIUVlEVlIwT0JCWUVGTXBNcW9XRDMyOUk2K1hJVmNnanVvVXpvZklXTUEwRwpDU3FHU0liM0RRRUJDd1VBQTRJQmdRQVExQ3ZlTDF1a1JBeEQzLzdvZUhReXRESDB4QkEwT3hYaU5nc0VkY3RLCnhsMlpFMjlmdCtEWEpDSkI2ZjgzYkFGUW5leWVGYmRZTFFadWtzdHdWRk9JZ21BWVExb3pTRjM1R2d5UCtML3AKN0IxR2hDWGhoYUZ5Rk1kdGEwRER4aGZkQ2J1V280N1RMcmpGZG5GdlFmVVlMVWNPcDVoZkNZNUxaWUlXaVF0Twp1VDRRUnIrc09YNm9JeUtxQ1dvUjBDVHpJUS8weHhHaE8zYXMwNy9LbmZ6OWNFZE1xb0xmNS9BOFVrc09mNzcrCkdUVE8wa082NzhQYWxadityZm05K216aTIzQVFmUkxwOWVkRUFyQjl4YlQrTVZIcENPK3ZYaHRLL1JPcUR5b1AKM0FxanQ4N3A0UnNZT01nT0V3YVVRRHlmdzFDTTFmamN6NDlveEF2SHRxZC9Tc0dOSW9lWFc5bFlMWmV1TGE4KwpwYnV2UXhnaVdlUnBCSUVPTTB2dERTQmgrbW16eUVrS20xM29RbDlGV0p2NGNmOWM3SUEvWG8zQWVhMC96R08rCmVtakJ3Z3BpSUxDR1JNRDNSaFovbERNRjhTRVhjeHhkYjlaMmc5Y2lSek1aa1IweG94NDcxRENlZ0Q1eFFmS3cKbDBxeHNqakQ3YVpiK1VOTFBDUCthMUk9Ci0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K
server: https://34.125.160.150
name: gke_tpu-prod-env-one-vm_us-west4_ran-xpk-test-zone
contexts:
- context:
cluster: gke_tpu-prod-env-one-vm_us-west4_ran-xpk-test-zone
user: gke_tpu-prod-env-one-vm_us-west4_ran-xpk-test-zone
name: gke_tpu-prod-env-one-vm_us-west4_ran-xpk-test-zone
current-context: gke_tpu-prod-env-one-vm_us-west4_ran-xpk-test-zone
kind: Config
preferences: {}
users:
- name: gke_tpu-prod-env-one-vm_us-west4_ran-xpk-test-zone
user:
exec:
apiVersion: client.authentication.k8s.io/v1beta1
args: null
command: gke-gcloud-auth-plugin
env: null
installHint: Install gke-gcloud-auth-plugin for use with kubectl by following
https://cloud.google.com/blog/products/containers-kubernetes/kubectl-auth-changes-in-gke
interactiveMode: IfAvailable
provideClusterInfo: true
61 changes: 61 additions & 0 deletions configs/example/gke_example_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities to construct configs for example_dag."""

from apis import gcp_config, metric_config, task, test_config
from configs import test_owner


def get_flax_resnet_gke_config(
tpu_version: str,
tpu_cores: int,
tpu_zone: str,
cluster_name: str,
cluster_config: str,
docker_image: str,
time_out_in_min: int,
) -> task.TpuGkeTask:
# TODO(ranran): update the project once quota is approved (b/311073979).
job_gcp_config = gcp_config.GCPConfig(
project_name="tpu-prod-env-one-vm",
zone=tpu_zone,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)

run_model_cmds = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this test is actually going to be used, should it be in the example directory? If it is just an example, I would go even simpler and just print jax.device_count() to show the configuration is correct.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. Yes, currently I just want to show an example. Since we are testing models, I was thinking probably it's better to have a hello world E2E example. Do you feel this is a complex one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine if the script prints the topology or number of devices. IMO the main part we want to confirm with the example is that we created the right machine and any metadata needed for initialization is plumbed through correctly to the runtime.

"python3 /tmp/flax/examples/imagenet/main.py"
" --config=/tmp/flax/examples/imagenet/configs/tpu.py"
" --workdir=/tmp/imagenet --config.num_epochs=1"
)

job_test_config = test_config.TpuGkeTest(
test_config.Tpu(
version=tpu_version,
cores=tpu_cores,
),
test_name="flax-resnet-gke",
cluster_name=cluster_name,
cluster_config=cluster_config,
docker_image=docker_image,
run_model_cmds=run_model_cmds,
set_up_cmds=None,
time_out_in_min=time_out_in_min,
task_owner=test_owner.RAN_R,
)

return task.TpuGkeTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
)
17 changes: 17 additions & 0 deletions configs/vm_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,20 @@ class RuntimeVersion(enum.Enum):
TPU_VM_TF_NIGHTLY_POD = "tpu-vm-tf-nightly-pod"
TPU_UBUNTU2204_BASE = "tpu-ubuntu2204-base"
TPU_VM_V4_BASE = "tpu-vm-v4-base"


# TODO(ranran): update the cluster name once quota is approved (b/311073979).
class ClusterName(enum.Enum):
V4_CLUSTER = ""
V5E_CLUSTER = "ran-xpk-test-zone"


# TODO(ranran): update the cluster name once quota is approved (b/311073979).
class ClusterConfig(enum.Enum):
V4_CONFIG = "v4_cluster_config"
RissyRan marked this conversation as resolved.
Show resolved Hide resolved
V5E_CONFIG = "v5e_cluster_config"


# TODO(ranran): update the project once quota is approved (b/311073979).
class DockerImage(enum.Enum):
DEMO_TEST = "gcr.io/tpu-prod-env-one-vm/xpk_jax_test:latest"
Loading
Loading