Skip to content

Commit

Permalink
Merge branch 'master' into yijiaj/mlperf-a2
Browse files Browse the repository at this point in the history
  • Loading branch information
jyj0w0 authored Oct 1, 2024
2 parents a57f750 + 7f05c24 commit f94e9d2
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
9 changes: 9 additions & 0 deletions dags/legacy_test/tests/pytorch/r2.5/ci.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@ local tpus = import 'templates/tpus.libsonnet';
accelerator: tpus.v5litepod_4,
},

local trillium_4 = self.trillium_4,
trillium_4:: {
tpuSettings+: {
softwareVersion: 'v2-alpha-tpuv6e',
},
accelerator: tpus.trillium_4,
},

configs: [
ci + v5litepod_4 + pjrt,
ci + pjrt + trillium_4,
],
}
7 changes: 6 additions & 1 deletion dags/legacy_test/tests/pytorch/r2.5/common.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ local volumes = import 'templates/volumes.libsonnet';
tpuVmPytorchSetup: |||
pip3 install -U 'setuptools>=70.0.0,<71.0.0'
# `unattended-upgr` blocks us from installing apt dependencies
sudo systemctl stop unattended-upgrades
if systemctl is-active --quiet unattended-upgrades; then
sudo systemctl stop unattended-upgrades
echo "unattended-upgrades stopped."
else
echo "unattended-upgrades is not running."
fi
sudo apt-get -y update
sudo apt install -y libopenblas-base
# for huggingface tests
Expand Down
17 changes: 16 additions & 1 deletion dags/pytorch_xla/r2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from airflow import models
from xlml.apis import gcp_config, metric_config, task, test_config
from dags import composer_env
from dags.vm_resource import Project, Zone, V5_NETWORKS, V5E_SUBNETWORKS
from dags.vm_resource import Project, Zone, V5_NETWORKS, V5E_SUBNETWORKS, V6E_SUBNETWORKS


# Run once a day at 2 pm UTC (6 am PST)
Expand Down Expand Up @@ -52,6 +52,12 @@
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)

US_CENTRAL2_B_TPU_PROD_ENV = gcp_config.GCPConfig(
project_name=Project.TPU_PROD_ENV_AUTOMATED.value,
zone=Zone.US_CENTRAL2_B.value,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)


@task_group(prefix_group_id=False)
def torchvision():
Expand Down Expand Up @@ -194,3 +200,12 @@ def llama():
),
US_EAST1_C,
)

ci_trillium_4 = task.run_queued_resource_test(
test_config.JSonnetTpuVmTest.from_pytorch(
"pt-2-5-ci-func-v6e-4-1vm",
network=V5_NETWORKS,
subnetwork=V6E_SUBNETWORKS,
),
US_CENTRAL2_B_TPU_PROD_ENV,
)

0 comments on commit f94e9d2

Please sign in to comment.