Skip to content

Commit

Permalink
Testing: shared load_state_dict_from_url monkeypatch (#2223)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored Aug 13, 2024
1 parent 9406f67 commit e885ccc
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 90 deletions.
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from typing import Any

import pytest
import torch
import torchvision
from pytest import MonkeyPatch


def load(*args: Any, progress: bool = False, **kwargs: Any) -> Any:
return torch.load(*args, **kwargs)


@pytest.fixture
def load_state_dict_from_url(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
21 changes: 10 additions & 11 deletions tests/models/test_dofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import torch
import torchvision
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum
Expand All @@ -22,11 +20,6 @@
)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestDOFA:
@pytest.mark.parametrize(
'wavelengths',
Expand Down Expand Up @@ -86,7 +79,11 @@ def weights(self, request: SubRequest) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = dofa_base_patch16_224()
Expand All @@ -95,7 +92,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_dofa(self) -> None:
Expand Down Expand Up @@ -123,7 +119,11 @@ def weights(self, request: SubRequest) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = dofa_large_patch16_224()
Expand All @@ -132,7 +132,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_dofa(self) -> None:
Expand Down
21 changes: 10 additions & 11 deletions tests/models/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,29 @@
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum

from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestResNet18:
@pytest.fixture(params=[*ResNet18_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model('resnet18', in_chans=weights.meta['in_chans'])
Expand All @@ -36,7 +33,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_resnet(self) -> None:
Expand Down Expand Up @@ -64,7 +60,11 @@ def weights(self, request: SubRequest) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model('resnet50', in_chans=weights.meta['in_chans'])
Expand All @@ -73,7 +73,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_resnet(self) -> None:
Expand Down
13 changes: 5 additions & 8 deletions tests/models/test_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import torch
Expand All @@ -14,19 +13,18 @@
from torchgeo.models import Swin_V2_B_Weights, swin_v2_b


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestSwin_V2_B:
@pytest.fixture(params=[*Swin_V2_B_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = torchvision.models.swin_v2_b()
Expand All @@ -35,7 +33,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_swin_v2_b(self) -> None:
Expand Down
14 changes: 5 additions & 9 deletions tests/models/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,29 @@
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum

from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestViTSmall16:
@pytest.fixture(params=[*ViTSmall16_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
Expand All @@ -38,7 +35,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_vit(self) -> None:
Expand Down
14 changes: 5 additions & 9 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@

import os
from pathlib import Path
from typing import Any

import pytest
import timm
import torch
import torch.nn as nn
import torchvision
from pytest import MonkeyPatch
from torchvision.models import resnet18
from torchvision.models._api import WeightsEnum
Expand All @@ -21,11 +19,6 @@
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestBYOL:
def test_custom_augment_fn(self) -> None:
model = resnet18()
Expand Down Expand Up @@ -88,7 +81,11 @@ def weights(self) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
Expand All @@ -99,7 +96,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
Expand Down
13 changes: 5 additions & 8 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import timm
import torch
import torch.nn as nn
import torchvision
from lightning.pytorch import Trainer
from pytest import MonkeyPatch
from torch.nn.modules import Module
Expand Down Expand Up @@ -56,11 +55,6 @@ def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


def plot(*args: Any, **kwargs: Any) -> None:
return None

Expand Down Expand Up @@ -125,7 +119,11 @@ def weights(self) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
Expand All @@ -136,7 +134,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
Expand Down
13 changes: 5 additions & 8 deletions tests/trainers/test_moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
import timm
import torch
import torchvision
from pytest import MonkeyPatch
from torch.nn import Module
from torchvision.models._api import WeightsEnum
Expand All @@ -25,11 +24,6 @@ def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestMoCoTask:
@pytest.mark.parametrize(
'name',
Expand Down Expand Up @@ -89,7 +83,11 @@ def weights(self) -> WeightsEnum:

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
Expand All @@ -100,7 +98,6 @@ def mocked_weights(
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
Expand Down
Loading

0 comments on commit e885ccc

Please sign in to comment.