From d62a50ca96fe64b3694e7f26d472ef30aa82d4ab Mon Sep 17 00:00:00 2001 From: leschultz Date: Fri, 19 Apr 2024 11:48:41 -0500 Subject: [PATCH] Added early stopping --- examples/materials/combined/fit.py | 8 +++++--- examples/synthetic/fit.py | 6 ++++-- src/multilearn/datasets.py | 2 +- src/multilearn/utils.py | 32 +++++++++++++++++++++++++----- 4 files changed, 37 insertions(+), 11 deletions(-) diff --git a/examples/materials/combined/fit.py b/examples/materials/combined/fit.py index 5c06930..9ad4627 100644 --- a/examples/materials/combined/fit.py +++ b/examples/materials/combined/fit.py @@ -14,7 +14,7 @@ def main(): # Data X, y = datasets.load(tasks) data = datasets.splitter( - X, + X, y, tasks, train_size=0.8, @@ -28,8 +28,9 @@ def main(): model = models.MultiNet( tasks=tasks, - input_arch={500: 1}, - mid_arch={1024: 1, 32: 1, 16: 1}, + input_arch={100: 1, 100: 1}, + mid_arch={100: 1, 50: 1}, + out_arch={50: 1, 10: 1} ) optimizer = optim.Adam @@ -40,6 +41,7 @@ def main(): n_epochs=n_epochs, batch_size=batch_size, lr=lr, + patience=10, save_dir=save_dir, ) diff --git a/examples/synthetic/fit.py b/examples/synthetic/fit.py index 2faeea0..fc3fd42 100644 --- a/examples/synthetic/fit.py +++ b/examples/synthetic/fit.py @@ -28,9 +28,11 @@ def main(): model = models.MultiNet( tasks=tasks, - input_arch={10: 1}, - mid_arch={1024: 1, 16: 1}, + input_arch={100: 1, 100: 1}, + mid_arch={100: 1, 50: 1}, + out_arch={50: 1, 10: 1} ) + optimizer = optim.Adam out = utils.train( diff --git a/src/multilearn/datasets.py b/src/multilearn/datasets.py index 54e6958..aa2415e 100644 --- a/src/multilearn/datasets.py +++ b/src/multilearn/datasets.py @@ -95,7 +95,7 @@ def load(names): elif name == 'toy2': - X = np.random.uniform(size=(900, 3)) + X = np.random.uniform(-100, 50, size=(900, 3)) y = 3+X[:, 0]+X[:, 1]**3+X[:, 2] elif name == 'friedman1': diff --git a/src/multilearn/utils.py b/src/multilearn/utils.py index 146c1f5..37b375f 100644 --- a/src/multilearn/utils.py +++ b/src/multilearn/utils.py @@ -115,6 +115,7 @@ def train( batch_size=32, lr=1e-4, save_dir='outputs', + patience=np.inf, print_n=100, ): @@ -150,6 +151,8 @@ def train( data_train = CombinedLoader(data_train, 'max_size') df_loss = [] + no_improv = 0 + best_loss = float('inf') for epoch in range(1, n_epochs+1): model.train() @@ -175,12 +178,14 @@ def train( with torch.no_grad(): model.eval() + all_loss = 0.0 for indx in data.keys(): y = data[indx]['y_train'] p = model(data[indx]['X_train'], indx) loss = data[indx]['loss'](p, y).item() - d = (epoch, loss, indx, 'train') + split = 'train' + d = (epoch, loss, indx, split) df_loss.append(d) if 'y_val' in data[indx].keys(): @@ -189,12 +194,29 @@ def train( p = model(data[indx]['X_val'], indx) loss = data[indx]['loss'](p, y).item() - d = (epoch, loss, indx, 'val') + split = 'val' + d = (epoch, loss, indx, split) df_loss.append(d) + all_loss += loss + + else: + all_loss += loss + + # Early stopping + if all_loss < best_loss: + best_model = copy.deepcopy(model) + best_loss = all_loss + no_improv = 0 + + else: + no_improv = 1 + + if no_improv >= patience: + break + if epoch % print_n == 0: - p = f'Epoch {epoch}/{n_epochs}: ' - print(p+f'Train loss {loss:.2f}') + print(f'Epoch {epoch}/{n_epochs}: {split} loss {loss:.2f}') # Loss curve columns = ['epoch', 'loss', 'data', 'split'] @@ -212,7 +234,7 @@ def train( ) out = { - 'model': model, + 'model': best_model, 'df_parity': df_parity, 'df_loss': df_loss, 'data': data,