Skip to content

Commit

Permalink
refactor codebase for new euclid workflow usage with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Tooyosi committed Oct 29, 2024
1 parent 0e2836e commit b5e1bd5
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 25 deletions.
14 changes: 12 additions & 2 deletions azure/batch/scripts/train_model_finetune_on_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# expects path to csv
parser.add_argument('--catalog', dest='catalog_loc', type=str, required=True)
parser.add_argument('--checkpoint', dest='checkpoint', type=str, required=True)
parser.add_argument('--schema', dest='schema', type=str, default='cosmic_dawn')
parser.add_argument('--num-workers', dest='num_workers', type=int, default=11) # benchmarks show 11 work on our VM types - was int((os.cpu_count())
parser.add_argument('--prefetch-factor', dest='prefetch_factor', type=int, default=9) # benchmarks show 9 works on our VM types (lots of ram) - was 4 (default)
# V100 GPU can handle 128 - can look at --mixed-precision opt to decrease the ram use
Expand All @@ -38,6 +39,15 @@
parser.add_argument('--debug', dest='debug', default=False, action='store_true')
args = parser.parse_args()

schema_dict = {
'cosmic_dawn': cosmic_dawn_ortho_schema,
'euclid': {
'label_cols': ['smooth-or-featured-euclid_smooth', 'smooth-or-featured-euclid_featured-or-disk', 'smooth-or-featured-euclid_problem', 'disk-edge-on-euclid_yes', 'disk-edge-on-euclid_no', 'has-spiral-arms-euclid_yes', 'has-spiral-arms-euclid_no', 'bar-euclid_strong', 'bar-euclid_weak', 'bar-euclid_no', 'bulge-size-euclid_dominant', 'bulge-size-euclid_large', 'bulge-size-euclid_moderate', 'bulge-size-euclid_small', 'bulge-size-euclid_none', 'how-rounded-euclid_round', 'how-rounded-euclid_in-between', 'how-rounded-euclid_cigar-shaped', 'edge-on-bulge-euclid_boxy', 'edge-on-bulge-euclid_none', 'edge-on-bulge-euclid_rounded', 'spiral-winding-euclid_tight', 'spiral-winding-euclid_medium', 'spiral-winding-euclid_loose', 'spiral-arm-count-euclid_1', 'spiral-arm-count-euclid_2', 'spiral-arm-count-euclid_3', 'spiral-arm-count-euclid_4', 'spiral-arm-count-euclid_more-than-4', 'spiral-arm-count-euclid_cant-tell', 'merging-euclid_none', 'merging-euclid_minor-disturbance', 'merging-euclid_major-disturbance', 'merging-euclid_merger', 'clumps-euclid_yes', 'clumps-euclid_no', 'problem-euclid_star', 'problem-euclid_artifact', 'problem-euclid_zoom', 'artifact-euclid_satellite', 'artifact-euclid_scattered', 'artifact-euclid_diffraction', 'artifact-euclid_ray', 'artifact-euclid_saturation', 'artifact-euclid_other', 'artifact-euclid_ghost'],
'questions': ['smooth-or-featured-euclid', 'indices 0 to 2', 'asked after None', 'disk-edge-on-euclid', 'indices 3 to 4', 'asked after smooth-or-featured-euclid_featured-or-disk', 'index 1', 'has-spiral-arms-euclid', 'indices 5 to 6', 'asked after disk-edge-on-euclid_no', 'index 4', 'bar-euclid', 'indices 7 to 9', 'asked after disk-edge-on-euclid_no', 'index 4', 'bulge-size-euclid', 'indices 10 to 14', 'asked after disk-edge-on-euclid_no', 'index 4', 'how-rounded-euclid', 'indices 15 to 17',' asked after smooth-or-featured-euclid_smooth', 'index 0', 'edge-on-bulge-euclid', 'indices 18 to 20', 'asked after disk-edge-on-euclid_yes', 'index 3', 'spiral-winding-euclid', 'indices 21 to 23', 'asked after has-spiral-arms-euclid_yes', 'index 5', 'spiral-arm-count-euclid', 'indices 24 to 29', 'asked after has-spiral-arms-euclid_yes', 'index 5', 'merging-euclid', 'indices 30 to 33', 'asked after None', 'clumps-euclid', 'indices 34 to 35', 'asked after disk-edge-on-euclid_no', 'index 4', 'problem-euclid', 'indices 36 to 38', 'asked after smooth-or-featured-euclid_problem', 'index 2', 'artifact-euclid', 'indices 39 to 45', 'asked after problem-euclid_artifact', 'index 37'],
'question_answer_pairs': {'smooth-or-featured-euclid': ['_smooth', '_featured-or-disk', '_problem'], 'disk-edge-on-euclid': ['_yes', '_no'], 'has-spiral-arms-euclid': ['_yes', '_no'], 'bar-euclid': ['_strong', '_weak', '_no'], 'bulge-size-euclid': ['_dominant', '_large', '_moderate', '_small', '_none'], 'how-rounded-euclid': ['_round', '_in-between', '_cigar-shaped'], 'edge-on-bulge-euclid': ['_boxy', '_none', '_rounded'], 'spiral-winding-euclid': ['_tight', '_medium', '_loose'], 'spiral-arm-count-euclid': ['_1', '_2', '_3', '_4', '_more-than-4', '_cant-tell'], 'merging-euclid': ['_none', '_minor-disturbance', '_major-disturbance', '_merger'], 'clumps-euclid': ['_yes', '_no'], 'problem-euclid': ['_star', '_artifact', '_zoom'], 'artifact-euclid': ['_satellite', '_scattered', '_diffraction', '_ray', '_saturation', '_other', '_ghost']}
}
}
schema = args.schema
# setup the error reporting tool - https://app.honeybadger.io/projects/
honeybadger_api_key = os.getenv('HONEYBADGER_API_KEY')
if honeybadger_api_key:
Expand All @@ -61,7 +71,7 @@
kade_catalog['file_loc'].iloc[len(kade_catalog.index) - 1]))

datamodule = GalaxyDataModule(
label_cols=cosmic_dawn_ortho_schema.label_cols,
label_cols=schema_dict[args.schema].label_cols,
catalog=kade_catalog,
batch_size=args.batch_size,
num_workers=args.num_workers,
Expand Down Expand Up @@ -99,7 +109,7 @@
model = finetune.FinetuneableZoobotTree(
checkpoint_loc=args.checkpoint,
# params specific to tree finetuning
schema=cosmic_dawn_ortho_schema,
schema=schema_dict[args.schema],
# params for superclass i.e. any finetuning
encoder_dim=args.encoder_dim,
n_layers=args.n_layers,
Expand Down
2 changes: 1 addition & 1 deletion bajor/apis/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def create_job(job: PredictionJob, response: Response, authorized: bool =
else:
log.debug('No active jobs running - lets get scheduling!')
results = predictions.schedule_job(
job_id, job.manifest_url, job.run_opts)
job_id, job.manifest_url, job.opts)
job.id = results['submitted_job_id']
job.status = results['job_task_status']

Expand Down
5 changes: 2 additions & 3 deletions bajor/apis/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ async def create_job(job: TrainingJob, response: Response, authorized: bool = De
log.debug('No active jobs running - lets get scheduling!')

# allow the env to specify default run opts like --debug on staging
run_opts = f'{job.run_opts} {training_run_opts()}'

results = training.schedule_job(job_id, job.stripped_manifest_path(), run_opts)
job.opts.run_opts = f'{job.opts.run_opts} {training_run_opts()}'
results = training.schedule_job(job_id, job.stripped_manifest_path(), job.opts)
job.id = results['submitted_job_id']
job.status = results['job_task_status']

Expand Down
13 changes: 8 additions & 5 deletions bajor/batch/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from bajor.batch.client import azure_batch_client
import bajor.batch.jobs as batch_jobs
from bajor.log_config import log
from bajor.models.job import Options

# Zoobot Azure Batch predictions pool ID
predictions_pool_id = os.getenv('POOL_ID', 'predictions_0')
Expand All @@ -24,17 +25,19 @@ def get_non_active_batch_job_list():
return batch_jobs.get_non_active_batch_job_list(predictions_pool_id)

# schedule a training job
def schedule_job(job_id, manifest_url, run_opts=''):
def schedule_job(job_id:str, manifest_url:str, options:Options=Options()):
checkpoint_target = 'EUCLID_ZOOBOT_CHECKPOINT_TARGET' if options.workflow_name == 'euclid' else 'ZOOBOT_CHECKPOINT_TARGET'

submitted_job_id = create_batch_job(
job_id=job_id, manifest_url=manifest_url, pool_id=predictions_pool_id)
job_id=job_id, manifest_url=manifest_url, pool_id=predictions_pool_id, checkpoint_target=checkpoint_target)
job_task_submission_status = create_job_tasks(
job_id=job_id, run_opts=run_opts)
job_id=job_id, run_opts=options.run_opts)

# return the submitted job_id and task submission status dict
return batch_jobs.job_submission_response(submitted_job_id, job_task_submission_status)


def create_batch_job(job_id, manifest_url, pool_id):
def create_batch_job(job_id, manifest_url, pool_id, checkpoint_target='ZOOBOT_CHECKPOINT_TARGET'):
log.debug('server_job, create_batch_job, using manifest at url: {}'.format(manifest_url))

log.debug(f'BatchJobManager, create_job, job_id: {job_id}')
Expand Down Expand Up @@ -67,7 +70,7 @@ def create_batch_job(job_id, manifest_url, pool_id):
# set the zoobot saved model checkpoint file path
batchmodels.EnvironmentSetting(
name='ZOOBOT_CHECKPOINT_TARGET',
value=os.getenv('ZOOBOT_CHECKPOINT_TARGET', 'zoobot.ckpt')),
value=os.getenv(checkpoint_target, 'zoobot.ckpt')),
# setup error reporting service
batchmodels.EnvironmentSetting(
name='HONEYBADGER_API_KEY',
Expand Down
13 changes: 8 additions & 5 deletions bajor/batch/train_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from bajor.batch.client import azure_batch_client
import bajor.batch.jobs as batch_jobs
from bajor.log_config import log
from bajor.models.job import Options

# Zoobot Azure Batch training pool ID
training_pool_id = os.getenv('POOL_ID', 'training_1')
Expand All @@ -25,16 +26,18 @@ def get_non_active_batch_job_list():
return batch_jobs.get_non_active_batch_job_list(training_pool_id)

# schedule a training job
def schedule_job(job_id, manifest_path, run_opts=''):
def schedule_job(job_id: str, manifest_path:str, options: Options=Options()):
checkpoint_target = 'EUCLID_ZOOBOT_CHECKPOINT_TARGET' if options.workflow_name == 'euclid' else 'ZOOBOT_CHECKPOINT_TARGET'

submitted_job_id = create_batch_job(
job_id=job_id, manifest_container_path=manifest_path, pool_id=training_pool_id)
job_id=job_id, manifest_container_path=manifest_path, pool_id=training_pool_id, checkpoint_target=checkpoint_target)
job_task_submission_status = create_job_tasks(
job_id=job_id, run_opts=run_opts)
job_id=job_id, run_opts=options.run_opts)

# return the submitted job_id and task submission status dict
return batch_jobs.job_submission_response(submitted_job_id, job_task_submission_status)

def create_batch_job(job_id, manifest_container_path, pool_id):
def create_batch_job(job_id, manifest_container_path, pool_id, checkpoint_target='ZOOBOT_CHECKPOINT_TARGET'):
log.debug('server_job, create_batch_job, using manifest from path: {}'.format(
manifest_container_path))

Expand Down Expand Up @@ -78,7 +81,7 @@ def create_batch_job(job_id, manifest_container_path, pool_id):
# set the zoobot saved model checkpoint file path
batchmodels.EnvironmentSetting(
name='ZOOBOT_CHECKPOINT_TARGET',
value=os.getenv('ZOOBOT_CHECKPOINT_TARGET', 'zoobot.ckpt')),
value=os.getenv(checkpoint_target, 'zoobot.ckpt')),
# setup error reporting service
batchmodels.EnvironmentSetting(
name='HONEYBADGER_API_KEY',
Expand Down
18 changes: 12 additions & 6 deletions bajor/models/job.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from pydantic import BaseModel, HttpUrl
from typing import Optional, Dict

class Options(BaseModel):
run_opts: str = ""
workflow_name: str = 'cosmic_dawn'


class TrainingJob(BaseModel):
manifest_path: str
id: str | None
status: str | None
run_opts: str = ''
id: Optional[str] = None
status: Optional[str] = None
opts: Options = Options()

# remove the leading / from the manifest url
# as it's added via the blob storage paths in schedule_job
Expand All @@ -14,6 +20,6 @@ def stripped_manifest_path(self):

class PredictionJob(BaseModel):
manifest_url: HttpUrl
id: str | None
status: str | None
run_opts: str = ''
id: Optional[str] = None
status: Optional[str] = None
opts: Options = Options()
2 changes: 2 additions & 0 deletions kubernetes/deployment-staging.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ spec:
value: 'DEBUG'
- name: TRAINING_RUN_OPTS
value: '--debug --wandb'
- name: EUCLID_ZOOBOT_CHECKPOINT_TARGET
value: 'staging-zoobot.ckpt'
- name: ZOOBOT_CHECKPOINT_TARGET
value: 'staging-zoobot.ckpt'
- name: ZOOBOT_FINETUNE_CHECKPOINT_FILE
Expand Down
2 changes: 1 addition & 1 deletion tests/batch/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_schedule_job(mock_create_job_tasks, mock_create_batch_job):
def test_no_active_jobs(mock_create_job_tasks, mock_create_batch_job):
train_finetuning.schedule_job(fake_job_id, 'fake-manifest.csv')
mock_create_batch_job.assert_called_once_with(
job_id=fake_job_id, manifest_container_path='fake-manifest.csv', pool_id='training_1')
job_id=fake_job_id, manifest_container_path='fake-manifest.csv', pool_id='training_1', checkpoint_target= 'ZOOBOT_CHECKPOINT_TARGET')
mock_create_job_tasks.assert_called_once_with(
job_id=fake_job_id, run_opts='')

Expand Down
8 changes: 6 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi.testclient import TestClient
import uuid, os, pytest
from unittest import mock
from bajor.env_helpers import training_run_opts

fake_revision = str(uuid.uuid4())
submitted_job_id = 'fake-job-id'
Expand Down Expand Up @@ -69,8 +70,11 @@ def test_batch_scheduling_code_is_called(mocked_client):
response = mocked_client.post(
"/training/jobs/",
auth=('bajor', 'bajor'),
json={"manifest_path": "test_manifest_file_path.csv"},
json={"manifest_path": "test_manifest_file_path.csv", "opt": { "run_opts": "", "workflow_name": 'cosmic_dawn'}},
)

run_opts = f' {training_run_opts()}'

assert response.status_code == 201
assert response.json() == {
'manifest_path': 'test_manifest_file_path.csv', 'id': submitted_job_id, 'run_opts': '', 'status': {"status": "started", "message": "Job submitted successfully"}}
'manifest_path': 'test_manifest_file_path.csv', 'id': submitted_job_id, 'opts': {'run_opts': run_opts, 'workflow_name': 'cosmic_dawn'}, 'status': {"status": "started", "message": "Job submitted successfully"}}

0 comments on commit b5e1bd5

Please sign in to comment.