diff --git a/docs/getting_started.rst b/docs/getting_started.rst index ac958e1..0771e95 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -37,22 +37,34 @@ To make a sky image with the `hips` package, follow the following three steps: 3. Call the `~hips.make_sky_image` function to fetch the HiPS data - and draw it, returning the sky image pixel data as a Numpy array:: + and draw it, returning an object of `~hips.HipsDrawResult`:: from hips import make_sky_image - data = make_sky_image(geometry, hips_survey, 'fits') + result = make_sky_image(geometry, hips_survey, 'fits') That's it. Go ahead and try it out for your favourite sky region and survey. Now you can then save the sky image to local disk e.g. FITS file format:: - from astropy.io import fits - hdu = fits.PrimaryHDU(data=data, header=geometry.fits_header) - hdu.writeto('my_image.fits') + result.write_image('my_image.fits') -or plot and analyse the sky image however you like. +The ``result`` object also contains other useful information, such as:: + + result.image + +will return a NumPy array containing pixel data, you can also get the WCS information using:: + + result.geometry + +If you want, you could also print out information about the ``result``:: + + print(result) + +or plot and analyse the sky image using:: + + result.plot() If you execute the example above, you will get this sky image which was plotted using `astropy.visualization.wcsaxes` diff --git a/docs/plot_fits.py b/docs/plot_fits.py index 14bcb18..bc1f860 100644 --- a/docs/plot_fits.py +++ b/docs/plot_fits.py @@ -10,11 +10,11 @@ width=2000, height=1000, fov="3 deg", coordsys='galactic', projection='AIT', ) -image = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format='fits') +result = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format='fits') # Draw the sky image import matplotlib.pyplot as plt from astropy.visualization.mpl_normalize import simple_norm ax = plt.subplot(projection=geometry.wcs) -norm = simple_norm(image, 'sqrt', min_percent=1, max_percent=99) -ax.imshow(image, origin='lower', norm=norm, cmap='gray') +norm = simple_norm(result.image, 'sqrt', min_percent=1, max_percent=99) +ax.imshow(result.image, origin='lower', norm=norm, cmap='gray') diff --git a/docs/plot_jpg.py b/docs/plot_jpg.py index ba3b852..1eb111b 100644 --- a/docs/plot_jpg.py +++ b/docs/plot_jpg.py @@ -10,9 +10,9 @@ width=2000, height=1000, fov="3 deg", coordsys='galactic', projection='AIT', ) -image = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format='jpg') +result = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format='jpg') # Draw the sky image import matplotlib.pyplot as plt ax = plt.subplot(projection=geometry.wcs) -ax.imshow(image, origin='lower') +ax.imshow(result.image, origin='lower') diff --git a/hips/draw/simple.py b/hips/draw/simple.py index b53bad8..a6dc4d7 100644 --- a/hips/draw/simple.py +++ b/hips/draw/simple.py @@ -1,6 +1,8 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst """HiPS tile drawing -- simple method.""" import numpy as np +from PIL import Image +from astropy.io import fits from typing import List, Tuple from astropy.wcs.utils import proj_plane_pixel_scales from skimage.transform import ProjectiveTransform, warp @@ -125,6 +127,11 @@ def tiles(self) -> List[HipsTile]: return self._tiles + @property + def result(self) -> 'HipsDrawResult': + """Return an object of `~hips.HipsDrawResult` class.""" + return HipsDrawResult(self.image, self.geometry, self.tile_format, self.tiles) + def warp_image(self, tile: HipsTile) -> np.ndarray: """Warp a HiPS tile and a sky image.""" return warp( @@ -172,6 +179,71 @@ def plot_mpl_hips_tile_grid(self) -> None: ax.imshow(self.image, origin='lower') +class HipsDrawResult: + """Container class for reporting information related with fetching / drawing of HiPS tiles. + + Parameters + ---------- + image: `~numpy.ndarray` + Container for HiPS tile data + geometry : `~hips.utils.WCSGeometry` + An object of WCSGeometry + tile_format : {'fits', 'jpg', 'png'} + Format of HiPS tile + tiles: List[HipsTile] + """ + + def __init__(self, image: np.ndarray, geometry: WCSGeometry, tile_format: str, tiles: List[HipsTile]) -> None: + self.image = image + self.geometry = geometry + self.tile_format = tile_format + self.tiles = tiles + + def __str__(self): + return ( + 'HiPS draw result:\n' + f'Sky image: shape={self.image.shape}, dtype={self.image.dtype}\n' + f'WCS geometry: {self.geometry}\n' + ) + + def __repr__(self): + return ( + 'HipsDrawResult(' + f'width={self.image.shape[0]}, ' + f'height={self.image.shape[1]}, ' + f'channels={self.image.ndim}, ' + f'dtype={self.image.dtype}, ' + f'format={self.tile_format}' + ')' + ) + + def write_image(self, filename: str) -> None: + """Write image to file. + + Parameters + ---------- + filename : str + Filename + """ + if self.tile_format == 'fits': + hdu = fits.PrimaryHDU(data=self.image, header=self.geometry.fits_header) + hdu.writeto(filename) + else: + image = Image.fromarray(self.image) + image.save(filename) + + def plot(self) -> None: + """Plot the all sky image using `astropy.visualization.wcsaxes` and showing the HEALPix grid.""" + import matplotlib.pyplot as plt + for tile in self.tiles: + corners = tile.meta.skycoord_corners.transform_to(self.geometry.celestial_frame) + ax = plt.subplot(projection=self.geometry.wcs) + opts = dict(color='red', lw=1, ) + ax.plot(corners.data.lon.deg, corners.data.lat.deg, + transform=ax.get_transform('world'), **opts) + ax.imshow(self.image, origin='lower') + + def measure_tile_shape(corners: tuple) -> Tuple[List[float]]: """Compute length of tile edges and diagonals.""" x, y = corners @@ -268,7 +340,7 @@ def plot_mpl_single_tile(geometry: WCSGeometry, tile: HipsTile, image: np.ndarra ax.imshow(image, origin='lower') -def make_sky_image(geometry: WCSGeometry, hips_survey: HipsSurveyProperties, tile_format: str) -> np.ndarray: +def make_sky_image(geometry: WCSGeometry, hips_survey: HipsSurveyProperties, tile_format: str) -> 'HipsDrawResult': """Make sky image: fetch tiles and draw. The example for this can be found on the :ref:`gs` page. @@ -291,4 +363,4 @@ def make_sky_image(geometry: WCSGeometry, hips_survey: HipsSurveyProperties, til painter = SimpleTilePainter(geometry, hips_survey, tile_format) painter.run() - return painter.image + return painter.result diff --git a/hips/draw/tests/test_simple.py b/hips/draw/tests/test_simple.py index 6dbf340..4287862 100644 --- a/hips/draw/tests/test_simple.py +++ b/hips/draw/tests/test_simple.py @@ -19,6 +19,7 @@ data_2=2296, data_sum=8756493140, dtype='>i2', + repr='HipsDrawResult(width=1000, height=2000, channels=2, dtype=>i2, format=fits)' ), dict( file_format='jpg', @@ -28,6 +29,7 @@ data_2=[137, 116, 114], data_sum=828908873, dtype='uint8', + repr='HipsDrawResult(width=1000, height=2000, channels=3, dtype=uint8, format=jpg)' ), dict( file_format='png', @@ -37,22 +39,25 @@ data_2=[227, 217, 205, 255], data_sum=1635622838, dtype='uint8', + repr='HipsDrawResult(width=1000, height=2000, channels=3, dtype=uint8, format=png)' ), ] @remote_data @pytest.mark.parametrize('pars', make_sky_image_pars) -def test_make_sky_image(pars): +def test_make_sky_image(tmpdir, pars): hips_survey = HipsSurveyProperties.fetch(url=pars['url']) geometry = make_test_wcs_geometry() - image = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format=pars['file_format']) - assert image.shape == pars['shape'] - assert image.dtype == pars['dtype'] - assert_allclose(np.sum(image), pars['data_sum']) - assert_allclose(image[200, 994], pars['data_1']) - assert_allclose(image[200, 995], pars['data_2']) - + result = make_sky_image(geometry=geometry, hips_survey=hips_survey, tile_format=pars['file_format']) + assert result.image.shape == pars['shape'] + assert result.image.dtype == pars['dtype'] + assert repr(result) == pars['repr'] + assert_allclose(np.sum(result.image), pars['data_sum']) + assert_allclose(result.image[200, 994], pars['data_1']) + assert_allclose(result.image[200, 995], pars['data_2']) + result.write_image(str(tmpdir / 'test.' + pars['file_format'])) + result.plot() @remote_data class TestSimpleTilePainter: diff --git a/hips/utils/wcs.py b/hips/utils/wcs.py index 6268dc8..e43b049 100644 --- a/hips/utils/wcs.py +++ b/hips/utils/wcs.py @@ -58,6 +58,13 @@ def __init__(self, wcs: WCS, width: int, height: int) -> None: self.wcs = wcs self.shape = Shape(width=width, height=height) + def __str__(self): + return ( + 'WCSGeometry data:\n' + f'WCS: {self.wcs}\n' + f'Shape: {self.shape}\n' + ) + @property def center_pix(self) -> Tuple[float, float]: """Image center in pixel coordinates (tuple of x, y)."""