Skip to content

Commit

Permalink
feat: add validation_split parameter to fit method
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 2, 2024
1 parent a538367 commit 652d0a4
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions neuralnetlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from neuralnetlib.activations import ActivationFunction
from neuralnetlib.callbacks import EarlyStopping
from neuralnetlib.layers import *
from neuralnetlib.losses import LossFunction, CategoricalCrossentropy, BinaryCrossentropy, SparseCategoricalCrossentropy, Wasserstein
from neuralnetlib.losses import LossFunction, CategoricalCrossentropy, BinaryCrossentropy, SparseCategoricalCrossentropy
from neuralnetlib.metrics import Metric
from neuralnetlib.optimizers import Optimizer
from neuralnetlib.preprocessing import PCA, pad_sequences, clip_gradients, SpectralNorm
from neuralnetlib.utils import shuffle, progress_bar, is_interactive, is_display_available, format_number, log_softmax, softmax, History, GradientDebugger
from neuralnetlib.utils import shuffle, progress_bar, is_interactive, is_display_available, format_number, log_softmax, train_test_split, History, GradientDebugger


class BaseModel(ABC):
Expand Down Expand Up @@ -229,6 +229,7 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray,
metrics: list | None = None,
random_state: int | None = None,
validation_data: tuple | None = None,
validation_split: float | None = None,
callbacks: list = [],
plot_decision_boundary: bool = False) -> dict:
"""
Expand All @@ -254,6 +255,13 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray,
'loss': [],
'val_loss': []
})

if validation_split is not None and validation_data is not None:
raise ValueError("Cannot specify both validation_data and validation_split")
elif validation_split is not None:
x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=validation_split,
random_state=random_state if random_state is not None else self.random_state)
validation_data = (x_test, y_test)

if plot_decision_boundary and not is_interactive() and not is_display_available():
raise ValueError("Cannot display the plot. Please run the script in an environment with a display.")
Expand Down Expand Up @@ -1097,12 +1105,19 @@ def fit(self, x_train: np.ndarray,
metrics: list | None = None,
random_state: int | None = None,
validation_data: tuple | None = None,
validation_split: float | None = None,
callbacks: list = []) -> dict:

history = History({
'loss': [],
'val_loss': []
})

if validation_data is not None and validation_split is not None:
raise ValueError("Cannot specify both validation_data and validation_split")
elif validation_data is None and validation_split is not None:
x_train, x_val = train_test_split(x_train, test_size=validation_split, random_state=random_state)
validation_data = (x_val, x_val)

x_train = np.array(x_train) if not isinstance(x_train, np.ndarray) else x_train

Expand Down Expand Up @@ -1705,12 +1720,19 @@ def fit(self, x_train: np.ndarray | list, y_train: np.ndarray | list,
metrics: list | None = None,
random_state: int | None = None,
validation_data: tuple | None = None,
validation_split: float | None = None,
callbacks: list = []) -> dict:

history = History({
'loss': [],
'val_loss': []
})

if validation_data is not None and validation_split is not None:
raise ValueError("Cannot specify both validation_data and validation_split")
elif validation_data is None and validation_split is not None:
x_val, y_val = train_test_split(x_train, test_size=validation_split, random_state=random_state)
validation_data = (x_val, y_val)

encoder_input, decoder_input, decoder_target = self.prepare_data(x_train, y_train)

Expand Down Expand Up @@ -2223,6 +2245,7 @@ def fit(
metrics: list | None = None,
random_state: int | None = None,
validation_data: tuple | None = None,
validation_split: float | None = None,
callbacks: list = [],
plot_generated: bool = False,
plot_interval: int = 1,
Expand Down Expand Up @@ -2256,6 +2279,12 @@ def fit(
'val_discriminator_loss': [],
'val_generator_loss': []
})

if validation_data is not None and validation_split is not None:
raise ValueError("Cannot specify both validation_data and validation_split")
elif validation_data is None and validation_split is not None:
x_train, x_val = train_test_split(x_train, test_size=validation_split, random_state=random_state)
validation_data = (x_val, None)

x_train = np.array(x_train) if not isinstance(x_train, np.ndarray) else x_train

Expand Down

0 comments on commit 652d0a4

Please sign in to comment.