From 252978978d875aea8f744614bb2614aa721ed088 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 18 Dec 2021 14:54:33 -0600 Subject: [PATCH 1/2] Fix GeoDataset pickling --- tests/datasets/test_geo.py | 9 +++++++++ torchgeo/datasets/geo.py | 24 ++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index b3380612b17..b71b08eca93 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import os +import pickle from pathlib import Path from typing import Dict @@ -121,6 +122,14 @@ def test_str(self, dataset: GeoDataset) -> None: assert "bbox: BoundingBox" in out assert "size: 1" in out + def test_picklable(self, dataset: GeoDataset) -> None: + x = pickle.dumps(dataset) + y = pickle.loads(x) + assert dataset.crs == y.crs + assert dataset.res == y.res + assert len(dataset) == len(y) + assert dataset.bounds == y.bounds + def test_abstract(self) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): GeoDataset() # type: ignore[abstract] diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 54509ce70f7..d10869a3194 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -172,6 +172,30 @@ def __str__(self) -> str: bbox: {self.bounds} size: {len(self)}""" + # NOTE: This hack should be removed once the following issue is fixed: + # https://github.com/Toblerity/rtree/issues/87 + + def __getstate__(self): + """Define how instances are pickled. + + Returns: + the state necessary to unpickle the instance + """ + index = self.index.intersection(self.index.bounds, objects=True) + index = [(item.id, item.bounds, item.object) for item in index] + return self.__dict__, index + + def __setstate__(self, state): + """Define how to unpickle an instance. + + Args: + state: the state of the instance when it was pickled + """ + attrs, index = state + self.__dict__.update(attrs) + for item in index: + self.index.insert(*item) + @property def bounds(self) -> BoundingBox: """Bounds of the index. From 6b97307242d072b860dc089348c35fe97a325842 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 18 Dec 2021 15:26:22 -0600 Subject: [PATCH 2/2] mypy fixes --- torchgeo/datasets/geo.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index d10869a3194..5c09a2d8930 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -175,25 +175,36 @@ def __str__(self) -> str: # NOTE: This hack should be removed once the following issue is fixed: # https://github.com/Toblerity/rtree/issues/87 - def __getstate__(self): + def __getstate__( + self, + ) -> Tuple[ + Dict[Any, Any], + List[Tuple[int, Tuple[float, float, float, float, float, float], str]], + ]: """Define how instances are pickled. Returns: the state necessary to unpickle the instance """ - index = self.index.intersection(self.index.bounds, objects=True) - index = [(item.id, item.bounds, item.object) for item in index] - return self.__dict__, index + objects = self.index.intersection(self.index.bounds, objects=True) + tuples = [(item.id, item.bounds, item.object) for item in objects] + return self.__dict__, tuples - def __setstate__(self, state): + def __setstate__( + self, + state: Tuple[ + Dict[Any, Any], + List[Tuple[int, Tuple[float, float, float, float, float, float], str]], + ], + ) -> None: """Define how to unpickle an instance. Args: state: the state of the instance when it was pickled """ - attrs, index = state + attrs, tuples = state self.__dict__.update(attrs) - for item in index: + for item in tuples: self.index.insert(*item) @property