Skip to content

Commit

Permalink
add Lora config to arg list in Neo sharding script& its integ test ch…
Browse files Browse the repository at this point in the history
…ange (#2552)
  • Loading branch information
HappyAmazonian authored Nov 14, 2024
1 parent e19237a commit c5f1efc
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 5 deletions.
38 changes: 34 additions & 4 deletions serving/docker/partition/sm_neo_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys
import logging
from importlib.metadata import version
from typing import Final
from typing import Final, Optional

from sm_neo_utils import (OptimizationFatalError, write_error_to_file,
get_neo_env_vars)
Expand Down Expand Up @@ -106,15 +106,43 @@ def shard_lmi_dist_model(self, input_dir: str, output_dir: str,
# unless specified otherwise by the customer
gpu_memory_utilization = float(
self.properties.get("option.gpu_memory_utilization", 0.9))
enforce_eager: bool = str(
self.properties.get("option.enforce_eager",
False)).lower() == "true"
enforce_eager: bool = self.properties.get("option.enforce_eager",
"true").lower() == "true"
max_rolling_batch_size = int(
self.properties.get("option.max_rolling_batch_size", 256))
max_model_len = self.properties.get("option.max_model_len", None)
if max_model_len is not None:
max_model_len = int(max_model_len)

# LoraConfigs
lora_kwargs = {}
if enable_lora := self.properties.get("option.enable_lora"):
enable_lora_bool = enable_lora.lower() == "true"

if enable_lora_bool:
max_loras: int = int(
self.properties.get("option.max_loras", "4"))
max_lora_rank: int = int(
self.properties.get("option.max_lora_rank", "16"))
fully_sharded_loras: bool = str(
self.properties.get("option.fully_sharded_loras",
"false")).lower() == "true"
lora_extra_vocab_size: int = int(
self.properties.get("option.lora_extra_vocab_size", "256"))
lora_dtype: str = self.properties.get("option.lora_dtype",
"auto")
max_cpu_loras: Optional[int] = None
if cpu_loras := self.properties.get("option.max_cpu_loras"):
max_cpu_loras = int(cpu_loras)

lora_kwargs["enable_lora"] = enable_lora_bool
lora_kwargs["fully_sharded_loras"] = fully_sharded_loras
lora_kwargs["max_loras"] = max_loras
lora_kwargs["max_lora_rank"] = max_lora_rank
lora_kwargs["lora_extra_vocab_size"] = lora_extra_vocab_size
lora_kwargs["lora_dtype"] = lora_dtype
lora_kwargs["max_cpu_loras"] = max_cpu_loras

engine_args = VllmEngineArgs(
model=input_dir,
pipeline_parallel_size=pp_degree,
Expand All @@ -125,7 +153,9 @@ def shard_lmi_dist_model(self, input_dir: str, output_dir: str,
enforce_eager=enforce_eager,
max_num_seqs=max_rolling_batch_size,
max_model_len=max_model_len,
**lora_kwargs,
)

engine = engine_from_args(engine_args)

model_dir = os.path.join(output_dir, sm_fml.MODEL_DIR_NAME)
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,12 @@ def get_model_name():
"seq_length": [32],
"tokenizer": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
},
"tiny-llama-lora-fml": {
"batch_size": [4],
"seq_length": [32],
"adapters": ["tarot"],
"tokenizer": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
},
"llama-3.1-8b": {
"batch_size": [1],
"seq_length": [256],
Expand Down
34 changes: 33 additions & 1 deletion tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,17 @@
"tiny-llama-fml": {
"option.model_id": "s3://djl-llm/tinyllama-1.1b-chat/",
"option.tensor_parallel_degree": 2,
"option.load_format": 'sagemaker_fast_model_loader',
"option.load_format": "sagemaker_fast_model_loader",
},
"tiny-llama-lora-fml": {
"option.model_id": "s3://djl-llm/tinyllama-1.1b-chat/",
"option.tensor_parallel_degree": 2,
"option.load_format": "sagemaker_fast_model_loader",
"option.adapters": "adapters",
"option.enable_lora": "true",
"option.max_lora_rank": "64",
"adapter_ids": ["barissglc/tinyllama-tarot-v1"],
"adapter_names": ["tarot"],
},
"llama-3.1-8b": {
"option.model_id": "s3://djl-llm/llama-3.1-8b-hf/",
Expand Down Expand Up @@ -1330,6 +1340,28 @@ def create_neo_input_model(properties):
cmd = ["aws", "s3", "sync", model_s3_uri, model_download_path]
subprocess.check_call(cmd)

adapter_ids = properties.pop("adapter_ids", [])
adapter_names = properties.pop("adapter_names", [])
# Copy Adapters if any
if adapter_ids:
print("copying adapter models")
adapters_path = os.path.join(model_download_path, "adapters")
os.makedirs(adapters_path, exist_ok=True)
## install huggingface_hub in your workflow file to use this
from huggingface_hub import snapshot_download
adapter_cache = {}
for adapter_id, adapter_name in zip(adapter_ids, adapter_names):
print(f"copying adapter models {adapter_id} {adapter_name}")
dir = os.path.join(adapters_path, adapter_name)
if adapter_id in adapter_cache:
shutil.copytree(adapter_cache[adapter_id], dir)
else:
os.makedirs(dir, exist_ok=True)
snapshot_download(adapter_id,
local_dir_use_symlinks=False,
local_dir=dir)
adapter_cache[adapter_id] = dir


def build_hf_handler_model(model):
if model not in hf_handler_list:
Expand Down

0 comments on commit c5f1efc

Please sign in to comment.