Skip to content

Commit

Permalink
linter fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Karl-Johan Alm committed Dec 6, 2023
1 parent 4054c80 commit 107e138
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,23 +242,28 @@ def load_model(
device_map = cfg.device_map

if cfg.gpu_memory_limit and max_memory is None:
gpu_memory_limit = str(cfg.gpu_memory_limit) + "GiB" if isinstance(cfg.gpu_memory_limit, int) else cfg.gpu_memory_limit
gpu_memory_limit = (
str(cfg.gpu_memory_limit) + "GiB"
if isinstance(cfg.gpu_memory_limit, int)
else cfg.gpu_memory_limit
)

max_memory = {}
for i in range(torch.cuda.device_count()):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything

if max_memory is not None:
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
from accelerate import infer_auto_device_map, init_empty_weights

with init_empty_weights():
model_canvas = AutoModelForCausalLM.from_config(model_config)
model_canvas.tie_weights()
device_map = infer_auto_device_map(
model_canvas,
max_memory=max_memory,
dtype='float16', # TODO: may probably use bfloat16 and others here as well
dtype="float16", # TODO: may probably use bfloat16 and others here as well
)
# We can discard max_memory now as we have a device map set up for us
max_memory = None
Expand Down

0 comments on commit 107e138

Please sign in to comment.