Skip to content

Commit

Permalink
Update data mount errors (#208)
Browse files Browse the repository at this point in the history
* Warn user about namespace mismatch

Before mounting a dataset, ensure that the dataset namespace matches if
it is private

* Remove incorrect message

* Refactor the logic for getting the data object

* Use the refactored function call from data

* Remove unused import

* Fix mocked modules
  • Loading branch information
narenst authored Oct 5, 2018
1 parent 1e3fe30 commit 2a994dd
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 47 deletions.
49 changes: 22 additions & 27 deletions floyd/cli/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,28 @@ def status(id):
The command also accepts a specific dataset version.
"""
if id:
data_source = DataClient().get(normalize_data_name(id))

if not data_source:
# Try with the raw ID
data_source = DataClient().get(id)

data_source = get_data_object(id, use_data_config=False)
print_data([data_source] if data_source else [])
else:
data_sources = DataClient().get_all()
print_data(data_sources)


def get_data_object(data_id, use_data_config=True):
"""
Normalize the data_id and query the server.
If that is unavailable try the raw ID
"""
normalized_data_reference = normalize_data_name(data_id, use_data_config=use_data_config)
data_obj = DataClient().get(normalized_data_reference)

# Try with the raw ID
if not data_obj and data_id != normalized_data_reference:
data_obj = DataClient().get(id)

return data_obj


def print_data(data_sources):
"""
Print dataset information in tabular form
Expand All @@ -134,11 +144,7 @@ def clone(id):
"""
Download all files in a dataset.
"""

data_source = DataClient().get(normalize_data_name(id, use_data_config=False))
if id and not data_source:
# Try with the raw ID
data_source = DataClient().get(id)
data_source = get_data_object(id, use_data_config=False)

if not data_source:
if 'output' in id:
Expand All @@ -159,10 +165,7 @@ def listfiles(data_name):
List files in a dataset.
"""

data_source = DataClient().get(normalize_data_name(data_name, use_data_config=False))
if data_name and not data_source:
# Try with the raw ID
data_source = DataClient().get(data_name)
data_source = get_data_object(data_name, use_data_config=False)

if not data_source:
if 'output' in data_name:
Expand Down Expand Up @@ -202,10 +205,7 @@ def getfile(data_name, path):
Download a specific file from a dataset.
"""

data_source = DataClient().get(normalize_data_name(data_name, use_data_config=False))
if data_name and not data_source:
# Try with the raw ID
data_source = DataClient().get(data_name)
data_source = get_data_object(data_name, use_data_config=False)

if not data_source:
if 'output' in data_name:
Expand All @@ -215,6 +215,7 @@ def getfile(data_name, path):
url = "{}/api/v1/resources/{}/{}?content=true".format(floyd.floyd_host, data_source.resource_id, path)
fname = os.path.basename(path)
DataClient().download(url, filename=fname)
floyd_logger.info("Download finished")


@click.command()
Expand All @@ -225,10 +226,7 @@ def output(id, url):
"""
View the files from a dataset.
"""
data_source = DataClient().get(normalize_data_name(id))
if id and not data_source:
# Try with the raw ID
data_source = DataClient().get(id)
data_source = get_data_object(id, use_data_config=False)

if not data_source:
sys.exit()
Expand All @@ -252,10 +250,7 @@ def delete(ids, yes):
failures = False

for id in ids:
data_source = DataClient().get(normalize_data_name(id))
if not data_source:
# Try with the raw ID
data_source = DataClient().get(id)
data_source = get_data_object(id, use_data_config=True)

if not data_source:
failures = True
Expand Down
42 changes: 26 additions & 16 deletions floyd/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
INSTANCE_NAME_MAP,
INSTANCE_TYPE_MAP,
)
from floyd.client.data import DataClient
from floyd.client.project import ProjectClient
from floyd.cli.utils import (
get_data_name, normalize_data_name, normalize_job_name
Expand All @@ -35,40 +34,51 @@
from floyd.model.experiment import ExperimentRequest
from floyd.log import logger as floyd_logger
from floyd.exceptions import BadRequestException
from floyd.cli.data import get_data_object
from floyd.cli.experiment import get_log_id, follow_logs
from floyd.cli.utils import read_yaml_config
from floyd.cli.utils import current_project_namespace, read_yaml_config


def process_data_ids(data):
if len(data) > 5:
def process_data_ids(data_ids):
if len(data_ids) > 5:
floyd_logger.error(
"Cannot attach more than 5 datasets to a job")
return False, None

# Get the data entity from the server to:
# 1. Confirm that the data id or uri exists and has the right permissions
# 2. If uri is used, get the id of the dataset
data_ids = []
for data_name_or_id in data:
processed_data_ids = []

for data_name_or_id in data_ids:
path = None
if ':' in data_name_or_id:
data_name_or_id, path = data_name_or_id.split(':')
data_name_or_id = normalize_data_name(data_name_or_id, use_data_config=False)

data_obj = DataClient().get(normalize_data_name(data_name_or_id, use_data_config=False))

if not data_obj:
# Try with the raw ID
data_obj = DataClient().get(data_name_or_id)
data_obj = get_data_object(data_id=data_name_or_id, use_data_config=False)

if not data_obj:
floyd_logger.error("Data not found for name or id: {}".format(data_name_or_id))
floyd_logger.error(
"Data not found for name: {}. "
"Check if the data name is correct and you have permission to access it.".format(data_name_or_id)
)
return False, None

# If data is private, check if the namespaces match
if not data_obj.public:
data_namespace = data_obj.name.split('/')[0]
if not data_namespace == current_project_namespace():
floyd_logger.error(
"Data is private and can only be attached to projects in its own namespace ({}): {}".format(
data_namespace, data_name_or_id)
)
return False, None

if path:
data_ids.append("%s:%s" % (data_obj.id, path))
processed_data_ids.append("%s:%s" % (data_obj.id, path))
else:
data_ids.append(data_obj.id)
return True, data_ids
processed_data_ids.append(data_obj.id)
return True, processed_data_ids


def resolve_final_instance_type(instance_type_override, yaml_str, task, cli_default):
Expand Down
2 changes: 1 addition & 1 deletion floyd/client/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get(self, id):

return Data.from_dict(data_dict)
except FloydException as e:
floyd_logger.info("Data %s: ERROR! %s\nIf you have already created the dataset, make sure you have uploaded at least one version.", id, e.message)
floyd_logger.info("Data %s: ERROR! %s\n", id, e.message)
return None

def get_all(self):
Expand Down
5 changes: 4 additions & 1 deletion floyd/model/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class DataSchema(Schema):
data = fields.Nested(DataDetailsSchema)
version = fields.Str(allow_none=True)
resource_id = fields.Str(allow_none=True)
public = fields.Boolean(allow_none=True)

@post_load
def make_data(self, data):
Expand All @@ -49,7 +50,8 @@ def __init__(self,
description,
data,
version=None,
resource_id=None):
resource_id=None,
public=None):
self.id = id
self.name = name
self.created = self.localize_date(created)
Expand All @@ -58,6 +60,7 @@ def __init__(self,
self.state = data.state
self.version = int(float(version)) if version else None
self.resource_id = resource_id
self.public = public

def localize_date(self, date):
if not date.tzinfo:
Expand Down
4 changes: 2 additions & 2 deletions tests/cli/run/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_with_no_data(self,
assert_exit_code(result, 0)

@patch('floyd.manager.data_config.DataConfigManager.get_config', side_effect=mock_data_config)
@patch('floyd.cli.run.DataClient.get')
@patch('floyd.cli.data.DataClient.get')
@patch('floyd.cli.run.EnvClient.get_all', return_value={'cpu': {'default': 'bar'}})
@patch('floyd.cli.run.AuthConfigManager.get_access_token', side_effect=mock_access_token)
@patch('floyd.cli.run.AuthConfigManager.get_auth_header', return_value="Bearer " + mock_access_token().token)
Expand Down Expand Up @@ -135,8 +135,8 @@ def test_multiple_envs_fails(self,
result = self.runner.invoke(run, ['--env', 'foo', 'ls'])
assert_exit_code(result, 0)

@patch('floyd.cli.run.DataClient.get')
@patch('floyd.model.access_token.assert_token_not_expired')
@patch('floyd.cli.data.DataClient.get')
@patch('floyd.cli.run.EnvClient.get_all', return_value={'cpu': {'foo': 'foo', 'bar': 'bar'}})
@patch('floyd.cli.run.AuthConfigManager.get_access_token', side_effect=mock_access_token)
@patch('floyd.cli.run.AuthConfigManager.get_auth_header', return_value="Bearer " + mock_access_token().token)
Expand Down

0 comments on commit 2a994dd

Please sign in to comment.