diff --git a/README.md b/README.md
index 2cc4958d..32974079 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,20 @@
# float8_experimental
-This is a prototype of a float8 training UX in native PyTorch, with full torch.compile and distributed support.
+This is an early version of a library for accelerating training with float8 in native PyTorch
+according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf.
The codebase strives to stay small, easily hackable, and debuggable with native PyTorch tooling.
+``torch.compile`` is supported out of the box. With ``torch.compile`` on, initial results show
+throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.
-Backwards compatibility is not guaranteed at this point. The codebase is in active development and
-will change rapidly.
+:warning: See the [feature tracker](https://github.com/pytorch-labs/float8_experimental/issues/187) for upcoming features. Key features such as weight cast recomputation in backward and large scale distributed support are not ready yet.
+
+:warning: Backwards compatibility is not guaranteed at this point. The codebase is in active development and
+will change rapidly.
# installation
+:warning: For now, use the latest PyTorch nightly for best results with torch.compile.
+
```Shell
pip install .
@@ -18,9 +25,9 @@ pip install -e .
pip install -e ".[dev]"
```
-# User API, subject to change
+# User API
-We provide two scaling strategies: per-tensor dynamic and delayed.
+We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details.
## float8 linear with dynamic scaling
@@ -61,7 +68,7 @@ m = Model(...)
swap_linear_with_float8_linear(m, Float8Linear)
# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
-# config.enable_pre_and_post_forward are needed for autocast+compile+FSDP+float8 to work
+# config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work
from float8_experimental import config
config.enable_amax_init = False # only needed for autocast + compile + FSDP + float8 delayed
config.enable_pre_and_post_forward = False # only needed for autocast + compile + FSDP + float8 delayed
@@ -103,7 +110,7 @@ pytest test/test_compile.py
# run a two-GPU integration test on FSDP
./test/test_fsdp.sh
-# run integration tests for TP/SP
+# run integration tests for TP/SP (outdated)
./test/test_tp.sh
# run all of these tests
@@ -116,7 +123,7 @@ pytest test/test_compile.py
# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes
./benchmarks/bench_matmul.py
-# benchmark fw/bw of `Linear`, `Float8Linear` on LLaMa 2 70B shapes
+# benchmark fw/bw of `Linear` and `Float8Linear` on LLaMa 2 70B shapes
# make sure to turn on torch.compile to get the best performance
./benchmarks/bench_linear_float8.py -o ../tmp/test.txt --compile
```