From 2a994dd475811fb7f35a8cd2644edb7460bc2296 Mon Sep 17 00:00:00 2001 From: Narendran Thiagarajan Date: Fri, 5 Oct 2018 10:53:34 -0700 Subject: [PATCH] Update data mount errors (#208) * Warn user about namespace mismatch Before mounting a dataset, ensure that the dataset namespace matches if it is private * Remove incorrect message * Refactor the logic for getting the data object * Use the refactored function call from data * Remove unused import * Fix mocked modules --- floyd/cli/data.py | 49 ++++++++++++++++++--------------------- floyd/cli/run.py | 42 ++++++++++++++++++++------------- floyd/client/data.py | 2 +- floyd/model/data.py | 5 +++- tests/cli/run/test_run.py | 4 ++-- 5 files changed, 55 insertions(+), 47 deletions(-) diff --git a/floyd/cli/data.py b/floyd/cli/data.py index b56dbec..7313531 100644 --- a/floyd/cli/data.py +++ b/floyd/cli/data.py @@ -100,18 +100,28 @@ def status(id): The command also accepts a specific dataset version. """ if id: - data_source = DataClient().get(normalize_data_name(id)) - - if not data_source: - # Try with the raw ID - data_source = DataClient().get(id) - + data_source = get_data_object(id, use_data_config=False) print_data([data_source] if data_source else []) else: data_sources = DataClient().get_all() print_data(data_sources) +def get_data_object(data_id, use_data_config=True): + """ + Normalize the data_id and query the server. + If that is unavailable try the raw ID + """ + normalized_data_reference = normalize_data_name(data_id, use_data_config=use_data_config) + data_obj = DataClient().get(normalized_data_reference) + + # Try with the raw ID + if not data_obj and data_id != normalized_data_reference: + data_obj = DataClient().get(id) + + return data_obj + + def print_data(data_sources): """ Print dataset information in tabular form @@ -134,11 +144,7 @@ def clone(id): """ Download all files in a dataset. """ - - data_source = DataClient().get(normalize_data_name(id, use_data_config=False)) - if id and not data_source: - # Try with the raw ID - data_source = DataClient().get(id) + data_source = get_data_object(id, use_data_config=False) if not data_source: if 'output' in id: @@ -159,10 +165,7 @@ def listfiles(data_name): List files in a dataset. """ - data_source = DataClient().get(normalize_data_name(data_name, use_data_config=False)) - if data_name and not data_source: - # Try with the raw ID - data_source = DataClient().get(data_name) + data_source = get_data_object(data_name, use_data_config=False) if not data_source: if 'output' in data_name: @@ -202,10 +205,7 @@ def getfile(data_name, path): Download a specific file from a dataset. """ - data_source = DataClient().get(normalize_data_name(data_name, use_data_config=False)) - if data_name and not data_source: - # Try with the raw ID - data_source = DataClient().get(data_name) + data_source = get_data_object(data_name, use_data_config=False) if not data_source: if 'output' in data_name: @@ -215,6 +215,7 @@ def getfile(data_name, path): url = "{}/api/v1/resources/{}/{}?content=true".format(floyd.floyd_host, data_source.resource_id, path) fname = os.path.basename(path) DataClient().download(url, filename=fname) + floyd_logger.info("Download finished") @click.command() @@ -225,10 +226,7 @@ def output(id, url): """ View the files from a dataset. """ - data_source = DataClient().get(normalize_data_name(id)) - if id and not data_source: - # Try with the raw ID - data_source = DataClient().get(id) + data_source = get_data_object(id, use_data_config=False) if not data_source: sys.exit() @@ -252,10 +250,7 @@ def delete(ids, yes): failures = False for id in ids: - data_source = DataClient().get(normalize_data_name(id)) - if not data_source: - # Try with the raw ID - data_source = DataClient().get(id) + data_source = get_data_object(id, use_data_config=True) if not data_source: failures = True diff --git a/floyd/cli/run.py b/floyd/cli/run.py index 1642ffd..112b398 100644 --- a/floyd/cli/run.py +++ b/floyd/cli/run.py @@ -16,7 +16,6 @@ INSTANCE_NAME_MAP, INSTANCE_TYPE_MAP, ) -from floyd.client.data import DataClient from floyd.client.project import ProjectClient from floyd.cli.utils import ( get_data_name, normalize_data_name, normalize_job_name @@ -35,12 +34,13 @@ from floyd.model.experiment import ExperimentRequest from floyd.log import logger as floyd_logger from floyd.exceptions import BadRequestException +from floyd.cli.data import get_data_object from floyd.cli.experiment import get_log_id, follow_logs -from floyd.cli.utils import read_yaml_config +from floyd.cli.utils import current_project_namespace, read_yaml_config -def process_data_ids(data): - if len(data) > 5: +def process_data_ids(data_ids): + if len(data_ids) > 5: floyd_logger.error( "Cannot attach more than 5 datasets to a job") return False, None @@ -48,27 +48,37 @@ def process_data_ids(data): # Get the data entity from the server to: # 1. Confirm that the data id or uri exists and has the right permissions # 2. If uri is used, get the id of the dataset - data_ids = [] - for data_name_or_id in data: + processed_data_ids = [] + + for data_name_or_id in data_ids: path = None if ':' in data_name_or_id: data_name_or_id, path = data_name_or_id.split(':') - data_name_or_id = normalize_data_name(data_name_or_id, use_data_config=False) - - data_obj = DataClient().get(normalize_data_name(data_name_or_id, use_data_config=False)) - if not data_obj: - # Try with the raw ID - data_obj = DataClient().get(data_name_or_id) + data_obj = get_data_object(data_id=data_name_or_id, use_data_config=False) if not data_obj: - floyd_logger.error("Data not found for name or id: {}".format(data_name_or_id)) + floyd_logger.error( + "Data not found for name: {}. " + "Check if the data name is correct and you have permission to access it.".format(data_name_or_id) + ) return False, None + + # If data is private, check if the namespaces match + if not data_obj.public: + data_namespace = data_obj.name.split('/')[0] + if not data_namespace == current_project_namespace(): + floyd_logger.error( + "Data is private and can only be attached to projects in its own namespace ({}): {}".format( + data_namespace, data_name_or_id) + ) + return False, None + if path: - data_ids.append("%s:%s" % (data_obj.id, path)) + processed_data_ids.append("%s:%s" % (data_obj.id, path)) else: - data_ids.append(data_obj.id) - return True, data_ids + processed_data_ids.append(data_obj.id) + return True, processed_data_ids def resolve_final_instance_type(instance_type_override, yaml_str, task, cli_default): diff --git a/floyd/client/data.py b/floyd/client/data.py index 58b96ef..d534fdc 100644 --- a/floyd/client/data.py +++ b/floyd/client/data.py @@ -70,7 +70,7 @@ def get(self, id): return Data.from_dict(data_dict) except FloydException as e: - floyd_logger.info("Data %s: ERROR! %s\nIf you have already created the dataset, make sure you have uploaded at least one version.", id, e.message) + floyd_logger.info("Data %s: ERROR! %s\n", id, e.message) return None def get_all(self): diff --git a/floyd/model/data.py b/floyd/model/data.py index 5653ca7..ea5726e 100644 --- a/floyd/model/data.py +++ b/floyd/model/data.py @@ -33,6 +33,7 @@ class DataSchema(Schema): data = fields.Nested(DataDetailsSchema) version = fields.Str(allow_none=True) resource_id = fields.Str(allow_none=True) + public = fields.Boolean(allow_none=True) @post_load def make_data(self, data): @@ -49,7 +50,8 @@ def __init__(self, description, data, version=None, - resource_id=None): + resource_id=None, + public=None): self.id = id self.name = name self.created = self.localize_date(created) @@ -58,6 +60,7 @@ def __init__(self, self.state = data.state self.version = int(float(version)) if version else None self.resource_id = resource_id + self.public = public def localize_date(self, date): if not date.tzinfo: diff --git a/tests/cli/run/test_run.py b/tests/cli/run/test_run.py index 532dd54..a8562bd 100644 --- a/tests/cli/run/test_run.py +++ b/tests/cli/run/test_run.py @@ -42,7 +42,7 @@ def test_with_no_data(self, assert_exit_code(result, 0) @patch('floyd.manager.data_config.DataConfigManager.get_config', side_effect=mock_data_config) - @patch('floyd.cli.run.DataClient.get') + @patch('floyd.cli.data.DataClient.get') @patch('floyd.cli.run.EnvClient.get_all', return_value={'cpu': {'default': 'bar'}}) @patch('floyd.cli.run.AuthConfigManager.get_access_token', side_effect=mock_access_token) @patch('floyd.cli.run.AuthConfigManager.get_auth_header', return_value="Bearer " + mock_access_token().token) @@ -135,8 +135,8 @@ def test_multiple_envs_fails(self, result = self.runner.invoke(run, ['--env', 'foo', 'ls']) assert_exit_code(result, 0) - @patch('floyd.cli.run.DataClient.get') @patch('floyd.model.access_token.assert_token_not_expired') + @patch('floyd.cli.data.DataClient.get') @patch('floyd.cli.run.EnvClient.get_all', return_value={'cpu': {'foo': 'foo', 'bar': 'bar'}}) @patch('floyd.cli.run.AuthConfigManager.get_access_token', side_effect=mock_access_token) @patch('floyd.cli.run.AuthConfigManager.get_auth_header', return_value="Bearer " + mock_access_token().token)