Skip to content

Commit

Permalink
scan and scan_layers (#7901)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei authored Nov 23, 2024
1 parent 31d348e commit 51575db
Show file tree
Hide file tree
Showing 13 changed files with 1,610 additions and 150 deletions.
8 changes: 4 additions & 4 deletions examples/decoder_only_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
from torch import nn


# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core.
# the default config is intentionally kept low to make it runnable on a single tpu v2-8 core.
@dataclass
class DecoderOnlyConfig:
hidden_size: int = 1024
num_hidden_layers: int = 2
num_attention_heads: int = 8
num_key_value_heads: int = 4
intermediate_size = 32 * 1024
vocab_size = 3200
use_flash_attention = False
intermediate_size: int = 32 * 1024
vocab_size: int = 3200
use_flash_attention: bool = False


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down
3 changes: 2 additions & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_scan.py"
run_test "$CDIR/scan/test_scan.py"
run_test "$CDIR/scan/test_scan_layers.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
Expand Down
Loading

0 comments on commit 51575db

Please sign in to comment.