diff --git a/hips/draw/tests/test_paint.py b/hips/draw/tests/test_paint.py index dfb2db8..89a8b2a 100644 --- a/hips/draw/tests/test_paint.py +++ b/hips/draw/tests/test_paint.py @@ -20,7 +20,7 @@ def setup_class(cls): width=2000, height=1000, fov="3 deg", coordsys='icrs', projection='AIT', ) - fetch_opts = dict(fetch_package='urllib', timeout=30, n_parallel=10) + fetch_opts = dict(timeout=30, n_parallel=10) cls.painter = HipsPainter(cls.geometry, cls.hips_survey, 'fits', fetch_opts=fetch_opts) def test_draw_hips_order(self): @@ -44,7 +44,7 @@ def test_compute_matching_hips_order(self, pars): coordsys='icrs', projection='AIT', ) - fetch_opts = dict(fetch_package='urllib', timeout=30, n_parallel=10) + fetch_opts = dict(timeout=30, n_parallel=10) simple_tile_painter = HipsPainter(geometry, self.hips_survey, 'fits', fetch_opts=fetch_opts) assert simple_tile_painter.draw_hips_order == pars['order'] diff --git a/hips/draw/tests/test_ui.py b/hips/draw/tests/test_ui.py index 786f3d2..a7dd254 100644 --- a/hips/draw/tests/test_ui.py +++ b/hips/draw/tests/test_ui.py @@ -61,7 +61,7 @@ def test_make_sky_image(tmpdir, pars): hips_survey = HipsSurveyProperties.fetch(url=pars['url']) geometry = make_test_wcs_geometry() - fetch_opts = dict(fetch_package='urllib', timeout=30, n_parallel=10) + fetch_opts = dict(timeout=30, n_parallel=10) result = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format=pars['file_format'], precise=pars['precise'], fetch_opts=fetch_opts) diff --git a/hips/draw/ui.py b/hips/draw/ui.py index b12ee19..9bcc414 100644 --- a/hips/draw/ui.py +++ b/hips/draw/ui.py @@ -147,7 +147,7 @@ def plot(self, show_grid: bool = False) -> None: def report(self) -> None: """Print a brief report for the fetched data.""" - print ( + print( f"Time for fetching tiles = {self.stats['fetch_time']} seconds\n" f"Time for drawing tiles = {self.stats['draw_time']} seconds\n" f"Total memory consumed = {self.stats['consumed_memory'] / 1e6} MB\n" diff --git a/hips/tiles/fetch.py b/hips/tiles/fetch.py index 3d50fb2..c79fdde 100644 --- a/hips/tiles/fetch.py +++ b/hips/tiles/fetch.py @@ -1,4 +1,5 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst +import socket import asyncio import urllib.request import concurrent.futures @@ -16,7 +17,7 @@ def fetch_tiles(tile_metas: List[HipsTileMeta], hips_survey: HipsSurveyProperties, progress_bar: bool = True, n_parallel: int = 5, - timeout: float = 10, fetch_package: str = 'urllib') -> List[HipsTile]: + timeout: float = 10) -> List[HipsTile]: """Fetch a list of HiPS tiles. This function fetches a list of HiPS tiles based @@ -37,8 +38,6 @@ def fetch_tiles(tile_metas: List[HipsTileMeta], hips_survey: HipsSurveyPropertie Number of tile fetch web requests to make in parallel timeout : float Seconds to timeout for fetching a HiPS tile - fetch_package : {'urllib', 'aiohttp'} - Package to use for fetching HiPS tiles Examples -------- @@ -68,90 +67,38 @@ def fetch_tiles(tile_metas: List[HipsTileMeta], hips_survey: HipsSurveyPropertie tiles : list A Python list of `~hips.HipsTile` """ - if fetch_package == 'aiohttp': - fetch_fct = tiles_aiohttp - elif fetch_package == 'urllib': - fetch_fct = tiles_urllib - else: - raise ValueError(f'Invalid package name: {fetch_package}') - - tiles = fetch_fct(tile_metas, hips_survey, progress_bar, n_parallel, timeout) - - # Sort tiles to match the tile_meta list - # TODO: this doesn't seem like a great solution. - # Use dict instead? - out = [] - for tile_meta in tile_metas: - for tile in tiles: - if tile.meta == tile_meta: - out.append(tile) - continue - return out - - -def fetch_tile_urllib(url: str, meta: HipsTileMeta, timeout: float) -> HipsTile: - """Fetch a HiPS tile asynchronously.""" - with urllib.request.urlopen(url, timeout=timeout) as conn: - raw_data = conn.read() - return HipsTile(meta, raw_data) - - -def tiles_urllib(tile_metas: List[HipsTileMeta], hips_survey: HipsSurveyProperties, - progress_bar: bool, n_parallel, timeout: float) -> List[HipsTile]: - """Generator function to fetch HiPS tiles from a remote URL.""" - with concurrent.futures.ThreadPoolExecutor(max_workers=n_parallel) as executor: - futures = [] - for meta in tile_metas: - url = hips_survey.tile_url(meta) - future = executor.submit(fetch_tile_urllib, url, meta, timeout) - futures.append(future) - - futures = concurrent.futures.as_completed(futures) - if progress_bar: - from tqdm import tqdm - futures = tqdm(futures, total=len(tile_metas), desc='Fetching tiles') - - tiles = [] - for future in futures: - tiles.append(future.result()) - - return tiles + tile_urls = [hips_survey.tile_url(meta) for meta in tile_metas] + response_all = do_fetch_tiles(tile_urls, hips_survey, progress_bar, n_parallel, timeout) - -async def fetch_tile_aiohttp(url: str, meta: HipsTileMeta, session, timeout: float) -> HipsTile: - """Fetch a HiPS tile asynchronously using aiohttp.""" - async with session.get(url, timeout=timeout) as response: - raw_data = await response.read() - return HipsTile(meta, raw_data) - - -async def fetch_all_tiles_aiohttp(tile_metas: List[HipsTileMeta], hips_survey: HipsSurveyProperties, - progress_bar: bool, n_parallel: int, timeout: float) -> List[HipsTile]: - """Generator function to fetch HiPS tiles from a remote URL using aiohttp.""" - import aiohttp - - connector = aiohttp.TCPConnector(limit=n_parallel) - async with aiohttp.ClientSession(connector=connector) as session: - futures = [] - for meta in tile_metas: - url = hips_survey.tile_url(meta) - future = asyncio.ensure_future(fetch_tile_aiohttp(url, meta, session, timeout)) - futures.append(future) - - futures = asyncio.as_completed(futures) - if progress_bar: - from tqdm import tqdm - futures = tqdm(futures, total=len(tile_metas), desc='Fetching tiles') - - tiles = [] - for future in futures: - tiles.append(await future) + tiles = [] + for idx, response in enumerate(response_all): + try: + response['raw_data'] + tiles.append(HipsTile(tile_metas[idx], response['raw_data'])) + except KeyError: + tiles.append(HipsTile(tile_metas[idx], b'', is_missing=True)) return tiles -def tiles_aiohttp(tile_metas: List[HipsTileMeta], hips_survey: HipsSurveyProperties, - progress_bar: bool, n_parallel: int, timeout: float) -> List[HipsTile]: - return asyncio.get_event_loop().run_until_complete( - fetch_all_tiles_aiohttp(tile_metas, hips_survey, progress_bar, n_parallel, timeout) - ) +def do_fetch_single_tile(url: str, timeout: float) -> dict: + """Fetch a HiPS tile asynchronously.""" + try: + with urllib.request.urlopen(url, timeout=timeout) as conn: + return {'raw_data': conn.read(), 'url': url} + except urllib.error.HTTPError as error: + # If the tile is missing, enable the `is_missing` flag in HipsTile. + if error.code == 404: + print(f'Tile not found at:\n{url}') + return {'is_missing': True} + except urllib.error.URLError as error: + if isinstance(error.reason, socket.timeout): + print(f'The server timed out while fetching the tile at:\n{url}') + return {'is_missing': True} + + +def do_fetch_tiles(tile_urls: List[str], hips_survey: HipsSurveyProperties, + progress_bar: bool, n_parallel, timeout: float) -> List[dict]: + """Generator function to fetch HiPS tiles from a remote URL.""" + with concurrent.futures.ThreadPoolExecutor(max_workers=n_parallel) as executor: + return list(executor.map(do_fetch_single_tile, tile_urls, [timeout] * len(tile_urls))) diff --git a/hips/tiles/tests/test_fetch.py b/hips/tiles/tests/test_fetch.py index 2b55f4d..a067e0f 100644 --- a/hips/tiles/tests/test_fetch.py +++ b/hips/tiles/tests/test_fetch.py @@ -1,7 +1,7 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst import pytest from astropy.tests.helper import remote_data -from numpy.testing import assert_allclose +from numpy.testing import assert_allclose, assert_equal from ..fetch import fetch_tiles from ..survey import HipsSurveyProperties from ..tile import HipsTileMeta @@ -9,21 +9,26 @@ TILE_FETCH_TEST_CASES = [ dict( tile_indices=[69623, 69627, 69628, 69629, 69630, 69631], + is_missing=[False, False, False, False, False, False], tile_format='fits', order=7, url='http://alasky.unistra.fr/DSS/DSS2Merged/properties', progress_bar=True, data=[2101, 1945, 1828, 1871, 2079, 2336], - fetch_package='urllib', ), dict( - tile_indices=[69623, 69627, 69628, 69629, 69630, 69631], + tile_indices=[69623, + 9999999, # missing + 69628, + 9999999, # missing + 9999999, # missing + 69631], + is_missing=[False, True, False, True, True, False], tile_format='fits', order=7, url='http://alasky.unistra.fr/DSS/DSS2Merged/properties', progress_bar=True, - data=[2101, 1945, 1828, 1871, 2079, 2336], - fetch_package='aiohttp', + data=[2101, 0, 1828, 0, 0, 2336], ), ] @@ -48,8 +53,10 @@ def test_fetch_tiles(pars): tiles = fetch_tiles( tile_metas, hips_survey, progress_bar=pars['progress_bar'], - fetch_package=pars['fetch_package'], ) for idx, val in enumerate(pars['data']): assert_allclose(tiles[idx].data[0][5], val) + + for idx, val in enumerate(pars['is_missing']): + assert_equal(tiles[idx].is_missing, val) diff --git a/hips/tiles/tile.py b/hips/tiles/tile.py index 1fb362c..cc31dbb 100644 --- a/hips/tiles/tile.py +++ b/hips/tiles/tile.py @@ -1,6 +1,7 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst from typing import List, Tuple from copy import deepcopy +import socket import warnings import urllib.request from io import BytesIO @@ -154,6 +155,8 @@ class HipsTile: Metadata of HiPS tile raw_data : `bytes` Raw data (copy of bytes from file) + is_missing : `bool` + Specifies whether the tile is missing or not Examples -------- @@ -175,9 +178,10 @@ class HipsTile: int16 """ - def __init__(self, meta: HipsTileMeta, raw_data: bytes) -> None: + def __init__(self, meta: HipsTileMeta, raw_data: bytes, is_missing: bool = False) -> None: self.meta = meta self.raw_data = raw_data + self.is_missing = is_missing self._data = None def __eq__(self, other: "HipsTile") -> bool: @@ -259,8 +263,11 @@ def data(self) -> np.ndarray: See the `to_numpy` function. """ - if self._data is None: + if self._data is None and not self.is_missing: self._data = self.to_numpy(self.raw_data, self.meta.file_format) + elif self.is_missing: + self._data = np.zeros((compute_image_shape(self.meta.width, self.meta.width, self.meta.file_format)), + dtype=np.uint8) return self._data @@ -295,6 +302,7 @@ def to_numpy(raw_data: bytes, fmt: str) -> np.ndarray: elif fmt in {"jpg", "png"}: with Image.open(bio) as image: data = np.array(image) + # Flip tile to be consistent with FITS orientation data = np.flipud(data) else: @@ -316,8 +324,12 @@ def read(cls, meta: HipsTileMeta, filename: str = None) -> "HipsTile": filename : str Filename """ - raw_data = Path(filename).read_bytes() - return cls(meta, raw_data) + if Path(filename).exists(): + return cls(meta, Path(filename).read_bytes()) + else: + print(f'Tile not found at:\n{filename}') + return cls(meta, b'', is_missing=True) + @classmethod def fetch(cls, meta: HipsTileMeta, url: str) -> "HipsTile": @@ -330,10 +342,18 @@ def fetch(cls, meta: HipsTileMeta, url: str) -> "HipsTile": url : str URL containing HiPS tile """ - with urllib.request.urlopen(url) as response: - raw_data = response.read() - - return cls(meta, raw_data) + try: + with urllib.request.urlopen(url, timeout=10) as response: + raw_data = response.read() + return cls(meta, raw_data) + except urllib.error.HTTPError as error: + if error.code == 404: + print(f'Tile not found at:\n{url}') + return cls(meta, b'', is_missing=True) + except urllib.error.URLError as error: + if isinstance(error.reason, socket.timeout): + print(f'The server timed out while fetching the tile at:\n{url}') + return cls(meta, b'', is_missing=True) def write(self, filename: str) -> None: """Write to file.