Skip to content

Commit

Permalink
Merge pull request #1081 from akx/ruff-format
Browse files Browse the repository at this point in the history
Reformat Python code with Ruff
  • Loading branch information
Titus-von-Koeller authored Mar 13, 2024
2 parents fd723b7 + 5a4263f commit 06029dd
Show file tree
Hide file tree
Showing 41 changed files with 2,661 additions and 1,777 deletions.
4 changes: 1 addition & 3 deletions .github/scripts/set_platform_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ def get_platform_tag(architecture):
system = platform.system()

if system == "Linux":
tag = (
"manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
)
tag = "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
elif system == "Darwin":
tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64"
elif system == "Windows":
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
rev: v0.3.2
hooks:
- id: ruff
args:
- --fix
# - id: ruff-format # TODO: enable when the time is right
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
Expand Down
122 changes: 69 additions & 53 deletions benchmarking/switchback/make_plot_with_jsonl.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import pandas as pd

cmap=plt.get_cmap('cool')

if __name__ == '__main__':
cmap = plt.get_cmap("cool")

fig = plt.figure(tight_layout=True, figsize=(12,3.5))
if __name__ == "__main__":
fig = plt.figure(tight_layout=True, figsize=(12, 3.5))
gs = gridspec.GridSpec(1, 2)

dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
Expand All @@ -19,25 +17,28 @@
ax = fig.add_subplot(gs[0, 0])

# TODO: change this to what you want.
rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True)
df = rdf[rdf.batch_size == batch_size_for_plot1]

# first plot the time occupied by different operations
for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),

('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'),
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'),
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'),

('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),

('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
('w_quantize_global', '.', '--', 'C4', 'Quantize global W (switchback)'),
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize global and\ntranspose W (switchback)'),
("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"),
(
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
"o",
"-",
"C4",
"SwitchBack int8 (sum of parts)",
),
("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"),
("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"),
("standard_gx", "^", ":", "gray", "Matmul GX (both)"),
("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"),
("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"),
("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"),
("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"),
("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"),
("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"),
]:
xs = []
ys = []
Expand All @@ -47,89 +48,104 @@
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
for k_ in k.split("+"):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
for k_ in k.split("+"):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)

ax.plot(
xs,
ys,
color=color,
label=name,
marker=marker,
markersize=5 if marker == "s" else 5,
linestyle=ls,
linewidth=2 if "+" in k else 1.0,
)

ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)


ax.set_xlabel('dim', fontsize=13)
ax.set_ylabel('time (ms)', fontsize=13)
ax.set_xlabel("dim", fontsize=13)
ax.set_ylabel("time (ms)", fontsize=13)

ax.grid()

ax.set_xscale('log')
ax.set_xscale("log")
if logscale_plot1:
ax.set_yscale('log')
ax.set_yscale("log")

ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.tick_params(axis="x", labelsize=11)
ax.tick_params(axis="y", labelsize=11)

ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True)

leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight('bold')
leg.get_texts()[1].set_fontweight('bold')
leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight("bold")
leg.get_texts()[1].set_fontweight("bold")
plt.subplots_adjust(left=0.1)
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20)

ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20)

ax = fig.add_subplot(gs[0, 1])

# now plot the % speedup for different batch sizes
for j, batch_size in enumerate(batch_sizes_for_plot2):
all_xs, all_ys = [], []
for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"),
(
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
"o",
"-",
"C4",
"SwitchBack int8 (total time)",
),
]:

xs, ys = [], []
df = rdf[rdf.batch_size == batch_size]
for embed_dim in dims_to_consider:
df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
for k_ in k.split("+"):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
for k_ in k.split("+"):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
all_xs.append(xs)
all_ys.append(ys)

color = cmap(j * 0.25)
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
markers = ['^', 'v', 'P', 'o']
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
markers = ["^", "v", "P", "o"]
ax.plot(
all_xs[0],
real_ys,
color=color,
label=f"batch * sequence length = {batch_size}",
marker=markers[j],
markersize=5 if marker == "s" else 5,
)

ax.legend()
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.set_xlabel("dim", fontsize=13)
ax.set_xscale("log")
ax.grid()
ax.set_ylabel(r'% speedup', fontsize=13)
ax.set_ylabel(r"% speedup", fontsize=13)


ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.tick_params(axis="x", labelsize=11)
ax.tick_params(axis="y", labelsize=11)

ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True)

ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)


ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20)

plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')
plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight")
122 changes: 86 additions & 36 deletions benchmarking/switchback/speed_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,31 @@

# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.

def get_time(k, fn, info_dict):

def get_time(k, fn, info_dict):
for _ in range(repeat // 2):
fn()
fn()

torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
fn()
fn()

torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info_dict[k] = ms

if __name__ == '__main__':

if __name__ == "__main__":
torch.manual_seed(0)
wm = 4
for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:

for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]:
# switch switches dim_in and dim_out
for switch in [False, True]:

# hparams
repeat = 64
batch_size = batch_size
Expand Down Expand Up @@ -73,35 +72,86 @@ def get_time(k, fn, info_dict):
state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max()

info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}

get_time('standard_fwd', lambda : x.matmul(w.t()), info)
get_time('standard_gw', lambda : g.t().matmul(x), info)
get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
get_time('w_quantize_global', lambda : quantize_global(w), info)
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)

time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']

print('TOTAL STANDARD', time_standard)
print('TOTAL ROWWISE', time_rowwise)
print('TOTAL GLOBAL', time_global)

print('speedup', -100*(time_global - time_standard)/time_standard)

info['time_standard'] = time_standard
info['time_rowwise'] = time_rowwise
info['time_global'] = time_global
info = {
"repeat": repeat,
"batch_size": batch_size,
"dim_out": dim_out,
"dim_in": dim_in,
"wm": wm,
"switch": switch,
}

get_time("standard_fwd", lambda: x.matmul(w.t()), info)
get_time("standard_gw", lambda: g.t().matmul(x), info)
get_time("standard_gx", lambda: g.matmul(w), info)
get_time(
"rowwise_fwd",
lambda: int8_matmul_rowwise_dequantize(
x_int8,
w_int8.t(),
state_x_rowwise,
state_w_columnwise,
None,
),
info,
)
get_time(
"rowwise_bwd",
lambda: int8_matmul_rowwise_dequantize(
g_int8,
wt_int8.t(),
state_x_rowwise,
state_w_rowwise,
None,
),
info,
)
get_time(
"global_fwd",
lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None),
info,
)
get_time(
"global_bwd",
lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None),
info,
)
get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info)
get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info)
get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info)
get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info)
get_time("w_quantize_global", lambda: quantize_global(w), info)
get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info)

time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"]
time_rowwise = (
info["x_quantize_rowwise"]
+ info["g_quantize_rowwise"]
+ info["w_quantize_colwise_transpose"]
+ info["w_quantize_rowwise"]
+ info["standard_gw"]
+ info["rowwise_fwd"]
+ info["rowwise_bwd"]
)
time_global = (
info["x_quantize_rowwise"]
+ info["g_quantize_rowwise"]
+ info["w_quantize_global"]
+ info["w_quantize_global_transpose"]
+ info["standard_gw"]
+ info["global_fwd"]
+ info["global_bwd"]
)

print("TOTAL STANDARD", time_standard)
print("TOTAL ROWWISE", time_rowwise)
print("TOTAL GLOBAL", time_global)

print("speedup", -100 * (time_global - time_standard) / time_standard)

info["time_standard"] = time_standard
info["time_rowwise"] = time_rowwise
info["time_global"] = time_global

info_json = json.dumps(info)

Expand Down
Loading

0 comments on commit 06029dd

Please sign in to comment.