diff --git a/README.md b/README.md index 2cc4958d..32974079 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,20 @@ # float8_experimental -This is a prototype of a float8 training UX in native PyTorch, with full torch.compile and distributed support. +This is an early version of a library for accelerating training with float8 in native PyTorch +according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf. The codebase strives to stay small, easily hackable, and debuggable with native PyTorch tooling. +``torch.compile`` is supported out of the box. With ``torch.compile`` on, initial results show +throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs. -Backwards compatibility is not guaranteed at this point. The codebase is in active development and -will change rapidly. +:warning: See the [feature tracker](https://github.com/pytorch-labs/float8_experimental/issues/187) for upcoming features. Key features such as weight cast recomputation in backward and large scale distributed support are not ready yet. + +:warning: Backwards compatibility is not guaranteed at this point. The codebase is in active development and +will change rapidly. # installation +:warning: For now, use the latest PyTorch nightly for best results with torch.compile. + ```Shell pip install . @@ -18,9 +25,9 @@ pip install -e . pip install -e ".[dev]" ``` -# User API, subject to change +# User API -We provide two scaling strategies: per-tensor dynamic and delayed. +We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. ## float8 linear with dynamic scaling @@ -61,7 +68,7 @@ m = Model(...) swap_linear_with_float8_linear(m, Float8Linear) # optional: use FSDP. Note that workarounds gated with config.enable_amax_init and -# config.enable_pre_and_post_forward are needed for autocast+compile+FSDP+float8 to work +# config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work from float8_experimental import config config.enable_amax_init = False # only needed for autocast + compile + FSDP + float8 delayed config.enable_pre_and_post_forward = False # only needed for autocast + compile + FSDP + float8 delayed @@ -103,7 +110,7 @@ pytest test/test_compile.py # run a two-GPU integration test on FSDP ./test/test_fsdp.sh -# run integration tests for TP/SP +# run integration tests for TP/SP (outdated) ./test/test_tp.sh # run all of these tests @@ -116,7 +123,7 @@ pytest test/test_compile.py # benchmark the torch._scaled_mm function on LLaMa 2 70B shapes ./benchmarks/bench_matmul.py -# benchmark fw/bw of `Linear`, `Float8Linear` on LLaMa 2 70B shapes +# benchmark fw/bw of `Linear` and `Float8Linear` on LLaMa 2 70B shapes # make sure to turn on torch.compile to get the best performance ./benchmarks/bench_linear_float8.py -o ../tmp/test.txt --compile ```