diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 8d124afe4..980056b98 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -2,8 +2,10 @@ import logging import os from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple +from contextlib import nullcontext +import torch from torch import nn as nn from torch.hub import load_state_dict_from_url @@ -411,10 +413,13 @@ def build_model_with_cfg( feature_cfg['feature_cls'] = kwargs.pop('feature_cls') # Instantiate the model - if model_cfg is None: - model = model_cls(**kwargs) - else: - model = model_cls(cfg=model_cfg, **kwargs) + with torch.device("meta") if pretrained else nullcontext(): + if model_cfg is None: + model = model_cls(**kwargs) + else: + model = model_cls(cfg=model_cfg, **kwargs) + if pretrained: + model.to_empty(device="cpu") model.pretrained_cfg = pretrained_cfg model.default_cfg = model.pretrained_cfg # alias for backwards compat diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b3b0ddca0..9ae893910 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -539,7 +539,8 @@ def __init__( self.patch_drop = nn.Identity() self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + with torch.device("cpu"): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ block_fn( dim=embed_dim,