From 0bbb9e307ad9bb5bdc11f4b311f50316093f018e Mon Sep 17 00:00:00 2001 From: Jonathan Beezley Date: Thu, 24 Jun 2021 09:17:35 -0600 Subject: [PATCH] Add the ability to download a dataset object --- swcc/swcc/api.py | 4 +- swcc/swcc/models.py | 117 ++++++++++++++++++++++++++++++++------------ swcc/usage.ipynb | 65 +++++++++++------------- 3 files changed, 119 insertions(+), 67 deletions(-) diff --git a/swcc/swcc/api.py b/swcc/swcc/api.py index 1aa84a88..db98715e 100644 --- a/swcc/swcc/api.py +++ b/swcc/swcc/api.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections import defaultdict from contextlib import contextmanager -from typing import List, Optional +from typing import Any, Dict, List, Optional import requests from requests_toolbelt.sessions import BaseUrlSession @@ -42,6 +43,7 @@ def __init__( base_url = f'{base_url.rstrip("/")}/' # tolerate input with or without trailing slash super().__init__(base_url=base_url, **kwargs) + self.cache: Dict[Any, Dict[int, Any]] = defaultdict(dict) retry = Retry() adapter = requests.adapters.HTTPAdapter(max_retries=retry) self.mount(base_url, adapter) diff --git a/swcc/swcc/models.py b/swcc/swcc/models.py index 99b70ad9..2a7036d3 100644 --- a/swcc/swcc/models.py +++ b/swcc/swcc/models.py @@ -2,7 +2,19 @@ from pathlib import Path, PurePath from tempfile import TemporaryDirectory -from typing import Any, Dict, Generic, Iterator, Literal, Optional, Type, TypeVar, Union, get_args +from typing import ( + Any, + Dict, + Generic, + Iterator, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, + get_args, +) from urllib.parse import unquote from openpyxl import load_workbook @@ -152,16 +164,19 @@ def fetch_entity(cls, v, field: ModelField): @classmethod def from_id(cls: Type[ModelType], id: int, **kwargs) -> ModelType: session = current_session() + cache = session.cache[cls] - r: requests.Response = session.get(f'{cls._endpoint}/{id}/') - raise_for_status(r) - json = r.json() - for key, value in cls.__fields__.items(): - if key in kwargs: - json[key] = kwargs[key] - elif value.type_ is not Any and issubclass(value.type_, ApiModel): - json[key] = value.type_.from_id(json[key]) - return cls(**json) + if id not in cache: + r: requests.Response = session.get(f'{cls._endpoint}/{id}/') + raise_for_status(r) + json = r.json() + for key, value in cls.__fields__.items(): + if key in kwargs: + json[key] = kwargs[key] + elif value.type_ is not Any and issubclass(value.type_, ApiModel): + json[key] = value.type_.from_id(json[key]) + cache[id] = cls(**json) + return cache[id] @classmethod def list(cls: Type[ModelType], **kwargs) -> Iterator[ModelType]: @@ -260,7 +275,12 @@ def segmentations(self) -> Iterator[Segmentation]: yield segmentation def add_project(self, file: Path, keywords: str = '', description: str = '') -> Project: - project = Project(file=file, keywords=keywords, description=description, dataset=self) + project = Project( + file=file, + keywords=keywords, + description=description, + dataset=self, + ).create() return project.load_project_spreadsheet() @classmethod @@ -301,6 +321,16 @@ def load_data_spreadsheet(self, file: Union[Path, str]) -> Dataset: return self + def download(self, path: Union[Path, str]): + self.assert_remote() + path = Path(path) + for segmentation in self.segmentations: + # TODO: add a dataset spreadsheet to the data model and get the path from it + segmentation.file.download(path / 'segmentations') + + for project in self.projects: + project.download(path) + class Subject(ApiModel): _endpoint = 'subjects' @@ -365,8 +395,6 @@ def add_shape_model(self, parameters: Dict[str, Any]) -> OptimizedShapeModel: ).create() def load_project_spreadsheet(self) -> Project: - self.assert_local() - file = self.file.path assert file # should be guaranteed by assert_local @@ -386,6 +414,29 @@ def load_project_spreadsheet(self) -> Project: return self + def _iter_data_sheet( + self, sheet: Any, root: Path + ) -> Iterator[Tuple[Path, Path, str, Path, Path]]: + headers = next(sheet) + if headers != ( + 'shape_file', + 'groomed_file', + 'alignment_file', + 'local_particles_file', + 'world_particles_file', + ): + raise Exception('Unknown spreadsheet format') + + for row in sheet: + shape_file, groomed_file, alignment_file, local, world = row + + shape_file = root / shape_file + groomed_file = root / groomed_file + local = root / local + world = root / world + + yield shape_file, groomed_file, alignment_file, local, world + def _parse_optimize_sheet(self, sheet: Any) -> OptimizedShapeModel: headers = next(sheet) if headers != ('key', 'value'): @@ -409,24 +460,9 @@ def _parse_data_sheet( sheet: Any, root: Path, ): - headers = next(sheet) - if headers != ( - 'shape_file', - 'groomed_file', - 'alignment_file', - 'local_particles_file', - 'world_particles_file', + for shape_file, groomed_file, alignment_file, local, world in self._iter_data_sheet( + sheet, root ): - raise Exception('Unknown spreadsheet format') - - for row in sheet: - shape_file, groomed_file, alignment_file, local, world = row - - shape_file = root / shape_file - groomed_file = root / groomed_file - local = root / local - world = root / world - segmentation = segmentations.get(shape_file.stem) if not segmentation: raise Exception(f'Could not find segmentation for "{shape_file}"') @@ -447,6 +483,27 @@ def _parse_data_sheet( groomed_segmentation=groomed_segmentation, ) + def download(self, path: Union[Path, str]): + path = Path(path) + project_file = self.file.download(path) + xls = load_workbook(str(project_file), read_only=True) + sheet = xls['data'].values + + shape_model = next(self.shape_models) + groomed_segmentations = { + PurePath(gs.file.name).stem: gs for gs in self.groomed_segmentations + } + local_files = {PurePath(p.local.name).stem: p for p in shape_model.particles} + + # TODO: Do we detect if alignment_file (transform) is a path? + for _, groomed_file, _, local, world in self._iter_data_sheet(sheet, path): + gs = groomed_segmentations[groomed_file.stem] + gs.file.download(groomed_file.parent) + + particles = local_files[local.stem] + particles.local.download(local.parent) + particles.world.download(world.parent) + class GroomedSegmentation(ApiModel): _endpoint = 'groomed-segmentations' diff --git a/swcc/usage.ipynb b/swcc/usage.ipynb index 36747f29..34f717da 100644 --- a/swcc/usage.ipynb +++ b/swcc/usage.ipynb @@ -12,13 +12,14 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 42, "id": "f34f8ace", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import re\n", + "from shutil import rmtree\n", "\n", "import getpass\n", "from tqdm.notebook import tqdm\n", @@ -32,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 19, "id": "fa5aa4f6", "metadata": { "scrolled": true @@ -44,7 +45,8 @@ "text": [ "username········\n", "password········\n", - "[(13, 'left_atrium')]\n" + "list \n", + "[(13, 'left_atrium'), (14, 'left_atrium_test'), (22, 'left_atrium_test_1')]\n" ] } ], @@ -58,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 21, "id": "2eb23385", "metadata": {}, "outputs": [ @@ -66,7 +68,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "[(13, 'left_atrium')]\n" + "list \n", + "[(13, 'left_atrium'), (14, 'left_atrium_test'), (22, 'left_atrium_test_1')]\n" ] } ], @@ -213,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 44, "id": "95b612de", "metadata": {}, "outputs": [], @@ -231,7 +234,7 @@ "with (dataset_path / 'License.txt').open('r') as f:\n", " license = f.read()\n", "\n", - "dataset = Dataset(name='left_atrium_test_1', license=license, description=description, acknowledgement=acknowledgement).create()\n", + "dataset = Dataset(name='left_atrium', license=license, description=description, acknowledgement=acknowledgement).create()\n", "dataset.load_data_spreadsheet(root_path / 'data.xlsx')\n", "\n", "project = dataset.add_project(file=dataset_path / 'project.xlsx')" @@ -239,30 +242,28 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, + "id": "ca774206", + "metadata": {}, + "outputs": [], + "source": [ + "# Download a full dataset in bulk\n", + "dataset = Dataset.from_name('left_atrium')\n", + "download_path = Path('downloads')\n", + "if download_path.exists():\n", + " rmtree(str(download_path))\n", + " \n", + "dataset.download(download_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, "id": "bcd0e591", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Anatomy type: left_atrium\n" - ] - }, - { - "data": { - "text/plain": [ - "{'number_of_particles': 128}" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "# Now we can explore the data using the models\n", + "# We can also explore the data using the models\n", "segmentation = next(subject.segmentations)\n", "print(f'Anatomy type: {segmentation.anatomy_type}')\n", "\n", @@ -318,14 +319,6 @@ "dataset.delete()\n", "project.delete()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0f31bd7b", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {