Skip to content

Commit

Permalink
[CI] add a big GPU marker to run memory-intensive tests separately on…
Browse files Browse the repository at this point in the history
… CI (#9691)

* add a marker for big gpu tests

* update

* trigger on PRs temporarily.

* onnx

* fix

* total memory

* fixes

* reduce memory threshold.

* bigger gpu

* empty

* g6e

* Apply suggestions from code review

* address comments.

* fix

* fix

* fix

* fix

* fix

* okay

* further reduce.

* updates

* remove

* updates

* updates

* updates

* updates

* fixes

* fixes

* updates.

* fix

* workflow fixes.

---------

Co-authored-by: Aryan <[email protected]>
  • Loading branch information
sayakpaul and a-r-r-o-w authored Oct 31, 2024
1 parent 4adf6af commit ff182ad
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 123 deletions.
56 changes: 56 additions & 0 deletions .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,62 @@ jobs:
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_big_gpu_torch_tests:
name: Torch tests on big GPU
strategy:
fail-fast: false
max-parallel: 2
runs-on:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "16gb" --ipc host --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
- name: Selected Torch CUDA Test on big GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-m "big_gpu_with_torch_cuda" \
--make-reports=tests_big_gpu_torch_cuda \
--report-log=tests_big_gpu_torch_cuda.log \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_big_gpu_torch_cuda_stats.txt
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: torch_cuda_big_gpu_test_reports
path: reports
- name: Generate Report and Notify Channel
if: always()
run: |
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_flax_tpu_tests:
name: Nightly Flax TPU Tests
runs-on: docker-tpu
Expand Down
21 changes: 21 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
) > version.parse("4.33")

USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
BIG_GPU_MEMORY = int(os.getenv("BIG_GPU_MEMORY", 40))

if is_torch_available():
import torch
Expand Down Expand Up @@ -310,6 +311,26 @@ def require_torch_accelerator_with_fp64(test_case):
)


def require_big_gpu_with_torch_cuda(test_case):
"""
Decorator marking a test that requires a bigger GPU (24GB) for execution. Some example pipelines: Flux, SD3, Cog,
etc.
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)

import torch

if not torch.cuda.is_available():
return unittest.skip("test requires PyTorch CUDA")(test_case)

device_properties = torch.cuda.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3)
return unittest.skipUnless(
total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
)(test_case)


def require_torch_accelerator_with_training(test_case):
"""Decorator marking a test that requires an accelerator with support for training."""
return unittest.skipUnless(
Expand Down
38 changes: 28 additions & 10 deletions tests/pipelines/controlnet_flux/test_controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import unittest

import numpy as np
import pytest
import torch
from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

from diffusers import (
Expand All @@ -30,7 +32,8 @@
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
slow,
torch_device,
)
Expand Down Expand Up @@ -180,7 +183,8 @@ def test_xformers_attention_forwardGenerator_pass(self):


@slow
@require_torch_gpu
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class FluxControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetPipeline

Expand All @@ -199,35 +203,49 @@ def test_canny(self):
"InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16
)
pipe = FluxControlNetPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16
"black-forest-labs/FLUX.1-dev",
text_encoder=None,
text_encoder_2=None,
controlnet=controlnet,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "A girl in city, 25 years old, cool, futuristic"
control_image = load_image(
"https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg"
).resize((512, 512))

prompt_embeds = torch.load(
hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
)
pooled_prompt_embeds = torch.load(
hf_hub_download(
repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
)
)

output = pipe(
prompt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
control_image=control_image,
controlnet_conditioning_scale=0.6,
num_inference_steps=2,
guidance_scale=3.5,
max_sequence_length=256,
output_type="np",
height=512,
width=512,
generator=generator,
)

image = output.images[0]

assert image.shape == (1024, 1024, 3)
assert image.shape == (512, 512, 3)

original_image = image[-3:, -3:, -1].flatten()

expected_image = np.array(
[0.33007812, 0.33984375, 0.33984375, 0.328125, 0.34179688, 0.33984375, 0.30859375, 0.3203125, 0.3203125]
)
expected_image = np.array([0.2734, 0.2852, 0.2852, 0.2734, 0.2754, 0.2891, 0.2617, 0.2637, 0.2773])

assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2
71 changes: 0 additions & 71 deletions tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import gc
import unittest

import numpy as np
Expand All @@ -13,9 +12,6 @@
FluxTransformer2DModel,
)
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)

Expand Down Expand Up @@ -222,70 +218,3 @@ def test_fused_qkv_projections(self):
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."


@slow
@require_torch_gpu
class FluxControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetImg2ImgPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"

def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()

def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)

image = torch.randn(1, 3, 64, 64).to(device)
control_image = torch.randn(1, 3, 64, 64).to(device)

return {
"prompt": "A photo of a cat",
"image": image,
"control_image": control_image,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"controlnet_conditioning_scale": 1.0,
"strength": 0.8,
"output_type": "np",
"generator": generator,
}

@unittest.skip("We cannot run inference on this model with the current CI hardware")
def test_flux_controlnet_img2img_inference(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

inputs = self.get_inputs(torch_device)

image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
[0.36132812, 0.30004883, 0.25830078],
[0.36669922, 0.31103516, 0.23754883],
[0.34814453, 0.29248047, 0.23583984],
[0.35791016, 0.30981445, 0.23999023],
[0.36328125, 0.31274414, 0.2607422],
[0.37304688, 0.32177734, 0.26171875],
[0.3671875, 0.31933594, 0.25756836],
[0.36035156, 0.31103516, 0.2578125],
[0.3857422, 0.33789062, 0.27563477],
[0.3701172, 0.31982422, 0.265625],
],
dtype=np.float32,
)

max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())

assert max_diff < 1e-4
35 changes: 14 additions & 21 deletions tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel

Expand All @@ -30,7 +31,8 @@
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
slow,
torch_device,
)
Expand Down Expand Up @@ -195,7 +197,8 @@ def test_xformers_attention_forwardGenerator_pass(self):


@slow
@require_torch_gpu
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3ControlNetPipeline

Expand Down Expand Up @@ -238,11 +241,9 @@ def test_canny(self):

original_image = image[-3:, -3:, -1].flatten()

expected_image = np.array(
[0.20947266, 0.1574707, 0.19897461, 0.15063477, 0.1418457, 0.17285156, 0.14160156, 0.13989258, 0.30810547]
)
expected_image = np.array([0.7314, 0.7075, 0.6611, 0.7539, 0.7563, 0.6650, 0.6123, 0.7275, 0.7222])

assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2

def test_pose(self):
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Pose", torch_dtype=torch.float16)
Expand Down Expand Up @@ -272,15 +273,12 @@ def test_pose(self):
assert image.shape == (1024, 1024, 3)

original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array([0.9048, 0.8740, 0.8936, 0.8516, 0.8799, 0.9360, 0.8379, 0.8408, 0.8652])

expected_image = np.array(
[0.8671875, 0.86621094, 0.91015625, 0.8491211, 0.87890625, 0.9140625, 0.8300781, 0.8334961, 0.8623047]
)

assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2

def test_tile(self):
controlnet = SD3ControlNetModel.from_pretrained("InstantX//SD3-Controlnet-Tile", torch_dtype=torch.float16)
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Tile", torch_dtype=torch.float16)
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
Expand All @@ -307,12 +305,9 @@ def test_tile(self):
assert image.shape == (1024, 1024, 3)

original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array([0.6699, 0.6836, 0.6226, 0.6572, 0.7310, 0.6646, 0.6650, 0.6694, 0.6011])

expected_image = np.array(
[0.6982422, 0.7011719, 0.65771484, 0.6904297, 0.7416992, 0.6904297, 0.6977539, 0.7080078, 0.6386719]
)

assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2

def test_multi_controlnet(self):
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16)
Expand Down Expand Up @@ -344,8 +339,6 @@ def test_multi_controlnet(self):
assert image.shape == (1024, 1024, 3)

original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array(
[0.7451172, 0.7416992, 0.7158203, 0.7792969, 0.7607422, 0.7089844, 0.6855469, 0.71777344, 0.7314453]
)
expected_image = np.array([0.7207, 0.7041, 0.6543, 0.7500, 0.7490, 0.6592, 0.6001, 0.7168, 0.7231])

assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2
Loading

0 comments on commit ff182ad

Please sign in to comment.