-
Notifications
You must be signed in to change notification settings - Fork 317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] SAC loss masking #2612
Comments
If the problem is that the observation of done states could be nan (which I argue it shouldn't) we could consider replacing nans with 0 so that the forward can run |
This change was indeed required as some network cannot accept values of observations that are written when the environment is done. There should not be any See #2590 for context.
NaN were there just to exemplify. In practice, the network should not be queried if the values are not used. For instance you could have a model that implements some sort of internal state update at each query and you woulnd't want this to be modified by values that will be discarded. Re (1), we could decide not to mask with If this can be addressed differently I'm happy to give it a look! cc @fmeirinhos |
I think we also need more coverage of MARL usage of the losses in the tests, because this could have been easily spotted if running |
I understand the orginal issue, but this seems to be a very difficult pickle. not all networks can be queried with sparse data or data of arbitrary shape (list is long) |
Not sure it makes sense to enforce this here without introducing problems bigger than the original |
You know what, I'm happy to revert that PR as soon as we can figure out what to do for people who have networks that simply cannot accept "done" observation values! I do think it's a valid concern and it should be addressed, but obviously by a non-buggy solution. |
I have been checking a bit how other libraries do it and they seem to pass the next obs anyway maybe it is my opinion, but i don’t see what is particular about an observation of a done state, it should be part of the same observation space as the others furthermore, for policies that have an internal state or counter, this in sac is a bit unnatural as the policy is called from 2 places anyway (actions and values) so keeping track of meaningful states is hard |
I don't think we should overfit too much to what other libs do but to issues our users are facing.
this is just an example. The point is that if an error is thrown when invalid data is passed to a network, we should never reach that error (or give the tooling necessary to avoid that). We could add a flag in the constructor like Then we capture errors where relevant and if the actor network raises during a call on the I gave it a shot in #2613 (without the capture of the error) |
For the record, this is an example of a function that errors when there's a NaN >>> import torch
>>> matrix_with_nan = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]])
>>>
>>> result = torch.linalg.cholesky(matrix_with_nan) Note that replacing NaN with 0s with cholesky is also problematic >>> torch.linalg.cholesky(torch.zeros(4, 4))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
torch._C._LinAlgError: linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1 is not positive-definite). |
That seems a reasonable solution to me, happy to review. i think we are just facing a difficult issue as i can clearly understand where both problems are coming from. I also don’t like padding or calls on useless data |
Also flattening the cholesky input and removing the NaN value is problematic no? |
no the matrix is in the feature dim, not the batch dim. It isn't flattened |
PR #2606
Introduces indexing of loss tensordict using done signals
rl/torchrl/objectives/sac.py
Line 718 in d537dcb
rl/torchrl/objectives/sac.py
Line 1223 in d537dcb
I have multiple concerns regarding this PR:
value_estimate()
will read the donesvalue_estimate()
already reads done and discards next_values for done states. plus the target of a done state should be the reward, so here by using 0s we are actually introducing a further bugIn my opinion this change was not needed as the done target_values of done states are already discarded in
value_estimate()
.Maybe I am wrong in this analysis, please let me know.
I do not think it is possible to avoid submitting inputs of done states to the policy without changing the input shape (which we should avoid as it could lead to errors)
The text was updated successfully, but these errors were encountered: