Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset trust remote tweaks #2361

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions timm/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
transform=None,
target_transform=None,
max_steps=None,
**kwargs,
):
assert reader is not None
if isinstance(reader, str):
Expand All @@ -121,6 +122,7 @@ def __init__(
input_key=input_key,
target_key=target_key,
max_steps=max_steps,
**kwargs,
)
else:
self.reader = reader
Expand Down
43 changes: 24 additions & 19 deletions timm/data/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,34 +74,37 @@ def create_dataset(
seed: int = 42,
repeats: int = 0,
input_img_mode: str = 'RGB',
trust_remote_code: bool = False,
**kwargs,
):
""" Dataset factory method

In parentheses after each arg are the type of dataset supported for each arg, one of:
* folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets
* Folder - default, timm folder (or tar) based ImageDataset
* Torch - torchvision based datasets
* HFDS - Hugging Face Datasets
* HFIDS - Hugging Face Datasets Iterable (streaming mode, with IterableDataset)
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
* WDS - Webdataset
* all - any of the above
* All - any of the above

Args:
name: dataset name, empty is okay for folder based datasets
root: root folder of dataset (all)
split: dataset split (all)
search_split: search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
class_map: specify class -> index mapping via text file or dict (folder)
load_bytes: load data, return images as undecoded bytes (folder)
download: download dataset if not present and supported (HFDS, TFDS, torch)
is_training: create dataset in train mode, this is different from the split.
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS)
batch_size: batch size hint for (TFDS, WDS)
seed: seed for iterable datasets (TFDS, WDS)
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS)
**kwargs: other args to pass to dataset
name: Dataset name, empty is okay for folder based datasets
root: Root folder of dataset (All)
split: Dataset split (All)
search_split: Search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch)
class_map: Specify class -> index mapping via text file or dict (Folder)
load_bytes: Load data, return images as undecoded bytes (Folder)
download: Download dataset if not present and supported (HFIDS, TFDS, Torch)
is_training: Create dataset in train mode, this is different from the split.
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS, HFIDS)
batch_size: Batch size hint for iterable datasets (TFDS, WDS, HFIDS)
seed: Seed for iterable datasets (TFDS, WDS, HFIDS)
repeats: Dataset repeats per iteration i.e. epoch (TFDS, WDS, HFIDS)
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS, HFIDS)
trust_remote_code: Trust remote code in Hugging Face Datasets if True (HFDS, HFIDS)
**kwargs: Other args to pass through to underlying Dataset and/or Reader classes

Returns:
Dataset object
Expand Down Expand Up @@ -162,6 +165,7 @@ def create_dataset(
split=split,
class_map=class_map,
input_img_mode=input_img_mode,
trust_remote_code=trust_remote_code,
**kwargs,
)
elif name.startswith('hfids/'):
Expand All @@ -177,7 +181,8 @@ def create_dataset(
repeats=repeats,
seed=seed,
input_img_mode=input_img_mode,
**kwargs
trust_remote_code=trust_remote_code,
**kwargs,
)
elif name.startswith('tfds/'):
ds = IterableImageDataset(
Expand Down
2 changes: 1 addition & 1 deletion timm/data/readers/reader_hfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self.dataset = datasets.load_dataset(
name, # 'name' maps to path arg in hf datasets
split=split,
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path if root set
trust_remote_code=trust_remote_code
)
# leave decode for caller, plus we want easy access to original path names...
Expand Down
7 changes: 6 additions & 1 deletion timm/data/readers/reader_hfids.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
target_img_mode: str = '',
shuffle_size: Optional[int] = None,
num_samples: Optional[int] = None,
trust_remote_code: bool = False
):
super().__init__()
self.root = root
Expand All @@ -60,7 +61,11 @@ def __init__(
self.target_key = target_key
self.target_img_mode = target_img_mode

self.builder = datasets.load_dataset_builder(name, cache_dir=root)
self.builder = datasets.load_dataset_builder(
name,
cache_dir=root,
trust_remote_code=trust_remote_code,
)
if download:
self.builder.download_and_prepare()

Expand Down
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@
help='Dataset key for input images.')
group.add_argument('--target-key', default=None, type=str,
help='Dataset key for target labels.')
group.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')

# Model parameters
group = parser.add_argument_group('Model parameters')
Expand Down Expand Up @@ -653,6 +655,7 @@ def main():
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.train_num_samples,
trust_remote_code=args.dataset_trust_remote_code,
)

if args.val_split:
Expand All @@ -668,6 +671,7 @@ def main():
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.val_num_samples,
trust_remote_code=args.dataset_trust_remote_code,
)

# setup mixup / cutmix
Expand Down
3 changes: 3 additions & 0 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
help='Dataset image conversion mode for input images.')
parser.add_argument('--target-key', default=None, type=str,
help='Dataset key for target labels.')
parser.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')

parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
Expand Down Expand Up @@ -268,6 +270,7 @@ def validate(args):
input_key=args.input_key,
input_img_mode=input_img_mode,
target_key=args.target_key,
trust_remote_code=args.dataset_trust_remote_code,
)

if args.valid_labels:
Expand Down
Loading