Skip to content

Commit

Permalink
Add no_weight_decay fns for use with timm optimizers, add option to u…
Browse files Browse the repository at this point in the history
…se non-scalar tensor for logit scale/bias
  • Loading branch information
rwightman committed Nov 12, 2024
1 parent 4fd0260 commit 2cae8c4
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 6 deletions.
6 changes: 4 additions & 2 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
Expand Down Expand Up @@ -131,9 +132,10 @@ def __init__(
cast_dtype=cast_dtype,
)

self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None
self.pad_id = pad_id
Expand Down
8 changes: 8 additions & 0 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ def load_checkpoint(
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)

# correct if logit_scale differs in being scaler vs 1d param
if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim:
state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape)

# correct if logit_bias differs in being scaler vs 1d param
if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim:
state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape)

# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
if 'logit_bias' not in state_dict and model.logit_bias is not None:
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
Expand Down
34 changes: 30 additions & 4 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def __init__(
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
Expand All @@ -249,9 +250,10 @@ def __init__(
self.text_pool_type = text.pool_type
self.register_buffer('attn_mask', text.attn_mask, persistent=False)

self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None

Expand All @@ -264,6 +266,15 @@ def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable

@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = {'positional_embedding'}
if hasattr(self.visual, 'no_weight_decay'):
for n in self.visual.no_weight_decay():
no_wd.add('visual.' + n)
return no_wd

def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
Expand Down Expand Up @@ -328,6 +339,7 @@ def __init__(
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
Expand All @@ -337,9 +349,11 @@ def __init__(
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)

lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None

Expand All @@ -355,6 +369,18 @@ def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)

@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = set()
if hasattr(self.visual, 'no_weight_decay'):
for n in self.visual.no_weight_decay():
no_wd.add('visual.' + n)
if hasattr(self.text, 'no_weight_decay'):
for n in self.visual.no_weight_decay():
no_wd.add('text.' + n)
return no_wd

def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
Expand Down
14 changes: 14 additions & 0 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,12 @@ def init_parameters(self):
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable

@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = {'positional_embedding', 'class_embedding'}
return no_wd

def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.pool_type == 'avg':
pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
Expand Down Expand Up @@ -759,6 +765,14 @@ def init_parameters(self):
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable

@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = {'positional_embedding'}
if self.cls_emb is not None:
no_wd.add('cls_emb')
return no_wd

def build_causal_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
Expand Down

0 comments on commit 2cae8c4

Please sign in to comment.