Skip to content

Commit

Permalink
Merge pull request #736 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v2.8.1
  • Loading branch information
Kevin Musgrave authored Dec 11, 2024
2 parents f4c0a7c + 6b61de6 commit 216a792
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 13 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/test_datasets.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: datasets

on:
pull_request:
branches: [ master, dev ]
paths:
- 'src/**'
- 'tests/**'
- '.github/workflows/**'

jobs:
call-base-test-workflow:
uses: ./.github/workflows/base_test_workflow.yml
with:
module-to-test: datasets
8 changes: 4 additions & 4 deletions docs/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ datasets.base_dataset.BaseDataset(
## CUB-200-2011

```python
datasets.cub.CUB(*args, **kwargs)
datasets.CUB(*args, **kwargs)
```

**Defined splits**:
Expand Down Expand Up @@ -75,7 +75,7 @@ train_and_test_dataset = CUB(root="data",
## Cars196

```python
datasets.cars196.Cars196(*args, **kwargs)
datasets.Cars196(*args, **kwargs)
```

**Defined splits**:
Expand Down Expand Up @@ -110,7 +110,7 @@ train_and_test_dataset = Cars196(root="data",
## INaturalist2018

```python
datasets.inaturalist2018.INaturalist2018(*args, **kwargs)
datasets.INaturalist2018(*args, **kwargs)
```

**Defined splits**:
Expand Down Expand Up @@ -146,7 +146,7 @@ train_and_test_dataset = INaturalist2018(root="data",
## StanfordOnlineProducts

```python
datasets.sop.StanfordOnlineProducts(*args, **kwargs)
datasets.StanfordOnlineProducts(*args, **kwargs)
```

**Defined splits**:
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.8.0"
__version__ = "2.8.1"
5 changes: 5 additions & 0 deletions src/pytorch_metric_learning/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .base_dataset import BaseDataset
from .cars196 import Cars196
from .cub import CUB
from .inaturalist2018 import INaturalist2018
from .sop import StanfordOnlineProducts
2 changes: 2 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
)
device_from_environ = os.environ.get("TEST_DEVICE", "cuda")
with_collect_stats = os.environ.get("WITH_COLLECT_STATS", "false")
test_datasets = os.environ.get("TEST_DATASETS", "false")

TEST_DTYPES = [getattr(torch, x) for x in dtypes_from_environ]
TEST_DEVICE = torch.device(device_from_environ)

assert c_f.COLLECT_STATS is False

WITH_COLLECT_STATS = True if with_collect_stats == "true" else False
TEST_DATASETS = True if test_datasets == "true" else False
c_f.COLLECT_STATS = WITH_COLLECT_STATS

print(
Expand Down
7 changes: 5 additions & 2 deletions tests/datasets/test_cars196.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from pytorch_metric_learning.datasets.cars196 import Cars196
from pytorch_metric_learning.datasets import Cars196
from .. import TEST_DATASETS


class TestCars196(unittest.TestCase):
Expand All @@ -19,6 +20,7 @@ def setUpClass(cls):
if os.path.exists(cls.CARS_196_ROOT):
cls.ALREADY_EXISTS = True

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_Cars196(self):
train_test_data = Cars196(
root=TestCars196.CARS_196_ROOT, split="train+test", download=True
Expand All @@ -34,6 +36,7 @@ def test_Cars196(self):
self.assertTrue(len(train_data) == 8054)
self.assertTrue(len(test_data) == 8131)

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_CARS_196_dataloader(self):
test_data = Cars196(
root=TestCars196.CARS_196_ROOT,
Expand All @@ -50,5 +53,5 @@ def test_CARS_196_dataloader(self):

@classmethod
def tearDownClass(cls):
if not cls.ALREADY_EXISTS:
if not cls.ALREADY_EXISTS and os.path.isdir(cls.CARS_196_ROOT):
shutil.rmtree(cls.CARS_196_ROOT)
7 changes: 5 additions & 2 deletions tests/datasets/test_cub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from pytorch_metric_learning.datasets.cub import CUB
from pytorch_metric_learning.datasets import CUB
from .. import TEST_DATASETS


class TestCUB(unittest.TestCase):
Expand All @@ -19,6 +20,7 @@ def setUpClass(cls):
if os.path.exists(cls.CUB_ROOT):
cls.ALREADY_EXISTS = True

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_CUB(self):
train_test_data = CUB(root=TestCUB.CUB_ROOT, split="train+test", download=True)
train_data = CUB(root=TestCUB.CUB_ROOT, split="train", download=True)
Expand All @@ -28,6 +30,7 @@ def test_CUB(self):
self.assertTrue(len(train_data) == 5864)
self.assertTrue(len(test_data) == 5924)

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_CUB_dataloader(self):
test_data = CUB(
root=TestCUB.CUB_ROOT,
Expand All @@ -44,5 +47,5 @@ def test_CUB_dataloader(self):

@classmethod
def tearDownClass(cls):
if not cls.ALREADY_EXISTS:
if not cls.ALREADY_EXISTS and os.path.isdir(cls.CUB_ROOT):
shutil.rmtree(cls.CUB_ROOT)
7 changes: 5 additions & 2 deletions tests/datasets/test_inaturalist2018.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from pytorch_metric_learning.datasets.inaturalist2018 import INaturalist2018
from pytorch_metric_learning.datasets import INaturalist2018
from .. import TEST_DATASETS


class TestINaturalist2018(unittest.TestCase):
Expand All @@ -19,6 +20,7 @@ def setUpClass(cls):
if os.path.exists(cls.INATURALIST2018_ROOT):
cls.ALREADY_EXISTS = True

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_INaturalist2018(self):
train_test_data = INaturalist2018(
root=TestINaturalist2018.INATURALIST2018_ROOT,
Expand All @@ -36,6 +38,7 @@ def test_INaturalist2018(self):
self.assertTrue(len(train_data) == 325846)
self.assertTrue(len(test_data) == 136093)

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_INaturalist2018_dataloader(self):
test_data = INaturalist2018(
root=TestINaturalist2018.INATURALIST2018_ROOT,
Expand All @@ -52,5 +55,5 @@ def test_INaturalist2018_dataloader(self):

@classmethod
def tearDownClass(cls):
if not cls.ALREADY_EXISTS:
if not cls.ALREADY_EXISTS and os.path.isdir(cls.INATURALIST2018_ROOT):
shutil.rmtree(cls.INATURALIST2018_ROOT)
8 changes: 6 additions & 2 deletions tests/datasets/test_sop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from pytorch_metric_learning.datasets.sop import StanfordOnlineProducts
from pytorch_metric_learning.datasets import StanfordOnlineProducts
from .. import TEST_DATASETS


class TestStanfordOnlineProducts(unittest.TestCase):
Expand All @@ -19,6 +20,7 @@ def setUpClass(cls):
if os.path.exists(cls.SOP_ROOT):
cls.ALREADY_EXISTS = True

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_SOP(self):
train_test_data = StanfordOnlineProducts(
root=TestStanfordOnlineProducts.SOP_ROOT, split="train+test", download=True
Expand All @@ -34,6 +36,7 @@ def test_SOP(self):
self.assertTrue(len(train_data) == 59551)
self.assertTrue(len(test_data) == 60502)

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_SOP_dataloader(self):
test_data = StanfordOnlineProducts(
root=TestStanfordOnlineProducts.SOP_ROOT,
Expand All @@ -48,7 +51,8 @@ def test_SOP_dataloader(self):
self.assertTupleEqual(tuple(inputs.shape), (8, 3, 224, 224))
self.assertTupleEqual(tuple(labels.shape), (8,))

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
@classmethod
def tearDownClass(cls):
if not cls.ALREADY_EXISTS:
if not cls.ALREADY_EXISTS and os.path.isdir(cls.SOP_ROOT):
shutil.rmtree(cls.SOP_ROOT)

0 comments on commit 216a792

Please sign in to comment.