-
-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding capability of taking auxiliary data #436
base: master
Are you sure you want to change the base?
Changes from 2 commits
9f34bca
399e692
55077a2
53b86be
9165fc3
4fe07cb
73f8f16
95c824f
fd482b2
d67aef6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,6 +55,84 @@ def load(self, dataset, fold=0): | |
else: | ||
raise ValueError(f"Unsupported file type: {ext}") | ||
|
||
@profile(logger=log) | ||
def load_auxiliary_data(self, auxiliary_data, fold=0): | ||
auxiliary_data = auxiliary_data if isinstance(auxiliary_data, ns) else ns(path=auxiliary_data) | ||
log.debug("Loading auxiliary data %s", auxiliary_data) | ||
paths = self._extract_auxiliary_paths(auxiliary_data.path if 'path' in auxiliary_data else auxiliary_data, fold=fold) | ||
train_path = paths['train'][fold] | ||
test_path = paths['test'][fold] | ||
paths = dict(train=train_path, test=test_path) | ||
return paths | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like the fact that it only returns paths. For all other data, we try to provide an object that allows a consistent loading of data, but still providing the possibility to access the path directly. I don't see why it can't be done here as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem here is that the auxiliary data could be in all different forms. For example, it could be a zip file containing bunch of images. We don't know how should we handle it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand, see my suggestion for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
you're making a design decision for them here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
hence the idea of adding a type/format attribute to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes sense. Thanks for the discussion! |
||
|
||
def _extract_auxiliary_paths(self, auxiliary_data, fold=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like a copy/paste of
Given the complexity of this logic, we can't afford duplicating it. def _extract_train_test_paths(self, dataset, fold=None):
return self._extract_paths(self._extract_train_test_paths, dataset, fold)
def _extract_auxiliary_paths(self, dataset, fold=None):
return self._extract_paths(self._extract_auxiliary_paths, dataset, fold,
train_suffix="train_auxiliary", test_suffix="test_auxiliary")
def _extract_paths(self, extract_paths_fn, data, fold=None, train_suffix='train', test_suffix='test'):
train_search_pat = re.compile(rf"(?:(.*)[_-]){train_suffix}(?:[_-](\d+))?\.\w+")
...
return extract_paths_fn(ns(train=…)
... |
||
train_search_pat = re.compile(r"(?:(.*)[_-])train_auxiliary(?:[_-](\d+))?\.\w+") | ||
test_search_pat = re.compile(r"(?:(.*)[_-])test_auxiliary(?:[_-](\d+))?\.\w+") | ||
if isinstance(auxiliary_data, (tuple, list)): | ||
assert len(auxiliary_data) % 2 == 0, "auxiliary data list must contain an even number of paths: [train_0, test_0, train_1, test_1, ...]." | ||
return self._extract_auxiliary_paths(ns(train=[p for i, p in enumerate(auxiliary_data) if i % 2 == 0], | ||
test=[p for i, p in enumerate(auxiliary_data) if i % 2 == 1]), | ||
fold=fold) | ||
elif isinstance(auxiliary_data, ns): | ||
return dict( | ||
train=[self._extract_auxiliary_paths(p)['train'][0] | ||
if i == fold else None | ||
for i, p in enumerate(as_list(auxiliary_data.train))], | ||
test=[self._extract_auxiliary_paths(p)['train'][0] | ||
if i == fold else None | ||
for i, p in enumerate(as_list(auxiliary_data.test))] if 'test' in auxiliary_data else [] | ||
) | ||
else: | ||
assert isinstance(auxiliary_data, str) | ||
auxiliary_data = os.path.expanduser(auxiliary_data) | ||
auxiliary_data = auxiliary_data.format(**rconfig().common_dirs) | ||
|
||
if os.path.exists(auxiliary_data): | ||
if os.path.isfile(auxiliary_data): | ||
# we leave the auxiliary data handling to the user | ||
return dict(train=[auxiliary_data], test=[]) | ||
elif os.path.isdir(auxiliary_data): | ||
files = list_all_files(auxiliary_data) | ||
log.debug("Files found in auxiliary data folder %s: %s", auxiliary_data, files) | ||
assert len(files) > 0, f"Empty folder: {auxiliary_data}" | ||
if len(files) == 1: | ||
return dict(train=files, test=[]) | ||
|
||
train_matches = [m for m in [train_search_pat.search(f) for f in files] if m] | ||
test_matches = [m for m in [test_search_pat.search(f) for f in files] if m] | ||
# verify they're for the same dataset (just based on name) | ||
assert train_matches, f"Folder {auxiliary_data} must contain at least one training auxiliary data." | ||
root_names = {m[1] for m in (train_matches+test_matches)} | ||
assert len(root_names) == 1, f"All dataset files in {auxiliary_data} should follow the same naming: xxxxx_train_N.ext or xxxxx_test_N.ext with N starting from 0." | ||
|
||
train_no_fold = next((m[0] for m in train_matches if m[2] is None), None) | ||
test_no_fold = next((m[0] for m in test_matches if m[2] is None), None) | ||
if train_no_fold and test_no_fold: | ||
return dict(train=[train_no_fold], test=[test_no_fold]) | ||
|
||
paths = dict(train=[], test=[]) | ||
fold = 0 | ||
while fold >= 0: | ||
train = next((m[0] for m in train_matches if m[2] == str(fold)), None) | ||
test = next((m[0] for m in test_matches if m[2] == str(fold)), None) | ||
if train and test: | ||
paths['train'].append(train) | ||
paths['test'].append(test) | ||
fold += 1 | ||
else: | ||
fold = -1 | ||
assert len(paths) > 0, f"No dataset file found in {auxiliary_data}: they should follow the naming xxxx_train.ext, xxxx_test.ext or xxxx_train_0.ext, xxxx_test_0.ext, xxxx_train_1.ext, ..." | ||
return paths | ||
elif is_valid_url(auxiliary_data): | ||
cached_file = os.path.join(self._cache_dir, os.path.basename(auxiliary_data)) | ||
if not os.path.exists(cached_file): # don't download if previously done | ||
handler = get_file_handler(auxiliary_data) | ||
assert handler.exists(auxiliary_data), f"Invalid path/url: {auxiliary_data}" | ||
handler.download(auxiliary_data, dest_path=cached_file) | ||
return self._extract_auxiliary_paths(cached_file) | ||
else: | ||
raise ValueError(f"Invalid dataset description: {auxiliary_data}") | ||
|
||
def _extract_train_test_paths(self, dataset, fold=None): | ||
if isinstance(dataset, (tuple, list)): | ||
assert len(dataset) % 2 == 0, "dataset list must contain an even number of paths: [train_0, test_0, train_1, test_1, ...]." | ||
|
@@ -167,6 +245,59 @@ def _get_metadata(self, prop): | |
return meta[prop] | ||
|
||
|
||
class DatasetWithauxiliaryData: | ||
PGijsbers marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__(self, dataset: FileDataset, auxiliary_data_path): | ||
self._dataset = dataset | ||
self._train_auxiliary_data = auxiliary_data_path.get('train', None) | ||
self._test_auxiliary_data = auxiliary_data_path.get('test', None) | ||
|
||
@property | ||
def train_auxiliary_data(self) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if it's a path, please name it |
||
return self._train_auxiliary_data | ||
|
||
@property | ||
def test_auxiliary_data(self) -> str: | ||
return self._test_auxiliary_data | ||
|
||
@property | ||
def type(self) -> DatasetType: | ||
assert self._dataset.target is not None | ||
return (DatasetType[self._dataset._type] if self._dataset._type is not None | ||
else DatasetType.regression if self._dataset.target.values is None | ||
else DatasetType.binary if len(self._dataset.target.values) == 2 | ||
else DatasetType.multiclass) | ||
|
||
@property | ||
def train(self) -> Datasplit: | ||
return self._dataset._train | ||
|
||
@property | ||
def test(self) -> Datasplit: | ||
return self._dataset._test | ||
|
||
@property | ||
def features(self) -> List[Feature]: | ||
return self._get_metadata('features') | ||
|
||
@property | ||
def target(self) -> Feature: | ||
return self._get_metadata('target') | ||
|
||
@memoize | ||
def _get_metadata(self, prop): | ||
meta = self._dataset._train.load_metadata() | ||
return meta[prop] | ||
|
||
@profile(logger=log) | ||
def release(self, properties=None): | ||
""" | ||
Call this to release cached properties and optimize memory once in-memory data are not needed anymore. | ||
:param properties: | ||
""" | ||
self._dataset.release(properties) | ||
|
||
|
||
class FileDatasplit(Datasplit): | ||
|
||
def __init__(self, dataset: FileDataset, format: str, path: str): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a test for this under
tests/unit/amlb/datasets/file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added