From 5d796f4e7612cfac0882233d0aff72623788c432 Mon Sep 17 00:00:00 2001 From: LTluttmann Date: Sun, 9 Jun 2024 19:36:04 +0200 Subject: [PATCH] td dataset without dict comprehension --- rl4co/data/dataset.py | 21 +++++++++++++++++++++ rl4co/envs/common/base.py | 4 ++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/rl4co/data/dataset.py b/rl4co/data/dataset.py index a1f03391..9eeb5043 100644 --- a/rl4co/data/dataset.py +++ b/rl4co/data/dataset.py @@ -6,6 +6,27 @@ from torch.utils.data import Dataset +class FastTdDataset(Dataset): + """ + Note: + Check out the issue on tensordict for more details: + https://github.com/pytorch-labs/tensordict/issues/374. + """ + + def __init__(self, td: TensorDict): + self.data_len = td.batch_size[0] + self.data = td + + def __len__(self): + return self.data_len + + def __getitems__(self, idx): + return self.data[idx] + + def add_key(self, key, value): + return ExtraKeyDataset(self, value, key_name=key) + + class TensorDictDataset(Dataset): """Dataset compatible with TensorDicts with low CPU usage. Fast loading but somewhat slow instantiation due to list comprehension since we diff --git a/rl4co/envs/common/base.py b/rl4co/envs/common/base.py index e477d7e3..c9a0068f 100644 --- a/rl4co/envs/common/base.py +++ b/rl4co/envs/common/base.py @@ -8,7 +8,7 @@ from tensordict.tensordict import TensorDict from torchrl.envs import EnvBase -from rl4co.data.dataset import TensorDictDataset +from rl4co.data.dataset import FastTdDataset from rl4co.data.utils import load_npz_to_tensordict from rl4co.utils.ops import get_num_starts, select_start_nodes from rl4co.utils.pylogger import get_pylogger @@ -52,7 +52,7 @@ def __init__( val_dataloader_names: list = None, test_dataloader_names: list = None, check_solution: bool = True, - dataset_cls: callable = TensorDictDataset, + dataset_cls: callable = FastTdDataset, seed: int = None, device: str = "cpu", batch_size: torch.Size = None,