diff --git a/cdci_data_analysis/analysis/instrument.py b/cdci_data_analysis/analysis/instrument.py index 5e9e3195..205cc587 100644 --- a/cdci_data_analysis/analysis/instrument.py +++ b/cdci_data_analysis/analysis/instrument.py @@ -247,6 +247,8 @@ def parse_inputs_files(self, use_scws, upload_dir, products_url, + bind_host, + bind_port, request_files_dir, decoded_token, sentry_dsn=None): @@ -277,7 +279,9 @@ def parse_inputs_files(self, step = 'updating par_dic with the uploaded files' self.update_par_dic_with_uploaded_files(par_dic=par_dic, uploaded_files_obj=uploaded_files_obj, - products_url=products_url) + products_url=products_url, + bind_host=bind_host, + bind_port=bind_port) step = 'updating ownership files' self.update_ownership_files(uploaded_files_obj, request_files_dir=request_files_dir, @@ -704,13 +708,15 @@ def set_input_products_from_fronted(self, input_file_path, par_dic, verbose=Fals else: raise RuntimeError - def update_par_dic_with_uploaded_files(self, par_dic, uploaded_files_obj, products_url): + def update_par_dic_with_uploaded_files(self, par_dic, uploaded_files_obj, products_url, bind_host, bind_port): if validators.url(products_url): + # TODO remove the dispatch-data part, better to have it extracted from the configuration file basepath = os.path.join(products_url, 'dispatch-data/download_file') else: - basepath = os.path.join(products_url, 'download_file') + basepath = os.path.join(f"http://{bind_host}:{bind_port}", 'download_file') for f in uploaded_files_obj: dpars = urlencode(dict(file_list=uploaded_files_obj[f], + _is_mmoda_url=True, return_archive=False)) download_file_url = f"{basepath}?{dpars}" par_dic[f] = download_file_url diff --git a/cdci_data_analysis/flask_app/app.py b/cdci_data_analysis/flask_app/app.py index 1404e740..3188949f 100644 --- a/cdci_data_analysis/flask_app/app.py +++ b/cdci_data_analysis/flask_app/app.py @@ -100,7 +100,8 @@ def run_api_instr_list(): logger.warning('\nThe endpoint \'/api/instr-list\' is deprecated and you will be automatically redirected to the ' '\'/instr-list\' endpoint. Please use this one in the future.\n') - if app.config['conf'].products_url is not None and validators.url(app.config['conf'].products_url): + if app.config['conf'].products_url is not None and validators.url(app.config['conf'].products_url, simple_host=True): + # TODO remove the dispatch-data part, better to have it extracted from the configuration file redirection_url = os.path.join(app.config['conf'].products_url, 'dispatch-data/instr-list') if request.args: args_request = urlencode(request.args) @@ -148,7 +149,7 @@ def meta_data_src(): return query.get_meta_data('src_query') -@app.route("/download_products", methods=['POST', 'GET']) +@app.route("/download_products", methods=['POST', 'GET', 'HEAD']) def download_products(): from_request_files_dir = request.args.get('from_request_files_dir', 'False') == 'True' download_file = request.args.get('download_file', 'False') == 'True' @@ -157,9 +158,10 @@ def download_products(): return query.download_file(from_request_files_dir=from_request_files_dir) -@app.route("/download_file", methods=['POST', 'GET']) +@app.route("/download_file", methods=['POST', 'GET', 'HEAD']) def download_file(): - if app.config['conf'].products_url is not None and validators.url(app.config['conf'].products_url): + if app.config['conf'].products_url is not None and validators.url(app.config['conf'].products_url, simple_host=True): + # TODO remove the dispatch-data part, better to have it extracted from the configuration file redirection_url = os.path.join(app.config['conf'].products_url, 'dispatch-data/download_products') if request.args: args_request = urlencode(request.args) diff --git a/cdci_data_analysis/flask_app/dispatcher_query.py b/cdci_data_analysis/flask_app/dispatcher_query.py index 2ead6243..aed09c7e 100644 --- a/cdci_data_analysis/flask_app/dispatcher_query.py +++ b/cdci_data_analysis/flask_app/dispatcher_query.py @@ -139,7 +139,7 @@ def __init__(self, app, try: if par_dic is None: - self.set_args(request, verbose=verbose) + self.set_args(request, verbose=verbose, download_files=download_files, download_products=download_products) else: self.par_dic = par_dic self.log_query_progression("after set args") @@ -293,6 +293,8 @@ def __init__(self, app, "When we find a solution we will try to reach you", status_code=500) if self.instrument is not None and not isinstance(self.instrument, str): products_url = self.app.config.get('conf').products_url + bind_host = self.app.config.get('conf').bind_host + bind_port = self.app.config.get('conf').bind_port self.instrument.parse_inputs_files( par_dic=self.par_dic, request=request, @@ -301,6 +303,8 @@ def __init__(self, app, use_scws=self.use_scws, upload_dir=self.request_files_dir, products_url=products_url, + bind_host=bind_host, + bind_port=bind_port, request_files_dir=self.request_files_dir, decoded_token=self.decoded_token, sentry_dsn=self.sentry_dsn @@ -848,8 +852,11 @@ def set_scws_call_back_related_params(self): if self.use_scws is None: self.use_scws = 'form_list' - def set_args(self, request, verbose=False): - if request.method in ['GET', 'POST']: + def set_args(self, request, verbose=False, download_products=False, download_files=False): + supported_methods = ['GET', 'POST'] + if download_files or download_products: + supported_methods.append('HEAD') + if request.method in supported_methods: args = request.values else: raise NotImplementedError diff --git a/cdci_data_analysis/pytest_fixtures.py b/cdci_data_analysis/pytest_fixtures.py index a6d19069..85649a53 100644 --- a/cdci_data_analysis/pytest_fixtures.py +++ b/cdci_data_analysis/pytest_fixtures.py @@ -566,6 +566,19 @@ def dispatcher_test_conf_with_external_products_url_fn(dispatcher_test_conf_fn): yield fn +@pytest.fixture +def dispatcher_test_conf_with_default_route_products_url_fn(dispatcher_test_conf_fn): + fn = dispatcher_test_conf_fn + with open(fn, "r+") as f: + data = f.read() + data = re.sub('(\s+products_url:).*\n', '\n products_url: http://0.0.0.0:1234/mmoda/\n', data) + f.seek(0) + f.write(data) + f.truncate() + + yield fn + + @pytest.fixture def dispatcher_test_conf_no_resubmit_timeout_fn(dispatcher_test_conf_fn): fn = dispatcher_test_conf_fn @@ -677,6 +690,13 @@ def dispatcher_test_conf_with_external_products_url(dispatcher_test_conf_with_ex yield loaded_yaml['dispatcher'] +@pytest.fixture +def dispatcher_test_conf_with_default_route_products_url(dispatcher_test_conf_with_default_route_products_url_fn): + with open(dispatcher_test_conf_with_default_route_products_url_fn) as yaml_f: + loaded_yaml = yaml.load(yaml_f, Loader=yaml.SafeLoader) + yield loaded_yaml['dispatcher'] + + def dispatcher_test_conf_with_no_resubmit_timeout(dispatcher_test_conf_with_no_resubmit_timeout_fn): with open(dispatcher_test_conf_with_no_resubmit_timeout_fn) as yaml_f: loaded_yaml = yaml.load(yaml_f, Loader=yaml.SafeLoader) @@ -1126,6 +1146,19 @@ def dispatcher_live_fixture_with_external_products_url(pytestconfig, dispatcher_ os.kill(pid, signal.SIGINT) +@pytest.fixture +def dispatcher_live_fixture_with_default_route_products_url(pytestconfig, dispatcher_test_conf_with_default_route_products_url_fn, dispatcher_debug): + dispatcher_state = start_dispatcher(pytestconfig.rootdir, dispatcher_test_conf_with_default_route_products_url_fn) + + service = dispatcher_state['url'] + pid = dispatcher_state['pid'] + + yield service + + kill_child_processes(pid, signal.SIGINT) + os.kill(pid, signal.SIGINT) + + @pytest.fixture def dispatcher_live_fixture_with_renku_options(pytestconfig, dispatcher_test_conf_with_renku_options_fn, dispatcher_debug): dispatcher_state = start_dispatcher(pytestconfig.rootdir, dispatcher_test_conf_with_renku_options_fn) diff --git a/requirements.txt b/requirements.txt index 650d5089..94404d99 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,5 +36,5 @@ MarkupSafe==2.0.1 # TODO: needed by some plugins: migrate simple_logger matplotlib -validators==0.20.0 +validators==0.28.3 pillow>=10.0.1 # not directly required, pinned by Snyk to avoid a vulnerability \ No newline at end of file diff --git a/setup.py b/setup.py index e34595bc..d8cc4e21 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ "nbformat", "giturlparse", "sentry-sdk", - "validators==0.20.0", + "validators==0.28.3", "jsonschema" ] diff --git a/tests/conftest.py b/tests/conftest.py index fd6cd174..1f6659bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,8 +23,11 @@ dispatcher_test_conf_with_gallery_fn, dispatcher_test_conf_with_gallery_no_resolver_fn, dispatcher_live_fixture_with_external_products_url, + dispatcher_live_fixture_with_default_route_products_url, dispatcher_test_conf_with_external_products_url_fn, + dispatcher_test_conf_with_default_route_products_url_fn, dispatcher_test_conf_with_external_products_url, + dispatcher_test_conf_with_default_route_products_url, dispatcher_test_conf_no_resubmit_timeout_fn, dispatcher_test_conf_with_matrix_options, dispatcher_test_conf_with_matrix_options_fn, diff --git a/tests/test_server_basic.py b/tests/test_server_basic.py index fddd877d..f9305f96 100644 --- a/tests/test_server_basic.py +++ b/tests/test_server_basic.py @@ -482,6 +482,39 @@ def test_download_products_public(dispatcher_long_living_fixture, empty_products assert data_downloaded == empty_products_files_fixture['content'] + +@pytest.mark.fast +def test_head_download_products_public(dispatcher_long_living_fixture, empty_products_files_fixture): + server = dispatcher_long_living_fixture + + logger.info("constructed server: %s", server) + + session_id = empty_products_files_fixture['session_id'] + job_id = empty_products_files_fixture['job_id'] + + params = { + 'query_status': 'ready', + 'file_list': 'test.fits.gz', + 'download_file_name': 'output_test', + 'session_id': session_id, + 'job_id': job_id + } + + c = requests.head(server + "/download_products", + params=params) + + assert c.status_code == 200 + file_path = f'scratch_sid_{session_id}_jid_{job_id}/test.fits.gz' + with open(file_path, "rb") as f_in: + in_data = f_in.read() + archived_file_path = f'scratch_sid_{session_id}_jid_{job_id}/output_test' + with gzip.open(archived_file_path, 'wb') as f: + f.write(in_data) + # download the output, read it and then compare it + size = os.path.getsize(archived_file_path) + assert int(c.headers['Content-Length']) == size + + @pytest.mark.fast def test_download_products_aliased_dir(dispatcher_live_fixture): DispatcherJobState.remove_scratch_folders() @@ -627,6 +660,31 @@ def test_download_file_redirection_external_products_url(dispatcher_live_fixture assert redirection_url == redirection_header_location_url +@pytest.mark.fast +@pytest.mark.parametrize("include_args", [True, False]) +def test_download_file_redirection_default_route_products_url(dispatcher_live_fixture_with_default_route_products_url, + dispatcher_test_conf_with_default_route_products_url, + include_args): + server = dispatcher_live_fixture_with_default_route_products_url + + logger.info("constructed server: %s", server) + + url_request = os.path.join(server, "download_file") + + if include_args: + url_request += '?a=4566&token=aaaaaaaaaa' + + c = requests.get(url_request, allow_redirects=False) + + assert c.status_code == 302 + redirection_header_location_url = c.headers["Location"] + redirection_url = os.path.join(dispatcher_test_conf_with_default_route_products_url['products_url'], 'dispatch-data/download_products') + if include_args: + redirection_url += '?a=4566&token=aaaaaaaaaa' + redirection_url += '&from_request_files_dir=True&download_file=True&download_products=False' + assert redirection_url == redirection_header_location_url + + @pytest.mark.fast @pytest.mark.parametrize("include_args", [True, False]) def test_download_file_redirection_no_custom_products_url(dispatcher_live_fixture_no_products_url, @@ -687,6 +745,45 @@ def test_download_file_public(dispatcher_long_living_fixture, request_files_fixt assert data_downloaded == request_files_fixture['content'] + +@pytest.mark.fast +@pytest.mark.parametrize('return_archive', [True, False]) +@pytest.mark.parametrize('matching_file_name', [True, False]) +def test_head_download_file(dispatcher_long_living_fixture, request_files_fixture, return_archive, matching_file_name): + DispatcherJobState.create_local_request_files_folder() + server = dispatcher_long_living_fixture + + logger.info("constructed server: %s", server) + + params = { + 'file_list': os.path.basename(request_files_fixture['file_path']), + 'download_file_name': 'output_test', + 'return_archive': return_archive, + } + + if matching_file_name: + params['download_file_name'] = params['file_list'] + + c = requests.head(server + "/download_file", + allow_redirects=True, + params=params) + + assert c.status_code == 200 + + if return_archive: + with open(request_files_fixture['file_path'], "rb") as f_in: + in_data = f_in.read() + archived_file_path = f'local_request_files/{params["download_file_name"]}' + with gzip.open(archived_file_path, 'wb') as f: + f.write(in_data) + # download the output, read it and then compare it + size = os.path.getsize(archived_file_path) + else: + size = os.path.getsize(request_files_fixture['file_path']) + + assert int(c.headers['Content-Length']) == size + + def test_query_restricted_instrument(dispatcher_live_fixture): server = dispatcher_live_fixture @@ -782,6 +879,30 @@ def test_instrument_list_redirection_external_products_url(dispatcher_live_fixtu assert redirection_url == redirection_header_location_url +@pytest.mark.fast +@pytest.mark.parametrize("include_args", [True, False]) +def test_instrument_list_redirection_default_route_products_url(dispatcher_live_fixture_with_default_route_products_url, + dispatcher_test_conf_with_default_route_products_url, + include_args): + server = dispatcher_live_fixture_with_default_route_products_url + + logger.info("constructed server: %s", server) + + url_request = os.path.join(server, "api/instr-list") + + if include_args: + url_request += '?a=4566&token=aaaaaaaaaa' + + c = requests.get(url_request, allow_redirects=False) + + assert c.status_code == 302 + redirection_header_location_url = c.headers["Location"] + redirection_url = os.path.join(dispatcher_test_conf_with_default_route_products_url['products_url'], 'dispatch-data/instr-list') + if include_args: + redirection_url += '?a=4566&token=aaaaaaaaaa' + assert redirection_url == redirection_header_location_url + + @pytest.mark.fast @pytest.mark.parametrize("allow_redirect", [True, False]) @pytest.mark.parametrize("include_args", [True, False]) @@ -1442,7 +1563,7 @@ def test_numerical_authorization_user_roles(dispatcher_live_fixture, roles): @pytest.mark.parametrize("public_download_request", [True, False]) -def test_arg_file(dispatcher_live_fixture, public_download_request): +def test_arg_file(dispatcher_live_fixture, dispatcher_test_conf, public_download_request): DispatcherJobState.remove_scratch_folders() DispatcherJobState.empty_request_files_folders() server = dispatcher_live_fixture @@ -1491,12 +1612,15 @@ def test_arg_file(dispatcher_live_fixture, public_download_request): assert len(args_dict['file_list']) == 1 assert os.path.exists(f'request_files/{args_dict["file_list"][0]}') - arg_download_url = jdata['products']['analysis_parameters']['dummy_file'].replace('PRODUCTS_URL/', server) + products_host_port = f"http://{dispatcher_test_conf['bind_options']['bind_host']}:{dispatcher_test_conf['bind_options']['bind_port']}" + + arg_download_url = jdata['products']['analysis_parameters']['dummy_file'].replace('PRODUCTS_URL/', products_host_port) file_hash = make_hash_file(p_file_path) dpars = urlencode(dict(file_list=file_hash, + _is_mmoda_url=True, return_archive=False)) - local_download_url = f"{os.path.join(server, 'download_file')}?{dpars}" + local_download_url = f"{os.path.join(products_host_port, 'download_file')}?{dpars}" assert arg_download_url == local_download_url