You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The ActionDiscretizer only gives the option of converting the input_spec["full_action_spec"] to MultiCategorical or MultiOneHot. This introduces a dimension into the shape:
which for me causes errors in the collector which is expecting a scalar shape:
File "runner.py", line 347, in run
rollout = next(self.collector_iter)
File "torchrl/collectors/collectors.py", line 1031, in iterator
tensordict_out = self.rollout()
File "torchrl/_utils.py", line 481, in unpack_rref_and_invoke_function
return func(self, *args, **kwargs)
File "torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "torchrl/collectors/collectors.py", line 1162, in rollout
env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
File "torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "torchrl/envs/batched_envs.py", line 67, in decorated_fun
return fun(self, *args, **kwargs)
File "torchrl/envs/batched_envs.py", line 1572, in step_and_maybe_reset
shared_tensordict_parent.update_(
File "tensordict/base.py", line 5339, in update_
self._apply_nest(
File "tensordict/_td.py", line 1330, in _apply_nest
item_trsf = item._apply_nest(
File "tensordict/_td.py", line 1330, in _apply_nest
item_trsf = item._apply_nest(
File "tensordict/_td.py", line 1330, in _apply_nest
item_trsf = item._apply_nest(
File "tensordict/_td.py", line 1350, in _apply_nest
item_trsf = fn(
File "tensordict/base.py", line 5318, in inplace_update
dest.copy_(source, non_blocking=non_blocking)
RuntimeError: output with shape [2, 1] doesn't match the broadcast shape [2, 2]
Solution
To get around this issue, I can replace the MultiCategorical instead with a Categorical:
However, _inv_call() does not have functionality for a scalar action, therefore have to do something like, in line 8658:
action = action.unsqueeze(-1)
to
action = action.unsqueeze(-1).unsqueeze(-1)
so that intervals.ndim == action.ndim.
Alternatives
Could we either:
Add an argument for selecting between MultiCategorical or Categorical
Or bring the creation of the new action_spec outside of the transform_input_spec method such that any child Class of ActionDiscretizer can more specifically define the desired action_spec - rather than currently I have to override transform_input_spec which I would rather maintain.
Also, within _inv_call can we add functionality to account for a scalar action?
Checklist
I have checked that there is no similar issue in the repo (required)
The text was updated successfully, but these errors were encountered:
Motivation
The
ActionDiscretizer
only gives the option of converting theinput_spec["full_action_spec"]
toMultiCategorical
orMultiOneHot
. This introduces a dimension into the shape:which for me causes errors in the collector which is expecting a scalar shape:
Solution
To get around this issue, I can replace the
MultiCategorical
instead with aCategorical
:However,
_inv_call()
does not have functionality for a scalar action, therefore have to do something like, in line 8658:to
so that
intervals.ndim == action.ndim
.Alternatives
Could we either:
MultiCategorical
orCategorical
action_spec
outside of thetransform_input_spec
method such that any child Class ofActionDiscretizer
can more specifically define the desiredaction_spec
- rather than currently I have to overridetransform_input_spec
which I would rather maintain.Also, within
_inv_call
can we add functionality to account for a scalar action?Checklist
The text was updated successfully, but these errors were encountered: