Skip to content

Commit

Permalink
Merge branch 'main' into tempdir-finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Oct 23, 2024
2 parents 145388b + 21c7ec8 commit 03a3717
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 10 deletions.
1 change: 1 addition & 0 deletions llmfoundry/command_utils/data_prep/convert_dataset_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def convert_dataset_hf_from_args(
ValueError: If the output directory already contains the requested splits
ValueError: If `concat_tokens` is set but `tokenizer` is not
"""
os.environ['WORLD_SIZE'] = '1'
if tokenizer_kwargs:
parsed_tokenizer_kwargs = json.loads(tokenizer_kwargs)
else:
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/command_utils/data_prep/convert_dataset_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def convert_dataset_json_from_args(
ValueError: If the out_root directory exists and contains files that overlap with the requested splits
ValueError: If concat_tokens is set and a tokenizer is not provided
"""
os.environ['WORLD_SIZE'] = '1'
if os.path.isdir(out_root) and len(
set(os.listdir(out_root)).intersection(set(split)),
) > 0:
Expand Down
4 changes: 4 additions & 0 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,9 @@ def validate_and_get_cluster_info(
).upper()[len('DATASECURITYMODE.'):]

# NONE stands for No Isolation Shared
# This check actually checks for Unity Catalog governance compatibility and does not
# check for invalid cluster access for a particular user. Cluster access controls is
# difficult and there is no single existing API to check this.
if data_security_mode == 'NONE':
raise ClusterInvalidAccessMode(
cluster_id=cluster_id,
Expand Down Expand Up @@ -767,6 +770,7 @@ def convert_delta_to_json_from_args(
use_serverless (bool): Use serverless or not. Make sure the workspace is entitled with serverless
json_output_filename (str): The name of the combined final jsonl that combines all partitioned jsonl
"""
os.environ['WORLD_SIZE'] = '1'
_check_imports()
from databricks.sdk import WorkspaceClient
w = WorkspaceClient()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def convert_finetuning_dataset_from_args(
ValueError: If the target settings are invalid.
ValueError: If the output directory already contains the requested splits.
"""
os.environ['WORLD_SIZE'] = '1'
if os.path.isdir(out_root) and len(
set(os.listdir(out_root)).intersection(set(splits)),
) > 0:
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/command_utils/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def convert_text_to_mds_from_args(
Raises:
ValueError: If `use_tokenizer_eos` is True and `eos_text` is not None
"""
os.environ['WORLD_SIZE'] = '1'
if use_tokenizer_eos:
# Ensure that eos text is not specified twice.
if eos_text is not None:
Expand Down
28 changes: 21 additions & 7 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,22 @@ def is_valid_ift_example(
return True


def _get_num_processes() -> int:
"""Get the number of processes to use for dataset processing."""
detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_proc = max(1, detected_cpus_with_margin)

# Check if the user has set the MAX_NUM_PROC environment variable
# which caps the number of processes used for dataset processing.
if 'MAX_NUM_PROC' in os.environ:
max_num_proc_env = int(os.environ['MAX_NUM_PROC'])
if max_num_proc_env < num_proc:
num_proc = max_num_proc_env

return num_proc


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Expand Down Expand Up @@ -956,18 +972,16 @@ def dataset_mapper(example: dict):
)
return mapping_fn(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_cpus_to_use = max(1, detected_cpus_with_margin)
if len(dataset) < num_cpus_to_use:
num_cpus_to_use = 1
num_proc = _get_num_processes()
if len(dataset) < num_proc:
num_proc = 1

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
num_proc=num_proc,
desc='Tokenizing dataset',
)

Expand All @@ -979,7 +993,7 @@ def dataset_mapper(example: dict):
target_responses,
decoder_only_format,
),
num_proc=num_cpus_to_use,
num_proc=num_proc,
desc='Filtering out long prompts',
)

Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ class ClusterInvalidAccessMode(UserError):
"""Error thrown when the cluster does not exist."""

def __init__(self, cluster_id: str, access_mode: str) -> None:
message = f'Cluster with id {cluster_id} has access mode {access_mode}. ' + \
'please make sure the cluster used has access mode Shared or Single User!'
message = f'The cluster you have provided: {cluster_id} does not have data governance enabled.' + \
'Please use a cluster with a data security mode other than NONE.'
super().__init__(
message,
cluster_id=cluster_id,
Expand Down
20 changes: 19 additions & 1 deletion tests/data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import os
from contextlib import nullcontext
from typing import Optional
from unittest import mock

import pytest

from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.finetuning.tasks import (
_get_num_processes,
dataset_constructor,
)
from llmfoundry.utils.exceptions import DatasetTooSmallError


def test_get_num_processes():
with mock.patch.dict(os.environ, {'MAX_NUM_PROC': '4'}):
with mock.patch('os.cpu_count', return_value=16):
assert _get_num_processes() == 4

with mock.patch.dict(os.environ, {'MAX_NUM_PROC': '32'}):
with mock.patch('os.cpu_count', return_value=16):
assert _get_num_processes() == 8

with mock.patch.dict(os.environ, {}):
with mock.patch('os.cpu_count', return_value=16):
assert _get_num_processes() == 8


@pytest.mark.parametrize('num_canonical_nodes', [None, 8, 2])
def test_finetuning_streaming_dataset_too_small(
num_canonical_nodes: Optional[int],
Expand Down

0 comments on commit 03a3717

Please sign in to comment.