Skip to content

Commit

Permalink
[fix] update context estimate interface (#1194)
Browse files Browse the repository at this point in the history
  • Loading branch information
tosterberg authored Oct 18, 2023
1 parent 0440b80 commit ff0f654
Showing 1 changed file with 24 additions and 33 deletions.
57 changes: 24 additions & 33 deletions engines/python/setup/djl_python/transformers_neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,55 +99,44 @@ def load_hf_model(self):
revision=self.revision,
low_cpu_mem_usage=True)

def get_model_specific_kwargs(self, model_type):
model_kwargs = {
"batch_size": self.batch_size,
"amp": self.amp,
'tp_degree': self.tensor_parallel_degree,
"n_positions": self.n_positions,
"unroll": self.unroll
}
if model_type == "llama":
model_kwargs[
'context_length_estimate'] = self.context_length_estimate
return model_kwargs

def load_inf2_model_from_disk(self, model_type, load_path):
if not self.load_split_model:
logging.info(f"Saving INF2 model to {load_path} ...")
save_pretrained_split(self.model, load_path)
model_kwargs = self.get_model_specific_kwargs(model_type)
if self.load_in_8bit:
neuron_config = NeuronConfig()
neuron_config.quant = QuantizationConfig(quant_dtype='s8',
dequant_dtype=self.amp)
return MODEL_TYPE_TO_MODEL[model_type].from_pretrained(
load_path,
batch_size=self.batch_size,
amp=self.amp,
tp_degree=self.tensor_parallel_degree,
n_positions=self.n_positions,
neuron_config=neuron_config,
context_length_estimate=self.context_length_estimate,
unroll=self.unroll)
load_path, neuron_config=neuron_config, **model_kwargs)
return MODEL_TYPE_TO_MODEL[model_type].from_pretrained(
load_path,
batch_size=self.batch_size,
amp=self.amp,
tp_degree=self.tensor_parallel_degree,
context_length_estimate=self.context_length_estimate,
n_positions=self.n_positions,
unroll=self.unroll)
load_path, **model_kwargs)

def load_inf2_model_from_memory(self, model_type):
model_kwargs = self.get_model_specific_kwargs(model_type)
if self.load_in_8bit:
neuron_config = NeuronConfig()
neuron_config.quant = QuantizationConfig(quant_dtype='s8',
dequant_dtype=self.amp)
model = MODEL_TYPE_TO_MODEL[model_type](
self.model.config,
batch_size=self.batch_size,
amp=self.amp,
tp_degree=self.tensor_parallel_degree,
context_length_estimate=self.context_length_estimate,
n_positions=self.n_positions,
neuron_config=neuron_config,
unroll=self.unroll)
self.model.config, neuron_config=neuron_config, **model_kwargs)
else:
model = MODEL_TYPE_TO_MODEL[model_type](
self.model.config,
batch_size=self.batch_size,
amp=self.amp,
tp_degree=self.tensor_parallel_degree,
context_length_estimate=self.context_length_estimate,
n_positions=self.n_positions,
unroll=self.unroll)
model = MODEL_TYPE_TO_MODEL[model_type](self.model.config,
**model_kwargs)
model.load_state_dict_low_memory(self.model.state_dict())
return model

Expand Down Expand Up @@ -176,7 +165,8 @@ def load_model(self, model_type):
self.model.to_neuron()
os.chdir(path)
elapsed = time.time() - start
logging.info(f"SysHealth: LLM sharding and compilation latency: {elapsed} secs")
logging.info(
f"SysHealth: LLM sharding and compilation latency: {elapsed} secs")

def initialize(self, properties):
# Neuron recommendation for transformersneuronx speedup
Expand Down Expand Up @@ -220,7 +210,8 @@ def initialize(self, properties):
self.low_cpu_mem_usage = True
if "context_length_estimate" in properties:
# expect input like [256, 1024, 2048]
self.context_length_estimate = json.loads(properties.get("context_length_estimate"))
self.context_length_estimate = json.loads(
properties.get("context_length_estimate"))
model_config = AutoConfig.from_pretrained(self.model_id_or_path,
revision=self.revision)
if model_config.model_type not in SUPPORTED_MODEL_TYPES:
Expand Down

0 comments on commit ff0f654

Please sign in to comment.