diff --git a/zoobot/pytorch/datasets/webdatamodule.py b/zoobot/pytorch/datasets/webdatamodule.py index abbb32f1..cc5d5800 100644 --- a/zoobot/pytorch/datasets/webdatamodule.py +++ b/zoobot/pytorch/datasets/webdatamodule.py @@ -1,4 +1,5 @@ import os +from collections import defaultdict from typing import Callable import logging import torch.utils.data @@ -89,6 +90,8 @@ def make_image_transform(self, mode="train"): def do_transform(img): + assert img.shape[2] < 4 # 1 or 3 channels in shape[2] dim, i.e. numpy/pil HWC convention + # if not, check decode mode is 'rgb' not 'torchrgb' return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32) return do_transform @@ -105,11 +108,14 @@ def make_loader(self, urls, mode="train"): if self.train_transform is None: logging.info('Using default transform') + decode_mode = 'rgb' # np.array, for albumentations transform_image = self.make_image_transform(mode=mode) else: logging.info('Ignoring hparams and using directly-passed transforms') + decode_mode = 'torchrgb' # tensor, for torchvision transform_image = self.train_transform if mode == 'train' else self.inference_transform + transform_label = dict_to_label_cols_factory(self.label_cols) dataset = wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0, nodesplitter=nodesplitter_func) @@ -119,8 +125,7 @@ def make_loader(self, urls, mode="train"): if shuffle > 0: dataset = dataset.shuffle(shuffle) - # dataset = dataset.decode("rgb") # np.array, for albumentations - dataset = dataset.decode("torchrgb") # tensor, for torchvision + dataset = dataset.decode(decode_mode) if mode == 'predict': if self.label_cols != ['id_str']: @@ -222,9 +227,18 @@ def label_transform(label_dict): return identity # do nothing def dict_to_filled_dict_factory(label_cols): + logging.info(f'label cols: {label_cols}') # might be a little slow, but very safe def label_transform(label_dict: dict): + # modifies inplace with 0 iff key missing - [label_dict.setdefault(col, 0) for col in label_cols] + # [label_dict.setdefault(col, 0) for col in label_cols] + + for col in label_cols: + label_dict[col] = label_dict.get(col, 0) + + # label_dict_with_default = defaultdict(0) + # label_dict_with_default.update(label_dict) + return label_dict return label_transform \ No newline at end of file diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index 3471d668..5fb8e3e3 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -87,6 +87,7 @@ def setup_metrics(self, nan_strategy='error'): # may sometimes want to ignore n def forward(self, x): + assert x.shape[1] < 4 # torchlike BCHW x = self.encoder(x) return self.head(x) @@ -142,7 +143,7 @@ def log_all_metrics(self, subset=None): prog_bar = metric_collection == self.loss_metrics for name, metric in metric_collection.items(): if subset in name: - logging.info(name) + # logging.info(name) self.log(name, metric, on_epoch=True, on_step=False, prog_bar=prog_bar, logger=True) else: # just log everything self.log_dict(self.loss_metrics, on_epoch=True, on_step=False, prog_bar=True, logger=True)