Skip to content

Commit

Permalink
Added early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
leschultz committed Apr 19, 2024
1 parent d7eb499 commit d62a50c
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 11 deletions.
8 changes: 5 additions & 3 deletions examples/materials/combined/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def main():
# Data
X, y = datasets.load(tasks)
data = datasets.splitter(
X,
X,
y,
tasks,
train_size=0.8,
Expand All @@ -28,8 +28,9 @@ def main():

model = models.MultiNet(
tasks=tasks,
input_arch={500: 1},
mid_arch={1024: 1, 32: 1, 16: 1},
input_arch={100: 1, 100: 1},
mid_arch={100: 1, 50: 1},
out_arch={50: 1, 10: 1}
)
optimizer = optim.Adam

Expand All @@ -40,6 +41,7 @@ def main():
n_epochs=n_epochs,
batch_size=batch_size,
lr=lr,
patience=10,
save_dir=save_dir,
)

Expand Down
6 changes: 4 additions & 2 deletions examples/synthetic/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ def main():

model = models.MultiNet(
tasks=tasks,
input_arch={10: 1},
mid_arch={1024: 1, 16: 1},
input_arch={100: 1, 100: 1},
mid_arch={100: 1, 50: 1},
out_arch={50: 1, 10: 1}
)

optimizer = optim.Adam

out = utils.train(
Expand Down
2 changes: 1 addition & 1 deletion src/multilearn/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def load(names):

elif name == 'toy2':

X = np.random.uniform(size=(900, 3))
X = np.random.uniform(-100, 50, size=(900, 3))
y = 3+X[:, 0]+X[:, 1]**3+X[:, 2]

elif name == 'friedman1':
Expand Down
32 changes: 27 additions & 5 deletions src/multilearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def train(
batch_size=32,
lr=1e-4,
save_dir='outputs',
patience=np.inf,
print_n=100,
):

Expand Down Expand Up @@ -150,6 +151,8 @@ def train(
data_train = CombinedLoader(data_train, 'max_size')

df_loss = []
no_improv = 0
best_loss = float('inf')
for epoch in range(1, n_epochs+1):

model.train()
Expand All @@ -175,12 +178,14 @@ def train(
with torch.no_grad():
model.eval()

all_loss = 0.0
for indx in data.keys():
y = data[indx]['y_train']
p = model(data[indx]['X_train'], indx)
loss = data[indx]['loss'](p, y).item()

d = (epoch, loss, indx, 'train')
split = 'train'
d = (epoch, loss, indx, split)
df_loss.append(d)

if 'y_val' in data[indx].keys():
Expand All @@ -189,12 +194,29 @@ def train(
p = model(data[indx]['X_val'], indx)
loss = data[indx]['loss'](p, y).item()

d = (epoch, loss, indx, 'val')
split = 'val'
d = (epoch, loss, indx, split)
df_loss.append(d)

all_loss += loss

else:
all_loss += loss

# Early stopping
if all_loss < best_loss:
best_model = copy.deepcopy(model)
best_loss = all_loss
no_improv = 0

else:
no_improv = 1

if no_improv >= patience:
break

if epoch % print_n == 0:
p = f'Epoch {epoch}/{n_epochs}: '
print(p+f'Train loss {loss:.2f}')
print(f'Epoch {epoch}/{n_epochs}: {split} loss {loss:.2f}')

# Loss curve
columns = ['epoch', 'loss', 'data', 'split']
Expand All @@ -212,7 +234,7 @@ def train(
)

out = {
'model': model,
'model': best_model,
'df_parity': df_parity,
'df_loss': df_loss,
'data': data,
Expand Down

0 comments on commit d62a50c

Please sign in to comment.