-
-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Network not learning #22
Comments
Great work plotting all this information! I'm afraid I'm not familiar enough with the code or the model to really help. But, in the meantime, would it be at all possible to share a link to Weights and Biases for this training run, please? To help us debug. |
Here is a link to this particular run: w&b-run The author Casper mentioned the loss seemed too big. Right now I'm using torch.nn.CrossEntropyLoss(y_hat-y). Is this the same they use in the paper? See his comment below: My question:
His answer:
|
Sorry for my chip in. @ValterFallenius |
Hey @bmkor, I am using neither actually, you can find my raw data in #27. However I have not published elevation, longitude/latitude data I have used, let me know if you need it. But unless you are writing a thesis for the Swedish government I think you might be better off using the original dataset available on huggingface ^^ Also I am not using any GOES data, since it's of bad quality in Sweden because of the lack of geostationary satellites. /Valter |
Thanks a lot for your prompt reply and comment. Would try to use those available in the huggingface first. See if I can make the model run. |
Hi, just so you know, the goes dataset currently doesn't have data in it, I'm working through adding data for that. The MRMS dataset does, although I am still finishing up the dataset script. But if you want to get started with that radar data, you can just download the Zarr files themselves and open them locally quite easily. |
hello, can you tell me where can download MRMS dataset? thank you vary much! |
Training loss is not decreasing
I have implemented the network in the PR "lightning" branch with pytorch lightning and tried to find any bugs. The network compiles without issues and seems to generate gradients but still network fails to learn anything. I have tried to play around with the learning rate and plot the data at different stages but even with 4 training samples (it should be able to overfit these) it fails to decrease the loss even after 100 epochs...
Here is the training loss plotted:
It seems like it's doing something but not nearly quick enough to overfit the small dataset. Something is wrong...
Hyperparameters:
n_samples = 4
hidden_dim=8,
forecast_steps=1,
input_channels=15,
output_channels=6, #512
input_size=112, # 112
n_samples = 100,
num_workers = 8,
batch_size = 1,
learning_rate = 1e-2
Below is a weights&biases grad report. As you can see most gradients are non-zero, I'm not sure why image_encoder has very small gradients for their biases...
wandb: epoch 83 wandb: grad_2.0_norm/head.bias_epoch 0.0746 wandb: grad_2.0_norm/head.bias_step 0.049 wandb: grad_2.0_norm/head.weight_epoch 0.0862 wandb: grad_2.0_norm/head.weight_step 0.081 wandb: grad_2.0_norm/image_encoder.module.module.0.bias_epoch 0.0 wandb: grad_2.0_norm/image_encoder.module.module.0.bias_step 0.0 wandb: grad_2.0_norm/image_encoder.module.module.0.weight_epoch 0.06653 wandb: grad_2.0_norm/image_encoder.module.module.0.weight_step 0.043 wandb: grad_2.0_norm/image_encoder.module.module.2.bias_epoch 0.00017 wandb: grad_2.0_norm/image_encoder.module.module.2.bias_step 0.0001 wandb: grad_2.0_norm/image_encoder.module.module.2.weight_epoch 0.003 wandb: grad_2.0_norm/image_encoder.module.module.2.weight_step 0.0019 wandb: grad_2.0_norm/image_encoder.module.module.3.bias_epoch 0.0 wandb: grad_2.0_norm/image_encoder.module.module.3.bias_step 0.0 wandb: grad_2.0_norm/image_encoder.module.module.3.weight_epoch 0.16387 wandb: grad_2.0_norm/image_encoder.module.module.3.weight_step 0.1125 wandb: grad_2.0_norm/image_encoder.module.module.4.bias_epoch 0.00013 wandb: grad_2.0_norm/image_encoder.module.module.4.bias_step 0.0001 wandb: grad_2.0_norm/image_encoder.module.module.4.weight_epoch 0.00203 wandb: grad_2.0_norm/image_encoder.module.module.4.weight_step 0.0012 wandb: grad_2.0_norm/image_encoder.module.module.5.bias_epoch 0.0 wandb: grad_2.0_norm/image_encoder.module.module.5.bias_step 0.0 wandb: grad_2.0_norm/image_encoder.module.module.5.weight_epoch 0.15237 wandb: grad_2.0_norm/image_encoder.module.module.5.weight_step 0.1151 wandb: grad_2.0_norm/image_encoder.module.module.6.bias_epoch 0.0032 wandb: grad_2.0_norm/image_encoder.module.module.6.bias_step 0.0018 wandb: grad_2.0_norm/image_encoder.module.module.6.weight_epoch 0.00157 wandb: grad_2.0_norm/image_encoder.module.module.6.weight_step 0.0012 wandb: grad_2.0_norm/image_encoder.module.module.7.bias_epoch 0.00497 wandb: grad_2.0_norm/image_encoder.module.module.7.bias_step 0.003 wandb: grad_2.0_norm/image_encoder.module.module.7.weight_epoch 0.11753 wandb: grad_2.0_norm/image_encoder.module.module.7.weight_step 0.0915 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_kv.weight_epoch 0.03763 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_kv.weight_step 0.0277 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.bias_epoch 0.0412 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.bias_step 0.0289 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.weight_epoch 0.05167 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.weight_step 0.0369 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_q.weight_epoch 0.0008 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_q.weight_step 0.0008 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_kv.weight_epoch 0.04393 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_kv.weight_step 0.0216 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.bias_epoch 0.0412 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.bias_step 0.0289 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.weight_epoch 0.04287 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.weight_step 0.027 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_q.weight_epoch 0.0014 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_q.weight_step 0.0009 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.bias_epoch 0.00197 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.bias_step 0.001 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.weight_epoch 0.03313 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.weight_step 0.0216 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.bias_epoch 0.00103 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.bias_step 0.0004 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.weight_epoch 0.00353 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.weight_step 0.002 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.bias_epoch 0.00133 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.bias_step 0.0009 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.weight_epoch 0.02123 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.weight_step 0.0147 wandb: grad_2.0_norm_total_epoch 0.31513 wandb: grad_2.0_norm_total_step 0.2254 wandb: train/loss_epoch 1.72826 wandb: train/loss_step 1.73303 wandb: trainer/global_step 251 wandb: validation/loss_epoch 1.76064
I have plotted the inputs as they flow through the layers, and none of them seems to do anything unexpected:
I'm out of ideas and would appreciate any input.
To Reproduce
Steps to reproduce the behavior:
The text was updated successfully, but these errors were encountered: