Skip to content

Commit

Permalink
Merge pull request #683 from oda-hub/support-head-download-endpoint
Browse files Browse the repository at this point in the history
Support head download endpoint, use bind conf in case of invalid products_url
  • Loading branch information
burnout87 authored May 28, 2024
2 parents 3f4eb17 + 08d0575 commit b2acf7d
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 15 deletions.
12 changes: 9 additions & 3 deletions cdci_data_analysis/analysis/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ def parse_inputs_files(self,
use_scws,
upload_dir,
products_url,
bind_host,
bind_port,
request_files_dir,
decoded_token,
sentry_dsn=None):
Expand Down Expand Up @@ -277,7 +279,9 @@ def parse_inputs_files(self,
step = 'updating par_dic with the uploaded files'
self.update_par_dic_with_uploaded_files(par_dic=par_dic,
uploaded_files_obj=uploaded_files_obj,
products_url=products_url)
products_url=products_url,
bind_host=bind_host,
bind_port=bind_port)
step = 'updating ownership files'
self.update_ownership_files(uploaded_files_obj,
request_files_dir=request_files_dir,
Expand Down Expand Up @@ -704,13 +708,15 @@ def set_input_products_from_fronted(self, input_file_path, par_dic, verbose=Fals
else:
raise RuntimeError

def update_par_dic_with_uploaded_files(self, par_dic, uploaded_files_obj, products_url):
def update_par_dic_with_uploaded_files(self, par_dic, uploaded_files_obj, products_url, bind_host, bind_port):
if validators.url(products_url):
# TODO remove the dispatch-data part, better to have it extracted from the configuration file
basepath = os.path.join(products_url, 'dispatch-data/download_file')
else:
basepath = os.path.join(products_url, 'download_file')
basepath = os.path.join(f"http://{bind_host}:{bind_port}", 'download_file')
for f in uploaded_files_obj:
dpars = urlencode(dict(file_list=uploaded_files_obj[f],
_is_mmoda_url=True,
return_archive=False))
download_file_url = f"{basepath}?{dpars}"
par_dic[f] = download_file_url
Expand Down
10 changes: 6 additions & 4 deletions cdci_data_analysis/flask_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def run_api_instr_list():
logger.warning('\nThe endpoint \'/api/instr-list\' is deprecated and you will be automatically redirected to the '
'\'/instr-list\' endpoint. Please use this one in the future.\n')

if app.config['conf'].products_url is not None and validators.url(app.config['conf'].products_url):
if app.config['conf'].products_url is not None and validators.url(app.config['conf'].products_url, simple_host=True):
# TODO remove the dispatch-data part, better to have it extracted from the configuration file
redirection_url = os.path.join(app.config['conf'].products_url, 'dispatch-data/instr-list')
if request.args:
args_request = urlencode(request.args)
Expand Down Expand Up @@ -148,7 +149,7 @@ def meta_data_src():
return query.get_meta_data('src_query')


@app.route("/download_products", methods=['POST', 'GET'])
@app.route("/download_products", methods=['POST', 'GET', 'HEAD'])
def download_products():
from_request_files_dir = request.args.get('from_request_files_dir', 'False') == 'True'
download_file = request.args.get('download_file', 'False') == 'True'
Expand All @@ -157,9 +158,10 @@ def download_products():
return query.download_file(from_request_files_dir=from_request_files_dir)


@app.route("/download_file", methods=['POST', 'GET'])
@app.route("/download_file", methods=['POST', 'GET', 'HEAD'])
def download_file():
if app.config['conf'].products_url is not None and validators.url(app.config['conf'].products_url):
if app.config['conf'].products_url is not None and validators.url(app.config['conf'].products_url, simple_host=True):
# TODO remove the dispatch-data part, better to have it extracted from the configuration file
redirection_url = os.path.join(app.config['conf'].products_url, 'dispatch-data/download_products')
if request.args:
args_request = urlencode(request.args)
Expand Down
13 changes: 10 additions & 3 deletions cdci_data_analysis/flask_app/dispatcher_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __init__(self, app,

try:
if par_dic is None:
self.set_args(request, verbose=verbose)
self.set_args(request, verbose=verbose, download_files=download_files, download_products=download_products)
else:
self.par_dic = par_dic
self.log_query_progression("after set args")
Expand Down Expand Up @@ -293,6 +293,8 @@ def __init__(self, app,
"When we find a solution we will try to reach you", status_code=500)
if self.instrument is not None and not isinstance(self.instrument, str):
products_url = self.app.config.get('conf').products_url
bind_host = self.app.config.get('conf').bind_host
bind_port = self.app.config.get('conf').bind_port
self.instrument.parse_inputs_files(
par_dic=self.par_dic,
request=request,
Expand All @@ -301,6 +303,8 @@ def __init__(self, app,
use_scws=self.use_scws,
upload_dir=self.request_files_dir,
products_url=products_url,
bind_host=bind_host,
bind_port=bind_port,
request_files_dir=self.request_files_dir,
decoded_token=self.decoded_token,
sentry_dsn=self.sentry_dsn
Expand Down Expand Up @@ -848,8 +852,11 @@ def set_scws_call_back_related_params(self):
if self.use_scws is None:
self.use_scws = 'form_list'

def set_args(self, request, verbose=False):
if request.method in ['GET', 'POST']:
def set_args(self, request, verbose=False, download_products=False, download_files=False):
supported_methods = ['GET', 'POST']
if download_files or download_products:
supported_methods.append('HEAD')
if request.method in supported_methods:
args = request.values
else:
raise NotImplementedError
Expand Down
33 changes: 33 additions & 0 deletions cdci_data_analysis/pytest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,19 @@ def dispatcher_test_conf_with_external_products_url_fn(dispatcher_test_conf_fn):
yield fn


@pytest.fixture
def dispatcher_test_conf_with_default_route_products_url_fn(dispatcher_test_conf_fn):
fn = dispatcher_test_conf_fn
with open(fn, "r+") as f:
data = f.read()
data = re.sub('(\s+products_url:).*\n', '\n products_url: http://0.0.0.0:1234/mmoda/\n', data)
f.seek(0)
f.write(data)
f.truncate()

yield fn


@pytest.fixture
def dispatcher_test_conf_no_resubmit_timeout_fn(dispatcher_test_conf_fn):
fn = dispatcher_test_conf_fn
Expand Down Expand Up @@ -677,6 +690,13 @@ def dispatcher_test_conf_with_external_products_url(dispatcher_test_conf_with_ex
yield loaded_yaml['dispatcher']


@pytest.fixture
def dispatcher_test_conf_with_default_route_products_url(dispatcher_test_conf_with_default_route_products_url_fn):
with open(dispatcher_test_conf_with_default_route_products_url_fn) as yaml_f:
loaded_yaml = yaml.load(yaml_f, Loader=yaml.SafeLoader)
yield loaded_yaml['dispatcher']


def dispatcher_test_conf_with_no_resubmit_timeout(dispatcher_test_conf_with_no_resubmit_timeout_fn):
with open(dispatcher_test_conf_with_no_resubmit_timeout_fn) as yaml_f:
loaded_yaml = yaml.load(yaml_f, Loader=yaml.SafeLoader)
Expand Down Expand Up @@ -1126,6 +1146,19 @@ def dispatcher_live_fixture_with_external_products_url(pytestconfig, dispatcher_
os.kill(pid, signal.SIGINT)


@pytest.fixture
def dispatcher_live_fixture_with_default_route_products_url(pytestconfig, dispatcher_test_conf_with_default_route_products_url_fn, dispatcher_debug):
dispatcher_state = start_dispatcher(pytestconfig.rootdir, dispatcher_test_conf_with_default_route_products_url_fn)

service = dispatcher_state['url']
pid = dispatcher_state['pid']

yield service

kill_child_processes(pid, signal.SIGINT)
os.kill(pid, signal.SIGINT)


@pytest.fixture
def dispatcher_live_fixture_with_renku_options(pytestconfig, dispatcher_test_conf_with_renku_options_fn, dispatcher_debug):
dispatcher_state = start_dispatcher(pytestconfig.rootdir, dispatcher_test_conf_with_renku_options_fn)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ MarkupSafe==2.0.1
# TODO: needed by some plugins: migrate
simple_logger
matplotlib
validators==0.20.0
validators==0.28.3
pillow>=10.0.1 # not directly required, pinned by Snyk to avoid a vulnerability
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"nbformat",
"giturlparse",
"sentry-sdk",
"validators==0.20.0",
"validators==0.28.3",
"jsonschema"
]

Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
dispatcher_test_conf_with_gallery_fn,
dispatcher_test_conf_with_gallery_no_resolver_fn,
dispatcher_live_fixture_with_external_products_url,
dispatcher_live_fixture_with_default_route_products_url,
dispatcher_test_conf_with_external_products_url_fn,
dispatcher_test_conf_with_default_route_products_url_fn,
dispatcher_test_conf_with_external_products_url,
dispatcher_test_conf_with_default_route_products_url,
dispatcher_test_conf_no_resubmit_timeout_fn,
dispatcher_test_conf_with_matrix_options,
dispatcher_test_conf_with_matrix_options_fn,
Expand Down
130 changes: 127 additions & 3 deletions tests/test_server_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,39 @@ def test_download_products_public(dispatcher_long_living_fixture, empty_products

assert data_downloaded == empty_products_files_fixture['content']


@pytest.mark.fast
def test_head_download_products_public(dispatcher_long_living_fixture, empty_products_files_fixture):
server = dispatcher_long_living_fixture

logger.info("constructed server: %s", server)

session_id = empty_products_files_fixture['session_id']
job_id = empty_products_files_fixture['job_id']

params = {
'query_status': 'ready',
'file_list': 'test.fits.gz',
'download_file_name': 'output_test',
'session_id': session_id,
'job_id': job_id
}

c = requests.head(server + "/download_products",
params=params)

assert c.status_code == 200
file_path = f'scratch_sid_{session_id}_jid_{job_id}/test.fits.gz'
with open(file_path, "rb") as f_in:
in_data = f_in.read()
archived_file_path = f'scratch_sid_{session_id}_jid_{job_id}/output_test'
with gzip.open(archived_file_path, 'wb') as f:
f.write(in_data)
# download the output, read it and then compare it
size = os.path.getsize(archived_file_path)
assert int(c.headers['Content-Length']) == size


@pytest.mark.fast
def test_download_products_aliased_dir(dispatcher_live_fixture):
DispatcherJobState.remove_scratch_folders()
Expand Down Expand Up @@ -627,6 +660,31 @@ def test_download_file_redirection_external_products_url(dispatcher_live_fixture
assert redirection_url == redirection_header_location_url


@pytest.mark.fast
@pytest.mark.parametrize("include_args", [True, False])
def test_download_file_redirection_default_route_products_url(dispatcher_live_fixture_with_default_route_products_url,
dispatcher_test_conf_with_default_route_products_url,
include_args):
server = dispatcher_live_fixture_with_default_route_products_url

logger.info("constructed server: %s", server)

url_request = os.path.join(server, "download_file")

if include_args:
url_request += '?a=4566&token=aaaaaaaaaa'

c = requests.get(url_request, allow_redirects=False)

assert c.status_code == 302
redirection_header_location_url = c.headers["Location"]
redirection_url = os.path.join(dispatcher_test_conf_with_default_route_products_url['products_url'], 'dispatch-data/download_products')
if include_args:
redirection_url += '?a=4566&token=aaaaaaaaaa'
redirection_url += '&from_request_files_dir=True&download_file=True&download_products=False'
assert redirection_url == redirection_header_location_url


@pytest.mark.fast
@pytest.mark.parametrize("include_args", [True, False])
def test_download_file_redirection_no_custom_products_url(dispatcher_live_fixture_no_products_url,
Expand Down Expand Up @@ -687,6 +745,45 @@ def test_download_file_public(dispatcher_long_living_fixture, request_files_fixt

assert data_downloaded == request_files_fixture['content']


@pytest.mark.fast
@pytest.mark.parametrize('return_archive', [True, False])
@pytest.mark.parametrize('matching_file_name', [True, False])
def test_head_download_file(dispatcher_long_living_fixture, request_files_fixture, return_archive, matching_file_name):
DispatcherJobState.create_local_request_files_folder()
server = dispatcher_long_living_fixture

logger.info("constructed server: %s", server)

params = {
'file_list': os.path.basename(request_files_fixture['file_path']),
'download_file_name': 'output_test',
'return_archive': return_archive,
}

if matching_file_name:
params['download_file_name'] = params['file_list']

c = requests.head(server + "/download_file",
allow_redirects=True,
params=params)

assert c.status_code == 200

if return_archive:
with open(request_files_fixture['file_path'], "rb") as f_in:
in_data = f_in.read()
archived_file_path = f'local_request_files/{params["download_file_name"]}'
with gzip.open(archived_file_path, 'wb') as f:
f.write(in_data)
# download the output, read it and then compare it
size = os.path.getsize(archived_file_path)
else:
size = os.path.getsize(request_files_fixture['file_path'])

assert int(c.headers['Content-Length']) == size


def test_query_restricted_instrument(dispatcher_live_fixture):
server = dispatcher_live_fixture

Expand Down Expand Up @@ -782,6 +879,30 @@ def test_instrument_list_redirection_external_products_url(dispatcher_live_fixtu
assert redirection_url == redirection_header_location_url


@pytest.mark.fast
@pytest.mark.parametrize("include_args", [True, False])
def test_instrument_list_redirection_default_route_products_url(dispatcher_live_fixture_with_default_route_products_url,
dispatcher_test_conf_with_default_route_products_url,
include_args):
server = dispatcher_live_fixture_with_default_route_products_url

logger.info("constructed server: %s", server)

url_request = os.path.join(server, "api/instr-list")

if include_args:
url_request += '?a=4566&token=aaaaaaaaaa'

c = requests.get(url_request, allow_redirects=False)

assert c.status_code == 302
redirection_header_location_url = c.headers["Location"]
redirection_url = os.path.join(dispatcher_test_conf_with_default_route_products_url['products_url'], 'dispatch-data/instr-list')
if include_args:
redirection_url += '?a=4566&token=aaaaaaaaaa'
assert redirection_url == redirection_header_location_url


@pytest.mark.fast
@pytest.mark.parametrize("allow_redirect", [True, False])
@pytest.mark.parametrize("include_args", [True, False])
Expand Down Expand Up @@ -1442,7 +1563,7 @@ def test_numerical_authorization_user_roles(dispatcher_live_fixture, roles):


@pytest.mark.parametrize("public_download_request", [True, False])
def test_arg_file(dispatcher_live_fixture, public_download_request):
def test_arg_file(dispatcher_live_fixture, dispatcher_test_conf, public_download_request):
DispatcherJobState.remove_scratch_folders()
DispatcherJobState.empty_request_files_folders()
server = dispatcher_live_fixture
Expand Down Expand Up @@ -1491,12 +1612,15 @@ def test_arg_file(dispatcher_live_fixture, public_download_request):
assert len(args_dict['file_list']) == 1
assert os.path.exists(f'request_files/{args_dict["file_list"][0]}')

arg_download_url = jdata['products']['analysis_parameters']['dummy_file'].replace('PRODUCTS_URL/', server)
products_host_port = f"http://{dispatcher_test_conf['bind_options']['bind_host']}:{dispatcher_test_conf['bind_options']['bind_port']}"

arg_download_url = jdata['products']['analysis_parameters']['dummy_file'].replace('PRODUCTS_URL/', products_host_port)

file_hash = make_hash_file(p_file_path)
dpars = urlencode(dict(file_list=file_hash,
_is_mmoda_url=True,
return_archive=False))
local_download_url = f"{os.path.join(server, 'download_file')}?{dpars}"
local_download_url = f"{os.path.join(products_host_port, 'download_file')}?{dpars}"

assert arg_download_url == local_download_url

Expand Down

0 comments on commit b2acf7d

Please sign in to comment.