diff --git a/dags/examples/maxtext_sweep_gce_example_dag.py b/dags/examples/maxtext_sweep_gce_example_dag.py index 88c502cff..1131cded5 100644 --- a/dags/examples/maxtext_sweep_gce_example_dag.py +++ b/dags/examples/maxtext_sweep_gce_example_dag.py @@ -24,7 +24,8 @@ from dags.multipod.configs import maxtext_sweep_gce_config from dags.multipod.configs import common - +# Set concurrency to number of workers otherwise tasks may time out +# if there are more concurrent tasks running at a time than number of workers with models.DAG( dag_id="maxtext_sweep_gce_example_dag", schedule=None, @@ -34,11 +35,12 @@ concurrency=2, ) as dag: # MaxText set up and run commands + base_output_directory = "gs://maxtext-experiments-multipod" base_set_up_cmds = common.download_maxtext() base_run_model_cmds = [ "cd /tmp/maxtext", "bash setup.sh MODE=stable", - "python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=gs://maxtext-experiments-multipod/ dataset_path=gs://max-datasets-rogue enable_checkpointing=false global_parameter_scale=1 steps=10", + f"python3 MaxText/train.py MaxText/configs/base.yml base_output_directory={base_output_directory} dataset_path=gs://max-datasets-rogue enable_checkpointing=false global_parameter_scale=1 steps=10", ] # Get list of MaxText GCE QueuedResource jobs @@ -51,6 +53,7 @@ tpu_version=TpuVersion.V4, tpu_cores=8, runtime_version=RuntimeVersion.TPU_UBUNTU2204_BASE.value, + base_output_directory=base_output_directory, num_slices=[1], run_name_prefix="maxtext-1b", base_set_up_cmds=base_set_up_cmds, @@ -60,4 +63,4 @@ # Run jobs for test in maxtext_sweep_gce_test: - test.run() + test.run_with_run_name_generation() diff --git a/dags/examples/maxtext_sweep_gke_example_dag.py b/dags/examples/maxtext_sweep_gke_example_dag.py index 5a70c87f1..0a16dc3cf 100644 --- a/dags/examples/maxtext_sweep_gke_example_dag.py +++ b/dags/examples/maxtext_sweep_gke_example_dag.py @@ -24,7 +24,8 @@ from dags.vm_resource import TpuVersion, Zone, Project, ClusterName, DockerImage from dags.multipod.configs import maxtext_sweep_gke_config - +# Set concurrency to number of workers otherwise tasks may time out +# if there are more concurrent tasks running at a time than number of workers with models.DAG( dag_id="maxtext_sweep_gke_example_dag", schedule=None, @@ -34,8 +35,9 @@ concurrency=2, ) as dag: # MaxText set up and run commands + base_output_directory = "gs://maxtext-experiments-multipod" base_run_model_cmds = [ - "python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=gs://maxtext-experiments-multipod/ dataset_path=gs://max-datasets-rogue enable_checkpointing=false global_parameter_scale=16 steps=10", + f"python3 MaxText/train.py MaxText/configs/base.yml base_output_directory={base_output_directory} dataset_path=gs://max-datasets-rogue enable_checkpointing=false global_parameter_scale=16 steps=10", ] # Get list of MaxText GKE XPK jobs @@ -45,6 +47,7 @@ cluster_name=ClusterName.V4_128_MULTISLICE_CLUSTER.value, tpu_zone=Zone.US_CENTRAL2_B.value, time_out_in_min=60, + base_output_directory=base_output_directory, tpu_version=TpuVersion.V4, tpu_cores=128, num_slices=[1], @@ -56,4 +59,4 @@ # Run jobs for test in maxtext_sweep_gke_test: - test.run() + test.run_with_run_name_generation() diff --git a/dags/multipod/configs/maxtext_sweep_gce_config.py b/dags/multipod/configs/maxtext_sweep_gce_config.py index 6ea8a2b07..891a47513 100644 --- a/dags/multipod/configs/maxtext_sweep_gce_config.py +++ b/dags/multipod/configs/maxtext_sweep_gce_config.py @@ -16,7 +16,6 @@ from xlml.apis import gcp_config, metric_config, task, test_config from dags.vm_resource import TpuVersion -import datetime import itertools from typing import List, Iterable @@ -32,8 +31,10 @@ def get_maxtext_sweep_gce_config( run_name_prefix: str, project_name: str, runtime_version: str, + base_output_directory: str, base_set_up_cmds: Iterable[str], base_run_model_cmds: Iterable[str], + dataset_name: metric_config.DatasetOption = metric_config.DatasetOption.BENCHMARK_DATASET, is_tpu_reserved: bool = True, network: str = "default", subnetwork: str = "default", @@ -41,7 +42,7 @@ def get_maxtext_sweep_gce_config( job_gcp_config = gcp_config.GCPConfig( project_name=project_name, zone=tpu_zone, - dataset_name=metric_config.DatasetOption.XLML_DATASET, + dataset_name=dataset_name, dataset_project=project_name, composer_project=project_name, ) @@ -56,7 +57,6 @@ def get_maxtext_sweep_gce_config( # Generate all combinations of sweep param configurations and create a TpuQueuedResourceTask for each one qr_task_list = [] - current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") for idx, config in enumerate(itertools.product(*sweep_params_list)): config_dict = {key: value for (key, value) in config} @@ -64,10 +64,6 @@ def get_maxtext_sweep_gce_config( curr_num_slices = config_dict["NUM_SLICES"] del config_dict["NUM_SLICES"] - # Add MaxText run_name - run_name = f"{run_name_prefix}-{curr_num_slices}x{tpu_version.value}-{tpu_cores}-{current_datetime}-{idx}" - config_dict["M_RUN_NAME"] = run_name - # Export sweep params as env variables for MaxText to read run_model_cmds = [f"export {key}={value}" for (key, value) in config_dict.items()] for cmd in base_run_model_cmds: @@ -89,9 +85,20 @@ def get_maxtext_sweep_gce_config( task_owner=test_owner, num_slices=curr_num_slices, ) + + job_metric_config = metric_config.MetricConfig( + tensorboard_summary=metric_config.SummaryConfig( + file_location=base_output_directory, + aggregation_strategy=metric_config.AggregationStrategy.MEDIAN, + use_regex_file_location=True, + ), + write_results_to_bigquery=True, + ) + qr_task = task.TpuQueuedResourceTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, + task_metric_config=job_metric_config, ) qr_task_list.append(qr_task) diff --git a/dags/multipod/configs/maxtext_sweep_gke_config.py b/dags/multipod/configs/maxtext_sweep_gke_config.py index 414f4fe58..51cd7b029 100644 --- a/dags/multipod/configs/maxtext_sweep_gke_config.py +++ b/dags/multipod/configs/maxtext_sweep_gke_config.py @@ -16,7 +16,6 @@ from xlml.apis import gcp_config, metric_config, task, test_config from dags.vm_resource import TpuVersion -import datetime import itertools from typing import List, Iterable @@ -33,13 +32,15 @@ def get_maxtext_sweep_gke_config( project_name: str, cluster_name: str, docker_image: str, + base_output_directory: str, base_run_model_cmds: Iterable[str], base_set_up_cmds: Iterable[str] = None, + dataset_name: metric_config.DatasetOption = metric_config.DatasetOption.BENCHMARK_DATASET, ) -> List[task.TpuXpkTask]: job_gcp_config = gcp_config.GCPConfig( project_name=project_name, zone=tpu_zone, - dataset_name=metric_config.DatasetOption.XLML_DATASET, + dataset_name=dataset_name, dataset_project=project_name, composer_project=project_name, ) @@ -54,7 +55,6 @@ def get_maxtext_sweep_gke_config( # Generate all combinations of sweep param configurations and create a TpuXpkTask for each one xpk_task_list = [] - current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") for idx, config in enumerate(itertools.product(*sweep_params_list)): config_dict = {key: value for (key, value) in config} @@ -62,10 +62,6 @@ def get_maxtext_sweep_gke_config( curr_num_slices = config_dict["NUM_SLICES"] del config_dict["NUM_SLICES"] - # Add MaxText run_name - run_name = f"{run_name_prefix}-{curr_num_slices}x{tpu_version.value}-{tpu_cores}-{current_datetime}-{idx}" - config_dict["M_RUN_NAME"] = run_name - # Export sweep params as env variables for MaxText to read run_model_cmds = [f"export {key}={value}" for (key, value) in config_dict.items()] for cmd in base_run_model_cmds: @@ -85,9 +81,20 @@ def get_maxtext_sweep_gke_config( cluster_name=cluster_name, docker_image=docker_image, ) + + job_metric_config = metric_config.MetricConfig( + tensorboard_summary=metric_config.SummaryConfig( + file_location=base_output_directory, + aggregation_strategy=metric_config.AggregationStrategy.MEDIAN, + use_regex_file_location=True, + ), + write_results_to_bigquery=True, + ) + xpk_task = task.TpuXpkTask( task_test_config=job_test_config, task_gcp_config=job_gcp_config, + task_metric_config=job_metric_config, ) xpk_task_list.append(xpk_task) diff --git a/xlml/apis/metric_config.py b/xlml/apis/metric_config.py index c5ff415b8..c50e73a9b 100644 --- a/xlml/apis/metric_config.py +++ b/xlml/apis/metric_config.py @@ -60,12 +60,15 @@ class SummaryConfig: exclude_tag_patterns: The matching patterns of tags that will be excluded. No tag is excluded by default. This pattern has higher prioirty to include_tag_pattern. + use_regex_file_location: Whether to use file_location as a regex to get the + file in GCS. """ file_location: str aggregation_strategy: AggregationStrategy include_tag_patterns: Optional[Iterable[str]] = None exclude_tag_patterns: Optional[Iterable[str]] = None + use_regex_file_location: bool = False @dataclasses.dataclass @@ -89,8 +92,10 @@ class MetricConfig: json_lines: The config for JSON Lines input. tensorboard_summary: The config for TensorBoard summary input. profile: The config for profile input. + write_results_to_bigquery: Override requirements to write to bigquery. """ json_lines: Optional[JSONLinesConfig] = None tensorboard_summary: Optional[SummaryConfig] = None profile: Optional[ProfileConfig] = None + write_results_to_bigquery: bool = False diff --git a/xlml/apis/task.py b/xlml/apis/task.py index 444c9733d..dd9a86d5a 100644 --- a/xlml/apis/task.py +++ b/xlml/apis/task.py @@ -22,7 +22,7 @@ from airflow.models.taskmixin import DAGNode from airflow.utils.task_group import TaskGroup from xlml.apis import gcp_config, metric_config, test_config -from xlml.utils import gpu, metric, ssh, tpu, xpk, startup_script +from xlml.utils import gpu, metric, name_format, ssh, tpu, xpk, startup_script class BaseTask(abc.ABC): @@ -77,6 +77,39 @@ def run(self) -> DAGNode: return group + def run_with_run_name_generation(self) -> DAGNode: + """Generate a unique run name and tensorboard file location, then run a test job. + + Returns: + A task group with the following tasks chained: generate_run_name, generate_tb_file_location, provision, run_model, + post_process and clean_up. + """ + with TaskGroup( + group_id=self.task_test_config.benchmark_id, prefix_group_id=True + ) as group: + run_name = name_format.generate_run_name(self.task_test_config.benchmark_id) + tb_file_location = name_format.generate_tb_file_location( + run_name, self.task_metric_config.tensorboard_summary.file_location + ) + + # Set run_name in run_model_cmds + new_run_model_cmds = [f"export M_RUN_NAME={run_name}"] + for cmd in self.task_test_config.run_model_cmds: + new_run_model_cmds.append(cmd) + self.task_test_config.run_model_cmds = new_run_model_cmds + + # Update tensorboard file location + self.task_metric_config.tensorboard_summary.file_location = tb_file_location + + provision, queued_resource, ssh_keys = self.provision() + run_model = self.run_model(queued_resource, ssh_keys) + post_process = self.post_process() + clean_up = self.clean_up(queued_resource) + + run_name >> tb_file_location >> provision >> run_model >> post_process >> clean_up + + return group + def run_with_startup_script(self) -> DAGNode: """Run a test job on GCE with startup script. @@ -261,6 +294,34 @@ def run(self) -> DAGNode: return group + def run_with_run_name_generation(self) -> DAGNode: + """Generate a unique run name and tensorboard file location, then run a test job within a docker image. + + Returns: + A task group with the following tasks chained: generate_run_name, generate_tb_file_location, run provision, run_model, + post_process. + """ + with TaskGroup( + group_id=self.task_test_config.benchmark_id, prefix_group_id=True + ) as group: + run_name = name_format.generate_run_name(self.task_test_config.benchmark_id) + tb_file_location = name_format.generate_tb_file_location( + run_name, self.task_metric_config.tensorboard_summary.file_location + ) + + # Set run_name in run_model_cmds + new_run_model_cmds = [f"export M_RUN_NAME={run_name}"] + for cmd in self.task_test_config.run_model_cmds: + new_run_model_cmds.append(cmd) + self.task_test_config.run_model_cmds = new_run_model_cmds + + # Update tensorboard file location + self.task_metric_config.tensorboard_summary.file_location = tb_file_location + + run_name >> tb_file_location >> self.run_model() >> self.post_process() + + return group + def run_model(self) -> DAGNode: """Run the TPU test in `task_test_config` using xpk. diff --git a/xlml/utils/metric.py b/xlml/utils/metric.py index cc21103e8..90dc52ed1 100644 --- a/xlml/utils/metric.py +++ b/xlml/utils/metric.py @@ -23,7 +23,9 @@ from typing import Dict, Iterable, List, Optional import uuid from absl import logging +import airflow from airflow.decorators import task +from airflow.exceptions import AirflowFailException from airflow.models import TaskInstance from airflow.operators.python import get_current_context from xlml.apis import gcp_config, test_config @@ -35,6 +37,7 @@ import numpy as np import tensorflow as tf from tensorflow.core.util import event_pb2 +from urllib.parse import urlparse @dataclasses.dataclass @@ -109,6 +112,11 @@ def read_from_tb( metrics[value.tag].append(TensorBoardScalar(float(t), event.step)) elif value_type == "text": metadata[value.tag] = bytes(value.tensor.string_val[0]).decode("utf-8") + elif value.HasField("simple_value"): + # simple_value indicates the value is a float: + # https://github.com/tensorflow/tensorflow/blob/4dacf3f/tensorflow/core/framework/summary.proto#L122 + scalar = TensorBoardScalar(value.simple_value, event.step) + metrics.setdefault(value.tag, []).append(scalar) else: logging.info(f"Discarding data point {value.tag} with type {value_type}.") @@ -220,7 +228,15 @@ def process_tensorboard_summary( a list of MetadataHistoryRow ofr a test run in a test job. """ uuid = generate_row_uuid(base_id, 0) - file_location = summary_config.file_location + + if isinstance(summary_config.file_location, airflow.XComArg): + file_location = summary_config.file_location.resolve(get_current_context()) + else: + file_location = summary_config.file_location + + if summary_config.use_regex_file_location: + file_location = get_gcs_file_location_with_regex(file_location) + aggregation_strategy = summary_config.aggregation_strategy include_tag_patterns = summary_config.include_tag_patterns exclude_tag_patterns = summary_config.exclude_tag_patterns @@ -251,6 +267,35 @@ def process_tensorboard_summary( return [metric_history_rows], [metadata_history_rows] +def get_gcs_file_location_with_regex(file_location: str) -> str: + """ + Get a file from GCS given a regex in the form of `gs:///`. + Does not support bucket name or path regex. Only supports file name regex. + + Args: + file_location: File location regex in the form of `gs:////`. + + Returns: + The file location of the first file that fits the given regex. + """ + storage_client = storage.Client() + + url = urlparse(file_location) + bucket_name = url.netloc + file_path = url.path.strip("/") + file_path_regex = re.compile(file_path) + prefix = "/".join(file_path.split("/")[:-1]) + + all_blobs_names = [ + b.name for b in storage_client.list_blobs(bucket_name, prefix=prefix) + ] + + try: + return f"gs://{bucket_name}/{next(filter(file_path_regex.match, all_blobs_names))}" + except StopIteration: + raise AirflowFailException(f"No objects matched supplied regex: {file_location}") + + # TODO(qinwen): implement profile metrics & upload to Vertex AI TensorBoard def process_profile( uuid: str, file_location: str @@ -332,6 +377,50 @@ def add_airflow_metadata( return metadata +def add_test_config_metadata( + base_id: str, + task_test_config: test_config.TestConfig[test_config.Accelerator], + task_gcp_config: gcp_config.GCPConfig, + metadata: List[List[bigquery.MetricHistoryRow]], +) -> List[List[bigquery.MetricHistoryRow]]: + for index in range(len(metadata)): + uuid = generate_row_uuid(base_id, index) + test_config_meta = [] + + test_config_meta.append( + bigquery.MetadataHistoryRow( + job_uuid=uuid, + metadata_key="accelerator", + metadata_value=task_test_config.accelerator.name, + ) + ) + test_config_meta.append( + bigquery.MetadataHistoryRow( + job_uuid=uuid, + metadata_key="project", + metadata_value=task_gcp_config.project_name, + ) + ) + if hasattr(task_test_config, "num_slices"): + test_config_meta.append( + bigquery.MetadataHistoryRow( + job_uuid=uuid, + metadata_key="num_slices", + metadata_value=task_test_config.num_slices, + ) + ) + test_config_meta.append( + bigquery.MetadataHistoryRow( + job_uuid=uuid, + metadata_key="multislice_topology", + metadata_value=f"{task_test_config.num_slices}x{task_test_config.accelerator.name}", + ) + ) + metadata[index].extend(test_config_meta) + + return metadata + + def generate_row_uuid(base_id: str, index: int) -> str: """Generate uuid for entry. @@ -355,11 +444,16 @@ def generate_process_id() -> str: return str(uuid.uuid4()) -def is_valid_entry() -> bool: +def is_valid_entry(task_metric_config: metric_config.MetricConfig) -> bool: """Define if entries are valid to insert into the table. Only scheduled runs from the prod composer environment are allowed. """ + + # Allow inserting entries if `write_results_to_bigquery` has been set to true + if task_metric_config.write_results_to_bigquery: + return True + # if it's a non-prod run, no entries are inserted if not composer_env.is_prod_env(): logging.info("This is a non-prod run, and no entries are inserted into tables.") @@ -538,6 +632,10 @@ def process_metrics( base_id, task_gcp_config.composer_project, metadata_history_rows_list ) + metadata_history_rows_list = add_test_config_metadata( + base_id, task_test_config, task_gcp_config, metadata_history_rows_list + ) + # append profile metrics to metric_history_rows_list if any if has_profile: if len(metric_history_rows_list) != len(profile_history_rows_list): @@ -577,5 +675,5 @@ def process_metrics( print("Test run rows:", test_run_rows) - if is_valid_entry(): + if is_valid_entry(task_metric_config): bigquery_metric.insert(test_run_rows) diff --git a/xlml/utils/metric_test.py b/xlml/utils/metric_test.py index 11caaff83..8c1922fb9 100644 --- a/xlml/utils/metric_test.py +++ b/xlml/utils/metric_test.py @@ -22,11 +22,12 @@ from absl import flags from absl.testing import absltest from absl.testing import parameterized -from xlml.apis import metric_config +from xlml.apis import metric_config, gcp_config, test_config from dags import composer_env from xlml.utils import bigquery, composer, metric import jsonlines import tensorflow as tf +from dags.vm_resource import TpuVersion, RuntimeVersion """Tests for Benchmark metric.py.""" @@ -315,27 +316,116 @@ def test_add_airflow_metadata(self): self.assert_metric_and_dimension_equal([], [], actual_value, expected_value) + def test_add_test_config_metadata(self): + base_id = "test_run" + uuid = hashlib.sha256(str(base_id + str(0)).encode("utf-8")).hexdigest() + + raw_meta = [ + [ + bigquery.MetadataHistoryRow( + job_uuid=uuid, + metadata_key="framework", + metadata_value="jax", + ) + ] + ] + task_test_config = test_config.TpuVmTest( + test_config.Tpu( + version=TpuVersion.V4, + cores=8, + runtime_version=RuntimeVersion.TPU_UBUNTU2204_BASE.value, + network="default", + subnetwork="default", + ), + test_name="test_name", + set_up_cmds="set_up_cmds", + run_model_cmds="run_model_cmds", + time_out_in_min=60, + task_owner="test_owner", + num_slices=1, + ) + + task_gcp_config = gcp_config.GCPConfig( + project_name="test_project", + zone="tpu_zone", + dataset_name="dataset_name", + dataset_project="test_project", + composer_project="test_project", + ) + + actual_value = metric.add_test_config_metadata( + base_id, + task_test_config, + task_gcp_config, + raw_meta, + ) + print("actual_value", actual_value) + + expected_value = raw_meta + print("expected_value", expected_value) + expected_value[0].append( + bigquery.MetadataHistoryRow( + job_uuid=uuid, + metadata_key="accelerator", + metadata_value="v4-8", + ) + ) + expected_value[0].append( + bigquery.MetadataHistoryRow( + job_uuid=uuid, + metadata_key="project", + metadata_value="test-project", + ) + ) + expected_value[0].append( + bigquery.MetadataHistoryRow( + job_uuid=uuid, + metadata_key="num_slices", + metadata_value="1", + ) + ) + expected_value[0].append( + bigquery.MetadataHistoryRow( + job_uuid=uuid, + metadata_key="topology", + metadata_value="1xv4-8", + ) + ) + self.assert_metric_and_dimension_equal([], [], actual_value, expected_value) + @parameterized.named_parameters( ( "prod_scheduled_run", composer_env.PROD_COMPOSER_ENV_NAME, "scheduled__2023-08-07T21:03:49.181263+00:00", True, + True, ), ( "non-prod_scheduled_run", composer_env.DEV_COMPOSER_ENV_NAME, "scheduled__2023-08-07T21:03:49.181263+00:00", False, + False, ), ( "prod_manual_run", composer_env.PROD_COMPOSER_ENV_NAME, "manual__2023-08-07T21:03:49.181263+00:00", False, + False, + ), + ( + "prod_manual_run_override", + composer_env.PROD_COMPOSER_ENV_NAME, + "manual__2023-08-07T21:03:49.181263+00:00", + True, + True, ), ) - def test_is_valid_entry(self, env_name, run_id, expected_value): + def test_is_valid_entry( + self, env_name, run_id, write_results_to_bigquery, expected_value + ): with mock.patch("xlml.utils.metric.get_current_context") as mock_context: mock_context.return_value = { "run_id": run_id, @@ -347,9 +437,32 @@ def test_is_valid_entry(self, env_name, run_id, expected_value): "COMPOSER_ENVIRONMENT": env_name, }, ) as mock_variable: - actual_value = metric.is_valid_entry() + job_metric_config = metric_config.MetricConfig( + write_results_to_bigquery=write_results_to_bigquery, + ) + actual_value = metric.is_valid_entry(job_metric_config) self.assertEqual(actual_value, expected_value) + def test_get_gcs_file_location_with_regex(self): + with mock.patch("xlml.utils.metric.storage") as mock_storage: + mock_gcs_client = mock_storage.Client.return_value + + expected_path = "path/to/events.out.tfevents.123" + mock_blob_1 = mock.MagicMock() + mock_blob_1.name = expected_path + + mock_blob_2 = mock.MagicMock() + mock_blob_2.name = "path/to/events.out.tfevents.234" + + mock_gcs_client.list_blobs.return_value = [mock_blob_1, mock_blob_2] + + actual_value = metric.get_gcs_file_location_with_regex( + "gs://my-bucket/path/to/events.out.tfevents.1*" + ) + mock_storage.Client.assert_called_once() + mock_gcs_client.list_blobs.assert_called_once() + self.assertEqual(actual_value, f"gs://my-bucket/{expected_path}") + if __name__ == "__main__": absltest.main() diff --git a/xlml/utils/name_format.py b/xlml/utils/name_format.py new file mode 100644 index 000000000..eafaeb117 --- /dev/null +++ b/xlml/utils/name_format.py @@ -0,0 +1,41 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import os +from airflow.decorators import task + + +@task +def generate_run_name(benchmark_id) -> str: + """Generates a unique run name by appending the current datetime to benchmark_id. + + Args: + benchmark_id: Benchmark id of the test + """ + current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + return f"{benchmark_id}-{current_datetime}" + + +@task +def generate_tb_file_location(run_name: str, base_output_directory: str) -> str: + """Generates a path to the tensorboard file to be used as a regex. Assumes the file is located in //tensorboard/events.out.tfevents.* + + Args: + run_name: run name for the tensorboard file location + base_output_directory: GCS bucket path + """ + return os.path.join( + base_output_directory, run_name, "tensorboard", "events.out.tfevents.*" + )