Skip to content

Commit

Permalink
Merge pull request #667 from oda-hub/fix-download-file-outside-session
Browse files Browse the repository at this point in the history
Fix vulnerability: arbitrary file download
  • Loading branch information
volodymyrss authored Mar 15, 2024
2 parents 6171102 + 6659547 commit c1b4607
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 4 deletions.
21 changes: 18 additions & 3 deletions cdci_data_analysis/flask_app/dispatcher_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,21 @@ def clear_temp_dir(self, temp_scratch_dir=None):
if temp_scratch_dir is not None and temp_scratch_dir != self.scratch_dir and os.path.exists(temp_scratch_dir):
shutil.rmtree(temp_scratch_dir)

@staticmethod
def validated_download_file_path(basepath, filename, should_exist=True):
# basic arg validation
if "../" in filename or filename.startswith(os.sep):
raise RequestNotAuthorized('No such file')

# still explicitly check if the file is in the dir
base_abs = os.path.realpath(basepath)
file_abs = os.path.realpath(os.path.join(basepath, filename))

if (os.path.commonpath([base_abs]) != os.path.commonpath([base_abs, file_abs])
or (should_exist and not os.path.isfile(file_abs)) ):
raise RequestNotAuthorized('No such file')
return file_abs

def prepare_download(self, file_list, file_name, scratch_dir):
file_name = file_name.replace(' ', '_')

Expand All @@ -922,18 +937,17 @@ def prepare_download(self, file_list, file_name, scratch_dir):
file_list = [file_list]

for ID, f in enumerate(file_list):
file_list[ID] = os.path.join(scratch_dir + '/', f)
file_list[ID] = self.validated_download_file_path(scratch_dir, f)

tmp_dir = tempfile.mkdtemp(prefix='download_', dir='./')

file_path = os.path.join(tmp_dir, file_name)
file_path = self.validated_download_file_path(tmp_dir, file_name, should_exist=False)
out_dir = file_name.replace('.tar', '')
out_dir = out_dir.replace('.gz', '')

if len(file_list) > 1:
tar = tarfile.open("%s" % (file_path), "w:gz")
for name in file_list:
#print('add to tar', file_name,name)
if name is not None:
tar.add(name, arcname='%s/%s' %
(out_dir, os.path.basename(name)))
Expand Down Expand Up @@ -1055,6 +1069,7 @@ def download_products(self):

self.validate_job_id(request_parameters_from_scratch_dir=True)
file_list = self.args.get('file_list').split(',')

file_name = self.args.get('download_file_name')

tmp_dir, target_file = self.prepare_download(
Expand Down
55 changes: 54 additions & 1 deletion tests/test_server_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import shutil
import urllib

import io
import requests
import time
import uuid
Expand Down Expand Up @@ -439,6 +439,59 @@ def test_download_products_public(dispatcher_long_living_fixture, empty_products

assert data_downloaded == empty_products_files_fixture['content']

@pytest.mark.fast
@pytest.mark.parametrize('filelist', ['../external_file', '/tmp/external_file', 'test.fits.gz'])
@pytest.mark.parametrize('outname', ['/tmp/output_test', '../output_test', 'output_test'])
def test_download_products_outside_dir(dispatcher_long_living_fixture,
empty_products_files_fixture,
filelist,
outname):
server = dispatcher_long_living_fixture

is_good = True if filelist == 'test.fits.gz' and outname == 'output_test' else False
logger.info("constructed server: %s", server)

session_id = empty_products_files_fixture['session_id']
job_id = empty_products_files_fixture['job_id']

if not is_good:
with open(filelist.replace('../', ''), 'w') as outb:
outb.write('__confidential__')

params = {
'instrument': 'any_name',
# since we are passing a job_id
'query_status': 'ready',
'file_list': filelist,
'download_file_name': outname,
'session_id': session_id,
'job_id': job_id
}

c = requests.get(server + "/download_products",
params=params)

if is_good:
assert c.status_code == 200
# further checks in previous test
else:
assert c.status_code == 403

# check the output anyway
assert b"__confidential__" not in c.content
if hasattr(c, 'text'):
assert "__confidential__" not in c.text

with io.BytesIO() as outb:
outb.write(c.content)
outb.seek(0)
gz = gzip.GzipFile(fileobj=outb, mode='rb')
with pytest.raises(gzip.BadGzipFile):
gz.read()
try:
os.remove(filelist.replace('../', ''))
except:
pass

def test_query_restricted_instrument(dispatcher_live_fixture):
server = dispatcher_live_fixture
Expand Down

0 comments on commit c1b4607

Please sign in to comment.