Skip to content

Commit

Permalink
Merge branch 'main' into anna/asynceval
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl authored Nov 14, 2023
2 parents d85ee5e + 45113eb commit 0e96fea
Show file tree
Hide file tree
Showing 22 changed files with 354 additions and 109 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
state_dict[k] = v.to(dtype=self.dtype)

if dist.get_global_rank() == 0:
log.debug('Saving Hugging Face checkpoint to disk')
log.debug('Saving Hugging Face checkpoint in global rank 0')

copied_config = copy.deepcopy(original_model.config)
if copied_config.model_type == 'mpt':
Expand Down
6 changes: 3 additions & 3 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,13 +477,13 @@ def build_text_denoising_dataloader(
remote=cfg.dataset.get('remote'),
split=cfg.dataset.get('split'),
shuffle=cfg.dataset.get('shuffle', False),
predownload=cfg.dataset.get('predownload', 100_000),
predownload=cfg.dataset.get('predownload', None),
keep_zip=cfg.dataset.get('keep_zip', False),
download_retry=cfg.dataset.get('download_retry', 2),
download_timeout=cfg.dataset.get('download_timeout', 60),
validate_hash=cfg.dataset.get('validate_hash'),
validate_hash=cfg.dataset.get('validate_hash', None),
shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', 128),
num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None),
batch_size=device_batch_size,
)

Expand Down
6 changes: 3 additions & 3 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,13 @@ def build_finetuning_dataloader(cfg: DictConfig,
epoch_size=cfg.dataset.get('epoch_size', None),
predownload=cfg.dataset.get('predownload', None),
cache_limit=cfg.dataset.get('cache_limit', None),
partition_algo=cfg.dataset.get('partition_algo', 'orig'),
partition_algo=cfg.dataset.get('partition_algo', 'relaxed'),
num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None),
batch_size=device_batch_size,
shuffle=cfg.dataset.get('shuffle', False),
shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1b'),
shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1e'),
shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
shuffle_block_size=cfg.dataset.get('shuffle_block_size', 1 << 18),
shuffle_block_size=cfg.dataset.get('shuffle_block_size', None),
sampling_method=cfg.dataset.get('sampling_method', 'balanced'),
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
batching_method=cfg.dataset.get('batching_method', 'random'),
Expand Down
22 changes: 12 additions & 10 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,28 +88,30 @@ class StreamingFinetuningDataset(StreamingDataset):
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
`False``.
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all
streams. If ``None``, takes its value from the total number of underlying samples.
Provide this field if you are weighting streams relatively to target a larger or
smaller epoch size. Defaults to ``None``.
predownload (int, optional): Target number of samples ahead to download the shards of while
iterating. Defaults to ``100_000``.
iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
shard cache. Before downloading a shard, the least recently used resident shard(s) may
be evicted (deleted from the local cache) in order to stay under the limit. Set to None
to disable shard eviction. Supports integer bytes as well as string human-readable
bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. Defaults to ``None``, which is interpreted as the number of nodes of the
initial run.
resumption. If ``None``, this is interpreted as 64 times the number of physical
nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the
number of physical nodes of the initial run otherwise. Defaults to ``None``.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
shuffle_block_size (int): Unit of shuffle. If ``None``, its value is calculated as
``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to ``None``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
sampling_granularity (int): When picking samples for a stream's final partial repeat,
Expand All @@ -129,16 +131,16 @@ def __init__(self,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
epoch_size: Optional[Union[int, str]] = None,
predownload: Optional[int] = None,
cache_limit: Optional[Union[int, str]] = None,
partition_algo: str = 'orig',
partition_algo: str = 'relaxed',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1b',
shuffle_algo: str = 'py1e',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
shuffle_block_size: Optional[int] = None,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
Expand Down
26 changes: 15 additions & 11 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,32 @@ class StreamingTextDataset(StreamingDataset):
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
`False``.
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all
streams. If ``None``, takes its value from the total number of underlying samples.
Provide this field if you are weighting streams relatively to target a larger or
smaller epoch size. Defaults to ``None``.
predownload (int, optional): Target number of samples ahead to download the shards of while
iterating. Defaults to ``100_000``.
iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
shard cache. Before downloading a shard, the least recently used resident shard(s) may
be evicted (deleted from the local cache) in order to stay under the limit. Set to None
to disable shard eviction. Supports integer bytes as well as string human-readable
bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. Defaults to ``None``, which is interpreted as the number of nodes of the
initial run.
resumption. If ``None``, this is interpreted as 64 times the number of physical
nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the
number of physical nodes of the initial run otherwise. Defaults to ``None``.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split
into blocks of this size, and samples within each block are shuffled. If ``None``, its
value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to
``None``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
sampling_granularity (int): When picking samples for a stream's final partial repeat,
Expand All @@ -89,16 +93,16 @@ def __init__(self,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: int = 100_000,
epoch_size: Optional[Union[int, str]] = None,
predownload: Optional[int] = None,
cache_limit: Optional[Union[int, str]] = None,
partition_algo: str = 'orig',
partition_algo: str = 'relaxed',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1b',
shuffle_algo: str = 'py1e',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
shuffle_block_size: Optional[int] = None,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
Expand Down
8 changes: 7 additions & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
use_cache: bool = False,
init_config: Dict = init_config_defaults,
fc_type: str = 'torch',
tie_word_embeddings: bool = True,
verbose: Optional[int] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(
---
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
"""
self.d_model = d_model
self.n_heads = n_heads
Expand Down Expand Up @@ -164,7 +166,11 @@ def __init__(
warnings.warn(
f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`'
)
super().__init__(**kwargs)
# tie_word_embeddings is set in Huggingface's PretrainedConfig __init__
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

self._validate_config()

Expand Down
106 changes: 84 additions & 22 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias
from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY,
attn_bias_shape,
build_attn_bias)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
Expand Down Expand Up @@ -231,10 +233,11 @@ def __init__(self, config: MPTConfig):
log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

def get_input_embeddings(self) -> nn.Embedding:
def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
return self.wte

def set_input_embeddings(self, value: nn.Embedding) -> None:
def set_input_embeddings(
self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
self.wte = value

@torch.no_grad()
Expand Down Expand Up @@ -574,14 +577,20 @@ class MPTForCausalLM(MPTPreTrainedModel):

def __init__(self, config: MPTConfig):
super().__init__(config)
if not config.tie_word_embeddings:
raise ValueError(
'MPTForCausalLM only supports tied word embeddings')

log.info(f'Instantiating an MPTForCausalLM model from {__file__}')

self.transformer: MPTModel = MPTModel(config)

self.lm_head = None
if not config.tie_word_embeddings:
self.lm_head = nn.Linear(
config.d_model,
config.vocab_size,
bias=False,
device=config.init_device,
)
self.lm_head._fsdp_wrap = True

for child in self.transformer.children():
if isinstance(child, torch.nn.ModuleList):
continue
Expand All @@ -602,19 +611,38 @@ def __init__(self, config: MPTConfig):
)
self.logit_scale = logit_scale

def get_input_embeddings(self) -> nn.Embedding:
return self.transformer.wte
def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
return self.transformer.get_input_embeddings()

def set_input_embeddings(
self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
self.transformer.wte = value
self.transformer.set_input_embeddings(value)

def get_output_embeddings(self) -> nn.Embedding:
return self.transformer.wte
def get_output_embeddings(
self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]:
if self.lm_head is not None:
return self.lm_head
return self.transformer.get_input_embeddings()

def set_output_embeddings(
self, new_embeddings: Union[SharedEmbedding, nn.Embedding]) -> None:
self.transformer.wte = new_embeddings
self, new_embeddings: Union[SharedEmbedding, nn.Embedding,
nn.Linear]) -> None:
if self.lm_head is not None:
self.lm_head = new_embeddings
else:
if not isinstance(new_embeddings, (SharedEmbedding, nn.Embedding)):
raise ValueError(
'new_embeddings must be an instance of SharedEmbedding ' +
f'or nn.Embedding, but got {type(new_embeddings)}.')
warnings.warn(
'Using `set_output_embeddings` to set the embedding layer of ' +
'MPTForCausalLM with tied weights. Given weights are tied, ' +
'using `set_input_embeddings` is recommended over using ' +
'`set_output_embeddings`.')
self.transformer.set_input_embeddings(new_embeddings)

def tie_weights(self) -> None:
self.lm_head = None

def set_decoder(self, decoder: MPTModel) -> None:
self.transformer = decoder
Expand Down Expand Up @@ -658,12 +686,14 @@ def forward(
use_cache=use_cache,
)

# move outputs to same device as weights for token embedding
# needed to support HF `device_map`
logits = self.transformer.wte(
outputs.last_hidden_state.to(self.transformer.wte.weight.device),
True,
)
if self.lm_head is not None:
logits = self.lm_head(outputs.last_hidden_state)
else:
# move outputs to same device as weights for token embedding
# needed to support HF `device_map`
out = outputs.last_hidden_state
out = out.to(self.transformer.wte.weight.device)
logits = self.transformer.wte(out, True)

if self.logit_scale is not None:
if self.logit_scale == 0:
Expand Down Expand Up @@ -705,7 +735,35 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool:

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
return isinstance(module, MPTBlock)
act_ckpt_list = getattr(self.config, 'activation_checkpointing_target',
None) or ['MPTBlock']

if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list:
if len(act_ckpt_list) > 1:
log.info(
'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).'
)
return isinstance(module, MPTBlock)

mod_types = ()
for mod_name in act_ckpt_list:
if mod_name.lower() == 'mptblock':
mod_types += (MPTBlock,)
elif mod_name in ATTN_CLASS_REGISTRY:
mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
elif mod_name in FFN_CLASS_REGISTRY:
mod_types += (FFN_CLASS_REGISTRY[mod_name],)
elif mod_name in NORM_CLASS_REGISTRY:
mod_types += (NORM_CLASS_REGISTRY[mod_name],)
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) +
list(FFN_CLASS_REGISTRY.keys()) +
list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
return isinstance(module, mod_types)

def prepare_inputs_for_generation(
self,
Expand Down Expand Up @@ -859,7 +917,11 @@ def flops_per_batch(self, batch: Mapping) -> int:
# assume the backward pass is approximately 2x the forward pass

bs, msl = batch['input_ids'].shape[0:2]
params_flops_per_token = 2 * self.n_active_params
params = self.n_active_params
if not self.model.transformer.config.tie_word_embeddings:
# embedding layers are lookup tables, therefore are not counted in the FLOP computation
params -= self.model.transformer.wte.weight.numel()
params_flops_per_token = 2 * params
params_flops_per_seq = params_flops_per_token * msl
attn_flops_per_seq = (self.model.config.n_layers * 2 * 2 *
(self.model.config.d_model * (msl**2)))
Expand Down
Loading

0 comments on commit 0e96fea

Please sign in to comment.