Skip to content

Commit

Permalink
Bugfix: wrong test assertion (#2316)
Browse files Browse the repository at this point in the history
* Assertion should be equal

* replace torch.manual_seed with torch.Generator().manual_seed

* Change manual see in example notebook

* ruff

* format tests, remove print
  • Loading branch information
sfalkena authored Sep 24, 2024
1 parent 19988f9 commit 1ea8e29
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
5 changes: 2 additions & 3 deletions docs/tutorials/custom_raster_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,9 @@
},
"outputs": [],
"source": [
"torch.manual_seed(1)\n",
"\n",
"dataset = Sentinel2(root)\n",
"sampler = RandomGeoSampler(dataset, size=4096, length=3)\n",
"g = torch.Generator().manual_seed(1)\n",
"sampler = RandomGeoSampler(dataset, size=4096, length=3, generator=g)\n",
"dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)\n",
"\n",
"for batch in dataloader:\n",
Expand Down
6 changes: 4 additions & 2 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,10 @@ def test_weighted_sampling(self) -> None:
def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler1 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sampler2 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
generator1 = torch.Generator().manual_seed(0)
generator2 = torch.Generator().manual_seed(0)
sampler1 = RandomBatchGeoSampler(ds, 1, 1, generator=generator1)
sampler2 = RandomBatchGeoSampler(ds, 1, 1, generator=generator2)
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 == sample2
Expand Down
18 changes: 9 additions & 9 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ def test_weighted_sampling(self) -> None:
def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler1 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sampler2 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
generator1 = torch.Generator().manual_seed(0)
generator2 = torch.Generator().manual_seed(0)
sampler1 = RandomGeoSampler(ds, 1, 1, generator=generator1)
sampler2 = RandomGeoSampler(ds, 1, 1, generator=generator2)
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 == sample2
Expand Down Expand Up @@ -302,15 +304,13 @@ def test_shuffle_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 11, 0, 11, 0, 11))
sampler1 = PreChippedGeoSampler(
ds, shuffle=True, generator=torch.manual_seed(2)
)
sampler2 = PreChippedGeoSampler(
ds, shuffle=True, generator=torch.manual_seed(2)
)
generator1 = torch.Generator().manual_seed(0)
generator2 = torch.Generator().manual_seed(0)
sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator1)
sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator2)
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 != sample2
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
Expand Down

0 comments on commit 1ea8e29

Please sign in to comment.