diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 184ddb9bcb..1279b219d9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,10 +16,12 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python: ['3.10'] - torch: ['1.13.0'] - torchvision: ['0.14.0'] + python: ['3.10', '3.11'] + torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.1.0', vision: '0.16.0'}] testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward'] + exclude: + - python: '3.11' + torch: {base: '1.13.0', vision: '0.14.0'} runs-on: ${{ matrix.os }} steps: @@ -34,17 +36,17 @@ jobs: pip install -r requirements-dev.txt - name: Install torch on mac if: startsWith(matrix.os, 'macOS') - run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }} + run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }} - name: Install torch on Windows if: startsWith(matrix.os, 'windows') - run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }} + run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }} - name: Install torch on ubuntu if: startsWith(matrix.os, 'ubuntu') run: | sudo sed -i 's/azure\.//' /etc/apt/sources.list sudo apt update sudo apt install -y google-perftools - pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install --no-cache-dir torch==${{ matrix.torch.base }}+cpu torchvision==${{ matrix.torch.vision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html - name: Install requirements run: | pip install -r requirements.txt diff --git a/tests/test_optim.py b/tests/test_optim.py index 737674e5cf..9bdfd6825d 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -10,7 +10,7 @@ import torch from torch.testing._internal.common_utils import TestCase -from torch.autograd import Variable +from torch.nn import Parameter from timm.scheduler import PlateauLRScheduler from timm.optim import create_optimizer_v2 @@ -21,9 +21,9 @@ def _test_basic_cases_template(weight, bias, input, constructor, scheduler_constructors): - weight = Variable(weight, requires_grad=True) - bias = Variable(bias, requires_grad=True) - input = Variable(input) + weight = Parameter(weight) + bias = Parameter(bias) + input = Parameter(input) optimizer = constructor(weight, bias) schedulers = [] for scheduler_constructor in scheduler_constructors: @@ -55,9 +55,9 @@ def fn(): def _test_state_dict(weight, bias, input, constructor): - weight = Variable(weight, requires_grad=True) - bias = Variable(bias, requires_grad=True) - input = Variable(input) + weight = Parameter(weight) + bias = Parameter(bias) + input = Parameter(input) def fn_base(optimizer, weight, bias): optimizer.zero_grad() @@ -73,8 +73,9 @@ def fn_base(optimizer, weight, bias): for _i in range(20): optimizer.step(fn) # Clone the weights and construct new optimizer for them - weight_c = Variable(weight.data.clone(), requires_grad=True) - bias_c = Variable(bias.data.clone(), requires_grad=True) + with torch.no_grad(): + weight_c = Parameter(weight.clone().detach()) + bias_c = Parameter(bias.clone().detach()) optimizer_c = constructor(weight_c, bias_c) fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c) # Load state dict @@ -86,12 +87,8 @@ def fn_base(optimizer, weight, bias): for _i in range(20): optimizer.step(fn) optimizer_c.step(fn_c) - #assert torch.equal(weight, weight_c) - #assert torch.equal(bias, bias_c) torch_tc.assertEqual(weight, weight_c) torch_tc.assertEqual(bias, bias_c) - # Make sure state dict wasn't modified - torch_tc.assertEqual(state_dict, state_dict_c) # Make sure state dict is deterministic with equal but not identical parameters torch_tc.assertEqual(optimizer.state_dict(), optimizer_c.state_dict()) # Make sure repeated parameters have identical representation in state dict @@ -103,9 +100,10 @@ def fn_base(optimizer, weight, bias): if not torch.cuda.is_available(): return - input_cuda = Variable(input.data.float().cuda()) - weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True) - bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True) + with torch.no_grad(): + input_cuda = Parameter(input.clone().detach().float().cuda()) + weight_cuda = Parameter(weight.clone().detach().cuda()) + bias_cuda = Parameter(bias.clone().detach().cuda()) optimizer_cuda = constructor(weight_cuda, bias_cuda) fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda) @@ -216,21 +214,21 @@ def _test_rosenbrock(constructor, scheduler_constructors=None): scheduler_constructors = [] params_t = torch.tensor([1.5, 1.5]) - params = Variable(params_t, requires_grad=True) + params = Parameter(params_t) optimizer = constructor([params]) schedulers = [] for scheduler_constructor in scheduler_constructors: schedulers.append(scheduler_constructor(optimizer)) solution = torch.tensor([1, 1]) - initial_dist = params.data.dist(solution) + initial_dist = params.clone().detach().dist(solution) def eval(params, w): # Depending on w, provide only the x or y gradient optimizer.zero_grad() loss = rosenbrock(params) loss.backward() - grad = drosenbrock(params.data) + grad = drosenbrock(params.clone().detach()) # NB: We torture test the optimizer by returning an # uncoalesced sparse tensor if w: @@ -256,7 +254,7 @@ def eval(params, w): else: scheduler.step() - torch_tc.assertLessEqual(params.data.dist(solution), initial_dist) + torch_tc.assertLessEqual(params.clone().detach().dist(solution), initial_dist) def _build_params_dict(weight, bias, **kwargs): diff --git a/timm/layers/mlp.py b/timm/layers/mlp.py index a68ae4bcca..11d9eeca32 100644 --- a/timm/layers/mlp.py +++ b/timm/layers/mlp.py @@ -130,8 +130,6 @@ def __init__( self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) - self.drop = nn.Dropout(drop) - def init_weights(self): # override init of fc1 w/ gate portion set to weight near zero, bias=1 nn.init.ones_(self.fc1_g.bias) diff --git a/timm/models/beit.py b/timm/models/beit.py index 3863198f12..663dcc4bd4 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -155,7 +155,7 @@ def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): x = F.scaled_dot_product_attention( q, k, v, attn_mask=rel_pos_bias, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/cait.py b/timm/models/cait.py index 4bc7dafc53..40d56061d3 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -50,7 +50,7 @@ def forward(self, x): if self.fused_attn: x_cls = torch.nn.functional.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/eva.py b/timm/models/eva.py index 81bcce525d..68b315386c 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -126,7 +126,7 @@ def forward( x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index d3d9bfdf65..b3143ae58b 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -514,7 +514,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.fused_attn: x = torch.nn.functional.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 12709f5818..6283443ce5 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -190,7 +190,7 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): k.transpose(-1, -2).contiguous(), v.transpose(-1, -2).contiguous(), attn_mask=attn_bias, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ).transpose(-1, -2).reshape(B, -1, H, W) else: q = q * self.scale @@ -259,7 +259,7 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_bias, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 98a79f598b..7b026a2e43 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -198,7 +198,7 @@ def forward(self, x): if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: attn = (q @ k.transpose(-2, -1)) * self.scale diff --git a/timm/models/nest.py b/timm/models/nest.py index de57ec6e99..d1901cee21 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -59,14 +59,14 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.) def forward(self, x): """ x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim) - """ + """ B, T, N, C = x.shape # result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head) qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) if self.fused_attn: - x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p) + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.) else: q = q * self.scale attn = q @ k.transpose(-2, -1) # (B, H, T, N, N) @@ -330,7 +330,7 @@ def __init__( # Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the # number of blocks along edge of image self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0])) - + # Patch embedding self.patch_embed = PatchEmbed( img_size=img_size, diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 00379b158a..16302002eb 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -130,7 +130,7 @@ def forward(self, x, feat_size: List[int]): k, v = kv.unbind(0) if self.fused_attn: - x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p) + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.) else: q = q * self.scale attn = q @ k.transpose(-2, -1) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 41b45afb69..34452c7cf1 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -164,7 +164,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/twins.py b/timm/models/twins.py index b96a0234d0..3cd25fb433 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -75,7 +75,7 @@ def forward(self, x, size: Size_): if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale @@ -172,7 +172,7 @@ def forward(self, x, size: Size_): if self.fused_attn: x = torch.nn.functional.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 9f5da60be5..953fc64d5e 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -95,7 +95,7 @@ def forward(self, x): if self.fused_attn: x = torch.nn.functional.scaled_dot_product_attention( q.contiguous(), k.contiguous(), v.contiguous(), - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: attn = (q @ k.transpose(-2, -1)) * self.scale diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index c992c14f9f..2d374085c9 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -85,7 +85,7 @@ def forward(self, x): if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale @@ -285,7 +285,7 @@ def forward(self, x): if self.fused_attn: x_attn = F.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale @@ -1208,7 +1208,7 @@ def _cfg(url='', **kwargs): url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - + # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only) 'vit_small_patch14_dinov2.lvd142m': _cfg( url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth', @@ -1528,7 +1528,7 @@ def _cfg(url='', **kwargs): hf_hub_id='timm/', license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - + 'vit_huge_patch14_ijepa_224.in1k': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar', # hf_hub_id='timm/', @@ -2182,7 +2182,7 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 model_args = dict( - patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, + patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU ) model = _create_vision_transformer( diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index ea428587c1..2cd37cfe7e 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -71,7 +71,7 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_bias, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 53c49b071e..59b354fb3d 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -168,7 +168,7 @@ def forward(self, x): x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_bias, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale