Skip to content

Commit

Permalink
ft tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Feb 1, 2024
1 parent b652165 commit bb9322e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __init__(

if question_answer_pairs is not None:
logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__')
assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies"
# assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies"
assert dependencies is not None
self.schema = schemas.Schema(question_answer_pairs, dependencies)
# replace with schema-derived version
Expand Down
6 changes: 3 additions & 3 deletions zoobot/pytorch/predictions/predict_on_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ def predict(catalog: pd.DataFrame, model: pl.LightningModule, n_samples: int, la
# crucial to specify the stage, or will error (as missing other catalogs)
predict_datamodule.setup(stage='predict')
# for images in predict_datamodule.predict_dataloader():
# print(images)
# print(images.shape)
# print(images)
# print(images.shape)
# exit()


# set up trainer (again)
trainer = pl.Trainer(
max_epochs=-1, # does nothing in this context, suppresses warning
inference_mode=True, # no grads needed
**trainer_kwargs # e.g. gpus
)

Expand Down

0 comments on commit bb9322e

Please sign in to comment.