-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support FunctionTransformer steps in definition
Allows user to specify a function included in gordo_components to be included in a sklearn FunctionTransformer within a config file
- Loading branch information
1 parent
72eda92
commit e157da7
Showing
7 changed files
with
127 additions
and
17 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
Functions to be used within sklearn's FunctionTransformer | ||
https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.FunctionTransformer.html | ||
Each function SHALL take an X, and optionally a y. | ||
Functions CAN take additional arguments which should be given during the initialization of the FunctionTransformer | ||
Example: | ||
>>> from sklearn.preprocessing import FunctionTransformer | ||
>>> import numpy as np | ||
>>> def my_function(X, another_arg): | ||
... # Some fancy X manipulation... | ||
... return X | ||
>>> transformer = FunctionTransformer(func=my_function, kw_args={'another_arg': 'this thing'}) | ||
>>> out = transformer.fit_transform(np.random.random(100).reshape(10, 10)) | ||
""" | ||
|
||
|
||
def multiply_by(X, factor): | ||
""" | ||
Multiplies X by a given factor | ||
""" | ||
return X * factor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import unittest | ||
import numpy as np | ||
|
||
from sklearn.preprocessing import FunctionTransformer | ||
from sklearn.pipeline import Pipeline | ||
from sklearn.decomposition import PCA | ||
|
||
|
||
class GordoFunctionTransformerFuncsTestCase(unittest.TestCase): | ||
""" | ||
Test all functions within gordo_components meants for use in a Scikit-Learn | ||
FunctionTransformer work as expected | ||
""" | ||
|
||
def _validate_transformer(self, transformer): | ||
""" | ||
Inserts a transformer into the middle of a pipeline and runs it | ||
""" | ||
pipe = Pipeline([ | ||
('pca1', PCA()), | ||
('custom', transformer), | ||
('pca2', PCA()) | ||
]) | ||
X = np.random.random(size=100).reshape(10, 10) | ||
pipe.fit_transform(X) | ||
|
||
def test_multiply_by_function_transformer(self): | ||
from gordo_components.model.transformer_funcs.general import multiply_by | ||
|
||
# Provide a require argument | ||
tf = FunctionTransformer(func=multiply_by, kw_args={'factor': 2}) | ||
self._validate_transformer(tf) | ||
|
||
# Ignore the required argument | ||
tf = FunctionTransformer(func=multiply_by) | ||
with self.assertRaises(TypeError): | ||
self._validate_transformer(tf) |