Skip to content

Commit

Permalink
[update] disable xavier norm if load model
Browse files Browse the repository at this point in the history
  • Loading branch information
Jourdelune committed Jul 2, 2024
1 parent 9cab912 commit 1ab490b
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,13 @@

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if os.path.exists(args.model_path + args.base_model_path):
model.load_state_dict(torch.load(args.model_path + args.base_model_path))
if os.path.exists(args.base_model_path):
model.load_state_dict(torch.load(args.base_model_path))
print("Model loaded")
else:
for p in model.parameters():
if p.dim() > 1:
torch.nn.init.xavier_uniform_(p)

print(f"Using device {device}")
model.to(device, dtype=dtype)
Expand Down Expand Up @@ -151,12 +155,6 @@
# total_iters=train_size * EPOCH // (GRADIENT_ACCUMULATION_STEPS * BATCH_SIZE),
# )

# xavier initialization

for p in model.parameters():
if p.dim() > 1:
torch.nn.init.xavier_uniform_(p)

# print number of parameters
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()) / 1e6}M")

Expand Down

0 comments on commit 1ab490b

Please sign in to comment.