diff --git a/deepvision/models/segmentation/segformer/segformer_pt.py b/deepvision/models/segmentation/segformer/segformer_pt.py index 8d5bc1c..b8f4b8d 100644 --- a/deepvision/models/segmentation/segformer/segformer_pt.py +++ b/deepvision/models/segmentation/segformer/segformer_pt.py @@ -40,7 +40,8 @@ def __init__( backend="pytorch", ) self.softmax_output = softmax_output - self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes) + if num_classes > 1: + self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes) def forward(self, x): y = self.backbone(x) @@ -74,14 +75,15 @@ def training_step(self, train_batch, batch_idx): on_epoch=True, prog_bar=True, ) - acc = self.acc(outputs, targets) - self.log( - "acc", - acc, - on_step=True, - on_epoch=True, - prog_bar=True, - ) + if self.num_classes > 1: + acc = self.acc(outputs, targets) + self.log( + "acc", + acc, + on_step=True, + on_epoch=True, + prog_bar=True, + ) return loss def validation_step(self, val_batch, batch_idx): @@ -95,12 +97,13 @@ def validation_step(self, val_batch, batch_idx): on_epoch=True, prog_bar=True, ) - val_acc = self.acc(outputs, targets) - self.log( - "val_acc", - val_acc, - on_step=True, - on_epoch=True, - prog_bar=True, - ) + if self.num_classes > 1: + val_acc = self.acc(outputs, targets) + self.log( + "val_acc", + val_acc, + on_step=True, + on_epoch=True, + prog_bar=True, + ) return loss