From 1487ab18bf4990de77bd10abfe89ac77848caf24 Mon Sep 17 00:00:00 2001 From: Xiao Gui Date: Mon, 27 Nov 2023 13:13:41 +0100 Subject: [PATCH] feat: add kwargs to allow for shard on disk/in mem --- .../sharded_file_accessor.py | 31 ++++++++++++------- unit_tests/test_sharded_file_accessor.py | 16 ++++++++++ 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/src/neuroglancer_scripts/sharded_file_accessor.py b/src/neuroglancer_scripts/sharded_file_accessor.py index 6ffae2c..1d66bd5 100644 --- a/src/neuroglancer_scripts/sharded_file_accessor.py +++ b/src/neuroglancer_scripts/sharded_file_accessor.py @@ -124,9 +124,8 @@ class MiniShard(CMCReadWrite): can_read_cmc = False can_write_cmc = True - def __init__(self, shard_spec: ShardSpec, - offset: np.uint64 = np.uint64(0), - strategy="on disk"): + def __init__(self, shard_spec: ShardSpec, offset: np.uint64 = np.uint64(0), + strategy="on disk", **kwargs): super().__init__(shard_spec) self._offset = offset @@ -238,12 +237,14 @@ class Shard(ShardCMC): can_read_cmc = False can_write_cmc = True - def __init__(self, base_dir, shard_key: np.uint64, shard_spec: ShardSpec): + def __init__(self, base_dir, shard_key: np.uint64, shard_spec: ShardSpec, + **kwargs): self.root_dir = pathlib.Path(base_dir) super().__init__(shard_key, shard_spec) self.file_path = pathlib.Path(base_dir) / f"{self.shard_key_str}.shard" self.populate_minishard_dict() self.dirty = False + self.kwargs = kwargs def file_exists(self, filepath: str) -> bool: return (self.root_dir / filepath).is_file() @@ -271,7 +272,8 @@ def store_cmc_chunk(self, buf: bytes, cmc: np.uint64): minishard_key = self.get_minishard_key(cmc) if minishard_key not in self.minishard_dict: - self.minishard_dict[minishard_key] = MiniShard(self.shard_spec) + self.minishard_dict[minishard_key] = MiniShard(self.shard_spec, + **self.kwargs) self.minishard_dict[minishard_key].store_cmc_chunk(buf, cmc) def fetch_cmc_chunk(self, cmc: np.uint64): @@ -352,16 +354,19 @@ class ShardedScale(ShardedScaleBase): def __init__(self, base_dir, key: str, shard_spec: ShardSpec, - shard_volume_spec: ShardVolumeSpec): + shard_volume_spec: ShardVolumeSpec, + **kwargs): super().__init__(key, shard_spec, shard_volume_spec) self.base_dir = pathlib.Path(base_dir) / key self.shard_dict: Dict[np.uint64, Shard] = {} + self.kwargs = kwargs def get_shard(self, shard_key: np.uint64): if shard_key not in self.shard_dict: self.shard_dict[shard_key] = Shard(self.base_dir, shard_key, - self.shard_spec) + self.shard_spec, + **self.kwargs) return self.shard_dict[shard_key] def close(self): @@ -379,7 +384,7 @@ class ShardedFileAccessor(neuroglancer_scripts.accessor.Accessor, can_read = False can_write = True - def __init__(self, base_dir): + def __init__(self, base_dir, **kwargs): ShardedAccessorBase.__init__(self) self.base_dir = pathlib.Path(base_dir) self.base_dir.mkdir(exist_ok=True, parents=True) @@ -396,6 +401,8 @@ def __init__(self, base_dir): import atexit atexit.register(self.close) + self.kwargs = kwargs + def file_exists(self, relative_path: str): return (self.base_dir / relative_path).exists() @@ -419,7 +426,8 @@ def fetch_chunk(self, key, chunk_coords): sharded_scale = ShardedScale(base_dir=self.base_dir, key=key, shard_spec=shard_spec, - shard_volume_spec=shard_volume_spec) + shard_volume_spec=shard_volume_spec, + **self.kwargs) self.ro_shard_dict[key] = sharded_scale return self.ro_shard_dict[key].fetch_chunk(chunk_coords) @@ -454,9 +462,10 @@ def store_chunk(self, buf, key, chunk_coords, **kwargs): sharded_scale = ShardedScale(base_dir=self.base_dir, key=key, shard_spec=shard_spec, - shard_volume_spec=shard_volume_spec) + shard_volume_spec=shard_volume_spec, + **self.kwargs) self.shard_dict[key] = sharded_scale - self.shard_dict[key].store_chunk(buf, chunk_coords) + self.shard_dict[key].store_chunk(buf, chunk_coords, **kwargs) def close(self): if len(self.shard_dict) == 0: diff --git a/unit_tests/test_sharded_file_accessor.py b/unit_tests/test_sharded_file_accessor.py index dcd9b3e..981bba3 100644 --- a/unit_tests/test_sharded_file_accessor.py +++ b/unit_tests/test_sharded_file_accessor.py @@ -234,6 +234,11 @@ def test_minishard_close(shard_spec_2_2_2: ShardSpec, pre_store_cmc, # Shard +def test_shard_kwargs(tmpdir, shard_spec_2_2_2): + shard = Shard(tmpdir, np.uint64(0), shard_spec_2_2_2, foo="bar") + assert shard.kwargs == {"foo": "bar"} + + @pytest.mark.parametrize("write_files, readable, legacy", [ ((), False, False), (("{shard_key_str}.shard",), True, False), @@ -571,6 +576,12 @@ def test_shard_close_not_dirty(tmpdir, shard_spec_1_1_1): # ShardedScale +def test_shardscale_kwargs(tmpdir, shard_spec_2_2_2): + vol_spec = ShardVolumeSpec([64, 64, 64], [128, 128, 128]) + scale = ShardedScale(tmpdir, "5mm", shard_spec_2_2_2, vol_spec, foo="bar") + assert scale.kwargs == {"foo": "bar"} + + @pytest.fixture def sharded_scale(tmpdir, shard_spec_2_2_2): vol_spec = ShardVolumeSpec([64, 64, 64], [128, 128, 128]) @@ -646,6 +657,11 @@ def test_sharded_scale_close(sharded_scale: ShardedScale): } +def test_shardedfileaccessor_kwargs(tmpdir): + accessor = ShardedFileAccessor(tmpdir, foo="bar") + assert accessor.kwargs == {"foo": "bar"} + + @pytest.fixture def faccessor_r(tmpdir): with open(pathlib.Path(tmpdir) / "info", "w") as fp: