diff --git a/dags/sparsity_diffusion_devx/configs/project_bite_config.py b/dags/sparsity_diffusion_devx/configs/project_bite_config.py index 14d1068d..65aff62c 100644 --- a/dags/sparsity_diffusion_devx/configs/project_bite_config.py +++ b/dags/sparsity_diffusion_devx/configs/project_bite_config.py @@ -51,6 +51,7 @@ def get_bite_tpu_config( runtime_version: str, model_config: str, time_out_in_min: int, + task_owner: str, is_tpu_reserved: bool = False, pinned_version: Optional[str] = None, ): @@ -82,7 +83,7 @@ def get_bite_tpu_config( set_up_cmds=set_up_cmds, run_model_cmds=run_model_cmds, timeout=datetime.timedelta(minutes=time_out_in_min), - task_owner=test_owner.RAN_R, + task_owner=task_owner, gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/jax", ) diff --git a/dags/sparsity_diffusion_devx/project_bite_tpu_e2e.py b/dags/sparsity_diffusion_devx/project_bite_tpu_e2e.py index d18d7b4a..804b2f32 100644 --- a/dags/sparsity_diffusion_devx/project_bite_tpu_e2e.py +++ b/dags/sparsity_diffusion_devx/project_bite_tpu_e2e.py @@ -16,7 +16,7 @@ import datetime from airflow import models -from dags import composer_env +from dags import composer_env, test_owner from dags.vm_resource import TpuVersion, Zone, RuntimeVersion from dags.sparsity_diffusion_devx.configs import project_bite_config as config @@ -47,6 +47,7 @@ runtime_version=RuntimeVersion.TPU_UBUNTU2204_BASE.value, model_config="fuji-test-v1", time_out_in_min=180, + task_owner=test_owner.Maggie_Z, ) # AXLearn pinned version against JAX head @@ -61,4 +62,5 @@ model_config="fuji-test-v1", pinned_version="e918d7c219d067dfcace8a25e619d90c5a54c36b", time_out_in_min=180, + task_owner=test_owner.Maggie_Z, ) diff --git a/dags/test_owner.py b/dags/test_owner.py index e96a8c5e..bcc4bee5 100644 --- a/dags/test_owner.py +++ b/dags/test_owner.py @@ -69,8 +69,12 @@ class Team(enum.Enum): # FRAMEWORK QINY_Y = "Qinyi Y." + # JAX AKANKSHA_G = "Akanksha G." # MAP_REPRODUCIBILITY GUNJAN_J = "Gunjan J." + +# Bite +Maggie_Z = "Maggie Z."