diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index da42dacbfe..a10ada48c5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -242,23 +242,28 @@ def load_model( device_map = cfg.device_map if cfg.gpu_memory_limit and max_memory is None: - gpu_memory_limit = str(cfg.gpu_memory_limit) + "GiB" if isinstance(cfg.gpu_memory_limit, int) else cfg.gpu_memory_limit + gpu_memory_limit = ( + str(cfg.gpu_memory_limit) + "GiB" + if isinstance(cfg.gpu_memory_limit, int) + else cfg.gpu_memory_limit + ) max_memory = {} for i in range(torch.cuda.device_count()): max_memory[i] = gpu_memory_limit - max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything + max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything if max_memory is not None: # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py from accelerate import infer_auto_device_map, init_empty_weights + with init_empty_weights(): model_canvas = AutoModelForCausalLM.from_config(model_config) model_canvas.tie_weights() device_map = infer_auto_device_map( model_canvas, max_memory=max_memory, - dtype='float16', # TODO: may probably use bfloat16 and others here as well + dtype="float16", # TODO: may probably use bfloat16 and others here as well ) # We can discard max_memory now as we have a device map set up for us max_memory = None