diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index f4230c55994..2b1be34eb7f 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1065,8 +1065,14 @@ def get_balanced_memory( ) ) # The last device is left with max_memory just in case the buffer is not enough. + max_leave_size = max([module_sizes[leave] for leave in leaves]) for idx in gpus_idx_list[:-1]: - max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx]) + if idx == 0 and not low_zero and max_leave_size > per_gpu * 0.9: + max_memory[idx] = min(max_leave_size * 1.3, max_memory[idx]) + elif idx == 1 and low_zero and max_leave_size > per_gpu * 0.9: + max_memory[idx] = min(max_leave_size * 1.3, max_memory[idx]) + else: + max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx]) if low_zero: min_zero = max(0, module_sizes[""] - sum([max_memory[i] for i in range(1, num_devices)]))