Skip to content

Commit

Permalink
Add fixed CV
Browse files Browse the repository at this point in the history
  • Loading branch information
skearnes committed Nov 14, 2014
1 parent 2a5cfd4 commit 1228bd2
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions osprey/cross_validators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import print_function, absolute_import, division

import numpy as np


class BaseCVFactory(object):
short_name = None
Expand Down Expand Up @@ -89,3 +91,26 @@ def create(self, X, y):

return StratifiedKFold(y, n_folds=self.n_folds, shuffle=self.shuffle,
random_state=self.random_state)


class FixedCVFactory(BaseCVFactory):
"""
Cross-validator to use with a fixed, held-out validation set.
Parameters
----------
start : int
Start index of validation set.
stop : int, optional
Stop index of validation set.
"""
short_name = 'fixed'

def __init__(self, start, stop=None):
self.valid = slice(start, stop)

def create(self, X, y):
indices = np.arange(len(X))
valid = indices[self.valid]
train = np.setdiff1d(indices, valid)
return (train, valid), # return a nested tuple

0 comments on commit 1228bd2

Please sign in to comment.