diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 24da8eba4..34026a583 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -365,7 +365,7 @@ def __init__( self.query_embed = None self.query_dim = 0 self.shared_fc = shared_fc - num_groups = num_classes if num_groups < 1 + num_groups = num_classes if num_groups < 1 else num_groups # case using class embed if have_class_embed: