diff --git a/webserver/forms.py b/webserver/forms.py
index b1d0268e4..b95f80176 100644
--- a/webserver/forms.py
+++ b/webserver/forms.py
@@ -1,3 +1,4 @@
+
from wtforms import BooleanField, SelectField, StringField, TextAreaField
from wtforms.validators import DataRequired
from flask_wtf import FlaskForm
@@ -22,6 +23,7 @@ class DatasetCSVImportForm(FlaskForm):
FileRequired(),
FileAllowed(["csv"], "Dataset needs to be in CSV format"),
])
+ public = BooleanField('Make this dataset public')
class DatasetEvaluationForm(FlaskForm):
@@ -35,4 +37,3 @@ class DatasetEvaluationForm(FlaskForm):
default = DATASET_EVAL_LOCAL
)
normalize = BooleanField("Normalize classes")
-
diff --git a/webserver/templates/datasets/import.html b/webserver/templates/datasets/import.html
index bf0c643fa..d428190a8 100644
--- a/webserver/templates/datasets/import.html
+++ b/webserver/templates/datasets/import.html
@@ -55,6 +55,13 @@
Import dataset
{{ form.file(class="form-control", required="required") }}
+
+
+
+
+
diff --git a/webserver/views/datasets.py b/webserver/views/datasets.py
index 7a222ba37..2cc93d54b 100644
--- a/webserver/views/datasets.py
+++ b/webserver/views/datasets.py
@@ -308,8 +308,9 @@ def import_csv():
"name": form.name.data,
"description": description if description else form.description.data,
"classes": classes,
- "public": True,
+ "public": form.public.data,
}
+
try:
dataset_id = db.dataset.create_from_dict(dataset_dict, current_user.id)
except dataset_validator.ValidationException as e:
diff --git a/webserver/views/test/test_datasets.py b/webserver/views/test/test_datasets.py
index 421ba6f28..cf53ac49f 100644
--- a/webserver/views/test/test_datasets.py
+++ b/webserver/views/test/test_datasets.py
@@ -403,7 +403,7 @@ def test_import_csv_valid(self, mock_create_from_dict, mock_parse_dataset_csv):
with open(test_csv_file) as csv_data:
resp = self.client.post(url_for('datasets.import_csv'),
- data={'name': 'dataset', 'file': (csv_data, 'dataset.csv')},
+ data={'name': 'dataset', 'public': 'y', 'file': (csv_data, 'dataset.csv')},
content_type='multipart/form-data')
# Validation succeeds, we redirect to the dataset view
self.assertRedirects(resp, '/datasets/%s' % ds_id)
@@ -415,6 +415,33 @@ def test_import_csv_valid(self, mock_create_from_dict, mock_parse_dataset_csv):
}
mock_create_from_dict.assert_called_with(expected_ds, self.test_user_id)
+ @mock.patch("webserver.views.datasets._parse_dataset_csv")
+ @mock.patch("db.dataset.create_from_dict")
+ def test_import_csv_private(self, mock_create_from_dict, mock_parse_dataset_csv):
+ """Upload a dataset with the visibility settings set to private"""
+ ds_id = '417fe34f-c124-47c0-b602-54d7164a8deb'
+ mock_create_from_dict.return_value = ds_id
+ mock_parse_dataset_csv.return_value = 'a desc', ['class_data']
+
+ self.temporary_login(self.test_user_id)
+ test_csv_file = os.path.join(TEST_DATA_PATH, 'test_dataset.csv')
+
+ with open(test_csv_file) as csv_data:
+ # Because the 'public' flag is a checkbox, unticking it has the result of
+ # not sending the field with the request, therefore it's not in data, below
+ resp = self.client.post(url_for('datasets.import_csv'),
+ data={'name': 'dataset', 'file': (csv_data, 'dataset.csv')},
+ content_type='multipart/form-data')
+ # Validation succeeds, we redirect to the dataset view
+ self.assertRedirects(resp, '/datasets/%s' % ds_id)
+ expected_ds = {
+ 'name': 'dataset',
+ 'description': 'a desc',
+ 'classes': ['class_data'],
+ 'public': False,
+ }
+ mock_create_from_dict.assert_called_with(expected_ds, self.test_user_id)
+
class DatasetsListTestCase(ServerTestCase):