Skip to content

Commit

Permalink
implementation of range requests, with unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
krokicki committed Oct 4, 2024
1 parent 4630e11 commit 2a91149
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 53 deletions.
65 changes: 57 additions & 8 deletions tests/test_awss3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from x2s3.settings import Target, Settings
from x2s3.utils import parse_xml

obj_path = '/janelia-data-examples/jrc_mus_lung_covid.n5/render/v1_acquire_align___20210609_224836/s0/0/0/0'

@pytest.fixture
def get_settings():
settings = Settings()
Expand Down Expand Up @@ -196,6 +198,61 @@ def test_virtual_host_get_object(app):
assert 'n5' in json_obj


def test_get_object_hidden(app):
with TestClient(app) as client:
response = client.get("/hidden-with-endpoint/jrc_mus_lung_covid.n5/attributes.json")
assert response.status_code == 200
json_obj = response.json()
assert 'n5' in json_obj


def test_get_object_range_first(app):
with TestClient(app) as client:
# Test a valid range request (first 100 bytes)
response = client.get(
obj_path,
headers={"Range": "bytes=0-99"}
)
assert response.status_code == 206 # Partial Content
assert 'Content-Range' in response.headers
assert response.headers['Content-Range'] == 'bytes 0-99/987143'
assert len(response.content) == 100

def test_get_object_range_mid(app):
with TestClient(app) as client:
# Test a valid range request (bytes 100-199)
response = client.get(
obj_path,
headers={"Range": "bytes=100-199"}
)
assert response.status_code == 206 # Partial Content
assert 'Content-Range' in response.headers
assert response.headers['Content-Range'] == 'bytes 100-199/987143'
assert len(response.content) == 100

def test_get_object_range_last(app):
with TestClient(app) as client:
# Test a valid range request (last 100 bytes)
response = client.get(
obj_path,
headers={"Range": "bytes=-100"}
)
assert response.status_code == 206 # Partial Content
assert 'Content-Range' in response.headers
assert len(response.content) == 100

def test_get_object_range_out_of_bounds(app):
with TestClient(app) as client:
# Test invalid range request (out of bounds)
response = client.get(
obj_path,
headers={"Range": "bytes=1000000-2000000"}
)
assert response.status_code == 416 # Range Not Satisfiable
root = parse_xml(response.text)
assert root.find('Code').text == 'InvalidRange'


def test_prefixed_list_objects(app):
with TestClient(app) as client:
bucket_name = 'with-prefix'
Expand All @@ -209,14 +266,6 @@ def test_prefixed_list_objects(app):
assert root.find('IsTruncated').text == "false"


def test_get_object_hidden(app):
with TestClient(app) as client:
response = client.get("/hidden-with-endpoint/jrc_mus_lung_covid.n5/attributes.json")
assert response.status_code == 200
json_obj = response.json()
assert 'n5' in json_obj


def test_get_object_missing(app):
with TestClient(app) as client:
response = client.get("/janelia-data-examples/missing")
Expand Down
6 changes: 4 additions & 2 deletions x2s3/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ async def target_dispatcher(request: Request,
else:
raise HTTPException(status_code=400, detail="Invalid list type")
else:
return await client.get_object(target_path)
range_header = request.headers.get("range")
return await client.get_object(target_path, range_header)

if not target_path or target_path.endswith("/"):
if app.settings.ui:
Expand All @@ -259,7 +260,8 @@ async def target_dispatcher(request: Request,
else:
return get_nosuchbucket_response(target_name)
else:
return await client.get_object(target_path)
range_header = request.headers.get("range")
return await client.get_object(target_path, range_header)



Expand Down
2 changes: 1 addition & 1 deletion x2s3/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async def head_object(self, key: str):
https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadObject.html
"""

async def get_object(self, key: str):
async def get_object(self, key: str, range_header: str = None):
"""
Basic interface for AWS S3's GetObject API.
https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObject.html
Expand Down
85 changes: 59 additions & 26 deletions x2s3/client_aioboto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ def handle_s3_exception(e, key=None):
elif isinstance(e, botocore.exceptions.ReadTimeoutError):
return JSONResponse({"error":"Upstream endpoint timed out"}, status_code=408)
elif isinstance(e, botocore.exceptions.ClientError):
code = e.response['ResponseMetadata']['HTTPStatusCode']
if e.response["Error"]["Code"] == "NoSuchKey":
return get_nosuchkey_response(key)
elif int(code) == 404 and key:
status_code = e.response['ResponseMetadata']['HTTPStatusCode']
error = e.response['Error']
error_code = error['Code'] if 'Code' in error else 'Unknown'
if error_code == "NoSuchKey":
return get_nosuchkey_response(key)
else:
logger.opt(exception=sys.exc_info()).error("Error using boto S3 API")
return JSONResponse({"error":"Error communicating with AWS S3"}, status_code=code)
message = error['Message'] if 'Message' in error else 'Unknown'
resource = error['Resource'] if 'Resource' in error else 'Unknown'
return get_error_response(status_code, error_code, message, resource)
else:
logger.opt(exception=sys.exc_info()).error("Error communicating with AWS S3")
return JSONResponse({"error":"Error communicating with AWS S3"}, status_code=500)
Expand Down Expand Up @@ -85,8 +86,9 @@ async def head_object(self, key: str):
s3_res = await client.head_object(Bucket=self.bucket_name, Key=real_key)
headers = {
"ETag": s3_res.get("ETag"),
"Accept-Ranges": "bytes",
"Content-Length": str(s3_res.get("ContentLength")),
"Last-Modified": s3_res.get("LastModified").strftime("%a, %d %b %Y %H:%M:%S GMT")
"Last-Modified": s3_res.get("LastModified").strftime("%a, %d %b %Y %H:%M:%S GMT"),
}

content_type = guess_content_type(real_key)
Expand All @@ -98,27 +100,32 @@ async def head_object(self, key: str):


@override
async def get_object(self, key: str):
async def get_object(self, key: str, range_header: str = None):
real_key = key
if self.bucket_prefix:
real_key = os.path.join(self.bucket_prefix, key) if key else self.bucket_prefix

filename = os.path.basename(real_key)
headers = {}

content_type = guess_content_type(filename)
headers['Content-Type'] = content_type

headers = {
'Accept-Ranges': "bytes",
'Content-Type': content_type,
}

if content_type=='application/octet-stream':
headers['Content-Disposition'] = f'attachment; filename="{filename}"'

try:
return S3Stream(
self.get_client_creator,
headers=headers,
media_type=content_type,
bucket=self.bucket_name,
key=key,
real_key=real_key,
media_type=content_type,
headers=headers)
range_header=range_header,
)
except Exception as e:
return handle_s3_exception(e, key)

Expand Down Expand Up @@ -211,17 +218,52 @@ def __init__(
bucket: str = None,
key: str = None,
real_key: str = None,
range_header: str = None
):
super(S3Stream, self).__init__(content, status_code, headers, media_type, background)
self.client_creator = client_creator
self.bucket = bucket
self.key = key
self.real_key = real_key
self.range_header = range_header

async def stream_response(self, send) -> None:

async def send_response(r):
await send({
"type": "http.response.start",
"status": r.status_code,
"headers": r.raw_headers,
})
await send({
"type": "http.response.body",
"body": r.body,
"more_body": False,
})

async with self.client_creator() as client:
result = None
try:
result = await client.get_object(Bucket=self.bucket, Key=self.real_key)
# Get the object with the range specified in headers
get_object_params = {
"Bucket": self.bucket,
"Key": self.real_key,
}
if self.range_header:
get_object_params["Range"] = self.range_header

result = await client.get_object(**get_object_params)
res_headers = result["ResponseMetadata"]["HTTPHeaders"]

# Determine if this is a Range result
if "content-range" in res_headers:
self.status_code = 206 # Partial Content
self.raw_headers.append((b"content-range",
res_headers["content-range"].encode('utf-8')))

if "content-length" in res_headers:
self.raw_headers.append((b"content-length",
res_headers["content-length"].encode('utf-8')))

await send({
"type": "http.response.start",
Expand All @@ -245,15 +287,6 @@ async def stream_response(self, send) -> None:
"body": b"",
"more_body": False})

except client.exceptions.NoSuchKey:
r = get_nosuchkey_response(self.key)
await send({
"type": "http.response.start",
"status": r.status_code,
"headers": r.raw_headers,
})
await send({
"type": "http.response.body",
"body": r.body,
"more_body": False,
})
except Exception as e:
r = handle_s3_exception(e, self.key)
await send_response(r)
2 changes: 1 addition & 1 deletion x2s3/client_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def head_object(self, key: str):


@override
async def get_object(self, key: str):
async def get_object(self, key: str, range_header: str = None):
try:
path = os.path.join(self.root_path, key)
if not os.path.isfile(path):
Expand Down
43 changes: 28 additions & 15 deletions x2s3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,29 +164,42 @@ def get_nosuchkey_response(key):

def get_nosuchbucket_response(bucket_name):
return Response(content=inspect.cleandoc(f"""
<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>NoSuchBucket</Code>
<Message>The specified bucket does not exist</Message>
<BucketName>{bucket_name}</BucketName>
</Error>
"""), status_code=404, media_type="application/xml")
<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>NoSuchBucket</Code>
<Message>The specified bucket does not exist</Message>
<BucketName>{bucket_name}</BucketName>
</Error>
"""), status_code=404, media_type="application/xml")


def get_accessdenied_response():
return Response(content=inspect.cleandoc("""
<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>AccessDenied</Code>
<Message>Access Denied</Message>
</Error>
"""), status_code=403, media_type="application/xml")
<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>AccessDenied</Code>
<Message>Access Denied</Message>
</Error>
"""), status_code=403, media_type="application/xml")


def get_error_response(status_code, error_code, message, resource):
return Response(content=inspect.cleandoc(f"""
<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>{error_code}</Code>
<Message>{message}</Message>
<Resource>{resource}</Resource>
</Error>
"""),
status_code=status_code,
media_type="application/xml")


def get_read_access_acl():
""" Returns an S3 ACL that grants full read access
"""
acl_xml = """
acl_xml = inspect.cleandoc("""
<AccessControlPolicy>
<Owner>
<ID>1</ID>
Expand All @@ -201,7 +214,7 @@ def get_read_access_acl():
</Grant>
</AccessControlList>
</AccessControlPolicy>
"""
""")
return Response(content=acl_xml, media_type="application/xml")


Expand Down

0 comments on commit 2a91149

Please sign in to comment.