Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
merging
  • Loading branch information
snarayan21 committed Sep 30, 2024
2 parents 68f91dd + 107d246 commit 36cc16a
Show file tree
Hide file tree
Showing 51 changed files with 781 additions and 169 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ cd scripts

# Convert C4 dataset to StreamingDataset format
python data_prep/convert_dataset_hf.py \
--dataset c4 --data_subset en \
--dataset allenai/c4 --data_subset en \
--out_root my-copy-c4 --splits train_small val_small \
--concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>'

Expand Down
4 changes: 2 additions & 2 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ Output the processed data to `./my-adaptation-data`. Note that we use smaller su
<!--pytest.mark.skip-->
```bash
python scripts/data_prep/convert_dataset_hf.py \
--dataset c4 --data_subset en \
--dataset allenai/c4 --data_subset en \
--out_root my-adaptation-data --splits train_small val_small \
--concat_tokens 4096 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' \
--compression zstd
Expand Down Expand Up @@ -248,7 +248,7 @@ The first step to training from scratch is to get your pretraining data prepared
<!--pytest.mark.skip-->
```bash
python scripts/data_prep/convert_dataset_hf.py \
--dataset c4 --data_subset en \
--dataset allenai/c4 --data_subset en \
--out_root my-copy-c4 --splits train_small val_small \
--concat_tokens 2048 --tokenizer gpt2 \
--eos_text '<|endoftext|>' \
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
models,
optim,
tokenizers,
tp,
utils,
)
from llmfoundry._version import __version__
Expand Down Expand Up @@ -87,5 +88,6 @@
'models',
'optim',
'tokenizers',
'tp',
'utils',
]
2 changes: 1 addition & 1 deletion llmfoundry/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The LLM Foundry Version."""

__version__ = '0.12.0.dev0'
__version__ = '0.13.0.dev0'
7 changes: 4 additions & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,10 @@ def tensor_hook(
del new_base_model_instance
else:
new_model_instance = type(original_model)(new_config)
new_model_instance.generation_config.update(
**original_model.generation_config.to_dict(),
)
if new_model_instance.generation_config is not None:
new_model_instance.generation_config.update(
**original_model.generation_config.to_dict(),
)

# Then load the state dict in with "assign" so that the state dict
# is loaded properly even though the model is initially on meta device.
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/command_utils/data_prep/convert_dataset_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(
truncated_samples=100,
)

CONSTS = {'c4': c4constants, 'the_pile': pileconstants}
CONSTS = {'allenai/c4': c4constants, 'the_pile': pileconstants}


def build_hf_dataset(
Expand Down Expand Up @@ -335,7 +335,7 @@ def convert_dataset_hf(
dataset_constants = CONSTS[dataset]
except KeyError:
raise ValueError(
f'Constants for dataset "{dataset}" not found. Currently only "the_pile" and "c4" are supported.',
f'Constants for dataset "{dataset}" not found. Currently only "the_pile" and "allenai/c4" are supported.',
)

if concat_tokens is not None and tokenizer is not None:
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/command_utils/data_prep/convert_dataset_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def build_hf_dataset(
no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries
tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use
data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset.
Typically "all" (The Pile) or "en" (c4).
Typically "all" (The Pile) or "en" (allenai/c4).
Returns:
An IterableDataset.
Expand Down
115 changes: 62 additions & 53 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from llmfoundry.utils.exceptions import (
ClusterDoesNotExistError,
ClusterInvalidAccessMode,
FailedToConnectToDatabricksError,
FailedToCreateSQLConnectionError,
InsufficientPermissionsError,
Expand Down Expand Up @@ -448,69 +449,66 @@ def fetch(
"""
cursor = dbsql.cursor() if dbsql is not None else None
try:
nrows = get_total_rows(
tablename,
method,
cursor,
sparkSession,
)
except Exception as e:
from pyspark.errors import AnalysisException
if isinstance(e, AnalysisException):
if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore
raise InsufficientPermissionsError(
action=f'reading from {tablename}',
) from e
raise RuntimeError(
f'Error in get rows from {tablename}. Restart sparkSession and try again',
) from e
# Get total rows
nrows = get_total_rows(tablename, method, cursor, sparkSession)

try:
# Get columns info
columns, order_by, columns_str = get_columns_info(
tablename,
method,
cursor,
sparkSession,
)

if method == 'dbconnect' and sparkSession is not None:
log.info(f'{processes=}')
df = sparkSession.table(tablename)

# Running the query and collecting the data as arrow or json.
signed, _, _ = df.collect_cf('arrow') # pyright: ignore
log.info(f'len(signed) = {len(signed)}')

args = get_args(signed, json_output_folder, columns)

# Stopping the SparkSession to avoid spilling connection state into the subprocesses.
sparkSession.stop()

with ProcessPoolExecutor(max_workers=processes) as executor:
list(executor.map(download_starargs, args))

elif method == 'dbsql' and cursor is not None:
for start in range(0, nrows, batch_size):
log.warning(f'batch {start}')
end = min(start + batch_size, nrows)
fetch_data(
method,
cursor,
sparkSession,
start,
end,
order_by,
tablename,
columns_str,
json_output_folder,
)

except Exception as e:
raise RuntimeError(
f'Error in get columns from {tablename}. Restart sparkSession and try again',
) from e
from databricks.sql.exc import ServerOperationError
from pyspark.errors import AnalysisException

if method == 'dbconnect' and sparkSession is not None:
log.info(f'{processes=}')
df = sparkSession.table(tablename)

# Running the query and collecting the data as arrow or json.
signed, _, _ = df.collect_cf('arrow') # pyright: ignore
log.info(f'len(signed) = {len(signed)}')

args = get_args(signed, json_output_folder, columns)

# Stopping the SparkSession to avoid spilling connection state into the subprocesses.
sparkSession.stop()

with ProcessPoolExecutor(max_workers=processes) as executor:
list(executor.map(download_starargs, args))

elif method == 'dbsql' and cursor is not None:
for start in range(0, nrows, batch_size):
log.warning(f'batch {start}')
end = min(start + batch_size, nrows)
fetch_data(
method,
cursor,
sparkSession,
start,
end,
order_by,
tablename,
columns_str,
json_output_folder,
)
if isinstance(e, (AnalysisException, ServerOperationError)):
if 'INSUFFICIENT_PERMISSIONS' in str(e):
raise InsufficientPermissionsError(str(e)) from e

if cursor is not None:
cursor.close()
if isinstance(e, InsufficientPermissionsError):
raise

# For any other exception, raise a general error
raise RuntimeError(f'Error processing {tablename}: {str(e)}') from e

finally:
if cursor is not None:
cursor.close()


def validate_and_get_cluster_info(
Expand Down Expand Up @@ -546,6 +544,17 @@ def validate_and_get_cluster_info(
if res is None:
raise ClusterDoesNotExistError(cluster_id)

data_security_mode = str(
res.data_security_mode,
).upper()[len('DATASECURITYMODE.'):]

# NONE stands for No Isolation Shared
if data_security_mode == 'NONE':
raise ClusterInvalidAccessMode(
cluster_id=cluster_id,
access_mode=data_security_mode,
)

assert res.spark_version is not None
stripped_runtime = re.sub(
r'[a-zA-Z]',
Expand Down
15 changes: 10 additions & 5 deletions llmfoundry/command_utils/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
CannotUnicodeDecodeFile,
DatasetTooSmallError,
InputFolderMissingDataError,
InputFolderNotFound,
OutputFolderNotEmptyError,
)

Expand Down Expand Up @@ -125,11 +126,15 @@ def get_object_names(input_folder: str) -> list[str]:
object_store = maybe_create_object_store_from_uri(input_folder)
if object_store is not None:
_, _, folder_prefix = parse_uri(input_folder)
names = [
name for name in object_store.list_objects(folder_prefix)
if name.endswith('.txt')
]
log.info(f'Found {len(names)} text files in remote storage')
try:
names = [
name for name in object_store.list_objects(folder_prefix)
if name.endswith('.txt')
]
log.info(f'Found {len(names)} text files in remote storage')
except FileNotFoundError:
raise InputFolderNotFound(folder_prefix)

else:
# input_folder is a local folder
names = [
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/command_utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def evaluate_model(
warnings.warn(
VersionedDeprecationWarning(
'The argument fsdp_config is deprecated. Please use parallelism_config instead.',
remove_version='0.13.0',
remove_version='0.14.0',
),
)
if fsdp_config and parallelism_config:
Expand Down Expand Up @@ -273,7 +273,7 @@ def evaluate(cfg: DictConfig) -> tuple[list[Trainer], pd.DataFrame]:
# Mandatory Evaluation Parameters
icl_tasks = eval_config.icl_tasks or eval_config.icl_tasks_str
if icl_tasks is None:
raise ValueError('icl_tasks must be specified in the config')
icl_tasks = []

# Optional Evaluation Parameters with default values
eval_loader_config = eval_config.eval_loader or eval_config.eval_loaders
Expand Down
32 changes: 27 additions & 5 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import time
import warnings
from copy import deepcopy
from typing import Any, Optional, Union

import torch
Expand Down Expand Up @@ -43,6 +44,7 @@
build_save_planner,
build_scheduler,
build_tokenizer,
build_tp_strategies,
)
from llmfoundry.utils.config_utils import (
TRAIN_CONFIG_KEYS,
Expand Down Expand Up @@ -329,16 +331,27 @@ def train(cfg: DictConfig) -> Trainer:
changing autoresume default to True...',
)

# Warn if fsdp is enabled but user only has 1 GPU
if dist.get_world_size() == 1 and fsdp_config is not None:
# Optional tp config
tp_config: Optional[dict[str, Any]] = train_cfg.tp_config

# Warn if FSDP or TP is enabled but user only has 1 GPU
if dist.get_world_size(
) == 1 and (fsdp_config is not None or tp_config is not None):
parallelism = ''
if fsdp_config is not None:
parallelism += 'FSDP'
if tp_config is not None:
parallelism += '+TP' if fsdp_config is not None else 'TP'
warnings.warn(
'FSDP is not applicable for single-GPU training. Reverting to DDP.',
f'{parallelism} is not applicable for single-GPU training. Reverting to DDP.',
)
fsdp_config = None
tp_config = None

# Initialize context
init_context = process_init_device(model_config, fsdp_config)
init_context = process_init_device(model_config, fsdp_config, tp_config)
logged_cfg.update({'fsdp_config': fsdp_config}, merge=True)
logged_cfg.update({'tp_config': deepcopy(tp_config)}, merge=True)

# Build tokenizer
log.info('Building tokenizer...')
Expand Down Expand Up @@ -502,6 +515,15 @@ def train(cfg: DictConfig) -> Trainer:

_log_num_params(model, logged_cfg)

# TP config
if tp_config is not None:
strategy = tp_config.pop('strategy', None)
assert isinstance(strategy, str), '`strategy` must be in `tp_config`.'
tp_config['layer_plan'] = build_tp_strategies(strategy, model)

# Parallelism config
parallelism_config = {'fsdp': fsdp_config, 'tp': tp_config}

# Optimizer
optimizer_name: str = train_cfg.optimizer.pop('name')
optimizer_cfg = train_cfg.optimizer
Expand Down Expand Up @@ -546,7 +568,7 @@ def train(cfg: DictConfig) -> Trainer:
precision=train_cfg.precision,
algorithms=algorithms,
device_train_microbatch_size=train_cfg.device_train_microbatch_size,
parallelism_config={'fsdp': fsdp_config},
parallelism_config=parallelism_config,
save_folder=train_cfg.save_folder,
save_filename=save_filename,
save_latest_filename=save_latest_filename,
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from composer.loggers import (
InMemoryLogger,
MLFlowLogger,
MosaicMLLogger,
TensorboardLogger,
WandBLogger,
)
Expand All @@ -18,3 +19,4 @@
func=InMemoryLogger,
) # for backwards compatibility
loggers.register('mlflow', func=MLFlowLogger)
loggers.register('mosaicml', func=MosaicMLLogger)
2 changes: 1 addition & 1 deletion llmfoundry/models/hf/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
warnings.warn(
VersionedDeprecationWarning(
'`HuggingFaceModelWithFSDP` is deprecated. In the future please use `BaseHuggingFaceModel`.',
remove_version='0.13.0',
remove_version='0.14.0',
),
)
super().__init__(
Expand Down
Loading

0 comments on commit 36cc16a

Please sign in to comment.