From 02e30ca6e4d629d9cddc87886acf06207f07f803 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 24 Feb 2024 11:54:14 +0200 Subject: [PATCH 1/2] Upgrade Ruff + configure formatting --- .pre-commit-config.yaml | 4 ++-- pyproject.toml | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c8ccfe8df..a859d05af 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,11 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.0 + rev: v0.3.2 hooks: - id: ruff args: - --fix - # - id: ruff-format # TODO: enable when the time is right + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index f74750720..609ff84fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,10 @@ src = [ "tests", "benchmarking" ] +target-version = "py38" +line-length = 119 + +[tool.ruff.lint] select = [ "B", # bugbear: security warnings "E", # pycodestyle @@ -17,7 +21,6 @@ select = [ "UP", # alert you when better syntax is available in your python version "RUF", # the ruff developer's own rules ] -target-version = "py38" ignore = [ "B007", # Loop control variable not used within the loop body (TODO: enable) "B028", # Warning without stacklevel (TODO: enable) @@ -30,7 +33,7 @@ ignore = [ ] ignore-init-module-imports = true # allow to expose in __init__.py via imports -[tool.ruff.extend-per-file-ignores] +[tool.ruff.lint.extend-per-file-ignores] "**/__init__.py" = ["F401"] # allow unused imports in __init__.py "{benchmarking,tests}/**/*.py" = [ "B007", @@ -42,7 +45,7 @@ ignore-init-module-imports = true # allow to expose in __init__.py via imports "UP030", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] combine-as-imports = true detect-same-package = true force-sort-within-sections = true From 5a4263f4dc05fe8f78f4111beab9f68a81deeab1 Mon Sep 17 00:00:00 2001 From: Ruff Date: Sat, 24 Feb 2024 12:01:15 +0200 Subject: [PATCH 2/2] Reformat with ruff-format --- .github/scripts/set_platform_tag.py | 4 +- .../switchback/make_plot_with_jsonl.py | 122 +- benchmarking/switchback/speed_benchmark.py | 122 +- bitsandbytes/autograd/_functions.py | 50 +- bitsandbytes/cextension.py | 6 +- bitsandbytes/diagnostics/cuda.py | 6 +- bitsandbytes/diagnostics/main.py | 6 +- bitsandbytes/functional.py | 1020 +++++++++++------ bitsandbytes/nn/modules.py | 262 +++-- bitsandbytes/nn/triton_based_modules.py | 72 +- bitsandbytes/optim/adagrad.py | 12 +- bitsandbytes/optim/adam.py | 255 ++++- bitsandbytes/optim/adamw.py | 193 +++- bitsandbytes/optim/lars.py | 20 +- bitsandbytes/optim/lion.py | 171 ++- bitsandbytes/optim/optimizer.py | 172 +-- bitsandbytes/optim/rmsprop.py | 12 +- bitsandbytes/research/autograd/_functions.py | 47 +- bitsandbytes/research/nn/modules.py | 19 +- bitsandbytes/triton/dequantize_rowwise.py | 38 +- .../triton/int8_matmul_mixed_dequantize.py | 148 ++- .../triton/int8_matmul_rowwise_dequantize.py | 147 ++- .../quantize_columnwise_and_transpose.py | 48 +- bitsandbytes/triton/quantize_global.py | 81 +- bitsandbytes/triton/quantize_rowwise.py | 37 +- bitsandbytes/utils.py | 28 +- check_bnb_install.py | 10 +- examples/int8_inference_huggingface.py | 13 +- install_cuda.py | 16 +- scripts/stale.py | 3 +- tests/test_autograd.py | 140 +-- tests/test_cuda_setup_evaluator.py | 5 +- tests/test_functional.py | 777 ++++++------- tests/test_generation.py | 75 +- tests/test_linear4bit.py | 12 +- tests/test_linear8bitlt.py | 18 +- tests/test_modules.py | 148 ++- tests/test_optim.py | 91 +- tests/test_triton.py | 19 +- 39 files changed, 2653 insertions(+), 1772 deletions(-) diff --git a/.github/scripts/set_platform_tag.py b/.github/scripts/set_platform_tag.py index ca561c880..c82077074 100644 --- a/.github/scripts/set_platform_tag.py +++ b/.github/scripts/set_platform_tag.py @@ -7,9 +7,7 @@ def get_platform_tag(architecture): system = platform.system() if system == "Linux": - tag = ( - "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64" - ) + tag = "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64" elif system == "Darwin": tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64" elif system == "Windows": diff --git a/benchmarking/switchback/make_plot_with_jsonl.py b/benchmarking/switchback/make_plot_with_jsonl.py index b23f63562..fd0dd7d58 100644 --- a/benchmarking/switchback/make_plot_with_jsonl.py +++ b/benchmarking/switchback/make_plot_with_jsonl.py @@ -1,13 +1,11 @@ - import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import pandas as pd -cmap=plt.get_cmap('cool') - -if __name__ == '__main__': +cmap = plt.get_cmap("cool") - fig = plt.figure(tight_layout=True, figsize=(12,3.5)) +if __name__ == "__main__": + fig = plt.figure(tight_layout=True, figsize=(12, 3.5)) gs = gridspec.GridSpec(1, 2) dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096] @@ -19,25 +17,28 @@ ax = fig.add_subplot(gs[0, 0]) # TODO: change this to what you want. - rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True) + rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True) df = rdf[rdf.batch_size == batch_size_for_plot1] # first plot the time occupied by different operations for k, marker, ls, color, name in [ - ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), - ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), - - ('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'), - ('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'), - ('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'), - - ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), - ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), - - ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), - ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), - ('w_quantize_global', '.', '--', 'C4', 'Quantize global W (switchback)'), - ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize global and\ntranspose W (switchback)'), + ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"), + ( + "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", + "o", + "-", + "C4", + "SwitchBack int8 (sum of parts)", + ), + ("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"), + ("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"), + ("standard_gx", "^", ":", "gray", "Matmul GX (both)"), + ("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"), + ("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"), + ("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"), + ("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"), + ("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"), + ("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"), ]: xs = [] ys = [] @@ -47,40 +48,46 @@ df_ = df_[df_.dim_out == embed_dim * 4] xs.append(embed_dim) y_ = 0 - for k_ in k.split('+'): + for k_ in k.split("+"): y_ += df_[k_].values[0] df_ = df[df.dim_in == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim] - for k_ in k.split('+'): + for k_ in k.split("+"): y_ += df_[k_].values[0] ys.append(y_ * 0.5) + ax.plot( + xs, + ys, + color=color, + label=name, + marker=marker, + markersize=5 if marker == "s" else 5, + linestyle=ls, + linewidth=2 if "+" in k else 1.0, + ) - ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) - - - ax.set_xlabel('dim', fontsize=13) - ax.set_ylabel('time (ms)', fontsize=13) + ax.set_xlabel("dim", fontsize=13) + ax.set_ylabel("time (ms)", fontsize=13) ax.grid() - ax.set_xscale('log') + ax.set_xscale("log") if logscale_plot1: - ax.set_yscale('log') + ax.set_yscale("log") - ax.tick_params(axis='x', labelsize=11) - ax.tick_params(axis='y', labelsize=11) + ax.tick_params(axis="x", labelsize=11) + ax.tick_params(axis="y", labelsize=11) ax.set_xticks(dims_to_xtick) ax.set_xticklabels(dims_to_xtick) ax.set_xticks([], minor=True) - leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) - leg.get_texts()[0].set_fontweight('bold') - leg.get_texts()[1].set_fontweight('bold') + leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10) + leg.get_texts()[0].set_fontweight("bold") + leg.get_texts()[1].set_fontweight("bold") plt.subplots_adjust(left=0.1) - ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20) - + ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20) ax = fig.add_subplot(gs[0, 1]) @@ -88,10 +95,15 @@ for j, batch_size in enumerate(batch_sizes_for_plot2): all_xs, all_ys = [], [] for k, marker, ls, color, name in [ - ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), - ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), + ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"), + ( + "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", + "o", + "-", + "C4", + "SwitchBack int8 (total time)", + ), ]: - xs, ys = [], [] df = rdf[rdf.batch_size == batch_size] for embed_dim in dims_to_consider: @@ -99,11 +111,11 @@ df_ = df_[df_.dim_out == embed_dim * 4] xs.append(embed_dim) y_ = 0 - for k_ in k.split('+'): + for k_ in k.split("+"): y_ += df_[k_].values[0] df_ = df[df.dim_in == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim] - for k_ in k.split('+'): + for k_ in k.split("+"): y_ += df_[k_].values[0] ys.append(y_ * 0.5) all_xs.append(xs) @@ -111,25 +123,29 @@ color = cmap(j * 0.25) real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] - markers = ['^', 'v', 'P', 'o'] - ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) + markers = ["^", "v", "P", "o"] + ax.plot( + all_xs[0], + real_ys, + color=color, + label=f"batch * sequence length = {batch_size}", + marker=markers[j], + markersize=5 if marker == "s" else 5, + ) ax.legend() - ax.set_xlabel('dim', fontsize=13) - ax.set_xscale('log') + ax.set_xlabel("dim", fontsize=13) + ax.set_xscale("log") ax.grid() - ax.set_ylabel(r'% speedup', fontsize=13) + ax.set_ylabel(r"% speedup", fontsize=13) - - ax.tick_params(axis='x', labelsize=11) - ax.tick_params(axis='y', labelsize=11) + ax.tick_params(axis="x", labelsize=11) + ax.tick_params(axis="y", labelsize=11) ax.set_xticks(dims_to_xtick) ax.set_xticklabels(dims_to_xtick) ax.set_xticks([], minor=True) - ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) - - + ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20) - plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight') + plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight") diff --git a/benchmarking/switchback/speed_benchmark.py b/benchmarking/switchback/speed_benchmark.py index c4f3cd4c6..eaba0e9cd 100644 --- a/benchmarking/switchback/speed_benchmark.py +++ b/benchmarking/switchback/speed_benchmark.py @@ -20,15 +20,15 @@ # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. -def get_time(k, fn, info_dict): +def get_time(k, fn, info_dict): for _ in range(repeat // 2): - fn() + fn() torch.cuda.synchronize() start = time.time() for _ in range(repeat): - fn() + fn() torch.cuda.synchronize() end = time.time() @@ -36,16 +36,15 @@ def get_time(k, fn, info_dict): print(f"time {k}: {ms:.3f} ms") info_dict[k] = ms -if __name__ == '__main__': + +if __name__ == "__main__": torch.manual_seed(0) wm = 4 for dim in [1024, 1280, 1408, 1664, 2048, 4096]: # note "batch_size" is actually "batch_size * embed_dim", which is why it's large - for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]: - + for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]: # switch switches dim_in and dim_out for switch in [False, True]: - # hparams repeat = 64 batch_size = batch_size @@ -73,35 +72,86 @@ def get_time(k, fn, info_dict): state_w_rowwise = w.max(dim=1)[0] state_w_global = w.max() - info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch} - - get_time('standard_fwd', lambda : x.matmul(w.t()), info) - get_time('standard_gw', lambda : g.t().matmul(x), info) - get_time('standard_gx', lambda : g.matmul(w), info) - get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info) - get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info) - get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) - get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info) - get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info) - get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info) - get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info) - get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info) - get_time('w_quantize_global', lambda : quantize_global(w), info) - get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info) - - time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw'] - time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd'] - time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] - - print('TOTAL STANDARD', time_standard) - print('TOTAL ROWWISE', time_rowwise) - print('TOTAL GLOBAL', time_global) - - print('speedup', -100*(time_global - time_standard)/time_standard) - - info['time_standard'] = time_standard - info['time_rowwise'] = time_rowwise - info['time_global'] = time_global + info = { + "repeat": repeat, + "batch_size": batch_size, + "dim_out": dim_out, + "dim_in": dim_in, + "wm": wm, + "switch": switch, + } + + get_time("standard_fwd", lambda: x.matmul(w.t()), info) + get_time("standard_gw", lambda: g.t().matmul(x), info) + get_time("standard_gx", lambda: g.matmul(w), info) + get_time( + "rowwise_fwd", + lambda: int8_matmul_rowwise_dequantize( + x_int8, + w_int8.t(), + state_x_rowwise, + state_w_columnwise, + None, + ), + info, + ) + get_time( + "rowwise_bwd", + lambda: int8_matmul_rowwise_dequantize( + g_int8, + wt_int8.t(), + state_x_rowwise, + state_w_rowwise, + None, + ), + info, + ) + get_time( + "global_fwd", + lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), + info, + ) + get_time( + "global_bwd", + lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), + info, + ) + get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info) + get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info) + get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info) + get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info) + get_time("w_quantize_global", lambda: quantize_global(w), info) + get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info) + + time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"] + time_rowwise = ( + info["x_quantize_rowwise"] + + info["g_quantize_rowwise"] + + info["w_quantize_colwise_transpose"] + + info["w_quantize_rowwise"] + + info["standard_gw"] + + info["rowwise_fwd"] + + info["rowwise_bwd"] + ) + time_global = ( + info["x_quantize_rowwise"] + + info["g_quantize_rowwise"] + + info["w_quantize_global"] + + info["w_quantize_global_transpose"] + + info["standard_gw"] + + info["global_fwd"] + + info["global_bwd"] + ) + + print("TOTAL STANDARD", time_standard) + print("TOTAL ROWWISE", time_rowwise) + print("TOTAL GLOBAL", time_global) + + print("speedup", -100 * (time_global - time_standard) / time_standard) + + info["time_standard"] = time_standard + info["time_rowwise"] = time_rowwise + info["time_global"] = time_global info_json = json.dumps(info) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 6cbb6efd9..e9821cd36 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -14,16 +14,18 @@ def prod(iterable): return reduce(operator.mul, iterable, 1) + # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py - """ This class pools outlier dimensions across layers. This is particularly important for small models where outlier features are less systematic and occur with low frequency. """ + + class GlobalOutlierPooler: _instance = None @@ -83,6 +85,7 @@ def get_inverse_transform_indices( break # if all indices fit in i bytes, stop early return permuted_tile_indices + def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: """ Undo a tiled permutation such as turing or ampere layout @@ -159,20 +162,12 @@ def backward(ctx, grad_output): ) if not A.is_contiguous(): A = A.contiguous() - qA, S2 = F.vectorwise_quant( - A.view(-1, A.shape[2]), dim=0, quant_type=quant_type - ) + qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type) igrad_B = F.igemm(qA.t(), qgrad_output) - grad_B = F.vectorwise_mm_dequant( - igrad_B, S2.t(), S1, grad_output.dtype, quant_type - ) + grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type) else: - qgrad_output, S1 = F.vectorwise_quant( - grad_output, dim=dims, quant_type=quant_type - ) - qA, S2 = F.vectorwise_quant( - A, dim=dims, quant_type=quant_type - ) + qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) + qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) grad_B = F.vectorwise_mm_dequant( igrad_B, @@ -201,9 +196,7 @@ def backward(ctx, grad_output): with torch.no_grad(): grad_A = torch.matmul(grad_output, B.permute(permute_dim)) else: - qgrad_output, S1 = F.vectorwise_quant( - grad_output, dim=dims, quant_type=quant_type - ) + qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) grad_A = F.vectorwise_mm_dequant( @@ -227,7 +220,7 @@ def supports_igemmlt(device: torch.device) -> bool: if torch.cuda.get_device_capability(device=device) < (7, 5): return False device_name = torch.cuda.get_device_name(device=device) - nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series + nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series if any(model_name in device_name for model_name in nvidia16_models): return False # these devices are technically cuda 7.5-capable, but they lack tensor cores return True @@ -246,6 +239,7 @@ def get_tile_inds(format, device): with torch.no_grad(): return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device) + @dataclass class MatmulLtState: _tile_indices: Optional[torch.Tensor] = None @@ -510,7 +504,6 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] else: return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) - # 1. Dequantize # 2. MatmulnN output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) @@ -532,7 +525,7 @@ def backward(ctx, grad_output): bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad + req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad A, B = ctx.tensors grad_A, grad_B, grad_bias = None, None, None @@ -542,8 +535,9 @@ def backward(ctx, grad_output): grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) # not supported by PyTorch. TODO: create work-around - #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) - if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) + # if req_gradB: grad_B = torch.matmul(grad_output.t(), A) + if req_gradA: + grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) return grad_A, grad_B, None, grad_bias, None @@ -554,7 +548,7 @@ def matmul( out: Optional[torch.Tensor] = None, state: Optional[MatmulLtState] = None, threshold=0.0, - bias=None + bias=None, ): state = state or MatmulLtState() if threshold > 0.0: @@ -562,11 +556,19 @@ def matmul( return MatMul8bitLt.apply(A, B, out, bias, state) -def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None): +def matmul_4bit( + A: torch.Tensor, + B: torch.Tensor, + quant_state: F.QuantState, + out: Optional[torch.Tensor] = None, + bias=None, +): assert quant_state is not None if A.numel() == A.shape[-1] and A.requires_grad == False: if A.shape[-1] % quant_state.blocksize != 0: - warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') + warn( + f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", + ) return MatMul4Bit.apply(A, B, out, bias, quant_state) else: out = F.gemv_4bit(A, B.t(), out, state=quant_state) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 57ba71020..c8ae7358d 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -56,7 +56,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" - "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: BNBNativeLibrary: logger.warning( "The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable." + "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.", ) return BNBNativeLibrary(dll) @@ -120,5 +120,5 @@ def get_native_library() -> BNBNativeLibrary: Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues -""" +""", ) diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index d65f80d8b..f993dff7e 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -120,7 +120,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: The CUDA version for the compile might depend on your conda install, if using conda. Inspect CUDA version via `conda list | grep cuda`. - """ + """, ) cuda_major, cuda_minor = cuda_specs.cuda_version_tuple @@ -129,7 +129,7 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: """ WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines! - """ + """, ) print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") @@ -170,7 +170,7 @@ def print_cuda_runtime_diagnostics() -> None: In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, - """ + """, ) for pth in cudart_paths: print(f"* Found CUDA runtime at: {pth}") diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 7a88bca26..1ce096f69 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -25,7 +25,7 @@ def sanity_check(): See the documentation for more details if needed. Trying a simple check anyway, but this will likely fail... - """ + """, ) from bitsandbytes.optim import Adam @@ -71,7 +71,7 @@ def main(): print( f"WARNING: {__package__} is currently running as CPU-only!\n" "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" - f"If you think that this is so erroneously,\nplease report an issue!" + f"If you think that this is so erroneously,\nplease report an issue!", ) except Exception: traceback.print_exc() @@ -80,6 +80,6 @@ def main(): Above we output some debug information. Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose WARNING: Please be sure to sanitize sensitive info from the output before posting it. - """ + """, ) sys.exit(1) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 61d0d83b2..8fa8f2f60 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -21,6 +21,7 @@ def prod(iterable): return reduce(operator.mul, iterable, 1) + name2qmap = {} if lib and lib.compiled_with_cuda: @@ -127,7 +128,6 @@ def prefetch_all(self, to_cpu=False): prefetch_tensor(t, to_cpu) - class CUBLAS_Context: _instance = None @@ -169,6 +169,7 @@ def get_instance(cls): cls._instance.initialize() return cls._instance + dtype2bytes = {} dtype2bytes[torch.float32] = 4 dtype2bytes[torch.float16] = 2 @@ -176,10 +177,11 @@ def get_instance(cls): dtype2bytes[torch.uint8] = 1 dtype2bytes[torch.int8] = 1 -FIRST_CUDA_DEVICE = torch.device('cuda', index=0) +FIRST_CUDA_DEVICE = torch.device("cuda", index=0) + def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): - num_bytes = dtype2bytes[dtype]*prod(shape) + num_bytes = dtype2bytes[dtype] * prod(shape) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) new_array = np.ctypeslib.as_array(c_ptr, shape=shape) @@ -188,31 +190,35 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): out.page_deviceid = device.index return out + def prefetch_tensor(A, to_cpu=False): - assert A.is_paged, 'Only paged tensors can be prefetched!' + assert A.is_paged, "Only paged tensors can be prefetched!" if to_cpu: deviceid = -1 else: deviceid = A.page_deviceid - num_bytes = dtype2bytes[A.dtype]*A.numel() + num_bytes = dtype2bytes[A.dtype] * A.numel() lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) + def elementwise_func(func_name, A, B, value, prefetch=True): func = None if A.dtype == torch.float32: - func = getattr(lib, f'c{func_name}_fp32', None) + func = getattr(lib, f"c{func_name}_fp32", None) cvalue = ct.c_float(value) elif A.dtype == torch.uint8: - func = getattr(lib, f'c{func_name}_uint8', None) + func = getattr(lib, f"c{func_name}_uint8", None) cvalue = ct.c_uint8(value) - if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') + if func is None: + raise NotImplementedError(f"Function not implemented: {func_name}") - is_managed = getattr(A, 'is_managed', False) + is_managed = getattr(A, "is_managed", False) if is_managed and prefetch: prefetch_tensor(A) - if B is not None: prefetch_tensor(B) + if B is not None: + prefetch_tensor(B) func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) if A.is_paged or B.is_paged: @@ -222,28 +228,36 @@ def elementwise_func(func_name, A, B, value, prefetch=True): # operation occurred. So we synchronize. torch.cuda.synchronize() -def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) -def arange(A, device=None): elementwise_func('arange', A, None, 0) -def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) + +def fill(A, value, device=None, prefetch=True): + elementwise_func("fill", A, None, value) + + +def arange(A, device=None): + elementwise_func("arange", A, None, 0) + + +def _mul(A, B, device=None): + elementwise_func("_mul", A, B, 0) def create_linear_map(signed=True, total_bits=8, add_zero=True): - sign = (-1.0 if signed else 0.0) + sign = -1.0 if signed else 0.0 total_values = 2**total_bits if add_zero or total_bits < 8: # add a zero # since we simulate less bits by having zeros in the data type, we # we need to center the quantization around zero and as such lose # a single value - total_values = (2**total_bits if not signed else 2**total_bits-1) + total_values = 2**total_bits if not signed else 2**total_bits - 1 values = torch.linspace(sign, 1.0, total_values) gap = 256 - values.numel() if gap == 0: return values else: - l = values.numel()//2 # noqa: E741 - return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) + l = values.numel() // 2 # noqa: E741 + return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist()) def create_normal_map(offset=0.9677083, use_extra_value=True): @@ -251,18 +265,17 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): from scipy.stats import norm except ImportError as ie: raise ImportError( - "Scipy is required for `create_normal_map`. " - "Install `bitsandbytes` with the `[test]` extra." + "Scipy is required for `create_normal_map`. Install `bitsandbytes` with the `[test]` extra.", ) from ie if use_extra_value: # one more positive value, this is an asymmetric type v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() - v2 = [0]*(256-15) ## we have 15 non-zero values in this data type + v2 = [0] * (256 - 15) ## we have 15 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() else: v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() - v2 = [0]*(256-14) ## we have 14 non-zero values in this data type + v2 = [0] * (256 - 14) ## we have 14 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() v = v1 + v2 + v3 @@ -275,38 +288,37 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): return values + def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits p = precision_bits has_sign = 1 if signed else 0 - assert e+p == total_bits-has_sign + assert e + p == total_bits - has_sign # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] pvalues = [] - for i, val in enumerate(range(-(2**(exponent_bits-has_sign)), 2**(exponent_bits-has_sign), 1)): + for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)): evalues.append(2**val) - values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) - #for ev in evalues: - bias = 2**(exponent_bits-1) - for evalue in range(2**(exponent_bits)): + # for ev in evalues: + bias = 2 ** (exponent_bits - 1) + for evalue in range(2 ** (exponent_bits)): for bit_pattern in lst: - value = (1 if evalue != 0 else 0) + value = 1 if evalue != 0 else 0 for i, pval in enumerate(list(bit_pattern)): - value += pval*(2**-(i+1)) + value += pval * (2 ** -(i + 1)) if evalue == 0: # subnormals - value = value*2**-(bias) + value = value * 2**-(bias) else: # normals - value = value*2**-(evalue-bias-1) + value = value * 2 ** -(evalue - bias - 1) values.append(value) if signed: values.append(-value) - assert len(values) == 2**total_bits values.sort() if total_bits < 8: @@ -320,7 +332,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) return code - def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): """ Creates the dynamic quantiztion map. @@ -345,7 +356,11 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): non_sign_bits = total_bits - (1 if signed else 1) additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 for i in range(max_exponent_bits): - fraction_items = int(2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1) + fraction_items = int( + 2 ** (i + non_sign_bits - max_exponent_bits) + 1 + if signed + else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1, + ) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() @@ -371,8 +386,9 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): data.sort() return Tensor(data) + def create_quantile_map(A, total_bits=8): - q = estimate_quantiles(A, num_quantiles=2**total_bits-1) + q = estimate_quantiles(A, num_quantiles=2**total_bits - 1) q = q.tolist() q.append(0) @@ -383,11 +399,13 @@ def create_quantile_map(A, total_bits=8): q.sort() q = Tensor(q) - q = q/q.abs().max() + q = q / q.abs().max() return q + def get_special_format_str(): - if not torch.cuda.is_available(): return 'col_turing' + if not torch.cuda.is_available(): + return "col_turing" major, _minor = torch.cuda.get_device_capability() if major <= 7: return "col_turing" @@ -396,20 +414,24 @@ def get_special_format_str(): return "col_turing" - def is_on_gpu(tensors): on_gpu = True gpu_ids = set() for t in tensors: - if t is None: continue # NULL pointers are fine - is_paged = getattr(t, 'is_paged', False) - on_gpu &= (t.device.type == 'cuda' or is_paged) + if t is None: + continue # NULL pointers are fine + is_paged = getattr(t, "is_paged", False) + on_gpu &= t.device.type == "cuda" or is_paged if not is_paged: gpu_ids.add(t.device.index) if not on_gpu: - raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') + raise TypeError( + f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}", + ) if len(gpu_ids) > 1: - raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') + raise TypeError( + f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}", + ) return on_gpu @@ -447,15 +469,13 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False): if not hasattr(lib, name): print(name) raise ValueError( - f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}" + f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}", ) else: return getattr(lib, name) -def get_transform_buffer( - shape, dtype, device, to_order, from_order="row", transpose=False -): +def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False): # init_func = torch.empty init_func = torch.zeros dims = len(shape) @@ -508,9 +528,7 @@ def nvidia_transform( else: from_order = state[1] if out is None: - out, new_state = get_transform_buffer( - state[0], A.dtype, A.device, to_order, state[1] - ) + out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1]) else: new_state = (state[1], to_order) func = get_transform_func(A.dtype, from_order, to_order, transpose) @@ -534,8 +552,13 @@ def nvidia_transform( return out, new_state -def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: - ''' +def estimate_quantiles( + A: Tensor, + out: Optional[torch.Tensor] = None, + offset: float = 1 / 512, + num_quantiles=256, +) -> Tensor: + """ Estimates 256 equidistant quantiles on the input tensor eCDF. Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles @@ -562,14 +585,21 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl ------- torch.Tensor: The 256 quantiles in float32 datatype. - ''' - if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') - if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") - if num_quantiles < 256 and offset == 1/(512): + """ + if A.numel() < 256: + raise NotImplementedError( + f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.", + ) + if num_quantiles > 256: + raise NotImplementedError( + f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}", + ) + if num_quantiles < 256 and offset == 1 / (512): # override default arguments - offset = 1/(2*num_quantiles) + offset = 1 / (2 * num_quantiles) - if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) + if out is None: + out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out]) device = pre_call(A.device) if A.dtype == torch.float32: @@ -581,7 +611,7 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl post_call(device) if num_quantiles < 256: - step = round(256/num_quantiles) + step = round(256 / num_quantiles) idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) out = out[idx] @@ -590,12 +620,35 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl class QuantState: """container for quantization state components to work with Params4bit and similar classes""" - valid_quant_types = ('fp4', 'nf4') - valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] - valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type', - 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] - def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): + valid_quant_types = ("fp4", "nf4") + valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] + valid_qs_keys = [ + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "quant_state", + "quant_type", + "blocksize", + "dtype", + "shape", + "nested_blocksize", + "nested_dtype", + "nested_offset", + ] + + def __init__( + self, + absmax, + shape=None, + code=None, + blocksize=None, + quant_type=None, + dtype=None, + offset=None, + state2=None, + ): self.absmax = absmax self.shape = shape self.code = code @@ -614,13 +667,20 @@ def __get_item__(self, idx): state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] """ if self.nested: - list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type] + list_repr = [ + self.absmax, + self.shape, + self.dtype, + self.blocksize, + [self.offset, self.state2], + self.quant_type, + ] else: list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] return list_repr[idx] @classmethod - def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState': + def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": """ unpacks components of state_dict into QuantState where necessary, convert into strings, torch.dtype, ints, etc. @@ -632,37 +692,39 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState # unpacking tensor with non-tensor components qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] - if not len(qs_key) and 'quant_type' not in qs_dict: + if not len(qs_key) and "quant_type" not in qs_dict: raise ValueError("Expected packed or unpacked quant_state items, found neither") elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: - raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.") + raise ValueError( + f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", + ) # unpacking minor and non-tensor quant state items if necessary if len(qs_key) == 1: first_qs_key = qs_key[0] qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) - qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes + qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) - if 'nested_absmax' in qs_dict: - offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) + if "nested_absmax" in qs_dict: + offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) state2 = cls( - absmax=qs_dict['nested_absmax'].to(device), - blocksize=qs_dict['nested_blocksize'], - code=qs_dict['nested_quant_map'].to(device), - dtype=getattr(torch, qs_dict['nested_dtype']), + absmax=qs_dict["nested_absmax"].to(device), + blocksize=qs_dict["nested_blocksize"], + code=qs_dict["nested_quant_map"].to(device), + dtype=getattr(torch, qs_dict["nested_dtype"]), ) else: offset, state2 = None, None quant_state = cls( - quant_type=qs_dict['quant_type'], - absmax=qs_dict['absmax'].to(device), - blocksize=qs_dict['blocksize'], - code=qs_dict['quant_map'].to(device), - dtype=getattr(torch, qs_dict['dtype']), - shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None, + quant_type=qs_dict["quant_type"], + absmax=qs_dict["absmax"].to(device), + blocksize=qs_dict["blocksize"], + code=qs_dict["quant_map"].to(device), + dtype=getattr(torch, qs_dict["dtype"]), + shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None, offset=offset, state2=state2, ) @@ -674,21 +736,23 @@ def as_dict(self, packed=False): param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving """ qs_dict = { - 'quant_type': self.quant_type, - 'absmax': self.absmax, - 'blocksize': self.blocksize, - 'quant_map': self.code, - 'dtype': str(self.dtype).strip('torch.'), - 'shape': tuple(self.shape), + "quant_type": self.quant_type, + "absmax": self.absmax, + "blocksize": self.blocksize, + "quant_map": self.code, + "dtype": str(self.dtype).strip("torch."), + "shape": tuple(self.shape), } if self.nested: - qs_dict.update({ - 'nested_absmax': self.state2.absmax, - 'nested_blocksize': self.state2.blocksize, - 'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors - 'nested_dtype': str(self.state2.dtype).strip('torch.'), - 'nested_offset': self.offset.item(), - }) + qs_dict.update( + { + "nested_absmax": self.state2.absmax, + "nested_blocksize": self.state2.blocksize, + "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors + "nested_dtype": str(self.state2.dtype).strip("torch."), + "nested_offset": self.offset.item(), + }, + ) if not packed: return qs_dict @@ -711,14 +775,22 @@ def __eq__(self, other): return False return ( - torch.allclose(self.absmax, other.absmax, atol=1e-6) and - self.shape == other.shape and - torch.allclose(self.code, other.code, atol=1e-6) and - self.dtype == other.dtype and - self.blocksize == other.blocksize and - self.quant_type == other.quant_type and - (self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and - (self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2) + torch.allclose(self.absmax, other.absmax, atol=1e-6) + and self.shape == other.shape + and torch.allclose(self.code, other.code, atol=1e-6) + and self.dtype == other.dtype + and self.blocksize == other.blocksize + and self.quant_type == other.quant_type + and ( + self.offset == other.offset + if self.offset is not None and other.offset is not None + else self.offset is other.offset + ) + and ( + self.state2 == other.state2 + if self.state2 is not None and other.state2 is not None + else self.state2 is other.state2 + ) ) @@ -756,7 +828,6 @@ def quantize_blockwise( The quantization state to undo the quantization. """ - if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -771,31 +842,66 @@ def quantize_blockwise( if out is None: out = torch.zeros_like(A, dtype=torch.uint8) - if A.device.type != 'cpu': + if A.device.type != "cpu": assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) is_on_gpu([code, A, out, absmax]) if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + cblocksize, + ct.c_int(A.numel()), + ) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + cblocksize, + ct.c_int(A.numel()), + ) elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_bf16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + cblocksize, + ct.c_int(A.numel()), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: # cpu code = code.cpu() - lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) if nested: offset = absmax.mean() absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) - quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2) + quant_state = QuantState( + absmax=qabsmax, + code=code, + blocksize=blocksize, + dtype=A.dtype, + offset=offset, + state2=state2, + ) else: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) @@ -809,7 +915,7 @@ def dequantize_blockwise( code: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 4096, - nested=False + nested=False, ) -> Tensor: """ Dequantizes blockwise quantized values. @@ -843,43 +949,76 @@ def dequantize_blockwise( code = name2qmap["dynamic"] if quant_state is None: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() if out is None: out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) - if A.device.type != 'cpu': + if A.device.type != "cpu": device = pre_call(A.device) code = quant_state.code.to(A.device) if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + raise ValueError( + f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", + ) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp32( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + ) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp16( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + ) elif out.dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_bf16( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: code = quant_state.code.cpu() - lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel())) + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(quant_state.absmax), + get_ptr(out), + ct.c_longlong(quant_state.blocksize), + ct.c_longlong(A.numel()), + ) return out + def get_4bit_type(typename, device=None, blocksize=64): - if device is None: device = 'cuda' + if device is None: + device = "cuda" data = None - if typename == 'nf4': - ''' Implements the NF4 data type. + if typename == "nf4": + """ Implements the NF4 data type. Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that is normalized into the range [-1, 1]. @@ -888,12 +1027,26 @@ def get_4bit_type(typename, device=None, blocksize=64): Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. - ''' - data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, - -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, - 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, - 0.7229568362236023, 1.0] - elif typename == 'fp4': + """ + data = [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ] + elif typename == "fp4": # 0b000 = 0 # 0b001 = 0.0625 # 0b010 = 8 @@ -904,20 +1057,35 @@ def get_4bit_type(typename, device=None, blocksize=64): # 0b111 = 3 # can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4) data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0] - elif typename == 'int4': + elif typename == "int4": data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7] - elif typename == 'af4': + elif typename == "af4": # Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good) # https://arxiv.org/abs/2306.06965 if blocksize == 64: - data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478, - -0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666, - 0.42563882, 0.55496234, 0.72424863, 1.][::-1] + data = [ + -1.0, + -0.69441008, + -0.51243739, + -0.3736951, + -0.25607552, + -0.14982478, + -0.04934812, + 0.0, + 0.04273164, + 0.12934483, + 0.21961274, + 0.31675666, + 0.42563882, + 0.55496234, + 0.72424863, + 1.0, + ][::-1] else: - raise NotImplementedError('4-bit AbnormalFloats currently only support blocksize 64.') + raise NotImplementedError("4-bit AbnormalFloats currently only support blocksize 64.") if data is None: - raise NotImplementedError(f'Typename {typename} not supported') + raise NotImplementedError(f"Typename {typename} not supported") data = Tensor(data) data /= data.abs().max() @@ -926,11 +1094,26 @@ def get_4bit_type(typename, device=None, blocksize=64): return data.to(device) -def quantize_fp4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage) +def quantize_fp4( + A: Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_storage=torch.uint8, +): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) -def quantize_nf4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) + +def quantize_nf4( + A: Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_storage=torch.uint8, +): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) def quantize_4bit( @@ -939,7 +1122,7 @@ def quantize_4bit( out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, - quant_type='fp4', + quant_type="fp4", quant_storage=torch.uint8, ) -> Tuple[Tensor, QuantState]: """ @@ -967,10 +1150,10 @@ def quantize_4bit( tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ - if A.device.type != 'cuda': - raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + if A.device.type != "cuda": + raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") n = A.numel() input_shape = A.shape @@ -980,10 +1163,9 @@ def quantize_4bit( blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - if out is None: mod = dtype2bytes[quant_storage] * 2 - out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device) + out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] @@ -991,20 +1173,62 @@ def quantize_4bit( is_on_gpu([A, out, absmax]) if A.dtype == torch.float32: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) elif A.dtype == torch.float16: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) elif A.dtype == torch.bfloat16: - if quant_type == 'fp4': - lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) @@ -1016,19 +1240,57 @@ def quantize_4bit( absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax - state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2) + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) else: - state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, ) + state = QuantState( + absmax=absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) return out, state -def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') -def dequantize_nf4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') +def dequantize_fp4( + A: Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, +) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") -def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor: + +def dequantize_nf4( + A: Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, +) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") + + +def dequantize_4bit( + A: Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type="fp4", +) -> Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -1056,23 +1318,31 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Dequantized tensor. """ if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", + ) + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") if quant_state is None: assert absmax is not None and out is not None - quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type) + quant_state = QuantState( + absmax=absmax, + shape=out.shape, + dtype=out.dtype, + blocksize=blocksize, + quant_type=quant_type, + ) else: absmax = quant_state.absmax - if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() if out is None: out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) @@ -1082,27 +1352,71 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: device = pre_call(A.device) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) elif out.dtype == torch.float16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) elif out.dtype == torch.bfloat16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) - is_transposed = (True if A.shape[0] == 1 else False) - if is_transposed: return out.t() - else: return out + is_transposed = True if A.shape[0] == 1 else False + if is_transposed: + return out.t() + else: + return out def quantize( @@ -1117,7 +1431,8 @@ def quantize( code = code.to(A.device) absmax = torch.abs(A).max() - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() inp = A / absmax out = quantize_no_absmax(inp, code, out) return out, (absmax, code) @@ -1144,7 +1459,7 @@ def dequantize( def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: - ''' + """ Quantizes input tensor to 8-bit. Quantizes the 32-bit input tensor `A` to the 8-bit output tensor @@ -1163,9 +1478,10 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No ------- torch.Tensor: Quantized 8-bit tensor. - ''' + """ prev_device = pre_call(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) is_on_gpu([A, out]) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) @@ -1173,7 +1489,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: - ''' + """ Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via @@ -1192,9 +1508,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = ------- torch.Tensor: 32-bit output tensor. - ''' + """ prev_device = pre_call(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.float32) + if out is None: + out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out]) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) @@ -1261,16 +1578,17 @@ def optimizer_update_32bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) - optim_func = None if g.dtype == torch.float32: optim_func = str2optimizer32bit[optimizer_name][0] elif g.dtype == torch.float16: optim_func = str2optimizer32bit[optimizer_name][1] - elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): + elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3: optim_func = str2optimizer32bit[optimizer_name][2] else: - raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) is_on_gpu([g, p, state1, state2, unorm_vec]) prev_device = pre_call(g.device) @@ -1290,7 +1608,8 @@ def optimizer_update_32bit( ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), - ct.c_int32(g.numel())) + ct.c_int32(g.numel()), + ) post_call(prev_device) @@ -1422,7 +1741,7 @@ def optimizer_update_8bit( ) else: raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", ) post_call(prev_device) @@ -1446,7 +1765,6 @@ def optimizer_update_8bit_blockwise( gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: - optim_func = None prev_device = pre_call(g.device) is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) @@ -1454,12 +1772,15 @@ def optimizer_update_8bit_blockwise( optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and - len(str2optimizer8bit_blockwise[optimizer_name])==3): + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): optim_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", ) post_call(prev_device) @@ -1487,9 +1808,8 @@ def optimizer_update_8bit_blockwise( ) post_call(prev_device) -def percentile_clipping( - grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 -): + +def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5): """Applies percentile clipping grad: torch.Tensor @@ -1531,9 +1851,7 @@ def percentile_clipping( return current_gnorm, clip_value, gnorm_scale -def histogram_scatter_add_2d( - histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor -): +def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor): assert len(histogram.shape) == 2 assert histogram.dtype == torch.float32 assert source.dtype == torch.float32 @@ -1550,12 +1868,12 @@ def histogram_scatter_add_2d( is_on_gpu([histogram, index1, index2, source]) lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) + def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): - if not torch.cuda.is_initialized(): torch.cuda.init() + if not torch.cuda.is_initialized(): + torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: - raise TypeError( - f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" - ) + raise TypeError(f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}") sA = A.shape sB = B.shape @@ -1596,12 +1914,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 sout = out.shape # special case common in backprop if not correct and len(sA) == 3 and len(sB) == 3: - if ( - sout[0] == sA[2] - and sout[1] == sB[2] - and sA[0] == sB[0] - and sA[1] == sB[1] - ): + if sout[0] == sA[2] and sout[1] == sB[2] and sA[0] == sB[0] and sA[1] == sB[1]: correct = True else: if len(sA) == 2 and len(sB) == 2: @@ -1634,26 +1947,29 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 if not correct: raise ValueError( - f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}." + f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.", ) return sout + def gemv_4bit( A: Tensor, B: Tensor, out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, - state=None + state=None, ): prev_device = pre_call(A.device) - #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: - raise ValueError('state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') + raise ValueError("state cannot None. gem_4bit( ) requires the state from quantize_4bit( )") if A.numel() != A.shape[-1]: - raise ValueError('Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') + raise ValueError( + 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]', + ) Bshape = state.shape bout = Bshape[0] @@ -1673,7 +1989,7 @@ def gemv_4bit( k = Bshape[1] lda = Bshape[0] ldc = Bshape[0] - ldb = (A.shape[-1]+1)//2 + ldb = (A.shape[-1] + 1) // 2 is_on_gpu([B, A, out, absmax, state.code]) m = ct.c_int32(m) n = ct.c_int32(n) @@ -1684,21 +2000,61 @@ def gemv_4bit( if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ) elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ) elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ) else: - raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") else: - raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") post_call(prev_device) return out + def igemm( A: Tensor, B: Tensor, @@ -1764,7 +2120,7 @@ def igemm( assert len(sA) == 3 if not (sA[0] == sB[0] and sA[1] == sB[1]): raise ValueError( - f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" + f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}", ) transposed_A = True @@ -1783,8 +2139,20 @@ def igemm( # B^T @ A^T = C^T # [km, nk -> mn] is_on_gpu([B, A, out]) - lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) + lib.cigemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ) return out @@ -1796,9 +2164,7 @@ def batched_igemm( transposed_B=False, ): if not len(A.shape) == 3 or not len(B.shape) == 3: - raise ValueError( - f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}" - ) + raise ValueError(f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}") sout = check_matmul(A, B, out, transposed_A, transposed_B) if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) @@ -1865,9 +2231,24 @@ def batched_igemm( ptr = CUBLAS_Context.get_instance().get_context(A.device) is_on_gpu([B, A, out]) - lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), - ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) + lib.cbatched_igemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ct.c_long(strideA), + ct.c_long(strideB), + ct.c_long(strideC), + ct.c_uint32(num_batch), + ) return out @@ -1876,14 +2257,14 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeB = SB[0] dimsA = len(shapeA) dimsB = len(shapeB) - assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' + assert dimsB == 2, "Only two dimensional matrices are supported for argument B" if dimsA == 2: m = shapeA[0] elif dimsA == 3: m = shapeA[0] * shapeA[1] rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' + assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" # if the tensor is empty, return a transformed empty tensor with the right dimensions if shapeA[0] == 0 and dimsA == 2: @@ -1892,13 +2273,9 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) if dimsA == 2 and out is None: - out, Sout = get_transform_buffer( - (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" - ) + out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row") elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer( - (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" - ) + out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row") assert dimsB != 3, "len(B.shape)==3 not supported" assert A.device.type == "cuda" @@ -1940,49 +2317,33 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = 0 ptrRowScale = get_ptr(None) is_on_gpu([A, B, out]) - if formatB == 'col_turing': + if formatB == "col_turing": if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_turing_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) elif formatB == "col_ampere": if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_ampere_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") if has_error: - print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') - raise Exception('cublasLt ran into an error!') + print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}") + raise Exception("cublasLt ran into an error!") torch.cuda.set_device(prev_device) return out, Sout -def mm_dequant( - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None -): +def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): assert A.dtype == torch.int32 - if bias is not None: assert bias.dtype == torch.float16 + if bias is not None: + assert bias.dtype == torch.float16 out_shape = quant_state[0] if len(out_shape) == 3: out_shape = (out_shape[0] * out_shape[1], out_shape[2]) @@ -1990,19 +2351,11 @@ def mm_dequant( if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) if new_row_stats is None: - new_row_stats = torch.empty( - out_shape[0], dtype=torch.float32, device=A.device - ) + new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) if new_col_stats is None: - new_col_stats = torch.empty( - out_shape[1], dtype=torch.float32, device=A.device - ) - assert ( - new_row_stats.shape[0] == row_stats.shape[0] - ), f"{new_row_stats.shape} vs {row_stats.shape}" - assert ( - new_col_stats.shape[0] == col_stats.shape[0] - ), f"{new_col_stats.shape} vs {col_stats.shape}" + new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" + assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" prev_device = pre_call(A.device) ptrA = get_ptr(A) @@ -2016,15 +2369,23 @@ def mm_dequant( numCols = ct.c_int32(out_shape[1]) is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) + lib.cdequant_mm_int32_fp16( + ptrA, + ptrRowStats, + ptrColStats, + ptrOut, + ptrNewRowStats, + ptrNewColStats, + ptrBias, + numRows, + numCols, + ) post_call(prev_device) return out -def get_colrow_absmax( - A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 -): +def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): assert A.dtype == torch.float16 device = A.device @@ -2037,18 +2398,12 @@ def get_colrow_absmax( col_tiles = (cols + 255) // 256 tiled_rows = ((rows + 15) // 16) * 16 if row_stats is None: - row_stats = torch.empty( - (rows,), dtype=torch.float32, device=device - ).fill_(-50000.0) + row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0) if col_stats is None: - col_stats = torch.empty( - (cols,), dtype=torch.float32, device=device - ).fill_(-50000.0) + col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0) if nnz_block_ptr is None and threshold > 0.0: - nnz_block_ptr = torch.zeros( - ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device - ) + nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device) ptrA = get_ptr(A) ptrRowStats = get_ptr(row_stats) @@ -2122,14 +2477,10 @@ def __init__(self, rows, cols, nnz, colptr, rowidx, values): def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros( - (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device - ) + rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) - return CSRSparseTensor( - cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values - ) + return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values) def coo2csc(cooA): @@ -2138,14 +2489,10 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros( - (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device - ) + colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) - return CSCSparseTensor( - cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values - ) + return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) def coo_zeros(rows, cols, nnz, device, dtype=torch.half): @@ -2155,9 +2502,7 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -def double_quant( - A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 -): +def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): device = A.device assert A.dtype == torch.half assert device.type == "cuda" @@ -2170,9 +2515,7 @@ def double_quant( rows = A.shape[0] if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( - A, threshold=threshold - ) + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) @@ -2190,9 +2533,7 @@ def double_quant( if threshold > 0.0: nnz = nnz_row_ptr[-1].item() if nnz > 0: - coo_tensor = coo_zeros( - A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device - ) + coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) ptrRowIdx = get_ptr(coo_tensor.rowidx) ptrColIdx = get_ptr(coo_tensor.colidx) ptrVal = get_ptr(coo_tensor.values) @@ -2251,12 +2592,16 @@ def double_quant( return out_row, out_col, row_stats, col_stats, coo_tensor -def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): +def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): prev_device = pre_call(A.device) - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: new_state = (state[0], to_order) # (shape, order) + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: + new_state = (state[0], to_order) # (shape, order) shape = state[0] if len(shape) == 2: @@ -2267,7 +2612,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No dim2 = ct.c_int32(shape[2]) is_on_gpu([A, out]) - if to_order == 'col32': + if to_order == "col32": if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: @@ -2288,7 +2633,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No elif from_order == "col_ampere": lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: - raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') + raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}") post_call(prev_device) @@ -2297,9 +2642,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No def spmm_coo(cooA, B, out=None): if out is None: - out = torch.empty( - (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype - ) + out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz @@ -2326,16 +2669,28 @@ def spmm_coo(cooA, B, out=None): cldc = ct.c_int32(ldc) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) - lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) + lib.cspmm_coo( + ptr, + ptrRowidx, + ptrColidx, + ptrValues, + cnnz, + crowsA, + ccolsA, + ccolsB, + cldb, + ptrB, + cldc, + ptrC, + ct.c_bool(transposed_B), + ) return out def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): if out is None: - out = torch.zeros( - (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype - ) + out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype) nnz = cooA.nnz prev_device = pre_call(B.device) assert cooA.rowidx.numel() == nnz @@ -2353,9 +2708,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): max_count, max_idx = torch.sort(counts, descending=True) max_idx = max_idx.int() max_count = max_count.int() - assert ( - max_count[0] <= 32 - ), f"Current max count per row is 8 but found {max_count[0]}." + assert max_count[0] <= 32, f"Current max count per row is 8 but found {max_count[0]}." assert B.dtype in [torch.float16, torch.int8] ptrOffset = get_ptr(offset) ptrMaxCount = get_ptr(max_count) @@ -2443,9 +2796,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): elif quant_type in ["vector-zeropoint", "row-zeropoint"]: dtype = x.dtype x = x.float() - dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin( - x, dim=dim, keepdim=True - ) + dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True) dyna[dyna == 0] = 1 qx = 255.0 / dyna minx = torch.amin(x, dim=dim, keepdim=True) @@ -2553,9 +2904,7 @@ def extract_outliers(A, SA, idx): assert formatA in ["col_turing", "col_ampere"] assert A.device.type == "cuda" - out = torch.zeros( - (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device - ) + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) idx_size = ct.c_int32(idx.numel()) rows = ct.c_int32(shapeA[0]) @@ -2565,7 +2914,7 @@ def extract_outliers(A, SA, idx): ptrOut = get_ptr(out) prev_device = pre_call(A.device) - if formatA == 'col_turing': + if formatA == "col_turing": lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) @@ -2573,6 +2922,7 @@ def extract_outliers(A, SA, idx): return out + def pipeline_test(A, batch_size): out = torch.zeros_like(A) lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index f7b96205b..e1cc6600d 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -44,6 +44,7 @@ class StableEmbedding(torch.nn.Embedding): reset_parameters(): Reset embedding parameters using Xavier uniform initialization. forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer. """ + def __init__( self, num_embeddings: int, @@ -89,9 +90,7 @@ def __init__( dtype, ) self.norm = torch.nn.LayerNorm(embedding_dim, device=device) - GlobalOptimManager.get_instance().register_module_override( - self, "weight", {"optim_bits": 32} - ) + GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32}) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) @@ -130,6 +129,7 @@ class Embedding(torch.nn.Embedding): """ Embedding class to store and retrieve word embeddings from their indices. """ + def __init__( self, num_embeddings: int, @@ -170,11 +170,9 @@ def __init__( scale_grad_by_freq, sparse, _weight, - device=device - ) - GlobalOptimManager.get_instance().register_module_override( - self, "weight", {"optim_bits": 32} + device=device, ) + GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32}) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) @@ -208,16 +206,16 @@ def forward(self, input: Tensor) -> Tensor: class Params4bit(torch.nn.Parameter): def __new__( - cls, - data: Optional[torch.Tensor] = None, - requires_grad=False, # quantized weights should be frozen by default - quant_state: Optional[QuantState] = None, - blocksize: int = 64, - compress_statistics: bool = True, - quant_type: str = 'fp4', - quant_storage: torch.dtype = torch.uint8, - module: Optional["Linear4bit"] = None, - bnb_quantized: bool = False + cls, + data: Optional[torch.Tensor] = None, + requires_grad=False, # quantized weights should be frozen by default + quant_state: Optional[QuantState] = None, + blocksize: int = 64, + compress_statistics: bool = True, + quant_type: str = "fp4", + quant_storage: torch.dtype = torch.uint8, + module: Optional["Linear4bit"] = None, + bnb_quantized: bool = False, ) -> "Params4bit": if data is None: data = torch.empty(0) @@ -250,7 +248,7 @@ def __setstate__(self, state): self.bnb_quantized = state["bnb_quantized"] self.module = state["module"] - def __deepcopy__(self,memo): + def __deepcopy__(self, memo): new_instance = type(self).__new__(type(self)) state = self.__getstate__() new_instance.__setstate__(state) @@ -265,7 +263,14 @@ def __copy__(self): return new_instance @classmethod - def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit": + def from_prequantized( + cls, + data: torch.Tensor, + quantized_stats: Dict[str, Any], + requires_grad: bool = False, + device="cuda", + **kwargs, + ) -> "Params4bit": self = torch.Tensor._make_subclass(cls, data.to(device)) self.requires_grad = requires_grad self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device) @@ -292,33 +297,39 @@ def _quantize(self, device): return self def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): - return self.to(device='cuda' if device is None else device, non_blocking=non_blocking) + return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) @overload - def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: - ... + def to( + self: T, + device: Optional[Union[int, device]] = ..., + dtype: Optional[Union[dtype, str]] = ..., + non_blocking: bool = ..., + ) -> T: ... @overload - def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: - ... + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ... @overload - def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: - ... + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if (device is not None and device.type == "cuda" and not self.bnb_quantized): + if device is not None and device.type == "cuda" and not self.bnb_quantized: return self._quantize(device) else: if self.quant_state is not None: self.quant_state.to(device) - new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), - requires_grad=self.requires_grad, quant_state=self.quant_state, - blocksize=self.blocksize, compress_statistics=self.compress_statistics, - quant_type=self.quant_type) + new_param = Params4bit( + super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, + quant_state=self.quant_state, + blocksize=self.blocksize, + compress_statistics=self.compress_statistics, + quant_type=self.quant_type, + ) return new_param @@ -355,7 +366,18 @@ class Linear4bit(nn.Linear): quantized_model = quantized_model.to(0) # Quantization happens here ``` """ - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None): + + def __init__( + self, + input_features, + output_features, + bias=True, + compute_dtype=None, + compress_statistics=True, + quant_type="fp4", + quant_storage=torch.uint8, + device=None, + ): """ Initialize Linear4bit class. @@ -368,7 +390,14 @@ def __init__(self, input_features, output_features, bias=True, compute_dtype=Non Whether the linear class uses the bias term as well. """ super().__init__(input_features, output_features, bias, device) - self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) + self.weight = Params4bit( + self.weight.data, + requires_grad=False, + compress_statistics=compress_statistics, + quant_type=quant_type, + quant_storage=quant_storage, + module=self, + ) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = False @@ -385,11 +414,15 @@ def set_compute_type(self, x): if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]): # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast # warn the user about this - warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.') - warnings.filterwarnings('ignore', message='.*inference.') + warnings.warn( + "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.", + ) + warnings.filterwarnings("ignore", message=".*inference.") if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]): - warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.') - warnings.filterwarnings('ignore', message='.*inference or training') + warnings.warn( + "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.", + ) + warnings.filterwarnings("ignore", message=".*inference or training") def _save_to_state_dict(self, destination, prefix, keep_vars): """ @@ -407,8 +440,8 @@ def forward(self, x: torch.Tensor): if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) - if getattr(self.weight, 'quant_state', None) is None: - if getattr(self, 'quant_state', None) is not None: + if getattr(self.weight, "quant_state", None) is None: + if getattr(self, "quant_state", None) is not None: # the quant state got lost when the parameter got converted. This happens for example for fsdp # since we registered the module, we can recover the state here assert self.weight.shape[1] == 1 @@ -416,7 +449,9 @@ def forward(self, x: torch.Tensor): self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) self.weight.quant_state = self.quant_state else: - print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.", + ) if not self.compute_type_is_set: self.set_compute_type(x) self.compute_type_is_set = True @@ -437,7 +472,17 @@ class LinearFP4(Linear4bit): """ Implements the FP4 data type. """ - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): + + def __init__( + self, + input_features, + output_features, + bias=True, + compute_dtype=None, + compress_statistics=True, + quant_storage=torch.uint8, + device=None, + ): """ Args: input_features (`str`): @@ -447,21 +492,40 @@ def __init__(self, input_features, output_features, bias=True, compute_dtype=Non bias (`bool`, defaults to `True`): Whether the linear class uses the bias term as well. """ - super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device) + super().__init__( + input_features, + output_features, + bias, + compute_dtype, + compress_statistics, + "fp4", + quant_storage, + device, + ) class LinearNF4(Linear4bit): - ''' Implements the NF4 data type. + """Implements the NF4 data type. + + Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that + is normalized into the range [-1, 1]. - Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that - is normalized into the range [-1, 1]. + For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314) - For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314) + Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in + the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. + """ - Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in - the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. - ''' - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): + def __init__( + self, + input_features, + output_features, + bias=True, + compute_dtype=None, + compress_statistics=True, + quant_storage=torch.uint8, + device=None, + ): """ Args: input_features (`str`): @@ -471,7 +535,16 @@ def __init__(self, input_features, output_features, bias=True, compute_dtype=Non bias (`bool`, defaults to `True`): Whether the linear class uses the bias term as well. """ - super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device) + super().__init__( + input_features, + output_features, + bias, + compute_dtype, + compress_statistics, + "nf4", + quant_storage, + device, + ) class Int8Params(torch.nn.Parameter): @@ -514,33 +587,22 @@ def to( device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ..., - ) -> T: - ... + ) -> T: ... @overload - def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: - ... + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ... @overload - def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: - ... + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): - device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( - *args, **kwargs - ) + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if ( - device is not None - and device.type == "cuda" - and self.data.device.type == "cpu" - ): + if device is not None and device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) else: new_param = Int8Params( - super().to( - device=device, dtype=dtype, non_blocking=non_blocking - ), + super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights, ) @@ -593,8 +655,18 @@ class Linear8bitLt(nn.Linear): int8_model = int8_model.to(0) # Quantization happens here ``` """ - def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, - memory_efficient_backward=False, threshold=0.0, index=None, device=None): + + def __init__( + self, + input_features, + output_features, + bias=True, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + device=None, + ): """ Initialize Linear8bitLt class. @@ -647,19 +719,36 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): destination[key_name] = param_from_state if keep_vars else param_from_state.detach() destination[format_name] = self.state.formatB - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs) + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) unexpected_copy = list(unexpected_keys) for key in unexpected_copy: - input_name = key[len(prefix):] + input_name = key[len(prefix) :] if input_name == "SCB": if self.weight.SCB is None: # buffers not yet initialized, can't access them directly without quantizing first - raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " - "not supported. Please call module.cuda() before module.load_state_dict()") + raise RuntimeError( + "Loading a quantized checkpoint into non-quantized Linear8bitLt is " + "not supported. Please call module.cuda() before module.load_state_dict()", + ) input_param = state_dict[key] self.weight.SCB.copy_(input_param) @@ -702,18 +791,18 @@ def __init__(self, input_features, output_features, bias=True, device=None): self.is_quantized = False def forward_with_outliers(self, x, outlier_idx): - raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function') + raise NotImplementedError("Please override the `forward_with_outliers(self, x, outlier_idx)` function") def quantize_weight(self, w, outlier_idx): - raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function') + raise NotImplementedError("Please override the `quantize_weights(self, w, outlier_idx)` function") def forward(self, x): if self.outlier_dim is None: tracer = OutlierTracer.get_instance() if not tracer.is_initialized(): - print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer') + print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer") outlier_idx = tracer.get_outliers(self.weight) - #print(outlier_idx, tracer.get_hvalue(self.weight)) + # print(outlier_idx, tracer.get_hvalue(self.weight)) self.outlier_dim = outlier_idx if not self.is_quantized: @@ -721,6 +810,7 @@ def forward(self, x): self.weight.data.copy_(w) self.is_quantized = True + class SwitchBackLinearBnb(nn.Linear): def __init__( self, @@ -731,11 +821,9 @@ def __init__( memory_efficient_backward=False, threshold=0.0, index=None, - device=None + device=None, ): - super().__init__( - input_features, output_features, bias, device - ) + super().__init__(input_features, output_features, bias, device) self.state = bnb.MatmulLtState() self.index = index @@ -745,9 +833,7 @@ def __init__( if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params( - self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights - ) + self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) def init_8bit_state(self): self.state.CB = self.weight.CB diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py index 9c7738c59..aa8494942 100644 --- a/bitsandbytes/nn/triton_based_modules.py +++ b/bitsandbytes/nn/triton_based_modules.py @@ -22,7 +22,6 @@ class _switchback_global(torch.autograd.Function): - @staticmethod def forward(ctx, X_3D, W, bias): # reshape input to [N * L, D] @@ -37,9 +36,7 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call "mixed" because we are mixing rowwise quantized and global quantized - return int8_matmul_mixed_dequantize( - X_int8, W_int8.t(), state_X, state_W, bias - ).view(*X_3D.size()[:-1], -1) + return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1) @staticmethod def backward(ctx, G_3D): @@ -56,7 +53,8 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) W_int8, state_W = quantize_global_transpose(W) grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D.size()[:-1], -1 + *G_3D.size()[:-1], + -1, ) if ctx.needs_input_grad[1]: # backward pass uses standard weight grad @@ -66,8 +64,8 @@ def backward(ctx, G_3D): return grad_X, grad_W, grad_bias -class _switchback_vectorrize(torch.autograd.Function): +class _switchback_vectorrize(torch.autograd.Function): @staticmethod def forward(ctx, X_3D, W, bias): # reshape input to [N * L, D] @@ -81,9 +79,7 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call kernel which expects rowwise quantized X and W - return int8_matmul_rowwise_dequantize( - X_int8, W_int8.t(), state_X, state_W, bias - ).view(*X_3D.size()[:-1], -1) + return int8_matmul_rowwise_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1) @staticmethod def backward(ctx, G_3D): @@ -99,7 +95,8 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) W_int8, state_W = quantize_columnwise_and_transpose(W) grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D.size()[:-1], -1 + *G_3D.size()[:-1], + -1, ) if ctx.needs_input_grad[1]: # backward pass uses standard weight grad @@ -109,8 +106,8 @@ def backward(ctx, G_3D): return grad_X, grad_W, grad_bias -class _switchback_global_mem_efficient(torch.autograd.Function): +class _switchback_global_mem_efficient(torch.autograd.Function): @staticmethod def forward(ctx, X_3D, W, bias): # reshape input to [N * L, D] @@ -127,9 +124,7 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call "mixed" because we are mixing rowwise quantized and global quantized - return int8_matmul_mixed_dequantize( - X_int8, W_int8.t(), state_X, state_W, bias - ).view(*X_3D_sz[:-1], -1) + return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D_sz[:-1], -1) @staticmethod def backward(ctx, G_3D): @@ -151,35 +146,34 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) del G W_int8 = W_int8.t().contiguous() - grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D_sz[:-1], -1 - ) + grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(*G_3D_sz[:-1], -1) return grad_X, grad_W, grad_bias + class SwitchBackLinear(nn.Linear): def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - vector_wise_quantization: bool = False, - mem_efficient : bool = False, - ): + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + vector_wise_quantization: bool = False, + mem_efficient: bool = False, + ): super().__init__(in_features, out_features, bias, device, dtype) if not is_triton_available(): - raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. - Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') + raise ImportError("""Could not import triton. Please install triton to use SwitchBackLinear. + Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower""") # By default, we use the global quantization. self.vector_wise_quantization = vector_wise_quantization if self.vector_wise_quantization: self._fn = _switchback_vectorrize if mem_efficient: - print('mem efficient is not supported for vector-wise quantization.') + print("mem efficient is not supported for vector-wise quantization.") exit(1) else: if mem_efficient: @@ -195,7 +189,7 @@ def prepare_for_eval(self): # if hasattr(m, "prepare_for_eval"): # m.prepare_for_eval() # model.apply(cond_prepare) - print('=> preparing for eval.') + print("=> preparing for eval.") if self.vector_wise_quantization: W_int8, state_W = quantize_rowwise(self.weight) else: @@ -219,18 +213,22 @@ def forward(self, x): X_int8, state_X = quantize_rowwise(X) if self.vector_wise_quantization: - return int8_matmul_rowwise_dequantize( - X_int8, self.W_int8.t(), state_X, self.state_W, self.bias - ).view(*x.size()[:-1], -1) + return int8_matmul_rowwise_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view( + *x.size()[:-1], + -1, + ) else: - return int8_matmul_mixed_dequantize( - X_int8, self.W_int8.t(), state_X, self.state_W, self.bias - ).view(*x.size()[:-1], -1) + return int8_matmul_mixed_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view( + *x.size()[:-1], + -1, + ) + SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) + # This is just the standard linear function. class StandardLinearFunction(torch.autograd.Function): @staticmethod @@ -260,7 +258,7 @@ def backward(ctx, grad_output_3D): return grad_input, grad_weight, grad_bias -class StandardLinear(nn.Linear): +class StandardLinear(nn.Linear): def forward(self, x): return StandardLinearFunction.apply(x, self.weight, self.bias) diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index c2ea87ab0..aace548fa 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -50,9 +50,7 @@ def __init__( if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: @@ -119,9 +117,7 @@ def __init__( if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: @@ -189,9 +185,7 @@ def __init__( if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index e534c8b8f..d8ffca63e 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -14,8 +14,21 @@ class Adam(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): """ Base Adam optimizer. @@ -45,11 +58,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class Adam8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): """ 8-bit Adam optimizer. @@ -79,11 +119,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class Adam32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): """ 32-bit Adam optimizer. @@ -113,11 +180,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class PagedAdam(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): """ Paged Adam optimizer. @@ -147,11 +241,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedAdam8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): """ 8-bit paged Adam optimizer. @@ -181,11 +302,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedAdam32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): """ Paged 32-bit Adam optimizer. @@ -215,7 +363,21 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class AnalysisAdam(torch.optim.Optimizer): """Adam that performs 8-bit vs 32-bit error analysis. @@ -293,9 +455,7 @@ def step(self, closure=None): if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: - raise RuntimeError( - "Adam does not support sparse gradients, please consider SparseAdam instead" - ) + raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") amsgrad = group.get("amsgrad", False) assert not amsgrad @@ -312,15 +472,9 @@ def step(self, closure=None): state["exp_avg"] = torch.zeros_like(p_data_fp32) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) - state["abserrors"] = torch.zeros( - (256, 256), device=p_data_fp32.device - ) - state["relerrors"] = torch.zeros( - (256, 256), device=p_data_fp32.device - ) - state["counts"] = torch.zeros( - (256, 256), device=p_data_fp32.device - ) + state["abserrors"] = torch.zeros((256, 256), device=p_data_fp32.device) + state["relerrors"] = torch.zeros((256, 256), device=p_data_fp32.device) + state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) @@ -328,25 +482,19 @@ def step(self, closure=None): state["exp_avg"] = state["exp_avg"].to(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) if amsgrad: - state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( - p_data_fp32 - ) + state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32) state["step"] += 1 beta1, beta2 = group["betas"] bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] - step_size = ( - group["lr"] * math.sqrt(bias_correction2) / bias_correction1 - ) + step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 e = state["abserrors"] rele = state["relerrors"] counts = state["counts"] if group["weight_decay"] != 0: - p_data_fp32.add_( - p_data_fp32, alpha=-group["weight_decay"] * group["lr"] - ) + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] if amsgrad: @@ -359,10 +507,7 @@ def step(self, closure=None): denom = exp_avg_sq.sqrt().add_(group["eps"]) update_fp32 = exp_avg / denom - if ( - p_data_fp32.numel() <= 8192 - or p_data_fp32.numel() > 50000 * 1000 - ): + if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: # embedding layer or too small p_data_fp32 += -step_size * update_fp32 else: @@ -401,9 +546,7 @@ def step(self, closure=None): # 3. dequantize # Error will be calculated automatically! else: - raise ValueError( - f"Invalid analysis value: {self.analysis}!" - ) + raise ValueError(f"Invalid analysis value: {self.analysis}!") denom = state2.sqrt().add_(group["eps"]) update_8bit = state1 / denom @@ -415,9 +558,7 @@ def step(self, closure=None): F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) - F.histogram_scatter_add_2d( - counts, C1.int(), C2.int(), torch.ones_like(abserr) - ) + F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr)) p_data_fp32 += -step_size * update_fp32 @@ -425,18 +566,10 @@ def step(self, closure=None): if self.savedir != "" and state["step"] % 100 == 0: if not os.path.exists(self.savedir): os.makedirs(self.savedir) - shapestr = "_".join( - [str(dim) for dim in p_data_fp32.shape] - ) - pathe = os.path.join( - self.savedir, f"{p_id}_{shapestr}_abserr.pkl" - ) - pathrele = os.path.join( - self.savedir, f"{p_id}_{shapestr}_relerr.pkl" - ) - pathcounts = os.path.join( - self.savedir, f"{p_id}_{shapestr}_counts.pkl" - ) + shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) + pathe = os.path.join(self.savedir, f"{p_id}_{shapestr}_abserr.pkl") + pathrele = os.path.join(self.savedir, f"{p_id}_{shapestr}_relerr.pkl") + pathcounts = os.path.join(self.savedir, f"{p_id}_{shapestr}_counts.pkl") torch.save(e, pathe) torch.save(rele, pathrele) torch.save(counts, pathcounts) diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 1e2dc04de..fa51458fd 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -6,8 +6,21 @@ class AdamW(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): """ Base AdamW optimizer. @@ -37,11 +50,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1 is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class AdamW8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): """ 8-bit AdamW optimizer. @@ -71,11 +111,38 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1 is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class AdamW32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): """ 32-bit AdamW optimizer. @@ -105,12 +172,37 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1 is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) class PagedAdamW(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): """ Paged AdamW optimizer. @@ -140,11 +232,37 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1 is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedAdamW8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): """ Paged 8-bit AdamW optimizer. @@ -174,11 +292,37 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1 is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedAdamW32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): """ Paged 32-bit AdamW optimizer. @@ -208,4 +352,17 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1 is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index 7449b805b..63c062988 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -51,9 +51,7 @@ def __init__( The maximum gradient norm. """ if momentum == 0: - raise NotImplementedError( - "LARS without momentum is not supported!" - ) + raise NotImplementedError("LARS without momentum is not supported!") super().__init__( "lars", params, @@ -110,9 +108,7 @@ def __init__( The maximum gradient norm. """ if momentum == 0: - raise NotImplementedError( - "LARS without momentum is not supported!" - ) + raise NotImplementedError("LARS without momentum is not supported!") super().__init__( "lars", params, @@ -169,9 +165,7 @@ def __init__( The maximum gradient norm. """ if momentum == 0: - raise NotImplementedError( - "LARS without momentum is not supported!" - ) + raise NotImplementedError("LARS without momentum is not supported!") super().__init__( "lars", params, @@ -204,9 +198,7 @@ def __init__( if momentum < 0.0: raise ValueError(f"Invalid momentum value: {momentum}") if weight_decay < 0.0: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, @@ -217,9 +209,7 @@ def __init__( max_unorm=max_unorm, ) if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError( - "Nesterov momentum requires a momentum and zero dampening" - ) + raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults) def __setstate__(self, state): diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index ce185f863..9f0f4a8a9 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -6,7 +6,19 @@ class Lion(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): """ Base Lion optimizer. @@ -32,10 +44,35 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bit is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class Lion8bit(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): """ 8-bit Lion optimizer. @@ -59,10 +96,35 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class Lion32bit(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): """ 32-bit Lion optimizer. @@ -86,11 +148,35 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) class PagedLion(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): """ Paged Lion optimizer. @@ -114,10 +200,34 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bit block_wise (`bool`, defaults to `True`): Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ - super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedLion8bit(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): """ Paged 8-bit Lion optimizer. @@ -141,10 +251,34 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None block_wise (`bool`, defaults to `True`): Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ - super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedLion32bit(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): """ Paged 32-bit Lion optimizer. @@ -168,4 +302,17 @@ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None block_wise (`bool`, defaults to `True`): Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ - super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index a97afb026..43ebbb24d 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -21,6 +21,7 @@ class GlobalOptimManager: """ A global optimizer manager for enabling custom optimizer configs. """ + _instance = None def __init__(self): @@ -48,13 +49,9 @@ def register_parameters(self, params): for group_index, group in enumerate(param_groups): for p_index, p in enumerate(group["params"]): if id(p) in self.pid2config: - self.index2config[(group_index, p_index)] = self.pid2config[ - id(p) - ] + self.index2config[(group_index, p_index)] = self.pid2config[id(p)] - def override_config( - self, parameters, key=None, value=None, key_value_dict=None - ): + def override_config(self, parameters, key=None, value=None, key_value_dict=None): """ Override initial optimizer config with specific hyperparameters. @@ -132,18 +129,18 @@ def __init__(self, params, defaults, optim_bits=32, is_paged=False): self.mng = GlobalOptimManager.get_instance() self.non_castable_tensor_keys = { - "qmap1", - "qmap2", - "max1", - "max2", - "new_max1", - "new_max2", - "state1", - "state2", - "gnorm_vec", - "absmax1", - "absmax2", - "unorm_vec", + "qmap1", + "qmap2", + "max1", + "max2", + "new_max1", + "new_max2", + "state1", + "state2", + "gnorm_vec", + "absmax1", + "absmax2", + "unorm_vec", } if optim_bits == 8: @@ -170,16 +167,12 @@ def load_state_dict(self, state_dict): saved_groups = state_dict["param_groups"] if len(groups) != len(saved_groups): - raise ValueError( - "loaded state dict has a different number of " - "parameter groups" - ) + raise ValueError("loaded state dict has a different number of parameter groups") param_lens = (len(g["params"]) for g in groups) saved_lens = (len(g["params"]) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): raise ValueError( - "loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group" + "loaded state dict contains a parameter group that doesn't match the size of optimizer's group", ) # Update the state @@ -228,9 +221,7 @@ def update_group(group, new_group): new_group["params"] = group["params"] return new_group - param_groups = [ - update_group(g, ng) for g, ng in zip(groups, saved_groups) - ] + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] self.__setstate__({"state": state, "param_groups": param_groups}) def to_gpu(self): @@ -240,7 +231,7 @@ def to_gpu(self): values = self.state[p] for k, v in values.items(): if isinstance(v, torch.Tensor): - is_paged = getattr(v, 'is_paged', False) + is_paged = getattr(v, "is_paged", False) if not is_paged: self.state[p][k] = v.to(p.device) @@ -248,9 +239,7 @@ def check_overrides(self): for module, attr, config in self.mng.module_weight_config_triple: pmodule = getattr(module, attr) assert pmodule is not None - assert isinstance(pmodule, torch.Tensor) or isinstance( - pmodule, torch.Parameter - ) + assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter) found = False for gindex, group in enumerate(self.param_groups): if found: @@ -262,9 +251,7 @@ def check_overrides(self): # found the matching parameter # init override self.mng.pid2config[id(p)] = config - self.mng.index2config[ - (gindex, pindex) - ] = self.mng.pid2config[id(p)] + self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)] found = True @torch.no_grad() @@ -287,7 +274,7 @@ def step(self, closure=None): self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True - #if self.is_paged: self.page_mng.prefetch_all() + # if self.is_paged: self.page_mng.prefetch_all() for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: @@ -304,7 +291,6 @@ def step(self, closure=None): # to sync to make sure all tensors are in the right state torch.cuda.synchronize() - return loss def get_config(self, gindex, pindex, group): @@ -328,9 +314,7 @@ def init_state(self, group, p, gindex, pindex): raise NotImplementedError("init_state method needs to be overridden") def update_step(self, group, p, gindex, pindex): - raise NotImplementedError( - "The update_step method needs to be overridden" - ) + raise NotImplementedError("The update_step method needs to be overridden") def get_state_buffer(self, p, dtype=torch.float32): if not self.is_paged or p.numel() < 1e5: @@ -345,12 +329,12 @@ def get_state_buffer(self, p, dtype=torch.float32): def prefetch_state(self, p): if self.is_paged: state = self.state[p] - s1 = state['state1'] - is_paged = getattr(s1, 'is_paged', False) + s1 = state["state1"] + is_paged = getattr(s1, "is_paged", False) if is_paged: - F.prefetch_tensor(state['state1']) - if 'state2' in state: - F.prefetch_tensor(state['state2']) + F.prefetch_tensor(state["state1"]) + if "state2" in state: + F.prefetch_tensor(state["state2"]) class Optimizer2State(Optimizer8bit): @@ -369,7 +353,7 @@ def __init__( block_wise=True, max_unorm=0.0, skip_zeros=False, - is_paged=False + is_paged=False, ): """ Base 2-state update optimizer class. @@ -414,13 +398,9 @@ def __init__( betas = [float(b) for b in betas] for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError( - f"Invalid beta parameter at index {i}: {betas[i]}" - ) + raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults, optim_bits, is_paged) @@ -449,9 +429,7 @@ def init_state(self, group, p, gindex, pindex): elif config["optim_bits"] == 8: dtype = torch.uint8 else: - raise NotImplementedError( - f'Amount of optimizer bits not supported: {config["optim_bits"]}' - ) + raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') if p.numel() < config["min_8bit_size"]: dtype = torch.float32 @@ -459,21 +437,15 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or ( - dtype == torch.uint8 and p.numel() < 4096 - ): + if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state2"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( - p.device - ) - self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to( - p.device - ) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) + self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device) state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] @@ -486,25 +458,13 @@ def init_state(self, group, p, gindex, pindex): blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state["absmax1"] = torch.zeros( - (blocks,), dtype=torch.float32, device=p.device - ) - state["absmax2"] = torch.zeros( - (blocks,), dtype=torch.float32, device=p.device - ) + state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) else: - state["max1"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) - state["new_max1"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) - state["max2"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) - state["new_max2"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) + state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) if config["percentile_clipping"] < 100: state["gnorm_vec"] = torch.zeros((100,), device=p.device) @@ -524,7 +484,10 @@ def update_step(self, group, p, gindex, pindex): if config["percentile_clipping"] < 100: current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( - grad, state["gnorm_vec"], step, config["percentile_clipping"] + grad, + state["gnorm_vec"], + step, + config["percentile_clipping"], ) else: gnorm_scale = 1.0 @@ -568,9 +531,7 @@ def update_step(self, group, p, gindex, pindex): state["new_max2"], config["weight_decay"], gnorm_scale=gnorm_scale, - unorm_vec=state["unorm_vec"] - if config["max_unorm"] > 0.0 - else None, + unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, max_unorm=config["max_unorm"], ) @@ -615,7 +576,7 @@ def __init__( block_wise=True, max_unorm=0.0, skip_zeros=False, - is_paged=False + is_paged=False, ): """ Base 1-state update optimizer class. @@ -656,13 +617,9 @@ def __init__( raise ValueError(f"Invalid epsilon value: {eps}") for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError( - f"Invalid beta parameter at index {i}: {betas[i]}" - ) + raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults, optim_bits, is_paged) @@ -691,9 +648,7 @@ def init_state(self, group, p, gindex, pindex): elif config["optim_bits"] == 8: dtype = torch.uint8 else: - raise NotImplementedError( - f'Amount of optimizer bits not supported: {config["optim_bits"]}' - ) + raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') if p.numel() < config["min_8bit_size"]: dtype = torch.float32 @@ -701,17 +656,13 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or ( - dtype == torch.uint8 and p.numel() < 4096 - ): + if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): state["state1"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( - p.device - ) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] @@ -721,16 +672,10 @@ def init_state(self, group, p, gindex, pindex): blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state["absmax1"] = torch.zeros( - (blocks,), dtype=torch.float32, device=p.device - ) + state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) else: - state["max1"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) - state["new_max1"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) + state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) if config["percentile_clipping"] < 100: state["gnorm_vec"] = torch.zeros((100,), device=p.device) @@ -750,7 +695,10 @@ def update_step(self, group, p, gindex, pindex): if config["percentile_clipping"] < 100: current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( - grad, state["gnorm_vec"], step, config["percentile_clipping"] + grad, + state["gnorm_vec"], + step, + config["percentile_clipping"], ) else: gnorm_scale = 1.0 @@ -766,7 +714,7 @@ def update_step(self, group, p, gindex, pindex): step, config["lr"], None, - config['betas'][1], + config["betas"][1], config["weight_decay"], gnorm_scale, state["unorm_vec"] if config["max_unorm"] > 0.0 else None, diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index ac371a66f..659617654 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -51,9 +51,7 @@ def __init__( Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) + raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") if centered: raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( @@ -116,9 +114,7 @@ def __init__( Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) + raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") if centered: raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( @@ -182,9 +178,7 @@ def __init__( """ if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) + raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") if centered: raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 7d869e39a..b194b8777 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -195,9 +195,9 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 ctx.B = B ctx.bias = bias if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device) else: - return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) # 1. Quantize A # 2. Quantize B @@ -216,9 +216,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( - A.to(torch.float16), threshold=state.threshold - ) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -234,14 +232,14 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: - #print('A shape', A.shape) + # print('A shape', A.shape) if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None # 2. Quantize B if state.has_fp16_weights: - #print('B shape', B.shape) + # print('B shape', B.shape) has_grad = True if (getattr(B, "grad", None) is not None) else False is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) if is_transposed: @@ -272,12 +270,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # else: # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - state.subB = ( - (outliers * state.SCB.view(-1, 1) / 127.0) - .t() - .contiguous() - .to(A.dtype) - ) + state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] @@ -320,14 +313,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - - clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + clone_func = torch.clone if len(output_shape) == 3 else lambda x: x return clone_func(output.view(output_shape)) @staticmethod def backward(ctx, grad_output): if ctx.is_empty: - bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA, A = ctx.tensors @@ -342,9 +334,7 @@ def backward(ctx, grad_output): # Cast grad_output to fp16 if len(grad_output.shape) == 3: - grad_output = grad_output.reshape( - -1, grad_output.shape[-1] - ).contiguous() + grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) @@ -357,25 +347,24 @@ def backward(ctx, grad_output): if state.CBt is not None: C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) + state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) # print('back B shape', state.CxBt.shape) # print('back grad shape', C32grad.shape) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: - CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: - raise Exception('State must contain either CBt or CB matrix for backward') + raise Exception("State must contain either CBt or CB matrix for backward") return grad_A, grad_B, None, grad_bias, None + def get_block_sizes(input_matrix, weight_matrix): input_features = input_matrix.shape[-1] - output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]) + output_features = weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1] array = [4096, 2048, 1024, 512, 256, 128, 64, 0] bsz, bsz2 = 1024, 1024 for i, k in enumerate(array): @@ -399,7 +388,8 @@ def matmul_fp8_global( bsz: int = -1, bsz2: int = -1, ): - if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + if bsz == -1 or bsz2 == -1: + bsz, bsz2 = get_block_sizes(A, B) return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) @@ -412,7 +402,8 @@ def matmul_fp8_mixed( bsz: int = -1, bsz2: int = -1, ): - if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + if bsz == -1 or bsz2 == -1: + bsz, bsz2 = get_block_sizes(A, B) return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) @@ -422,7 +413,7 @@ def switchback_bnb( out: Optional[torch.Tensor] = None, state: Optional[MatmulLtState] = None, threshold=0.0, - bias=None + bias=None, ): state = state or MatmulLtState() if threshold > 0.0: diff --git a/bitsandbytes/research/nn/modules.py b/bitsandbytes/research/nn/modules.py index 7fca34d23..57c0f3358 100644 --- a/bitsandbytes/research/nn/modules.py +++ b/bitsandbytes/research/nn/modules.py @@ -28,12 +28,20 @@ def forward(self, x: torch.Tensor): self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) + out = bnb.research.matmul_fp8_mixed( + x, + self.weight.t(), + fw_code=self.fw_code, + bw_code=self.bw_code, + bsz=self.bsz, + bsz2=self.bsz2, + ) if self.bias is not None: out += self.bias return out + class LinearFP8Global(nn.Linear): def __init__(self, input_features, output_features, bias=True): super().__init__(input_features, output_features, bias) @@ -54,7 +62,14 @@ def forward(self, x: torch.Tensor): self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) + out = bnb.matmul_fp8_global( + x, + self.weight.t(), + fw_code=self.fw_code, + bw_code=self.bw_code, + bsz=self.bsz, + bsz2=self.bsz2, + ) if self.bias is not None: out += self.bias diff --git a/bitsandbytes/triton/dequantize_rowwise.py b/bitsandbytes/triton/dequantize_rowwise.py index 3d7529852..26eab84f2 100644 --- a/bitsandbytes/triton/dequantize_rowwise.py +++ b/bitsandbytes/triton/dequantize_rowwise.py @@ -5,9 +5,10 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None -else: + def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): + return None +else: import triton import triton.language as tl @@ -15,21 +16,21 @@ def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None # TODO: autotune this better. @triton.autotune( - configs=[ - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=['n_elements'] + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["n_elements"], ) @triton.jit def _dequantize_rowwise( @@ -51,7 +52,6 @@ def _dequantize_rowwise( output = max_val * x * inv_127 tl.store(output_ptr + offsets, output, mask=row_mask) - def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) @@ -60,5 +60,5 @@ def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): assert x.is_cuda and output.is_cuda n_elements = output.numel() grid = lambda meta: (x.shape[0],) - _dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + _dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) return output diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py index dc3047d7e..583371d91 100644 --- a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py @@ -3,14 +3,14 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None -else: + def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): + return None +else: import triton import triton.language as tl from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - # This is a matmul kernel based on triton.ops.matmul # It is modified to support rowwise quantized input and global quantized weight # It's purpose is fused matmul then dequantize @@ -27,58 +27,83 @@ def get_configs_io_bound(): for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 configs.append( - triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, - num_stages=num_stages, num_warps=num_warps)) + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1}, + num_stages=num_stages, + num_warps=num_warps, + ), + ) # split_k for split_k in [2, 4, 8, 16]: - configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k}, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ), + ) return configs - @triton.autotune( configs=[ # basic configs for compute-bound matmuls - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), # good for int8 - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), *get_configs_io_bound(), ], - key=['M', 'N', 'K'], - prune_configs_by={ - 'early_config_prune': early_config_prune, - 'perf_model': estimate_matmul_time, - 'top_k': 10 + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, + ) + @triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, }, ) - @triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, - }) @triton.jit - def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr - ): + def _int8_matmul_mixed_dequantize( + A, + B, + C, + bias, + state_x_ptr, + state_w_ptr, + M, + N, + K, + divfactor: tl.constexpr, + has_bias: tl.constexpr, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + ): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -115,13 +140,13 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, b = tl.load(B) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) acc += tl.dot(a, b) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - acc = (w_factor * (x_factor * (acc * divfactor))) + acc = w_factor * (x_factor * (acc * divfactor)) acc = acc.to(C.dtype.element_ty) # conditionally add bias @@ -137,10 +162,9 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, else: tl.atomic_add(C, acc, mask=mask) - def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): device = a.device - divfactor = 1. / (127. * 127.) + divfactor = 1.0 / (127.0 * 127.0) has_bias = 0 if bias is None else 1 # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: @@ -154,12 +178,28 @@ def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): # allocates output c = torch.empty((M, N), device=device, dtype=torch.float16) # accumulator types - ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 # launch int8_matmul_mixed_dequantize kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - GROUP_M=8, ACC_TYPE=ACC_TYPE) + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) + _int8_matmul_mixed_dequantize[grid]( + a, + b, + c, + bias, + state_x, + state_w, + M, + N, + K, + divfactor, + has_bias, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + GROUP_M=8, + ACC_TYPE=ACC_TYPE, + ) return c diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py index 4881e1468..e3d192ded 100644 --- a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py @@ -3,7 +3,9 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None + + def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): + return None else: import triton import triton.language as tl @@ -17,7 +19,6 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None def init_to_zero(name): return lambda nargs: nargs[name].zero_() - def get_configs_io_bound(): configs = [] for num_stages in [2, 3, 4, 5, 6]: @@ -26,58 +27,83 @@ def get_configs_io_bound(): for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 configs.append( - triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, - num_stages=num_stages, num_warps=num_warps)) + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1}, + num_stages=num_stages, + num_warps=num_warps, + ), + ) # split_k for split_k in [2, 4, 8, 16]: - configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k}, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ), + ) return configs - @triton.autotune( configs=[ # basic configs for compute-bound matmuls - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), # good for int8 - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), *get_configs_io_bound(), ], - key=['M', 'N', 'K'], - prune_configs_by={ - 'early_config_prune': early_config_prune, - 'perf_model': estimate_matmul_time, - 'top_k': 10 + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, + ) + @triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, }, ) - @triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, - }) @triton.jit - def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr - ): + def _int8_matmul_rowwise_dequantize( + A, + B, + C, + bias, + state_x_ptr, + state_w_ptr, + M, + N, + K, + divfactor, + has_bias: tl.constexpr, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + ): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -114,13 +140,13 @@ def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, b = tl.load(B) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) acc += tl.dot(a, b) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - acc = (w_factor * (x_factor * (acc * divfactor))) + acc = w_factor * (x_factor * (acc * divfactor)) acc = acc.to(C.dtype.element_ty) if has_bias: @@ -135,9 +161,8 @@ def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, else: tl.atomic_add(C, acc, mask=mask) - def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): - divfactor = 1. / (127. * 127.) + divfactor = 1.0 / (127.0 * 127.0) has_bias = 0 if bias is None else 1 @@ -154,12 +179,28 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): # allocates output c = torch.empty((M, N), device=device, dtype=torch.float16) # accumulator types - ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 # launch int8_matmul_rowwise_dequantize kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - GROUP_M=8, ACC_TYPE=ACC_TYPE) + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) + _int8_matmul_rowwise_dequantize[grid]( + a, + b, + c, + bias, + state_x, + state_w, + M, + N, + K, + divfactor, + has_bias, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + GROUP_M=8, + ACC_TYPE=ACC_TYPE, + ) return c diff --git a/bitsandbytes/triton/quantize_columnwise_and_transpose.py b/bitsandbytes/triton/quantize_columnwise_and_transpose.py index e7961cf53..b8eeffd0c 100644 --- a/bitsandbytes/triton/quantize_columnwise_and_transpose.py +++ b/bitsandbytes/triton/quantize_columnwise_and_transpose.py @@ -5,9 +5,10 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def quantize_columnwise_and_transpose(x: torch.Tensor): return None -else: + def quantize_columnwise_and_transpose(x: torch.Tensor): + return None +else: import triton import triton.language as tl @@ -15,23 +16,23 @@ def quantize_columnwise_and_transpose(x: torch.Tensor): return None # TODO: autotune this better. @triton.autotune( - configs=[ - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_stages=16), - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=16, num_warps=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=['n_elements'] + configs=[ + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_stages=16), + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=16, num_warps=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["n_elements"], ) @triton.jit def _quantize_columnwise_and_transpose( @@ -39,7 +40,8 @@ def _quantize_columnwise_and_transpose( output_ptr, output_maxs, n_elements, - M : tl.constexpr, N : tl.constexpr, + M: tl.constexpr, + N: tl.constexpr, BLOCK_SIZE: tl.constexpr, P2: tl.constexpr, ): @@ -47,12 +49,12 @@ def _quantize_columnwise_and_transpose( block_start = pid p2_arange = tl.arange(0, P2) p2_arange_mask = p2_arange < M - arange = p2_arange * N + arange = p2_arange * N offsets = block_start + arange x = tl.load(x_ptr + offsets, mask=p2_arange_mask) abs_x = tl.abs(x) max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127. * (x / max_val)) + output = tl.libdevice.llrint(127.0 * (x / max_val)) new_start = pid * M new_offsets = new_start + p2_arange @@ -68,6 +70,6 @@ def quantize_columnwise_and_transpose(x: torch.Tensor): assert x.is_cuda and output.is_cuda n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) return output, output_maxs diff --git a/bitsandbytes/triton/quantize_global.py b/bitsandbytes/triton/quantize_global.py index 5cf194744..f35bdd304 100644 --- a/bitsandbytes/triton/quantize_global.py +++ b/bitsandbytes/triton/quantize_global.py @@ -1,24 +1,25 @@ - import torch from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def quantize_global_transpose(input): return None - def quantize_global(x: torch.Tensor): return None -else: + def quantize_global_transpose(input): + return None + + def quantize_global(x: torch.Tensor): + return None +else: import triton import triton.language as tl # global quantize @triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), - triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), - - ], - key=['n_elements'] + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_stages=1), + ], + key=["n_elements"], ) @triton.jit def _quantize_global( @@ -34,35 +35,43 @@ def _quantize_global( mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) absmax_inv = tl.load(absmax_inv_ptr) - output = tl.libdevice.llrint(127. * (x * absmax_inv)) + output = tl.libdevice.llrint(127.0 * (x * absmax_inv)) tl.store(output_ptr + offsets, output, mask=mask) def quantize_global(x: torch.Tensor): absmax = x.abs().max().unsqueeze(0) - absmax_inv = 1./ absmax - output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) + absmax_inv = 1.0 / absmax + output = torch.empty(*x.shape, device="cuda", dtype=torch.int8) assert x.is_cuda and output.is_cuda n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) _quantize_global[grid](x, absmax_inv, output, n_elements) return output, absmax - # global quantize and transpose @triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), - - # ... - ], - key=['M', 'N'] + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), + # ... + ], + key=["M", "N"], ) @triton.jit - def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, - BLOCK_M : tl.constexpr, - BLOCK_N : tl.constexpr, - GROUP_M : tl.constexpr): + def _quantize_global_transpose( + A, + absmax_inv_ptr, + B, + stride_am, + stride_an, + stride_bn, + stride_bm, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_M: tl.constexpr, + ): pid = tl.program_id(0) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N @@ -86,20 +95,30 @@ def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, strid B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) mask = (rm < M)[:, None] & (rn < N)[None, :] - output = tl.libdevice.llrint(127. * (a * absmax_inv)) + output = tl.libdevice.llrint(127.0 * (a * absmax_inv)) tl.store(B, output, mask=mask) def quantize_global_transpose(input): absmax = input.abs().max().unsqueeze(0) - absmax_inv = 1./ absmax + absmax_inv = 1.0 / absmax M, N = input.shape - out = torch.empty(N, M, device='cuda', dtype=torch.int8) + out = torch.empty(N, M, device="cuda", dtype=torch.int8) assert out.size(0) == N and out.size(1) == M assert input.stride(0) == 1 or input.stride(1) == 1 assert out.stride(0) == 1 or out.stride(1) == 1 - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) - _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + _quantize_global_transpose[grid]( + input, + absmax_inv, + out, + input.stride(0), + input.stride(1), + out.stride(0), + out.stride(1), + M, + N, + ) return out, absmax diff --git a/bitsandbytes/triton/quantize_rowwise.py b/bitsandbytes/triton/quantize_rowwise.py index 078f4aa2d..f92ace02c 100644 --- a/bitsandbytes/triton/quantize_rowwise.py +++ b/bitsandbytes/triton/quantize_rowwise.py @@ -5,9 +5,10 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def quantize_rowwise(x: torch.Tensor): return None -else: + def quantize_rowwise(x: torch.Tensor): + return None +else: import triton import triton.language as tl @@ -15,21 +16,21 @@ def quantize_rowwise(x: torch.Tensor): return None # TODO: autotune this better. @triton.autotune( - configs=[ - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=['n_elements'] + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["n_elements"], ) @triton.jit def _quantize_rowwise( @@ -49,7 +50,7 @@ def _quantize_rowwise( abs_x = tl.abs(x) max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127. * (x / max_val)) + output = tl.libdevice.llrint(127.0 * (x / max_val)) tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_maxs + pid, max_val) diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 0582f7fc0..48c7fc82d 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -30,7 +30,7 @@ def outlier_hook(module, input): # (1) zscore test of std of hidden dimension outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) # (2) magnitude > 6 test - dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1))) + dims = (torch.abs(input[0]) > 6).sum(dim=list(range(len(input[0].shape) - 1))) outlier_idx2 = torch.where(dims > 0)[0] outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() tracer.hvalue2outlier_idx[hvalue] = outlier_idx @@ -59,14 +59,14 @@ def initialize(self, model): self.hooks.append(m.register_forward_pre_hook(outlier_hook)) def is_initialized(self): - return getattr(self, 'initialized', False) + return getattr(self, "initialized", False) def get_hvalue(self, weight): return weight.data.storage().data_ptr() def get_outliers(self, weight): if not self.is_initialized(): - print('Outlier tracer is not initialized...') + print("Outlier tracer is not initialized...") return None hvalue = self.get_hvalue(weight) if hvalue in self.hvalue2outlier_idx: @@ -80,6 +80,7 @@ def get_instance(cls): cls._instance = cls.__new__(cls) return cls._instance + def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): if rdm: return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() @@ -87,13 +88,13 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False) m = weight.mean(reduction_dim) mm = m.mean() mstd = m.std() - zm = (m-mm)/mstd + zm = (m - mm) / mstd std = weight.std(reduction_dim) stdm = std.mean() stdstd = std.std() - zstd = (std-stdm)/stdstd + zstd = (std - stdm) / stdstd if topk is not None: val, idx = torch.topk(std.abs(), k=topk, dim=0) @@ -105,10 +106,7 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False) def execute_and_return(command_string: str) -> Tuple[str, str]: def _decode(subprocess_err_out_tuple): - return tuple( - to_decode.decode("UTF-8").strip() - for to_decode in subprocess_err_out_tuple - ) + return tuple(to_decode.decode("UTF-8").strip() for to_decode in subprocess_err_out_tuple) def execute_and_return_decoded_std_streams(command_string): return _decode( @@ -116,14 +114,13 @@ def execute_and_return_decoded_std_streams(command_string): shlex.split(command_string), stdout=subprocess.PIPE, stderr=subprocess.PIPE, - ).communicate() + ).communicate(), ) std_out, std_err = execute_and_return_decoded_std_streams(command_string) return std_out, std_err - def replace_linear( model, linear_replacement, @@ -163,8 +160,9 @@ def replace_linear( model._modules[name].bias = old_module.bias if post_processing_function is not None: - func = getattr(module, post_processing_function, None) - if func is not None: func(module) + func = getattr(module, post_processing_function, None) + if func is not None: + func(module) return model @@ -179,7 +177,7 @@ def pack_dict_to_tensor(source_dict): A torch tensor containing the packed data. """ json_str = json.dumps(source_dict) - json_bytes = json_str.encode('utf-8') + json_bytes = json_str.encode("utf-8") tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8) return tensor_data @@ -196,7 +194,7 @@ def unpack_tensor_to_dict(tensor_data): A Python dictionary containing the unpacked data. """ json_bytes = bytes(tensor_data.cpu().numpy()) - json_str = json_bytes.decode('utf-8') + json_str = json_bytes.decode("utf-8") unpacked_dict = json.loads(json_str) return unpacked_dict diff --git a/check_bnb_install.py b/check_bnb_install.py index 5a7f74f89..7a9dc93fc 100644 --- a/check_bnb_install.py +++ b/check_bnb_install.py @@ -2,14 +2,14 @@ import bitsandbytes as bnb -p = torch.nn.Parameter(torch.rand(10,10).cuda()) -a = torch.rand(10,10).cuda() +p = torch.nn.Parameter(torch.rand(10, 10).cuda()) +a = torch.rand(10, 10).cuda() p1 = p.data.sum().item() adam = bnb.optim.Adam([p]) -out = a*p +out = a * p loss = out.sum() loss.backward() adam.step() @@ -17,5 +17,5 @@ p2 = p.data.sum().item() assert p1 != p2 -print('SUCCESS!') -print('Installation was successful!') +print("SUCCESS!") +print("Installation was successful!") diff --git a/examples/int8_inference_huggingface.py b/examples/int8_inference_huggingface.py index c89ba8d11..2d4c77952 100644 --- a/examples/int8_inference_huggingface.py +++ b/examples/int8_inference_huggingface.py @@ -2,23 +2,18 @@ from transformers import LlamaForCausalLM, LlamaTokenizer MAX_NEW_TOKENS = 128 -model_name = 'meta-llama/Llama-2-7b-hf' +model_name = "meta-llama/Llama-2-7b-hf" -text = 'Hamburg is in which country?\n' +text = "Hamburg is in which country?\n" tokenizer = LlamaTokenizer.from_pretrained(model_name) input_ids = tokenizer(text, return_tensors="pt").input_ids -max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' +max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB" n_gpus = torch.cuda.device_count() max_memory = {i: max_memory for i in range(n_gpus)} -model = LlamaForCausalLM.from_pretrained( - model_name, - device_map='auto', - load_in_8bit=True, - max_memory=max_memory -) +model = LlamaForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory) generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) diff --git a/install_cuda.py b/install_cuda.py index b41b33b39..9e426cbd7 100644 --- a/install_cuda.py +++ b/install_cuda.py @@ -19,6 +19,7 @@ "123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run", } + def install_cuda(version, base_path, download_path): formatted_version = f"{version[:-1]}.{version[-1]}" folder = f"cuda-{formatted_version}" @@ -29,7 +30,7 @@ def install_cuda(version, base_path, download_path): subprocess.run(["rm", "-rf", install_path], check=True) url = cuda_versions[version] - filename = url.split('/')[-1] + filename = url.split("/")[-1] filepath = os.path.join(download_path, filename) if not os.path.exists(filepath): @@ -44,9 +45,14 @@ def install_cuda(version, base_path, download_path): # Install CUDA print(f"Installing CUDA version {version}...") install_command = [ - "bash", filepath, - "--no-drm", "--no-man-page", "--override", - "--toolkitpath=" + install_path, "--toolkit", "--silent" + "bash", + filepath, + "--no-drm", + "--no-man-page", + "--override", + "--toolkitpath=" + install_path, + "--toolkit", + "--silent", ] print(f"Running command: {' '.join(install_command)}") @@ -62,6 +68,7 @@ def install_cuda(version, base_path, download_path): print(f"CUDA version {version} installed at {install_path}") + def main(): user_base_path = os.path.expanduser("~/cuda") system_base_path = "/usr/local/cuda" @@ -93,5 +100,6 @@ def main(): print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}") sys.exit(1) + if __name__ == "__main__": main() diff --git a/scripts/stale.py b/scripts/stale.py index 613f5b7cb..a65652aeb 100644 --- a/scripts/stale.py +++ b/scripts/stale.py @@ -15,6 +15,7 @@ Script to close stale issue. Taken in part from the AllenNLP repository. https://github.com/allenai/allennlp. """ + from datetime import datetime as dt, timezone import os @@ -50,7 +51,7 @@ def main(): issue.create_comment( "This issue has been automatically marked as stale because it has not had " "recent activity. If you think this still needs to be addressed " - "please comment on this thread.\n\n" + "please comment on this thread.\n\n", ) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index d01e5e9db..9da665a2d 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -20,7 +20,11 @@ @pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) -@pytest.mark.parametrize("funcs", [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], ids=["func=bmm", "func=matmul"]) +@pytest.mark.parametrize( + "funcs", + [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], + ids=["func=bmm", "func=matmul"], +) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad")) @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) @@ -30,16 +34,13 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool dim3 = dim3 - (dim3 % 16) dim4 = dim4 - (dim4 % 16) for i in range(25): - # normal multiply if funcs[0] in [torch.mm, torch.matmul]: dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0]) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) - target = torch.randn( - size=(dim2, dim4), device="cuda", requires_grad=req_grad[1] - ) + target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]) torch.nn.init.xavier_uniform_(B) if not transpose[0] and not transpose[1]: @@ -71,9 +72,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss( - out_torch, target - ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -81,18 +80,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool B.grad = None if req_grad[0]: - torch.testing.assert_close( - gradA1, gradA2, atol=0.015, rtol=0.1 - ) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() < n * 0.02 - torch.testing.assert_close( - gradB1, gradB2, atol=0.18, rtol=0.3 - ) + torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) # batched matrix multiply if funcs[0] in [torch.bmm, torch.matmul]: @@ -119,9 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) assert (idx == 0).sum().item() < n * 0.01 - torch.testing.assert_close( - out_bnb, out_torch, atol=0.027, rtol=0.2 - ) + torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2) if any(req_grad): out_bnb.data.copy_(out_torch) @@ -133,9 +126,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss( - out_torch, target - ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -143,9 +134,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool B.grad = None if req_grad[0]: - torch.testing.assert_close( - gradA1, gradA2, atol=0.015, rtol=0.1 - ) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) @@ -192,9 +181,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss( - out_torch, target - ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -202,9 +189,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool B.grad = None if req_grad[0]: - torch.testing.assert_close( - gradA1, gradA2, atol=0.015, rtol=0.1 - ) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) @@ -218,25 +203,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) @pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp")) -@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], ids=["func=matmul", "func=switchback_bnb"]) +@pytest.mark.parametrize( + "funcs", + [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], + ids=["func=matmul", "func=switchback_bnb"], +) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) -def test_matmullt( - dim1, - dim2, - dim3, - dim4, - funcs, - dtype, - req_grad, - transpose, - decomp, - has_fp16_weights, - has_bias -): +def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") @@ -245,18 +222,13 @@ def test_matmullt( req_grad[2] = False for i in range(3): - # normal multiply if funcs[0] in [torch.mm, torch.matmul]: - A = torch.randn( - size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype - ) + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) if decomp == 6.0: with torch.no_grad(): A[:, outlier_dim] = 6.0 - B = torch.randn( - size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype - ) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) target = torch.randn( size=(dim2, dim4), device="cuda", @@ -266,7 +238,7 @@ def test_matmullt( bias = None bias2 = None if has_bias: - bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) B2 = B.clone() @@ -311,9 +283,7 @@ def test_matmullt( if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() - loss_bnb = torch.nn.functional.mse_loss( - out_bnb, target - ).mean() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() loss_bnb.backward() gradA1 = A.grad gradB1 = B.grad @@ -323,9 +293,7 @@ def test_matmullt( gradBias1 = bias.grad bias.grad = None - loss_torch = torch.nn.functional.mse_loss( - out_torch, target - ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -336,9 +304,7 @@ def test_matmullt( bias.grad = None if req_grad[0]: - torch.testing.assert_close( - gradA1, gradA2, atol=0.015, rtol=0.1 - ) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() if dim2 > 0: @@ -352,9 +318,7 @@ def test_matmullt( assert (idx == 0).sum().item() <= n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() <= n * 0.02 - torch.testing.assert_close( - gradB1, gradB2, atol=0.18, rtol=0.3 - ) + torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) if req_grad[2]: torch.testing.assert_close(gradBias1, gradBias2) @@ -370,8 +334,20 @@ def test_matmullt( @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) -@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'], ids=id_formatter("quant_type")) -def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type): +@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type")) +def test_matmul_4bit( + dim1, + dim2, + dim3, + dim4, + funcs, + dtype, + req_grad, + transpose, + has_bias, + compress_statistics, + quant_type, +): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) if has_bias == False: @@ -387,11 +363,15 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, bias = None bias2 = None if has_bias: - bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) - B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type) + B2, quant_state = bnb.functional.quantize_4bit( + B, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) @@ -410,7 +390,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, if n > 0: assert err < 0.115 - #assert err < 0.20 + # assert err < 0.20 if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() @@ -424,7 +404,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, gradBias1 = bias.grad bias.grad = None - loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -435,7 +415,7 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, bias.grad = None if req_grad[0]: - torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[2]: torch.testing.assert_close(gradBias1, gradBias2) @@ -448,8 +428,12 @@ def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) -@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], ids=["matmul_fp8_mixed", 'matmul_fp8_global']) -def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): +@pytest.mark.parametrize( + "funcs", + [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], + ids=["matmul_fp8_mixed", "matmul_fp8_global"], +) +def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) req_grad = list(req_grad) @@ -480,7 +464,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): err = torch.abs(out_bnb - out_torch).float().mean().item() if n > 0: assert err < 0.115 - #assert err < 0.20 + # assert err < 0.20 if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() @@ -491,7 +475,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -499,7 +483,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() @@ -514,8 +498,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): assert (idx == 0).sum().item() <= n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() <= n * 0.02 - grad_err = (gradB1-gradB2).abs().mean() + grad_err = (gradB1 - gradB2).abs().mean() assert grad_err.item() < 0.003 - torch.testing.assert_close( - gradB1, gradB2, atol=0.18, rtol=0.3 - ) + torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index cb0b38fdd..fc79a54b0 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -35,7 +35,4 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) - assert ( - get_cuda_bnb_library_path(cuda111_noblas_spec).stem - == "libbitsandbytes_cuda111_nocublaslt" - ) + assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt" diff --git a/tests/test_functional.py b/tests/test_functional.py index d4f65755f..b9f1a6ead 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -19,9 +19,7 @@ id_formatter, ) -torch.set_printoptions( - precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 -) +torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) k = 20 @@ -98,9 +96,7 @@ def teardown(): pass -@pytest.mark.parametrize( - "dtype", [torch.float32, torch.float16], ids=["float", "half"] -) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"]) def test_estimate_quantiles(dtype): A = torch.rand(1024, 1024, device="cuda") A = A.to(dtype) @@ -136,7 +132,6 @@ def test_quantile_quantization(): assert diff < 0.001 - def test_dynamic_quantization(): diffs = [] reldiffs = [] @@ -149,8 +144,8 @@ def test_dynamic_quantization(): diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diff.mean().item() < 0.0135 - print(sum(diffs)/len(diffs)) - print(sum(reldiffs)/len(reldiffs)) + print(sum(diffs) / len(diffs)) + print(sum(reldiffs) / len(reldiffs)) for i in range(100): A1 = torch.rand(1024, 1024, device="cuda") @@ -161,13 +156,12 @@ def test_dynamic_quantization(): assert diff < 0.004 - @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): - #print('') + # print('') diffs = [] reldiffs = [] for i in range(100): @@ -178,10 +172,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): reldiff = diff / torch.abs(A1.float() + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) - #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs)) - #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs)) + abserr = sum(diffs) / len(diffs) + relerr = sum(reldiffs) / len(reldiffs) + # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs)) + # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs)) assert abserr < 0.011 assert relerr < 0.018 assert A2.dtype == dtype @@ -196,9 +190,9 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): reldiff = diff / torch.abs(A1.float() + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) - #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) + # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) + abserr = sum(diffs) / len(diffs) + relerr = sum(reldiffs) / len(reldiffs) if signed: assert abserr < 0.0035 assert relerr < 0.015 @@ -206,14 +200,11 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): assert abserr < 0.00175 assert relerr < 0.012 assert A2.dtype == dtype - #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) - #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) - + # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) -@pytest.mark.parametrize( - "gtype", [torch.float32, torch.float16], ids=["float", "half"] -) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) def test_percentile_clipping(gtype): gnorm_vec1 = torch.zeros(100, device="cuda") gnorm_vec2 = torch.zeros(100, device="cuda") @@ -223,9 +214,7 @@ def test_percentile_clipping(gtype): for i in range(k): step += 1 g = torch.randn(n, n, dtype=gtype, device="cuda") - gnorm1, clip2, gnorm_scale = F.percentile_clipping( - g, gnorm_vec2, step, percentile=percentile - ) + gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 gnorm2 = torch.norm(g.float()) @@ -309,7 +298,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): dim2 = dim2 - (dim2 % 32) errors = [] relerrors = [] - #print("") + # print("") for i in range(5): if batched: A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") @@ -321,9 +310,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") maxA, Ac = quant_methods[0](A, 1) maxB, Bc = quant_methods[1](B, 0) - torch.testing.assert_close( - quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05 - ) + torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) if batched: out2 = torch.bmm(A, B) C = torch.bmm(Ac.float(), Bc.float()) @@ -338,8 +325,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): relerr = err / torch.abs(out2) errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) - #print(mean(errors)) - #print(mean(relerrors)) + # print(mean(errors)) + # print(mean(relerrors)) def test_stable_embedding(): @@ -356,16 +343,8 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): batch_dim = batch_dim - (batch_dim % 16) seq_dim = seq_dim - (seq_dim % 16) for i in range(k): - shapeA = ( - (batch_dim, hidden_dim) - if not transpose[0] - else (hidden_dim, batch_dim) - ) - shapeB = ( - (32 * random.randint(1, 4), hidden_dim) - if transpose[1] - else (hidden_dim, 32 * random.randint(1, 4)) - ) + shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) + shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: @@ -385,11 +364,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): for i in range(k): shapeA = (batch_dim, seq_dim, hidden_dim) - shapeB = ( - (32 * random.randint(1, 4), hidden_dim) - if transpose[1] - else (hidden_dim, 32 * random.randint(1, 4)) - ) + shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: @@ -410,16 +385,10 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): hidden_dim = hidden_dim - (hidden_dim % 32) batch_dim = batch_dim - (batch_dim % 2) for i in range(25): - A = torch.randint( - -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda" - ).to(torch.int8) - B = torch.randint( - -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda" - ).to(torch.int8) + A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8) out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) - iout = torch.empty( - A.shape[2], B.shape[2], dtype=torch.int32, device=A.device - ) + iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) out = F.igemm(A, B, out=iout) torch.testing.assert_close(out.float(), out2) @@ -444,9 +413,7 @@ def min_max(x): errs2 = [] relerrs2 = [] for i in range(k): - A = torch.normal( - 0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda" - ) + A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda") if transpose: B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") else: @@ -523,9 +490,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) out = F.igemm(A.permute([0, 2, 1]), B) elif transpose[0] and transpose[1]: - out2 = torch.bmm( - A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float() - ) + out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) torch.testing.assert_close(out.float(), out2.float()) @@ -541,7 +506,7 @@ def test_vector_quant(dim1, dim2, dim3): qA, SA = F.vectorwise_quant(A, dim=0) A1 = F.vectorwise_dequant(qA, SA) n = A1.numel() - assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002)) + assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002)) @pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) @@ -565,9 +530,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans if dims == 2: A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( - dtype - ) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype) out, S = F.nvidia_transform(A, to_order=orderOut) @@ -579,17 +542,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans if dims == 2: n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) elif dims == 3: - n = ( - A.shape[0] - * A.shape[1] - * (A.shape[2] + (32 - (A.shape[2] % 32))) - ) + n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32))) assert out.numel() == n elif orderOut == "col_turing": # 32 col 8 row tiles - n = (A.shape[0] + (8 - A.shape[0] % 8)) * ( - A.shape[1] + (32 - (A.shape[1] % 32)) - ) + n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32))) assert out.numel() == n total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) for row in range(A.shape[0]): @@ -598,9 +555,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans j = col coltile = (col // 32) + (1 if col % 32 != 0 else 0) - rowtile = ( - (row // 8) + (1 if row % 8 != 0 else 0) - ) * total_coltile + rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile offset = 32 * 8 * (rowtile + coltile) col2 = col % 32 row2 = (row % 8) * 32 @@ -611,9 +566,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) if orderOut == "col32": - out2, S = F.nvidia_transform( - out, from_order=orderOut, to_order="row", state=S - ) + out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S) torch.testing.assert_close(A, out2) @@ -626,16 +579,10 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): for i in range(k): if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to( - torch.int8 - ) + A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8) elif dims == 3: - A = torch.randint( - -128, 127, size=(dim1, dim2, dim3), device="cuda" - ).to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to( - torch.int8 - ) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) A2, SA = F.transform(A, "col32") @@ -645,9 +592,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): torch.testing.assert_close(C1, C3.float()) # transpose - B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( - torch.int8 - ) + B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.float()) B2t, SBt = F.transform(B, "col_turing", transpose=True) @@ -667,9 +612,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): if dims == 2: A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() elif dims == 3: - A = torch.normal( - 0, 0.5, size=(dim1, dim2, dim3), device="cuda" - ).half() + A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half() B = torch.randn((dim4, dim3), device="cuda").half() torch.nn.init.xavier_uniform_(B) C1 = torch.matmul(A, B.t()) @@ -700,6 +643,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): # C3, S = F.transform(C2, 'row', state=SC) # torch.testing.assert_close(C1, C3.float()) + @pytest.mark.parametrize( ("batch", "seq", "model", "hidden"), [ @@ -729,7 +673,6 @@ def test_bench_8bit_training(batch, seq, model, hidden): torch.cuda.synchronize() t0 = time.time() for i in range(k): - out1 = torch.matmul(A, w1.t()) # fc1 # out2 = torch.matmul(out1, w2.t())# fc2 @@ -866,13 +809,15 @@ def test_bench_8bit_training(batch, seq, model, hidden): def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): inner = torch.randint(1, 128, size=(1,)).item() bias = None - if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16) + if has_bias: + bias = torch.randn(dim4, device="cuda", dtype=torch.float16) formatB = F.get_special_format_str() for i in range(1): A = torch.randn(dim1, inner, device="cuda") B = torch.randn(dim4, inner, device="cuda") C1 = torch.matmul(A.half(), B.t().half()) - if has_bias: C1 += bias + if has_bias: + C1 += bias A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) @@ -883,7 +828,8 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): C3, S = F.nvidia_transform(C2, "row", state=SC) C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) - if has_bias: C4 += bias + if has_bias: + C4 += bias # TODO: is something wrong here? If so, the problem goes deeper # n = C1.numel() @@ -917,9 +863,7 @@ def test_colrow_absmax(dim1, dim2, dims): else: assert False - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( - A, threshold=threshold - ) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) A_blocked = einops.rearrange( torch.abs(A), @@ -939,9 +883,7 @@ def test_colrow_absmax(dim1, dim2, dims): torch.testing.assert_close(row_stats1_trunc, row_stats2) torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( - A, threshold=0.0 - ) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) torch.testing.assert_close(col_stats1, col_stats2) torch.testing.assert_close(row_stats1, row_stats2) @@ -963,24 +905,16 @@ def test_double_quant(dim1, dim2): torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0) n = CAt.numel() - num_not_close_rows = ( - (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() - ) - num_not_close_cols = ( - (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() - ) + num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() + num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() # allow for 1:500 error due to rounding differences min_error = 1 / 500 if num_not_close_cols > (min_error * n): - print( - f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}" - ) + print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}") assert False if num_not_close_rows > (min_error * n): - print( - f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}" - ) + print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}") assert False torch.testing.assert_close(Srow.flatten().float(), statsA) @@ -991,13 +925,12 @@ def test_double_quant(dim1, dim2): ("dim1", "dim4", "inner"), ( pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") - for (dim1, dim4, inner) - in zip( + for (dim1, dim4, inner) in zip( get_test_dims(1, 4 * 1024, n=4), get_test_dims(1, 4 * 1024, n=4), get_test_dims(1, 4 * 1024, n=4), ) - ) + ), ) def test_integrated_igemmlt(dim1, dim4, inner): for i in range(k): @@ -1037,13 +970,12 @@ def test_integrated_igemmlt(dim1, dim4, inner): ("dim1", "dim4", "inner"), ( pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") - for (dim1, dim4, inner) - in zip( + for (dim1, dim4, inner) in zip( get_test_dims(1, 4 * 1024, n=6), get_test_dims(1, 4 * 1024, n=6), get_test_dims(1, 4 * 1024, n=6), ) - ) + ), ) @pytest.mark.skip("Row scale has some bugs for ampere") def test_igemmlt_row_scale(dim1, dim4, inner): @@ -1067,9 +999,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner): c = 10.0 * inner * scale row_scale = torch.ones_like(maxA) / c - outC32, SC = F.igemmlt( - A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale - ) + outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) C3, S = F.nvidia_transform(outC32, "row", state=SC) maxval = torch.abs(C3).max() if maxval == 127: @@ -1150,9 +1080,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32, SC = F.igemmlt( - A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale - ) + outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) torch.cuda.synchronize() print("row-wise", time.time() - t0) @@ -1177,13 +1105,9 @@ def test_row_scale_bench(dim1, dim4, inner): def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): for i in range(k): if dims == 2: - A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to( - dtype - ) + A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) elif dims == 3: - A = torch.randint( - 10, 99, size=(dim1, dim2, dim3), device="cuda" - ).to(dtype) + A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) A.view(-1)[-1] = -1 if transpose: @@ -1224,23 +1148,17 @@ def test_coo_double_quant(dim1, dim2): idx = torch.abs(A) >= threshold CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( - A, threshold=threshold - ) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) if coo_tensor is not None: A1 = A * idx A2 = torch.zeros_like(A) - A2[ - coo_tensor.rowidx.long(), coo_tensor.colidx.long() - ] = coo_tensor.values + A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values torch.testing.assert_close(A1, A2) A1 = A * (idx == 0) A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() - torch.testing.assert_close( - A * (idx == 0), A2, rtol=0.05, atol=1.5e-2 - ) + torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) @pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) @@ -1261,9 +1179,7 @@ def test_spmm_coo(dim1, dim2, transposed_B): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx if transposed_B: @@ -1303,9 +1219,7 @@ def test_spmm_bench(): print(nnz / idx.numel()) rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) for i in range(10): out2 = F.spmm_coo(cooA, B) @@ -1339,9 +1253,7 @@ def test_integrated_sparse_decomp(dim1, dim2): out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( - A, threshold=threshold - ) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) C32A, SA = F.transform(CA, "col32") out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) @@ -1396,9 +1308,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx out1 = torch.matmul(A2.half(), B.half()) out = out_func(out1.shape, dtype=torch.float16, device=out1.device) @@ -1413,9 +1323,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): std = out1.std() out1 /= std out2 /= std - assert_all_approx_close( - out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count - ) + assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) idx_col = torch.randint(0, A2.shape[-1], size=(15,)) @@ -1443,9 +1351,7 @@ def test_coo2csr(): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx csrA = F.coo2csr(cooA) counts = csrA.rowptr[1:] - csrA.rowptr[:-1] @@ -1463,9 +1369,7 @@ def test_coo2csc(): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx cscA = F.coo2csc(cooA) counts = cscA.colptr[1:] - cscA.colptr[:-1] @@ -1499,9 +1403,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) out1 = torch.matmul(A2, B.half()) @@ -1582,7 +1484,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): @pytest.mark.parametrize( ("batch", "seq", "model", "hidden"), - [pytest.param(1, 1, 6656, 4*6656, id="batch=1, seq=1, model=6656, hidden=26k")], + [pytest.param(1, 1, 6656, 4 * 6656, id="batch=1, seq=1, model=6656, hidden=26k")], ) @pytest.mark.benchmark def test_bench_matmul(batch, seq, model, hidden): @@ -1605,8 +1507,8 @@ def test_bench_matmul(batch, seq, model, hidden): outliers = torch.randint(0, model, size=(5,)).cuda() A[:, :, outliers] = 8.0 - linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()) - #linearMixedBit.eval() + linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half() + # linearMixedBit.eval() linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() @@ -1623,121 +1525,123 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): torch.matmul(A, B.t()) torch.cuda.synchronize() - print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + print( + f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s", + ) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) - #torch.cuda.synchronize() - #print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + # torch.cuda.synchronize() + # print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) - #torch.cuda.synchronize() - #print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + # torch.cuda.synchronize() + # print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() for i in range(iters): bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) torch.cuda.synchronize() - print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") torch.cuda.synchronize() t0 = time.time() for i in range(iters): bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c) torch.cuda.synchronize() - print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # bnb.matmul(A, B) - #torch.cuda.synchronize() - #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # bnb.matmul(A, B, threshold=6.0) - #torch.cuda.synchronize() - #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - #C32A, SA = F.transform(CA, "col32") - #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) - #CxB, SB = F.transform(CB, to_order=formatB) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + # C32A, SA = F.transform(CA, "col32") + # CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + # CxB, SB = F.transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - #torch.cuda.synchronize() - #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - #BA, statsB = F.vectorwise_quant(B, dim=1) - #CxB, SB = F.nvidia_transform(CB, to_order=formatB) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # BA, statsB = F.vectorwise_quant(B, dim=1) + # CxB, SB = F.nvidia_transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # A2 = A.view(-1, A.shape[-1]).contiguous() # CA, statsA = F.vectorwise_quant(A2, dim=1) # C32A, SA = F.nvidia_transform(CA, "col32") # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) - #torch.cuda.synchronize() - #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") - #CxB, SB = F.nvidia_transform(CB, to_order=formatB) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") + # CxB, SB = F.nvidia_transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # A2 = A.view(-1, A.shape[-1]).contiguous() # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") # C32A, SA = F.nvidia_transform(CA, "col32") # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # out = Cout * statsB * statsA * (1.0 / (127 * 127)) - #torch.cuda.synchronize() - #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linear8bit(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # linear8bit(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # linear8bit(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linearMixedBit(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # linearMixedBit(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # linearMixedBit(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linear8bit_train(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # linear8bit_train(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # linear8bit_train(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linear8bit_train_thresh(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # linear8bit_train_thresh(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # linear8bit_train(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + def test_zeropoint(): def quant_zp(x): @@ -1778,8 +1682,8 @@ def quant_zp(x): C2 -= A.sum(1).view(-1, 1) * zp ca, cqa, cza = quant_zp(A) - #print(ca.min(), ca.max()) - #print((ca - cza).min(), (ca - cza).max()) + # print(ca.min(), ca.max()) + # print((ca - cza).min(), (ca - cza).max()) zp = 1 scale = 2.0 @@ -1808,14 +1712,14 @@ def quant_zp(x): C7 -= zpa * zpb * A.shape[1] C7 /= qa * qb - #print("") + # print("") # print(C0.flatten()[:10]) - #print(C1.flatten()[:10]) - #print(C2.flatten()[:10]) - #print(C3.flatten()[:10]) - #print(C5.flatten()[:10]) - #print(C6.flatten()[:10]) - #print(C7.flatten()[:10]) + # print(C1.flatten()[:10]) + # print(C2.flatten()[:10]) + # print(C3.flatten()[:10]) + # print(C5.flatten()[:10]) + # print(C6.flatten()[:10]) + # print(C7.flatten()[:10]) err1 = torch.abs(C1 - C2).mean().item() err2 = torch.abs(C1 - C3).mean().item() err3 = torch.abs(C1 - C4).mean().item() @@ -1852,16 +1756,15 @@ def test_extract_outliers(): torch.testing.assert_close(outliers1, outliers2) - def test_blockwise_cpu_large(): diffs = [] reldiffs = [] batch = 128 seq = 128 - for hidden in [128]:#, 14336]: + for hidden in [128]: # , 14336]: for blocksize in [4096, 16384]: for i in range(2): - A1 = torch.randn(batch, seq, hidden, device='cpu') + A1 = torch.randn(batch, seq, hidden, device="cpu") t0 = time.time() C, S = F.quantize_blockwise(A1, blocksize=blocksize) A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) @@ -1875,10 +1778,9 @@ def test_blockwise_cpu_large(): # print(sum(reldiffs)/len(reldiffs)) - def test_fp8_quant(): for e_bits in range(1, 7): - p_bits = 7-e_bits + p_bits = 7 - e_bits code = F.create_fp8_map(True, e_bits, p_bits).cuda() abserr = [] @@ -1888,12 +1790,12 @@ def test_fp8_quant(): C, SC = F.quantize_blockwise(A1, code=code) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) - reldiff = diff/torch.abs(A1+1e-8) + reldiff = diff / torch.abs(A1 + 1e-8) abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) - #assert diff < 0.0075 - #print(sum(abserr)/len(abserr)) - #print(sum(relerr)/len(relerr)) + # assert diff < 0.0075 + # print(sum(abserr)/len(abserr)) + # print(sum(relerr)/len(relerr)) abserr = [] relerr = [] @@ -1902,12 +1804,12 @@ def test_fp8_quant(): C, SC = F.quantize_blockwise(A1, code=code) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) - reldiff = diff/torch.abs(A1+1e-8) + reldiff = diff / torch.abs(A1 + 1e-8) abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) - #assert diff < 0.0075 - #print(sum(abserr)/len(abserr)) - #print(sum(relerr)/len(relerr)) + # assert diff < 0.0075 + # print(sum(abserr)/len(abserr)) + # print(sum(relerr)/len(relerr)) abserr = [] relerr = [] @@ -1916,50 +1818,48 @@ def test_fp8_quant(): C, SC = F.quantize_blockwise(A1) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) - reldiff = diff/torch.abs(A1+1e-8) + reldiff = diff / torch.abs(A1 + 1e-8) abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) - #assert diff < 0.0075 - #print(3, sum(abserr)/len(abserr)) - #print(3, sum(relerr)/len(relerr)) + # assert diff < 0.0075 + # print(3, sum(abserr)/len(abserr)) + # print(3, sum(relerr)/len(relerr)) def test_few_bit_quant(): - - #print('') + # print('') for bits in range(2, 9): - #print('='*30, bits, '='*30) - for method in ['linear', 'fp8', 'dynamic', 'quantile']: + # print('='*30, bits, '='*30) + for method in ["linear", "fp8", "dynamic", "quantile"]: abserrs = [] relerrs = [] code = None - if method == 'linear': + if method == "linear": code = F.create_linear_map(True, total_bits=bits).cuda() - elif method == 'fp8': - ebits = math.ceil(bits/2) - pbits = bits-ebits-1 + elif method == "fp8": + ebits = math.ceil(bits / 2) + pbits = bits - ebits - 1 code = F.create_fp8_map(True, ebits, pbits, bits).cuda() - elif method == 'dynamic': - code = F.create_dynamic_map(True, bits-0, bits).cuda() - elif method == 'quantile': - values = torch.randn(2048, 2048, device='cuda') + elif method == "dynamic": + code = F.create_dynamic_map(True, bits - 0, bits).cuda() + elif method == "quantile": + values = torch.randn(2048, 2048, device="cuda") code = F.create_quantile_map(values, bits).cuda() # for some data types we have no zero # for some data types we have one zero # for some data types we have two zeros - assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}' - #print(method, (code==0).sum()) + assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}" + # print(method, (code==0).sum()) assert code.numel() == 256 for i in range(10): - - values = torch.randn(1, 32, device='cuda') + values = torch.randn(1, 32, device="cuda") values /= values.abs().max() - #values[values.abs() < 1e-6] += 1e-5 + # values[values.abs() < 1e-6] += 1e-5 q1 = [] v1 = [] for v in values[0]: - idx = torch.abs(v-code).argmin() + idx = torch.abs(v - code).argmin() q1.append(idx.item()) v1.append(code[idx].item()) @@ -1970,62 +1870,61 @@ def test_few_bit_quant(): v2 = F.dequantize_blockwise(q2, S2) idx = torch.isclose(q1.int(), q2.int()) - err2 = torch.abs(v2-values) + err2 = torch.abs(v2 - values) abserrs.append(err2.mean().item()) - relerrs.append((err2/(1e-10+values).abs()).mean().item()) + relerrs.append((err2 / (1e-10 + values).abs()).mean().item()) if idx.sum(): # some weird cases - err1 = torch.abs(v1-values).mean() - #assert err2.mean() <= err1 + err1 = torch.abs(v1 - values).mean() + # assert err2.mean() <= err1 else: torch.testing.assert_close(q1, q2) - #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) - #assert False + # print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) + # assert False def test_kbit_quantile_estimation(): for i in range(100): - data = torch.randn(1024, 1024, device='cuda') + data = torch.randn(1024, 1024, device="cuda") for bits in range(2, 9): - p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits) + p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits) val1 = torch.Tensor(norm.ppf(p)).cuda() val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) - err = torch.abs(val1-val2).mean() + err = torch.abs(val1 - val2).mean() assert err < 0.038 for i in range(100): - data = torch.randn(1024, 1024, device='cuda') + data = torch.randn(1024, 1024, device="cuda") for bits in range(2, 4): - total_values = 2**bits-1 - p = np.linspace(0, 1, 2*total_values+1) - idx = np.arange(1, 2*total_values+1, 2) + total_values = 2**bits - 1 + p = np.linspace(0, 1, 2 * total_values + 1) + idx = np.arange(1, 2 * total_values + 1, 2) p = p[idx] - offset = 1/(2*total_values) - p = np.linspace(offset, 1-offset, total_values) + offset = 1 / (2 * total_values) + p = np.linspace(offset, 1 - offset, total_values) val1 = torch.Tensor(norm.ppf(p)).cuda() - val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1) - err = torch.abs(val1-val2).mean() + val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1) + err = torch.abs(val1 - val2).mean() assert err < 0.035 @pytest.mark.benchmark def test_bench_dequantization(): - a = torch.rand(1024, 1024, device='cuda').half() - code =F.create_fp8_map(True, 3, 0, 4).cuda() + a = torch.rand(1024, 1024, device="cuda").half() + code = F.create_fp8_map(True, 3, 0, 4).cuda() qa, SA = F.quantize_blockwise(a, code=code) print(qa.max()) - max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000 - #print(max_theoretical_mu) + max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000 + # print(max_theoretical_mu) torch.cuda.synchronize() t0 = time.time() for i in range(100): qa, SA = F.quantize_blockwise(a) torch.cuda.synchronize() - #print((time.time()-t0)/1e6) - + # print((time.time()-t0)/1e6) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @@ -2037,26 +1936,28 @@ def test_fp4_quant(dtype): result = 0 bias = 3 sign, e1, e2, p1 = bits - idx = sign*8 + e1*4 + e2*2 + p1*1 + idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1 sign = -1.0 if sign else 1.0 - exp = e1*2 + e2*1 + exp = e1 * 2 + e2 * 1 if exp == 0: # sub-normal - if p1 == 0: result = 0 - else: result = sign*0.0625 + if p1 == 0: + result = 0 + else: + result = sign * 0.0625 else: # normal - exp = 2**(-exp + bias + 1) + exp = 2 ** (-exp + bias + 1) frac = 1.5 if p1 else 1.0 - result = sign*exp*frac + result = sign * exp * frac code[idx] = result - A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype) + A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) qa, SA = F.quantize_fp4(A1, blocksize=64) A2 = F.dequantize_fp4(qa, SA) err = (A1 - A2).abs().float() - relerr = (err/(A1.abs().float()+1e-8)).mean() + relerr = (err / (A1.abs().float() + 1e-8)).mean() idx = err > 1.0 err = err.mean() @@ -2065,31 +1966,29 @@ def test_fp4_quant(dtype): assert relerr.item() < 0.28 -@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) def test_4bit_compressed_stats(quant_type): for blocksize in [128, 64]: errs1 = [] errs2 = [] for i in range(10): - A1 = torch.randn(1024, 1024, device='cuda').half() + A1 = torch.randn(1024, 1024, device="cuda").half() q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) - q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) + q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) - err = (A1 - A2).abs().float() - relerr = (err/(A1.abs().float()+1e-15)).mean() + relerr = (err / (A1.abs().float() + 1e-15)).mean() err = err.mean() errs1.append(err.item()) - assert err.item() < 0.11 assert relerr.item() < 0.28 err = (A1 - A3).abs().float() - relerr = (err/(A1.abs().float()+1e-15)).mean() + relerr = (err / (A1.abs().float() + 1e-15)).mean() err = err.mean() errs2.append(err.item()) @@ -2097,70 +1996,71 @@ def test_4bit_compressed_stats(quant_type): assert err.item() < 0.11 assert relerr.item() < 0.28 - #print(sum(errs1)/len(errs1), blocksize, quant_type) - #print(sum(errs2)/len(errs2), blocksize, quant_type) - + # print(sum(errs1)/len(errs1), blocksize, quant_type) + # print(sum(errs2)/len(errs2), blocksize, quant_type) - -#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) -@pytest.mark.parametrize("quant_type", ['nf4']) +# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +@pytest.mark.parametrize("quant_type", ["nf4"]) @pytest.mark.benchmark def test_bench_4bit_dequant(quant_type): blocksize = 256 - a = torch.rand(1024*12*4, 1024*12, device='cuda').half() + a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half() qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type) - input_size = a.numel()/2 - output_size = a.numel()*2 - num_bytes = input_size+output_size - GB = num_bytes/1e9 - max_theoretical_s = GB/768 - #print(max_theoretical_s*1e6) - b = torch.randn(128, 1024*12, device='cuda').half() + input_size = a.numel() / 2 + output_size = a.numel() * 2 + num_bytes = input_size + output_size + GB = num_bytes / 1e9 + max_theoretical_s = GB / 768 + # print(max_theoretical_s*1e6) + b = torch.randn(128, 1024 * 12, device="cuda").half() iters = 100 torch.cuda.synchronize() t0 = time.time() for i in range(iters): F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) - #b.copy_(a) + # b.copy_(a) torch.cuda.synchronize() - #print((time.time()-t0)/iters*1e6) + # print((time.time()-t0)/iters*1e6) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # torch.matmul(b, a.t()) - #torch.cuda.synchronize() - #print((time.time()-t0)/iters*1e6) - + # torch.cuda.synchronize() + # print((time.time()-t0)/iters*1e6) def test_normal_map_tree(): code = F.create_normal_map() - values =code[:8].tolist() + code[-8:].tolist() + values = code[:8].tolist() + code[-8:].tolist() num_pivots = 1 - #print(values) - while num_pivots <16: - idx = list(range(16//num_pivots//2, 16, 16//num_pivots)) - #print(idx) + # print(values) + while num_pivots < 16: + idx = list(range(16 // num_pivots // 2, 16, 16 // num_pivots)) + # print(idx) num_pivots *= 2 pivots = [] for i in idx: - pivots.append((values[i-1]+values[i])/2) - #print(pivots) + pivots.append((values[i - 1] + values[i]) / 2) + # print(pivots) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") -@pytest.mark.parametrize("storage_type", ['nf4', 'fp4']) -@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed']) +@pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) -@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize( + "quant_storage", + [torch.uint8, torch.float16, torch.bfloat16, torch.float32], + ids=describe_dtype, +) def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): for dim in [128, 256, 512, 1024]: - #for dim in [4*1024]: - #for dim in [1*16]: + # for dim in [4*1024]: + # for dim in [1*16]: errs1 = [] errs2 = [] errs3 = [] @@ -2171,38 +2071,42 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): max_errs2 = [] max_errs3 = [] - for i in range(100): - if kind == 'fc1': - A = torch.randn(1, dim, dtype=dtype, device='cuda') - B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - elif kind == 'fc2': - A = torch.randn(1, 4*dim, dtype=dtype, device='cuda') - B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim) - elif kind == 'attn': - A = torch.randn(1, dim, dtype=dtype, device='cuda') - B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - elif kind == 'attn_packed': - A = torch.randn(1, dim, dtype=dtype, device='cuda') - B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - - qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage) + if kind == "fc1": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "fc2": + A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda") + B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "attn": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "attn_packed": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + + qB, state = F.quantize_4bit( + B, + quant_type=storage_type, + compress_statistics=double_quant, + quant_storage=quant_storage, + ) C3 = torch.matmul(A, B.t()) C2 = F.gemv_4bit(A, qB.t(), state=state) A.requires_grad = True C1 = bnb.matmul_4bit(A, qB.t(), state) - err1 = (C1-C2).abs().float() - err2 = (C3-C2).abs().float() - err3 = (C3-C1).abs().float() + err1 = (C1 - C2).abs().float() + err2 = (C3 - C2).abs().float() + err3 = (C3 - C1).abs().float() - mag1 = torch.abs(C1).float()+1e-5 - mag2 = torch.abs(C3).float()+1e-5 - mag3 = torch.abs(C3).float()+1e-5 + mag1 = torch.abs(C1).float() + 1e-5 + mag2 = torch.abs(C3).float() + 1e-5 + mag3 = torch.abs(C3).float() + 1e-5 - relerr1 = err1/mag1 - relerr2 = err2/mag2 - relerr3 = err3/mag3 + relerr1 = err1 / mag1 + relerr2 = err2 / mag2 + relerr3 = err3 / mag3 max_err1 = err1.max() max_err2 = err2.max() @@ -2220,34 +2124,34 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): max_errs2.append(max_err2.item()) max_errs3.append(max_err3.item()) - c = int(C1.numel()*0.0014*(dim/256))+1 + c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) - err1 = sum(errs1)/len(errs1)/math.sqrt(dim) - err2 = sum(errs2)/len(errs2)/math.sqrt(dim) - err3 = sum(errs3)/len(errs3)/math.sqrt(dim) - relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim) - relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim) - relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim) - maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim) - maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim) - maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim) - absratio = err2/err3 - relratio = relerr2/relerr3 - maxratio = relerr2/relerr3 + err1 = sum(errs1) / len(errs1) / math.sqrt(dim) + err2 = sum(errs2) / len(errs2) / math.sqrt(dim) + err3 = sum(errs3) / len(errs3) / math.sqrt(dim) + relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim) + relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim) + relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim) + maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim) + maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim) + maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim) + absratio = err2 / err3 + relratio = relerr2 / relerr3 + maxratio = relerr2 / relerr3 # for debugging if the tests fails # - #print('='*80) - #print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') - #print(C1.flatten()[-20:]) - #print(C2.flatten()[-20:]) - #print(f'inference vs training abs: {err1}') - #print(f'inference vs training rel: {relerr1}') - #print(f'inference vs training max: {maxerr1}') - #print(f'inference vs training vs torch err ratio abs: {absratio}') - #print(f'inference vs training vs torch err ratio rel: {relratio}') - #print(f'inference vs training vs torch err ratio max: {maxratio}') + # print('='*80) + # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') + # print(C1.flatten()[-20:]) + # print(C2.flatten()[-20:]) + # print(f'inference vs training abs: {err1}') + # print(f'inference vs training rel: {relerr1}') + # print(f'inference vs training max: {maxerr1}') + # print(f'inference vs training vs torch err ratio abs: {absratio}') + # print(f'inference vs training vs torch err ratio rel: {relratio}') + # print(f'inference vs training vs torch err ratio max: {maxratio}') if dtype == torch.float16: if dim <= 512: assert err1 < 7e-5 @@ -2283,56 +2187,59 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): assert relratio < 1.04 and relratio > 0.96 assert maxratio < 1.02 and maxratio > 0.98 + @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): - n = 32*10 + n = 32 * 10 A = F.get_paged(n, n, dtype=torch.float32) B = F.get_paged(n, n, dtype=torch.uint8) B2 = F.get_paged(n, n, dtype=torch.float32) assert A.is_paged assert B.is_paged - assert A.page_deviceid==0 - assert B.page_deviceid==0 + assert A.page_deviceid == 0 + assert B.page_deviceid == 0 F.fill(A, 17.0) F.fill(B, 17) F.fill(B2, 2) - assert (A==17).sum().item() == n*n - assert (B==17).sum().item() == n*n - C = A*B.float() - assert (C==289).sum().item() == n*n + assert (A == 17).sum().item() == n * n + assert (B == 17).sum().item() == n * n + C = A * B.float() + assert (C == 289).sum().item() == n * n F._mul(A, B2) F._mul(A, B2) F._mul(A, B2) - assert (A==17*(2**3)).sum().item() == n*n - # F.prefetch_tensor(A) - # F.prefetch_tensor(B) + assert (A == 17 * (2**3)).sum().item() == n * n + + +# F.prefetch_tensor(A) +# F.prefetch_tensor(B) - # F.fill(B2, 17.0) - # F._mul(A, B2) +# F.fill(B2, 17.0) +# F._mul(A, B2) - # F.prefetch_tensor(A, to_cpu=True) - # F.prefetch_tensor(B, to_cpu=True) - # F.prefetch_tensor(B2, to_cpu=True) - # torch.cuda.synchronize() +# F.prefetch_tensor(A, to_cpu=True) +# F.prefetch_tensor(B, to_cpu=True) +# F.prefetch_tensor(B2, to_cpu=True) +# torch.cuda.synchronize() - # assert (A==17).sum().item() == n*n +# assert (A==17).sum().item() == n*n - # torch.testing.assert_close(A, torch.ones(A.shape)*289) +# torch.testing.assert_close(A, torch.ones(A.shape)*289) -@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) +@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) -@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True']) +@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) def test_gemv_eye_4bit(storage_type, dtype, double_quant): dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242)) dims = get_test_dims(0, 8192, n=dims) - dims = [dim + (64-(dim % 64)) for dim in dims] - #for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: + dims = [dim + (64 - (dim % 64)) for dim in dims] + # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: for dim in dims: - A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda') - B = torch.eye(dim, dtype=dtype, device='cuda') + A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda") + B = torch.eye(dim, dtype=dtype, device="cuda") qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) C3 = torch.matmul(A, B.t()) @@ -2343,5 +2250,5 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant): torch.testing.assert_close(A, C3) torch.testing.assert_close(A, C1) torch.testing.assert_close(A, C2) - #torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) - #torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) + # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) + # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) diff --git a/tests/test_generation.py b/tests/test_generation.py index ef354d70a..911aa14da 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -10,56 +10,61 @@ def get_4bit_config(): - return transformers.BitsAndBytesConfig( - load_in_4bit=True, - load_in_8bit=False, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type='nf4', - ) + return transformers.BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) def get_model_and_tokenizer(config): model_name_or_path, quant_type = config bnb_config = get_4bit_config() - if quant_type == '16bit': + if quant_type == "16bit": bnb_config.load_in_4bit = False else: - bnb_config.bnb_4bit_quant_type= quant_type - model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path, + bnb_config.bnb_4bit_quant_type = quant_type + model = transformers.AutoModelForCausalLM.from_pretrained( + model_name_or_path, quantization_config=bnb_config, - max_memory={0:'48GB'}, - device_map='auto', - torch_dtype=torch.bfloat16 - ).eval() + max_memory={0: "48GB"}, + device_map="auto", + torch_dtype=torch.bfloat16, + ).eval() tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) return model, tokenizer + def get_prompt_for_generation_eval(text, add_roles=True): description = ( "A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions." ) if add_roles: - prompt = f'{description} ### Human: {text} ### Assistant:' + prompt = f"{description} ### Human: {text} ### Assistant:" else: - prompt = f'{description} {text}' + prompt = f"{description} {text}" return prompt + def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_for_generation_eval): text = prompt_func(text) - inputs = tokenizer(text, return_tensors="pt").to('cuda:0') - outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config) + inputs = tokenizer(text, return_tensors="pt").to("cuda:0") + outputs = model.generate(inputs=inputs["input_ids"], generation_config=generation_config) return tokenizer.decode(outputs[0], skip_special_tokens=True) -models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7'] -dtypes = ['nf4', 'fp4'] -@pytest.fixture(scope='session', params=product(models, dtypes)) +models = ["huggyllama/llama-7b", "bigscience/bloom-1b7"] +dtypes = ["nf4", "fp4"] + + +@pytest.fixture(scope="session", params=product(models, dtypes)) def model_and_tokenizer(request): model, tokenizer = get_model_and_tokenizer(request.param) yield request.param, model, tokenizer @@ -81,20 +86,19 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype): ) generation_config.max_new_tokens = 20 - - #text = 'Please write down the first 50 digits of pi.' - #text = get_prompt_for_generation_eval(text) - #text += ' Sure, here the first 50 digits of pi: 3.14159' + # text = 'Please write down the first 50 digits of pi.' + # text = get_prompt_for_generation_eval(text) + # text += ' Sure, here the first 50 digits of pi: 3.14159' n_cases = 6 - text = '3.14159' - if hasattr(model.config, 'quantization_config'): + text = "3.14159" + if hasattr(model.config, "quantization_config"): model.config.quantization_config.bnb_4bit_compute_dtype = dtype model.config.quantization_config.bnb_4bit_use_double_quant = DQ if not inference_kernel: - text = [text]*n_cases - inputs = tokenizer(text, return_tensors="pt").to('cuda:0') - x = inputs['input_ids'] + text = [text] * n_cases + inputs = tokenizer(text, return_tensors="pt").to("cuda:0") + x = inputs["input_ids"] outputs = [] if inference_kernel: for i in range(n_cases): @@ -105,15 +109,14 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype): outputs = model.generate(x, generation_config=generation_config) outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] - assert len(outputs) == n_cases failure_count = 0 for i in range(n_cases): - if not outputs[i][:len(str(math.pi))] == str(math.pi): + if not outputs[i][: len(str(math.pi))] == str(math.pi): failure_count += 1 - failure_max = (2 if fixture_config[0] == 'huggyllama/llama-7b' else 4) + failure_max = 2 if fixture_config[0] == "huggyllama/llama-7b" else 4 if failure_count > failure_max: print(math.pi) for out in outputs: print(out) - raise ValueError(f'Failure count: {failure_count}/{n_cases}') + raise ValueError(f"Failure count: {failure_count}/{n_cases}") diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 567e1a466..bbbd05335 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -28,9 +28,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora device = "cuda" layer_shape = (300, 400) - linear = torch.nn.Linear( - *layer_shape, dtype=original_dtype, device="cpu" - ) # original layer + linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer # Quantizing original layer linear_q = bnb.nn.Linear4bit( @@ -42,9 +40,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora quant_type=quant_type, device="meta", ) - new_weight = bnb.nn.Params4bit( - data=linear.weight, quant_type=quant_type, requires_grad=False - ) + new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False) linear_q.weight = new_weight if bias: linear_q.bias = torch.nn.Parameter(linear.bias) @@ -172,7 +168,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora target_compression = ( 0.143 if original_dtype == torch.float32 else 0.29 ) # these numbers get lower as weight shape increases - ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" + ratio_error_msg = ( + f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" + ) assert size_ratio < target_compression, ratio_error_msg diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index edc3409cd..4b62abd6d 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -19,6 +19,7 @@ # contributed by Alex Borzunov, see: # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py + @pytest.mark.skipif( not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", @@ -50,7 +51,9 @@ def test_linear_no_igemmlt(): linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( - linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False + linear.weight.data.clone(), + requires_grad=False, + has_fp16_weights=False, ).to(linear.weight.dtype) linear_custom.bias = linear.bias linear_custom = linear_custom.cuda() @@ -77,7 +80,14 @@ def test_linear_no_igemmlt(): @pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt")) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) @pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) -def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt, save_before_forward, load_before_cuda): +def test_linear_serialization( + has_fp16_weights, + serialize_before_forward, + deserialize_before_cuda, + force_no_igemmlt, + save_before_forward, + load_before_cuda, +): linear = torch.nn.Linear(32, 96) x = torch.randn(3, 32, dtype=torch.half) @@ -92,7 +102,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( - linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights + linear.weight.data.clone(), + requires_grad=has_fp16_weights, + has_fp16_weights=has_fp16_weights, ) linear_custom.bias = linear.bias linear_custom = linear_custom.cuda() diff --git a/tests/test_modules.py b/tests/test_modules.py index 674620e29..db4d72410 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -19,12 +19,18 @@ class MLP8bit(torch.nn.Module): def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): super().__init__() self.fc1 = bnb.nn.Linear8bitLt( - dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, - threshold=threshold + dim1, + dim2, + has_fp16_weights=has_fp16_weights, + memory_efficient_backward=memory_efficient_backward, + threshold=threshold, ) self.fc2 = bnb.nn.Linear8bitLt( - dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, - threshold=threshold + dim2, + dim1, + has_fp16_weights=has_fp16_weights, + memory_efficient_backward=memory_efficient_backward, + threshold=threshold, ) def forward(self, x): @@ -52,9 +58,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): class LinearFunction(torch.autograd.Function): @staticmethod def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): - round_func = ( - LinearFunction.round_stoachastic if stochastic else torch.round - ) + round_func = LinearFunction.round_stoachastic if stochastic else torch.round norm = math.sqrt(math.pi) / math.sqrt(2.0) # std = torch.abs(x).mean()*norm std = torch.std(x) @@ -122,9 +126,7 @@ def dequant_min_max(xq, A, B, SA, SB, dtype): return x.to(dtype) def get_8bit_linear(x, stochastic=False): - round_func = ( - LinearFunction.round_stoachastic if stochastic else torch.round - ) + round_func = LinearFunction.round_stoachastic if stochastic else torch.round max1 = torch.abs(x).max() x = x / max1 * 127 x = round_func(x) / 127 * max1 @@ -133,9 +135,7 @@ def get_8bit_linear(x, stochastic=False): @staticmethod def get_8bit_vector_wise(x, dim, stochastic=False): - round_func = ( - LinearFunction.round_stoachastic if stochastic else torch.round - ) + round_func = LinearFunction.round_stoachastic if stochastic else torch.round max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1[max1 == 0] = 1.0 x = (x * 127) / max1 @@ -219,9 +219,7 @@ def forward(ctx, x, weight, bias=None, args=None): weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) outputq = bnb.functional.igemm(x8, weight8.t()) - output = LinearFunction.dequant( - outputq, S1, S2, x.dtype, args.quant_type - ) + output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) # if torch.rand(1) < 0.01: # output32 = torch.matmul(x, weight.t()) # err = torch.abs(output-output32).float() @@ -250,37 +248,25 @@ def backward(ctx, grad_output): # weight and x are already 8bit # -> transform grad_output to 8-bit if args.use_8bit_training == "forward+wgrad": - grad_output8, S1 = LinearFunction.quant( - grad_output, args.quant_type, dim=[0, 1] - ) + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) grad_weight8 = bnb.functional.igemm(grad_output8, x8) - grad_weight = LinearFunction.dequant( - grad_weight8, S1, S2, grad_output.dtype, args.quant_type - ) + grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) grad_input = grad_output.matmul(weight) elif args.use_8bit_training == "full": - grad_output8, S1 = LinearFunction.quant( - grad_output, args.quant_type, dim=[0, 1] - ) + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) bnb.functional.igemm(grad_output8, x8, out=grad_weight8) - grad_weight = LinearFunction.dequant( - grad_weight8, S1, S2, grad_output.dtype, args.quant_type - ) + grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) - grad_output8, S1 = LinearFunction.quant( - grad_output, args.quant_type, dim=2 - ) + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) grad_input8 = bnb.functional.igemm(grad_output8, weight8) - grad_input = LinearFunction.dequant( - grad_input8, S1, S3, grad_output.dtype, args.quant_type - ) + grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type) else: grad_input = grad_output.matmul(weight) @@ -356,12 +342,8 @@ def test_linear8bitlt_accumulated_gradient(): opt1.zero_grad(True) opt2.step() opt2.zero_grad(True) - assert_all_approx_close( - l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2 - ) - assert_all_approx_close( - l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2 - ) + assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2) + assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2) # we do this copy because otherwise we have small divergences over time that add up l1[0].weight.data.copy_(l2[0].weight.data) l1[1].weight.data.copy_(l2[1].weight.data) @@ -375,7 +357,17 @@ def test_linear8bitlt_accumulated_gradient(): @pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): - l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) + l1 = ( + bnb.nn.Linear8bitLt( + 32, + 64, + threshold=threshold, + has_fp16_weights=False, + memory_efficient_backward=memory_efficient_backward, + ) + .cuda() + .half() + ) assert l1.weight.dtype == torch.int8 l1.eval() @@ -397,11 +389,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) - .cuda() - .half() - ) + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -414,11 +402,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) - .half() - .cuda() - ) + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -431,7 +415,17 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda")) + mlp = ( + MLP8bit( + 32, + 64, + threshold=threshold, + has_fp16_weights=False, + memory_efficient_backward=memory_efficient_backward, + ) + .half() + .to("cuda") + ) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -447,8 +441,12 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc2.weight.device.type == "cuda" mlp = MLP8bit( - 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward - ) + 32, + 64, + threshold=threshold, + has_fp16_weights=False, + memory_efficient_backward=memory_efficient_backward, + ) w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, mlp = mlp.cuda().half() # and this line triggers quantization @@ -489,7 +487,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4, ], - ids=['Int8Lt', 'FP4'], + ids=["Int8Lt", "FP4"], ) def test_linear_kbit_fp32_bias(module): # casts model to fp16 -> int8 automatically @@ -544,7 +542,7 @@ def test_kbit_backprop(module): kbit[1].bias.detach().copy_(ref[1].bias) ref = ref.half().cuda() kbit = kbit.half().cuda() - kbit = kbit.half().to('cuda') + kbit = kbit.half().to("cuda") errs1 = [] errs2 = [] @@ -562,10 +560,10 @@ def test_kbit_backprop(module): bgrad1 = ref[0].bias.grad bgrad2 = kbit[0].bias.grad - err1 = (out1-out2).abs().float() - err2 = (grad1-grad2).abs().float() - relerr1 = (err1/(out1.abs().float()+1e-9)) - relerr2 = (err2/(grad1.abs().float()+1e-9)) + err1 = (out1 - out2).abs().float() + err2 = (grad1 - grad2).abs().float() + relerr1 = err1 / (out1.abs().float() + 1e-9) + relerr2 = err2 / (grad1.abs().float() + 1e-9) errs1.append(err1.mean().item()) errs2.append(err2.mean().item()) relerrs1.append(relerr1.mean().item()) @@ -582,20 +580,20 @@ def test_kbit_backprop(module): assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0 assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 - #print('out', sum(errs1)/len(errs1)) - #print('grad', sum(errs2)/len(errs2)) - #print('rel out', sum(relerrs1)/len(relerrs1)) - #print('rel grad', sum(relerrs2)/len(relerrs2)) + # print('out', sum(errs1)/len(errs1)) + # print('grad', sum(errs2)/len(errs2)) + # print('rel out', sum(relerrs1)/len(relerrs1)) + # print('rel grad', sum(relerrs2)/len(relerrs2)) -def test_fp8linear(): +def test_fp8linear(): b = 10 h = 1024 inp = torch.randn(b, h).cuda() - fp32 = torch.nn.Linear(h, h*2).cuda() - fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda() - fp32b = torch.nn.Linear(h*2, h).cuda() - fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda() + fp32 = torch.nn.Linear(h, h * 2).cuda() + fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda() + fp32b = torch.nn.Linear(h * 2, h).cuda() + fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda() fp8.weight.data.copy_(fp32.weight.data) fp8.bias.data.copy_(fp32.bias.data) @@ -605,34 +603,34 @@ def test_fp8linear(): a = fp32b(torch.nn.functional.gelu(fp32(inp))) b = fp8b(torch.nn.functional.gelu(fp8(inp))) - err = (a-b).abs().mean() + err = (a - b).abs().mean() a.mean().backward() b.mean().backward() - graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean() - bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean() + graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean() + bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean() assert err < 0.05 assert graderr < 0.00002 assert bgraderr < 0.00002 + def test_4bit_warnings(): dim1 = 64 - with pytest.warns(UserWarning, match=r'inference or training'): + with pytest.warns(UserWarning, match=r"inference or training"): net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = net.cuda() inp = torch.rand(10, dim1).cuda().half() net(inp) - with pytest.warns(UserWarning, match=r'inference.'): + with pytest.warns(UserWarning, match=r"inference."): net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = net.cuda() inp = torch.rand(1, dim1).cuda().half() net(inp) with pytest.warns(UserWarning) as record: - net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = net.cuda() inp = torch.rand(10, dim1).cuda().half() diff --git a/tests/test_optim.py b/tests/test_optim.py index 9395b8820..d8c46e415 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -16,6 +16,7 @@ k = 20 + def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): idx = torch.isclose(a, b, rtol=rtol, atol=atol) error_count = (idx == 0).sum().item() @@ -33,6 +34,7 @@ def get_temp_dir(): def rm_path(path): shutil.rmtree(path) + str2optimizers = {} str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) @@ -66,8 +68,14 @@ def rm_path(path): ) str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) -str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) -str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) +str2optimizers["paged_adamw8bit_blockwise"] = ( + torch.optim.AdamW, + lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True), +) +str2optimizers["paged_adam8bit_blockwise"] = ( + torch.optim.Adam, + lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True), +) str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True)) str2optimizers["momentum8bit_blockwise"] = ( @@ -90,9 +98,18 @@ def rm_path(path): str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] -str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] -str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] -str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["adam8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1"), + ("exp_avg_sq", "state2", "qmap2", "absmax2"), +] +str2statenames["paged_adam8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1"), + ("exp_avg_sq", "state2", "qmap2", "absmax2"), +] +str2statenames["paged_adamw8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1"), + ("exp_avg_sq", "state2", "qmap2", "absmax2"), +] str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")] str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] @@ -101,7 +118,7 @@ def rm_path(path): str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] -optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion'] +optimizer_names_32bit = ["adam", "momentum", "rmsprop", "paged_adamw", "paged_adam", "lion", "paged_lion"] @pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt")) @@ -109,7 +126,7 @@ def rm_path(path): @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) def test_optimizer32bit(dim1, dim2, gtype, optim_name): - if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: + if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]: pytest.skip() if dim1 == 1 and dim2 == 1: return @@ -161,9 +178,13 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): for name1, name2 in str2statenames[optim_name]: # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], - atol=atol, rtol=rtol, - max_error_count=10) + assert_most_approx_close( + torch_optimizer.state[p1][name1], + bnb_optimizer.state[p2][name2], + atol=atol, + rtol=rtol, + max_error_count=10, + ) if gtype != torch.float32: # the adam buffers should also be close because they are 32-bit @@ -193,13 +214,9 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() - bnb.optim.GlobalOptimManager.get_instance().override_config( - p3, "optim_bits", 8 - ) + bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8) - bnb.optim.GlobalOptimManager.get_instance().register_parameters( - [p1, p2, p3] - ) + bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) p1 = p1.cuda() p2 = p2.cuda() p3 = p3.cuda() @@ -242,7 +259,8 @@ def test_global_config(dim1, dim2, gtype): @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) def test_optimizer8bit(dim1, dim2, gtype, optim_name): - if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip() + if gtype == torch.bfloat16 and optim_name not in ["adam8bit_blockwise", "lion8bit_blockwise"]: + pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -294,17 +312,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], ) - num_not_close = ( - torch.isclose( - torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol - ) - == 0 - ) - #assert num_not_close.sum().item() < 20 + num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 + # assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) err = torch.abs(p1 - p2) - relerr = err / (torch.abs(p1)+1e-9) + relerr = err / (torch.abs(p1) + 1e-9) if g.dtype == torch.bfloat16: assert err.mean() < 0.00015 assert relerr.mean() < 0.0016 @@ -316,9 +329,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): relerrors.append(relerr.mean().item()) if i % 10 == 0 and i > 0: - for (name1, name2, qmap, max_val), s in zip( - str2statenames[optim_name], dequant_states - ): + for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): s1cpy = s.clone() raws1cpy = bnb_optimizer.state[p2][name2].clone() qmap1 = bnb_optimizer.state[p2][qmap].clone() @@ -348,7 +359,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ) torch.testing.assert_close(s1cpy, s1) - num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) + num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 assert num_not_close.sum().item() < 20 # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 5 errors for Lion @@ -395,15 +406,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): for i in range(50): step += 1 - g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + ( - 0.01 * i - ) + g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i) g2 = g1.clone() p2.grad = g2 - current_gnorm, clip_val, gnorm_scale = F.percentile_clipping( - g1, gnorm_vec, step, 5 - ) + current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5) g1 = (g1.float() * gnorm_scale).to(gtype) p1.grad = g1 @@ -497,8 +504,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): @pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype) -@pytest.mark.parametrize("optim_name", ['paged_adamw'], ids=id_formatter("optim_name")) -@pytest.mark.parametrize("mode", ['bnb'], ids=id_formatter("mode")) +@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name")) +@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode")) @pytest.mark.benchmark def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) @@ -506,24 +513,24 @@ def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): layers1 = layers1.cuda() large_tensor = None - if mode == 'torch': + if mode == "torch": optim = str2optimizers[optim_name][0](layers1.parameters()) else: optim = str2optimizers[optim_name][1](layers1.parameters()) # 12 GB - large_tensor = torch.empty((int(4.5e9),), device='cuda') + large_tensor = torch.empty((int(4.5e9),), device="cuda") torch.cuda.synchronize() time.sleep(5) num_batches = 5 - batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype) - lbls = torch.randint(0, 10, size=(num_batches,128)).cuda() + batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype) + lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda() for i in range(num_batches): print(i) b = batches[i] - if i ==2: + if i == 2: torch.cuda.synchronize() t0 = time.time() diff --git a/tests/test_triton.py b/tests/test_triton.py index 218a533d5..3624fb5e9 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -7,15 +7,18 @@ from tests.helpers import TRUE_FALSE -@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, - reason="This test requires triton and a GPU with compute capability 8.0 or higher.") +@pytest.mark.skipif( + not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, + reason="This test requires triton and a GPU with compute capability 8.0 or higher.", +) @pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE) def test_switchback(vector_wise_quantization): for dim in [83]: for batch in [13]: - standard = torch.nn.Linear(dim, 4 * dim).cuda().half() - switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() + switchback = ( + SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() + ) baseline = Linear8bitLt(dim, 4 * dim).cuda().half() switchback.weight.data.copy_(standard.weight) switchback.bias.data.copy_(standard.bias) @@ -38,23 +41,23 @@ def test_switchback(vector_wise_quantization): err_sb = (out_standard - out_sb).abs().mean() err_baseline = (out_standard - out_baseline).abs().mean() - print('OUT', err_sb, err_baseline) + print("OUT", err_sb, err_baseline) assert err_sb < 2 * err_baseline err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean() - print('GW2', err_sb, err_baseline) + print("GW2", err_sb, err_baseline) assert err_sb < 2 * err_baseline err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean() - print('GW1', err_sb, err_baseline) + print("GW1", err_sb, err_baseline) assert err_sb < 2 * err_baseline err_sb = (x1.grad - x2.grad).abs().mean() err_baseline = (x1.grad - x3.grad).abs().mean() - print('GX1', err_sb, err_baseline) + print("GX1", err_sb, err_baseline) assert err_sb < 2 * err_baseline