Skip to content

Commit

Permalink
Support inputs_embeds (#687)
Browse files Browse the repository at this point in the history
* support inputs_embeds

* update tests to test inputs_embeds

* make iids optional inputs to fwd

* remove check for both iids and inputs_embeds

in MPTForCausalLM. It is checked in the base model, and it is actually a common practice to pass both during autoregressive generation. Embeds are used first, then once the kvcache is nonempty, iids are used instead

* reorder kwargs

* add more tests

* fix device merge artifact in test_model.oy

* fix generate test

* yapf
  • Loading branch information
samhavens authored Dec 1, 2023
1 parent 3100859 commit 22ae919
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 29 deletions.
51 changes: 28 additions & 23 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _apply_sequence_id(self, attn_bias: torch.Tensor,

def forward(
self,
input_ids: torch.LongTensor,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
prefix_mask: Optional[torch.ByteTensor] = None,
Expand Down Expand Up @@ -412,11 +412,6 @@ def forward(
'prefix_mask is a required argument when MPT is configured with prefix_lm=True.'
)

# Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds is not implemented for MPT.')

if self.training:
if self.attn_uses_sequence_id and sequence_id is None:
raise ValueError(
Expand All @@ -430,14 +425,25 @@ def forward(
'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.'
)

S = input_ids.size(1)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
'You cannot specify both input_ids and inputs_embeds.')
elif input_ids is not None:
S = input_ids.size(1)
x = self.wte(input_ids)
input_device = input_ids.device
elif inputs_embeds is not None:
S = inputs_embeds.size(1)
x = inputs_embeds
input_device = inputs_embeds.device
else:
raise ValueError('You must specify input_ids or inputs_embeds')

assert (
S <= self.config.max_seq_len
), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'

rotary_emb_w_meta_info = None
x = self.wte(input_ids)
if self.learned_pos_emb or self.rope:
past_position = 0
if past_key_values is not None:
Expand Down Expand Up @@ -467,7 +473,7 @@ def forward(
past_position,
S + past_position,
dtype=torch.long,
device=input_ids.device,
device=input_device,
).unsqueeze(0)
if attention_mask is not None:
# adjust the position indices to account for padding tokens
Expand Down Expand Up @@ -652,7 +658,7 @@ def get_decoder(self) -> MPTModel:

def forward(
self,
input_ids: torch.LongTensor,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
prefix_mask: Optional[torch.ByteTensor] = None,
Expand All @@ -669,11 +675,6 @@ def forward(
use_cache = (use_cache
if use_cache is not None else self.config.use_cache)

# if input_embeds is not none, raise a not implemented error
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds has to be None (for hf/peft support).')
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
Expand All @@ -684,6 +685,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
)

if self.lm_head is not None:
Expand Down Expand Up @@ -773,10 +775,6 @@ def prepare_inputs_for_generation(
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds is not implemented for MPT yet')

attention_mask = kwargs['attention_mask'].bool()
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
raise NotImplementedError(
Expand All @@ -787,6 +785,7 @@ def prepare_inputs_for_generation(
else:
sequence_id = None

# only last token for inputs_ids if past is defined in kwargs
if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)

Expand All @@ -800,14 +799,20 @@ def prepare_inputs_for_generation(
else:
prefix_mask = None

return {
'input_ids': input_ids,
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}

model_inputs.update({
'attention_mask': attention_mask,
'prefix_mask': prefix_mask,
'sequence_id': sequence_id,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache', True),
}
})
return model_inputs

@staticmethod
def _reorder_cache(
Expand Down Expand Up @@ -898,7 +903,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
add_bidirectional_mask_if_missing(batch)
# Note: prefix_mask is only used if model.prefix_lm is True
return self.model(
input_ids=batch['input_ids'],
input_ids=batch.get('input_ids', None),
attention_mask=batch.get('attention_mask', None),
prefix_mask=batch.get('bidirectional_mask', None),
sequence_id=batch.get('sequence_id', None),
Expand Down
79 changes: 73 additions & 6 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import pathlib
import warnings
from typing import Any, Dict, Union, cast
from typing import Any, Dict, List, Optional, Union, cast
from unittest import mock

import pytest
Expand Down Expand Up @@ -94,13 +94,26 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'):
return test_cfg, model, optimizer


def gen_random_batch(batch_size: int, test_cfg: Union[DictConfig, ListConfig]):
def gen_random_batch(batch_size: int,
test_cfg: Union[DictConfig, ListConfig],
inputs: Optional[List[str]] = None):
# inputs can be [], ['input_ids'], ['input_ids', 'inputs_embeds'], and ['inputs_embeds']
# default to only input ids
if inputs == None:
inputs = ['input_ids']
# generate input batch of random data, suitable for a Causal or Prefix LM
batch = {}
batch['input_ids'] = torch.randint(
low=0,
high=test_cfg.model.vocab_size,
size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device)
for inp in inputs:
if inp == 'input_ids':
batch['input_ids'] = torch.randint(
low=0,
high=test_cfg.model.vocab_size,
size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device)
if inp == 'inputs_embeds':
batch['inputs_embeds'] = torch.randn(
batch_size, test_cfg.max_seq_len,
test_cfg.model.d_model).to(test_cfg.device)

batch['labels'] = torch.randint(low=0,
high=test_cfg.model.vocab_size,
size=(batch_size, test_cfg.max_seq_len)).to(
Expand Down Expand Up @@ -150,6 +163,34 @@ def test_full_forward_and_backward(batch_size: int = 2):
assert not torch.equal(original_params, updated_params)


def test_full_forward_and_backward_with_inputs_embeds(batch_size: int = 2):
test_cfg, model, optimizer = get_objs(
conf_path='scripts/train/yamls/pretrain/testing.yaml')

batch = gen_random_batch(batch_size, test_cfg, inputs=['inputs_embeds'])

model.train()
original_params = next(model.parameters()).clone().data
outputs = model(batch)
loss = model.loss(outputs, batch)
loss.backward()
optimizer.step()
updated_params = next(model.parameters()).clone().data
assert not torch.equal(original_params, updated_params)


@pytest.mark.parametrize('inputs', [[], ['input_ids', 'inputs_embeds']])
def test_invalid_inputs_embeds_input_ids_combinations(inputs: List[str]):
test_cfg, model, _ = get_objs(
conf_path='scripts/train/yamls/pretrain/testing.yaml')

batch = gen_random_batch(2, test_cfg, inputs=inputs)

model.train()
with pytest.raises(ValueError):
_ = model(batch)


def test_attention_mechanism(batch_size: int = 2):
test_cfg, model, _ = get_objs(
conf_path='scripts/train/yamls/pretrain/testing.yaml')
Expand Down Expand Up @@ -825,6 +866,9 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict,
no_padding_attention_mask = composer_device.tensor_to_device(
no_padding_attention_mask)

# inputs_embeds
inputs_embeds = composer_device.tensor_to_device(torch.randn(2, 3, 128))

# a single batch with different amounts of left padding in the input
batched_input_ids = torch.tensor([[50256, 50256, 50256, 11274, 16390, 11],
[50256, 50256, 16, 11274, 16390, 11]])
Expand Down Expand Up @@ -860,6 +904,29 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict,
assert generation_with_no_padding[:, 3:].equal(
generation_with_left_padding[:, 6:])

# check that both/neither ids and embeds do not error
# note that we need to set the BOS token ID for generating from neither
_ = mpt.generate(input_ids=no_padding_input_ids,
inputs_embeds=inputs_embeds,
attention_mask=no_padding_attention_mask,
max_new_tokens=5,
use_cache=False)
_ = mpt.generate(input_ids=no_padding_input_ids,
inputs_embeds=inputs_embeds,
attention_mask=no_padding_attention_mask,
max_new_tokens=5,
use_cache=True)
_ = mpt.generate(input_ids=None,
inputs_embeds=None,
max_new_tokens=5,
use_cache=False,
bos_token_id=50256)
_ = mpt.generate(input_ids=None,
inputs_embeds=None,
max_new_tokens=5,
use_cache=True,
bos_token_id=50256)


@pytest.mark.gpu
@pytest.mark.parametrize('world_size', [1, 2])
Expand Down

0 comments on commit 22ae919

Please sign in to comment.