Skip to content

Commit

Permalink
fits_file_reader and output_handler classes, changes to operations, t…
Browse files Browse the repository at this point in the history
…ests
  • Loading branch information
LTDakin committed Oct 9, 2024
1 parent 55d5395 commit 80fcb80
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 87 deletions.
25 changes: 25 additions & 0 deletions datalab/datalab_session/data_operations/fits_file_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from astropy.io import fits

from datalab.datalab_session.s3_utils import get_fits
from datalab.datalab_session.file_utils import get_hdu

class FITSFileReader:

basename = None
fits_file = None
hdu_list = None

def __init__(self, basename: str, source: str = None) -> None:
self.basename = basename
self.fits_file = get_fits(basename, source)
self.hdu_list = fits.open(self.fits_file)

def __str__(self) -> str:
return f"{self.basename}@{self.fits_file}\nHDU List\n{self.hdu_list.info()}"

@property
def sci_data(self):
return self.hdu_list['SCI'].data

def hdu(self, extension: str):
return get_hdu(self.fits_file, extension)
35 changes: 35 additions & 0 deletions datalab/datalab_session/data_operations/fits_output_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import tempfile
import numpy as np
from astropy.io import fits

from datalab.datalab_session.file_utils import create_jpgs
from datalab.datalab_session.s3_utils import save_fits_and_thumbnails


class FITSOutputHandler():

def __init__(self, key: str, data: np.array, comment: str=None) -> None:
self.key = key
self.primary_hdu = fits.PrimaryHDU(header=fits.Header([('KEY', key)]))
self.image_hdu = fits.ImageHDU(data=data, name='SCI')
if comment: self.set_comment(comment)

def __str__(self) -> str:
return f"Key: {self.key}\nData:\n{self.data}"

def set_comment(self, comment: str):
self.primary_hdu.header.add_comment(comment)

def set_sci_data(self, new_data: np.array):
self.image_hdu.data = new_data

def create_save_fits(self, index: int=None, large_jpg: str=None, small_jpg: str=None):
hdu_list = fits.HDUList([self.primary_hdu, self.image_hdu])
fits_output_path = tempfile.NamedTemporaryFile(suffix=f'{self.key}.fits').name
hdu_list.writeto(fits_output_path, overwrite=True)

# allow for operations to pregenerate the jpgs, ex. RGB stacking
if not large_jpg or not small_jpg:
large_jpg, small_jpg = create_jpgs(self.key, fits_output_path)

return save_fits_and_thumbnails(self.key, fits_output_path, large_jpg, small_jpg, index)
32 changes: 15 additions & 17 deletions datalab/datalab_session/data_operations/median.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import numpy as np

from datalab.datalab_session.data_operations.fits_file_reader import FITSFileReader
from datalab.datalab_session.data_operations.fits_output_handler import FITSOutputHandler
from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from datalab.datalab_session.exceptions import ClientAlertException
from datalab.datalab_session.file_utils import crop_arrays, create_output
from datalab.datalab_session.file_utils import crop_arrays

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -40,25 +42,21 @@ def wizard_description():
}

def operate(self):
# Getting/Checking the Input
input_list = self.input_data.get('input_files', [])
if len(input_list) <= 1: raise ClientAlertException('Median needs at least 2 files')
comment = f'Datalab Median on {", ".join([image["basename"] for image in input_list])}'
log.info(comment)

input = self.input_data.get('input_files', [])
input_FITS_list = [FITSFileReader(input['basename'], input['source']) for input in input_list]

if len(input) <= 1:
raise ClientAlertException('Median needs at least 2 files')

log.info(f'Executing median operation on {len(input)} files')

image_data_list = self.get_fits_npdata(input)

cropped_data_list = crop_arrays(image_data_list)
stacked_data = np.stack(cropped_data_list, axis=2)

# using the numpy library's median method
median = np.median(stacked_data, axis=2)
# Creating the Median array
cropped_data = crop_arrays([image.sci_data for image in input_FITS_list])
stacked_ndarray = np.stack(cropped_data, axis=2)
median = np.median(stacked_ndarray, axis=2)

self.set_operation_progress(0.80)

output = create_output(self.cache_key, median, comment=f'Product of Datalab Median on files {", ".join([image["basename"] for image in input])}')

output = FITSOutputHandler(self.cache_key, median, comment).create_save_fits()
log.info(f'Median output: {output}')
self.set_output(output)
log.info(f'Median output: {self.get_output()}')
23 changes: 12 additions & 11 deletions datalab/datalab_session/data_operations/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import numpy as np

from datalab.datalab_session.data_operations.fits_file_reader import FITSFileReader
from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from datalab.datalab_session.file_utils import create_output
from datalab.datalab_session.data_operations.fits_output_handler import FITSOutputHandler

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -40,20 +41,20 @@ def wizard_description():

def operate(self):

input = self.input_data.get('input_files', [])
input_list = self.input_data.get('input_files', [])
log.info(f'Normalization operation on {len(input_list)} file(s)')

log.info(f'Executing normalization operation on {len(input)} file(s)')

image_data_list = self.get_fits_npdata(input)
input_FITS_list = [FITSFileReader(input['basename'], input['source']) for input in input_list]

output_files = []
for index, image in enumerate(image_data_list, start=1):
median = np.median(image)
normalized_image = image / median
for index, image in enumerate(input_FITS_list, start=1):
median = np.median(image.sci_data)
normalized_image = image.sci_data / median

output = create_output(self.cache_key, normalized_image, index=index, comment=f'Product of Datalab Normalization on file {input[index-1]["basename"]}')
comment = f'Datalab Normalization on file {input_list[index-1]["basename"]}'
output = FITSOutputHandler(f'{self.cache_key}', normalized_image, comment).create_save_fits(index=index)
output_files.append(output)
self.set_operation_progress(0.5 + index/len(image_data_list) * 0.4)
self.set_operation_progress(0.5 + index/len(input_FITS_list) * 0.4)

log.info(f'Normalization output: {output_files}')
self.set_output(output_files)
log.info(f'Normalization output: {self.get_output()}')
35 changes: 16 additions & 19 deletions datalab/datalab_session/data_operations/rgb_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from astropy.io import fits
import numpy as np

from datalab.datalab_session.data_operations.fits_file_reader import FITSFileReader
from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from datalab.datalab_session.data_operations.fits_output_handler import FITSOutputHandler
from datalab.datalab_session.exceptions import ClientAlertException
from datalab.datalab_session.file_utils import create_output, crop_arrays, create_jpgs
from datalab.datalab_session.s3_utils import get_fits
from datalab.datalab_session.file_utils import crop_arrays, create_jpgs

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -58,28 +59,24 @@ def wizard_description():

def operate(self):
rgb_input_list = self.input_data['red_input'] + self.input_data['green_input'] + self.input_data['blue_input']
if len(rgb_input_list) != 3: raise ClientAlertException('RGB stack requires exactly 3 files')
rgb_comment = f'Datalab RGB Stack on files {", ".join([image["basename"] for image in rgb_input_list])}'
log.info(rgb_comment)

if len(rgb_input_list) != 3:
raise ClientAlertException('RGB stack requires exactly 3 files')

log.info(f'Executing RGB Stack operation on files: {rgb_input_list}')
input_FITS_list = [FITSFileReader(input['basename'], input['source']) for input in rgb_input_list]
self.set_operation_progress(0.4)

fits_paths = []
for index, file in enumerate(rgb_input_list, start=1):
fits_paths.append(get_fits(file.get('basename')))
self.set_operation_progress(index * 0.2)

large_jpg_path, small_jpg_path = create_jpgs(self.cache_key, fits_paths, color=True)
fits_file_list = [image.fits_file for image in input_FITS_list]
large_jpg_path, small_jpg_path = create_jpgs(self.cache_key, fits_file_list, color=True)
self.set_operation_progress(0.6)

# color photos take three files, so we store it as one fits file with a 3d SCI ndarray
arrays = [fits.open(file)['SCI'].data for file in fits_paths]
cropped_data_list = crop_arrays(arrays)
stacked_data = np.stack(cropped_data_list, axis=2)

sci_data_list = [image.sci_data for image in input_FITS_list]
cropped_data_list = crop_arrays(sci_data_list)
stacked_ndarray = np.stack(cropped_data_list, axis=2)
self.set_operation_progress(0.8)

rgb_comment = f'Product of Datalab RGB Stack on files {", ".join([image["basename"] for image in rgb_input_list])}'
output = create_output(self.cache_key, stacked_data, large_jpg=large_jpg_path, small_jpg=small_jpg_path, comment=rgb_comment)
output = FITSOutputHandler(self.cache_key, stacked_ndarray, rgb_comment).create_save_fits(large_jpg=large_jpg_path, small_jpg=small_jpg_path)

log.info(f'RGB Stack output: {output}')
self.set_output(output)
log.info(f'RGB Stack output: {self.get_output()}')
26 changes: 13 additions & 13 deletions datalab/datalab_session/data_operations/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import numpy as np

from datalab.datalab_session.data_operations.fits_file_reader import FITSFileReader
from datalab.datalab_session.data_operations.fits_output_handler import FITSOutputHandler
from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from datalab.datalab_session.exceptions import ClientAlertException
from datalab.datalab_session.file_utils import create_output, crop_arrays
from datalab.datalab_session.file_utils import crop_arrays

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -43,24 +45,22 @@ def wizard_description():
def operate(self):

input_files = self.input_data.get('input_files', [])
if len(input_files) <= 1: raise ClientAlertException('Stack needs at least 2 files')
comment= f'Datalab Stacking on {", ".join([image["basename"] for image in input_files])}'
log.info(comment)

if len(input_files) <= 1:
raise ClientAlertException('Stack needs at least 2 files')
input_FITS_list = [FITSFileReader(input['basename'], input['source']) for input in input_files]
self.set_operation_progress(0.4)

log.info(f'Executing stacking operation on {len(input_files)} files')

image_data_list = self.get_fits_npdata(input_files)

cropped_data = crop_arrays(image_data_list)
stacked_data = np.stack(cropped_data, axis=2)
cropped_data = crop_arrays([image.sci_data for image in input_FITS_list])
stacked_ndarray = np.stack(cropped_data, axis=2)
self.set_operation_progress(0.6)

# using the numpy library's sum method
stacked_sum = np.sum(stacked_data, axis=2)
stacked_sum = np.sum(stacked_ndarray, axis=2)
self.set_operation_progress(0.8)

stacking_comment = f'Product of Datalab Stacking. Stack of {", ".join([image["basename"] for image in input_files])}'
output = create_output(self.cache_key, stacked_sum, comment=stacking_comment)
output = FITSOutputHandler(self.cache_key, stacked_sum, comment).create_save_fits()

log.info(f'Stacked output: {output}')
self.set_output(output)
log.info(f'Stacked output: {self.get_output()}')
33 changes: 15 additions & 18 deletions datalab/datalab_session/data_operations/subtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import numpy as np

from datalab.datalab_session.data_operations.fits_file_reader import FITSFileReader
from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from datalab.datalab_session.data_operations.fits_output_handler import FITSOutputHandler
from datalab.datalab_session.exceptions import ClientAlertException
from datalab.datalab_session.file_utils import crop_arrays, create_output
from datalab.datalab_session.file_utils import crop_arrays

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -52,30 +54,25 @@ def operate(self):
input_files = self.input_data.get('input_files', [])
subtraction_file_input = self.input_data.get('subtraction_file', [])

if not subtraction_file_input:
raise ClientAlertException('Missing a subtraction file')
if not subtraction_file_input: raise ClientAlertException('Missing a subtraction file')
if len(input_files) < 1: raise ClientAlertException('Need at least one input file')

if len(input_files) < 1:
raise ClientAlertException('Need at least one input file')
log.info(f'Subtraction operation on {len(input_files)} files')

log.info(f'Executing subtraction operation on {len(input_files)} files')

input_image_data_list = self.get_fits_npdata(input_files)

subtraction_image = self.get_fits_npdata(subtraction_file_input)[0]
self.set_operation_progress(0.70)
input_FITS_list = [FITSFileReader(input['basename'], input['source']) for input in input_files]
subtraction_FITS = FITSFileReader(subtraction_file_input[0]['basename'], subtraction_file_input[0]['source'])
self.set_operation_progress(0.5)

outputs = []
for index, input_image in enumerate(input_image_data_list):
for index, input_image in enumerate(input_FITS_list, start=1):
# crop the input_image and subtraction_image to the same size
input_image, subtraction_image = crop_arrays([input_image, subtraction_image])
input_image, subtraction_image = crop_arrays([input_image.sci_data, subtraction_FITS.sci_data])

difference_array = np.subtract(input_image, subtraction_image)

subtraction_comment = f'Product of Datalab Subtraction of {subtraction_file_input[0]["basename"]} subtracted from {input_files[index]["basename"]}'
outputs.append(create_output(self.cache_key, difference_array, index=index, comment=subtraction_comment))

self.set_operation_progress(0.90)
subtraction_comment = f'Datalab Subtraction of {subtraction_file_input[0]["basename"]} subtracted from {input_files[index-1]["basename"]}'
outputs.append(FITSOutputHandler(f'{self.cache_key}', difference_array, subtraction_comment).create_save_fits(index=index))
self.set_operation_progress(0.5 + index/len(input_FITS_list) * 0.4)

log.info(f'Subtraction output: {outputs}')
self.set_output(outputs)
log.info(f'Subtraction output: {self.get_output()}')
Binary file modified datalab/datalab_session/tests/test_files/median/median_1_2.fits
Binary file not shown.
Binary file not shown.
18 changes: 9 additions & 9 deletions datalab/datalab_session/tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def tearDown(self):
return super().tearDown()

@mock.patch('datalab.datalab_session.file_utils.tempfile.NamedTemporaryFile')
@mock.patch('datalab.datalab_session.data_operations.data_operation.get_fits')
@mock.patch('datalab.datalab_session.file_utils.save_fits_and_thumbnails')
@mock.patch('datalab.datalab_session.file_utils.create_jpgs')
@mock.patch('datalab.datalab_session.data_operations.fits_file_reader.get_fits')
@mock.patch('datalab.datalab_session.data_operations.fits_output_handler.save_fits_and_thumbnails')
@mock.patch('datalab.datalab_session.data_operations.fits_output_handler.create_jpgs')
def test_operate(self, mock_create_jpgs, mock_save_fits_and_thumbnails, mock_get_fits, mock_named_tempfile):

# return the test fits paths in order of the input_files instead of aws fetch
Expand Down Expand Up @@ -221,10 +221,10 @@ def tearDown(self):
self.clean_test_dir()
return super().tearDown()

@mock.patch('datalab.datalab_session.file_utils.save_fits_and_thumbnails')
@mock.patch('datalab.datalab_session.file_utils.create_jpgs')
@mock.patch('datalab.datalab_session.data_operations.fits_output_handler.save_fits_and_thumbnails')
@mock.patch('datalab.datalab_session.data_operations.fits_output_handler.create_jpgs')
@mock.patch('datalab.datalab_session.file_utils.tempfile.NamedTemporaryFile')
@mock.patch('datalab.datalab_session.data_operations.rgb_stack.get_fits')
@mock.patch('datalab.datalab_session.data_operations.fits_file_reader.get_fits')
def test_operate(self, mock_get_fits, mock_named_tempfile, mock_create_jpgs, mock_save_fits_and_thumbnails):

# return the test fits paths in order of the input_files instead of aws fetch
Expand Down Expand Up @@ -265,9 +265,9 @@ def tearDown(self):
return super().tearDown()

@mock.patch('datalab.datalab_session.file_utils.tempfile.NamedTemporaryFile')
@mock.patch('datalab.datalab_session.data_operations.data_operation.get_fits')
@mock.patch('datalab.datalab_session.file_utils.save_fits_and_thumbnails')
@mock.patch('datalab.datalab_session.file_utils.create_jpgs')
@mock.patch('datalab.datalab_session.data_operations.fits_file_reader.get_fits')
@mock.patch('datalab.datalab_session.data_operations.fits_output_handler.save_fits_and_thumbnails')
@mock.patch('datalab.datalab_session.data_operations.fits_output_handler.create_jpgs')
def test_operate(self, mock_create_jpgs, mock_save_fits_and_thumbnails, mock_get_fits, mock_named_tempfile):

# Create a negative images using numpy
Expand Down

0 comments on commit 80fcb80

Please sign in to comment.