diff --git a/cdci_data_analysis/flask_app/dispatcher_query.py b/cdci_data_analysis/flask_app/dispatcher_query.py index 2bc93744..3c9244d0 100644 --- a/cdci_data_analysis/flask_app/dispatcher_query.py +++ b/cdci_data_analysis/flask_app/dispatcher_query.py @@ -913,6 +913,21 @@ def clear_temp_dir(self, temp_scratch_dir=None): if temp_scratch_dir is not None and temp_scratch_dir != self.scratch_dir and os.path.exists(temp_scratch_dir): shutil.rmtree(temp_scratch_dir) + @staticmethod + def validated_download_file_path(basepath, filename, should_exist=True): + # basic arg validation + if "../" in filename or filename.startswith(os.sep): + raise RequestNotAuthorized('No such file') + + # still explicitly check if the file is in the dir + base_abs = os.path.realpath(basepath) + file_abs = os.path.realpath(os.path.join(basepath, filename)) + + if (os.path.commonpath([base_abs]) != os.path.commonpath([base_abs, file_abs]) + or (should_exist and not os.path.isfile(file_abs)) ): + raise RequestNotAuthorized('No such file') + return file_abs + def prepare_download(self, file_list, file_name, scratch_dir): file_name = file_name.replace(' ', '_') @@ -922,18 +937,17 @@ def prepare_download(self, file_list, file_name, scratch_dir): file_list = [file_list] for ID, f in enumerate(file_list): - file_list[ID] = os.path.join(scratch_dir + '/', f) + file_list[ID] = self.validated_download_file_path(scratch_dir, f) tmp_dir = tempfile.mkdtemp(prefix='download_', dir='./') - file_path = os.path.join(tmp_dir, file_name) + file_path = self.validated_download_file_path(tmp_dir, file_name, should_exist=False) out_dir = file_name.replace('.tar', '') out_dir = out_dir.replace('.gz', '') if len(file_list) > 1: tar = tarfile.open("%s" % (file_path), "w:gz") for name in file_list: - #print('add to tar', file_name,name) if name is not None: tar.add(name, arcname='%s/%s' % (out_dir, os.path.basename(name))) @@ -1055,6 +1069,7 @@ def download_products(self): self.validate_job_id(request_parameters_from_scratch_dir=True) file_list = self.args.get('file_list').split(',') + file_name = self.args.get('download_file_name') tmp_dir, target_file = self.prepare_download( diff --git a/tests/test_server_basic.py b/tests/test_server_basic.py index be39ae94..535a87c7 100644 --- a/tests/test_server_basic.py +++ b/tests/test_server_basic.py @@ -1,7 +1,7 @@ import re import shutil import urllib - +import io import requests import time import uuid @@ -439,6 +439,59 @@ def test_download_products_public(dispatcher_long_living_fixture, empty_products assert data_downloaded == empty_products_files_fixture['content'] +@pytest.mark.fast +@pytest.mark.parametrize('filelist', ['../external_file', '/tmp/external_file', 'test.fits.gz']) +@pytest.mark.parametrize('outname', ['/tmp/output_test', '../output_test', 'output_test']) +def test_download_products_outside_dir(dispatcher_long_living_fixture, + empty_products_files_fixture, + filelist, + outname): + server = dispatcher_long_living_fixture + + is_good = True if filelist == 'test.fits.gz' and outname == 'output_test' else False + logger.info("constructed server: %s", server) + + session_id = empty_products_files_fixture['session_id'] + job_id = empty_products_files_fixture['job_id'] + + if not is_good: + with open(filelist.replace('../', ''), 'w') as outb: + outb.write('__confidential__') + + params = { + 'instrument': 'any_name', + # since we are passing a job_id + 'query_status': 'ready', + 'file_list': filelist, + 'download_file_name': outname, + 'session_id': session_id, + 'job_id': job_id + } + + c = requests.get(server + "/download_products", + params=params) + + if is_good: + assert c.status_code == 200 + # further checks in previous test + else: + assert c.status_code == 403 + + # check the output anyway + assert b"__confidential__" not in c.content + if hasattr(c, 'text'): + assert "__confidential__" not in c.text + + with io.BytesIO() as outb: + outb.write(c.content) + outb.seek(0) + gz = gzip.GzipFile(fileobj=outb, mode='rb') + with pytest.raises(gzip.BadGzipFile): + gz.read() + try: + os.remove(filelist.replace('../', '')) + except: + pass def test_query_restricted_instrument(dispatcher_live_fixture): server = dispatcher_live_fixture