diff --git a/README.md b/README.md index 5024d88c9f..ca972d68ac 100644 --- a/README.md +++ b/README.md @@ -426,6 +426,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod - path: knowrohit07/know_sql type: context_qa.load_v2 train_on_split: validation + + # loading from s3 or gcs + # s3 creds will be loaded from the system default and gcs only supports public access + dataset: + - path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs. + ... ``` - loading @@ -520,7 +526,7 @@ float16: true # A list of one or more datasets to finetune the model with datasets: - # HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files + # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files - path: vicgalle/alpaca-gpt4 # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] type: alpaca # format | format: (chat/instruct) | .load_ diff --git a/requirements.txt b/requirements.txt index 9ed66033bd..dec9398327 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ deepspeed addict fire PyYAML>=6.0 -datasets +datasets>=2.14.0 flash-attn>=2.3.0 sentencepiece wandb @@ -33,3 +33,8 @@ art fschat==0.2.29 gradio tensorboard + +# remote filesystems +s3fs +gcsfs +# adlfs diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 2af85831ad..a62b34e1d9 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -170,30 +170,74 @@ def for_d_in_datasets(dataset_configs): except (FileNotFoundError, ConnectionError): pass + ds_from_cloud = False + storage_options = {} + remote_file_system = None + if config_dataset.path.startswith("s3://"): + try: + import aiobotocore.session # type: ignore + import s3fs # type: ignore + except ImportError as exc: + raise ImportError( + "s3:// paths require aiobotocore and s3fs to be installed" + ) from exc + + # Takes credentials from ~/.aws/credentials for default profile + s3_session = aiobotocore.session.AioSession(profile="default") + storage_options = {"session": s3_session} + remote_file_system = s3fs.S3FileSystem(**storage_options) + elif config_dataset.path.startswith( + "gs://" + ) or config_dataset.path.startswith("gcs://"): + try: + import gcsfs # type: ignore + except ImportError as exc: + raise ImportError( + "gs:// or gcs:// paths require gcsfs to be installed" + ) from exc + + # gcsfs will use default credentials from the environment else anon + # https://gcsfs.readthedocs.io/en/latest/#credentials + storage_options = {"token": None} + remote_file_system = gcsfs.GCSFileSystem(**storage_options) + # TODO: Figure out how to get auth creds passed + # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"): + # try: + # import adlfs + # except ImportError as exc: + # raise ImportError( + # "adl:// or abfs:// paths require adlfs to be installed" + # ) from exc + + # # Gen 1 + # storage_options = { + # "tenant_id": TENANT_ID, + # "client_id": CLIENT_ID, + # "client_secret": CLIENT_SECRET, + # } + # # Gen 2 + # storage_options = { + # "account_name": ACCOUNT_NAME, + # "account_key": ACCOUNT_KEY, + # } + + # remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) + try: + if remote_file_system and remote_file_system.exists( + config_dataset.path + ): + ds_from_cloud = True + except (FileNotFoundError, ConnectionError): + pass + # prefer local dataset, even if hub exists local_path = Path(config_dataset.path) if local_path.exists(): if local_path.is_dir(): - # TODO dirs with arrow or parquet files could be loaded with `load_from_disk` - ds = load_dataset( - config_dataset.path, - name=config_dataset.name, - data_files=config_dataset.data_files, - streaming=False, - split=None, - ) + ds = load_from_disk(config_dataset.path) elif local_path.is_file(): - ds_type = "json" - if config_dataset.ds_type: - ds_type = config_dataset.ds_type - elif ".parquet" in config_dataset.path: - ds_type = "parquet" - elif ".arrow" in config_dataset.path: - ds_type = "arrow" - elif ".csv" in config_dataset.path: - ds_type = "csv" - elif ".txt" in config_dataset.path: - ds_type = "text" + ds_type = get_ds_type(config_dataset) + ds = load_dataset( ds_type, name=config_dataset.name, @@ -213,6 +257,22 @@ def for_d_in_datasets(dataset_configs): data_files=config_dataset.data_files, token=use_auth_token, ) + elif ds_from_cloud and remote_file_system: + if remote_file_system.isdir(config_dataset.path): + ds = load_from_disk( + config_dataset.path, + storage_options=storage_options, + ) + elif remote_file_system.isfile(config_dataset.path): + ds_type = get_ds_type(config_dataset) + ds = load_dataset( + ds_type, + name=config_dataset.name, + data_files=config_dataset.path, + streaming=False, + split=None, + storage_options=storage_options, + ) else: if isinstance(config_dataset.data_files, str): fp = hf_hub_download( @@ -304,6 +364,24 @@ def for_d_in_datasets(dataset_configs): return dataset, prompters +def get_ds_type(config_dataset: DictDefault): + """ + Get the dataset type from the path if it's not specified + """ + ds_type = "json" + if config_dataset.ds_type: + ds_type = config_dataset.ds_type + elif ".parquet" in config_dataset.path: + ds_type = "parquet" + elif ".arrow" in config_dataset.path: + ds_type = "arrow" + elif ".csv" in config_dataset.path: + ds_type = "csv" + elif ".txt" in config_dataset.path: + ds_type = "text" + return ds_type + + def load_prepare_datasets( tokenizer: PreTrainedTokenizerBase, cfg,