From 4238d0d3cdf9ec252bd258fd4ca261bc9e7fcd0f Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 9 Apr 2024 17:18:00 +0000 Subject: [PATCH] making allow_unsafe_types and replication configurable through dataset configs --- llmfoundry/data/finetuning/dataloader.py | 2 ++ llmfoundry/data/finetuning/tasks.py | 6 ++++-- llmfoundry/data/text_data.py | 6 ++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 38c9673a14..1d8711d280 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -170,6 +170,8 @@ def build_finetuning_dataloader(cfg: DictConfig, sampling_granularity=cfg.dataset.get('sampling_granularity', 1), batching_method=cfg.dataset.get('batching_method', 'random'), max_seq_len=cfg.dataset.max_seq_len, + allow_unsafe_types=cfg.dataset.get('allow_unsafe_types', False), + replication=cfg.dataset.get('replication', None), ) else: diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 1c25a685e1..e6992d1ce9 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -516,6 +516,8 @@ def __init__(self, sampling_granularity: int = 1, batching_method: str = 'random', max_seq_len: int = 2048, + allow_unsafe_types: bool = False, + replication: Optional[int] = None, **kwargs: Any): if len(kwargs) > 0: @@ -552,8 +554,8 @@ def __init__(self, sampling_method=sampling_method, sampling_granularity=sampling_granularity, batching_method=batching_method, - allow_unsafe_types=False, - replication=None, + allow_unsafe_types=allow_unsafe_types, + replication=replication, ) self.tokenizer = tokenizer diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 221f959368..af1c02bf79 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -109,6 +109,8 @@ def __init__(self, sampling_method: str = 'balanced', sampling_granularity: int = 1, batching_method: str = 'random', + allow_unsafe_types: bool = False, + replication: Optional[int] = None, **kwargs: Any): if len(kwargs) > 0: @@ -151,8 +153,8 @@ def __init__(self, sampling_method=sampling_method, sampling_granularity=sampling_granularity, batching_method=batching_method, - allow_unsafe_types=False, - replication=None, + allow_unsafe_types=allow_unsafe_types, + replication=replication, ) self.tokenizer = tokenizer self.max_seq_len = max_seq_len