Skip to content

Commit

Permalink
Add maxtext sweep metrics collection (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
raymondzouu authored Feb 13, 2024
1 parent 1874623 commit 09b0443
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 23 deletions.
9 changes: 6 additions & 3 deletions dags/examples/maxtext_sweep_gce_example_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from dags.multipod.configs import maxtext_sweep_gce_config
from dags.multipod.configs import common


# Set concurrency to number of workers otherwise tasks may time out
# if there are more concurrent tasks running at a time than number of workers
with models.DAG(
dag_id="maxtext_sweep_gce_example_dag",
schedule=None,
Expand All @@ -34,11 +35,12 @@
concurrency=2,
) as dag:
# MaxText set up and run commands
base_output_directory = "gs://maxtext-experiments-multipod"
base_set_up_cmds = common.download_maxtext()
base_run_model_cmds = [
"cd /tmp/maxtext",
"bash setup.sh MODE=stable",
"python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=gs://maxtext-experiments-multipod/ dataset_path=gs://max-datasets-rogue enable_checkpointing=false global_parameter_scale=1 steps=10",
f"python3 MaxText/train.py MaxText/configs/base.yml base_output_directory={base_output_directory} dataset_path=gs://max-datasets-rogue enable_checkpointing=false global_parameter_scale=1 steps=10",
]

# Get list of MaxText GCE QueuedResource jobs
Expand All @@ -51,6 +53,7 @@
tpu_version=TpuVersion.V4,
tpu_cores=8,
runtime_version=RuntimeVersion.TPU_UBUNTU2204_BASE.value,
base_output_directory=base_output_directory,
num_slices=[1],
run_name_prefix="maxtext-1b",
base_set_up_cmds=base_set_up_cmds,
Expand All @@ -60,4 +63,4 @@

# Run jobs
for test in maxtext_sweep_gce_test:
test.run()
test.run_with_run_name_generation()
9 changes: 6 additions & 3 deletions dags/examples/maxtext_sweep_gke_example_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from dags.vm_resource import TpuVersion, Zone, Project, ClusterName, DockerImage
from dags.multipod.configs import maxtext_sweep_gke_config


# Set concurrency to number of workers otherwise tasks may time out
# if there are more concurrent tasks running at a time than number of workers
with models.DAG(
dag_id="maxtext_sweep_gke_example_dag",
schedule=None,
Expand All @@ -34,8 +35,9 @@
concurrency=2,
) as dag:
# MaxText set up and run commands
base_output_directory = "gs://maxtext-experiments-multipod"
base_run_model_cmds = [
"python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=gs://maxtext-experiments-multipod/ dataset_path=gs://max-datasets-rogue enable_checkpointing=false global_parameter_scale=16 steps=10",
f"python3 MaxText/train.py MaxText/configs/base.yml base_output_directory={base_output_directory} dataset_path=gs://max-datasets-rogue enable_checkpointing=false global_parameter_scale=16 steps=10",
]

# Get list of MaxText GKE XPK jobs
Expand All @@ -45,6 +47,7 @@
cluster_name=ClusterName.V4_128_MULTISLICE_CLUSTER.value,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
base_output_directory=base_output_directory,
tpu_version=TpuVersion.V4,
tpu_cores=128,
num_slices=[1],
Expand All @@ -56,4 +59,4 @@

# Run jobs
for test in maxtext_sweep_gke_test:
test.run()
test.run_with_run_name_generation()
20 changes: 13 additions & 7 deletions dags/multipod/configs/maxtext_sweep_gce_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from xlml.apis import gcp_config, metric_config, task, test_config
from dags.vm_resource import TpuVersion
import datetime
import itertools
from typing import List, Iterable

Expand All @@ -32,16 +31,18 @@ def get_maxtext_sweep_gce_config(
run_name_prefix: str,
project_name: str,
runtime_version: str,
base_output_directory: str,
base_set_up_cmds: Iterable[str],
base_run_model_cmds: Iterable[str],
dataset_name: metric_config.DatasetOption = metric_config.DatasetOption.BENCHMARK_DATASET,
is_tpu_reserved: bool = True,
network: str = "default",
subnetwork: str = "default",
) -> List[task.TpuQueuedResourceTask]:
job_gcp_config = gcp_config.GCPConfig(
project_name=project_name,
zone=tpu_zone,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
dataset_name=dataset_name,
dataset_project=project_name,
composer_project=project_name,
)
Expand All @@ -56,18 +57,13 @@ def get_maxtext_sweep_gce_config(

# Generate all combinations of sweep param configurations and create a TpuQueuedResourceTask for each one
qr_task_list = []
current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
for idx, config in enumerate(itertools.product(*sweep_params_list)):
config_dict = {key: value for (key, value) in config}

# Remove num_slices as a sweep param after combinations have been generated
curr_num_slices = config_dict["NUM_SLICES"]
del config_dict["NUM_SLICES"]

# Add MaxText run_name
run_name = f"{run_name_prefix}-{curr_num_slices}x{tpu_version.value}-{tpu_cores}-{current_datetime}-{idx}"
config_dict["M_RUN_NAME"] = run_name

# Export sweep params as env variables for MaxText to read
run_model_cmds = [f"export {key}={value}" for (key, value) in config_dict.items()]
for cmd in base_run_model_cmds:
Expand All @@ -89,9 +85,19 @@ def get_maxtext_sweep_gce_config(
task_owner=test_owner,
num_slices=curr_num_slices,
)

job_metric_config = metric_config.MetricConfig(
tensorboard_summary=metric_config.SummaryConfig(
file_location=base_output_directory,
aggregation_strategy=metric_config.AggregationStrategy.MEDIAN,
use_regex_file_location=True,
),
)

qr_task = task.TpuQueuedResourceTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
task_metric_config=job_metric_config,
)
qr_task_list.append(qr_task)

Expand Down
20 changes: 13 additions & 7 deletions dags/multipod/configs/maxtext_sweep_gke_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from xlml.apis import gcp_config, metric_config, task, test_config
from dags.vm_resource import TpuVersion
import datetime
import itertools
from typing import List, Iterable

Expand All @@ -33,13 +32,15 @@ def get_maxtext_sweep_gke_config(
project_name: str,
cluster_name: str,
docker_image: str,
base_output_directory: str,
base_run_model_cmds: Iterable[str],
base_set_up_cmds: Iterable[str] = None,
dataset_name: metric_config.DatasetOption = metric_config.DatasetOption.BENCHMARK_DATASET,
) -> List[task.TpuXpkTask]:
job_gcp_config = gcp_config.GCPConfig(
project_name=project_name,
zone=tpu_zone,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
dataset_name=dataset_name,
dataset_project=project_name,
composer_project=project_name,
)
Expand All @@ -54,18 +55,13 @@ def get_maxtext_sweep_gke_config(

# Generate all combinations of sweep param configurations and create a TpuXpkTask for each one
xpk_task_list = []
current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
for idx, config in enumerate(itertools.product(*sweep_params_list)):
config_dict = {key: value for (key, value) in config}

# Remove num_slices as a sweep param after combinations have been generated
curr_num_slices = config_dict["NUM_SLICES"]
del config_dict["NUM_SLICES"]

# Add MaxText run_name
run_name = f"{run_name_prefix}-{curr_num_slices}x{tpu_version.value}-{tpu_cores}-{current_datetime}-{idx}"
config_dict["M_RUN_NAME"] = run_name

# Export sweep params as env variables for MaxText to read
run_model_cmds = [f"export {key}={value}" for (key, value) in config_dict.items()]
for cmd in base_run_model_cmds:
Expand All @@ -85,9 +81,19 @@ def get_maxtext_sweep_gke_config(
cluster_name=cluster_name,
docker_image=docker_image,
)

job_metric_config = metric_config.MetricConfig(
tensorboard_summary=metric_config.SummaryConfig(
file_location=base_output_directory,
aggregation_strategy=metric_config.AggregationStrategy.MEDIAN,
use_regex_file_location=True,
),
)

xpk_task = task.TpuXpkTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
task_metric_config=job_metric_config,
)
xpk_task_list.append(xpk_task)

Expand Down
3 changes: 3 additions & 0 deletions xlml/apis/metric_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,15 @@ class SummaryConfig:
exclude_tag_patterns: The matching patterns of tags that will be excluded.
No tag is excluded by default. This pattern has higher prioirty to
include_tag_pattern.
use_regex_file_location: Whether to use file_location as a regex to get the
file in GCS.
"""

file_location: str
aggregation_strategy: AggregationStrategy
include_tag_patterns: Optional[Iterable[str]] = None
exclude_tag_patterns: Optional[Iterable[str]] = None
use_regex_file_location: bool = False


@dataclasses.dataclass
Expand Down
63 changes: 62 additions & 1 deletion xlml/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from airflow.models.taskmixin import DAGNode
from airflow.utils.task_group import TaskGroup
from xlml.apis import gcp_config, metric_config, test_config
from xlml.utils import gpu, metric, ssh, tpu, xpk, startup_script
from xlml.utils import gpu, metric, name_format, ssh, tpu, xpk, startup_script


class BaseTask(abc.ABC):
Expand Down Expand Up @@ -77,6 +77,39 @@ def run(self) -> DAGNode:

return group

def run_with_run_name_generation(self) -> DAGNode:
"""Generate a unique run name and tensorboard file location, then run a test job.
Returns:
A task group with the following tasks chained: generate_run_name, generate_tb_file_location, provision, run_model,
post_process and clean_up.
"""
with TaskGroup(
group_id=self.task_test_config.benchmark_id, prefix_group_id=True
) as group:
run_name = name_format.generate_run_name(self.task_test_config.benchmark_id)
tb_file_location = name_format.generate_tb_file_location(
run_name, self.task_metric_config.tensorboard_summary.file_location
)

# Set run_name in run_model_cmds
new_run_model_cmds = [f"export M_RUN_NAME={run_name}"]
for cmd in self.task_test_config.run_model_cmds:
new_run_model_cmds.append(cmd)
self.task_test_config.run_model_cmds = new_run_model_cmds

# Update tensorboard file location
self.task_metric_config.tensorboard_summary.file_location = tb_file_location

provision, queued_resource, ssh_keys = self.provision()
run_model = self.run_model(queued_resource, ssh_keys)
post_process = self.post_process()
clean_up = self.clean_up(queued_resource)

run_name >> tb_file_location >> provision >> run_model >> post_process >> clean_up

return group

def run_with_startup_script(self) -> DAGNode:
"""Run a test job on GCE with startup script.
Expand Down Expand Up @@ -261,6 +294,34 @@ def run(self) -> DAGNode:

return group

def run_with_run_name_generation(self) -> DAGNode:
"""Generate a unique run name and tensorboard file location, then run a test job within a docker image.
Returns:
A task group with the following tasks chained: generate_run_name, generate_tb_file_location, run provision, run_model,
post_process.
"""
with TaskGroup(
group_id=self.task_test_config.benchmark_id, prefix_group_id=True
) as group:
run_name = name_format.generate_run_name(self.task_test_config.benchmark_id)
tb_file_location = name_format.generate_tb_file_location(
run_name, self.task_metric_config.tensorboard_summary.file_location
)

# Set run_name in run_model_cmds
new_run_model_cmds = [f"export M_RUN_NAME={run_name}"]
for cmd in self.task_test_config.run_model_cmds:
new_run_model_cmds.append(cmd)
self.task_test_config.run_model_cmds = new_run_model_cmds

# Update tensorboard file location
self.task_metric_config.tensorboard_summary.file_location = tb_file_location

run_name >> tb_file_location >> self.run_model() >> self.post_process()

return group

def run_model(self) -> DAGNode:
"""Run the TPU test in `task_test_config` using xpk.
Expand Down
Loading

0 comments on commit 09b0443

Please sign in to comment.