Skip to content

Commit

Permalink
feat: add kwargs to allow for shard on disk/in mem
Browse files Browse the repository at this point in the history
  • Loading branch information
xgui3783 committed Nov 27, 2023
1 parent 2ab8f19 commit 1487ab1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
31 changes: 20 additions & 11 deletions src/neuroglancer_scripts/sharded_file_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ class MiniShard(CMCReadWrite):
can_read_cmc = False
can_write_cmc = True

def __init__(self, shard_spec: ShardSpec,
offset: np.uint64 = np.uint64(0),
strategy="on disk"):
def __init__(self, shard_spec: ShardSpec, offset: np.uint64 = np.uint64(0),
strategy="on disk", **kwargs):
super().__init__(shard_spec)
self._offset = offset

Expand Down Expand Up @@ -238,12 +237,14 @@ class Shard(ShardCMC):
can_read_cmc = False
can_write_cmc = True

def __init__(self, base_dir, shard_key: np.uint64, shard_spec: ShardSpec):
def __init__(self, base_dir, shard_key: np.uint64, shard_spec: ShardSpec,
**kwargs):
self.root_dir = pathlib.Path(base_dir)
super().__init__(shard_key, shard_spec)
self.file_path = pathlib.Path(base_dir) / f"{self.shard_key_str}.shard"
self.populate_minishard_dict()
self.dirty = False
self.kwargs = kwargs

def file_exists(self, filepath: str) -> bool:
return (self.root_dir / filepath).is_file()
Expand Down Expand Up @@ -271,7 +272,8 @@ def store_cmc_chunk(self, buf: bytes, cmc: np.uint64):

minishard_key = self.get_minishard_key(cmc)
if minishard_key not in self.minishard_dict:
self.minishard_dict[minishard_key] = MiniShard(self.shard_spec)
self.minishard_dict[minishard_key] = MiniShard(self.shard_spec,
**self.kwargs)
self.minishard_dict[minishard_key].store_cmc_chunk(buf, cmc)

def fetch_cmc_chunk(self, cmc: np.uint64):
Expand Down Expand Up @@ -352,16 +354,19 @@ class ShardedScale(ShardedScaleBase):

def __init__(self, base_dir, key: str,
shard_spec: ShardSpec,
shard_volume_spec: ShardVolumeSpec):
shard_volume_spec: ShardVolumeSpec,
**kwargs):
super().__init__(key, shard_spec, shard_volume_spec)
self.base_dir = pathlib.Path(base_dir) / key
self.shard_dict: Dict[np.uint64, Shard] = {}
self.kwargs = kwargs

def get_shard(self, shard_key: np.uint64):
if shard_key not in self.shard_dict:
self.shard_dict[shard_key] = Shard(self.base_dir,
shard_key,
self.shard_spec)
self.shard_spec,
**self.kwargs)
return self.shard_dict[shard_key]

def close(self):
Expand All @@ -379,7 +384,7 @@ class ShardedFileAccessor(neuroglancer_scripts.accessor.Accessor,
can_read = False
can_write = True

def __init__(self, base_dir):
def __init__(self, base_dir, **kwargs):
ShardedAccessorBase.__init__(self)
self.base_dir = pathlib.Path(base_dir)
self.base_dir.mkdir(exist_ok=True, parents=True)
Expand All @@ -396,6 +401,8 @@ def __init__(self, base_dir):
import atexit
atexit.register(self.close)

self.kwargs = kwargs

def file_exists(self, relative_path: str):
return (self.base_dir / relative_path).exists()

Expand All @@ -419,7 +426,8 @@ def fetch_chunk(self, key, chunk_coords):
sharded_scale = ShardedScale(base_dir=self.base_dir,
key=key,
shard_spec=shard_spec,
shard_volume_spec=shard_volume_spec)
shard_volume_spec=shard_volume_spec,
**self.kwargs)
self.ro_shard_dict[key] = sharded_scale
return self.ro_shard_dict[key].fetch_chunk(chunk_coords)

Expand Down Expand Up @@ -454,9 +462,10 @@ def store_chunk(self, buf, key, chunk_coords, **kwargs):
sharded_scale = ShardedScale(base_dir=self.base_dir,
key=key,
shard_spec=shard_spec,
shard_volume_spec=shard_volume_spec)
shard_volume_spec=shard_volume_spec,
**self.kwargs)
self.shard_dict[key] = sharded_scale
self.shard_dict[key].store_chunk(buf, chunk_coords)
self.shard_dict[key].store_chunk(buf, chunk_coords, **kwargs)

def close(self):
if len(self.shard_dict) == 0:
Expand Down
16 changes: 16 additions & 0 deletions unit_tests/test_sharded_file_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ def test_minishard_close(shard_spec_2_2_2: ShardSpec, pre_store_cmc,


# Shard
def test_shard_kwargs(tmpdir, shard_spec_2_2_2):
shard = Shard(tmpdir, np.uint64(0), shard_spec_2_2_2, foo="bar")
assert shard.kwargs == {"foo": "bar"}


@pytest.mark.parametrize("write_files, readable, legacy", [
((), False, False),
(("{shard_key_str}.shard",), True, False),
Expand Down Expand Up @@ -571,6 +576,12 @@ def test_shard_close_not_dirty(tmpdir, shard_spec_1_1_1):


# ShardedScale
def test_shardscale_kwargs(tmpdir, shard_spec_2_2_2):
vol_spec = ShardVolumeSpec([64, 64, 64], [128, 128, 128])
scale = ShardedScale(tmpdir, "5mm", shard_spec_2_2_2, vol_spec, foo="bar")
assert scale.kwargs == {"foo": "bar"}


@pytest.fixture
def sharded_scale(tmpdir, shard_spec_2_2_2):
vol_spec = ShardVolumeSpec([64, 64, 64], [128, 128, 128])
Expand Down Expand Up @@ -646,6 +657,11 @@ def test_sharded_scale_close(sharded_scale: ShardedScale):
}


def test_shardedfileaccessor_kwargs(tmpdir):
accessor = ShardedFileAccessor(tmpdir, foo="bar")
assert accessor.kwargs == {"foo": "bar"}


@pytest.fixture
def faccessor_r(tmpdir):
with open(pathlib.Path(tmpdir) / "info", "w") as fp:
Expand Down

0 comments on commit 1487ab1

Please sign in to comment.