From 098f9d43504534813e40676cc2211f6e7a00d60a Mon Sep 17 00:00:00 2001 From: Matthew Harrigan Date: Mon, 8 Aug 2016 10:54:59 -0700 Subject: [PATCH] Add "module" estimator spec Import a (possibly local) module and look for a global named `estimator` --- osprey/config.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/osprey/config.py b/osprey/config.py index d014523..27a5b42 100644 --- a/osprey/config.py +++ b/osprey/config.py @@ -44,7 +44,7 @@ FIELDS = { 'estimator': ['pickle', 'eval', 'eval_scope', 'entry_point', - 'params'], + 'params', 'module'], 'dataset_loader': ['name', 'params'], 'trials': ['uri', 'project_name'], 'search_space': dict, @@ -149,7 +149,18 @@ def estimator(self): pickle: path-to-pickle-file.pkl eval: "Pipeline([('cluster': KMeans())])" entry_point: sklearn.linear_model.LogisticRegression + module: myestimator """ + module_path = self.get_value('estimator/module') + if module_path is not None: + with prepend_syspath(dirname(abspath(self.path))): + estimator_module = importlib.import_module(module_path) + estimator = estimator_module.estimator() + if not isinstance(estimator, sklearn.base.BaseEstimator): + raise RuntimeError('estimator/pickle must load a ' + 'sklearn-derived Estimator') + return estimator + evalstring = self.get_value('estimator/eval') if evalstring is not None: got = self.get_value('estimator/eval_scope') @@ -166,7 +177,8 @@ def estimator(self): scope.update(getattr(eval_scopes, pkg_name)()) else: try: - pkg = importlib.import_module(pkg_name) + with prepend_syspath(dirname(abspath(self.path))): + pkg = importlib.import_module(pkg_name) except ImportError as e: raise RuntimeError(str(e)) scope.update(eval_scopes.import_all_estimators(pkg))