diff --git a/henchman/learning.py b/henchman/learning.py index d8e3810..2711b6e 100644 --- a/henchman/learning.py +++ b/henchman/learning.py @@ -38,6 +38,13 @@ def _fit_predict(X_train, X_test, y_train, y_test, model, metric): return metric(y_test, preds), model +def _get_dataframes(model, metric): + X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=False, test_size=split_size) + model = model + model.fit(X_train, y_train) + return y_test, model.predict_proba(X_test) + + def _score_tt(X, y, model, metric, split_size): X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=False, test_size=split_size) score, fit_model = _fit_predict(X_train, X_test, @@ -45,7 +52,8 @@ def _score_tt(X, y, model, metric, split_size): return [score], fit_model -def create_model(X, y, model=None, metric=None, n_splits=1, split_size=.3): +def create_model(X, y, model=None, metric=None, + n_splits=1, split_size=.3, _return_df=False): '''Make a model. Returns a scorelist and a fit model. A wrapper around a standard scoring workflow. Uses train_test_split unless otherwise specified(in which case @@ -56,13 +64,15 @@ def create_model(X, y, model=None, metric=None, n_splits=1, split_size=.3): recommended you just use the sklearn API. Args: - X(pd.DataFrame): A cleaned numeric feature matrix. - y(pd.Series): A column of labels. + X (pd.DataFrame): A cleaned numeric feature matrix. + y (pd.Series): A column of labels. model: A sklearn model with fit and predict methods. metric: A metric which takes y_test, preds and returns a score. - n_splits(int): If 1 use a train_test_split. Otherwise use tssplit. + n_splits (int): If 1 use a train_test_split. Otherwise use tssplit. Default value is 1. - split_size(float): Size of testing set. Default is .3. + split_size (float): Size of testing set. Default is .3. + _return_df (bool): If true, return (X_train, X_test, y_train, y_test) after returns. + Not generally useful, but sometimes necessary. Returns: scores, fit_model(list[float], sklearn.ensemble): A list of scores and a fit model. @@ -83,6 +93,8 @@ def create_model(X, y, model=None, metric=None, n_splits=1, split_size=.3): assert model is not None assert metric is not None if n_splits == 1: + if _return_df: + return _score_tt(X, y, model, metric, split_size), create_holdout(X, y, split_size) return _score_tt(X, y, model, metric, split_size) if n_splits > 1: @@ -95,6 +107,8 @@ def create_model(X, y, model=None, metric=None, n_splits=1, split_size=.3): score, fit_model = _fit_predict(X_train, X_test, y_train, y_test, model, metric) scorelist.append(score) + if _return_df: + return scorelist, fit_model, (X_train, X_test, y_train, y_test) return scorelist, fit_model diff --git a/henchman/plotting.py b/henchman/plotting.py index ed417eb..61e1191 100644 --- a/henchman/plotting.py +++ b/henchman/plotting.py @@ -22,6 +22,10 @@ from bokeh.palettes import Category20 from henchman.learning import _raw_feature_importances +from henchman.learning import create_model + +from sklearn.metrics import (roc_auc_score, precision_score, + recall_score, f1_score, roc_curve) def show_template(): @@ -600,3 +604,39 @@ def callback(attr, old, new): )) return lambda doc: modify_doc(doc, col) + + +# def plot_roc_auc(X, y, model, pos_label=1, prob_col=1, n_splits=3): +# scores, model, df_list = create_model(X, y, model, roc_auc_score, _return_df=True, n_splits=n_splits) + +# probs = model.predict_proba(df_list[1]) +# fpr, tpr, thresholds = roc_curve(df_list[3], +# probs[:, prob_col], +# pos_label=pos_label) + +# p = figure() +# p.line(x=fpr, y=tpr) +# p.title.text = 'Receiver operating characteristic' +# p.xaxis.axis_label = 'False Positive Rate' +# p.yaxis.axis_label = 'True Positive Rate' + +# p.line(x=fpr, y=fpr, color='red', line_dash='dashed') +# return(p) + + +# def plot_f1(X, y, model, nprecs, n_splits=3): +# scores, model, df_list = create_model(X, y, model, roc_auc_score, _return_df=True, n_splits=n_splits) +# probs = model.predict_proba(df_list[1]) +# threshes = [x/1000. for x in range(50, nprecs)] +# precisions = [precision_score(df_list[3], probs[:, 1] > t) for t in threshes] +# recalls = [recall_score(df_list[3], probs[:, 1] > t) for t in threshes] +# fones = [f1_score(df_list[3], probs[:, 1] > t) for t in threshes] + +# output_notebook() +# p = figure() +# p.line(x=threshes, y=precisions, color='green', legend='precision') +# p.line(x=threshes, y=recalls, color='blue', legend='recall') +# p.line(x=threshes, y=fones, color='red', legend='f1') +# p.xaxis.axis_label = 'Threshold' +# p.title.text = 'Precision, Recall, and F1 by Threshold' +# return(p)