From f2a818093d1b5a05db7363d91a839b2baf94d8b0 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 6 Jun 2024 17:27:13 +0200 Subject: [PATCH] remove double wandb initialization --- neural_lam/models/ar_model.py | 4 ++++ neural_lam/utils.py | 18 +++--------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 6ced211f..9448edae 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -597,3 +597,7 @@ def on_load_checkpoint(self, checkpoint): if not self.restore_opt: opt = self.configure_optimizers() checkpoint["optimizer_states"] = [opt.state_dict()] + + def on_run_end(self): + if self.trainer.is_global_zero: + wandb.save("neural_lam/data_config.yaml") diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 3f7a27c6..19021204 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,14 +1,13 @@ # Standard library import os -import shutil import random +import shutil import time # Third-party import numpy as np import pytorch_lightning as pl import torch -import wandb # pylint: disable=wrong-import-order from pytorch_lightning.utilities import rank_zero_only from torch import nn from tueplots import bundles, figsizes @@ -134,7 +133,8 @@ def loads_file(fn): hierarchical = n_levels > 1 # Nor just single level mesh graph # Load static edge features - m2m_features = loads_file("m2m_features.pt") # List of (M_m2m[l], d_edge_f) + # List of (M_m2m[l], d_edge_f) + m2m_features = loads_file("m2m_features.pt") g2m_features = loads_file("g2m_features.pt") # (M_g2m, d_edge_f) m2g_features = loads_file("m2g_features.pt") # (M_m2g, d_edge_f) @@ -288,24 +288,12 @@ def init_wandb(args): f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-" f"{time.strftime('%m_%d_%H_%M_%S')}-{random_int}" ) - wandb.init( - name=run_name, - project=args.wandb_project, - config=args, - ) logger = pl.loggers.WandbLogger( project=args.wandb_project, name=run_name, config=args, ) - wandb.save("neural_lam/data_config.yaml") else: - wandb.init( - project=args.wandb_project, - config=args, - id=args.resume_run, - resume="must", - ) logger = pl.loggers.WandbLogger( project=args.wandb_project, id=args.resume_run,