Skip to content

Commit

Permalink
Merge branch 'main' into execution_prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml authored Nov 17, 2023
2 parents 6014dd4 + 25bb63f commit 90ef388
Show file tree
Hide file tree
Showing 27 changed files with 490 additions and 128 deletions.
16 changes: 8 additions & 8 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ def __init__(
if self.mlflow_registered_model_name is not None:
# Both the metadata and the task are needed in order for mlflow
# and databricks optimized model serving to work
if 'metadata' not in mlflow_logging_config:
mlflow_logging_config['metadata'] = {
'task': 'llm/v1/completions'
}
if 'task' not in mlflow_logging_config:
mlflow_logging_config['task'] = 'text-generation'
default_metadata = {'task': 'llm/v1/completions'}
passed_metadata = mlflow_logging_config.get('metadata', {})
mlflow_logging_config['metadata'] = {
**default_metadata,
**passed_metadata
}
mlflow_logging_config.setdefault('task', 'text-generation')
self.mlflow_logging_config = mlflow_logging_config

self.huggingface_folder_name_fstr = os.path.join(
Expand All @@ -93,7 +94,6 @@ def __init__(
self.save_interval = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)

self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
save_folder, loggers=[])
if self.remote_ud is not None:
Expand Down Expand Up @@ -204,7 +204,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
state_dict[k] = v.to(dtype=self.dtype)

if dist.get_global_rank() == 0:
log.debug('Saving Hugging Face checkpoint to disk')
log.debug('Saving Hugging Face checkpoint in global rank 0')

copied_config = copy.deepcopy(original_model.config)
if copied_config.model_type == 'mpt':
Expand Down
6 changes: 3 additions & 3 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,13 +477,13 @@ def build_text_denoising_dataloader(
remote=cfg.dataset.get('remote'),
split=cfg.dataset.get('split'),
shuffle=cfg.dataset.get('shuffle', False),
predownload=cfg.dataset.get('predownload', 100_000),
predownload=cfg.dataset.get('predownload', None),
keep_zip=cfg.dataset.get('keep_zip', False),
download_retry=cfg.dataset.get('download_retry', 2),
download_timeout=cfg.dataset.get('download_timeout', 60),
validate_hash=cfg.dataset.get('validate_hash'),
validate_hash=cfg.dataset.get('validate_hash', None),
shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', 128),
num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None),
batch_size=device_batch_size,
)

Expand Down
6 changes: 3 additions & 3 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,13 @@ def build_finetuning_dataloader(cfg: DictConfig,
epoch_size=cfg.dataset.get('epoch_size', None),
predownload=cfg.dataset.get('predownload', None),
cache_limit=cfg.dataset.get('cache_limit', None),
partition_algo=cfg.dataset.get('partition_algo', 'orig'),
partition_algo=cfg.dataset.get('partition_algo', 'relaxed'),
num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None),
batch_size=device_batch_size,
shuffle=cfg.dataset.get('shuffle', False),
shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1b'),
shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1e'),
shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
shuffle_block_size=cfg.dataset.get('shuffle_block_size', 1 << 18),
shuffle_block_size=cfg.dataset.get('shuffle_block_size', None),
sampling_method=cfg.dataset.get('sampling_method', 'balanced'),
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
batching_method=cfg.dataset.get('batching_method', 'random'),
Expand Down
22 changes: 12 additions & 10 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,28 +88,30 @@ class StreamingFinetuningDataset(StreamingDataset):
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
`False``.
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all
streams. If ``None``, takes its value from the total number of underlying samples.
Provide this field if you are weighting streams relatively to target a larger or
smaller epoch size. Defaults to ``None``.
predownload (int, optional): Target number of samples ahead to download the shards of while
iterating. Defaults to ``100_000``.
iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
shard cache. Before downloading a shard, the least recently used resident shard(s) may
be evicted (deleted from the local cache) in order to stay under the limit. Set to None
to disable shard eviction. Supports integer bytes as well as string human-readable
bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. Defaults to ``None``, which is interpreted as the number of nodes of the
initial run.
resumption. If ``None``, this is interpreted as 64 times the number of physical
nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the
number of physical nodes of the initial run otherwise. Defaults to ``None``.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
shuffle_block_size (int): Unit of shuffle. If ``None``, its value is calculated as
``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to ``None``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
sampling_granularity (int): When picking samples for a stream's final partial repeat,
Expand All @@ -129,16 +131,16 @@ def __init__(self,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
epoch_size: Optional[Union[int, str]] = None,
predownload: Optional[int] = None,
cache_limit: Optional[Union[int, str]] = None,
partition_algo: str = 'orig',
partition_algo: str = 'relaxed',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1b',
shuffle_algo: str = 'py1e',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
shuffle_block_size: Optional[int] = None,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
Expand Down
26 changes: 15 additions & 11 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,32 @@ class StreamingTextDataset(StreamingDataset):
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
`False``.
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all
streams. If ``None``, takes its value from the total number of underlying samples.
Provide this field if you are weighting streams relatively to target a larger or
smaller epoch size. Defaults to ``None``.
predownload (int, optional): Target number of samples ahead to download the shards of while
iterating. Defaults to ``100_000``.
iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
shard cache. Before downloading a shard, the least recently used resident shard(s) may
be evicted (deleted from the local cache) in order to stay under the limit. Set to None
to disable shard eviction. Supports integer bytes as well as string human-readable
bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. Defaults to ``None``, which is interpreted as the number of nodes of the
initial run.
resumption. If ``None``, this is interpreted as 64 times the number of physical
nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the
number of physical nodes of the initial run otherwise. Defaults to ``None``.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split
into blocks of this size, and samples within each block are shuffled. If ``None``, its
value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to
``None``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
sampling_granularity (int): When picking samples for a stream's final partial repeat,
Expand All @@ -89,16 +93,16 @@ def __init__(self,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: int = 100_000,
epoch_size: Optional[Union[int, str]] = None,
predownload: Optional[int] = None,
cache_limit: Optional[Union[int, str]] = None,
partition_algo: str = 'orig',
partition_algo: str = 'relaxed',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1b',
shuffle_algo: str = 'py1e',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
shuffle_block_size: Optional[int] = None,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
Expand Down
8 changes: 7 additions & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
use_cache: bool = False,
init_config: Dict = init_config_defaults,
fc_type: str = 'torch',
tie_word_embeddings: bool = True,
verbose: Optional[int] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(
---
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
"""
self.d_model = d_model
self.n_heads = n_heads
Expand Down Expand Up @@ -164,7 +166,11 @@ def __init__(
warnings.warn(
f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`'
)
super().__init__(**kwargs)
# tie_word_embeddings is set in Huggingface's PretrainedConfig __init__
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

self._validate_config()

Expand Down
Loading

0 comments on commit 90ef388

Please sign in to comment.