diff --git a/azure/batch/scripts/train_model_finetune_on_catalog.py b/azure/batch/scripts/train_model_finetune_on_catalog.py index b5fbef9..90c70bc 100644 --- a/azure/batch/scripts/train_model_finetune_on_catalog.py +++ b/azure/batch/scripts/train_model_finetune_on_catalog.py @@ -22,6 +22,7 @@ # expects path to csv parser.add_argument('--catalog', dest='catalog_loc', type=str, required=True) parser.add_argument('--checkpoint', dest='checkpoint', type=str, required=True) + parser.add_argument('--schema', dest='schema', type=str, default='cosmic_dawn') parser.add_argument('--num-workers', dest='num_workers', type=int, default=11) # benchmarks show 11 work on our VM types - was int((os.cpu_count()) parser.add_argument('--prefetch-factor', dest='prefetch_factor', type=int, default=9) # benchmarks show 9 works on our VM types (lots of ram) - was 4 (default) # V100 GPU can handle 128 - can look at --mixed-precision opt to decrease the ram use @@ -38,6 +39,15 @@ parser.add_argument('--debug', dest='debug', default=False, action='store_true') args = parser.parse_args() + schema_dict = { + 'cosmic_dawn': cosmic_dawn_ortho_schema, + 'euclid': { + 'label_cols': ['smooth-or-featured-euclid_smooth', 'smooth-or-featured-euclid_featured-or-disk', 'smooth-or-featured-euclid_problem', 'disk-edge-on-euclid_yes', 'disk-edge-on-euclid_no', 'has-spiral-arms-euclid_yes', 'has-spiral-arms-euclid_no', 'bar-euclid_strong', 'bar-euclid_weak', 'bar-euclid_no', 'bulge-size-euclid_dominant', 'bulge-size-euclid_large', 'bulge-size-euclid_moderate', 'bulge-size-euclid_small', 'bulge-size-euclid_none', 'how-rounded-euclid_round', 'how-rounded-euclid_in-between', 'how-rounded-euclid_cigar-shaped', 'edge-on-bulge-euclid_boxy', 'edge-on-bulge-euclid_none', 'edge-on-bulge-euclid_rounded', 'spiral-winding-euclid_tight', 'spiral-winding-euclid_medium', 'spiral-winding-euclid_loose', 'spiral-arm-count-euclid_1', 'spiral-arm-count-euclid_2', 'spiral-arm-count-euclid_3', 'spiral-arm-count-euclid_4', 'spiral-arm-count-euclid_more-than-4', 'spiral-arm-count-euclid_cant-tell', 'merging-euclid_none', 'merging-euclid_minor-disturbance', 'merging-euclid_major-disturbance', 'merging-euclid_merger', 'clumps-euclid_yes', 'clumps-euclid_no', 'problem-euclid_star', 'problem-euclid_artifact', 'problem-euclid_zoom', 'artifact-euclid_satellite', 'artifact-euclid_scattered', 'artifact-euclid_diffraction', 'artifact-euclid_ray', 'artifact-euclid_saturation', 'artifact-euclid_other', 'artifact-euclid_ghost'], + 'questions': ['smooth-or-featured-euclid', 'indices 0 to 2', 'asked after None', 'disk-edge-on-euclid', 'indices 3 to 4', 'asked after smooth-or-featured-euclid_featured-or-disk', 'index 1', 'has-spiral-arms-euclid', 'indices 5 to 6', 'asked after disk-edge-on-euclid_no', 'index 4', 'bar-euclid', 'indices 7 to 9', 'asked after disk-edge-on-euclid_no', 'index 4', 'bulge-size-euclid', 'indices 10 to 14', 'asked after disk-edge-on-euclid_no', 'index 4', 'how-rounded-euclid', 'indices 15 to 17',' asked after smooth-or-featured-euclid_smooth', 'index 0', 'edge-on-bulge-euclid', 'indices 18 to 20', 'asked after disk-edge-on-euclid_yes', 'index 3', 'spiral-winding-euclid', 'indices 21 to 23', 'asked after has-spiral-arms-euclid_yes', 'index 5', 'spiral-arm-count-euclid', 'indices 24 to 29', 'asked after has-spiral-arms-euclid_yes', 'index 5', 'merging-euclid', 'indices 30 to 33', 'asked after None', 'clumps-euclid', 'indices 34 to 35', 'asked after disk-edge-on-euclid_no', 'index 4', 'problem-euclid', 'indices 36 to 38', 'asked after smooth-or-featured-euclid_problem', 'index 2', 'artifact-euclid', 'indices 39 to 45', 'asked after problem-euclid_artifact', 'index 37'], + 'question_answer_pairs': {'smooth-or-featured-euclid': ['_smooth', '_featured-or-disk', '_problem'], 'disk-edge-on-euclid': ['_yes', '_no'], 'has-spiral-arms-euclid': ['_yes', '_no'], 'bar-euclid': ['_strong', '_weak', '_no'], 'bulge-size-euclid': ['_dominant', '_large', '_moderate', '_small', '_none'], 'how-rounded-euclid': ['_round', '_in-between', '_cigar-shaped'], 'edge-on-bulge-euclid': ['_boxy', '_none', '_rounded'], 'spiral-winding-euclid': ['_tight', '_medium', '_loose'], 'spiral-arm-count-euclid': ['_1', '_2', '_3', '_4', '_more-than-4', '_cant-tell'], 'merging-euclid': ['_none', '_minor-disturbance', '_major-disturbance', '_merger'], 'clumps-euclid': ['_yes', '_no'], 'problem-euclid': ['_star', '_artifact', '_zoom'], 'artifact-euclid': ['_satellite', '_scattered', '_diffraction', '_ray', '_saturation', '_other', '_ghost']} + } + } + schema = args.schema # setup the error reporting tool - https://app.honeybadger.io/projects/ honeybadger_api_key = os.getenv('HONEYBADGER_API_KEY') if honeybadger_api_key: @@ -61,7 +71,7 @@ kade_catalog['file_loc'].iloc[len(kade_catalog.index) - 1])) datamodule = GalaxyDataModule( - label_cols=cosmic_dawn_ortho_schema.label_cols, + label_cols=schema_dict[args.schema].label_cols, catalog=kade_catalog, batch_size=args.batch_size, num_workers=args.num_workers, @@ -99,7 +109,7 @@ model = finetune.FinetuneableZoobotTree( checkpoint_loc=args.checkpoint, # params specific to tree finetuning - schema=cosmic_dawn_ortho_schema, + schema=schema_dict[args.schema], # params for superclass i.e. any finetuning encoder_dim=args.encoder_dim, n_layers=args.n_layers, diff --git a/bajor/apis/predictions.py b/bajor/apis/predictions.py index b5eab0d..1a96ae7 100644 --- a/bajor/apis/predictions.py +++ b/bajor/apis/predictions.py @@ -38,7 +38,7 @@ async def create_job(job: PredictionJob, response: Response, authorized: bool = else: log.debug('No active jobs running - lets get scheduling!') results = predictions.schedule_job( - job_id, job.manifest_url, job.run_opts) + job_id, job.manifest_url, job.opts) job.id = results['submitted_job_id'] job.status = results['job_task_status'] diff --git a/bajor/apis/training.py b/bajor/apis/training.py index 9e20ec4..3bf93ed 100644 --- a/bajor/apis/training.py +++ b/bajor/apis/training.py @@ -40,9 +40,8 @@ async def create_job(job: TrainingJob, response: Response, authorized: bool = De log.debug('No active jobs running - lets get scheduling!') # allow the env to specify default run opts like --debug on staging - run_opts = f'{job.run_opts} {training_run_opts()}' - - results = training.schedule_job(job_id, job.stripped_manifest_path(), run_opts) + job.opts.run_opts = f'{job.opts.run_opts} {training_run_opts()}' + results = training.schedule_job(job_id, job.stripped_manifest_path(), job.opts) job.id = results['submitted_job_id'] job.status = results['job_task_status'] diff --git a/bajor/batch/predictions.py b/bajor/batch/predictions.py index 95f4462..6644096 100644 --- a/bajor/batch/predictions.py +++ b/bajor/batch/predictions.py @@ -9,6 +9,7 @@ from bajor.batch.client import azure_batch_client import bajor.batch.jobs as batch_jobs from bajor.log_config import log +from bajor.models.job import Options # Zoobot Azure Batch predictions pool ID predictions_pool_id = os.getenv('POOL_ID', 'predictions_0') @@ -24,17 +25,19 @@ def get_non_active_batch_job_list(): return batch_jobs.get_non_active_batch_job_list(predictions_pool_id) # schedule a training job -def schedule_job(job_id, manifest_url, run_opts=''): +def schedule_job(job_id:str, manifest_url:str, options:Options=Options()): + checkpoint_target = 'EUCLID_ZOOBOT_CHECKPOINT_TARGET' if options.workflow_name == 'euclid' else 'ZOOBOT_CHECKPOINT_TARGET' + submitted_job_id = create_batch_job( - job_id=job_id, manifest_url=manifest_url, pool_id=predictions_pool_id) + job_id=job_id, manifest_url=manifest_url, pool_id=predictions_pool_id, checkpoint_target=checkpoint_target) job_task_submission_status = create_job_tasks( - job_id=job_id, run_opts=run_opts) + job_id=job_id, run_opts=options.run_opts) # return the submitted job_id and task submission status dict return batch_jobs.job_submission_response(submitted_job_id, job_task_submission_status) -def create_batch_job(job_id, manifest_url, pool_id): +def create_batch_job(job_id, manifest_url, pool_id, checkpoint_target='ZOOBOT_CHECKPOINT_TARGET'): log.debug('server_job, create_batch_job, using manifest at url: {}'.format(manifest_url)) log.debug(f'BatchJobManager, create_job, job_id: {job_id}') @@ -67,7 +70,7 @@ def create_batch_job(job_id, manifest_url, pool_id): # set the zoobot saved model checkpoint file path batchmodels.EnvironmentSetting( name='ZOOBOT_CHECKPOINT_TARGET', - value=os.getenv('ZOOBOT_CHECKPOINT_TARGET', 'zoobot.ckpt')), + value=os.getenv(checkpoint_target, 'zoobot.ckpt')), # setup error reporting service batchmodels.EnvironmentSetting( name='HONEYBADGER_API_KEY', diff --git a/bajor/batch/train_finetuning.py b/bajor/batch/train_finetuning.py index 9cd2e28..6902698 100644 --- a/bajor/batch/train_finetuning.py +++ b/bajor/batch/train_finetuning.py @@ -9,6 +9,7 @@ from bajor.batch.client import azure_batch_client import bajor.batch.jobs as batch_jobs from bajor.log_config import log +from bajor.models.job import Options # Zoobot Azure Batch training pool ID training_pool_id = os.getenv('POOL_ID', 'training_1') @@ -25,16 +26,18 @@ def get_non_active_batch_job_list(): return batch_jobs.get_non_active_batch_job_list(training_pool_id) # schedule a training job -def schedule_job(job_id, manifest_path, run_opts=''): +def schedule_job(job_id: str, manifest_path:str, options: Options=Options()): + checkpoint_target = 'EUCLID_ZOOBOT_CHECKPOINT_TARGET' if options.workflow_name == 'euclid' else 'ZOOBOT_CHECKPOINT_TARGET' + submitted_job_id = create_batch_job( - job_id=job_id, manifest_container_path=manifest_path, pool_id=training_pool_id) + job_id=job_id, manifest_container_path=manifest_path, pool_id=training_pool_id, checkpoint_target=checkpoint_target) job_task_submission_status = create_job_tasks( - job_id=job_id, run_opts=run_opts) + job_id=job_id, run_opts=options.run_opts) # return the submitted job_id and task submission status dict return batch_jobs.job_submission_response(submitted_job_id, job_task_submission_status) -def create_batch_job(job_id, manifest_container_path, pool_id): +def create_batch_job(job_id, manifest_container_path, pool_id, checkpoint_target='ZOOBOT_CHECKPOINT_TARGET'): log.debug('server_job, create_batch_job, using manifest from path: {}'.format( manifest_container_path)) @@ -78,7 +81,7 @@ def create_batch_job(job_id, manifest_container_path, pool_id): # set the zoobot saved model checkpoint file path batchmodels.EnvironmentSetting( name='ZOOBOT_CHECKPOINT_TARGET', - value=os.getenv('ZOOBOT_CHECKPOINT_TARGET', 'zoobot.ckpt')), + value=os.getenv(checkpoint_target, 'zoobot.ckpt')), # setup error reporting service batchmodels.EnvironmentSetting( name='HONEYBADGER_API_KEY', diff --git a/bajor/models/job.py b/bajor/models/job.py index 13b5976..96cb85d 100644 --- a/bajor/models/job.py +++ b/bajor/models/job.py @@ -1,10 +1,16 @@ from pydantic import BaseModel, HttpUrl +from typing import Optional, Dict + +class Options(BaseModel): + run_opts: str = "" + workflow_name: str = 'cosmic_dawn' + class TrainingJob(BaseModel): manifest_path: str - id: str | None - status: str | None - run_opts: str = '' + id: Optional[str] = None + status: Optional[str] = None + opts: Options = Options() # remove the leading / from the manifest url # as it's added via the blob storage paths in schedule_job @@ -14,6 +20,6 @@ def stripped_manifest_path(self): class PredictionJob(BaseModel): manifest_url: HttpUrl - id: str | None - status: str | None - run_opts: str = '' + id: Optional[str] = None + status: Optional[str] = None + opts: Options = Options() diff --git a/kubernetes/deployment-staging.tmpl b/kubernetes/deployment-staging.tmpl index b3cd043..4f6a61b 100644 --- a/kubernetes/deployment-staging.tmpl +++ b/kubernetes/deployment-staging.tmpl @@ -50,6 +50,8 @@ spec: value: 'DEBUG' - name: TRAINING_RUN_OPTS value: '--debug --wandb' + - name: EUCLID_ZOOBOT_CHECKPOINT_TARGET + value: 'staging-zoobot.ckpt' - name: ZOOBOT_CHECKPOINT_TARGET value: 'staging-zoobot.ckpt' - name: ZOOBOT_FINETUNE_CHECKPOINT_FILE diff --git a/tests/batch/test_training.py b/tests/batch/test_training.py index 8f3381c..64f1f15 100644 --- a/tests/batch/test_training.py +++ b/tests/batch/test_training.py @@ -40,7 +40,7 @@ def test_schedule_job(mock_create_job_tasks, mock_create_batch_job): def test_no_active_jobs(mock_create_job_tasks, mock_create_batch_job): train_finetuning.schedule_job(fake_job_id, 'fake-manifest.csv') mock_create_batch_job.assert_called_once_with( - job_id=fake_job_id, manifest_container_path='fake-manifest.csv', pool_id='training_1') + job_id=fake_job_id, manifest_container_path='fake-manifest.csv', pool_id='training_1', checkpoint_target= 'ZOOBOT_CHECKPOINT_TARGET') mock_create_job_tasks.assert_called_once_with( job_id=fake_job_id, run_opts='') diff --git a/tests/test_api.py b/tests/test_api.py index c9caa61..c4d8073 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,7 @@ from fastapi.testclient import TestClient import uuid, os, pytest from unittest import mock +from bajor.env_helpers import training_run_opts fake_revision = str(uuid.uuid4()) submitted_job_id = 'fake-job-id' @@ -69,8 +70,11 @@ def test_batch_scheduling_code_is_called(mocked_client): response = mocked_client.post( "/training/jobs/", auth=('bajor', 'bajor'), - json={"manifest_path": "test_manifest_file_path.csv"}, + json={"manifest_path": "test_manifest_file_path.csv", "opt": { "run_opts": "", "workflow_name": 'cosmic_dawn'}}, ) + + run_opts = f' {training_run_opts()}' + assert response.status_code == 201 assert response.json() == { - 'manifest_path': 'test_manifest_file_path.csv', 'id': submitted_job_id, 'run_opts': '', 'status': {"status": "started", "message": "Job submitted successfully"}} + 'manifest_path': 'test_manifest_file_path.csv', 'id': submitted_job_id, 'opts': {'run_opts': run_opts, 'workflow_name': 'cosmic_dawn'}, 'status': {"status": "started", "message": "Job submitted successfully"}}