Skip to content

Commit

Permalink
Fix clearing backend cache from device agnostic testing (huggingface#…
Browse files Browse the repository at this point in the history
…6075)

update
  • Loading branch information
DN6 authored and donhardman committed Dec 18, 2023
1 parent 43ff401 commit 4a09065
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion tests/models/test_models_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)

@parameterized.expand(
[
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)

def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)

def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
Expand Down Expand Up @@ -565,7 +565,7 @@ def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)

def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
Expand Down Expand Up @@ -820,7 +820,7 @@ def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)

def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)

def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
Expand Down Expand Up @@ -531,7 +531,7 @@ class StableDiffusion2PipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)

def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
Expand Down

0 comments on commit 4a09065

Please sign in to comment.