-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/combined loss #70
base: develop
Are you sure you want to change the base?
Conversation
if config.training.training_loss._target_ == 'anemoi.training.losses.combined.CombinedLoss': | ||
assert "loss_weights" in config.training.training_loss, "Loss weights must be provided for combined loss" | ||
losses = [] | ||
ignore_nans = config.training.training_loss.get("ignore_nans", False) # no point in doing this for each loss, nan+nan is nan | ||
for loss in config.training.training_loss.losses: | ||
node_weighting = instantiate(loss.node_weights) | ||
loss_node_weights = node_weighting.weights(graph_data) | ||
loss_node_weights = self.output_mask.apply(loss_node_weights, dim=0, fill_value=0.0) | ||
loss_instantiated = self.get_loss_function(loss, scalars=self.scalars, **{"node_weights": loss_node_weights, "ignore_nans": ignore_nans}) | ||
losses.append(loss_instantiated) | ||
assert isinstance(loss_instantiated, BaseWeightedLoss) | ||
self.loss = instantiate({"_target_": config.training.training_loss._target_}, losses=losses, loss_weights = config.training.training_loss.loss_weights, **loss_kwargs) | ||
else: | ||
self.loss = self.get_loss_function(config.training.training_loss, scalars=self.scalars, **loss_kwargs) | ||
assert isinstance(self.loss, BaseWeightedLoss) and not isinstance( | ||
self.loss, | ||
torch.nn.ModuleList, | ||
), f"Loss function must be a `BaseWeightedLoss`, not a {type(self.loss).__name__!r}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this is over specific for this use case, and instantiate's objects unneccessarily
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instantiating node_weights was necessary to call the combined loss but if you find a way around it, please let me know... I have another version where all of this is implemented in the get_loss_function from the forecaster. It is cleaner so I'll try to commit it soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, yeah, as I wrote the loss functions code originally, I was able to find a way around, and only update the CombinedLoss class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you'd like, we can work together on https://github.com/ecmwf/anemoi-core/tree/fix/combined_loss_hcookie to make sure your use case is addressed.
elif hasattr(loss, "__class__"): | ||
self.losses.append(loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we checking for __class__
? If checking for an object why not isinstance(loss, object)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it could originally only take a class (of type "type", not instantiated) as losses arguments. Indeed, loss(**kwargs) called later in the function expects init arguments from the individual loss object and not forward arguments. As I said, I'll try to commit recent changes later.
Try to fix #68
Add tests