Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Update README.md (#189)
Browse files Browse the repository at this point in the history
Summary:
Add more product context, link to feature tracker, fix some typos

Pull Request resolved: #189

Reviewed By: drisspg, malfet

Differential Revision: D52806367

Pulled By: vkuzo

fbshipit-source-id: f6d9b549ae697cf75cb00d0ee7e03989e3c4175c
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jan 16, 2024
1 parent d272138 commit f86dd67
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
# float8_experimental

This is a prototype of a float8 training UX in native PyTorch, with full torch.compile and distributed support.
This is an early version of a library for accelerating training with float8 in native PyTorch
according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf.
The codebase strives to stay small, easily hackable, and debuggable with native PyTorch tooling.
``torch.compile`` is supported out of the box. With ``torch.compile`` on, initial results show
throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.

Backwards compatibility is not guaranteed at this point. The codebase is in active development and
will change rapidly.
:warning: <em>See the [feature tracker](https://github.com/pytorch-labs/float8_experimental/issues/187) for upcoming features. Key features such as weight cast recomputation in backward and large scale distributed support are not ready yet. </em>

:warning: <em>Backwards compatibility is not guaranteed at this point. The codebase is in active development and
will change rapidly.</em>

# installation

:warning: <em>For now, use the latest PyTorch nightly for best results with torch.compile.</em>

```Shell
pip install .

Expand All @@ -18,9 +25,9 @@ pip install -e .
pip install -e ".[dev]"
```

# User API, subject to change
# User API

We provide two scaling strategies: per-tensor dynamic and delayed.
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details.

## float8 linear with dynamic scaling

Expand Down Expand Up @@ -61,7 +68,7 @@ m = Model(...)
swap_linear_with_float8_linear(m, Float8Linear)

# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
# config.enable_pre_and_post_forward are needed for autocast+compile+FSDP+float8 to work
# config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work
from float8_experimental import config
config.enable_amax_init = False # only needed for autocast + compile + FSDP + float8 delayed
config.enable_pre_and_post_forward = False # only needed for autocast + compile + FSDP + float8 delayed
Expand Down Expand Up @@ -103,7 +110,7 @@ pytest test/test_compile.py
# run a two-GPU integration test on FSDP
./test/test_fsdp.sh

# run integration tests for TP/SP
# run integration tests for TP/SP (outdated)
./test/test_tp.sh

# run all of these tests
Expand All @@ -116,7 +123,7 @@ pytest test/test_compile.py
# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes
./benchmarks/bench_matmul.py

# benchmark fw/bw of `Linear`, `Float8Linear` on LLaMa 2 70B shapes
# benchmark fw/bw of `Linear` and `Float8Linear` on LLaMa 2 70B shapes
# make sure to turn on torch.compile to get the best performance
./benchmarks/bench_linear_float8.py -o ../tmp/test.txt --compile
```
Expand Down

0 comments on commit f86dd67

Please sign in to comment.