Skip to content

Commit

Permalink
update tests, simplified scale construction
Browse files Browse the repository at this point in the history
  • Loading branch information
xgui3783 committed Nov 6, 2023
1 parent 22711ae commit c3bd3d0
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 40 deletions.
48 changes: 48 additions & 0 deletions script_tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,54 @@ def test_all_in_one_conversion(examples_dir, tmpdir):
# with --mmap / --load-full-volume


def test_sharded_conversion(examples_dir, tmpdir):
input_nifti = examples_dir / "JuBrain" / "colin27T1_seg.nii.gz"
# The file may be present but be a git-lfs pointer file, so we need to open
# it to make sure that it is the actual correct file.
try:
gzip.open(str(input_nifti)).read(348)
except OSError as exc:
pytest.skip("Cannot find a valid example file {0} for testing: {1}"
.format(input_nifti, exc))

output_dir = tmpdir / "MPM"
assert subprocess.call([
"volume-to-precomputed",
"--generate-info",
"--sharding 2,2,0",
str(input_nifti),
str(output_dir)
], env=env) == 0
assert subprocess.call([
"generate-scales-info",
"--type=segmentation",
"--encoding=compressed_segmentation",
str(output_dir / "info_fullres.json"),
str(output_dir)
], env=env) == 0
assert subprocess.call([
"volume-to-precomputed",
"--sharding 2,2,0",
str(input_nifti),
str(output_dir)
], env=env) == 0
assert subprocess.call([
"compute-scales",
"--downscaling-method=stride", # for test speed
str(output_dir)
], env=env) == 0
assert subprocess.call([
"scale-stats",
str(output_dir),
], env=env) == 0
assert subprocess.call([
"convert-chunks",
"--copy-info",
str(output_dir),
str(output_dir / "copy")
], env=env) == 0


def test_slice_conversion(tmpdir):
# Prepare dummy slices
path_to_slices = tmpdir / "slices"
Expand Down
18 changes: 15 additions & 3 deletions src/neuroglancer_scripts/scripts/scale_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,27 @@ def show_scales_info(info):
for scale in info["scales"]:
scale_name = scale["key"]
size = scale["size"]

shard_info = "Unsharded"
shard_spec = scale.get("sharding")
sharding_num_directories = None
if shard_spec:
shard_bits = shard_spec.get("shard_bits")
shard_info = f"Sharded: {shard_bits}bits"
sharding_num_directories = 2 ** shard_bits + 1

for chunk_size in scale["chunk_sizes"]:
size_in_chunks = [(s - 1) // cs + 1 for s,
cs in zip(size, chunk_size)]
num_chunks = np.prod(size_in_chunks)
num_directories = size_in_chunks[0] * (1 + size_in_chunks[1])
num_directories = (
sharding_num_directories
if sharding_num_directories is not None
else size_in_chunks[0] * (1 + size_in_chunks[1]))
size_bytes = np.prod(size) * dtype.itemsize * num_channels
print("Scale {}, chunk size {}:"
print("Scale {}, {}, chunk size {}:"
" {:,d} chunks, {:,d} directories, raw uncompressed size {}B"
.format(scale_name, chunk_size,
.format(scale_name, shard_info, chunk_size,
num_chunks, num_directories,
readable_count(size_bytes)))
total_size += size_bytes
Expand Down
3 changes: 0 additions & 3 deletions src/neuroglancer_scripts/sharded_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,6 @@ def get_cmc(self, chunk_coords: List[int]) -> np.uint64:

return self.compressed_morton_code(grid_coords)

def generate_shard_spec(self) -> ShardSpec:
return ShardSpec(4, 4, "identity", "gzip", "gzip", 1)


class CMCReadWrite(ABC):

Expand Down
54 changes: 25 additions & 29 deletions src/neuroglancer_scripts/sharded_file_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,11 @@ class ShardedFileAccessor(neuroglancer_scripts.accessor.Accessor,
can_read = False
can_write = True

def __init__(self, base_dir,
shard_volume_spec_dict: Dict[str, ShardVolumeSpec] = {}):
def __init__(self, base_dir):
ShardedAccessorBase.__init__(self)
self.base_dir = pathlib.Path(base_dir)
self.shard_dict: Dict[str, ShardedScale] = {}
self.ro_shard_dict: Dict[str, ShardedScale] = {}
self.shard_volume_spec_dict = shard_volume_spec_dict

try:
self.info = json.loads(self.fetch_file("info"))
Expand Down Expand Up @@ -423,38 +421,34 @@ def fetch_chunk(self, key, chunk_coords):
self.ro_shard_dict[key] = sharded_scale
return self.ro_shard_dict[key].fetch_chunk(chunk_coords)

def store_chunk(self, buf, key, chunk_coords, **kwargs):
shard_volume_spec: ShardVolumeSpec = None
shard_spec: ShardSpec = None
if key in (self.shard_volume_spec_dict or {}):
shard_volume_spec = self.shard_volume_spec_dict[key]
def get_volume_shard_spec(self, key: str):
try:
found_scale = [s for s in self.info.get("scales", [])
if s.get("key") == key]
if len(found_scale) > 0:
scale, *_ = found_scale
chunk_sizes, = scale.get("chunk_sizes")
size = scale.get("size")
shard_volume_spec = ShardVolumeSpec(chunk_sizes, size)
sharding = scale.get("sharding")
if sharding:
sharding_kwargs = {key: value
for key, value in sharding.items()
if key != "@type"}
shard_spec = ShardSpec(**sharding_kwargs)
except Exception:
...
assert len(found_scale) == 1, ("Expecting one and only one scale "
f"with key {key}, but got "
f"{len(found_scale)}")

scale, *_ = found_scale
sharding = scale.get("sharding")

chunk_sizes, = scale.get("chunk_sizes")
size = scale.get("size")
shard_volume_spec = ShardVolumeSpec(chunk_sizes, size)

if not shard_volume_spec:
raise ShardedIOError(
f"Expecting key {key} in shard_volume_spec_dict, or defined in"
" 'info' file, but were not. Existing keys: "
f"{list(self.shard_volume_spec_dict.keys())}. Existing info:"
f"{json.dumps(self.info, indent=4)}")
sharding_kwargs = {key: value
for key, value in sharding.items()
if key != "@type"}
shard_spec = ShardSpec(**sharding_kwargs)

return shard_volume_spec, shard_spec
except Exception as e:
raise ShardedIOError from e

def store_chunk(self, buf, key, chunk_coords, **kwargs):
if key not in self.shard_dict:
if not shard_spec:
shard_spec = shard_volume_spec.generate_shard_spec()
shard_volume_spec, shard_spec = self.get_volume_shard_spec(key)

sharded_scale = ShardedScale(base_dir=self.base_dir,
key=key,
shard_spec=shard_spec,
Expand All @@ -463,6 +457,8 @@ def store_chunk(self, buf, key, chunk_coords, **kwargs):
self.shard_dict[key].store_chunk(buf, chunk_coords)

def close(self):
if len(self.shard_dict) == 0:
return
scale_arr = []
for scale in self.shard_dict.values():
scale.close()
Expand Down
56 changes: 56 additions & 0 deletions unit_tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

import argparse
import pathlib
import json

import pytest
from unittest.mock import patch

from neuroglancer_scripts.accessor import (
get_accessor_for_url,
Expand All @@ -17,6 +19,9 @@
)
from neuroglancer_scripts.file_accessor import FileAccessor
from neuroglancer_scripts.http_accessor import HttpAccessor
from neuroglancer_scripts.sharded_base import ShardedAccessorBase
from neuroglancer_scripts.sharded_file_accessor import ShardedFileAccessor
from neuroglancer_scripts.sharded_http_accessor import ShardedHttpAccessor


@pytest.mark.parametrize("accessor_options", [
Expand All @@ -40,6 +45,57 @@ def test_get_accessor_for_url(accessor_options):
get_accessor_for_url("file:///%ff", accessor_options)


valid_info_str = json.dumps({
"scales": [
{
"key": "foo",
"sharding": {
"@type": "neuroglancer_uint64_sharded_v1"
}
}
]
})


@patch.object(ShardedAccessorBase, "info_is_sharded")
@pytest.mark.parametrize("scheme", ["https://", "http://", "file:///", ""])
@pytest.mark.parametrize("fetch_file_returns, info_is_sharded_returns, exp", [
(Exception("foobar"), None, False),
('mal formed json', None, False),
(valid_info_str, Exception("foobar"), False),
(valid_info_str, False, False),
(valid_info_str, True, True),
])
def test_sharded_accessor_via_info(info_is_sharded_mock, fetch_file_returns,
info_is_sharded_returns, exp, scheme):

if isinstance(info_is_sharded_returns, Exception):
info_is_sharded_mock.side_effect = info_is_sharded_returns
else:
info_is_sharded_mock.return_value = info_is_sharded_returns

assert scheme in ("https://", "http://", "file:///", "")
if scheme in ("file:///", ""):
base_acc_cls = FileAccessor
shard_accessor_cls = ShardedFileAccessor
if scheme in ("https://", "http://"):
base_acc_cls = HttpAccessor
shard_accessor_cls = ShardedHttpAccessor
with patch.object(base_acc_cls, "fetch_file") as fetch_file_mock:
if isinstance(fetch_file_returns, Exception):
fetch_file_mock.side_effect = fetch_file_returns
else:
fetch_file_mock.return_value = fetch_file_returns

result = get_accessor_for_url(f"{scheme}example/")
assert isinstance(result, shard_accessor_cls if exp else base_acc_cls)

if info_is_sharded_returns is None:
info_is_sharded_mock.assert_not_called()
else:
info_is_sharded_mock.assert_called_once()


@pytest.mark.parametrize("write_chunks", [True, False])
@pytest.mark.parametrize("write_files", [True, False])
def test_add_argparse_options(write_chunks, write_files):
Expand Down
13 changes: 13 additions & 0 deletions unit_tests/test_sharded_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,19 @@ def test_to_json(mocked_get_shard_base):
} ,
True
),
(
{
"scales": [
{
"key": "foo",
"sharding": {
"@type": "foo-bar"
}
}
]
} ,
True
),

(
{
Expand Down
Loading

0 comments on commit c3bd3d0

Please sign in to comment.