From bb9322e63239ed36febf407e088080a3d4998bd6 Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Thu, 1 Feb 2024 12:38:21 -0500 Subject: [PATCH] ft tweaks --- zoobot/pytorch/estimators/define_model.py | 2 +- zoobot/pytorch/predictions/predict_on_catalog.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index 5fb8e3e3..bd33427e 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -238,7 +238,7 @@ def __init__( if question_answer_pairs is not None: logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__') - assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies" + # assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies" assert dependencies is not None self.schema = schemas.Schema(question_answer_pairs, dependencies) # replace with schema-derived version diff --git a/zoobot/pytorch/predictions/predict_on_catalog.py b/zoobot/pytorch/predictions/predict_on_catalog.py index 3a68ab88..0b99270f 100644 --- a/zoobot/pytorch/predictions/predict_on_catalog.py +++ b/zoobot/pytorch/predictions/predict_on_catalog.py @@ -38,14 +38,14 @@ def predict(catalog: pd.DataFrame, model: pl.LightningModule, n_samples: int, la # crucial to specify the stage, or will error (as missing other catalogs) predict_datamodule.setup(stage='predict') # for images in predict_datamodule.predict_dataloader(): - # print(images) - # print(images.shape) + # print(images) + # print(images.shape) + # exit() # set up trainer (again) trainer = pl.Trainer( max_epochs=-1, # does nothing in this context, suppresses warning - inference_mode=True, # no grads needed **trainer_kwargs # e.g. gpus )