From 1ab490bc539f52ce0142b99928e4bcf13234995e Mon Sep 17 00:00:00 2001 From: Jourdelune Date: Tue, 2 Jul 2024 12:50:36 +0200 Subject: [PATCH] [update] disable xavier norm if load model --- scripts/train.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 8ec7b2e..dc294c2 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -108,9 +108,13 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -if os.path.exists(args.model_path + args.base_model_path): - model.load_state_dict(torch.load(args.model_path + args.base_model_path)) +if os.path.exists(args.base_model_path): + model.load_state_dict(torch.load(args.base_model_path)) print("Model loaded") +else: + for p in model.parameters(): + if p.dim() > 1: + torch.nn.init.xavier_uniform_(p) print(f"Using device {device}") model.to(device, dtype=dtype) @@ -151,12 +155,6 @@ # total_iters=train_size * EPOCH // (GRADIENT_ACCUMULATION_STEPS * BATCH_SIZE), # ) -# xavier initialization - -for p in model.parameters(): - if p.dim() > 1: - torch.nn.init.xavier_uniform_(p) - # print number of parameters print(f"Number of parameters: {sum(p.numel() for p in model.parameters()) / 1e6}M")