Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
remove TE from the codebase (#184)
Browse files Browse the repository at this point in the history
Summary:
as titled

Pull Request resolved: #184

Test Plan:
```
// scripts work
python benchmarks/profile_linear_float8.py ../tmp/ False dynamic
python benchmarks/bench_linear_float8.py -o ../tmp/test.txt -n 1
CUDA_VISIBLE_DEVICES=0,1 python benchmarks/bench_multi_gpu.py

// no hits
grep -r transformer_engine .
```

Reviewed By: drisspg

Differential Revision: D52715981

Pulled By: vkuzo

fbshipit-source-id: 30d8036e3454148d5611585984b5f21ecbe674d3
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jan 13, 2024
1 parent d0af81a commit d272138
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 176 deletions.
70 changes: 0 additions & 70 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,6 @@
from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history
from tqdm import tqdm

# Check if transformer_engine is installed
transformer_engine_installed = False
try:
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

transformer_engine_installed = True
except ImportError:
print("transformer_engine not installed and we won't compare against this")

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

Expand Down Expand Up @@ -66,7 +56,6 @@ class Experiment:
dtype: torch.dtype
compiled: bool = False
float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn
te_time_sec: Optional[float] = None

# 3 Times since we are calculating forward backward
@property
Expand All @@ -87,21 +76,6 @@ def float8_tops_sec(self):
def float8_pct_top_peak(self):
return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype]

@property
def te_tops_sec(self):
M, K, N = self.shape
if self.te_time_sec is not None:
return float(3 * (2 * M * K * N)) / self.te_time_sec
else:
return None

@property
def te_pct_top_peak(self):
if self.te_tops_sec is not None:
return self.te_tops_sec / dtype_to_peak_tops[self.float_8_dtype]
else:
return None


def main(
sweep_path: Path,
Expand All @@ -113,7 +87,6 @@ def main(

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
Expand Down Expand Up @@ -145,19 +118,6 @@ def float8_forw_backward():
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

if transformer_engine_installed:
# Use the same recipe as float8_linear.DelayedScalingRecipe
fp8_format = recipe.Format.HYBRID
fp8_recipe = recipe.DelayedScaling(
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
)
te_linear = te.Linear(K, N, bias=input_bias).to(device=device, dtype=dtype)

def te_forw_backward():
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
y = te_linear(input_tensor)
y.sum().backward()

def n_times(n, fn, *args, **kwargs):
def wrapper(*args, **kwargs):
for _ in range(n):
Expand All @@ -169,21 +129,14 @@ def wrapper(*args, **kwargs):

ref_forw_backward = n_times(REPEAT_N, ref_forw_backward)
float8_forw_backward = n_times(REPEAT_N, float8_forw_backward)
if transformer_engine_installed:
te_forw_backward = n_times(REPEAT_N, te_forw_backward)

if compile:
ref_forw_backward = torch.compile(ref_forw_backward)
float8_forw_backward = torch.compile(float8_forw_backward)
# Compiling TE_linear fails but they are already compiling under the hood
# if transformer_engine_installed:
# te_forw_backward = torch.compile(te_forw_backward)

for _ in range(5):
ref_forw_backward()
float8_forw_backward()
if transformer_engine_installed:
te_forw_backward()

ref_time = (
benchmark_torch_function_in_microseconds(ref_forw_backward)
Expand All @@ -195,27 +148,16 @@ def wrapper(*args, **kwargs):
* 1e-6
/ REPEAT_N
)
if transformer_engine_installed:
te_time_sec = (
benchmark_torch_function_in_microseconds(te_forw_backward)
* 1e-6
/ REPEAT_N
)
else:
te_time_sec = None
experiment = Experiment(
name,
(M, K, N),
ref_time,
float8_time,
dtype,
compile,
te_time_sec=te_time_sec,
)
print(experiment)
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
if transformer_engine_installed:
print("te speedup", experiment.ref_time_sec / experiment.te_time_sec)
experiment_list.append(experiment)
torch._dynamo.reset()

Expand All @@ -229,13 +171,10 @@ def wrapper(*args, **kwargs):
"fp8_dtype",
"ref_time_sec",
"pt_fp8_time_sec",
"te_fp8_time_sec",
"ref_tops_sec",
"ref_pct_top_peak",
"pt_fp8_tops_sec",
"pt_fp8_pct_top_peak",
"te_fp8_tops_sec",
"te_fp8_pct_top_peak",
]
data = []
for experiment in experiment_list:
Expand All @@ -250,22 +189,15 @@ def wrapper(*args, **kwargs):
experiment.float_8_dtype,
experiment.ref_time_sec,
experiment.float8_time_sec,
experiment.te_time_sec,
experiment.ref_tops_sec,
experiment.ref_pct_top_peak,
experiment.float8_tops_sec,
experiment.float8_pct_top_peak,
experiment.te_tops_sec,
experiment.te_pct_top_peak,
]
)

data_pd = pd.DataFrame(data, columns=headers)
data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"]
if transformer_engine_installed:
data_pd["te_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["te_fp8_time_sec"]
else:
data_pd["te_fp8_speedup"] = -1.0
data_pd["shape"] = (
"("
+ data_pd["M"].astype(str)
Expand All @@ -284,9 +216,7 @@ def wrapper(*args, **kwargs):
"compiled",
"ref_time_sec",
"pt_fp8_time_sec",
"te_fp8_time_sec",
"pt_fp8_speedup",
"te_fp8_speedup",
]
]
print(data_pd_simple)
Expand Down
71 changes: 5 additions & 66 deletions benchmarks/bench_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,6 @@
StateDictType,
)

# Check if transformer_engine is installed
transformer_engine_installed = False
try:
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

transformer_engine_installed = True
except ImportError:
print("transformer_engine not installed and we won't compare against this")


torch.manual_seed(0)

Expand Down Expand Up @@ -68,26 +58,18 @@ def cleanup():
dist.destroy_process_group()


def get_model(K, N, is_fp8, is_te, base_dtype=torch.float32):
def get_model(K, N, is_fp8, base_dtype=torch.float32):
modules = [
(
nn.Linear(K, N, dtype=base_dtype)
if not is_te
else te.Linear(K, N, params_dtype=base_dtype)
),
nn.Linear(K, N, dtype=base_dtype),
nn.ReLU(),
]
N_LAYERS = 20
# N linear layers
for _ in range(N_LAYERS - 1):
if is_te:
modules.append(te.Linear(N, N, params_dtype=base_dtype))
else:
modules.append(nn.Linear(N, N, dtype=base_dtype))
modules.append(nn.Linear(N, N, dtype=base_dtype))
modules.append(nn.ReLU())
m = nn.Sequential(*modules)
if is_fp8:
assert not is_te, "`is_fp8` (using pytorch fp8) can't be used with `is_te`"
swap_linear_with_float8_linear(m, Float8Linear, emulate=False)
return m

Expand All @@ -105,9 +87,7 @@ def fsdp_main(rank, world_size, args):
bsz_local_end = int((rank + 1) / world_size * B)
input_tensor = input_global[bsz_local_start:bsz_local_end].to(rank)

fp8_model = get_model(K, N, is_fp8=True, is_te=False, base_dtype=base_dtype).to(
rank
)
fp8_model = get_model(K, N, is_fp8=True, base_dtype=base_dtype).to(rank)
# Need use_orig_params=True to compile FSDP
fp8_model = FSDP(fp8_model, use_orig_params=True)
fp8_optimizer = torch.optim.SGD(fp8_model.parameters(), lr=lr * world_size)
Expand All @@ -132,9 +112,7 @@ def float8_forw_backward():
fp8_optimizer.step()
sync_float8_func(fp8_model)

ref_model = get_model(K, N, is_fp8=False, is_te=False, base_dtype=base_dtype).to(
rank
)
ref_model = get_model(K, N, is_fp8=False, base_dtype=base_dtype).to(rank)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=lr * world_size)
if compile:
ref_model = torch.compile(ref_model)
Expand All @@ -146,30 +124,6 @@ def ref_forw_backward():
ref_model(input_tensor).sum().backward()
ref_optimizer.step()

if transformer_engine_installed:
te_model = FSDP(
get_model(K, N, is_fp8=False, is_te=True, base_dtype=base_dtype).to(rank),
use_orig_params=True,
)
fp8_format = recipe.Format.HYBRID
fp8_recipe = recipe.DelayedScaling(
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
)
# Compiling TE_linear fails but they are already compiling under the hood
# if transformer_engine_installed:
# te_forw_backward = torch.compile(te_forw_backward)
if rank == 0:
print(te_model)

te_optimizer = torch.optim.SGD(ref_model.parameters(), lr=lr * world_size)

def te_forw_backward():
te_optimizer.zero_grad()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
y = te_model(input_tensor)
y.sum().backward()
te_optimizer.step()

def run_n_iterations(n, fn):
for _ in range(n):
fn()
Expand All @@ -179,8 +133,6 @@ def run_n_iterations(n, fn):
# warmup
run_n_iterations(50, ref_forw_backward)
run_n_iterations(50, float8_forw_backward)
if transformer_engine_installed:
run_n_iterations(50, te_forw_backward)

N_ITER = 50
ref_time = (
Expand All @@ -197,24 +149,11 @@ def run_n_iterations(n, fn):
* 1e-6
/ N_ITER
)
if transformer_engine_installed:
te_time_sec = (
benchmark_torch_function_in_microseconds(
run_n_iterations, N_ITER, te_forw_backward
)
* 1e-6
/ N_ITER
)
else:
te_time_sec = None

if rank == 0:
print("ref_time", ref_time)
print("float8_time", float8_time)
print("te_time_sec", te_time_sec)
print("float8 speedup", ref_time / float8_time)
if transformer_engine_installed:
print("te speedup", ref_time / te_time_sec)

cleanup()

Expand Down
40 changes: 0 additions & 40 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,6 @@ def profile_function(
return prof


# Check if transformer_engine is installed
transformer_engine_installed = False
try:
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

transformer_engine_installed = True
except ImportError:
print("transformer_engine not installed and we won't compare against this")


@dataclass(frozen=True)
class LinearParams:
M: int
Expand Down Expand Up @@ -165,35 +154,13 @@ def float8_forw_backward_wrapper(x):
with record_function("backward"):
out.sum().backward()

if transformer_engine_installed:
# Use the same recipe as float8_linear.DelayedScalingRecipe
fp8_format = recipe.Format.HYBRID
fp8_recipe = recipe.DelayedScaling(
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
)
te_linear = te.Linear(params.K, params.N, bias=params.input_bias).to(
device="cuda", dtype=params.ref_dtype
)

def te_forw_backward(x):
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with record_function("forward"):
out = te_linear(x)
with record_function("backward"):
out.sum().backward()

if params.torch_compile:
ref_forw_backward = torch.compile(ref_forw_backward)
float8_forw_backward = torch.compile(float8_forw_backward, fullgraph=True)
# Compiling TE_linear fails but they are already compiling under the hood
# if transformer_engine_installed:
# te_forw_backward = torch.compile(te_forw_backward)

for _ in range(5):
ref_forw_backward(input_tensor)
float8_forw_backward_wrapper(input_tensor)
if transformer_engine_installed:
te_forw_backward(input_tensor)

# Profile Reference Linear
ref_string = f"linear_ref_dtype_{params.ref_dtype}_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}_compile_{params.torch_compile}.json"
Expand All @@ -213,13 +180,6 @@ def te_forw_backward(x):
)
profile_function(profile_config, float8_forw_backward_wrapper, input_tensor)

te_string = f"linear_transformer_engine_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}.json"
if transformer_engine_installed:
profile_config = ProfileConfig(
str(profile_path / te_string), te_string, iters=5, warmup_iters=5, sync=True
)
profile_function(profile_config, te_forw_backward, input_tensor)


def invoke_main() -> None:
# Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles --compile=True --linear_type="dynamic"
Expand Down

0 comments on commit d272138

Please sign in to comment.