Skip to content

Commit

Permalink
remove double wandb initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Jun 6, 2024
1 parent 334ab42 commit f2a8180
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
4 changes: 4 additions & 0 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,3 +597,7 @@ def on_load_checkpoint(self, checkpoint):
if not self.restore_opt:
opt = self.configure_optimizers()
checkpoint["optimizer_states"] = [opt.state_dict()]

def on_run_end(self):
if self.trainer.is_global_zero:
wandb.save("neural_lam/data_config.yaml")
18 changes: 3 additions & 15 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Standard library
import os
import shutil
import random
import shutil
import time

# Third-party
import numpy as np
import pytorch_lightning as pl
import torch
import wandb # pylint: disable=wrong-import-order
from pytorch_lightning.utilities import rank_zero_only
from torch import nn
from tueplots import bundles, figsizes
Expand Down Expand Up @@ -134,7 +133,8 @@ def loads_file(fn):
hierarchical = n_levels > 1 # Nor just single level mesh graph

# Load static edge features
m2m_features = loads_file("m2m_features.pt") # List of (M_m2m[l], d_edge_f)
# List of (M_m2m[l], d_edge_f)
m2m_features = loads_file("m2m_features.pt")
g2m_features = loads_file("g2m_features.pt") # (M_g2m, d_edge_f)
m2g_features = loads_file("m2g_features.pt") # (M_m2g, d_edge_f)

Expand Down Expand Up @@ -288,24 +288,12 @@ def init_wandb(args):
f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-"
f"{time.strftime('%m_%d_%H_%M_%S')}-{random_int}"
)
wandb.init(
name=run_name,
project=args.wandb_project,
config=args,
)
logger = pl.loggers.WandbLogger(
project=args.wandb_project,
name=run_name,
config=args,
)
wandb.save("neural_lam/data_config.yaml")
else:
wandb.init(
project=args.wandb_project,
config=args,
id=args.resume_run,
resume="must",
)
logger = pl.loggers.WandbLogger(
project=args.wandb_project,
id=args.resume_run,
Expand Down

0 comments on commit f2a8180

Please sign in to comment.