Skip to content

Commit

Permalink
Add support for transformers-neuronx continuous batching (#488)
Browse files Browse the repository at this point in the history
* test(generation): move decoder tests in their own file

* test(decoder): increase batch size

* test(decoder): add test to check for unk tokens

* test(decoder): test LLama padding issues

* feat(decoder): add default attention_mask

* refactor(decoder): isolate prefill from decode

* feat(decoder): add support for continuous batching

* feat(exporters): decoders with continuous batching

* feat(decoder): continuous_batching used by default

* fix(tgi): use SDK 2.17 torch-neuronx version

* fix(tgi): use correct versions in Dockerfile

* feat(tgi): bump version and use max-batch-size

* review: address comments
  • Loading branch information
dacorvo authored Feb 19, 2024
1 parent 460d226 commit 8f3e96a
Show file tree
Hide file tree
Showing 13 changed files with 289 additions and 147 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ PACKAGE_FILES = $(PACKAGE_PYTHON_FILES) \
$(PACKAGE_DIST) $(PACKAGE_WHEEL): $(PACKAGE_FILES)
python -m build

TGI_VERSION ?= 1.4.0
TGI_VERSION ?= 1.4.1

neuronx-tgi: $(PACKAGE_DIST)
docker build --rm -f text-generation-inference/Dockerfile \
Expand Down
7 changes: 6 additions & 1 deletion optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class NeuronDecoderConfig(NeuronConfig):
be passed to export the model,
- NEURONX_CLASS (`str`) -- the name of the transformers-neuronx class to instantiate for the model.
It is a full class name defined relatively to the transformers-neuronx module, e.g. `gpt2.model.GPT2ForSampling`
[`~optimum.utils.DummyInputGenerator`] specifying how to create dummy inputs.
- CONTINUOUS_BATCHING (`bool`, , defaults to `False`) -- Whether the model supports continuous batching or not.
The NEURONX_CLASS must always be defined in each model configuration.
Expand All @@ -389,6 +389,7 @@ class NeuronDecoderConfig(NeuronConfig):

INPUT_ARGS = ("batch_size", "sequence_length")
NEURONX_CLASS = None
CONTINUOUS_BATCHING = False

def __init__(self, task: str):
if not is_transformers_neuronx_available():
Expand All @@ -404,3 +405,7 @@ def __init__(self, task: str):
@property
def neuronx_class(self):
return self._neuronx_class

@property
def continuous_batching(self):
return self.CONTINUOUS_BATCHING
2 changes: 2 additions & 0 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ class GPT2NeuronConfig(TextNeuronDecoderConfig):
@register_in_tasks_manager("llama", "text-generation")
class LLamaNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "llama.model.LlamaForSampling"
CONTINUOUS_BATCHING = True


@register_in_tasks_manager("t5-encoder", "text2text-generation")
Expand Down Expand Up @@ -533,3 +534,4 @@ def generate_io_aliases(self, model):
@register_in_tasks_manager("mistral", "text-generation")
class MistralNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "mistral.model.MistralForSampling"
CONTINUOUS_BATCHING = True
135 changes: 90 additions & 45 deletions optimum/neuron/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import copy
import logging
from typing import TYPE_CHECKING, Dict, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import torch
from transformers import (
Expand Down Expand Up @@ -656,14 +656,14 @@ def __init__(
generation_config: Optional["GenerationConfig"] = None,
):
super().__init__(config, checkpoint_dir, compiled_dir=compiled_dir, generation_config=generation_config)
self.cur_len = 0
self.batch_size = self.model.config.batch_size
self.max_length = self.model.config.n_positions
self.batch_size = self.config.neuron["batch_size"]
self.max_length = self.config.neuron["sequence_length"]
self.continuous_batching = self.model.neuron_config and self.model.neuron_config.continuous_batching
# The generate method from GenerationMixin expects the device attribute to be set
self.device = torch.device("cpu")

def reset_generation(self):
self.cur_len = 0
pass

@add_start_docstrings_to_model_forward(
NEURON_CAUSALLM_MODEL_FORWARD_DOCSTRING
Expand All @@ -688,32 +688,78 @@ def forward(
return ModelOutput([("logits", out_logits)])
return (out_logits,)

def prepare_inputs_for_generation(
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs
) -> Dict[str, torch.Tensor]:
# convert attention_mask to start_ids
def get_start_ids(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
seq_ids: Optional[torch.Tensor] = None,
):
# The start_ids parameter has different meanings:
# - for continuous (unpadded) batching it corresponds to the sequence id,
# - for static batching it corresponds to the start of the padded sequence.
if self.continuous_batching:
if seq_ids is None:
seq_ids = torch.arange(input_ids.shape[0])
else:
assert seq_ids.shape[0] == input_ids.shape[0]
return seq_ids
start_ids = None
if attention_mask is not None:
_, start_ids = attention_mask.max(axis=1)

if self.cur_len > 0:
# Only pass the last tokens of each sample
input_ids = input_ids[:, -1:]
# Specify the single index at which the new keys and values need to be stored
cache_ids = torch.as_tensor([self.cur_len], dtype=torch.int32)
else:
# cache_ids will be set directly by the parallel context encoding code
cache_ids = None

# Increment the current cache index
self.cur_len += input_ids.shape[-1]
model_inputs = {
return start_ids

def get_cache_ids(self, attention_mask: torch.tensor, prefill: bool):
cache_n, cache_len = attention_mask.shape
if self.continuous_batching:
# Evaluate the inputs that are not masked for each sequence
input_length = attention_mask.sum(axis=1)
if not prefill:
# When decoding, cache_ids contains a single value per sequence
return (input_length - 1).unsqueeze(1)
# When prefilling, cache_ids is an increasing range
cache_ids = torch.zeros_like(attention_mask)
for i in range(cache_n):
cur_length = input_length[i]
cache_ids[i, :cur_length] = torch.arange(cur_length)
return cache_ids
# Static batching
return None if prefill else torch.tensor([cache_len - 1], dtype=torch.int32)

def prepare_inputs_for_prefill(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, seq_ids: Optional[List[int]] = None
) -> Dict[str, torch.Tensor]:
start_ids = self.get_start_ids(input_ids, attention_mask, seq_ids=seq_ids)
cache_ids = self.get_cache_ids(attention_mask, prefill=True)
if self.continuous_batching and torch.any(attention_mask[:, 0] == 0):
# Inputs are left padded: we need to invert padding as continuous batching requires right-padding
batch_size, seq_len = input_ids.shape
input_length = attention_mask.sum(axis=1)
new_input_ids = torch.zeros_like(input_ids)
for i in range(batch_size):
cur_length = input_length[i]
new_input_ids[i, :cur_length] = input_ids[i, seq_len - cur_length :]
input_ids = new_input_ids
return {
"input_ids": input_ids,
"cache_ids": cache_ids,
"start_ids": start_ids,
}

return model_inputs
def prepare_inputs_for_decode(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
seq_ids: Optional[List[int]] = None,
) -> Dict[str, torch.Tensor]:
start_ids = self.get_start_ids(input_ids, attention_mask, seq_ids=seq_ids)
cache_ids = self.get_cache_ids(attention_mask, prefill=False)
# Only pass the last tokens of each sample
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"cache_ids": cache_ids,
"start_ids": start_ids,
}

def can_generate(self) -> bool:
"""Returns True to validate the check made in `GenerationMixin.generate()`."""
Expand Down Expand Up @@ -775,7 +821,7 @@ def generate(
f"The input sequence length ({sequence_length}) exceeds the model static sequence length ({self.max_length})"
)
padded_input_ids = input_ids
padded_attention_mask = attention_mask
padded_attention_mask = torch.ones_like(input_ids) if attention_mask is None else attention_mask
if batch_size > self.batch_size:
raise ValueError(
f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.batch_size})"
Expand All @@ -784,18 +830,15 @@ def generate(
logger.warning("Inputs will be padded to match the model static batch size. This will increase latency.")
padding_shape = [self.batch_size - batch_size, sequence_length]
padding = torch.full(padding_shape, fill_value=self.config.eos_token_id, dtype=torch.int64)
padded_input_ids = torch.cat([input_ids, padding])
if attention_mask is not None:
padding = torch.zeros(padding_shape, dtype=torch.int64)
padded_attention_mask = torch.cat([attention_mask, padding])
# Drop the current generation context and clear the Key/Value cache
self.reset_generation()
padded_input_ids = torch.cat([padded_input_ids, padding])
padding = torch.zeros(padding_shape, dtype=torch.int64)
padded_attention_mask = torch.cat([padded_attention_mask, padding])

output_ids = self.generate_tokens(
padded_input_ids,
selector,
batch_size,
attention_mask=padded_attention_mask,
padded_attention_mask,
**model_kwargs,
)
return output_ids[:batch_size, :]
Expand All @@ -805,7 +848,7 @@ def generate_tokens(
input_ids: torch.LongTensor,
selector: TokenSelector,
batch_size: int,
attention_mask: Optional[torch.Tensor] = None,
attention_mask: torch.Tensor,
**model_kwargs,
) -> torch.LongTensor:
r"""
Expand All @@ -831,17 +874,15 @@ def generate_tokens(
unfinished_sequences = torch.zeros(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
unfinished_sequences[:batch_size] = 1

# Prefill and obtain the first token
model_inputs = self.prepare_inputs_for_prefill(input_ids, attention_mask)
outputs = self(
**model_inputs,
return_dict=True,
)

# auto-regressive generation
while True:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, attention_mask, **model_kwargs)

# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
)

next_token_logits = outputs.logits[:, -1, :]

next_tokens = selector.select(input_ids, next_token_logits)
Expand All @@ -851,10 +892,7 @@ def generate_tokens(

# update inputs for the next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if attention_mask is not None:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)

# if eos_token was found in one sentence, set sentence to finished
unfinished_sequences = unfinished_sequences * next_tokens.ne(selector.eos_token_id)
Expand All @@ -867,4 +905,11 @@ def generate_tokens(
if selector.stopping_criteria(input_ids, None):
break

# forward pass to get next token
model_inputs = self.prepare_inputs_for_decode(input_ids, attention_mask)
outputs = self(
**model_inputs,
return_dict=True,
)

return input_ids
29 changes: 20 additions & 9 deletions optimum/neuron/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@


if is_transformers_neuronx_available():
from transformers_neuronx.config import ContinuousBatchingConfig, NeuronConfig
from transformers_neuronx.module import save_split


Expand Down Expand Up @@ -131,16 +132,26 @@ def __init__(

exporter = get_exporter(config, task)

# transformers-neuronx uses f32/f16 instead of fp32/fp16
auto_cast_type = auto_cast_type.replace("p", "")
tnx_kwargs = {
"batch_size": batch_size,
"tp_degree": num_cores,
# transformers-neuronx uses f32/f16 instead of fp32/fp16
"amp": auto_cast_type.replace("p", ""),
}
if batch_size > 1 and exporter.continuous_batching:
# Continuous batching is always enabled for models that support it because static batching
# is broken for these models: see https://github.com/aws-neuron/transformers-neuronx/issues/79
tnx_kwargs["neuron_config"] = NeuronConfig(
continuous_batching=ContinuousBatchingConfig(batch_size_for_shared_caches=batch_size)
)
tnx_kwargs["n_positions"] = [sequence_length]
tnx_kwargs["context_length_estimate"] = [sequence_length]
else:
tnx_kwargs["n_positions"] = sequence_length

# Instantiate neuronx model
checkpoint_path = checkpoint_dir.name if isinstance(checkpoint_dir, TemporaryDirectory) else checkpoint_dir
neuronx_model = exporter.neuronx_class.from_pretrained(
checkpoint_path,
batch_size=batch_size,
n_positions=sequence_length,
tp_degree=num_cores,
amp=auto_cast_type,
)
neuronx_model = exporter.neuronx_class.from_pretrained(checkpoint_path, **tnx_kwargs)

if compiled_dir is not None:
# Specify the path where compiled artifacts are stored before conversion
Expand Down
2 changes: 1 addition & 1 deletion tests/generation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def export_seq2seq_model_class(request):
@requires_neuronx
def neuron_decoder_path(export_decoder_id):
model = NeuronModelForCausalLM.from_pretrained(
export_decoder_id, export=True, batch_size=1, sequence_length=100, num_cores=2
export_decoder_id, export=True, batch_size=2, sequence_length=100, num_cores=2
)
model_dir = TemporaryDirectory()
model_path = model_dir.name
Expand Down
62 changes: 1 addition & 61 deletions tests/generation/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,12 @@
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import StoppingCriteria

from optimum.neuron import NeuronModelForCausalLM, NeuronModelForSeq2SeqLM
from optimum.neuron import NeuronModelForSeq2SeqLM
from optimum.neuron.utils.testing_utils import is_inferentia_test, is_trainium_test, requires_neuronx
from optimum.neuron.utils.training_utils import patch_generation_mixin_to_general_neuron_generation_mixin


def _test_model_generation(model, tokenizer, batch_size, input_length, **gen_kwargs):
input_ids = torch.ones((batch_size, input_length), dtype=torch.int64)
with torch.inference_mode():
sample_output = model.generate(input_ids, **gen_kwargs)
assert sample_output.shape[0] == batch_size


def _test_model_generation_trn(model, tokenizer, batch_size, input_length, **gen_kwargs):
import torch_xla.core.xla_model as xm

Expand All @@ -43,58 +35,6 @@ def _test_model_generation_trn(model, tokenizer, batch_size, input_length, **gen
assert sample_output.shape[0] == batch_size


@pytest.mark.parametrize(
"gen_kwargs",
[
{"do_sample": True},
{"do_sample": True, "temperature": 0.7},
{"do_sample": False},
{"do_sample": False, "repetition_penalty": 1.2},
],
ids=["sample", "sample-with-temp", "greedy", "greedy_no-repeat"],
)
@is_inferentia_test
@requires_neuronx
def test_decoder_generation(neuron_decoder_path, gen_kwargs):
model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path)
tokenizer = AutoTokenizer.from_pretrained(neuron_decoder_path)
_test_model_generation(model, tokenizer, model.batch_size, 10, **gen_kwargs)


@is_inferentia_test
@requires_neuronx
def test_model_generation_input_dimensions(neuron_decoder_path):
model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path)
tokenizer = AutoTokenizer.from_pretrained(neuron_decoder_path)
# Using valid input dimensions
_test_model_generation(model, tokenizer, model.batch_size, model.max_length // 2)
# Using an incompatible batch_size
with pytest.raises(ValueError, match="The specified batch_size"):
_test_model_generation(model, tokenizer, model.batch_size + 1, model.max_length)
# Using an incompatible input length
with pytest.raises(ValueError, match="The input sequence length"):
_test_model_generation(model, tokenizer, model.batch_size, input_length=model.max_length * 2)


@is_inferentia_test
@requires_neuronx
def test_decoder_generation_custom_stopping_criteria():
model_id = "hf-internal-testing/tiny-random-gpt2"
model = NeuronModelForCausalLM.from_pretrained(model_id, export=True, batch_size=1)

class CustomStoppingCriteria(StoppingCriteria):
def __init__(self):
self.called = False

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
self.called = True
return True

criteria = CustomStoppingCriteria()
model.generate(input_ids=torch.ones([1, 10], dtype=torch.int64), stopping_criteria=[criteria])
assert criteria.called, "Custom StoppingCriteria should have been called"


@is_inferentia_test
@requires_neuronx
def test_seq2seq_generation_beam(neuron_seq2seq_beam_path):
Expand Down
Loading

0 comments on commit 8f3e96a

Please sign in to comment.