Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make improvements to parallel tile fetching #108

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hips/draw/tests/test_paint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setup_class(cls):
width=2000, height=1000, fov="3 deg",
coordsys='icrs', projection='AIT',
)
fetch_opts = dict(fetch_package='urllib', timeout=30, n_parallel=10)
fetch_opts = dict(timeout=30, n_parallel=10)
cls.painter = HipsPainter(cls.geometry, cls.hips_survey, 'fits', fetch_opts=fetch_opts)

def test_draw_hips_order(self):
Expand All @@ -44,7 +44,7 @@ def test_compute_matching_hips_order(self, pars):
coordsys='icrs', projection='AIT',
)

fetch_opts = dict(fetch_package='urllib', timeout=30, n_parallel=10)
fetch_opts = dict(timeout=30, n_parallel=10)
simple_tile_painter = HipsPainter(geometry, self.hips_survey, 'fits', fetch_opts=fetch_opts)
assert simple_tile_painter.draw_hips_order == pars['order']

Expand Down
2 changes: 1 addition & 1 deletion hips/draw/tests/test_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_make_sky_image(tmpdir, pars):
hips_survey = HipsSurveyProperties.fetch(url=pars['url'])
geometry = make_test_wcs_geometry()

fetch_opts = dict(fetch_package='urllib', timeout=30, n_parallel=10)
fetch_opts = dict(timeout=30, n_parallel=10)
result = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format=pars['file_format'],
precise=pars['precise'], fetch_opts=fetch_opts)

Expand Down
2 changes: 1 addition & 1 deletion hips/draw/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def plot(self, show_grid: bool = False) -> None:
def report(self) -> None:
"""Print a brief report for the fetched data."""

print (
print(
f"Time for fetching tiles = {self.stats['fetch_time']} seconds\n"
f"Time for drawing tiles = {self.stats['draw_time']} seconds\n"
f"Total memory consumed = {self.stats['consumed_memory'] / 1e6} MB\n"
Expand Down
117 changes: 32 additions & 85 deletions hips/tiles/fetch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import socket
import asyncio
import urllib.request
import concurrent.futures
Expand All @@ -16,7 +17,7 @@

def fetch_tiles(tile_metas: List[HipsTileMeta], hips_survey: HipsSurveyProperties,
progress_bar: bool = True, n_parallel: int = 5,
timeout: float = 10, fetch_package: str = 'urllib') -> List[HipsTile]:
timeout: float = 10) -> List[HipsTile]:
"""Fetch a list of HiPS tiles.

This function fetches a list of HiPS tiles based
Expand All @@ -37,8 +38,6 @@ def fetch_tiles(tile_metas: List[HipsTileMeta], hips_survey: HipsSurveyPropertie
Number of tile fetch web requests to make in parallel
timeout : float
Seconds to timeout for fetching a HiPS tile
fetch_package : {'urllib', 'aiohttp'}
Package to use for fetching HiPS tiles

Examples
--------
Expand Down Expand Up @@ -68,90 +67,38 @@ def fetch_tiles(tile_metas: List[HipsTileMeta], hips_survey: HipsSurveyPropertie
tiles : list
A Python list of `~hips.HipsTile`
"""
if fetch_package == 'aiohttp':
fetch_fct = tiles_aiohttp
elif fetch_package == 'urllib':
fetch_fct = tiles_urllib
else:
raise ValueError(f'Invalid package name: {fetch_package}')

tiles = fetch_fct(tile_metas, hips_survey, progress_bar, n_parallel, timeout)

# Sort tiles to match the tile_meta list
# TODO: this doesn't seem like a great solution.
# Use dict instead?
out = []
for tile_meta in tile_metas:
for tile in tiles:
if tile.meta == tile_meta:
out.append(tile)
continue
return out


def fetch_tile_urllib(url: str, meta: HipsTileMeta, timeout: float) -> HipsTile:
"""Fetch a HiPS tile asynchronously."""
with urllib.request.urlopen(url, timeout=timeout) as conn:
raw_data = conn.read()
return HipsTile(meta, raw_data)


def tiles_urllib(tile_metas: List[HipsTileMeta], hips_survey: HipsSurveyProperties,
progress_bar: bool, n_parallel, timeout: float) -> List[HipsTile]:
"""Generator function to fetch HiPS tiles from a remote URL."""
with concurrent.futures.ThreadPoolExecutor(max_workers=n_parallel) as executor:
futures = []
for meta in tile_metas:
url = hips_survey.tile_url(meta)
future = executor.submit(fetch_tile_urllib, url, meta, timeout)
futures.append(future)

futures = concurrent.futures.as_completed(futures)
if progress_bar:
from tqdm import tqdm
futures = tqdm(futures, total=len(tile_metas), desc='Fetching tiles')

tiles = []
for future in futures:
tiles.append(future.result())

return tiles
tile_urls = [hips_survey.tile_url(meta) for meta in tile_metas]
response_all = do_fetch_tiles(tile_urls, hips_survey, progress_bar, n_parallel, timeout)


async def fetch_tile_aiohttp(url: str, meta: HipsTileMeta, session, timeout: float) -> HipsTile:
"""Fetch a HiPS tile asynchronously using aiohttp."""
async with session.get(url, timeout=timeout) as response:
raw_data = await response.read()
return HipsTile(meta, raw_data)


async def fetch_all_tiles_aiohttp(tile_metas: List[HipsTileMeta], hips_survey: HipsSurveyProperties,
progress_bar: bool, n_parallel: int, timeout: float) -> List[HipsTile]:
"""Generator function to fetch HiPS tiles from a remote URL using aiohttp."""
import aiohttp

connector = aiohttp.TCPConnector(limit=n_parallel)
async with aiohttp.ClientSession(connector=connector) as session:
futures = []
for meta in tile_metas:
url = hips_survey.tile_url(meta)
future = asyncio.ensure_future(fetch_tile_aiohttp(url, meta, session, timeout))
futures.append(future)

futures = asyncio.as_completed(futures)
if progress_bar:
from tqdm import tqdm
futures = tqdm(futures, total=len(tile_metas), desc='Fetching tiles')

tiles = []
for future in futures:
tiles.append(await future)
tiles = []
for idx, response in enumerate(response_all):
try:
response['raw_data']
tiles.append(HipsTile(tile_metas[idx], response['raw_data']))
except KeyError:
tiles.append(HipsTile(tile_metas[idx], b'', is_missing=True))

return tiles


def tiles_aiohttp(tile_metas: List[HipsTileMeta], hips_survey: HipsSurveyProperties,
progress_bar: bool, n_parallel: int, timeout: float) -> List[HipsTile]:
return asyncio.get_event_loop().run_until_complete(
fetch_all_tiles_aiohttp(tile_metas, hips_survey, progress_bar, n_parallel, timeout)
)
def do_fetch_single_tile(url: str, timeout: float) -> dict:
"""Fetch a HiPS tile asynchronously."""
try:
with urllib.request.urlopen(url, timeout=timeout) as conn:
return {'raw_data': conn.read(), 'url': url}
except urllib.error.HTTPError as error:
# If the tile is missing, enable the `is_missing` flag in HipsTile.
if error.code == 404:
print(f'Tile not found at:\n{url}')
return {'is_missing': True}
except urllib.error.URLError as error:
if isinstance(error.reason, socket.timeout):
print(f'The server timed out while fetching the tile at:\n{url}')
return {'is_missing': True}


def do_fetch_tiles(tile_urls: List[str], hips_survey: HipsSurveyProperties,
progress_bar: bool, n_parallel, timeout: float) -> List[dict]:
"""Generator function to fetch HiPS tiles from a remote URL."""
with concurrent.futures.ThreadPoolExecutor(max_workers=n_parallel) as executor:
return list(executor.map(do_fetch_single_tile, tile_urls, [timeout] * len(tile_urls)))
19 changes: 13 additions & 6 deletions hips/tiles/tests/test_fetch.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import pytest
from astropy.tests.helper import remote_data
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_equal
from ..fetch import fetch_tiles
from ..survey import HipsSurveyProperties
from ..tile import HipsTileMeta

TILE_FETCH_TEST_CASES = [
dict(
tile_indices=[69623, 69627, 69628, 69629, 69630, 69631],
is_missing=[False, False, False, False, False, False],
tile_format='fits',
order=7,
url='http://alasky.unistra.fr/DSS/DSS2Merged/properties',
progress_bar=True,
data=[2101, 1945, 1828, 1871, 2079, 2336],
fetch_package='urllib',
),
dict(
tile_indices=[69623, 69627, 69628, 69629, 69630, 69631],
tile_indices=[69623,
9999999, # missing
69628,
9999999, # missing
9999999, # missing
69631],
is_missing=[False, True, False, True, True, False],
tile_format='fits',
order=7,
url='http://alasky.unistra.fr/DSS/DSS2Merged/properties',
progress_bar=True,
data=[2101, 1945, 1828, 1871, 2079, 2336],
fetch_package='aiohttp',
data=[2101, 0, 1828, 0, 0, 2336],
),
]

Expand All @@ -48,8 +53,10 @@ def test_fetch_tiles(pars):
tiles = fetch_tiles(
tile_metas, hips_survey,
progress_bar=pars['progress_bar'],
fetch_package=pars['fetch_package'],
)

for idx, val in enumerate(pars['data']):
assert_allclose(tiles[idx].data[0][5], val)

for idx, val in enumerate(pars['is_missing']):
assert_equal(tiles[idx].is_missing, val)
36 changes: 28 additions & 8 deletions hips/tiles/tile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
from typing import List, Tuple
from copy import deepcopy
import socket
import warnings
import urllib.request
from io import BytesIO
Expand Down Expand Up @@ -154,6 +155,8 @@ class HipsTile:
Metadata of HiPS tile
raw_data : `bytes`
Raw data (copy of bytes from file)
is_missing : `bool`
Specifies whether the tile is missing or not

Examples
--------
Expand All @@ -175,9 +178,10 @@ class HipsTile:
int16
"""

def __init__(self, meta: HipsTileMeta, raw_data: bytes) -> None:
def __init__(self, meta: HipsTileMeta, raw_data: bytes, is_missing: bool = False) -> None:
self.meta = meta
self.raw_data = raw_data
self.is_missing = is_missing
self._data = None

def __eq__(self, other: "HipsTile") -> bool:
Expand Down Expand Up @@ -259,8 +263,11 @@ def data(self) -> np.ndarray:

See the `to_numpy` function.
"""
if self._data is None:
if self._data is None and not self.is_missing:
self._data = self.to_numpy(self.raw_data, self.meta.file_format)
elif self.is_missing:
self._data = np.zeros((compute_image_shape(self.meta.width, self.meta.width, self.meta.file_format)),
dtype=np.uint8)

return self._data

Expand Down Expand Up @@ -295,6 +302,7 @@ def to_numpy(raw_data: bytes, fmt: str) -> np.ndarray:
elif fmt in {"jpg", "png"}:
with Image.open(bio) as image:
data = np.array(image)

# Flip tile to be consistent with FITS orientation
data = np.flipud(data)
else:
Expand All @@ -316,8 +324,12 @@ def read(cls, meta: HipsTileMeta, filename: str = None) -> "HipsTile":
filename : str
Filename
"""
raw_data = Path(filename).read_bytes()
return cls(meta, raw_data)
if Path(filename).exists():
return cls(meta, Path(filename).read_bytes())
else:
print(f'Tile not found at:\n{filename}')
return cls(meta, b'', is_missing=True)


@classmethod
def fetch(cls, meta: HipsTileMeta, url: str) -> "HipsTile":
Expand All @@ -330,10 +342,18 @@ def fetch(cls, meta: HipsTileMeta, url: str) -> "HipsTile":
url : str
URL containing HiPS tile
"""
with urllib.request.urlopen(url) as response:
raw_data = response.read()

return cls(meta, raw_data)
try:
with urllib.request.urlopen(url, timeout=10) as response:
raw_data = response.read()
return cls(meta, raw_data)
except urllib.error.HTTPError as error:
if error.code == 404:
print(f'Tile not found at:\n{url}')
return cls(meta, b'', is_missing=True)
except urllib.error.URLError as error:
if isinstance(error.reason, socket.timeout):
print(f'The server timed out while fetching the tile at:\n{url}')
return cls(meta, b'', is_missing=True)

def write(self, filename: str) -> None:
"""Write to file.
Expand Down