From 896da5d3b59252cba40aea6818621b2fbc77fbf1 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sun, 11 Jun 2023 19:44:41 +0200 Subject: [PATCH] Benchmark >2048 token sequence prompts in batches --- model.py | 4 ++++ test_benchmark_inference.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/model.py b/model.py index 64e1c43a..be15f5b4 100644 --- a/model.py +++ b/model.py @@ -705,6 +705,10 @@ def __init__(self, config): temp_zeros_float, temp_dq) + # Clear the cache + + torch.cuda.empty_cache() + def forward(self, input_ids, cache, last_id_only = True, preprocess_only = False): diff --git a/test_benchmark_inference.py b/test_benchmark_inference.py index 177ed21c..20b0ac66 100644 --- a/test_benchmark_inference.py +++ b/test_benchmark_inference.py @@ -33,7 +33,14 @@ def begin(): def next_logits(input_ids, last_id_only = True): global model, cache - return model.forward(input_ids, cache, last_id_only) + n_logits = None + a = 0 + while a < input_ids.shape[-1]: + b = min(input_ids.shape[-1], a + 2048) + n_logits = model.forward(input_ids[:, a:b], cache, last_id_only) + a = b + + return n_logits def tokenize(text): @@ -121,7 +128,7 @@ def mem(name, total = False): # Warming up apparently makes a huge difference - for i in range(1, 4): + for i in range(1, 3): print(f" -- Warmup pass {i}...") begin() logits = timer("Warmup", lambda: next_logits(ids))