Skip to content

Commit

Permalink
Add "module" estimator spec
Browse files Browse the repository at this point in the history
Import a (possibly local) module and look for a global named `estimator`
  • Loading branch information
mpharrigan committed Aug 8, 2016
1 parent 14f5dc6 commit 098f9d4
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions osprey/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

FIELDS = {
'estimator': ['pickle', 'eval', 'eval_scope', 'entry_point',
'params'],
'params', 'module'],
'dataset_loader': ['name', 'params'],
'trials': ['uri', 'project_name'],
'search_space': dict,
Expand Down Expand Up @@ -149,7 +149,18 @@ def estimator(self):
pickle: path-to-pickle-file.pkl
eval: "Pipeline([('cluster': KMeans())])"
entry_point: sklearn.linear_model.LogisticRegression
module: myestimator
"""
module_path = self.get_value('estimator/module')
if module_path is not None:
with prepend_syspath(dirname(abspath(self.path))):
estimator_module = importlib.import_module(module_path)
estimator = estimator_module.estimator()
if not isinstance(estimator, sklearn.base.BaseEstimator):
raise RuntimeError('estimator/pickle must load a '
'sklearn-derived Estimator')
return estimator

evalstring = self.get_value('estimator/eval')
if evalstring is not None:
got = self.get_value('estimator/eval_scope')
Expand All @@ -166,7 +177,8 @@ def estimator(self):
scope.update(getattr(eval_scopes, pkg_name)())
else:
try:
pkg = importlib.import_module(pkg_name)
with prepend_syspath(dirname(abspath(self.path))):
pkg = importlib.import_module(pkg_name)
except ImportError as e:
raise RuntimeError(str(e))
scope.update(eval_scopes.import_all_estimators(pkg))
Expand Down

0 comments on commit 098f9d4

Please sign in to comment.