Skip to content

Commit

Permalink
scan and apply_layers: milestone 1
Browse files Browse the repository at this point in the history
This commit adds the lowering of scan to HLO While op. It also
introduce apply_layers which can sequentially apply a bunch of layers
using scan underneath.

In this milestone we use AOTAutograd to obtain the backward of the
function being scanned. Users can either save the activations in
fn or recompute them by passing different graph partitioners to
AOTAutograd.
  • Loading branch information
tengyifei committed Nov 21, 2024
1 parent 2ec2264 commit 41acb50
Show file tree
Hide file tree
Showing 13 changed files with 1,597 additions and 144 deletions.
8 changes: 4 additions & 4 deletions examples/decoder_only_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
from torch import nn


# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core.
# the default config is intentionally kept low to make it runnable on a single tpu v2-8 core.
@dataclass
class DecoderOnlyConfig:
hidden_size: int = 1024
num_hidden_layers: int = 2
num_attention_heads: int = 8
num_key_value_heads: int = 4
intermediate_size = 32 * 1024
vocab_size = 3200
use_flash_attention = False
intermediate_size: int = 32 * 1024
vocab_size: int = 3200
use_flash_attention: bool = False


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down
3 changes: 2 additions & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_scan.py"
run_test "$CDIR/scan/test_scan.py"
run_test "$CDIR/scan/test_scan_layers.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
Expand Down
Loading

0 comments on commit 41acb50

Please sign in to comment.