Skip to content

Commit

Permalink
Bugfix for Opening Files from the Cloud (#121)
Browse files Browse the repository at this point in the history
* better handling of s3 file access

* Check access level of bucket, modify unit test

* Add print statement

* Add timeout to HTTP call
  • Loading branch information
snbianco authored Jul 26, 2024
1 parent f47fc89 commit 11b0567
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
25 changes: 19 additions & 6 deletions astrocut/asdf_cutouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
import pathlib
from typing import Union, Tuple
import requests

import asdf
import astropy
Expand All @@ -16,19 +17,28 @@
from astropy.modeling import models


def _get_cloud_http(s3_uri: Union[str, S3Path]) -> str:
def _get_cloud_http(s3_uri: Union[str, S3Path], verbose: bool = False) -> str:
"""
Get the HTTP URI of a cloud resource from an S3 URI.
Parameters
----------
s3_uri : string | S3Path
the S3 URI of the cloud resource
verbose : bool
Default False. If true intermediate information is printed.
"""
# create file system
fs = s3fs.S3FileSystem()

# open resource and get URL
# check if public or private by sending an HTTP request
s3_path = S3Path.from_uri(s3_uri) if isinstance(s3_uri, str) else s3_uri
url = f'https://{s3_path.bucket}.s3.amazonaws.com/{s3_path.key}'
resp = requests.head(url, timeout=10)
is_anon = False if resp.status_code == 403 else True
if verbose and not is_anon:
print(f'Attempting to access private S3 bucket: {s3_path.bucket}')

# create file system and get URL of file
fs = s3fs.S3FileSystem(anon=is_anon)
with fs.open(s3_uri, 'rb') as f:
return f.url()

Expand Down Expand Up @@ -242,7 +252,8 @@ def _write_asdf(cutout: astropy.nddata.Cutout2D, gwcsobj: gwcs.wcs.WCS, outfile:

def asdf_cut(input_file: Union[str, pathlib.Path, S3Path], ra: float, dec: float, cutout_size: int = 20,
output_file: Union[str, pathlib.Path] = "example_roman_cutout.fits",
write_file: bool = True, fill_value: Union[int, float] = np.nan) -> astropy.nddata.Cutout2D:
write_file: bool = True, fill_value: Union[int, float] = np.nan,
verbose: bool = False) -> astropy.nddata.Cutout2D:
"""
Takes a single ASDF input file (`input_file`) and generates a cutout of designated size `cutout_size`
around the given coordinates (`coordinates`).
Expand All @@ -265,6 +276,8 @@ def asdf_cut(input_file: Union[str, pathlib.Path, S3Path], ra: float, dec: float
Optional, default True. Flag to write the cutout to a file or not.
fill_value: int | float
Optional, default `np.nan`. The fill value for pixels outside the original image.
verbose : bool
Default False. If true intermediate information is printed.
Returns
-------
Expand All @@ -275,7 +288,7 @@ def asdf_cut(input_file: Union[str, pathlib.Path, S3Path], ra: float, dec: float
# if file comes from AWS cloud bucket, get HTTP URL to open with asdf
file = input_file
if (isinstance(input_file, str) and input_file.startswith('s3://')) or isinstance(input_file, S3Path):
file = _get_cloud_http(input_file)
file = _get_cloud_http(input_file, verbose)

# get the 2d image data
with asdf.open(file) as f:
Expand Down
20 changes: 15 additions & 5 deletions astrocut/tests/test_asdf_cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,27 +315,37 @@ def test_slice_gwcs(fakedata):
assert (gwcsobj.footprint(bounding_box=tuple(reversed(cutout.bbox_original))) == sliced.footprint()).all()


@patch('requests.head')
@patch('s3fs.S3FileSystem')
def test_get_cloud_http(mock_s3fs):
def test_get_cloud_http(mock_s3fs, mock_requests):
""" test we can get HTTP URI of cloud resource """
# mock HTTP response
mock_resp = MagicMock()
mock_resp.status_code = 200 # public bucket
mock_requests.return_value = mock_resp

# mock s3 file system operations
HTTP_URI = "http_test"
mock_fs = mock_s3fs.return_value
mock_file = MagicMock()
mock_fs = MagicMock()
mock_file.url.return_value = HTTP_URI
mock_fs.open.return_value.__enter__.return_value = mock_file
mock_s3fs.return_value = mock_fs

# test function with string input
s3_uri = "s3://test_bucket/test_file.asdf"
http_uri = _get_cloud_http(s3_uri)
assert http_uri == HTTP_URI
mock_s3fs.assert_called_once_with()
mock_s3fs.assert_called_with(anon=True)
mock_fs.open.assert_called_once_with(s3_uri, 'rb')
mock_file.url.assert_called_once()

# test function with S3Path input
s3_uri_path = S3Path("test_bucket/test_file_2.asdf")
s3_uri_path = S3Path("/test_bucket/test_file_2.asdf")
http_uri_path = _get_cloud_http(s3_uri_path)
assert http_uri_path == HTTP_URI
mock_fs.open.assert_called_with(s3_uri_path, 'rb')

# test function with private bucket
mock_resp.status_code = 403
http_uri = _get_cloud_http(s3_uri)
mock_s3fs.assert_called_with(anon=False)

0 comments on commit 11b0567

Please sign in to comment.