diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb index 29d55c6fd18..0e2785bb7ff 100644 --- a/docs/tutorials/custom_raster_dataset.ipynb +++ b/docs/tutorials/custom_raster_dataset.ipynb @@ -478,10 +478,9 @@ }, "outputs": [], "source": [ - "torch.manual_seed(1)\n", - "\n", "dataset = Sentinel2(root)\n", - "sampler = RandomGeoSampler(dataset, size=4096, length=3)\n", + "g = torch.Generator().manual_seed(1)\n", + "sampler = RandomGeoSampler(dataset, size=4096, length=3, generator=g)\n", "dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)\n", "\n", "for batch in dataloader:\n", diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 01dea100ca8..199239a0e79 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -148,8 +148,10 @@ def test_weighted_sampling(self) -> None: def test_random_seed(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - sampler1 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) - sampler2 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) + generator1 = torch.Generator().manual_seed(0) + generator2 = torch.Generator().manual_seed(0) + sampler1 = RandomBatchGeoSampler(ds, 1, 1, generator=generator1) + sampler2 = RandomBatchGeoSampler(ds, 1, 1, generator=generator2) sample1 = next(iter(sampler1)) sample2 = next(iter(sampler2)) assert sample1 == sample2 diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 743c8be70da..e2c829f1b9e 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -143,8 +143,10 @@ def test_weighted_sampling(self) -> None: def test_random_seed(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - sampler1 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) - sampler2 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) + generator1 = torch.Generator().manual_seed(0) + generator2 = torch.Generator().manual_seed(0) + sampler1 = RandomGeoSampler(ds, 1, 1, generator=generator1) + sampler2 = RandomGeoSampler(ds, 1, 1, generator=generator2) sample1 = next(iter(sampler1)) sample2 = next(iter(sampler2)) assert sample1 == sample2 @@ -302,15 +304,13 @@ def test_shuffle_seed(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) ds.index.insert(1, (0, 11, 0, 11, 0, 11)) - sampler1 = PreChippedGeoSampler( - ds, shuffle=True, generator=torch.manual_seed(2) - ) - sampler2 = PreChippedGeoSampler( - ds, shuffle=True, generator=torch.manual_seed(2) - ) + generator1 = torch.Generator().manual_seed(0) + generator2 = torch.Generator().manual_seed(0) + sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator1) + sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator2) sample1 = next(iter(sampler1)) sample2 = next(iter(sampler2)) - assert sample1 != sample2 + assert sample1 == sample2 @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2])