Skip to content

Commit

Permalink
Minor refactors.
Browse files Browse the repository at this point in the history
  • Loading branch information
notsyncing committed Mar 6, 2024
1 parent 2a278f4 commit 91ed9de
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
17 changes: 8 additions & 9 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,6 @@ def infer_auto_device_map(
current_device = 0
current_memory_used = 0
device_memory_used = {}
device_memory_reserved = {}
device_buffer_sizes = {}

# Direct submodules and parameters
Expand Down Expand Up @@ -1176,10 +1175,11 @@ def infer_auto_device_map(

device = devices[current_device]
current_max_size = max_memory[device] if device != "disk" else None
current_memory_reserved = 0
# Reduce max size available by the largest layer.
if devices[current_device] in main_devices:
current_max_size = current_max_size - max_layer_size
device_memory_reserved[current_device] = max_layer_size
current_memory_reserved = max_layer_size
# Case 1 -> We're too big!
if current_max_size is not None and current_memory_used + module_size > current_max_size:
# Split or not split?
Expand All @@ -1198,7 +1198,7 @@ def infer_auto_device_map(
if verbose:
print("This module cannot be split, going to the next device.")

device_memory_used[device] = current_memory_used
device_memory_used[device] = current_memory_used + current_memory_reserved
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
Expand Down Expand Up @@ -1297,7 +1297,7 @@ def infer_auto_device_map(
if verbose:
print("None of the tied module can be split, going to the next device.")

device_memory_used[device] = current_memory_used
device_memory_used[device] = current_memory_used + current_memory_reserved
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
Expand All @@ -1312,7 +1312,7 @@ def infer_auto_device_map(
f"(available={current_max_size - current_memory_used})."
)
current_memory_used += module_size
device_memory_used[device] = current_memory_used
device_memory_used[device] = current_memory_used + current_memory_reserved
device_map[name] = devices[current_device]

if not offload_buffers and isinstance(module, nn.Module):
Expand All @@ -1324,16 +1324,15 @@ def infer_auto_device_map(
if clean_result:
device_map = clean_device_map(device_map)

if not offload_buffers:
non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0)

non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0)
if non_gpu_buffer_size > 0 and not offload_buffers:
is_buffer_fit_any_gpu = False
for gpu_device, gpu_max_memory in max_memory.items():
if gpu_device == "cpu" or gpu_device == "disk":
continue

if not is_buffer_fit_any_gpu:
gpu_memory_used = device_memory_used.get(gpu_device, 0) + device_memory_reserved.get(gpu_device, 0)
gpu_memory_used = device_memory_used.get(gpu_device, 0)

if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used:
is_buffer_fit_any_gpu = True
Expand Down
6 changes: 3 additions & 3 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def test_infer_auto_device_map_with_buffer_check(self):
# device 0, but with offload_buffers they won't be loaded to device 0 all at once, so it's ok now
# Should NOT print a warning in such case
with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.simplefilter("always")
device_map = infer_auto_device_map(model, max_memory={0: 400, "cpu": "1GB"}, offload_buffers=True)
assert device_map == {"linear1": 0, "batchnorm": "cpu", "linear2": "cpu"}

Expand All @@ -654,7 +654,7 @@ def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self):
# can hold all remaining buffers
# Should NOT print a warning in such case
with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.simplefilter("always")
device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 400, "cpu": "1GB"})
assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"}

Expand All @@ -667,7 +667,7 @@ def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self):
# Now we have two devices, neither can hold all the buffers, but we are using the offload_buffers=True
# Should NOT print a warning in such case
with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.simplefilter("always")
device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 200, "cpu": "1GB"}, offload_buffers=True)
assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"}

Expand Down

0 comments on commit 91ed9de

Please sign in to comment.