From 652d0a41e5d9b6be889a33387646cc235f543af7 Mon Sep 17 00:00:00 2001 From: Marc Pinet <52708150+marcpinet@users.noreply.github.com> Date: Mon, 2 Dec 2024 11:01:33 +0100 Subject: [PATCH] feat: add validation_split parameter to fit method --- neuralnetlib/models.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/neuralnetlib/models.py b/neuralnetlib/models.py index 84046ea..c91da9c 100644 --- a/neuralnetlib/models.py +++ b/neuralnetlib/models.py @@ -11,11 +11,11 @@ from neuralnetlib.activations import ActivationFunction from neuralnetlib.callbacks import EarlyStopping from neuralnetlib.layers import * -from neuralnetlib.losses import LossFunction, CategoricalCrossentropy, BinaryCrossentropy, SparseCategoricalCrossentropy, Wasserstein +from neuralnetlib.losses import LossFunction, CategoricalCrossentropy, BinaryCrossentropy, SparseCategoricalCrossentropy from neuralnetlib.metrics import Metric from neuralnetlib.optimizers import Optimizer from neuralnetlib.preprocessing import PCA, pad_sequences, clip_gradients, SpectralNorm -from neuralnetlib.utils import shuffle, progress_bar, is_interactive, is_display_available, format_number, log_softmax, softmax, History, GradientDebugger +from neuralnetlib.utils import shuffle, progress_bar, is_interactive, is_display_available, format_number, log_softmax, train_test_split, History, GradientDebugger class BaseModel(ABC): @@ -229,6 +229,7 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray, metrics: list | None = None, random_state: int | None = None, validation_data: tuple | None = None, + validation_split: float | None = None, callbacks: list = [], plot_decision_boundary: bool = False) -> dict: """ @@ -254,6 +255,13 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray, 'loss': [], 'val_loss': [] }) + + if validation_split is not None and validation_data is not None: + raise ValueError("Cannot specify both validation_data and validation_split") + elif validation_split is not None: + x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=validation_split, + random_state=random_state if random_state is not None else self.random_state) + validation_data = (x_test, y_test) if plot_decision_boundary and not is_interactive() and not is_display_available(): raise ValueError("Cannot display the plot. Please run the script in an environment with a display.") @@ -1097,12 +1105,19 @@ def fit(self, x_train: np.ndarray, metrics: list | None = None, random_state: int | None = None, validation_data: tuple | None = None, + validation_split: float | None = None, callbacks: list = []) -> dict: history = History({ 'loss': [], 'val_loss': [] }) + + if validation_data is not None and validation_split is not None: + raise ValueError("Cannot specify both validation_data and validation_split") + elif validation_data is None and validation_split is not None: + x_train, x_val = train_test_split(x_train, test_size=validation_split, random_state=random_state) + validation_data = (x_val, x_val) x_train = np.array(x_train) if not isinstance(x_train, np.ndarray) else x_train @@ -1705,12 +1720,19 @@ def fit(self, x_train: np.ndarray | list, y_train: np.ndarray | list, metrics: list | None = None, random_state: int | None = None, validation_data: tuple | None = None, + validation_split: float | None = None, callbacks: list = []) -> dict: history = History({ 'loss': [], 'val_loss': [] }) + + if validation_data is not None and validation_split is not None: + raise ValueError("Cannot specify both validation_data and validation_split") + elif validation_data is None and validation_split is not None: + x_val, y_val = train_test_split(x_train, test_size=validation_split, random_state=random_state) + validation_data = (x_val, y_val) encoder_input, decoder_input, decoder_target = self.prepare_data(x_train, y_train) @@ -2223,6 +2245,7 @@ def fit( metrics: list | None = None, random_state: int | None = None, validation_data: tuple | None = None, + validation_split: float | None = None, callbacks: list = [], plot_generated: bool = False, plot_interval: int = 1, @@ -2256,6 +2279,12 @@ def fit( 'val_discriminator_loss': [], 'val_generator_loss': [] }) + + if validation_data is not None and validation_split is not None: + raise ValueError("Cannot specify both validation_data and validation_split") + elif validation_data is None and validation_split is not None: + x_train, x_val = train_test_split(x_train, test_size=validation_split, random_state=random_state) + validation_data = (x_val, None) x_train = np.array(x_train) if not isinstance(x_train, np.ndarray) else x_train