Skip to content

Commit

Permalink
td dataset without dict comprehension
Browse files Browse the repository at this point in the history
  • Loading branch information
LTluttmann committed Jun 9, 2024
1 parent 8c3c2a0 commit 5d796f4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
21 changes: 21 additions & 0 deletions rl4co/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@
from torch.utils.data import Dataset


class FastTdDataset(Dataset):
"""
Note:
Check out the issue on tensordict for more details:
https://github.com/pytorch-labs/tensordict/issues/374.
"""

def __init__(self, td: TensorDict):
self.data_len = td.batch_size[0]
self.data = td

def __len__(self):
return self.data_len

def __getitems__(self, idx):
return self.data[idx]

def add_key(self, key, value):
return ExtraKeyDataset(self, value, key_name=key)


class TensorDictDataset(Dataset):
"""Dataset compatible with TensorDicts with low CPU usage.
Fast loading but somewhat slow instantiation due to list comprehension since we
Expand Down
4 changes: 2 additions & 2 deletions rl4co/envs/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tensordict.tensordict import TensorDict
from torchrl.envs import EnvBase

from rl4co.data.dataset import TensorDictDataset
from rl4co.data.dataset import FastTdDataset
from rl4co.data.utils import load_npz_to_tensordict
from rl4co.utils.ops import get_num_starts, select_start_nodes
from rl4co.utils.pylogger import get_pylogger
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(
val_dataloader_names: list = None,
test_dataloader_names: list = None,
check_solution: bool = True,
dataset_cls: callable = TensorDictDataset,
dataset_cls: callable = FastTdDataset,
seed: int = None,
device: str = "cpu",
batch_size: torch.Size = None,
Expand Down

0 comments on commit 5d796f4

Please sign in to comment.