From 2dff16fa582487a7117221fb56331f2a08bf4726 Mon Sep 17 00:00:00 2001 From: Augustin Godinot Date: Fri, 8 Nov 2024 18:15:03 +0100 Subject: [PATCH 1/2] Add --dataset-trust-remote-code to the train.py and validate.py scripts --- train.py | 4 ++++ validate.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/train.py b/train.py index 7aab20a76a..38fd3c6d9b 100755 --- a/train.py +++ b/train.py @@ -102,6 +102,8 @@ help='Dataset key for input images.') group.add_argument('--target-key', default=None, type=str, help='Dataset key for target labels.') +group.add_argument('--dataset-trust-remote-code', action='store_true', default=False, + help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.') # Model parameters group = parser.add_argument_group('Model parameters') @@ -641,6 +643,7 @@ def main(): input_key=args.input_key, target_key=args.target_key, num_samples=args.train_num_samples, + trust_remote_code=args.dataset_trust_remote_code, ) if args.val_split: @@ -656,6 +659,7 @@ def main(): input_key=args.input_key, target_key=args.target_key, num_samples=args.val_num_samples, + trust_remote_code=args.dataset_trust_remote_code, ) # setup mixup / cutmix diff --git a/validate.py b/validate.py index 602111bb6a..159bd0b1a1 100755 --- a/validate.py +++ b/validate.py @@ -66,6 +66,8 @@ help='Dataset image conversion mode for input images.') parser.add_argument('--target-key', default=None, type=str, help='Dataset key for target labels.') +parser.add_argument('--dataset-trust-remote-code', action='store_true', default=False, + help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.') parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', help='model architecture (default: dpn92)') @@ -268,6 +270,7 @@ def validate(args): input_key=args.input_key, input_img_mode=input_img_mode, target_key=args.target_key, + trust_remote_code=args.dataset_trust_remote_code, ) if args.valid_labels: From 7573096eb8049acdec9b6681240088af25cf17ce Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 6 Dec 2024 11:40:04 -0800 Subject: [PATCH 2/2] Make sure trust_remote code only passed to HF datasets. Improve some docstrings. --- timm/data/dataset.py | 2 ++ timm/data/dataset_factory.py | 43 +++++++++++++++++-------------- timm/data/readers/reader_hfds.py | 2 +- timm/data/readers/reader_hfids.py | 7 ++++- 4 files changed, 33 insertions(+), 21 deletions(-) diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 1c481bffa5..14d484ba9f 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -103,6 +103,7 @@ def __init__( transform=None, target_transform=None, max_steps=None, + **kwargs, ): assert reader is not None if isinstance(reader, str): @@ -121,6 +122,7 @@ def __init__( input_key=input_key, target_key=target_key, max_steps=max_steps, + **kwargs, ) else: self.reader = reader diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index 3b735c4e7b..021d50be1c 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -74,34 +74,37 @@ def create_dataset( seed: int = 42, repeats: int = 0, input_img_mode: str = 'RGB', + trust_remote_code: bool = False, **kwargs, ): """ Dataset factory method In parentheses after each arg are the type of dataset supported for each arg, one of: - * folder - default, timm folder (or tar) based ImageDataset - * torch - torchvision based datasets + * Folder - default, timm folder (or tar) based ImageDataset + * Torch - torchvision based datasets * HFDS - Hugging Face Datasets + * HFIDS - Hugging Face Datasets Iterable (streaming mode, with IterableDataset) * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset * WDS - Webdataset - * all - any of the above + * All - any of the above Args: - name: dataset name, empty is okay for folder based datasets - root: root folder of dataset (all) - split: dataset split (all) - search_split: search for split specific child fold from root so one can specify - `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder) - class_map: specify class -> index mapping via text file or dict (folder) - load_bytes: load data, return images as undecoded bytes (folder) - download: download dataset if not present and supported (HFDS, TFDS, torch) - is_training: create dataset in train mode, this is different from the split. - For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS) - batch_size: batch size hint for (TFDS, WDS) - seed: seed for iterable datasets (TFDS, WDS) - repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS) - input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS) - **kwargs: other args to pass to dataset + name: Dataset name, empty is okay for folder based datasets + root: Root folder of dataset (All) + split: Dataset split (All) + search_split: Search for split specific child fold from root so one can specify + `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch) + class_map: Specify class -> index mapping via text file or dict (Folder) + load_bytes: Load data, return images as undecoded bytes (Folder) + download: Download dataset if not present and supported (HFIDS, TFDS, Torch) + is_training: Create dataset in train mode, this is different from the split. + For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS, HFIDS) + batch_size: Batch size hint for iterable datasets (TFDS, WDS, HFIDS) + seed: Seed for iterable datasets (TFDS, WDS, HFIDS) + repeats: Dataset repeats per iteration i.e. epoch (TFDS, WDS, HFIDS) + input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS, HFIDS) + trust_remote_code: Trust remote code in Hugging Face Datasets if True (HFDS, HFIDS) + **kwargs: Other args to pass through to underlying Dataset and/or Reader classes Returns: Dataset object @@ -162,6 +165,7 @@ def create_dataset( split=split, class_map=class_map, input_img_mode=input_img_mode, + trust_remote_code=trust_remote_code, **kwargs, ) elif name.startswith('hfids/'): @@ -177,7 +181,8 @@ def create_dataset( repeats=repeats, seed=seed, input_img_mode=input_img_mode, - **kwargs + trust_remote_code=trust_remote_code, + **kwargs, ) elif name.startswith('tfds/'): ds = IterableImageDataset( diff --git a/timm/data/readers/reader_hfds.py b/timm/data/readers/reader_hfds.py index b205447252..13f8e24488 100644 --- a/timm/data/readers/reader_hfds.py +++ b/timm/data/readers/reader_hfds.py @@ -48,7 +48,7 @@ def __init__( self.dataset = datasets.load_dataset( name, # 'name' maps to path arg in hf datasets split=split, - cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path + cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path if root set trust_remote_code=trust_remote_code ) # leave decode for caller, plus we want easy access to original path names... diff --git a/timm/data/readers/reader_hfids.py b/timm/data/readers/reader_hfids.py index 9f7ce76db0..943ce2158b 100644 --- a/timm/data/readers/reader_hfids.py +++ b/timm/data/readers/reader_hfids.py @@ -44,6 +44,7 @@ def __init__( target_img_mode: str = '', shuffle_size: Optional[int] = None, num_samples: Optional[int] = None, + trust_remote_code: bool = False ): super().__init__() self.root = root @@ -60,7 +61,11 @@ def __init__( self.target_key = target_key self.target_img_mode = target_img_mode - self.builder = datasets.load_dataset_builder(name, cache_dir=root) + self.builder = datasets.load_dataset_builder( + name, + cache_dir=root, + trust_remote_code=trust_remote_code, + ) if download: self.builder.download_and_prepare()