diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index eef8f41c..4736cd29 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -18,16 +18,6 @@ from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history from tqdm import tqdm -# Check if transformer_engine is installed -transformer_engine_installed = False -try: - import transformer_engine.pytorch as te - from transformer_engine.common import recipe - - transformer_engine_installed = True -except ImportError: - print("transformer_engine not installed and we won't compare against this") - # estimating TOPs for matmuls in fp32, fp16, fp8 # assuming A * B = C, with A being M * K, B being K * N, C being M * N @@ -66,7 +56,6 @@ class Experiment: dtype: torch.dtype compiled: bool = False float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn - te_time_sec: Optional[float] = None # 3 Times since we are calculating forward backward @property @@ -87,21 +76,6 @@ def float8_tops_sec(self): def float8_pct_top_peak(self): return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype] - @property - def te_tops_sec(self): - M, K, N = self.shape - if self.te_time_sec is not None: - return float(3 * (2 * M * K * N)) / self.te_time_sec - else: - return None - - @property - def te_pct_top_peak(self): - if self.te_tops_sec is not None: - return self.te_tops_sec / dtype_to_peak_tops[self.float_8_dtype] - else: - return None - def main( sweep_path: Path, @@ -113,7 +87,6 @@ def main( # LLaMa 2 70B single-node weight shapes # assumes fused attn.wqkv and ffn.w13 - # source: https://fburl.com/gsheet/g8onr7rh name_to_shapes_70b = { "attn.wqkv": (8192, 1280), "attn.w0": (1024, 8192), @@ -145,19 +118,6 @@ def float8_forw_backward(): sync_float8_amax_and_scale_history(linear_float8) linear_float8(input_tensor).sum().backward() - if transformer_engine_installed: - # Use the same recipe as float8_linear.DelayedScalingRecipe - fp8_format = recipe.Format.HYBRID - fp8_recipe = recipe.DelayedScaling( - fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max" - ) - te_linear = te.Linear(K, N, bias=input_bias).to(device=device, dtype=dtype) - - def te_forw_backward(): - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - y = te_linear(input_tensor) - y.sum().backward() - def n_times(n, fn, *args, **kwargs): def wrapper(*args, **kwargs): for _ in range(n): @@ -169,21 +129,14 @@ def wrapper(*args, **kwargs): ref_forw_backward = n_times(REPEAT_N, ref_forw_backward) float8_forw_backward = n_times(REPEAT_N, float8_forw_backward) - if transformer_engine_installed: - te_forw_backward = n_times(REPEAT_N, te_forw_backward) if compile: ref_forw_backward = torch.compile(ref_forw_backward) float8_forw_backward = torch.compile(float8_forw_backward) - # Compiling TE_linear fails but they are already compiling under the hood - # if transformer_engine_installed: - # te_forw_backward = torch.compile(te_forw_backward) for _ in range(5): ref_forw_backward() float8_forw_backward() - if transformer_engine_installed: - te_forw_backward() ref_time = ( benchmark_torch_function_in_microseconds(ref_forw_backward) @@ -195,14 +148,6 @@ def wrapper(*args, **kwargs): * 1e-6 / REPEAT_N ) - if transformer_engine_installed: - te_time_sec = ( - benchmark_torch_function_in_microseconds(te_forw_backward) - * 1e-6 - / REPEAT_N - ) - else: - te_time_sec = None experiment = Experiment( name, (M, K, N), @@ -210,12 +155,9 @@ def wrapper(*args, **kwargs): float8_time, dtype, compile, - te_time_sec=te_time_sec, ) print(experiment) print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) - if transformer_engine_installed: - print("te speedup", experiment.ref_time_sec / experiment.te_time_sec) experiment_list.append(experiment) torch._dynamo.reset() @@ -229,13 +171,10 @@ def wrapper(*args, **kwargs): "fp8_dtype", "ref_time_sec", "pt_fp8_time_sec", - "te_fp8_time_sec", "ref_tops_sec", "ref_pct_top_peak", "pt_fp8_tops_sec", "pt_fp8_pct_top_peak", - "te_fp8_tops_sec", - "te_fp8_pct_top_peak", ] data = [] for experiment in experiment_list: @@ -250,22 +189,15 @@ def wrapper(*args, **kwargs): experiment.float_8_dtype, experiment.ref_time_sec, experiment.float8_time_sec, - experiment.te_time_sec, experiment.ref_tops_sec, experiment.ref_pct_top_peak, experiment.float8_tops_sec, experiment.float8_pct_top_peak, - experiment.te_tops_sec, - experiment.te_pct_top_peak, ] ) data_pd = pd.DataFrame(data, columns=headers) data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"] - if transformer_engine_installed: - data_pd["te_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["te_fp8_time_sec"] - else: - data_pd["te_fp8_speedup"] = -1.0 data_pd["shape"] = ( "(" + data_pd["M"].astype(str) @@ -284,9 +216,7 @@ def wrapper(*args, **kwargs): "compiled", "ref_time_sec", "pt_fp8_time_sec", - "te_fp8_time_sec", "pt_fp8_speedup", - "te_fp8_speedup", ] ] print(data_pd_simple) diff --git a/benchmarks/bench_multi_gpu.py b/benchmarks/bench_multi_gpu.py index 9c584753..6ac51356 100644 --- a/benchmarks/bench_multi_gpu.py +++ b/benchmarks/bench_multi_gpu.py @@ -26,16 +26,6 @@ StateDictType, ) -# Check if transformer_engine is installed -transformer_engine_installed = False -try: - import transformer_engine.pytorch as te - from transformer_engine.common import recipe - - transformer_engine_installed = True -except ImportError: - print("transformer_engine not installed and we won't compare against this") - torch.manual_seed(0) @@ -68,26 +58,18 @@ def cleanup(): dist.destroy_process_group() -def get_model(K, N, is_fp8, is_te, base_dtype=torch.float32): +def get_model(K, N, is_fp8, base_dtype=torch.float32): modules = [ - ( - nn.Linear(K, N, dtype=base_dtype) - if not is_te - else te.Linear(K, N, params_dtype=base_dtype) - ), + nn.Linear(K, N, dtype=base_dtype), nn.ReLU(), ] N_LAYERS = 20 # N linear layers for _ in range(N_LAYERS - 1): - if is_te: - modules.append(te.Linear(N, N, params_dtype=base_dtype)) - else: - modules.append(nn.Linear(N, N, dtype=base_dtype)) + modules.append(nn.Linear(N, N, dtype=base_dtype)) modules.append(nn.ReLU()) m = nn.Sequential(*modules) if is_fp8: - assert not is_te, "`is_fp8` (using pytorch fp8) can't be used with `is_te`" swap_linear_with_float8_linear(m, Float8Linear, emulate=False) return m @@ -105,9 +87,7 @@ def fsdp_main(rank, world_size, args): bsz_local_end = int((rank + 1) / world_size * B) input_tensor = input_global[bsz_local_start:bsz_local_end].to(rank) - fp8_model = get_model(K, N, is_fp8=True, is_te=False, base_dtype=base_dtype).to( - rank - ) + fp8_model = get_model(K, N, is_fp8=True, base_dtype=base_dtype).to(rank) # Need use_orig_params=True to compile FSDP fp8_model = FSDP(fp8_model, use_orig_params=True) fp8_optimizer = torch.optim.SGD(fp8_model.parameters(), lr=lr * world_size) @@ -132,9 +112,7 @@ def float8_forw_backward(): fp8_optimizer.step() sync_float8_func(fp8_model) - ref_model = get_model(K, N, is_fp8=False, is_te=False, base_dtype=base_dtype).to( - rank - ) + ref_model = get_model(K, N, is_fp8=False, base_dtype=base_dtype).to(rank) ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=lr * world_size) if compile: ref_model = torch.compile(ref_model) @@ -146,30 +124,6 @@ def ref_forw_backward(): ref_model(input_tensor).sum().backward() ref_optimizer.step() - if transformer_engine_installed: - te_model = FSDP( - get_model(K, N, is_fp8=False, is_te=True, base_dtype=base_dtype).to(rank), - use_orig_params=True, - ) - fp8_format = recipe.Format.HYBRID - fp8_recipe = recipe.DelayedScaling( - fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max" - ) - # Compiling TE_linear fails but they are already compiling under the hood - # if transformer_engine_installed: - # te_forw_backward = torch.compile(te_forw_backward) - if rank == 0: - print(te_model) - - te_optimizer = torch.optim.SGD(ref_model.parameters(), lr=lr * world_size) - - def te_forw_backward(): - te_optimizer.zero_grad() - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - y = te_model(input_tensor) - y.sum().backward() - te_optimizer.step() - def run_n_iterations(n, fn): for _ in range(n): fn() @@ -179,8 +133,6 @@ def run_n_iterations(n, fn): # warmup run_n_iterations(50, ref_forw_backward) run_n_iterations(50, float8_forw_backward) - if transformer_engine_installed: - run_n_iterations(50, te_forw_backward) N_ITER = 50 ref_time = ( @@ -197,24 +149,11 @@ def run_n_iterations(n, fn): * 1e-6 / N_ITER ) - if transformer_engine_installed: - te_time_sec = ( - benchmark_torch_function_in_microseconds( - run_n_iterations, N_ITER, te_forw_backward - ) - * 1e-6 - / N_ITER - ) - else: - te_time_sec = None if rank == 0: print("ref_time", ref_time) print("float8_time", float8_time) - print("te_time_sec", te_time_sec) print("float8 speedup", ref_time / float8_time) - if transformer_engine_installed: - print("te speedup", ref_time / te_time_sec) cleanup() diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index 7f6371c3..b62020b1 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -76,17 +76,6 @@ def profile_function( return prof -# Check if transformer_engine is installed -transformer_engine_installed = False -try: - import transformer_engine.pytorch as te - from transformer_engine.common import recipe - - transformer_engine_installed = True -except ImportError: - print("transformer_engine not installed and we won't compare against this") - - @dataclass(frozen=True) class LinearParams: M: int @@ -165,35 +154,13 @@ def float8_forw_backward_wrapper(x): with record_function("backward"): out.sum().backward() - if transformer_engine_installed: - # Use the same recipe as float8_linear.DelayedScalingRecipe - fp8_format = recipe.Format.HYBRID - fp8_recipe = recipe.DelayedScaling( - fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max" - ) - te_linear = te.Linear(params.K, params.N, bias=params.input_bias).to( - device="cuda", dtype=params.ref_dtype - ) - - def te_forw_backward(x): - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - with record_function("forward"): - out = te_linear(x) - with record_function("backward"): - out.sum().backward() - if params.torch_compile: ref_forw_backward = torch.compile(ref_forw_backward) float8_forw_backward = torch.compile(float8_forw_backward, fullgraph=True) - # Compiling TE_linear fails but they are already compiling under the hood - # if transformer_engine_installed: - # te_forw_backward = torch.compile(te_forw_backward) for _ in range(5): ref_forw_backward(input_tensor) float8_forw_backward_wrapper(input_tensor) - if transformer_engine_installed: - te_forw_backward(input_tensor) # Profile Reference Linear ref_string = f"linear_ref_dtype_{params.ref_dtype}_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}_compile_{params.torch_compile}.json" @@ -213,13 +180,6 @@ def te_forw_backward(x): ) profile_function(profile_config, float8_forw_backward_wrapper, input_tensor) - te_string = f"linear_transformer_engine_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}.json" - if transformer_engine_installed: - profile_config = ProfileConfig( - str(profile_path / te_string), te_string, iters=5, warmup_iters=5, sync=True - ) - profile_function(profile_config, te_forw_backward, input_tensor) - def invoke_main() -> None: # Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles --compile=True --linear_type="dynamic"