Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Don't merge] Platform tests #5

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft

[Don't merge] Platform tests #5

wants to merge 10 commits into from

Conversation

pcuenca
Copy link
Collaborator

@pcuenca pcuenca commented Aug 26, 2022

No description provided.

@pcuenca pcuenca marked this pull request as draft August 26, 2022 10:21
pickle.dump(text_embeddings.detach().to("cpu").numpy(), f)

num_inference_steps = 50
pipe.scheduler.set_timesteps(num_inference_steps, offset=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed offset=1 is correct here!

The code is very ugly and hard to understand, I'll fix it later. But it
produces the same results in PyTorch and JAX when using the test scripts
`jax-micro-scheduler.py` and `torch-micro-scheduler.py`.

I couldn't just use a dataclass to store the state because jitting
requires simple Python containers. Therefore I convert dataclass <->
dict. We should probably just use dicts instead.

Pure functional implementation became convoluted when trying to
reproduce the previous logic. Qinsheng Zhang sent me a couple of links
for their implementation that could be worth exploring.

Even though the scheduler now produces the same output on test inputs,
the unet+scheduler loop still shows big differences.
Test only works without pmap / fori_loop.
@pcuenca
Copy link
Collaborator Author

pcuenca commented Aug 29, 2022

Current status

☑️ Easy bugs

☑️ Scheduler state bug

I created a stateless implementation that I test using torch-micro-scheduler.py and jax-micro-scheduler.py. They use random input (the same in both scripts) and print the output after every call to scheduler.step. The last line from torch-micro-scheduler.py, when running in CPU, is:

50 [1]: [[ -0.03086062 -14.875211   -21.42081      4.41971   ]]

The last one from jax-micro-scheduler.py using CPU is:
50 [1]: [[[ -0.0308551 -14.875193 -21.420778 4.419718 ]]]

And when using a CUDA device it is:
50 [1]: [[[ -0.03085638 -14.875197 -21.420776 4.4197164 ]]]

I consider this to be close enough. These tests were conducted in the same Linux box.

☑️ Sinusoidal embeddings bug

After this fix, the difference after the 50-times inference loop is at ~8e-4 for the slice I'm printing.

❌ Remaining bugs

A simple image generation loop, without classifier-free guidance, still shows differences between the PyTorch and the Flax versions.

The test scripts show a slice of the current latents after each iteration step. These are the results:

PyTorch CPU

Step  0 [981]: [-0.20029372 -0.71322274 -0.6588407   1.3012452 ]
Step  1 [961]: [-0.20023333 -0.7130672  -0.6587741   1.3011351 ]
Step  2 [961]: [-0.19388276 -0.7101537  -0.6593911   1.2980776 ]
Step  3 [941]: [-0.18676892 -0.7066815  -0.6598941   1.2944702 ]
Step  4 [921]: [-0.17876498 -0.70243776 -0.660212    1.290017  ]
Step  5 [901]: [-0.16972902 -0.69751877 -0.66045487  1.284985  ]
Step  6 [881]: [-0.15949863 -0.69193625 -0.66029674  1.2790979 ]
Step  7 [861]: [-0.14811812 -0.68585664 -0.6598889   1.2726934 ]
Step  8 [841]: [-0.13533752 -0.6792915  -0.6590149   1.2656195 ]
Step  9 [821]: [-0.12131497 -0.6722729  -0.657692    1.2578422 ]
Step 10 [801]: [-0.10598446 -0.6647858  -0.65582865  1.2492626 ]
Step 11 [781]: [-0.08965665 -0.65693915 -0.6534673   1.2399257 ]
Step 12 [761]: [-0.07231925 -0.6487282  -0.6504389   1.2296318 ]
Step 13 [741]: [-0.05423042 -0.64025223 -0.64689887  1.2185035 ]
Step 14 [721]: [-0.03560711 -0.63162565 -0.64254385  1.2063537 ]
Step 15 [701]: [-0.0166229  -0.62293816 -0.63756216  1.1934131 ]
Step 16 [681]: [ 0.00244747 -0.61430305 -0.631572    1.1795033 ]
Step 17 [661]: [ 0.02151561 -0.6057371  -0.62473357  1.1647282 ]
Step 18 [641]: [ 0.04053753 -0.59714484 -0.6170455   1.1489599 ]
Step 19 [621]: [ 0.05943215 -0.5886498  -0.6084518   1.1321318 ]
Step 20 [601]: [ 0.07833029 -0.58015937 -0.5992237   1.114365  ]
Step 21 [581]: [ 0.09711783 -0.5717824  -0.5891966   1.0954585 ]
Step 22 [561]: [ 0.11586673 -0.56339365 -0.57859325  1.0755281 ]
Step 23 [541]: [ 0.13457038 -0.5550037  -0.56728995  1.054374  ]
Step 24 [521]: [ 0.15327923 -0.54639286 -0.5554153   1.0321356 ]
Step 25 [501]: [ 0.17189951 -0.53757924 -0.5428081   1.008607  ]
Step 26 [481]: [ 0.19037156 -0.52846026 -0.5295105   0.9838172 ]
Step 27 [461]: [ 0.20872533 -0.5189969  -0.5155331   0.9578293 ]
Step 28 [441]: [ 0.22693683 -0.50913084 -0.5008364   0.9305079 ]
Step 29 [421]: [ 0.24495293 -0.4988736  -0.48545808  0.9019891 ]
Step 30 [401]: [ 0.2628477  -0.48817235 -0.46945277  0.87225354]
Step 31 [381]: [ 0.28055626 -0.47705117 -0.45268854  0.8412365 ]
Step 32 [361]: [ 0.2980854  -0.46552035 -0.4353878   0.80914426]
Step 33 [341]: [ 0.3154298  -0.45355433 -0.41739047  0.77581835]
Step 34 [321]: [ 0.33256987 -0.44116807 -0.39883557  0.7414012 ]
Step 35 [301]: [ 0.34944463 -0.4283454  -0.3796378   0.705828  ]
Step 36 [281]: [ 0.3660169  -0.41508853 -0.35983148  0.6691652 ]
Step 37 [261]: [ 0.38228554 -0.40139943 -0.3394577   0.6314346 ]
Step 38 [241]: [ 0.39813244 -0.38732246 -0.31853062  0.59268993]
Step 39 [221]: [ 0.41363195 -0.37279892 -0.29705465  0.5529112 ]
Step 40 [201]: [ 0.42867348 -0.35788408 -0.27501422  0.5121036 ]
Step 41 [181]: [ 0.44330248 -0.3425106  -0.25247994  0.47017682]
Step 42 [161]: [ 0.45744026 -0.32664642 -0.22940254  0.42711228]
Step 43 [141]: [ 0.47099906 -0.3102275  -0.20562789  0.38262194]
Step 44 [121]: [ 0.48409614 -0.29324844 -0.1810959   0.3365584 ]
Step 45 [101]: [ 0.4967933  -0.27560505 -0.15544392  0.28805938]
Step 46  [81]: [ 0.5093855  -0.25712317 -0.12850419  0.23624928]
Step 47  [61]: [ 0.5223362  -0.23739506 -0.09932347  0.17866679]
Step 48  [41]: [ 0.53584635 -0.21572267 -0.06636252  0.11064841]
Step 49  [21]: [ 0.5504587  -0.18439673 -0.01844398  0.00506777]
Step 50   [1]: [ 0.5498064  -0.18463875 -0.01927387  0.00364937]

Flax CPU

Step: 0: [-0.19992101 -0.7129871  -0.6586888   1.3008868 ]
Step: 1: [-0.19981347 -0.7128339  -0.6585821   1.3007385 ]
Step: 2: [-0.19290741 -0.70965886 -0.6589236   1.2971743 ]
Step: 3: [-0.18514648 -0.7057757  -0.6591348   1.2928768 ]
Step: 4: [-0.17614478 -0.70119715 -0.6590942   1.2878238 ]
Step: 5: [-0.16603893 -0.69599694 -0.6587924   1.2821018 ]
Step: 6: [-0.1546922  -0.69035506 -0.65814054  1.2758265 ]
Step: 7: [-0.14186884 -0.6842141  -0.6570212   1.2688968 ]
Step: 8: [-0.12765141 -0.6776505  -0.65541327  1.2612278 ]
Step: 9: [-0.11221321 -0.6706691  -0.6532921   1.2528518 ]
Step: 10: [-0.09563611 -0.6633487  -0.6506718   1.2437    ]
Step: 11: [-0.07814168 -0.65575045 -0.64742196  1.2336963 ]
Step: 12: [-0.05991069 -0.6480145  -0.6434747   1.2227519 ]
Step: 13: [-0.04123378 -0.6402351  -0.63899314  1.211138  ]
Step: 14: [-0.02239008 -0.6325704  -0.6335036   1.198588  ]
Step: 15: [-0.00347782 -0.6249878  -0.62726957  1.1853817 ]
Step: 16: [ 0.01554522 -0.61738086 -0.62022686  1.1712673 ]
Step: 17: [ 0.03435938 -0.61008286 -0.61222446  1.1561339 ]
Step: 18: [ 0.05318831 -0.6028376  -0.60362244  1.140135  ]
Step: 19: [ 0.07208736 -0.5956207  -0.5944642   1.1233044 ]
Step: 20: [ 0.09092688 -0.58852077 -0.58460766  1.1053636 ]
Step: 21: [ 0.10976899 -0.58154446 -0.574152    1.0863402 ]
Step: 22: [ 0.12861738 -0.5744467  -0.5630966   1.0662225 ]
Step: 23: [ 0.14747892 -0.56707346 -0.55156934  1.0450902 ]
Step: 24: [ 0.1661616  -0.55952215 -0.53917605  1.022551  ]
Step: 25: [ 0.18472286 -0.5516906  -0.52600926  0.9986482 ]
Step: 26: [ 0.2032078  -0.54341716 -0.51236373  0.9736709 ]
Step: 27: [ 0.22153237 -0.53473914 -0.49796087  0.94738066]
Step: 28: [ 0.23979937 -0.5256013  -0.48277816  0.9196414 ]
Step: 29: [ 0.2578938  -0.51604843 -0.4671864   0.89093804]
Step: 30: [ 0.27593407 -0.5060353  -0.4508992   0.8608283 ]
Step: 31: [ 0.29384482 -0.4955864  -0.43395382  0.8294597 ]
Step: 32: [ 0.31158978 -0.48479837 -0.41666374  0.79701966]
Step: 33: [ 0.32928428 -0.4736257  -0.3989653   0.7633914 ]
Step: 34: [ 0.3467968  -0.4620075  -0.38061678  0.72844476]
Step: 35: [ 0.36414167 -0.45000747 -0.36187106  0.69242454]
Step: 36: [ 0.3813184  -0.43756184 -0.34273604  0.65527713]
Step: 37: [ 0.39826918 -0.42457163 -0.32300457  0.616891  ]
Step: 38: [ 0.41499922 -0.4111316  -0.3029643   0.5774899 ]
Step: 39: [ 0.43152043 -0.39716995 -0.28250125  0.53699094]
Step: 40: [ 0.44782242 -0.3827076  -0.26167208  0.4953915 ]
Step: 41: [ 0.46390548 -0.36765796 -0.24032444  0.45241216]
Step: 42: [ 0.47982553 -0.35209504 -0.21875334  0.40826562]
Step: 43: [ 0.49573308 -0.33579704 -0.19671243  0.36226097]
Step: 44: [ 0.5115899  -0.3185639  -0.17404059  0.31394032]
Step: 45: [ 0.527544   -0.30043286 -0.15111566  0.26273066]
Step: 46: [ 0.5436902 -0.2809677 -0.1277442  0.2067176]
Step: 47: [ 0.5601497  -0.25989434 -0.10393226  0.14277294]
Step: 48: [ 0.57713884 -0.2371628  -0.08050387  0.06425086]
Step: 49: [ 0.59747016 -0.20854189 -0.05620244 -0.06810611]
Step: 50: [ 0.5983118  -0.21079235 -0.05890573 -0.07158186]

As you can see, the numbers are slowly drifting apart. For example, the difference after the first step is in the order of 1e-4. After step 50, it is 0.048 for the first column (9% relative).

Could there be something in the UNet that is slightly different? How would you go about debugging this?

@pcuenca
Copy link
Collaborator Author

pcuenca commented Aug 30, 2022

After the sinusoidal fix (#7), the differences after the inference loop are in the order of ~8e-4 for the slice I print in one of my test images. This is the visual result (left: before the fix; middle: PyTorch CPU reference; right: after the fix). I see no difference between the PyTorch and Flax versions:

Before : After

Next steps:

  • Should we enable int64 for timesteps?

This means that some operations will be performed in float64 and then we'll cast to float32. It has to be enabled explicitly using jax.config.update("jax_enable_x64", True). I haven't measure performance impact, my anecdotal results on one test image are:

int32: jnp.sum(jnp.abs((latents - torch_latents))) -> 98.92086
int64: jnp.sum(jnp.abs((latents - torch_latents))) -> 96.41958

  • Improve the stateless scheduler code

I'm using both a dataclass and a dict depending on what part of the code, and I'm sharding manually. I believe I should move to a flax.struct.dataclass to simplify.

@patrickvonplaten
Copy link
Collaborator

Thanks a lot for the summary @pcuenca and great job finding the bug!

Regarding the questions:

1.) I think we should not create flloat64 timesteps, but just keep float32 timesteps (the 2% relative diff is not worth it IMO)

More generally I think after having verified that TPU works well we can spin up a demo with this and then before publishing the notebook, I'd actually merge everything directly into diffusers and make a bigger release (0.3.0) saying we now also support JAX

@pcuenca
Copy link
Collaborator Author

pcuenca commented Aug 30, 2022

I think we should not create flloat64 timesteps
👍

I'd actually merge everything directly into diffusers and make a bigger release (0.3.0) saying we now also support JAX

Fine for me! But this is only for Stable Diffusion with this one scheduler, we need to be clear about that :)

flax.struct.dataclass instances can be automatically replicated, and the
API enforces a functional approach. I don't think they provide a lot of
benefits over a pure dict though.
@pcuenca pcuenca mentioned this pull request Aug 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants