diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index 5fb8e3e3..bd33427e 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -238,7 +238,7 @@ def __init__( if question_answer_pairs is not None: logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__') - assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies" + # assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies" assert dependencies is not None self.schema = schemas.Schema(question_answer_pairs, dependencies) # replace with schema-derived version diff --git a/zoobot/pytorch/predictions/predict_on_catalog.py b/zoobot/pytorch/predictions/predict_on_catalog.py index 3a68ab88..0b99270f 100644 --- a/zoobot/pytorch/predictions/predict_on_catalog.py +++ b/zoobot/pytorch/predictions/predict_on_catalog.py @@ -38,14 +38,14 @@ def predict(catalog: pd.DataFrame, model: pl.LightningModule, n_samples: int, la # crucial to specify the stage, or will error (as missing other catalogs) predict_datamodule.setup(stage='predict') # for images in predict_datamodule.predict_dataloader(): - # print(images) - # print(images.shape) + # print(images) + # print(images.shape) + # exit() # set up trainer (again) trainer = pl.Trainer( max_epochs=-1, # does nothing in this context, suppresses warning - inference_mode=True, # no grads needed **trainer_kwargs # e.g. gpus )