Skip to content

Commit

Permalink
Merge pull request #56 from zooniverse/bajor-zoobot-upgrade
Browse files Browse the repository at this point in the history
upgrade zoobot to latest version
  • Loading branch information
Tooyosi authored Dec 12, 2024
2 parents 0ef709d + 5e33361 commit d938268
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 29 deletions.
21 changes: 15 additions & 6 deletions azure/batch/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
FROM nvidia/cuda:11.3.1-base-ubuntu20.04
FROM nvidia/cuda:12.1.1-base-ubuntu22.04

ENV DEBIAN_FRONTEND noninteractive
ENV DEBIAN_FRONTEND=noninteractive

WORKDIR /usr/src/zoobot

# Install prerequisites and add deadsnakes PPA for Python 3.10
RUN apt-get update && apt-get -y upgrade && \
apt-get install --no-install-recommends -y \
software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && apt-get install --no-install-recommends -y \
build-essential \
python3 \
python3.10 \
python3.10-distutils \
python3.10-dev \
python3-pip \
git && \
apt-get clean && rm -rf /var/lib/apt/lists/*

RUN ln -s /usr/bin/python3 /usr/bin/python
# Link Python 3.10 as default
RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \
ln -sf /usr/bin/python3.10 /usr/bin/python3

# install a newer version of pip
# as we can't use the use the ubuntu package pip version (20.0.2)
Expand All @@ -23,6 +31,7 @@ RUN apt-get remove -y python3-pip
RUN ln -s /usr/local/bin/pip3 /usr/bin/pip
RUN ln -s /usr/local/bin/pip3 /usr/bin/pip3

# install our dependencies (see setup.py)

# Install project dependencies (see setup.py)
COPY setup.py .
RUN pip install . --extra-index-url https://download.pytorch.org/whl/cu113
RUN pip install . --extra-index-url https://download.pytorch.org/whl/cu121
12 changes: 2 additions & 10 deletions azure/batch/scripts/predict_on_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,7 @@ def __getitem__(self, idx):
# ensure we raise other response errors like 404 and 500 etc
# Note: we don't retry on errors that aren't in the `status_forcelist`, instead we fast fail!
response.raise_for_status()
url_mime_type = response.headers['content-type']
# handle PNG images
if url_mime_type == 'image/png':
# use PIL image to read the png file buffer
image = Image.open(response.raw)
else: # but assume all other images are JPEG
# HWC PIL image
image = Image.fromarray(
galaxy_dataset.decode_jpeg(response.raw.read()))
image = Image.open(response.raw)
except Exception as e:
# add some logging on the failed url
logging.critical('Cannot load {}'.format(url))
Expand All @@ -98,7 +90,7 @@ def __getitem__(self, idx):
class PredictionGalaxyDataModule(galaxy_datamodule.GalaxyDataModule):
# override the setup method to setup our prediction dataset on the prediction catalog
def setup(self, stage: Optional[str] = None):
self.predict_dataset = PredictionGalaxyDataset(catalog=self.predict_catalog, transform=self.transform)
self.predict_dataset = PredictionGalaxyDataset(catalog=self.predict_catalog, transform=self.test_transform)


def save_predictions_to_json(predictions: np.ndarray, image_ids: List[str], label_cols: List[str], save_loc: str):
Expand Down
15 changes: 4 additions & 11 deletions azure/batch/scripts/train_model_finetune_on_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule

from zoobot.pytorch.training import finetune
from zoobot.shared.schemas import cosmic_dawn_ortho_schema
from zoobot.shared.schemas import cosmic_dawn_ortho_schema, euclid_ortho_schema

if __name__ == '__main__':

Expand Down Expand Up @@ -41,11 +41,7 @@

schema_dict = {
'cosmic_dawn': cosmic_dawn_ortho_schema,
'euclid': {
'label_cols': ['smooth-or-featured-euclid_smooth', 'smooth-or-featured-euclid_featured-or-disk', 'smooth-or-featured-euclid_problem', 'disk-edge-on-euclid_yes', 'disk-edge-on-euclid_no', 'has-spiral-arms-euclid_yes', 'has-spiral-arms-euclid_no', 'bar-euclid_strong', 'bar-euclid_weak', 'bar-euclid_no', 'bulge-size-euclid_dominant', 'bulge-size-euclid_large', 'bulge-size-euclid_moderate', 'bulge-size-euclid_small', 'bulge-size-euclid_none', 'how-rounded-euclid_round', 'how-rounded-euclid_in-between', 'how-rounded-euclid_cigar-shaped', 'edge-on-bulge-euclid_boxy', 'edge-on-bulge-euclid_none', 'edge-on-bulge-euclid_rounded', 'spiral-winding-euclid_tight', 'spiral-winding-euclid_medium', 'spiral-winding-euclid_loose', 'spiral-arm-count-euclid_1', 'spiral-arm-count-euclid_2', 'spiral-arm-count-euclid_3', 'spiral-arm-count-euclid_4', 'spiral-arm-count-euclid_more-than-4', 'spiral-arm-count-euclid_cant-tell', 'merging-euclid_none', 'merging-euclid_minor-disturbance', 'merging-euclid_major-disturbance', 'merging-euclid_merger', 'clumps-euclid_yes', 'clumps-euclid_no', 'problem-euclid_star', 'problem-euclid_artifact', 'problem-euclid_zoom', 'artifact-euclid_satellite', 'artifact-euclid_scattered', 'artifact-euclid_diffraction', 'artifact-euclid_ray', 'artifact-euclid_saturation', 'artifact-euclid_other', 'artifact-euclid_ghost'],
'questions': ['smooth-or-featured-euclid', 'indices 0 to 2', 'asked after None', 'disk-edge-on-euclid', 'indices 3 to 4', 'asked after smooth-or-featured-euclid_featured-or-disk', 'index 1', 'has-spiral-arms-euclid', 'indices 5 to 6', 'asked after disk-edge-on-euclid_no', 'index 4', 'bar-euclid', 'indices 7 to 9', 'asked after disk-edge-on-euclid_no', 'index 4', 'bulge-size-euclid', 'indices 10 to 14', 'asked after disk-edge-on-euclid_no', 'index 4', 'how-rounded-euclid', 'indices 15 to 17',' asked after smooth-or-featured-euclid_smooth', 'index 0', 'edge-on-bulge-euclid', 'indices 18 to 20', 'asked after disk-edge-on-euclid_yes', 'index 3', 'spiral-winding-euclid', 'indices 21 to 23', 'asked after has-spiral-arms-euclid_yes', 'index 5', 'spiral-arm-count-euclid', 'indices 24 to 29', 'asked after has-spiral-arms-euclid_yes', 'index 5', 'merging-euclid', 'indices 30 to 33', 'asked after None', 'clumps-euclid', 'indices 34 to 35', 'asked after disk-edge-on-euclid_no', 'index 4', 'problem-euclid', 'indices 36 to 38', 'asked after smooth-or-featured-euclid_problem', 'index 2', 'artifact-euclid', 'indices 39 to 45', 'asked after problem-euclid_artifact', 'index 37'],
'question_answer_pairs': {'smooth-or-featured-euclid': ['_smooth', '_featured-or-disk', '_problem'], 'disk-edge-on-euclid': ['_yes', '_no'], 'has-spiral-arms-euclid': ['_yes', '_no'], 'bar-euclid': ['_strong', '_weak', '_no'], 'bulge-size-euclid': ['_dominant', '_large', '_moderate', '_small', '_none'], 'how-rounded-euclid': ['_round', '_in-between', '_cigar-shaped'], 'edge-on-bulge-euclid': ['_boxy', '_none', '_rounded'], 'spiral-winding-euclid': ['_tight', '_medium', '_loose'], 'spiral-arm-count-euclid': ['_1', '_2', '_3', '_4', '_more-than-4', '_cant-tell'], 'merging-euclid': ['_none', '_minor-disturbance', '_major-disturbance', '_merger'], 'clumps-euclid': ['_yes', '_no'], 'problem-euclid': ['_star', '_artifact', '_zoom'], 'artifact-euclid': ['_satellite', '_scattered', '_diffraction', '_ray', '_saturation', '_other', '_ghost']}
}
'euclid': euclid_ortho_schema
}
schema = schema_dict.get(args.schema, cosmic_dawn_ortho_schema)
# setup the error reporting tool - https://app.honeybadger.io/projects/
Expand Down Expand Up @@ -105,15 +101,12 @@
else:
logger = None


# load the model from checkpoint
model = finetune.FinetuneableZoobotTree(
checkpoint_loc=args.checkpoint,
# params specific to tree finetuning
schema=schema,
# params for superclass i.e. any finetuning
encoder_dim=args.encoder_dim,
n_layers=args.n_layers,
prog_bar=args.progress_bar
zoobot_checkpoint_loc=args.checkpoint
)

trainer = finetune.get_trainer(
Expand Down
4 changes: 2 additions & 2 deletions azure/batch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
"Environment :: GPU :: NVIDIA CUDA"
],
packages=setuptools.find_packages(),
python_requires=">=3.7", # tf 2.8.0 requires Python 3.7 and above
python_requires=">=3.9", # tf 2.8.0 requires Python 3.7 and above
install_requires=[
'zoobot[pytorch_cu113] >= 1.0', # the big cheese - bring in the zoobot!
'zoobot[pytorch-cu121] >= 2.0.0', # the big cheese - bring in the zoobot!
'requests >= 2.28.1', # used to download prediction images from a remote URL
'honeybadger' # used for error reporting
]
Expand Down

0 comments on commit d938268

Please sign in to comment.