Skip to content

Commit

Permalink
Fix pylint for micro benchmark (AI-Hypercomputer#36)
Browse files Browse the repository at this point in the history
fix pylint for micro benchmark
  • Loading branch information
FanhaiLu1 authored Apr 25, 2024
1 parent 8786fb2 commit dca79a5
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 16 deletions.
2 changes: 2 additions & 0 deletions benchmarks/.pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[MESSAGES CONTROL]
disable=C0114
13 changes: 13 additions & 0 deletions benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
40 changes: 24 additions & 16 deletions benchmarks/analyze_sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,20 @@
}

# batch size 60, ful cache, bfloat
system_time_per_decode_token_ms = 26.55 / 60
SYSTEM_TIME_PER_DECODE_TOKEN_MS = 26.55 / 60

# batch size 96, ful cache, quantized
system_time_per_decode_token_ms = 26.0 / 96
SYSTEM_TIME_PER_DECODE_TOKEN_MS = 26.0 / 96

# batch size 96, rolling, bfloat
system_time_per_decode_token_ms = 28.18 / 96
SYSTEM_TIME_PER_DECODE_TOKEN_MS = 28.18 / 96

# batch size 160, rolling, quantized
system_time_per_decode_token_ms = 30 / 160
SYSTEM_TIME_PER_DECODE_TOKEN_MS = 30 / 160


#pylint: disable-next=all
def do_simulation(prefill_bucket_size_to_ms, system_time_per_decode_token_ms):

def next_power_of_2(x):
return 1 if x == 0 else 2 ** (x - 1).bit_length()

Expand All @@ -83,6 +83,7 @@ def tokens_in_input_str(s):
convo_numbers = []
# Please update with your own data file path
loaded_share_gpt = json.load(
#pylint: disable-next=all
open("~/data/ShareGPT_V3_unfiltered_cleaned_split.json", "r")
)
for example in loaded_share_gpt:
Expand All @@ -97,17 +98,17 @@ def tokens_in_input_str(s):
c for c in convo_numbers if c[0] <= CUTOFF_INPUT and c[1] <= CUTOFF_OUTPUT
]

mean_input = sum([c[0] for c in kept_convos]) / len(kept_convos)
mean_output = sum([c[1] for c in kept_convos]) / len(kept_convos)
mean_input = sum(c[0] for c in kept_convos) / len(kept_convos)
mean_output = sum(c[1] for c in kept_convos) / len(kept_convos)

print(
f"Total {num_convos=} but only kept {kept_convos=}. Out of kept, {mean_input=}, {mean_output=}"
f"""Total {num_convos=} but only kept {kept_convos=}.
Out of kept, {mean_input=}, {mean_output=}"""
)

total_prefill_system_ms = 0
total_generate_system_ms = 0

total_system_output_tokens = 0
for convo in kept_convos:
input_tok, output_tok = convo
bucket = max(128, next_power_of_2(input_tok))
Expand All @@ -122,18 +123,20 @@ def tokens_in_input_str(s):
total_generate_system_ms += generate_system_ms

total_time_ms = total_prefill_system_ms + total_generate_system_ms
input_tokens = sum([c[0] for c in kept_convos])
input_tokens = sum(c[0] for c in kept_convos)

output_tokens = sum([c[1] for c in kept_convos])
output_tokens = sum(c[1] for c in kept_convos)
print(
f"Output tokens {output_tokens} in {total_time_ms/1000:.2f} seconds, for {output_tokens/(total_time_ms/1000):.2f} out tok/s"
f"""Output tokens {output_tokens} in {total_time_ms/1000:.2f} seconds,
for {output_tokens/(total_time_ms/1000):.2f} out tok/s"""
)

total_prefill_sec = total_prefill_system_ms / 1000
total_generate_sec = total_generate_system_ms / 1000

print(
f"Total time {total_time_ms/1000:.2f} seconds, split {total_prefill_sec=:.2f} seconds and {total_generate_sec=:.2f} seconds"
f"""Total time {total_time_ms/1000:.2f} seconds,
split {total_prefill_sec=:.2f} seconds and {total_generate_sec=:.2f} seconds"""
)

idealized_prefill_sec = (
Expand All @@ -148,16 +151,21 @@ def tokens_in_input_str(s):
generate_savings_sec = total_generate_sec - idealized_generate_sec

print(
f"we think prefill will take {total_prefill_sec=:.2f}, we could get it to {idealized_prefill_sec=:.2f} so we'd save {prefill_savings_sec=:.2f} seconds "
f"""we think prefill will take {total_prefill_sec=:.2f},
we could get it to {idealized_prefill_sec=:.2f} so we'd
save {prefill_savings_sec=:.2f} seconds """
)
print(
f"with sparsity we could go from {total_generate_sec=:.2f}, we could get it to {idealized_generate_sec=:.2f} so we'd save {generate_savings_sec=:.2f} seconds "
f"""with sparsity we could go from {total_generate_sec=:.2f},
we could get it to {idealized_generate_sec=:.2f} so we'd save
{generate_savings_sec=:.2f} seconds """
)

idealized_overall_time = idealized_generate_sec + idealized_prefill_sec

print(
f"Idealized out tokens {output_tokens} in {idealized_overall_time:.2f} seconds, for {output_tokens/idealized_overall_time:.2f} out tok/s"
f"""Idealized out tokens {output_tokens} in {idealized_overall_time:.2f} seconds,
for {output_tokens/idealized_overall_time:.2f} out tok/s"""
)
print("prfill", prefill_bucket_size_to_ms)
print("decode step", system_time_per_decode_token_ms)

0 comments on commit dca79a5

Please sign in to comment.