Skip to content

Commit

Permalink
use meta device
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Nov 30, 2024
1 parent 82e8677 commit b95e335
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
15 changes: 10 additions & 5 deletions timm/models/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import logging
import os
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple
from contextlib import nullcontext

import torch
from torch import nn as nn
from torch.hub import load_state_dict_from_url

Expand Down Expand Up @@ -411,10 +413,13 @@ def build_model_with_cfg(
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')

# Instantiate the model
if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)
with torch.device("meta") if pretrained else nullcontext():
if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)
if pretrained:
model.to_empty(device="cpu")
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat

Expand Down
3 changes: 2 additions & 1 deletion timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,8 @@ def __init__(
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
with torch.device("cpu"):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim,
Expand Down

0 comments on commit b95e335

Please sign in to comment.