Skip to content

Commit

Permalink
For inference in model_worker, allow the device to be specified via a…
Browse files Browse the repository at this point in the history
… command line parameter.

Right now it has only been tested with Apple Sillicon devices via the mps device.
  • Loading branch information
filipe-m-almeida committed Oct 6, 2023
1 parent f7d2c1a commit 9987782
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ You can launch as many workers as you want, and compare between different model
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port <different from 40000, say 40001> --worker http://localhost:<change accordingly, i.e. 40001> --model-path <ckpt2>
```

I you are using an Apple device with an M1 or M2 chip, you can specify the mps device by using the --device flag: `--device mps`.

#### Launch a model worker (Multiple GPUs, when GPU VRAM <= 24GB)

If the VRAM of your GPU is less than 24GB (e.g., RTX 3090, RTX 4090, etc.), you may try running it with multiple GPUs. Our latest code base will automatically try to use multiple GPUs if you have more than one GPU. You can specify which GPUs to use with `CUDA_VISIBLE_DEVICES`. Below is an example of running with the first two GPUs.
Expand Down
4 changes: 2 additions & 2 deletions llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN


def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"):
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
kwargs = {"device_map": device_map}

if load_8bit:
Expand Down Expand Up @@ -137,7 +137,7 @@ def load_from_hf(repo_id, filename, subfolder=None):
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device='cuda', dtype=torch.float16)
vision_tower.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor

if hasattr(model.config, "max_sequence_length"):
Expand Down
11 changes: 7 additions & 4 deletions llava/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ModelWorker:
def __init__(self, controller_addr, worker_addr,
worker_id, no_register,
model_path, model_base, model_name,
load_8bit, load_4bit):
load_8bit, load_4bit, device):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
Expand All @@ -60,9 +60,10 @@ def __init__(self, controller_addr, worker_addr,
else:
self.model_name = model_name

self.device = device
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path, model_base, self.model_name, load_8bit, load_4bit)
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
self.is_multimodal = 'llava' in self.model_name.lower()

if not no_register:
Expand Down Expand Up @@ -159,7 +160,7 @@ def generate_stream(self, params):
stop_str = params.get("stop", None)
do_sample = True if temperature > 0.001 else False

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
Expand Down Expand Up @@ -258,6 +259,7 @@ async def get_status(request: Request):
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--model-name", type=str)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
parser.add_argument("--limit-model-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=1)
Expand All @@ -278,5 +280,6 @@ async def get_status(request: Request):
args.model_base,
args.model_name,
args.load_8bit,
args.load_4bit)
args.load_4bit,
args.device)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

0 comments on commit 9987782

Please sign in to comment.