Skip to content

Commit

Permalink
Merge pull request #38 from LCOGT/refactor/image-data-class
Browse files Browse the repository at this point in the history
fits_file_reader and output_handler classes, changes to operations, tests
  • Loading branch information
LTDakin authored Oct 11, 2024
2 parents 55d5395 + 7d0416a commit a9f5658
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 89 deletions.
62 changes: 62 additions & 0 deletions datalab/datalab_session/data_operations/fits_output_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import tempfile
import numpy as np
from astropy.io import fits

from datalab.datalab_session.file_utils import create_jpgs
from datalab.datalab_session.s3_utils import save_fits_and_thumbnails


class FITSOutputHandler():
"""A class to handle FITS output files and create jpgs.
Class handles the creation of Datalab output for developers.
The class inits with a cache_key and data, and creates a FITS file with the data.
The FITS file is then saved to the cache and the large and small jpgs are created.
Attributes:
datalab_id (str): The cache key for the FITS file.
primary_hdu (fits.PrimaryHDU): The primary HDU for the FITS file.
image_hdu (fits.ImageHDU): The image HDU for the FITS file.
data (np.array): The data for the image HDU.
"""

def __init__(self, cache_key: str, data: np.array, comment: str=None) -> None:
"""Inits FITSOutputHandler with cache_key and data.
Args:
cache_key (str): The cache key for the FITS file, used as an ID when stored in S3.
data (np.array): The data that will create the image HDU.
comment (str): Optionally add a comment to add to the FITS file.
"""
self.datalab_id = cache_key
self.primary_hdu = fits.PrimaryHDU(header=fits.Header([('KEY', cache_key)]))
self.image_hdu = fits.ImageHDU(data=data, name='SCI')
if comment: self.set_comment(comment)

def __str__(self) -> str:
return f"Key: {self.datalab_id}\nData:\n{self.data}"

def set_comment(self, comment: str):
"""Add a comment to the FITS file."""
self.primary_hdu.header.add_comment(comment)

def create_and_save_data_products(self, index: int=None, large_jpg_path: str=None, small_jpg_path: str=None):
"""Create the FITS file and save it to S3.
This function can be called when you're done with the operation and would like to save the FITS file and jpgs in S3.
It returns a datalab output dictionary that is formatted to be readable by the frontend.
Args:
index (int): Optionally add an index to the FITS file name. Appended to cache_key for multiple outputs.
large_jpg (str): Optionally add a path to a large jpg to save, will not create a new jpg.
small_jpg (str): Optionally add a path to a small jpg to save, will not create a new jpg.
"""
hdu_list = fits.HDUList([self.primary_hdu, self.image_hdu])
fits_output_path = tempfile.NamedTemporaryFile(suffix=f'{self.datalab_id}.fits').name
hdu_list.writeto(fits_output_path, overwrite=True)

# allow for operations to pregenerate the jpgs, ex. RGB stacking
if not large_jpg_path or not small_jpg_path:
large_jpg_path, small_jpg_path = create_jpgs(self.datalab_id, fits_output_path)

return save_fits_and_thumbnails(self.datalab_id, fits_output_path, large_jpg_path, small_jpg_path, index)
44 changes: 44 additions & 0 deletions datalab/datalab_session/data_operations/input_data_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from astropy.io import fits

from datalab.datalab_session.s3_utils import get_fits
from datalab.datalab_session.file_utils import get_hdu

class InputDataHandler():
"""A class to read FITS files and provide access to the data.
The class inits with a basename and source, and reads the FITS file
this data is then stored in the class attributes for easy access.
Attributes:
basename (str): The basename of the FITS file.
fits_file (str): The path to the FITS file.
sci_data (np.array): The data from the 'SCI' extension of the FITS file.
"""

def __init__(self, basename: str, source: str = None) -> None:
"""Inits InputDataHandler with basename and source.
Uses the basename to query the archive for the matching FITS file.
Also can take a source argument to specify a different source for the FITS file.
At the time of writing two common sources are 'datalab' and 'archive'.
New sources will need to be added in the get_fits function in s3_utils.py.
Args:
basename (str): The basename of the FITS file.
source (str): Optionally add a source to the FITS file in case it's not the LCO archive.
"""
self.basename = basename
self.fits_file = get_fits(basename, source)
self.sci_data = get_hdu(self.fits_file, 'SCI').data

def __str__(self) -> str:
with fits.open(self.fits_file) as hdul:
return f"{self.basename}@{self.fits_file}\nHDU List\n{self.hdul.info()}"

def get_hdu(self, extension: str=None):
"""Return an HDU from the FITS file.
Args:
extension (str): The extension to return from the FITS file. Default is 'SCI'.
"""
return get_hdu(self.fits_file, extension)
39 changes: 20 additions & 19 deletions datalab/datalab_session/data_operations/median.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import numpy as np

from datalab.datalab_session.data_operations.input_data_handler import InputDataHandler
from datalab.datalab_session.data_operations.fits_output_handler import FITSOutputHandler
from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from datalab.datalab_session.exceptions import ClientAlertException
from datalab.datalab_session.file_utils import crop_arrays, create_output
from datalab.datalab_session.file_utils import crop_arrays

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -40,25 +42,24 @@ def wizard_description():
}

def operate(self):

input = self.input_data.get('input_files', [])

if len(input) <= 1:
raise ClientAlertException('Median needs at least 2 files')

log.info(f'Executing median operation on {len(input)} files')

image_data_list = self.get_fits_npdata(input)

cropped_data_list = crop_arrays(image_data_list)
stacked_data = np.stack(cropped_data_list, axis=2)

# using the numpy library's median method
median = np.median(stacked_data, axis=2)
# Getting/Checking the Input
input_list = self.input_data.get('input_files', [])
if len(input_list) <= 1: raise ClientAlertException('Median needs at least 2 files')
comment = f'Datalab Median on {", ".join([image["basename"] for image in input_list])}'
log.info(comment)

input_fits_list = []
for index, input in enumerate(input_list, start=1):
input_fits_list.append(InputDataHandler(input['basename'], input['source']))
self.set_operation_progress(0.5 * (index / len(input_list)))

# Creating the Median array
cropped_data = crop_arrays([image.sci_data for image in input_fits_list])
stacked_ndarray = np.stack(cropped_data, axis=2)
median = np.median(stacked_ndarray, axis=2)

self.set_operation_progress(0.80)

output = create_output(self.cache_key, median, comment=f'Product of Datalab Median on files {", ".join([image["basename"] for image in input])}')

output = FITSOutputHandler(self.cache_key, median, comment).create_and_save_data_products()
log.info(f'Median output: {output}')
self.set_output(output)
log.info(f'Median output: {self.get_output()}')
26 changes: 15 additions & 11 deletions datalab/datalab_session/data_operations/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import numpy as np

from datalab.datalab_session.data_operations.input_data_handler import InputDataHandler
from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from datalab.datalab_session.file_utils import create_output
from datalab.datalab_session.data_operations.fits_output_handler import FITSOutputHandler

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -40,20 +41,23 @@ def wizard_description():

def operate(self):

input = self.input_data.get('input_files', [])
input_list = self.input_data.get('input_files', [])
log.info(f'Normalization operation on {len(input_list)} file(s)')

log.info(f'Executing normalization operation on {len(input)} file(s)')

image_data_list = self.get_fits_npdata(input)
input_fits_list = []
for index, input in enumerate(input_list, start=1):
input_fits_list.append(InputDataHandler(input['basename'], input['source']))
self.set_operation_progress(0.5 * (index / len(input_list)))

output_files = []
for index, image in enumerate(image_data_list, start=1):
median = np.median(image)
normalized_image = image / median
for index, image in enumerate(input_fits_list, start=1):
median = np.median(image.sci_data)
normalized_image = image.sci_data / median

output = create_output(self.cache_key, normalized_image, index=index, comment=f'Product of Datalab Normalization on file {input[index-1]["basename"]}')
comment = f'Datalab Normalization on file {input_list[index-1]["basename"]}'
output = FITSOutputHandler(f'{self.cache_key}', normalized_image, comment).create_and_save_data_products(index=index)
output_files.append(output)
self.set_operation_progress(0.5 + index/len(image_data_list) * 0.4)
self.set_operation_progress(0.5 + index/len(input_fits_list) * 0.4)

log.info(f'Normalization output: {output_files}')
self.set_output(output_files)
log.info(f'Normalization output: {self.get_output()}')
37 changes: 18 additions & 19 deletions datalab/datalab_session/data_operations/rgb_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from astropy.io import fits
import numpy as np

from datalab.datalab_session.data_operations.input_data_handler import InputDataHandler
from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from datalab.datalab_session.data_operations.fits_output_handler import FITSOutputHandler
from datalab.datalab_session.exceptions import ClientAlertException
from datalab.datalab_session.file_utils import create_output, crop_arrays, create_jpgs
from datalab.datalab_session.s3_utils import get_fits
from datalab.datalab_session.file_utils import crop_arrays, create_jpgs

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -58,28 +59,26 @@ def wizard_description():

def operate(self):
rgb_input_list = self.input_data['red_input'] + self.input_data['green_input'] + self.input_data['blue_input']
if len(rgb_input_list) != 3: raise ClientAlertException('RGB stack requires exactly 3 files')
rgb_comment = f'Datalab RGB Stack on files {", ".join([image["basename"] for image in rgb_input_list])}'
log.info(rgb_comment)

if len(rgb_input_list) != 3:
raise ClientAlertException('RGB stack requires exactly 3 files')

log.info(f'Executing RGB Stack operation on files: {rgb_input_list}')
input_fits_list = []
for index, input in enumerate(rgb_input_list, start=1):
input_fits_list.append(InputDataHandler(input['basename'], input['source']))
self.set_operation_progress(0.4 * (index / len(rgb_input_list)))

fits_paths = []
for index, file in enumerate(rgb_input_list, start=1):
fits_paths.append(get_fits(file.get('basename')))
self.set_operation_progress(index * 0.2)

large_jpg_path, small_jpg_path = create_jpgs(self.cache_key, fits_paths, color=True)
fits_file_list = [image.fits_file for image in input_fits_list]
large_jpg_path, small_jpg_path = create_jpgs(self.cache_key, fits_file_list, color=True)
self.set_operation_progress(0.6)

# color photos take three files, so we store it as one fits file with a 3d SCI ndarray
arrays = [fits.open(file)['SCI'].data for file in fits_paths]
cropped_data_list = crop_arrays(arrays)
stacked_data = np.stack(cropped_data_list, axis=2)

sci_data_list = [image.sci_data for image in input_fits_list]
cropped_data_list = crop_arrays(sci_data_list)
stacked_ndarray = np.stack(cropped_data_list, axis=2)
self.set_operation_progress(0.8)

rgb_comment = f'Product of Datalab RGB Stack on files {", ".join([image["basename"] for image in rgb_input_list])}'
output = create_output(self.cache_key, stacked_data, large_jpg=large_jpg_path, small_jpg=small_jpg_path, comment=rgb_comment)
output = FITSOutputHandler(self.cache_key, stacked_ndarray, rgb_comment).create_and_save_data_products(large_jpg_path=large_jpg_path, small_jpg_path=small_jpg_path)

log.info(f'RGB Stack output: {output}')
self.set_output(output)
log.info(f'RGB Stack output: {self.get_output()}')
28 changes: 15 additions & 13 deletions datalab/datalab_session/data_operations/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import numpy as np

from datalab.datalab_session.data_operations.input_data_handler import InputDataHandler
from datalab.datalab_session.data_operations.fits_output_handler import FITSOutputHandler
from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from datalab.datalab_session.exceptions import ClientAlertException
from datalab.datalab_session.file_utils import create_output, crop_arrays
from datalab.datalab_session.file_utils import crop_arrays

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -43,24 +45,24 @@ def wizard_description():
def operate(self):

input_files = self.input_data.get('input_files', [])
if len(input_files) <= 1: raise ClientAlertException('Stack needs at least 2 files')
comment= f'Datalab Stacking on {", ".join([image["basename"] for image in input_files])}'
log.info(comment)

if len(input_files) <= 1:
raise ClientAlertException('Stack needs at least 2 files')
input_fits_list = []
for index, input in enumerate(input_files, start=1):
input_fits_list.append(InputDataHandler(input['basename'], input['source']))
self.set_operation_progress(0.5 * (index / len(input_files)))

log.info(f'Executing stacking operation on {len(input_files)} files')

image_data_list = self.get_fits_npdata(input_files)

cropped_data = crop_arrays(image_data_list)
stacked_data = np.stack(cropped_data, axis=2)
cropped_data = crop_arrays([image.sci_data for image in input_fits_list])
stacked_ndarray = np.stack(cropped_data, axis=2)
self.set_operation_progress(0.6)

# using the numpy library's sum method
stacked_sum = np.sum(stacked_data, axis=2)
stacked_sum = np.sum(stacked_ndarray, axis=2)
self.set_operation_progress(0.8)

stacking_comment = f'Product of Datalab Stacking. Stack of {", ".join([image["basename"] for image in input_files])}'
output = create_output(self.cache_key, stacked_sum, comment=stacking_comment)
output = FITSOutputHandler(self.cache_key, stacked_sum, comment).create_and_save_data_products()

log.info(f'Stacked output: {output}')
self.set_output(output)
log.info(f'Stacked output: {self.get_output()}')
35 changes: 17 additions & 18 deletions datalab/datalab_session/data_operations/subtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import numpy as np

from datalab.datalab_session.data_operations.input_data_handler import InputDataHandler
from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from datalab.datalab_session.data_operations.fits_output_handler import FITSOutputHandler
from datalab.datalab_session.exceptions import ClientAlertException
from datalab.datalab_session.file_utils import crop_arrays, create_output
from datalab.datalab_session.file_utils import crop_arrays

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -52,30 +54,27 @@ def operate(self):
input_files = self.input_data.get('input_files', [])
subtraction_file_input = self.input_data.get('subtraction_file', [])

if not subtraction_file_input:
raise ClientAlertException('Missing a subtraction file')
if not subtraction_file_input: raise ClientAlertException('Missing a subtraction file')
if len(input_files) < 1: raise ClientAlertException('Need at least one input file')

if len(input_files) < 1:
raise ClientAlertException('Need at least one input file')
log.info(f'Subtraction operation on {len(input_files)} files')

log.info(f'Executing subtraction operation on {len(input_files)} files')

input_image_data_list = self.get_fits_npdata(input_files)

subtraction_image = self.get_fits_npdata(subtraction_file_input)[0]
self.set_operation_progress(0.70)
subtraction_fits = InputDataHandler(subtraction_file_input[0]['basename'], subtraction_file_input[0]['source'])
input_fits_list = []
for index, input in enumerate(input_files, start=1):
input_fits_list.append(InputDataHandler(input['basename'], input['source']))
self.set_operation_progress(0.5 * (index / len(input_files)))

outputs = []
for index, input_image in enumerate(input_image_data_list):
for index, input_image in enumerate(input_fits_list, start=1):
# crop the input_image and subtraction_image to the same size
input_image, subtraction_image = crop_arrays([input_image, subtraction_image])
input_image, subtraction_image = crop_arrays([input_image.sci_data, subtraction_fits.sci_data])

difference_array = np.subtract(input_image, subtraction_image)

subtraction_comment = f'Product of Datalab Subtraction of {subtraction_file_input[0]["basename"]} subtracted from {input_files[index]["basename"]}'
outputs.append(create_output(self.cache_key, difference_array, index=index, comment=subtraction_comment))

self.set_operation_progress(0.90)
subtraction_comment = f'Datalab Subtraction of {subtraction_file_input[0]["basename"]} subtracted from {input_files[index-1]["basename"]}'
outputs.append(FITSOutputHandler(f'{self.cache_key}', difference_array, subtraction_comment).create_and_save_data_products(index=index))
self.set_operation_progress(0.5 + index/len(input_fits_list) * 0.4)

log.info(f'Subtraction output: {outputs}')
self.set_output(outputs)
log.info(f'Subtraction output: {self.get_output()}')
Binary file modified datalab/datalab_session/tests/test_files/median/median_1_2.fits
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit a9f5658

Please sign in to comment.