-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* refactor dataset classes, use openml cache * fix example select_similar_datasets_by_knn.py * create DatasetIDType * create PredictorType * remove DataManager, refactor cache * update tests & test data * allow explicit OpenMLDataset creation from name/search * adapt examples to the last changes
- Loading branch information
1 parent
267e6f9
commit 5261b8f
Showing
59 changed files
with
2,350 additions
and
415 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,4 +129,4 @@ dmypy.json | |
.pyre/ | ||
|
||
# User data | ||
data/ | ||
/data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .dataset_base import DatasetBase, DatasetData, DatasetIDType | ||
from .custom_dataset import DataNotFoundError, CustomDataset | ||
from .openml_dataset import OpenMLDataset, OpenMLDatasetIDType |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from __future__ import annotations | ||
|
||
import pickle | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
from meta_automl.data_preparation.dataset import DatasetBase | ||
from meta_automl.data_preparation.dataset.dataset_base import DatasetData | ||
|
||
|
||
|
||
class DataNotFoundError(FileNotFoundError): | ||
pass | ||
|
||
|
||
class CustomDataset(DatasetBase): | ||
|
||
def get_data(self, cache_path: Optional[Path] = None) -> DatasetData: | ||
cache_path = cache_path or self.cache_path | ||
if not cache_path.exists(): | ||
raise DataNotFoundError(f'Dataset {self} is missing by the path "{cache_path}".') | ||
with open(cache_path, 'rb') as f: | ||
dataset_data = pickle.load(f) | ||
return dataset_data | ||
|
||
def dump_data(self, dataset_data: DatasetData, cache_path: Optional[Path] = None) -> CustomDataset: | ||
cache_path = cache_path or self.cache_path | ||
with open(cache_path, 'wb') as f: | ||
pickle.dump(dataset_data, f) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from __future__ import annotations | ||
|
||
from abc import abstractmethod, ABC | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Union, Optional, List, Any | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import scipy as sp | ||
|
||
from meta_automl.data_preparation.file_system import CacheOperator, get_dataset_cache_path | ||
|
||
DatasetIDType = Any | ||
|
||
|
||
@dataclass | ||
class DatasetData: | ||
x: Union[np.ndarray, pd.DataFrame, sp.sparse.csr_matrix] | ||
y: Optional[Union[np.ndarray, pd.DataFrame]] = None | ||
categorical_indicator: Optional[List[bool]] = None | ||
attribute_names: Optional[List[str]] = None | ||
|
||
|
||
class DatasetBase(ABC, CacheOperator): | ||
|
||
def __init__(self, id_: DatasetIDType, name: Optional[str] = None): | ||
self.id_ = id_ | ||
self.name = name | ||
|
||
def __repr__(self): | ||
return f'{self.__class__.__name__}(id_={self.id_}, name={self.name})' | ||
|
||
@abstractmethod | ||
def get_data(self) -> DatasetData: | ||
raise NotImplementedError() | ||
|
||
@property | ||
def cache_path(self) -> Path: | ||
return get_dataset_cache_path(self) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Union | ||
|
||
import openml | ||
|
||
from meta_automl.data_preparation.dataset import DatasetBase | ||
from meta_automl.data_preparation.dataset.dataset_base import DatasetData | ||
from meta_automl.data_preparation.file_system import update_openml_cache_dir | ||
|
||
OpenMLDatasetIDType = int | ||
|
||
update_openml_cache_dir() | ||
|
||
|
||
class OpenMLDataset(DatasetBase): | ||
|
||
def __init__(self, id_: OpenMLDatasetIDType): | ||
if isinstance(id_, str): | ||
raise ValueError('Creating OpenMLDataset by dataset name is ambiguous. Please, use dataset id.' | ||
f'Otherwise, you can perform search by f{self.__class__.__name__}.from_search().') | ||
self._openml_dataset = openml.datasets.get_dataset(id_, download_data=False, download_qualities=False, | ||
error_if_multiple=True) | ||
id_ = self._openml_dataset.id | ||
name = self._openml_dataset.name | ||
super().__init__(id_, name) | ||
|
||
@classmethod | ||
def from_search(cls, id_: Union[OpenMLDatasetIDType, str], **get_dataset_kwargs) -> OpenMLDataset: | ||
openml_dataset = openml.datasets.get_dataset(id_, download_data=False, download_qualities=False, | ||
**get_dataset_kwargs) | ||
return cls(openml_dataset.id) | ||
|
||
def get_data(self, dataset_format: str = 'dataframe') -> DatasetData: | ||
X, y, categorical_indicator, attribute_names = self._openml_dataset.get_data( | ||
target=self._openml_dataset.default_target_attribute, | ||
dataset_format=dataset_format | ||
) | ||
return DatasetData(X, y, categorical_indicator, attribute_names) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .datasets_loader import DatasetsLoader | ||
from .openml_datasets_loader import OpenMLDatasetsLoader, OpenMLDatasetID | ||
from .openml_datasets_loader import OpenMLDatasetsLoader |
16 changes: 4 additions & 12 deletions
16
meta_automl/data_preparation/datasets_loaders/datasets_loader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,17 @@ | ||
from __future__ import annotations | ||
|
||
from abc import abstractmethod | ||
from typing import List, Type | ||
from typing import List | ||
|
||
from meta_automl.data_preparation.data_manager import DataManager | ||
from meta_automl.data_preparation.dataset import Dataset, DatasetCache, NoCacheError | ||
from meta_automl.data_preparation.dataset import DatasetBase | ||
|
||
|
||
class DatasetsLoader: | ||
data_manager: Type[DataManager] = DataManager | ||
|
||
@abstractmethod | ||
def load(self, *args, **kwargs) -> List[DatasetCache]: | ||
def load(self, *args, **kwargs) -> List[DatasetBase]: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def load_single(self, *args, **kwargs) -> DatasetCache: | ||
def load_single(self, *args, **kwargs) -> DatasetBase: | ||
raise NotImplementedError() | ||
|
||
def cache_to_memory(self, dataset: DatasetCache) -> Dataset: | ||
try: | ||
return dataset.from_cache() | ||
except NoCacheError: | ||
return self.load_single(dataset.id).from_cache() |
Oops, something went wrong.