Skip to content

Commit

Permalink
Reduce memory usage (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
oraluben authored Jan 5, 2025
1 parent 60f0f87 commit a710e18
Showing 1 changed file with 36 additions and 30 deletions.
66 changes: 36 additions & 30 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@
from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap


AVAILABLE_EXAMPLES = {
"causal": lambda: test_mask(mask_mod=causal_mask),
"alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True),
"sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)),
"prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)),
"document": lambda: run_document_masking(max_seq_len=32768, num_docs=12),
"softcap": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True
),
"softcap_approx": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True
),
}


torch.set_default_device("cuda")
torch.manual_seed(0)

Expand Down Expand Up @@ -97,19 +112,15 @@ def test_mask(
causal_fav2_flops = 0.5 * B * H * D * S * S
flops = density * B * H * D * S * S

# Forward pass
causal_fa2_time = do_bench(causal_fa2)
sdpa_mask_time = do_bench(sdpa_mask)
flex_ms = do_bench(flex_attention_call)
times = []
for attn in (causal_fa2, sdpa_mask, flex_attention_call):
fwd_time = do_bench(attn)
fwd_out = attn()
bwd_time = do_bench(lambda: fwd_out.backward(gradOut, retain_graph=True)) # noqa: F821
times.append((fwd_time, bwd_time))

# Backward pass
causal_fa2_out = causal_fa2()
sdpa_mask_out = sdpa_mask()
flex_out = flex_attention_call()

causal_fa2_bw_time = do_bench(lambda: causal_fa2_out.backward(gradOut, retain_graph=True))
sdpa_mask_bw_time = do_bench(lambda: sdpa_mask_out.backward(gradOut, retain_graph=True))
flex_bw_ms = do_bench(lambda: flex_out.backward(gradOut, retain_graph=True))
del fwd_out
torch.cuda.empty_cache()

print_header(
f"{score_mod.__name__ if score_mod is not None else mask_mod.__name__}".replace(
Expand Down Expand Up @@ -140,6 +151,12 @@ def test_mask(
torch.testing.assert_close(flex, sdpa_mask, atol=1e-1, rtol=1e-2)

print("Correctness check passed ✅")

(
(causal_fa2_time, causal_fa2_bw_time),
(sdpa_mask_time, sdpa_mask_bw_time),
(flex_ms, flex_bw_ms),
) = times
# Usage in your results formatting:
results = [
[
Expand Down Expand Up @@ -210,28 +227,16 @@ def main(examples: List[str] = ["all"]):
Args:
examples: List of examples to run. If "all" is specified, all examples will be run.
"""
available_examples = {
"causal": lambda: test_mask(mask_mod=causal_mask),
"alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True),
"sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)),
"prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)),
"document": lambda: run_document_masking(max_seq_len=32768, num_docs=12),
"softcap": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True
),
"softcap_approx": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True
),
}

if "all" in examples:
ex_to_run = list(available_examples.keys())
ex_to_run = list(AVAILABLE_EXAMPLES.keys())
else:
ex_to_run = examples

for ex in ex_to_run:
if ex in available_examples:
available_examples[ex]()
if ex in AVAILABLE_EXAMPLES:
AVAILABLE_EXAMPLES[ex]()
torch.cuda.empty_cache()
else:
print(f"Warning: Unknown example key '{ex}'. Skipping.")

Expand All @@ -248,8 +253,9 @@ def main(examples: List[str] = ["all"]):
nargs="+",
default=["all"],
help="List of examples to run. Use space to separate multiple examples. "
"Available options: causal, alibi, sliding_window, prefix_lm, "
"document, softcap, softcap_approx, or 'all' to run all examples.",
"Available options: "
+ ", ".join(sorted(AVAILABLE_EXAMPLES.keys()))
+ ", or 'all' to run all examples.",
)

args = parser.parse_args()
Expand Down

0 comments on commit a710e18

Please sign in to comment.