Skip to content

Commit

Permalink
Enable tie_word_embeddings config setting to enable / disable weigh…
Browse files Browse the repository at this point in the history
…t tied embeddings (#728)

* enable disabling embed weight tying

* fix bug

* updt with descriptive var names

* fix hf config

* move comment with code

* bug fix

* add _tie_weights method

* undo mcli yaml change

* refactor

* add tests

* Update llmfoundry/models/mpt/modeling_mpt.py

Co-authored-by: Sasha Doubov <[email protected]>

* pr comments

* updt tests to guard against numerical issues

---------

Co-authored-by: Sasha Doubov <[email protected]>
  • Loading branch information
vchiley and sashaDoubov authored Nov 13, 2023
1 parent d11ba82 commit 7899178
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 60 deletions.
8 changes: 7 additions & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
use_cache: bool = False,
init_config: Dict = init_config_defaults,
fc_type: str = 'torch',
tie_word_embeddings: bool = True,
verbose: Optional[int] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(
---
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
"""
self.d_model = d_model
self.n_heads = n_heads
Expand Down Expand Up @@ -164,7 +166,11 @@ def __init__(
warnings.warn(
f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`'
)
super().__init__(**kwargs)
# tie_word_embeddings is set in Huggingface's PretrainedConfig __init__
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

self._validate_config()

Expand Down
72 changes: 52 additions & 20 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,11 @@ def __init__(self, config: MPTConfig):
log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

def get_input_embeddings(self) -> nn.Embedding:
def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
return self.wte

def set_input_embeddings(self, value: nn.Embedding) -> None:
def set_input_embeddings(
self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
self.wte = value

@torch.no_grad()
Expand Down Expand Up @@ -574,14 +575,20 @@ class MPTForCausalLM(MPTPreTrainedModel):

def __init__(self, config: MPTConfig):
super().__init__(config)
if not config.tie_word_embeddings:
raise ValueError(
'MPTForCausalLM only supports tied word embeddings')

log.info(f'Instantiating an MPTForCausalLM model from {__file__}')

self.transformer: MPTModel = MPTModel(config)

self.lm_head = None
if not config.tie_word_embeddings:
self.lm_head = nn.Linear(
config.d_model,
config.vocab_size,
bias=False,
device=config.init_device,
)
self.lm_head._fsdp_wrap = True

for child in self.transformer.children():
if isinstance(child, torch.nn.ModuleList):
continue
Expand All @@ -602,19 +609,38 @@ def __init__(self, config: MPTConfig):
)
self.logit_scale = logit_scale

def get_input_embeddings(self) -> nn.Embedding:
return self.transformer.wte
def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
return self.transformer.get_input_embeddings()

def set_input_embeddings(
self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
self.transformer.wte = value
self.transformer.set_input_embeddings(value)

def get_output_embeddings(self) -> nn.Embedding:
return self.transformer.wte
def get_output_embeddings(
self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]:
if self.lm_head is not None:
return self.lm_head
return self.transformer.get_input_embeddings()

def set_output_embeddings(
self, new_embeddings: Union[SharedEmbedding, nn.Embedding]) -> None:
self.transformer.wte = new_embeddings
self, new_embeddings: Union[SharedEmbedding, nn.Embedding,
nn.Linear]) -> None:
if self.lm_head is not None:
self.lm_head = new_embeddings
else:
if not isinstance(new_embeddings, (SharedEmbedding, nn.Embedding)):
raise ValueError(
'new_embeddings must be an instance of SharedEmbedding ' +
f'or nn.Embedding, but got {type(new_embeddings)}.')
warnings.warn(
'Using `set_output_embeddings` to set the embedding layer of ' +
'MPTForCausalLM with tied weights. Given weights are tied, ' +
'using `set_input_embeddings` is recommended over using ' +
'`set_output_embeddings`.')
self.transformer.set_input_embeddings(new_embeddings)

def tie_weights(self) -> None:
self.lm_head = None

def set_decoder(self, decoder: MPTModel) -> None:
self.transformer = decoder
Expand Down Expand Up @@ -658,12 +684,14 @@ def forward(
use_cache=use_cache,
)

# move outputs to same device as weights for token embedding
# needed to support HF `device_map`
logits = self.transformer.wte(
outputs.last_hidden_state.to(self.transformer.wte.weight.device),
True,
)
if self.lm_head is not None:
logits = self.lm_head(outputs.last_hidden_state)
else:
# move outputs to same device as weights for token embedding
# needed to support HF `device_map`
out = outputs.last_hidden_state
out = out.to(self.transformer.wte.weight.device)
logits = self.transformer.wte(out, True)

if self.logit_scale is not None:
if self.logit_scale == 0:
Expand Down Expand Up @@ -859,7 +887,11 @@ def flops_per_batch(self, batch: Mapping) -> int:
# assume the backward pass is approximately 2x the forward pass

bs, msl = batch['input_ids'].shape[0:2]
params_flops_per_token = 2 * self.n_active_params
params = self.n_active_params
if not self.model.transformer.config.tie_word_embeddings:
# embedding layers are lookup tables, therefore are not counted in the FLOP computation
params -= self.model.transformer.wte.weight.numel()
params_flops_per_token = 2 * params
params_flops_per_seq = params_flops_per_token * msl
attn_flops_per_seq = (self.model.config.n_layers * 2 * 2 *
(self.model.config.d_model * (msl**2)))
Expand Down
41 changes: 29 additions & 12 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,20 +248,21 @@ def test_callback_inits_with_defaults():

@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
@pytest.mark.parametrize(
'model,tie_word_embeddings',
[('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)],
)
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize('log_to_mlflow', [True, False])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)])
@patch('os.cpu_count', MagicMock(return_value=None))
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_state_dict_type: Optional[str],
log_to_mlflow: bool,
hf_save_interval: str,
save_interval: str, max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int):
def test_huggingface_conversion_callback(
model: str, tmp_path: pathlib.Path, tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str], log_to_mlflow: bool,
hf_save_interval: str, save_interval: str, max_duration: str,
expected_hf_checkpoints: int, expected_normal_checkpoints: int):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))
Expand Down Expand Up @@ -298,9 +299,11 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
'attn_impl': 'torch',
},
'loss_fn': 'torch_crossentropy',
'tie_word_embeddings': tie_word_embeddings,
}
tokenizer_name = 'EleutherAI/gpt-neox-20b'
elif model == 'neo':
assert tie_word_embeddings is None
model_cfg = {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'EleutherAI/gpt-neo-125M',
Expand All @@ -313,6 +316,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
}
tokenizer_name = 'EleutherAI/gpt-neo-125M'
elif model == 'llama2':
assert tie_word_embeddings is None
if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
pytest.skip(
'The CI cluster does not have access to the Llama models, so skip this test.'
Expand Down Expand Up @@ -489,19 +493,26 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
delete_transformers_cache()


@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
def test_convert_and_generate(model: str, tmp_path: pathlib.Path):
@pytest.mark.parametrize(
'model,tie_word_embeddings',
[('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)],
)
def test_convert_and_generate(model: str, tie_word_embeddings: bool,
tmp_path: pathlib.Path):
delete_transformers_cache()

om_cfg = None
if model == 'mpt':
om_cfg = get_config(
conf_path='scripts/train/yamls/pretrain/testing.yaml')
om_cfg['tie_word_embeddings'] = tie_word_embeddings
elif model == 'neo':
assert tie_word_embeddings is None
om_cfg = get_config(
conf_path='scripts/train/yamls/pretrain/gpt-neo-125m.yaml')
om_cfg['model']['config_overrides']['hidden_size'] = 36
elif model == 'llama2':
assert tie_word_embeddings is None
if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
pytest.skip(
'The CI cluster does not have access to the Llama models, so skip this test.'
Expand Down Expand Up @@ -562,11 +573,14 @@ def test_convert_and_generate(model: str, tmp_path: pathlib.Path):


@pytest.mark.gpu
def test_convert_and_generate_triton(tmp_path: pathlib.Path):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_convert_and_generate_triton(tie_word_embeddings: str,
tmp_path: pathlib.Path):
delete_transformers_cache()

cfg = get_config()
cfg['model']['init_device'] = 'cpu'
cfg['tie_word_embeddings'] = tie_word_embeddings
tokenizer = transformers.AutoTokenizer.from_pretrained(
'EleutherAI/gpt-neox-20b')
model = ComposerMPTCausalLM(cfg['model'], tokenizer)
Expand Down Expand Up @@ -602,7 +616,9 @@ def test_convert_and_generate_triton(tmp_path: pathlib.Path):
delete_transformers_cache()


def test_convert_and_generate_meta(tmp_path: pathlib.Path):
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_convert_and_generate_meta(tie_word_embeddings: str,
tmp_path: pathlib.Path):
delete_transformers_cache()

from composer.utils import dist
Expand All @@ -612,6 +628,7 @@ def test_convert_and_generate_meta(tmp_path: pathlib.Path):
om_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml')

om_cfg['model']['init_device'] = 'cpu'
om_cfg['tie_word_embeddings'] = tie_word_embeddings
tokenizer = transformers.AutoTokenizer.from_pretrained(
om_cfg.tokenizer.name)
original_model = COMPOSER_MODEL_REGISTRY[om_cfg['model'].name](
Expand Down
Loading

0 comments on commit 7899178

Please sign in to comment.