Skip to content

Commit

Permalink
Merge pull request #91 from girder/dataset-download
Browse files Browse the repository at this point in the history
  • Loading branch information
jbeezley authored Jun 24, 2021
2 parents 363edb8 + 0bbb9e3 commit fc158b7
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 67 deletions.
4 changes: 3 additions & 1 deletion swcc/swcc/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from collections import defaultdict
from contextlib import contextmanager
from typing import List, Optional
from typing import Any, Dict, List, Optional

import requests
from requests_toolbelt.sessions import BaseUrlSession
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
base_url = f'{base_url.rstrip("/")}/' # tolerate input with or without trailing slash
super().__init__(base_url=base_url, **kwargs)

self.cache: Dict[Any, Dict[int, Any]] = defaultdict(dict)
retry = Retry()
adapter = requests.adapters.HTTPAdapter(max_retries=retry)
self.mount(base_url, adapter)
Expand Down
117 changes: 87 additions & 30 deletions swcc/swcc/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,19 @@

from pathlib import Path, PurePath
from tempfile import TemporaryDirectory
from typing import Any, Dict, Generic, Iterator, Literal, Optional, Type, TypeVar, Union, get_args
from typing import (
Any,
Dict,
Generic,
Iterator,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
get_args,
)
from urllib.parse import unquote

from openpyxl import load_workbook
Expand Down Expand Up @@ -152,16 +164,19 @@ def fetch_entity(cls, v, field: ModelField):
@classmethod
def from_id(cls: Type[ModelType], id: int, **kwargs) -> ModelType:
session = current_session()
cache = session.cache[cls]

r: requests.Response = session.get(f'{cls._endpoint}/{id}/')
raise_for_status(r)
json = r.json()
for key, value in cls.__fields__.items():
if key in kwargs:
json[key] = kwargs[key]
elif value.type_ is not Any and issubclass(value.type_, ApiModel):
json[key] = value.type_.from_id(json[key])
return cls(**json)
if id not in cache:
r: requests.Response = session.get(f'{cls._endpoint}/{id}/')
raise_for_status(r)
json = r.json()
for key, value in cls.__fields__.items():
if key in kwargs:
json[key] = kwargs[key]
elif value.type_ is not Any and issubclass(value.type_, ApiModel):
json[key] = value.type_.from_id(json[key])
cache[id] = cls(**json)
return cache[id]

@classmethod
def list(cls: Type[ModelType], **kwargs) -> Iterator[ModelType]:
Expand Down Expand Up @@ -260,7 +275,12 @@ def segmentations(self) -> Iterator[Segmentation]:
yield segmentation

def add_project(self, file: Path, keywords: str = '', description: str = '') -> Project:
project = Project(file=file, keywords=keywords, description=description, dataset=self)
project = Project(
file=file,
keywords=keywords,
description=description,
dataset=self,
).create()
return project.load_project_spreadsheet()

@classmethod
Expand Down Expand Up @@ -301,6 +321,16 @@ def load_data_spreadsheet(self, file: Union[Path, str]) -> Dataset:

return self

def download(self, path: Union[Path, str]):
self.assert_remote()
path = Path(path)
for segmentation in self.segmentations:
# TODO: add a dataset spreadsheet to the data model and get the path from it
segmentation.file.download(path / 'segmentations')

for project in self.projects:
project.download(path)


class Subject(ApiModel):
_endpoint = 'subjects'
Expand Down Expand Up @@ -365,8 +395,6 @@ def add_shape_model(self, parameters: Dict[str, Any]) -> OptimizedShapeModel:
).create()

def load_project_spreadsheet(self) -> Project:
self.assert_local()

file = self.file.path
assert file # should be guaranteed by assert_local

Expand All @@ -386,6 +414,29 @@ def load_project_spreadsheet(self) -> Project:

return self

def _iter_data_sheet(
self, sheet: Any, root: Path
) -> Iterator[Tuple[Path, Path, str, Path, Path]]:
headers = next(sheet)
if headers != (
'shape_file',
'groomed_file',
'alignment_file',
'local_particles_file',
'world_particles_file',
):
raise Exception('Unknown spreadsheet format')

for row in sheet:
shape_file, groomed_file, alignment_file, local, world = row

shape_file = root / shape_file
groomed_file = root / groomed_file
local = root / local
world = root / world

yield shape_file, groomed_file, alignment_file, local, world

def _parse_optimize_sheet(self, sheet: Any) -> OptimizedShapeModel:
headers = next(sheet)
if headers != ('key', 'value'):
Expand All @@ -409,24 +460,9 @@ def _parse_data_sheet(
sheet: Any,
root: Path,
):
headers = next(sheet)
if headers != (
'shape_file',
'groomed_file',
'alignment_file',
'local_particles_file',
'world_particles_file',
for shape_file, groomed_file, alignment_file, local, world in self._iter_data_sheet(
sheet, root
):
raise Exception('Unknown spreadsheet format')

for row in sheet:
shape_file, groomed_file, alignment_file, local, world = row

shape_file = root / shape_file
groomed_file = root / groomed_file
local = root / local
world = root / world

segmentation = segmentations.get(shape_file.stem)
if not segmentation:
raise Exception(f'Could not find segmentation for "{shape_file}"')
Expand All @@ -447,6 +483,27 @@ def _parse_data_sheet(
groomed_segmentation=groomed_segmentation,
)

def download(self, path: Union[Path, str]):
path = Path(path)
project_file = self.file.download(path)
xls = load_workbook(str(project_file), read_only=True)
sheet = xls['data'].values

shape_model = next(self.shape_models)
groomed_segmentations = {
PurePath(gs.file.name).stem: gs for gs in self.groomed_segmentations
}
local_files = {PurePath(p.local.name).stem: p for p in shape_model.particles}

# TODO: Do we detect if alignment_file (transform) is a path?
for _, groomed_file, _, local, world in self._iter_data_sheet(sheet, path):
gs = groomed_segmentations[groomed_file.stem]
gs.file.download(groomed_file.parent)

particles = local_files[local.stem]
particles.local.download(local.parent)
particles.world.download(world.parent)


class GroomedSegmentation(ApiModel):
_endpoint = 'groomed-segmentations'
Expand Down
65 changes: 29 additions & 36 deletions swcc/usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 42,
"id": "f34f8ace",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import re\n",
"from shutil import rmtree\n",
"\n",
"import getpass\n",
"from tqdm.notebook import tqdm\n",
Expand All @@ -32,7 +33,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 19,
"id": "fa5aa4f6",
"metadata": {
"scrolled": true
Expand All @@ -44,7 +45,8 @@
"text": [
"username········\n",
"password········\n",
"[(13, 'left_atrium')]\n"
"list <class 'swcc.models.Dataset'>\n",
"[(13, 'left_atrium'), (14, 'left_atrium_test'), (22, 'left_atrium_test_1')]\n"
]
}
],
Expand All @@ -58,15 +60,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 21,
"id": "2eb23385",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(13, 'left_atrium')]\n"
"list <class 'swcc.models.Dataset'>\n",
"[(13, 'left_atrium'), (14, 'left_atrium_test'), (22, 'left_atrium_test_1')]\n"
]
}
],
Expand Down Expand Up @@ -213,7 +216,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 44,
"id": "95b612de",
"metadata": {},
"outputs": [],
Expand All @@ -231,38 +234,36 @@
"with (dataset_path / 'License.txt').open('r') as f:\n",
" license = f.read()\n",
"\n",
"dataset = Dataset(name='left_atrium_test_1', license=license, description=description, acknowledgement=acknowledgement).create()\n",
"dataset = Dataset(name='left_atrium', license=license, description=description, acknowledgement=acknowledgement).create()\n",
"dataset.load_data_spreadsheet(root_path / 'data.xlsx')\n",
"\n",
"project = dataset.add_project(file=dataset_path / 'project.xlsx')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"id": "ca774206",
"metadata": {},
"outputs": [],
"source": [
"# Download a full dataset in bulk\n",
"dataset = Dataset.from_name('left_atrium')\n",
"download_path = Path('downloads')\n",
"if download_path.exists():\n",
" rmtree(str(download_path))\n",
" \n",
"dataset.download(download_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bcd0e591",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Anatomy type: left_atrium\n"
]
},
{
"data": {
"text/plain": [
"{'number_of_particles': 128}"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Now we can explore the data using the models\n",
"# We can also explore the data using the models\n",
"segmentation = next(subject.segmentations)\n",
"print(f'Anatomy type: {segmentation.anatomy_type}')\n",
"\n",
Expand Down Expand Up @@ -318,14 +319,6 @@
"dataset.delete()\n",
"project.delete()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f31bd7b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit fc158b7

Please sign in to comment.