Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add missing assertions
Browse files Browse the repository at this point in the history
notsyncing committed Mar 8, 2024
1 parent 91ed9de commit f67ac00
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
@@ -635,9 +635,10 @@ def test_infer_auto_device_map_with_buffer_check(self):
# Only linear1 (144) fits on device 0, and remaining buffers (batchnorm's 160 + linear2's 120 = 280) won't fit
# device 0, but with offload_buffers they won't be loaded to device 0 all at once, so it's ok now
# Should NOT print a warning in such case
with warnings.catch_warnings():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
device_map = infer_auto_device_map(model, max_memory={0: 400, "cpu": "1GB"}, offload_buffers=True)
assert len(w) == 0
assert device_map == {"linear1": 0, "batchnorm": "cpu", "linear2": "cpu"}

def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self):
@@ -653,9 +654,10 @@ def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self):
# Now we have two devices, linear1 will fit on device 0, batchnorm will fit on device 1, and the second device
# can hold all remaining buffers
# Should NOT print a warning in such case
with warnings.catch_warnings():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 400, "cpu": "1GB"})
assert len(w) == 0
assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"}

# Now we have two devices, but neither the first nor the second device can hold all remaining buffers
@@ -666,9 +668,10 @@ def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self):

# Now we have two devices, neither can hold all the buffers, but we are using the offload_buffers=True
# Should NOT print a warning in such case
with warnings.catch_warnings():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 200, "cpu": "1GB"}, offload_buffers=True)
assert len(w) == 0
assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"}

@require_cuda

0 comments on commit f67ac00

Please sign in to comment.