From a6d55b18f5c50733ee9d88f6e822da5246641c9f Mon Sep 17 00:00:00 2001 From: Matthew Harrigan Date: Mon, 8 Aug 2016 10:54:59 -0700 Subject: [PATCH] Add "module" estimator spec Import a (possibly local) module and look for a global named `estimator` --- osprey/config.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/osprey/config.py b/osprey/config.py index d014523..4a47349 100644 --- a/osprey/config.py +++ b/osprey/config.py @@ -44,7 +44,7 @@ FIELDS = { 'estimator': ['pickle', 'eval', 'eval_scope', 'entry_point', - 'params'], + 'params', 'module'], 'dataset_loader': ['name', 'params'], 'trials': ['uri', 'project_name'], 'search_space': dict, @@ -149,7 +149,18 @@ def estimator(self): pickle: path-to-pickle-file.pkl eval: "Pipeline([('cluster': KMeans())])" entry_point: sklearn.linear_model.LogisticRegression + module: myestimator """ + module_path = self.get_value('estimator/module') + if module_path is not None: + with prepend_syspath(dirname(abspath(self.path))): + estimator_module = importlib.import_module(module_path) + estimator = estimator_module.estimator() + if not isinstance(estimator, sklearn.base.BaseEstimator): + raise RuntimeError('estimator/pickle must load a ' + 'sklearn-derived Estimator') + return estimator + evalstring = self.get_value('estimator/eval') if evalstring is not None: got = self.get_value('estimator/eval_scope')