-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Don't merge] Platform tests #5
base: main
Are you sure you want to change the base?
Conversation
pickle.dump(text_embeddings.detach().to("cpu").numpy(), f) | ||
|
||
num_inference_steps = 50 | ||
pipe.scheduler.set_timesteps(num_inference_steps, offset=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed offset=1
is correct here!
The code is very ugly and hard to understand, I'll fix it later. But it produces the same results in PyTorch and JAX when using the test scripts `jax-micro-scheduler.py` and `torch-micro-scheduler.py`. I couldn't just use a dataclass to store the state because jitting requires simple Python containers. Therefore I convert dataclass <-> dict. We should probably just use dicts instead. Pure functional implementation became convoluted when trying to reproduce the previous logic. Qinsheng Zhang sent me a couple of links for their implementation that could be worth exploring. Even though the scheduler now produces the same output on test inputs, the unet+scheduler loop still shows big differences.
Test only works without pmap / fori_loop.
Current status☑️ Easy bugs
☑️ Scheduler state bugI created a stateless implementation that I test using
The last one from And when using a I consider this to be close enough. These tests were conducted in the same Linux box. ☑️ Sinusoidal embeddings bugAfter this fix, the difference after the 50-times inference loop is at ❌ Remaining bugsA simple image generation loop, without classifier-free guidance, still shows differences between the PyTorch and the Flax versions. The test scripts show a slice of the current PyTorch CPU
Flax CPU
As you can see, the numbers are slowly drifting apart. For example, the difference after the first step is in the order of 1e-4. After step 50, it is 0.048 for the first column (9% relative). Could there be something in the UNet that is slightly different? How would you go about debugging this? |
After the sinusoidal fix (#7), the differences after the inference loop are in the order of Next steps:
This means that some operations will be performed in
I'm using both a dataclass and a dict depending on what part of the code, and I'm sharding manually. I believe I should move to a
|
Thanks a lot for the summary @pcuenca and great job finding the bug! Regarding the questions: 1.) I think we should not create flloat64 timesteps, but just keep float32 timesteps (the 2% relative diff is not worth it IMO) More generally I think after having verified that TPU works well we can spin up a demo with this and then before publishing the notebook, I'd actually merge everything directly into |
Fine for me! But this is only for Stable Diffusion with this one scheduler, we need to be clear about that :) |
flax.struct.dataclass instances can be automatically replicated, and the API enforces a functional approach. I don't think they provide a lot of benefits over a pure dict though.
No description provided.