Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inputs_embeds #687

Merged
merged 13 commits into from
Dec 1, 2023
77 changes: 48 additions & 29 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,6 @@ def forward(
'prefix_mask is a required argument when MPT is configured with prefix_lm=True.'
)

# Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds is not implemented for MPT.')

if self.training:
if self.attn_uses_sequence_id and sequence_id is None:
raise ValueError(
Expand All @@ -361,7 +356,16 @@ def forward(
S <= self.config.max_seq_len
), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'

tok_emb = self.wte(input_ids)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
'You cannot specify both input_ids and inputs_embeds.')
elif input_ids is not None:
samhavens marked this conversation as resolved.
Show resolved Hide resolved
tok_emb = self.wte(input_ids)
elif inputs_embeds is not None:
tok_emb = inputs_embeds
else:
raise ValueError('You must specify input_ids or inputs_embeds')

if self.learned_pos_emb:
past_position = 0
if past_key_values is not None:
Expand Down Expand Up @@ -554,22 +558,34 @@ def forward(
use_cache = (use_cache
if use_cache is not None else self.config.use_cache)

# if input_embeds is not none, raise a not implemented error
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds has to be None (for hf/peft support).')
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
prefix_mask=prefix_mask,
sequence_id=sequence_id,
return_dict=return_dict,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
'You cannot specify both input_ids and inputs_embeds.')
samhavens marked this conversation as resolved.
Show resolved Hide resolved
samhavens marked this conversation as resolved.
Show resolved Hide resolved
elif inputs_embeds is not None:
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.transformer(
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
attention_mask=attention_mask,
prefix_mask=prefix_mask,
sequence_id=sequence_id,
return_dict=return_dict,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
)
else:
outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
prefix_mask=prefix_mask,
sequence_id=sequence_id,
return_dict=return_dict,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
)
samhavens marked this conversation as resolved.
Show resolved Hide resolved

# move outputs to same device as weights for token embedding
# needed to support HF `device_map`
Expand Down Expand Up @@ -628,10 +644,6 @@ def prepare_inputs_for_generation(
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds is not implemented for MPT yet')

attention_mask = kwargs['attention_mask'].bool()
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
raise NotImplementedError(
Expand All @@ -642,6 +654,7 @@ def prepare_inputs_for_generation(
else:
sequence_id = None

# only last token for inputs_ids if past is defined in kwargs
if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)

Expand All @@ -655,14 +668,20 @@ def prepare_inputs_for_generation(
else:
prefix_mask = None

return {
'input_ids': input_ids,
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}

model_inputs.update({
'attention_mask': attention_mask,
'prefix_mask': prefix_mask,
'sequence_id': sequence_id,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache', True),
}
})
return model_inputs

@staticmethod
def _reorder_cache(
Expand Down
33 changes: 28 additions & 5 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,20 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'):
return test_cfg, model, optimizer


def gen_random_batch(batch_size: int, test_cfg: Union[DictConfig, ListConfig]):
def gen_random_batch(batch_size: int,
test_cfg: Union[DictConfig, ListConfig],
inputs_embeds: bool = False):
# generate input batch of random data, suitable for a Causal or Prefix LM
batch = {}
batch['input_ids'] = torch.randint(
low=0,
high=test_cfg.model.vocab_size,
size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device)
if inputs_embeds:
batch['inputs_embeds'] = torch.randn(batch_size, test_cfg.max_seq_len,
test_cfg.model.d_model).to(
test_cfg.device)
else:
batch['input_ids'] = torch.randint(
low=0,
high=test_cfg.model.vocab_size,
size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device)
batch['labels'] = torch.randint(low=0,
high=test_cfg.model.vocab_size,
size=(batch_size, test_cfg.max_seq_len)).to(
Expand Down Expand Up @@ -153,6 +160,22 @@ def test_full_forward_and_backward(batch_size: int = 2):
assert not torch.equal(original_params, updated_params)


def test_full_forward_and_backward_with_inputs_embeds(batch_size: int = 2):
test_cfg, model, optimizer = get_objs(
conf_path='scripts/train/yamls/pretrain/testing.yaml')

batch = gen_random_batch(batch_size, test_cfg, inputs_embeds=True)

model.train()
original_params = next(model.parameters()).clone().data
outputs = model(batch)
loss = model.loss(outputs, batch)
loss.backward()
optimizer.step()
updated_params = next(model.parameters()).clone().data
assert not torch.equal(original_params, updated_params)


def test_attention_mechanism(batch_size: int = 2):
test_cfg, model, _ = get_objs(
conf_path='scripts/train/yamls/pretrain/testing.yaml')
Expand Down
Loading