Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference with lower VRAM requirements #18

Merged
merged 2 commits into from
Nov 22, 2024

Conversation

frutiemax92
Copy link
Contributor

@frutiemax92 frutiemax92 commented Nov 21, 2024

Inference with bfloat16 and float16 works on a RTX4070. Float32 gives OOM.

@nitinmukesh
Copy link

nitinmukesh commented Nov 22, 2024

Thank you for this PR.

tested on 8GB VRAM

Before
image
Inference time: 5-6m

After using your updates
image
Inference time: 2m

@lawrence-cj lawrence-cj merged commit c66ebf9 into NVlabs:main Nov 22, 2024
2 checks passed
@lawrence-cj
Copy link
Collaborator

Thanks a lot for your work along the time

@FurkanGozukara
Copy link

FurkanGozukara commented Nov 22, 2024

this changed how batch size works

now it process 6 times if batch size is 6 and each iteration it becomes slower and slower

definitely something wrong - also vram not cleared after generations :) it uses entire vram after 6 iterations with batch size 6

image

@lawrence-cj
Copy link
Collaborator

The batch inference is not changed by this PR. The only thing this PR change is adding some lines of torch.no_grad(). SanaPipeline class in the current repo only supports batch inference for the same prompt:

num_images_per_prompt=1,

If you input a prompts list, it will generate it separately.

@FurkanGozukara
Copy link

then it is probably because of the changes i had to make it to run on windows - i probably broken something because i am giving only 1 prompt via gradio

https://github.com/FurkanGozukara/Sana/blob/main/app/sana_pipeline.py

@FurkanGozukara
Copy link

um_images_per_pro

your code has this for batch size

            for _ in range(num_images_per_prompt):
                with torch.no_grad():
                    prompts.append(
                        prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip()
                    )

@lawrence-cj
Copy link
Collaborator

Yes. This is only for generating multiple images with a single prompt. If you want to batch inference with input like: ['car', 'car', 'car'], then it's not supported for now.

@frutiemax92
Copy link
Contributor Author

then it is probably because of the changes i had to make it to run on windows - i probably broken something because i am giving only 1 prompt via gradio

https://github.com/FurkanGozukara/Sana/blob/main/app/sana_pipeline.py

I am using WSL2 and the sample script from the model's page on github.

@yujincheng08
Copy link
Collaborator

This PR breaks batch inference. Fixed in 9da8550.

@FurkanGozukara
Copy link

This PR breaks batch inference. Fixed in 9da8550.

thanks i told batch was broken

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants