Skip to content

Commit

Permalink
Sampling random generator (#2309)
Browse files Browse the repository at this point in the history
* Add random generator

* Add tests for seed

* pass generator every sampler

* Simplification of tests, docstring updates

* try to pass docs build

* forgotten updates

* equal should have been unequal
  • Loading branch information
sfalkena authored Sep 23, 2024
1 parent 59e6531 commit dfebdae
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 8 deletions.
10 changes: 10 additions & 0 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import product

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -144,6 +145,15 @@ def test_weighted_sampling(self) -> None:
for bbox in batch:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler1 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sampler2 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
24 changes: 24 additions & 0 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import product

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -139,6 +140,15 @@ def test_weighted_sampling(self) -> None:
for bbox in sampler:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler1 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sampler2 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down Expand Up @@ -288,6 +298,20 @@ def test_point_data(self) -> None:
for _ in sampler:
continue

def test_shuffle_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 11, 0, 11, 0, 11))
sampler1 = PreChippedGeoSampler(
ds, shuffle=True, generator=torch.manual_seed(2)
)
sampler2 = PreChippedGeoSampler(
ds, shuffle=True, generator=torch.manual_seed(2)
)
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 != sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
6 changes: 5 additions & 1 deletion torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def setup(self, stage: str) -> None:

if stage in ['fit']:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
self.train_dataset,
self.patch_size,
self.batch_size,
self.length,
generator=generator,
)
if stage in ['fit', 'validate']:
self.val_sampler = GridGeoSampler(
Expand Down
11 changes: 10 additions & 1 deletion torchgeo/samplers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from rtree.index import Index, Property
from torch import Generator
from torch.utils.data import Sampler

from ..datasets import BoundingBox, GeoDataset
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
length: int | None = None,
roi: BoundingBox | None = None,
units: Units = Units.PIXELS,
generator: Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.
Expand All @@ -86,6 +88,9 @@ def __init__(
.. versionchanged:: 0.4
``length`` parameter is now optional, a reasonable default will be used
.. versionadded:: 0.7
The *generator* parameter.
Args:
dataset: dataset to index from
size: dimensions of each :term:`patch`
Expand All @@ -97,9 +102,11 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
units: defines if ``size`` is in pixel or CRS units
generator: pseudo-random number generator (PRNG).
"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)
self.generator = generator

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)
Expand Down Expand Up @@ -144,7 +151,9 @@ def __iter__(self) -> Iterator[list[BoundingBox]]:
# Choose random indices within that tile
batch = []
for _ in range(self.batch_size):
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
bounding_box = get_random_bounding_box(
bounds, self.size, self.res, self.generator
)
batch.append(bounding_box)

yield batch
Expand Down
26 changes: 23 additions & 3 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import abc
from collections.abc import Callable, Iterable, Iterator
from functools import partial

import torch
from rtree.index import Index, Property
from torch import Generator
from torch.utils.data import Sampler

from ..datasets import BoundingBox, GeoDataset
Expand Down Expand Up @@ -72,6 +74,7 @@ def __init__(
length: int | None = None,
roi: BoundingBox | None = None,
units: Units = Units.PIXELS,
generator: Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.
Expand All @@ -88,6 +91,9 @@ def __init__(
.. versionchanged:: 0.4
``length`` parameter is now optional, a reasonable default will be used
.. versionadded:: 0.7
The *generator* parameter.
Args:
dataset: dataset to index from
size: dimensions of each :term:`patch`
Expand All @@ -98,13 +104,15 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
units: defines if ``size`` is in pixel or CRS units
generator: pseudo-random number generator (PRNG).
"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)

self.generator = generator
self.length = 0
self.hits = []
areas = []
Expand Down Expand Up @@ -142,7 +150,9 @@ def __iter__(self) -> Iterator[BoundingBox]:
bounds = BoundingBox(*hit.bounds)

# Choose a random index within that tile
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
bounding_box = get_random_bounding_box(
bounds, self.size, self.res, self.generator
)

yield bounding_box

Expand Down Expand Up @@ -270,20 +280,30 @@ class PreChippedGeoSampler(GeoSampler):
"""

def __init__(
self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False
self,
dataset: GeoDataset,
roi: BoundingBox | None = None,
shuffle: bool = False,
generator: torch.Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.
.. versionadded:: 0.3
.. versionadded:: 0.7
The *generator* parameter.
Args:
dataset: dataset to index from
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
shuffle: if True, reshuffle data at every epoch
generator: pseudo-random number generator (PRNG) used in combination with shuffle.
"""
super().__init__(dataset, roi)
self.shuffle = shuffle
self.generator = generator

self.hits = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
Expand All @@ -297,7 +317,7 @@ def __iter__(self) -> Iterator[BoundingBox]:
"""
generator: Callable[[int], Iterable[int]] = range
if self.shuffle:
generator = torch.randperm
generator = partial(torch.randperm, generator=self.generator)

for idx in generator(len(self)):
yield BoundingBox(*self.hits[idx].bounds)
Expand Down
14 changes: 11 additions & 3 deletions torchgeo/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import overload

import torch
from torch import Generator

from ..datasets import BoundingBox

Expand Down Expand Up @@ -35,7 +36,10 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]:


def get_random_bounding_box(
bounds: BoundingBox, size: tuple[float, float] | float, res: float
bounds: BoundingBox,
size: tuple[float, float] | float,
res: float,
generator: Generator | None = None,
) -> BoundingBox:
"""Returns a random bounding box within a given bounding box.
Expand All @@ -46,10 +50,14 @@ def get_random_bounding_box(
* a ``tuple`` of two floats - in which case, the first *float* is used for the
height dimension, and the second *float* for the width dimension
.. versionadded:: 0.7
The *generator* parameter.
Args:
bounds: the larger bounding box to sample from
size: the size of the bounding box to sample
res: the resolution of the image
generator: pseudo-random number generator (PRNG).
Returns:
randomly sampled bounding box from the extent of the input
Expand All @@ -64,8 +72,8 @@ def get_random_bounding_box(
miny = bounds.miny

# Use an integer multiple of res to avoid resampling
minx += int(torch.rand(1).item() * width) * res
miny += int(torch.rand(1).item() * height) * res
minx += int(torch.rand(1, generator=generator).item() * width) * res
miny += int(torch.rand(1, generator=generator).item() * height) * res

maxx = minx + t_size[1]
maxy = miny + t_size[0]
Expand Down

0 comments on commit dfebdae

Please sign in to comment.