Skip to content

Commit

Permalink
[TNX] llama 70b special param in cc flags (#1469)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored Jan 10, 2024
1 parent b298187 commit 7069a1f
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class TnXQuantizeMethods(str, Enum):
class TransformerNeuronXProperties(Properties):
"""Transformer neuronx related configurations"""
neuron_optimize_level: Optional[OptimizeLevel] = None
enable_mixed_precision_accumulation: bool = False
dtype: Dtype = Dtype.f32
n_positions: int = 128
unroll: Optional[str] = None
Expand All @@ -59,9 +60,18 @@ class TransformerNeuronXProperties(Properties):

@validator('neuron_optimize_level')
def set_neuron_optimal_env(cls, level):
if "NEURON_CC_FLAGS" not in os.environ:
os.environ["NEURON_CC_FLAGS"] = ""
os.environ[
"NEURON_CC_FLAGS"] = os.environ["NEURON_CC_FLAGS"] + f" -O{level}"

@validator('enable_mixed_precision_accumulation')
def set_mixed_precision_accumulation(cls, enablement):
if "NEURON_CC_FLAGS" not in os.environ:
os.environ["NEURON_CC_FLAGS"] = ""
os.environ["NEURON_CC_FLAGS"] = os.environ[
"NEURON_CC_FLAGS"] + f" --enable-mixed-precision-accumulation"

@validator('context_length_estimate', pre=True)
def parse_context_length(cls, context_length_estimate):
return [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import torch

from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, Token, FINISH_REASON_MAPPER
from djl_python.transformers_neuronx_scheduler.optimum_neuron_scheduler import NeuronGenerator
Expand Down Expand Up @@ -59,10 +60,17 @@ def inference(self, input_data, parameters):
if not is_last_token:
req_ids.append(request.id)

token_id = generation.token_id
log_prob = generation.token_logprob
if isinstance(token_id, torch.Tensor):
token_id = token_id.item()
if isinstance(log_prob, torch.Tensor):
log_prob = log_prob.item()

token = Token(
generation.token_id, ""
token_id, ""
if generation.token_is_special else generation.token_text,
generation.token_logprob, generation.token_is_special)
log_prob, generation.token_is_special)
request.set_next_token(token,
self.output_formatter,
last_token=is_last_token,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def test_tnx_all_configs(self):
"low_cpu_mem_usage": "true",
'context_length_estimate': '256, 512, 1024',
"task": "feature-extraction",
"save_mp_checkpoint_path": "/path/to/checkpoint"
"save_mp_checkpoint_path": "/path/to/checkpoint",
"neuron_optimize_level": 3,
"enable_mixed_precision_accumulation": "true"
}
tnx_configs = TransformerNeuronXProperties(**common_properties,
**properties)
Expand All @@ -134,6 +136,9 @@ def test_tnx_all_configs(self):
self.assertEqual(tnx_configs.task, properties['task'])
self.assertEqual(tnx_configs.save_mp_checkpoint_path,
properties['save_mp_checkpoint_path'])
neuron_cc = os.environ["NEURON_CC_FLAGS"]
self.assertTrue("-O3" in neuron_cc)
self.assertTrue("--enable-mixed-precision-accumulation" in neuron_cc)

# tests context length estimate as integer
def test_tnx_cle_int(context_length_estimate):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from djl_python.rolling_batch.rolling_batch import Request
from djl_python.transformers_neuronx_scheduler.token_selector import TokenSelector
from djl_python.transformers_neuronx_scheduler.utils import Generation, FinishReason, GeneratedText
from djl_python.transformers_neuronx_scheduler.utils import Generation, FinishReason, GeneratedText, TokenDecoder


class Slot:
Expand All @@ -51,6 +51,7 @@ def clear(self):
self._generated_tokens = 0
self._next_token_text = ""
self._cache_id = torch.zeros(1)
self._token_decoder = None

@property
def id(self) -> int:
Expand All @@ -76,14 +77,21 @@ def generation_config(self) -> GenerationConfig:
def generated_tokens(self) -> torch.LongTensor:
return self._generated_tokens

def assign(self, request: Request, generation_config: GenerationConfig):
@property
def decoder(self) -> TokenDecoder:
return self._token_decoder

def assign(self, request: Request, generation_config: GenerationConfig,
tokenizer):
"""Assign a request to a slot.
Args:
request (`Request`):
The request to be assigned. Contains the inputs and tokens selection parameters.
generation_config (`transformers.GenerationConfig`):
The base generation config (might be modified by the request generation parameters).
tokenizer:
The tokenizer used to decode token.
"""
self._state = Slot.State.READY
self._request_id = request.id
Expand All @@ -102,6 +110,7 @@ def assign(self, request: Request, generation_config: GenerationConfig):
self._generation_config.max_new_tokens = param.get(
"max_new_tokens", 30)
# TODO: stop_sequences, ignore_eos_token
self._token_decoder = TokenDecoder(tokenizer)

def reset(self, input_ids, attention_mask, selector, cache_id):
"""Reset the slot for the next generation.
Expand Down Expand Up @@ -256,7 +265,7 @@ def prefill(self, new_requests: List[Request]):
prefill_slots = []
for request in new_requests:
slot = empty_slots.pop()
slot.assign(request, self.model.generation_config)
slot.assign(request, self.model.generation_config, self.tokenizer)
prefill_slots.append(slot)
logging.debug(
f"Request {slot.request_id} assigned to slot {slot.id}")
Expand Down Expand Up @@ -366,22 +375,15 @@ def _generate_token(
next_token_logits = outputs.logits[i:i + 1, -1, :]
slot_input_ids = input_ids[i:i + 1, :]
next_token = slot.select(slot_input_ids, next_token_logits)
next_token_text = self.tokenizer.decode(next_token)
if not slot.generated_text.endswith(
" ") and not next_token_text.startswith(" "):
# Some tokenizers do not prepend spaces automatically when decoding a single token
contextual_text = self.tokenizer.decode(
[slot.next_token, next_token])
if contextual_text[:-len(next_token_text)].endswith(" "):
next_token_text = " " + next_token_text
next_token_text = slot.decoder.decode(next_token.item())
slot.trim_cache_id()
slot.append(next_token, next_token_text)
generated_text = None
finish_reason = None
if next_token == self.tokenizer.eos_token_id:
finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN
elif slot.stopped:
finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE
finish_reason = FinishReason.FINISH_REASON_LENGTH
if finish_reason is not None:
# We must include the generated text for each finished sequence in the response
generated_text = GeneratedText(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class FinishReason(str, Enum):

FINISH_REASON_LENGTH = 0
FINISH_REASON_EOS_TOKEN = 1
FINISH_REASON_STOP_SEQUENCE = 2
Expand Down Expand Up @@ -49,3 +48,37 @@ def __init__(self, text: str, generated_tokens: int,
self.generated_tokens = generated_tokens
self.finish_reason = finish_reason
self.seed = seed


class TokenDecoder:

def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.prefix_offset = 0
self.read_offset = 0
self.all_input_ids = []

def _decode_token(self) -> str:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
prefix_text = self.tokenizer.decode(
self.all_input_ids[self.prefix_offset:self.read_offset],
skip_special_tokens=False)
new_text = self.tokenizer.decode(
self.all_input_ids[self.prefix_offset:], skip_special_tokens=False)
if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text[len(prefix_text):]
self.prefix_offset = self.read_offset
self.read_offset = len(self.all_input_ids)
return new_text
else:
return ""

def decode(self, token_id: int):
self.all_input_ids.append(token_id)
return self._decode_token()
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ If you are using Neuron container and engine set to Python, the following parame
| option.low_cpu_mem_usage | No | Reduce CPU memory usage when loading models. | Default: `False` |
| option.load_split_model | No | Toggle to True when using model artifacts that have already been split for neuron compilation/loading. | Default: `False` |
| option.compiled_graph_path | No | Provide an s3 URI, or a local directory that stores the pre-compiled graph for your model (NEFF cache) to skip runtime compilation. | Default: `None` |
| option.enable_mixed_precision_accumulation | No | Turn this on for LLAMA 70B model to achieve better accuracy. | `true` Default: `None` |


### TensorRT-LLM
Expand Down

0 comments on commit 7069a1f

Please sign in to comment.