From 1228bd2d06221a6adeb84cf59051eaad547ee61d Mon Sep 17 00:00:00 2001 From: skearnes Date: Fri, 14 Nov 2014 13:01:05 -0800 Subject: [PATCH] Add fixed CV --- osprey/cross_validators.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/osprey/cross_validators.py b/osprey/cross_validators.py index af1657f..0e3f832 100644 --- a/osprey/cross_validators.py +++ b/osprey/cross_validators.py @@ -1,5 +1,7 @@ from __future__ import print_function, absolute_import, division +import numpy as np + class BaseCVFactory(object): short_name = None @@ -89,3 +91,26 @@ def create(self, X, y): return StratifiedKFold(y, n_folds=self.n_folds, shuffle=self.shuffle, random_state=self.random_state) + + +class FixedCVFactory(BaseCVFactory): + """ + Cross-validator to use with a fixed, held-out validation set. + + Parameters + ---------- + start : int + Start index of validation set. + stop : int, optional + Stop index of validation set. + """ + short_name = 'fixed' + + def __init__(self, start, stop=None): + self.valid = slice(start, stop) + + def create(self, X, y): + indices = np.arange(len(X)) + valid = indices[self.valid] + train = np.setdiff1d(indices, valid) + return (train, valid), # return a nested tuple