diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index dda3faba5..539616332 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -96,6 +96,7 @@ def __init__( quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, + nonscalar_logit_scale: bool = False, cast_dtype: Optional[torch.dtype] = None, pad_id: int = 0, ): @@ -131,9 +132,10 @@ def __init__( cast_dtype=cast_dtype, ) - self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) + lshape = [1] if nonscalar_logit_scale else [] + self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) if init_logit_bias is not None: - self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) + self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) else: self.logit_bias = None self.pad_id = pad_id diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 358b51fb0..03ccc4f06 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -188,6 +188,14 @@ def load_checkpoint( if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) + # correct if logit_scale differs in being scaler vs 1d param + if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim: + state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape) + + # correct if logit_bias differs in being scaler vs 1d param + if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim: + state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape) + # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712 if 'logit_bias' not in state_dict and model.logit_bias is not None: state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 989662ebb..9a9443603 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -230,6 +230,7 @@ def __init__( quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, + nonscalar_logit_scale: bool = False, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): @@ -249,9 +250,10 @@ def __init__( self.text_pool_type = text.pool_type self.register_buffer('attn_mask', text.attn_mask, persistent=False) - self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) + lshape = [1] if nonscalar_logit_scale else [] + self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) if init_logit_bias is not None: - self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) + self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) else: self.logit_bias = None @@ -264,6 +266,15 @@ def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.transformer.grad_checkpointing = enable + @torch.jit.ignore + def no_weight_decay(self): + # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default + no_wd = {'positional_embedding'} + if hasattr(self.visual, 'no_weight_decay'): + for n in self.visual.no_weight_decay(): + no_wd.add('visual.' + n) + return no_wd + def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features @@ -328,6 +339,7 @@ def __init__( quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, + nonscalar_logit_scale: bool = False, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): @@ -337,9 +349,11 @@ def __init__( self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.context_length = self.text.context_length self.vocab_size = self.text.vocab_size - self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) + + lshape = [1] if nonscalar_logit_scale else [] + self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) if init_logit_bias is not None: - self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) + self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) else: self.logit_bias = None @@ -355,6 +369,18 @@ def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) + @torch.jit.ignore + def no_weight_decay(self): + # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default + no_wd = set() + if hasattr(self.visual, 'no_weight_decay'): + for n in self.visual.no_weight_decay(): + no_wd.add('visual.' + n) + if hasattr(self.text, 'no_weight_decay'): + for n in self.visual.no_weight_decay(): + no_wd.add('text.' + n) + return no_wd + def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index bf85dc8e7..860d77503 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -595,6 +595,12 @@ def init_parameters(self): def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable + @torch.jit.ignore + def no_weight_decay(self): + # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default + no_wd = {'positional_embedding', 'class_embedding'} + return no_wd + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.pool_type == 'avg': pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] @@ -759,6 +765,14 @@ def init_parameters(self): def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable + @torch.jit.ignore + def no_weight_decay(self): + # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default + no_wd = {'positional_embedding'} + if self.cls_emb is not None: + no_wd.add('cls_emb') + return no_wd + def build_causal_mask(self): # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf