Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] ActionDiscretizer scalar integration #2615

Open
1 task done
oslumbers opened this issue Nov 28, 2024 · 2 comments
Open
1 task done

[Feature Request] ActionDiscretizer scalar integration #2615

oslumbers opened this issue Nov 28, 2024 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@oslumbers
Copy link

Motivation

The ActionDiscretizer only gives the option of converting the input_spec["full_action_spec"] to MultiCategorical or MultiOneHot. This introduces a dimension into the shape:

MultiCategorical(
    shape=torch.Size([1]),
    space=BoxList(boxes=[CategoricalBox(n=4)]),
    dtype=torch.int64,
    domain=discrete)

which for me causes errors in the collector which is expecting a scalar shape:

  File "runner.py", line 347, in run
    rollout = next(self.collector_iter)
  File "torchrl/collectors/collectors.py", line 1031, in iterator
    tensordict_out = self.rollout()
  File "torchrl/_utils.py", line 481, in unpack_rref_and_invoke_function
    return func(self, *args, **kwargs)
  File "torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "torchrl/collectors/collectors.py", line 1162, in rollout
    env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
  File "torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "torchrl/envs/batched_envs.py", line 67, in decorated_fun
    return fun(self, *args, **kwargs)
  File "torchrl/envs/batched_envs.py", line 1572, in step_and_maybe_reset
    shared_tensordict_parent.update_(
  File "tensordict/base.py", line 5339, in update_
    self._apply_nest(
  File "tensordict/_td.py", line 1330, in _apply_nest
    item_trsf = item._apply_nest(
  File "tensordict/_td.py", line 1330, in _apply_nest
    item_trsf = item._apply_nest(
  File "tensordict/_td.py", line 1330, in _apply_nest
    item_trsf = item._apply_nest(
  File "tensordict/_td.py", line 1350, in _apply_nest
    item_trsf = fn(
  File "tensordict/base.py", line 5318, in inplace_update
    dest.copy_(source, non_blocking=non_blocking)
RuntimeError: output with shape [2, 1] doesn't match the broadcast shape [2, 2]

Solution

To get around this issue, I can replace the MultiCategorical instead with a Categorical:

Categorical(
    shape=torch.Size([]),
    space=CategoricalBox(n=tensor([4])),
    device=cpu,
    dtype=torch.int64,
    domain=discrete)

However, _inv_call() does not have functionality for a scalar action, therefore have to do something like, in line 8658:

action = action.unsqueeze(-1)

to

action = action.unsqueeze(-1).unsqueeze(-1)

so that intervals.ndim == action.ndim.

Alternatives

Could we either:

  1. Add an argument for selecting between MultiCategorical or Categorical
  2. Or bring the creation of the new action_spec outside of the transform_input_spec method such that any child Class of ActionDiscretizer can more specifically define the desired action_spec - rather than currently I have to override transform_input_spec which I would rather maintain.

Also, within _inv_call can we add functionality to account for a scalar action?

Checklist

  • I have checked that there is no similar issue in the repo (required)
@oslumbers oslumbers added the enhancement New feature or request label Nov 28, 2024
@vmoens
Copy link
Contributor

vmoens commented Nov 29, 2024

Looking at it

Here is a MRE for future use

from typing import Optional

from torchrl.envs import EnvBase, ActionDiscretizer
from tensordict import TensorDict, TensorDictBase
from torchrl.data import Bounded
import torch

class EnvWithScalarAction(EnvBase):
    _batch_size = torch.Size(())

    def _reset(self, td: TensorDict):
        return TensorDict(observation=torch.randn(3), done=torch.zeros(1, dtype=torch.bool), truncated=torch.zeros(1, dtype=torch.bool), terminated=torch.zeros(1, dtype=torch.bool))

    def _step(
        self,
        tensordict: TensorDictBase,
    ) -> TensorDictBase:
        return TensorDict(observation=torch.randn(3), reward=torch.zeros(1), done=torch.zeros(1, dtype=torch.bool), truncated=torch.zeros(1, dtype=torch.bool), terminated=torch.zeros(1, dtype=torch.bool))

    def _set_seed(self, seed: Optional[int]):
        ...

def policy(td):
    td.set("action", torch.rand(()))
    return td

env = EnvWithScalarAction()
env.auto_specs_(policy=policy)
env.action_spec = Bounded(-1, 1, shape=())

tenv = env.append_transform(ActionDiscretizer(num_intervals=4))

print(tenv.rollout(4))

@vmoens
Copy link
Contributor

vmoens commented Nov 29, 2024

This is a first stab
#2619

Needs more comprehensive tests etc.

Question: is it also breaking with action_spec with shape [1] or just []?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants