diff --git a/models/demos/t3000/falcon40b/scripts/distributed_layernorm.py b/models/demos/t3000/falcon40b/scripts/distributed_layernorm.py new file mode 100644 index 00000000000..4e53d98b2ed --- /dev/null +++ b/models/demos/t3000/falcon40b/scripts/distributed_layernorm.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np + +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( + get_atol_rtol_pcc, +) + + +def basic_layernorm(x, gamma, beta, epsilon=1e-5): + mean = torch.mean(x, dim=-1, keepdim=True) + variance = torch.var(x, dim=-1, keepdim=True) + + # Normalize the input + x_hat = (x - mean) / np.sqrt(variance + epsilon) + + # Scale and shift + y = gamma * x_hat + beta + + return y + + +def compute_mean_and_variance(chunk): + n = chunk.shape[-1] + mean = torch.mean(chunk, dim=-1, keepdim=True) + variance = torch.var(chunk, dim=-1, keepdim=True) + return mean, variance, n + + +def combine_statistics(mean1, var1, count1, mean2, var2, count2): + combined_count = count1 + count2 + delta = mean2 - mean1 + combined_mean = (count1 * mean1 + count2 * mean2) / combined_count + combined_variance = (count1 * var1 + count2 * var2 + delta**2 * count1 * count2 / combined_count) / combined_count + return combined_mean, combined_variance, combined_count + + +def chunked_layer_norm_direct(x, gamma, beta, chunk_size=1024, epsilon=1e-5): + total_mean = 0 + total_variance = 0 + total_count = 0 + + # Process each chunk + for i in range(0, x.shape[-1], chunk_size): + chunk = x[:, i : i + chunk_size] + chunk_mean, chunk_variance, chunk_count = compute_mean_and_variance(chunk) + + # Combine statistics from the chunk with the total statistics + total_mean, total_variance, total_count = combine_statistics( + total_mean, total_variance, total_count, chunk_mean, chunk_variance, chunk_count + ) + + # Normalize the input + x_hat = (x - total_mean) / np.sqrt(total_variance + epsilon) + + # Scale and shift + y = gamma * x_hat + beta + + return y + + +def layer_norm_welford(x, gamma, beta, epsilon=1e-5): + # Initialize mean and M2 for Welford's algorithm + mean = torch.zeros((x.shape[0], 1), dtype=x.dtype) + M2 = torch.zeros((x.shape[0], 1), dtype=x.dtype) + count = 0 + + # First pass to compute mean and variance using Welford's algorithm + for i in range(x.shape[-1]): + value = x[:, i : i + 1] + count += 1 + delta = value - mean + mean += delta / count + delta2 = value - mean + M2 += delta * delta2 + + variance = M2 / count + + # Normalize the input + x_hat = (x - mean) / (variance + epsilon) ** 0.5 + + # Scale and shift + y = gamma * x_hat + beta + + return y + + +def combine_statistics_welford(n_a, avg_a, M2_a, n_b, avg_b, M2_b): + n = n_a + n_b + delta = avg_b - avg_a + avg_ab = avg_a + delta * n_b / n + M2_ab = M2_a + M2_b + delta**2 * n_a * n_b / n + return n, avg_ab, M2_ab + + +def layer_norm_welford_chunked(x, gamma, beta, chunk_size=1024, epsilon=1e-5): + mean = torch.zeros((x.shape[0], 1), dtype=x.dtype) + M2 = torch.zeros((x.shape[0], 1), dtype=x.dtype) + count = 0 + + # Process each chunk + for c in range(0, x.shape[-1], chunk_size): + mean_c = torch.zeros((x.shape[0], 1), dtype=x.dtype) + M2_c = torch.zeros((x.shape[0], 1), dtype=x.dtype) + count_c = 0 + chunk = x[:, c : c + chunk_size] + for i in range(chunk.shape[-1]): + value = chunk[:, i : i + 1] + count_c += 1 + delta = value - mean_c + mean_c += delta / count_c + delta2 = value - mean_c + M2_c += delta * delta2 + + count, mean, M2 = combine_statistics_welford(count, mean, M2, count_c, mean_c, M2_c) + + variance = M2 / count + + # Normalize the input + x_hat = (x - mean) / (variance + epsilon) ** 0.5 + + # Scale and shift + y = gamma * x_hat + beta + + return y + + +def layer_norm_decomp_chunked(x, gamma, beta, chunk_size=1024, epsilon=1e-5): + meanx = torch.zeros((x.shape[0], 1), dtype=x.dtype) + meanx2 = torch.zeros((x.shape[0], 1), dtype=x.dtype) + count = 0 + + # Process each chunk + num_chunks = x.shape[-1] // chunk_size + for i in range(0, x.shape[-1], chunk_size): + chunk = x[:, i : i + chunk_size] + count += chunk.shape[-1] + + meanx += torch.mean(chunk, dim=-1, keepdim=True) + meanx2 += torch.mean(torch.square(chunk), dim=-1, keepdim=True) + + mean = meanx / num_chunks + meanx2 = meanx2 / num_chunks + var = meanx2 - torch.square(mean) + + # Normalize the input + x_hat = (x - mean) / torch.sqrt(var + epsilon) + + # Scale and shift + y = gamma * x_hat + beta + + return y + + +def layer_norm_decomp(x, gamma, beta, epsilon=1e-5): + mean = torch.mean(x, dim=-1, keepdim=True) + var = x - mean + var = torch.mean(torch.square(var)) + x_hat = (x - mean) / torch.sqrt(var + epsilon) + y = gamma * x_hat + beta + return y + + +def distributed_layernorm_poc(x, gamma, beta, chunk_size=1024, epsilon=1e-5): + # Prepare inputs for distributed processing + num_chunks = x.shape[-1] // chunk_size + xs = [] + gammas = [] + betas = [] + for i in range(0, x.shape[-1], chunk_size): + x_chunk = x[:, i : i + chunk_size] + xs.append(x_chunk) + + gamma_chunk = gamma[i : i + chunk_size] + gammas.append(gamma_chunk) + + beta_chunk = beta[i : i + chunk_size] + betas.append(beta_chunk) + + count = [] + meanx = [] + meanx2 = [] + # Distributed processing + for chunk in xs: + count_local = chunk.shape[-1] + meanx_local = torch.mean(chunk, dim=-1, keepdim=True) + meanx2_local = torch.mean(torch.square(chunk), dim=-1, keepdim=True) + + count.append(count_local) + meanx.append(meanx_local) + meanx2.append(meanx2_local) + + # AllReduce cound, meanx, meanx2 + count = torch.torch.FloatTensor(count).sum(dim=0) + mean = torch.stack(meanx, dim=0).sum(dim=0) / num_chunks + meanx2 = torch.stack(meanx2, dim=0).sum(dim=0) / num_chunks + var = meanx2 - torch.square(mean) + + # Distributed processing + ys = [] + for i in range(num_chunks): + # Normalize the input + x_hat_local = (xs[i] - mean) / torch.sqrt(var + epsilon) + + # Scale and shift + y_local = gammas[i] * x_hat_local + betas[i] + ys.append(y_local) + + # Post processing: concat ys + y = torch.cat(ys, dim=-1) + + return y + + +def main(): + S = 2048 + H = 8192 + + input_shape = (S, H) + + x = torch.randn(input_shape, dtype=torch.float32) * 4.0 # Example input + + gamma = torch.randn(H) # Scale parameter + beta = torch.randn(H) # Shift parameter + + # PyTorch LayerNorm + layer_norm = torch.nn.LayerNorm(H, elementwise_affine=True) + layer_norm.eval() + layer_norm.weight.data = gamma + layer_norm.bias.data = beta + normalized_output_torch = layer_norm(x) + + # Custom LayerNorm + basic_layernorm_output = basic_layernorm(x, gamma, beta) + normalized_output_custom = chunked_layer_norm_direct(x, gamma, beta) + welford_output = layer_norm_welford(x, gamma, beta) + tt_output = layer_norm_decomp(x, gamma, beta) + decomp_chunked_output = layer_norm_decomp_chunked(x, gamma, beta) + welford_chunked_output = layer_norm_welford_chunked(x, gamma, beta) + distributed_layernorm_output = distributed_layernorm_poc(x, gamma, beta) + + # Comparison + + print("\nBasic LayerNorm") + cal_atol, cal_rtol, cal_pcc, output_str = get_atol_rtol_pcc(basic_layernorm_output, normalized_output_torch) + print(output_str) + + print("\nCustom Chunked LayerNorm") + cal_atol, cal_rtol, cal_pcc, output_str = get_atol_rtol_pcc(normalized_output_custom, normalized_output_torch) + print(output_str) + + print("\nWelford LayerNorm") + cal_atol, cal_rtol, cal_pcc, output_str = get_atol_rtol_pcc(welford_output, normalized_output_torch) + print(output_str) + + print("\nTT LayerNorm") + cal_atol, cal_rtol, cal_pcc, output_str = get_atol_rtol_pcc(tt_output, normalized_output_torch) + print(output_str) + + print("\nDecomposed Chunked LayerNorm") + cal_atol, cal_rtol, cal_pcc, output_str = get_atol_rtol_pcc(decomp_chunked_output, normalized_output_torch) + print(output_str) + + print("\nWelford Chunked LayerNorm") + cal_atol, cal_rtol, cal_pcc, output_str = get_atol_rtol_pcc(welford_chunked_output, normalized_output_torch) + print(output_str) + + print("\nDistributed LayerNorm") + cal_atol, cal_rtol, cal_pcc, output_str = get_atol_rtol_pcc(distributed_layernorm_output, normalized_output_torch) + print(output_str) + + +if __name__ == "__main__": + main() diff --git a/models/demos/t3000/falcon40b/tests/ops/test_distributed_layernorm.py b/models/demos/t3000/falcon40b/tests/ops/test_distributed_layernorm.py new file mode 100644 index 00000000000..ddc517fd9c2 --- /dev/null +++ b/models/demos/t3000/falcon40b/tests/ops/test_distributed_layernorm.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import math +from loguru import logger + +import tt_lib as ttl +import ttnn +from models.demos.t3000.falcon40b.tt.ops.distributed_layernorm import TtDistributedLayernorm +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( + comp_pcc, +) +from models.utility_functions import torch2tt_tensor, tt2torch_tensor, skip_for_grayskull, get_devices_for_t3000 +from models.demos.t3000.falcon40b.tt.model_config import ( + get_model_config, +) + +from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import ( + FalconForCausalLM, +) + + +class PytorchDistributedLayernorm(torch.nn.Module): + def __init__(self, gammas, betas, epsilon=1e-5): + super().__init__() + self.gammas = gammas + self.betas = betas + self.epsilon = epsilon + + def forward(self, xs): + num_chunks = len(xs) + counts = [] + meanxs = [] + meanx2s = [] + # Distributed processing + for chunk in xs: + count_local = chunk.shape[-1] + meanx_local = torch.mean(chunk, dim=-1, keepdim=True) + meanx2_local = torch.mean(torch.square(chunk), dim=-1, keepdim=True) + + counts.append(count_local) + meanxs.append(meanx_local) + meanx2s.append(meanx2_local) + + count = torch.torch.FloatTensor(counts).sum(dim=0) + meanxs = [meanxs[i] * counts[i] for i in range(num_chunks)] # Weighting by chunk size + meanx2s = [meanx2s[i] * counts[i] for i in range(num_chunks)] # Weighting by chunk size + + # AllGather meanx, meanx2 + meanxs = torch.stack(meanxs, dim=0) + meanx2s = torch.stack(meanx2s, dim=0) + + # Reduce + mean = meanxs.sum(dim=0) / count + meanx2 = meanx2s.sum(dim=0) / count + var = meanx2 - torch.square(mean) + + # Distributed processing + ys = [] + for i in range(num_chunks): + # Normalize the input + x_hat_local = (xs[i] - mean) / torch.sqrt(var + self.epsilon) + + # Scale and shift + y_local = self.gammas[i] * x_hat_local + self.betas[i] + ys.append(y_local) + + # Post processing: concat ys + y = torch.cat(ys, dim=-1) + + return y + + +class PytorchLayernorm(torch.nn.Module): + def __init__(self, gamma, beta, hidden_size=8192): + super().__init__() + self.ln = torch.nn.LayerNorm(hidden_size, elementwise_affine=True) + self.ln.weight = gamma + self.ln.bias = beta + + self.ln.eval() + + def forward(self, x): + result = self.ln(x) + return result + + +def run_test_DistributedLayernorm_inference(pcc, devices, model_location_generator, get_tt_cache_path): + S = 2048 + num_chips = 8 + epsilon = 1e-5 + + # Prepare input + torch.manual_seed(0) + + model_input_shape = [1, S] + model_version = "tiiuae/falcon-40b-instruct" + + model_config = get_model_config("BFLOAT8_B-DRAM", "prefill", model_input_shape, num_chips) + + tt_cache_path = get_tt_cache_path( + model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"] + ) + + if 0: + model_version = "tiiuae/falcon-40b-instruct" + + model_name = model_location_generator(model_version, model_subdir="Falcon") + + hugging_face_reference_model = FalconForCausalLM.from_pretrained( + model_name, low_cpu_mem_usage=True, num_hidden_layers=1 + ) + hugging_face_reference_model.eval() + config = hugging_face_reference_model.config + + hidden_size = config.hidden_size + gamma = hugging_face_reference_model.transformer.h[0].ln_attn.weight + beta = hugging_face_reference_model.transformer.h[0].ln_attn.bias + else: + hidden_size = 8192 + gamma = torch.nn.Parameter(torch.randn(hidden_size)) # Scale parameter + beta = torch.nn.Parameter(torch.randn(hidden_size)) # Shift parameter + + input_shape = [1, 1, S, hidden_size] + + input_torch = (torch.rand(input_shape) * 2) - 1 + + inputs_torch = torch.chunk(input_torch, len(devices), -1) + gammas_torch = torch.chunk(gamma, len(devices), -1) + betas_torch = torch.chunk(beta, len(devices), -1) + + tt_inputs = [] + for i in range(len(devices)): + tt_input_host = torch2tt_tensor(inputs_torch[i], None, tt_dtype=ttl.tensor.DataType.BFLOAT16) + tt_inputs.append(tt_input_host.to(devices[i], model_config["DEFAULT_MEMCFG"])) + + # PyTorch basic layernorm output -------------------------------------------------------------------- + pytorch_FalconLayernorm_model = PytorchLayernorm(gamma=gamma, beta=beta) + torch_layernorm_output = pytorch_FalconLayernorm_model(input_torch) + + # PyTorch distributed layernorm output -------------------------------------------------------------------- + pytorch_FalconLayernorm_model = PytorchDistributedLayernorm(gammas=gammas_torch, betas=betas_torch) + torch_distributed_layernorm_outputs = pytorch_FalconLayernorm_model(inputs_torch) + torch_distributed_layernorm_output = torch.concat([torch_distributed_layernorm_outputs], -1) + + # check pytorch vs. distributed pytorch implementation--------------------------------------------------------- + does_pass, output_pcc = comp_pcc(torch_layernorm_output, torch_distributed_layernorm_output, pcc) + logger.info(f"PCC value: {output_pcc}") + + if does_pass: + logger.info("Pytorch distributed layernorm Passed!") + else: + logger.warning("Pytorch distributed layernorm Failed!") + assert does_pass, f"PCC value is lower than {pcc}" + + # TT hardware execution ------------------------------------------------------------- + tt_distributed_layernorm = TtDistributedLayernorm(devices, gammas_torch, betas_torch, epsilon, tt_cache_path) + + tt_outputs = tt_distributed_layernorm(tt_inputs) + + tt_out = torch.concat([tt2torch_tensor(tt_o) for tt_o in tt_outputs], -1) + + # check outputs ---------------------------------------------------------------------- + does_pass, output_pcc = comp_pcc(torch_layernorm_output, tt_out, pcc) + logger.info(f"PCC value: {output_pcc}") + + if does_pass: + logger.info("TT Distributed Layernorm Passed!") + else: + logger.warning("TT Distributed Layernorm Failed!") + assert does_pass, f"PCC value is lower than {pcc}" + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize("pcc", [(0.99)]) +def test_DistributedLayernorm_inference( + pcc, + all_devices, + model_location_generator, + get_tt_cache_path, +): + devices = get_devices_for_t3000(all_devices, 8) + + run_test_DistributedLayernorm_inference(pcc, devices, model_location_generator, get_tt_cache_path) diff --git a/models/demos/t3000/falcon40b/tests/ops/test_distributed_layernorm_dlnp1.py b/models/demos/t3000/falcon40b/tests/ops/test_distributed_layernorm_dlnp1.py new file mode 100644 index 00000000000..40988154853 --- /dev/null +++ b/models/demos/t3000/falcon40b/tests/ops/test_distributed_layernorm_dlnp1.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import math +from loguru import logger + +import tt_lib as ttl +import ttnn +from models.demos.t3000.falcon40b.tt.ops.distributed_layernorm_dlnp1 import TtDistributedLayernormDLNP1 +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( + comp_pcc, +) +from models.utility_functions import torch2tt_tensor, tt2torch_tensor, skip_for_grayskull, get_devices_for_t3000 +from models.demos.t3000.falcon40b.tt.model_config import ( + get_model_config, +) + +from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import ( + FalconForCausalLM, +) + + +class PytorchDistributedLayernormDLNP1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, xs): + num_chunks = len(xs) + S = xs[0].shape[2] + counts = [] + meanxs = [] + meanx2s = [] + # Distributed processing + for chunk in xs: + count_local = chunk.shape[-1] + meanx_local = torch.mean(chunk, dim=-1, keepdim=True) + meanx2_local = torch.mean(torch.square(chunk), dim=-1, keepdim=True) + + counts.append(count_local) + meanxs.append(meanx_local) + meanx2s.append(meanx2_local) + + meanxs = [meanxs[i] * counts[i] for i in range(num_chunks)] # Weighting by chunk size + meanx2s = [meanx2s[i] * counts[i] for i in range(num_chunks)] # Weighting by chunk size + + # pad with zeros as for tiles + output = [] + for i in range(num_chunks): + output.append( + torch.concat([meanxs[i], torch.zeros([1, 1, S, 31]), meanx2s[i], torch.zeros([1, 1, S, 31])], dim=-1) + ) + + return output + + +def run_test_DistributedLayernorm_inference(pcc, devices, model_location_generator, get_tt_cache_path): + S = 2048 + + # Prepare input + torch.manual_seed(0) + + hidden_size = 8192 + input_shape = [1, 1, S, hidden_size] + input_torch = (torch.rand(input_shape) * 2) - 1 + inputs_torch = torch.chunk(input_torch, len(devices), -1) + + dram_memcfg = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM) + + tt_inputs = [] + for i in range(len(devices)): + tt_input_host = torch2tt_tensor(inputs_torch[i], None, tt_dtype=ttl.tensor.DataType.BFLOAT16) + tt_inputs.append(tt_input_host.to(devices[i], dram_memcfg)) + + # PyTorch distributed layernorm output -------------------------------------------------------------------- + pytorch_FalconLayernorm_model = PytorchDistributedLayernormDLNP1() + torch_output = pytorch_FalconLayernorm_model(inputs_torch) + torch_output = torch.concat(torch_output, -1) + + # TT hardware execution ------------------------------------------------------------- + tt_distributed_layernorm = TtDistributedLayernormDLNP1() + tt_output = tt_distributed_layernorm(tt_inputs) + + tt_output_host = torch.concat([tt2torch_tensor(tt_o) for tt_o in tt_output], -1) + + # check outputs ---------------------------------------------------------------------- + does_pass, output_pcc = comp_pcc(torch_output, tt_output_host, pcc) + logger.info(f"PCC value: {output_pcc}") + + if does_pass: + logger.info("TT DLP1 Passed!") + else: + logger.warning("TT DLP1 Failed!") + assert does_pass, f"PCC value is lower than {pcc}" + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize("pcc", [(0.99)]) +def test_DistributedLayernorm_inference( + pcc, + all_devices, + model_location_generator, + get_tt_cache_path, +): + devices = get_devices_for_t3000(all_devices, 8) + + run_test_DistributedLayernorm_inference(pcc, devices, model_location_generator, get_tt_cache_path) diff --git a/models/demos/t3000/falcon40b/tt/ops/distributed_layernorm.py b/models/demos/t3000/falcon40b/tt/ops/distributed_layernorm.py new file mode 100644 index 00000000000..93e3392ff6a --- /dev/null +++ b/models/demos/t3000/falcon40b/tt/ops/distributed_layernorm.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import math +from torch import nn +import tt_lib as ttl +import ttnn + +from typing import List +from models.utility_functions import torch2tt_tensor, tt2torch_tensor + + +class TtDistributedLayernorm: + def __init__(self, devices, gammas, betas, epsilon, tt_cache_path): + super().__init__() + + self.devices = devices + ln_weights_str = f"ln.weight" + ln_bias_str = f"ln.bias" + + dtype = ttl.tensor.DataType.BFLOAT16 + # dtype = ttl.tensor.DataType.BFLOAT8_B + dram_memcfg = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM) + self.dram_memcfg = dram_memcfg + + num_devices = len(devices) + + self.ln_gamma = [] + self.ln_beta = [] + for i in range(num_devices): + ln_weights_path = tt_cache_path / f"{ln_weights_str}_{dtype.name}_{i}_{num_devices}.bin" + if (ln_weights_path).exists(): + ln_gamma_host = ttl.tensor.load_tensor(str(ln_weights_path)) + self.ln_gamma.append(ln_gamma_host.to(devices[i], dram_memcfg)) + else: + ln_gamma_host = torch2tt_tensor( + gammas[i], + None, + tt_layout=ttl.tensor.Layout.ROW_MAJOR, + tt_memory_config=dram_memcfg, + tt_dtype=dtype, + ) + + self.ln_gamma.append(ln_gamma_host.to(devices[i], dram_memcfg)) + + ttl.tensor.dump_tensor( + str(ln_weights_path), + ln_gamma_host, + ) + + ln_bias_path = tt_cache_path / f"{ln_bias_str}_{dtype.name}_{i}_{num_devices}.bin" + if (ln_bias_path).exists(): + ln_beta_host = ttl.tensor.load_tensor(str(ln_bias_path)) + self.ln_beta.append(ln_beta_host.to(devices[i], dram_memcfg)) + else: + ln_beta_host = torch2tt_tensor( + betas[i], + None, + tt_layout=ttl.tensor.Layout.ROW_MAJOR, + tt_memory_config=dram_memcfg, + tt_dtype=dtype, + ) + self.ln_beta.append(ln_beta_host.to(devices[i], dram_memcfg)) + + ttl.tensor.dump_tensor( + str(ln_bias_path), + ln_beta_host, + ) + + self.ln_eps = epsilon + + def __call__(self, xs: ttl.tensor.Tensor) -> ttl.tensor.Tensor: + num_devices = len(xs) + + counts = [] + total_count = 0 + meanxs = [] + + # Each device computes local statistics mean(x) and mean(x^2) + # meanx = torch.mean(xs, dim=-1, keepdim=True) + for i in range(num_devices): + count_local = xs[i].shape[-1] + total_count += count_local + counts.append(count_local) + + meanx_local = ttl.tensor.reduce( + xs[i], ttl.tensor.ReduceOpMath.SUM, ttl.tensor.ReduceOpDim.W, scaler=1.0 / counts[i] + ) + meanxs.append(meanx_local) + + # meanx2 = torch.mean(torch.square(xs), dim=-1, keepdim=True) + meanx2s = [] + for i in range(num_devices): + x2_local = ttl.tensor.pow(xs[i], 2) + meanx2_local = ttl.tensor.reduce( + x2_local, ttl.tensor.ReduceOpMath.SUM, ttl.tensor.ReduceOpDim.W, scaler=1.0 / counts[i] + ) + meanx2s.append(meanx2_local) + + # AllReduce meanx and meanx2 + # Weighted meanx to number of samples per device + for i in range(num_devices): + meanxs[i] = ttl.tensor.mul_unary(meanxs[i], counts[i]) + # AllGather + meanxs = ttl.tensor.all_gather( + meanxs, + dim=3, + num_links=1, + output_mem_config=self.dram_memcfg, + ) + # Mean over per-device meanx + # mean = torch.stack(meanx, dim=0).sum(dim=0) / total_count + mean = [] + for i in range(num_devices): + mean.append( + ttl.tensor.reduce( + meanxs[i], ttl.tensor.ReduceOpMath.SUM, ttl.tensor.ReduceOpDim.W, scaler=1.0 / total_count + ) + ) + + # Weighted meanx2 to number of samples per device + for i in range(num_devices): + meanx2s[i] = ttl.tensor.mul_unary(meanx2s[i], counts[i]) + # AllGather + meanx2s = ttl.tensor.all_gather( + meanx2s, + dim=3, + num_links=1, + output_mem_config=self.dram_memcfg, + ) + # Mean over per-device meanx2 + # meanx2 = torch.stack(meanx2, dim=0).sum(dim=0) / total_count + meanx2 = [] + for i in range(num_devices): + meanx2.append( + ttl.tensor.reduce( + meanx2s[i], ttl.tensor.ReduceOpMath.SUM, ttl.tensor.ReduceOpDim.W, scaler=1.0 / total_count + ) + ) + + # Variance + # var = meanx2 - torch.square(mean) + var = [] + for i in range(num_devices): + var.append(ttl.tensor.pow(mean[i], 2)) + for i in range(num_devices): + var[i] = ttl.tensor.sub(meanx2[i], var[i]) + meanx2[i].deallocate(True) + + # Normalize the input: x_hat = (xs[i] - mean) / torch.sqrt(var + epsilon) + denominators = [] + for i in range(num_devices): + denominators.append(ttl.tensor.add_unary(var[i], self.ln_eps)) + for i in range(num_devices): + denominators[i] = ttl.tensor.pow(denominators[i], 0.5) + for i in range(num_devices): + denominators[i] = ttl.tensor.recip(denominators[i]) + + nominators = [] + for i in range(num_devices): + nominators.append( + ttl.tensor.bcast(xs[i], mean[i], math_op=ttl.tensor.BcastOpMath.SUB, dim=ttl.tensor.BcastOpDim.W) + ) + + x_hat = [] + for i in range(num_devices): + x_hat.append( + ttl.tensor.bcast( + nominators[i], denominators[i], math_op=ttl.tensor.BcastOpMath.MUL, dim=ttl.tensor.BcastOpDim.W + ) + ) + nominators[i].deallocate(True) + denominators[i].deallocate(True) + + # Scale and shift: x_hat = self.gammas * x_hat + self.betas_torch + for i in range(num_devices): + x_hat[i] = ttl.tensor.bcast( + x_hat[i], self.ln_gamma[i], math_op=ttl.tensor.BcastOpMath.MUL, dim=ttl.tensor.BcastOpDim.H + ) + for i in range(num_devices): + x_hat[i] = ttl.tensor.bcast( + x_hat[i], self.ln_beta[i], math_op=ttl.tensor.BcastOpMath.ADD, dim=ttl.tensor.BcastOpDim.H + ) + + return x_hat diff --git a/models/demos/t3000/falcon40b/tt/ops/distributed_layernorm_dlnp1.py b/models/demos/t3000/falcon40b/tt/ops/distributed_layernorm_dlnp1.py new file mode 100644 index 00000000000..99d02920a8a --- /dev/null +++ b/models/demos/t3000/falcon40b/tt/ops/distributed_layernorm_dlnp1.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import math +from torch import nn +import tt_lib as ttl +import ttnn + + +class TtDistributedLayernormDLNP1: + def __init__(self): + super().__init__() + + def __call__(self, xs: ttl.tensor.Tensor) -> ttl.tensor.Tensor: + num_devices = len(xs) + + counts = [] + total_count = 0 + meanxs = [] + + # Each device computes local statistics mean(x) and mean(x^2) + # meanx = torch.mean(xs, dim=-1, keepdim=True) + for i in range(num_devices): + count_local = xs[i].shape[-1] + total_count += count_local + counts.append(count_local) + + meanx_local = ttl.tensor.reduce( + xs[i], ttl.tensor.ReduceOpMath.SUM, ttl.tensor.ReduceOpDim.W, scaler=1.0 / counts[i] + ) + meanxs.append(meanx_local) + + # meanx2 = torch.mean(torch.square(xs), dim=-1, keepdim=True) + meanx2s = [] + for i in range(num_devices): + x2_local = ttl.tensor.pow(xs[i], 2) + meanx2_local = ttl.tensor.reduce( + x2_local, ttl.tensor.ReduceOpMath.SUM, ttl.tensor.ReduceOpDim.W, scaler=1.0 / counts[i] + ) + meanx2s.append(meanx2_local) + + # Weighted meanx to number of samples per device + for i in range(num_devices): + meanxs[i] = ttl.tensor.mul_unary(meanxs[i], counts[i]) + + # Weighted meanx2 to number of samples per device + for i in range(num_devices): + meanx2s[i] = ttl.tensor.mul_unary(meanx2s[i], counts[i]) + + output = [] + for i in range(num_devices): + output.append(ttl.tensor.concat([meanxs[i], meanx2s[i]], 3)) + + return output diff --git a/tests/scripts/t3000/run_t3000_frequent_tests.sh b/tests/scripts/t3000/run_t3000_frequent_tests.sh index b50217e24b8..b52b72fcc4a 100755 --- a/tests/scripts/t3000/run_t3000_frequent_tests.sh +++ b/tests/scripts/t3000/run_t3000_frequent_tests.sh @@ -76,9 +76,23 @@ run_t3000_trace_stress_tests() { # Record the end time end_time=$(date +%s) duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_trace_stress_tests $duration seconds to complete" } +run_t3000_distributed_layernorm_tests() { + # Record the start time + start_time=$(date +%s) + + echo "LOG_METAL: Running run_t3000_distributed_layernorm_tests" + + WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest tests/ttnn/unit_tests/operations/test_distributed_layernorm.py + + # Record the end time + end_time=$(date +%s) + duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_distributed_layernorm_tests $duration seconds to complete" +} run_t3000_falcon40b_tests() { # Record the start time @@ -104,6 +118,9 @@ run_t3000_tests() { # Run tteager tests run_t3000_tteager_tests + #Run distributed layernorm tests + run_t3000_distributed_layernorm_tests + # Run trace tests run_t3000_trace_stress_tests diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_distributed_layernorm_post_allgather.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_distributed_layernorm_post_allgather.py new file mode 100644 index 00000000000..c152b56c4fb --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_distributed_layernorm_post_allgather.py @@ -0,0 +1,185 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import pytest +import torch + +import tt_lib as ttl + +from models.utility_functions import tt2torch_tensor, torch2tt_tensor, skip_for_grayskull + +from loguru import logger +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_allclose, comp_pcc + + +def reference_layernorm(x, gamma, beta, epsilon, is_rmsnorm): + if is_rmsnorm: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + epsilon) * gamma + else: + return torch.nn.functional.layer_norm(x, x.shape[-1:], gamma, beta, epsilon) + + +def run_layernorm_part_2(inp_shape, n_devices, is_rmsnorm, dtype, device, fp32_enabled=False): + kernel_config = ttl.tensor.WormholeComputeKernelConfig( + math_fidelity=ttl.tensor.MathFidelity.HiFi4, # Highest fidelity + math_approx_mode=False, + fp32_dest_acc_en=fp32_enabled, + packer_l1_acc=False, + ) + + torch.manual_seed(1234) + tile_cols_per_device = 1 if is_rmsnorm else 2 # layernorm has 2 stats to distribute + + canon_inp = torch.randn(inp_shape) * 4 - 1 + gamma = torch.rand(inp_shape[-1]) * 2 - 1 + beta = torch.rand(inp_shape[-1]) * 2 - 1 + gamma_chunked = gamma.chunk(n_devices, dim=-1) + beta_chunked = beta.chunk(n_devices, dim=-1) + # Get per-chunk mean and mean(x^2) + inp_chunked = canon_inp.chunk(n_devices, dim=-1) + mean = [x.sum(dim=-1, keepdim=True) for x in inp_chunked] + meanx2 = [x.pow(2).sum(dim=-1, keepdim=True) for x in inp_chunked] + + stats_tiles = torch.zeros(inp_shape[:-1] + (32 * n_devices * tile_cols_per_device,)) + for idx, (m, mm) in enumerate(zip(mean, meanx2)): + mm_idx = idx * tile_cols_per_device * 32 + stats_tiles[..., mm_idx : mm_idx + 1] = mm + + if not is_rmsnorm: + m_idx = mm_idx + 32 # next tile is m + stats_tiles[..., m_idx : m_idx + 1] = m + + epsilon = 1e-5 + # reference layernorm + ref_out = reference_layernorm(canon_inp, gamma, beta, epsilon, is_rmsnorm) + ref_chunks = ref_out.chunk(n_devices, dim=-1) + + dram_memcfg = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM) + + all_pass = True + # layernorm post all gather + for d in range(n_devices): + tt_inp = torch2tt_tensor( + inp_chunked[d], + tt_dtype=dtype, + tt_device=device, + tt_layout=ttl.tensor.Layout.TILE, + tt_memory_config=dram_memcfg, + ) + tt_gamma = torch2tt_tensor( + gamma_chunked[d].reshape(1, 1, -1, 32), + tt_dtype=ttl.tensor.DataType.BFLOAT16, + tt_device=device, + tt_layout=ttl.tensor.Layout.ROW_MAJOR, + tt_memory_config=dram_memcfg, + ) + tt_beta = torch2tt_tensor( + beta_chunked[d].reshape(1, 1, -1, 32), + tt_dtype=ttl.tensor.DataType.BFLOAT16, + tt_device=device, + tt_layout=ttl.tensor.Layout.ROW_MAJOR, + tt_memory_config=dram_memcfg, + ) + tt_stats = torch2tt_tensor( + stats_tiles, + tt_dtype=ttl.tensor.DataType.BFLOAT16, + tt_device=device, + tt_layout=ttl.tensor.Layout.TILE, + tt_memory_config=dram_memcfg, + ) + + if is_rmsnorm: + tt_lnp2_out = ttl.operations.primary.rmsnorm_post_allgather( + tt_inp, tt_stats, epsilon, tt_gamma, compute_kernel_config=kernel_config + ) + else: + tt_lnp2_out = ttl.operations.primary.layernorm_post_allgather( + tt_inp, tt_stats, epsilon, tt_gamma, tt_beta, compute_kernel_config=kernel_config + ) + + tt_lnp2_out_cpu = tt2torch_tensor(tt_lnp2_out) + passing, output_str = comp_allclose(ref_chunks[d], tt_lnp2_out_cpu, rtol=1e-1, atol=1e-01) + logger.debug(f"layernorm vs tt={output_str}") + all_pass = all_pass and passing + + assert all_pass + + +@skip_for_grayskull("Requires wormhole") +@pytest.mark.parametrize( + "dtype", + (ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B), + ids=["BFLOAT16", "BFLOAT8_B"], +) +@pytest.mark.parametrize( + "inp_shape", + [ + (1, 1, 2048, 8192), + (1, 1, 128, 8192), + (2, 1, 128, 8192), + ], +) +@pytest.mark.parametrize( + "n_devices", + [4, 8], +) +@pytest.mark.parametrize( + "is_rmsnorm", + [True, False], + ids=["rmsnorm", "layernorm"], +) +@pytest.mark.parametrize( + "fp32_enabled", + [True, False], + ids=["fp32_enabled", "fp32_disabled"], +) +def test_layernorm_part_2_with_program_cache( + inp_shape, n_devices, is_rmsnorm, dtype, fp32_enabled, device, use_program_cache +): + run_layernorm_part_2(inp_shape, n_devices, is_rmsnorm, dtype, device, fp32_enabled) + + +@skip_for_grayskull("Requires wormhole") +@pytest.mark.parametrize( + "dtype", + [ttl.tensor.DataType.BFLOAT16], + ids=["BFLOAT16"], +) +@pytest.mark.parametrize( + "inp_shape", + [ + (1, 1, 2048, 8192), + ], +) +@pytest.mark.parametrize( + "n_devices", + [8], +) +@pytest.mark.parametrize( + "is_rmsnorm", + [True, False], + ids=["rmsnorm", "layernorm"], +) +def test_layernorm_part_2_with_program_cache2(inp_shape, n_devices, is_rmsnorm, dtype, device, use_program_cache): + dummy_tensors = [] + + dram_memcfg = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM) + + for i in range(2): + if i > 0: + dummy_tensors.append( + torch2tt_tensor( + torch.randn(inp_shape), + tt_dtype=dtype, + tt_device=device, + tt_layout=ttl.tensor.Layout.TILE, + tt_memory_config=dram_memcfg, + ) + ) + run_layernorm_part_2(inp_shape, n_devices, is_rmsnorm, dtype, device) + + assert device.num_program_cache_entries() == 1, "Program cache should have only one entry" + str( + device.num_program_cache_entries() + ) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_distributed_layernorm_pre_allgather.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_distributed_layernorm_pre_allgather.py new file mode 100644 index 00000000000..4e1cdffc35b --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_distributed_layernorm_pre_allgather.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import pytest +import torch +from models.utility_functions import tt2torch_tensor, torch2tt_tensor, skip_for_grayskull + +import tt_lib as ttl + +from loguru import logger +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_allclose_and_pcc, comp_equal + + +def reference(x, n_devices, is_rmsnorm): + num_chunks = len(x) + S = x[0].shape[2] + B = x[0].shape[0] + counts = [] + sumxs = [] + sumx2s = [] + # Distributed processing + for chunk in x: + count_local = chunk.shape[-1] + sumx_local = torch.sum(chunk, dim=-1, keepdim=True) + sumx2_local = torch.sum(torch.square(chunk), dim=-1, keepdim=True) + + counts.append(count_local) + sumxs.append(sumx_local) + sumx2s.append(sumx2_local) + + # pad with zeros as for tiles + output = [] + for i in range(num_chunks): + if is_rmsnorm: + output.append(torch.concat([sumx2s[i], torch.zeros([B, 1, S, 31])], dim=-1)) + else: + output.append( + torch.concat([sumx2s[i], torch.zeros([B, 1, S, 31]), sumxs[i], torch.zeros([B, 1, S, 31])], dim=-1) + ) + + return output + + +def ln_pre_allgather_op(xs, n_devices, is_rmsnorm, out_dtpe): + kernel_config = ttl.tensor.WormholeComputeKernelConfig( + math_fidelity=ttl.tensor.MathFidelity.HiFi4, # Highest fidelity + math_approx_mode=False, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + tt_out = [] + for d in range(n_devices): + if is_rmsnorm: + tt_out.append( + ttl.operations.primary.rmsnorm_pre_allgather( + xs[d], compute_kernel_config=kernel_config, output_dtype=out_dtpe + ) + ) + else: + tt_out.append( + ttl.operations.primary.layernorm_pre_allgather( + xs[d], compute_kernel_config=kernel_config, output_dtype=out_dtpe + ) + ) + return tt_out + + +def run_layernorm_part_1(inp_shape, n_devices, is_rmsnorm, input_dtype, output_dtype, device): + torch.manual_seed(1234) + + # Set print options + torch.set_printoptions(threshold=100) + + canon_inp = torch.randn(inp_shape).bfloat16() * 4 - 1 + + # Get per-chunk inputs + inp_chunked = canon_inp.chunk(n_devices, dim=-1) + + # Reference + out_torch = reference(inp_chunked, n_devices, is_rmsnorm) + out_torch = torch.concat(out_torch, -1) + + dram_memcfg = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM) + + tt_inp = [] + for d in range(n_devices): + tt_inp.append( + torch2tt_tensor( + inp_chunked[d], + tt_dtype=input_dtype, + tt_device=device, + tt_layout=ttl.tensor.Layout.TILE, + tt_memory_config=dram_memcfg, + ) + ) + + # LN pre all gather OP + tt_out = ln_pre_allgather_op(tt_inp, n_devices, is_rmsnorm, output_dtype) + + tt_output_host = torch.concat([tt2torch_tensor(tt_o) for tt_o in tt_out], -1) + + all_passing = True + + for i in range(n_devices): + device_offset = i * 32 if is_rmsnorm else i * 64 + # Compare sum(xˆ2) + passing, output_str = comp_allclose_and_pcc( + out_torch[:, :, :, 0 + device_offset], + tt_output_host[:, :, :, 0 + device_offset], + rtol=1e-1, + atol=1e-01, + pcc=0.9, + ) + logger.debug(f"tt vs torch sum(xˆ2) = {output_str}") + all_passing &= passing + + # Check if zeros are same + passing, output_str = comp_equal( + out_torch[:, :, :, 1 + device_offset : 32 + device_offset], + tt_output_host[:, :, :, 1 + device_offset : 32 + device_offset], + ) + logger.debug(f"tt vs torch padding 1 = {output_str}") + all_passing &= passing + + if not is_rmsnorm: + # Compare sum(x) + passing, output_str = comp_allclose_and_pcc( + out_torch[:, :, :, 32 + device_offset], + tt_output_host[:, :, :, 32 + device_offset], + rtol=1e-1, + atol=1e-01, + pcc=0.98, + ) + logger.debug(f"tt vs torch sum(x) = {output_str}") + all_passing &= passing + + # Check if zeros are same + passing, output_str = comp_equal( + out_torch[:, :, :, 33 + device_offset : 64 + device_offset], + tt_output_host[:, :, :, 33 + device_offset : 64 + device_offset], + ) + logger.debug(f"tt vs torch padding 2 = {output_str}") + all_passing &= passing + + assert all_passing + + +@skip_for_grayskull("Requires wormhole") +@pytest.mark.parametrize( + "input_dtype", + (ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B), + ids=["BFLOAT16", "BFLOAT8_B"], +) +@pytest.mark.parametrize( + "output_dtype", + (ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B), + ids=["BFLOAT16", "BFLOAT8_B"], +) +@pytest.mark.parametrize( + "inp_shape", + [ + (1, 1, 2048, 8192), + (1, 1, 128, 8192), + (2, 1, 128, 8192), + ], +) +@pytest.mark.parametrize( + "n_devices", + [4, 8], +) +@pytest.mark.parametrize( + "is_rmsnorm", + [True, False], + ids=["rmsnorm", "layernorm"], +) +def test_layernorm_part_1_with_program_cache( + inp_shape, n_devices, is_rmsnorm, input_dtype, output_dtype, device, use_program_cache +): + run_layernorm_part_1(inp_shape, n_devices, is_rmsnorm, input_dtype, output_dtype, device) + + +@skip_for_grayskull("Requires wormhole") +@pytest.mark.parametrize( + "input_dtype", + [ttl.tensor.DataType.BFLOAT16], + ids=["BFLOAT16"], +) +@pytest.mark.parametrize( + "output_dtype", + [ttl.tensor.DataType.BFLOAT16], + ids=["BFLOAT16"], +) +@pytest.mark.parametrize( + "inp_shape", + [ + (1, 1, 2048, 8192), + ], +) +@pytest.mark.parametrize( + "n_devices", + [8], +) +@pytest.mark.parametrize( + "is_rmsnorm", + [True, False], + ids=["rmsnorm", "layernorm"], +) +def test_layernorm_part_1_with_program_cache2( + inp_shape, n_devices, is_rmsnorm, input_dtype, output_dtype, device, use_program_cache +): + dummy_tensors = [] + + dram_memcfg = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM) + + for i in range(2): + if i > 0: + dummy_tensors.append( + torch2tt_tensor( + torch.randn(inp_shape), + tt_dtype=input_dtype, + tt_device=device, + tt_layout=ttl.tensor.Layout.TILE, + tt_memory_config=dram_memcfg, + ) + ) + run_layernorm_part_1(inp_shape, n_devices, is_rmsnorm, input_dtype, output_dtype, device) + + assert device.num_program_cache_entries() == 1, "Program cache should have only one entry" + str( + device.num_program_cache_entries() + ) diff --git a/tests/ttnn/unit_tests/operations/test_distributed_layernorm.py b/tests/ttnn/unit_tests/operations/test_distributed_layernorm.py new file mode 100644 index 00000000000..b84514ff8fa --- /dev/null +++ b/tests/ttnn/unit_tests/operations/test_distributed_layernorm.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import pytest +import torch + +import ttnn + +from models.utility_functions import tt2torch_tensor, get_devices_for_t3000, skip_for_grayskull + +from loguru import logger +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_allclose, comp_pcc + + +def reference_layernorm(x, gamma, beta, epsilon, is_rmsnorm): + if is_rmsnorm: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + epsilon) * gamma + else: + return torch.nn.functional.layer_norm(x, x.shape[-1:], gamma, beta, epsilon) + + +def tt_distributed_layernorm(inp, gamma, beta, epsilon, is_rmsnorm, compute_kernel_config, stats_dtype): + n_devices = len(inp) + + # Run layernorm part 1 + tt_stats = [] + for d in range(n_devices): + if is_rmsnorm: + tt_stats.append( + ttnn.experimental.operations.primary.rmsnorm_pre_allgather( + inp[d], compute_kernel_config=compute_kernel_config, output_dtype=stats_dtype + ) + ) + else: + tt_stats.append( + ttnn.experimental.operations.primary.layernorm_pre_allgather( + inp[d], compute_kernel_config=compute_kernel_config, output_dtype=stats_dtype + ) + ) + + # AllGather stats + tt_stats = ttnn.experimental.tensor.all_gather( + tt_stats, dim=3, num_links=1, output_mem_config=ttnn.DRAM_MEMORY_CONFIG + ) + + # Run layernorm part 2 + tt_out = [] + for d in range(n_devices): + if is_rmsnorm: + tt_out.append( + ttnn.experimental.operations.primary.rmsnorm_post_allgather( + inp[d], tt_stats[d], epsilon, gamma[d], compute_kernel_config=compute_kernel_config + ) + ) + else: + tt_out.append( + ttnn.experimental.operations.primary.layernorm_post_allgather( + inp[d], tt_stats[d], epsilon, gamma[d], beta[d], compute_kernel_config=compute_kernel_config + ) + ) + tt_stats[d].deallocate(True) + return tt_out + + +def run_distributed_layernorm( + inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, devices, fp32_enabled=False, iterations=1 +): + compute_kernel_config = ttnn.experimental.tensor.WormholeComputeKernelConfig( + math_fidelity=ttnn.experimental.tensor.MathFidelity.HiFi4, # Highest fidelity + math_approx_mode=False, + fp32_dest_acc_en=fp32_enabled, + packer_l1_acc=False, + ) + + torch.manual_seed(1234) + + canon_inp = torch.randn(inp_shape) * 4 - 1 + gamma = torch.rand(inp_shape[-1]) * 2 - 1 + beta = torch.rand(inp_shape[-1]) * 2 - 1 + gamma_chunked = gamma.chunk(n_devices, dim=-1) + beta_chunked = beta.chunk(n_devices, dim=-1) + inp_chunked = canon_inp.chunk(n_devices, dim=-1) + epsilon = 1e-5 + + # reference impl + out_torch = reference_layernorm(canon_inp, gamma, beta, epsilon, is_rmsnorm) + + tt_inp = [] + for d in range(n_devices): + tt_inp.append( + ttnn.as_tensor( + inp_chunked[d], + dtype=dtype, + device=devices[d], + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + ) + + tt_gamma = [] + for d in range(n_devices): + tt_gamma.append( + ttnn.as_tensor( + gamma_chunked[d].reshape(1, 1, -1, 32), + dtype=ttnn.bfloat16, + device=devices[d], + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + ) + + tt_beta = [] + for d in range(n_devices): + tt_beta.append( + ttnn.as_tensor( + beta_chunked[d].reshape(1, 1, -1, 32), + dtype=ttnn.bfloat16, + device=devices[d], + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + ) + for i in range(iterations): + tt_out = tt_distributed_layernorm( + tt_inp, tt_gamma, tt_beta, epsilon, is_rmsnorm, compute_kernel_config, stats_dtype + ) + tt_output_host = torch.concat([tt2torch_tensor(tt_o) for tt_o in tt_out], -1) + + passing, output_str = comp_allclose(tt_output_host, out_torch, rtol=1e-1, atol=1e-01) + logger.debug(f"torch vs tt distributed layernorm = {output_str}") + + assert passing + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "iterations", + [2], + ids=["loops2"], +) +@pytest.mark.parametrize( + "dtype", + (ttnn.bfloat16, ttnn.bfloat8_b), + ids=["BFLOAT16_in", "BFLOAT8_B_in"], +) +@pytest.mark.parametrize( + "stats_dtype", + (ttnn.bfloat16, ttnn.bfloat8_b), + ids=["BFLOAT16_stats", "BFLOAT8_B_stats"], +) +@pytest.mark.parametrize( + "inp_shape", + [ + (1, 1, 2048, 8192), + (1, 1, 128, 8192), + (2, 1, 128, 8192), + ], + ids=["inp_shape0", "inp_shape1", "inp_shape2"], +) +@pytest.mark.parametrize( + "n_devices", + [4, 8], +) +@pytest.mark.parametrize( + "is_rmsnorm", + [True, False], + ids=["rmsnorm", "layernorm"], +) +def test_distributed_layernorm_with_program_cache( + inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, iterations, all_devices, use_program_cache +): + if len(all_devices) != 8: + pytest.skip("Not T3000!") + + devices = get_devices_for_t3000(all_devices, n_devices) + + run_distributed_layernorm(inp_shape, n_devices, is_rmsnorm, dtype, stats_dtype, devices, iterations=iterations) + + for d in range(len(devices)): + assert devices[d].num_program_cache_entries() == 3, "Program cache should have only 3 entries, but has " + str( + devices[d].num_program_cache_entries() + ) diff --git a/tt_eager/tt_dnn/op_library/CMakeLists.txt b/tt_eager/tt_dnn/op_library/CMakeLists.txt index 1cd0e90bf0a..faadaf1337a 100644 --- a/tt_eager/tt_dnn/op_library/CMakeLists.txt +++ b/tt_eager/tt_dnn/op_library/CMakeLists.txt @@ -129,6 +129,10 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/moreh_mean_backward/moreh_mean_backward_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layernorm/multi_core/layernorm_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layernorm/layernorm_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layernorm_distributed/layernorm_pre_allgather_op_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layernorm_distributed/layernorm_pre_allgather_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layernorm_distributed/layernorm_post_allgather_op_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layernorm_distributed/layernorm_post_allgather_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_bmm/moreh_bmm_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_bmm_backward/moreh_bmm_backward_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_linear/moreh_linear_op.cpp diff --git a/tt_eager/tt_dnn/op_library/layernorm_distributed/kernels/compute/layernorm_post_allgather.cpp b/tt_eager/tt_dnn/op_library/layernorm_distributed/kernels/compute/layernorm_post_allgather.cpp new file mode 100644 index 00000000000..6e3154fb77b --- /dev/null +++ b/tt_eager/tt_dnn/op_library/layernorm_distributed/kernels/compute/layernorm_post_allgather.cpp @@ -0,0 +1,268 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#define REDUCE_OP PoolType::SUM +#define REDUCE_DIM ReduceDim::REDUCE_ROW + +#define BCAST_LLKOP EltwiseBinaryType::ELWMUL +#define BCAST_DIM BroadcastType::COL + +#include "compute_kernel_api/reduce.h" +#include "compute_kernel_api/bcast.h" +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/layernorm.h" + + +ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } +ALWI void REL() { release_dst(tt::DstMode::Half); } + + +namespace NAMESPACE { +void MAIN { + uint32_t NCHt = get_arg_val(0); + constexpr uint32_t Wt = get_compile_time_arg_val(0); + constexpr uint32_t blk = get_compile_time_arg_val(1); + constexpr uint32_t stats_tiles_cols = get_compile_time_arg_val(2); + constexpr uint32_t do_gamma = get_compile_time_arg_val(3); + constexpr uint32_t do_beta = get_compile_time_arg_val(4); + constexpr bool FLOAT32_DTYPE = get_compile_time_arg_val(5) == 1; + + constexpr uint32_t onetile = 1; + + constexpr uint32_t cb_inp = tt::CB::c_in0; + constexpr uint32_t cb_stats = tt::CB::c_in1; + + constexpr uint32_t cb_eps = tt::CB::c_in4; + constexpr uint32_t cb_reduce = tt::CB::c_in5; + + constexpr uint32_t cb_out = tt::CB::c_out0; + + constexpr uint32_t cb_stats_reduced = tt::CB::c_intermed0; // [E(x**2), E(x)] + constexpr uint32_t cb_var_eps = tt::CB::c_intermed3; // var + epsilon (or E(x**2) + epsilon) + constexpr uint32_t cb_recip_sqrt_var = tt::CB::c_intermed4; // 1/sqrt(var+eps) + constexpr uint32_t cb_x_normed = tt::CB::c_intermed6; // (x - E(x)) * 1/sqrt(var+eps) or x * 1/sqrt(E(x**2) + eps) + + constexpr uint32_t cb_var = tt::CB::c_intermed2; // E(x**2) - E(x)**2 or E(x**2) + #ifndef RMSNORM + // Layernorm-specific CBs + constexpr uint32_t cb_mean_squared = tt::CB::c_intermed1; // E(x)**2 + constexpr uint32_t cb_x_minus_mean = tt::CB::c_intermed5; // x - E(x) + + constexpr uint32_t cb_norm_x_input = cb_x_minus_mean; + constexpr uint32_t stats_tile_stride = 2; + #else + constexpr uint32_t cb_norm_x_input = cb_inp; + constexpr uint32_t stats_tile_stride = 1; + #endif + + constexpr uint32_t cb_gamma = tt::CB::c_in2; + constexpr uint32_t cb_beta = tt::CB::c_in3; + uint32_t cb_times_gamma_out = cb_out; + if constexpr(do_gamma and do_beta) { + cb_times_gamma_out = tt::CB::c_intermed7; + } + + binary_op_init_common(cb_inp, cb_inp, cb_stats_reduced); + + cb_wait_front(cb_reduce, 1); // comes from the reader + cb_wait_front(cb_eps, 1); // comes from the reader + + + for (uint32_t ncht = 0; ncht < NCHt; ncht++) { + + constexpr int onetile = 1; + constexpr int dst0 = 0; + + unpack_reconfig_data_format(cb_stats, cb_reduce); + pack_reconfig_data_format(cb_stats_reduced); + + /* + * Reduce stats input. + * cb_stats = [sum(x0**2), sum(x0), sum(x1**2), sum(x1), ...] + * RMSNorm packs mean(x**2) into cb_var. Layernorm just uses cb_stats_reduced. + */ + reduce_init_delta(REDUCE_OP, REDUCE_DIM); + cb_wait_front(cb_stats, stats_tiles_cols); + cb_reserve_back(cb_stats_reduced, stats_tile_stride); + #ifdef RMSNORM + cb_reserve_back(cb_var, 1); + #endif + ACQ(); + // Reduce sum(x**2) first + for (uint32_t i = 0; i < stats_tiles_cols; i += stats_tile_stride) { + reduce_tile(cb_stats, cb_reduce, i, 0, 0); + } + pack_tile(0, cb_stats_reduced); + + #ifndef RMSNORM + // Reduce sum(x) next + for (uint32_t i = 1; i < stats_tiles_cols; i += stats_tile_stride) { + reduce_tile(cb_stats, cb_reduce, i, 0, 1); + } + pack_tile(1, cb_stats_reduced); + #else + pack_tile(0, cb_var); + #endif + REL(); + cb_push_back(cb_stats_reduced, stats_tile_stride); + cb_pop_front(cb_stats, stats_tiles_cols); + #ifdef RMSNORM + cb_push_back(cb_var, 1); + #endif + + reduce_revert_delta(); + + #ifndef RMSNORM + /* + * E[x]**2 + */ + unpack_reconfig_data_format(cb_stats_reduced, cb_stats_reduced); + pack_reconfig_data_format(cb_mean_squared); + mul_tiles_init(); + cb_reserve_back(cb_mean_squared, onetile); + cb_wait_front(cb_stats_reduced, stats_tile_stride); + ACQ(); + mul_tiles(cb_stats_reduced, cb_stats_reduced, 1, 1, 0); + pack_tile(0, cb_mean_squared); + REL(); + + cb_push_back(cb_mean_squared, 1); + + /* + * E[x**2] - E[x]**2 + */ + unpack_reconfig_data_format(cb_stats_reduced, cb_mean_squared); + pack_reconfig_data_format(cb_var); + sub_tiles_init(); + + cb_reserve_back(cb_var, onetile); + cb_wait_front(cb_mean_squared, 1); + ACQ(); + sub_tiles(cb_stats_reduced, cb_mean_squared, 0, 0, 0); + pack_tile(0, cb_var); + REL(); + cb_push_back(cb_var, 1); + cb_pop_front(cb_mean_squared, 1); + + /* + * x - E[x] + */ + unpack_reconfig_data_format(cb_inp, cb_stats_reduced); + pack_reconfig_data_format(cb_x_minus_mean); + sub_bcast_cols_init_short(); + for (uint32_t wt = 0; wt < Wt; wt += blk) { + cb_wait_front(cb_inp, blk); + cb_reserve_back(cb_x_minus_mean, blk); + ACQ(); + for (uint32_t wtr = 0; wtr + +#define REDUCE_OP PoolType::SUM +#define REDUCE_DIM ReduceDim::REDUCE_ROW + +#include "compute_kernel_api/reduce.h" +#include "compute_kernel_api/bcast.h" +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/layernorm.h" + +#include "debug/dprint.h" + + +ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } +ALWI void REL() { release_dst(tt::DstMode::Half); } + + +namespace NAMESPACE { +void MAIN { + uint32_t NCHt = get_arg_val(0); + constexpr uint32_t Wt = get_compile_time_arg_val(0); + constexpr uint32_t blk = get_compile_time_arg_val(1); + + constexpr uint32_t onetile = 1; + + constexpr uint32_t cb_inp = tt::CB::c_in0; + constexpr uint32_t cb_reduce = tt::CB::c_in1; + + constexpr uint32_t cb_out = tt::CB::c_out0; + + constexpr uint32_t cb_x2 = tt::CB::c_intermed0; // x**2 + + cb_wait_front(cb_reduce, 1); // comes from the reader + + binary_op_init_common(cb_inp, cb_reduce, cb_x2); + + for (uint32_t ncht = 0; ncht < NCHt; ncht++) { + + constexpr int onetile = 1; + constexpr int dst0 = 0; + + /* + * x**2 + */ + unpack_reconfig_data_format(cb_inp, cb_inp); + pack_reconfig_data_format(cb_x2); + mul_tiles_init(cb_inp, cb_inp); + for (uint32_t wt = 0; wt < Wt; wt += blk) { + cb_wait_front(cb_inp, wt+blk); // cumulative wait + cb_reserve_back(cb_x2, blk); + ACQ(); + for (uint32_t wtr = 0; wtr(REDUCE_OP, REDUCE_DIM); + cb_wait_front(cb_x2, Wt); + cb_reserve_back(cb_out, onetile); + ACQ(); + for (uint32_t wtr = 0; wtr(REDUCE_OP, REDUCE_DIM); + cb_reserve_back(cb_out, onetile); + ACQ(); + for (uint32_t wtr = 0; wtr +#include "dataflow_api.h" +#include "tt_eager/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp" +#include "tt_eager/tt_dnn/kernels/dataflow/generate_bcast_scalar.hpp" +#include "debug/assert.h" + +void kernel_main() { + const uint32_t src_addr = get_arg_val(0); + const uint32_t NCHt = get_arg_val(1); + const uint32_t Wt = get_arg_val(2); + const uint32_t tile_offset = get_arg_val(3); + const uint32_t stats_tile_offset = get_arg_val(4); + + const uint32_t gamma_addr = get_arg_val(7); + const uint32_t beta_addr = get_arg_val(8); + const uint32_t stats_addr = get_arg_val(9); + + constexpr uint32_t cb_inp = tt::CB::c_in0; + constexpr uint32_t cb_stats = tt::CB::c_in1; + constexpr uint32_t cb_gamma = tt::CB::c_in2; + constexpr uint32_t cb_beta = tt::CB::c_in3; + constexpr uint32_t cb_eps = tt::CB::c_in4; + constexpr uint32_t cb_reduce = tt::CB::c_in5; + + // ublocks size defined in tiles + const uint32_t src0_tile_bytes = get_tile_size(cb_inp); + const DataFormat src0_data_format = get_dataformat(cb_inp); + const uint32_t stats_tile_bytes = get_tile_size(cb_stats); + const DataFormat stats_data_format = get_dataformat(cb_stats); + + constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1; + constexpr bool stats_is_dram = get_compile_time_arg_val(1) == 1; + constexpr bool gamma_is_dram = get_compile_time_arg_val(2) == 1; + constexpr bool beta_is_dram = get_compile_time_arg_val(3) == 1; + constexpr uint32_t blk = get_compile_time_arg_val(4); + constexpr uint32_t stats_tiles_cols = get_compile_time_arg_val(5); + + const InterleavedAddrGenFast src_a = { + .bank_base_address = src_addr, + .page_size = src0_tile_bytes, + .data_format = src0_data_format + }; + + const InterleavedAddrGenFast src_stats = { + .bank_base_address = stats_addr, + .page_size = stats_tile_bytes, + .data_format = stats_data_format + }; + + + constexpr bool stick_size_is_pow2 = get_compile_time_arg_val(6) == 1; + ASSERT(stick_size_is_pow2); + const uint32_t log_base_2_of_page_size = get_compile_time_arg_val(7); + #ifdef FUSE_GAMMA + const InterleavedPow2AddrGen addrg = { + .bank_base_address = gamma_addr, + .log_base_2_of_page_size = log_base_2_of_page_size + }; + const uint32_t gamma_tile_bytes = get_tile_size(cb_gamma); + #endif + #ifdef FUSE_BETA + const InterleavedPow2AddrGen addrb = { + .bank_base_address = beta_addr, + .log_base_2_of_page_size = log_base_2_of_page_size + }; + const uint32_t beta_tile_bytes = get_tile_size(cb_beta); + #endif + + + + // Generate constant tiles for layernorm compute + uint32_t scaler = get_arg_val(5); + generate_reduce_scaler(cb_reduce, scaler); + const uint32_t eps = get_arg_val(6); + generate_bcast_col_scalar(cb_eps, eps); + + uint32_t inp_tile_idx = tile_offset; + uint32_t stats_tile_idx = stats_tile_offset; + + for (uint32_t ncht = 0; ncht < NCHt; ncht++) { + // Read stats tiles + cb_reserve_back(cb_stats, stats_tiles_cols); + uint32_t stats_wr_ptr = get_write_ptr(cb_stats); + for (uint32_t st = 0; st < stats_tiles_cols; ++st) { + noc_async_read_tile(stats_tile_idx, src_stats, stats_wr_ptr); + stats_wr_ptr += stats_tile_bytes; + stats_tile_idx++; + } + noc_async_read_barrier(); + cb_push_back(cb_stats, stats_tiles_cols); + + // read input tiles + for (uint32_t wt = 0; wt +#include "dataflow_api.h" +#include "tt_eager/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp" +#include "tt_eager/tt_dnn/kernels/dataflow/generate_bcast_scalar.hpp" +#include "debug/assert.h" + +void kernel_main() { + const uint32_t src_addr = get_arg_val(0); + const uint32_t NCHt = get_arg_val(1); + const uint32_t Wt = get_arg_val(2); + const uint32_t tile_offset = get_arg_val(3); + + constexpr uint32_t cb_inp = tt::CB::c_in0; + constexpr uint32_t cb_reduce = tt::CB::c_in1; + + // ublocks size defined in tiles + const uint32_t src0_tile_bytes = get_tile_size(cb_inp); + const DataFormat src0_data_format = get_dataformat(cb_inp); + + constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1; + constexpr uint32_t blk = get_compile_time_arg_val(1); + + const InterleavedAddrGenFast src_a = { + .bank_base_address = src_addr, + .page_size = src0_tile_bytes, + .data_format = src0_data_format + }; + + // Generate constant tiles for reduce scalar + uint32_t scaler = get_arg_val(4); + generate_reduce_scaler(cb_reduce, scaler); + + uint32_t inp_tile_idx = tile_offset; + + for (uint32_t ncht = 0; ncht < NCHt; ncht++) { + + // read input tiles + for (uint32_t wt = 0; wt(0); + const uint32_t num_tiles = get_arg_val(1); + const uint32_t tile_offset = get_arg_val(2); + + constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; + constexpr uint32_t blk = get_compile_time_arg_val(1); // needed for correctness of softmax/LN kernels + + constexpr uint32_t cb_out = tt::CB::c_out0; + constexpr uint32_t onetile = 1; + const uint32_t tile_bytes = get_tile_size(cb_out); + const DataFormat data_format = get_dataformat(cb_out); + + const InterleavedAddrGenFast s = { + .bank_base_address = dst_addr, + .page_size = tile_bytes, + .data_format = data_format + }; + + uint32_t tile_id = tile_offset; + for (uint32_t i = 0; i + +using uint32_t = std::uint32_t; +using namespace tt::constants; +using namespace tt::tt_metal; + +namespace tt { + +namespace tt_metal { + +void LayerNormPostAllGather::validate(const std::vector &input_tensors, const std::vector>& optional_input_tensors) const { + TT_FATAL(input_tensors.size() == 2 and optional_input_tensors.size() <= 2, "Must have between 1 to 4 input tensors"); + auto& a = input_tensors.at(0); + auto& stats = input_tensors.at(1); + const auto& gamma = optional_input_tensors.at(0); + const auto& beta = optional_input_tensors.at(1); + + for (const auto& tensor: input_tensors) { + TT_FATAL(tensor.get_layout() == Layout::TILE); + TT_FATAL(tensor.get_dtype() == DataType::BFLOAT16 || tensor.get_dtype() == DataType::BFLOAT8_B); + TT_FATAL(tensor.storage_type() == StorageType::DEVICE, "Operands to layernorm need to be on device!"); + TT_FATAL(tensor.buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); + } + + // stats has 2 or 1 tile columns per device if layernorm or rmsnorm + TT_FATAL(stats.get_legacy_shape()[-1] % TILE_WIDTH == 0); + TT_FATAL(stats.get_legacy_shape()[0] == a.get_legacy_shape()[0]); + TT_FATAL(stats.get_legacy_shape()[1] == a.get_legacy_shape()[1]); + TT_FATAL(stats.get_legacy_shape()[2] == a.get_legacy_shape()[2]); + // TODO: How to check if number of tile columns is correct? Would have to know # of devices and is_rmsnorm + + TT_FATAL(gamma.has_value()); + const auto& gamma_tensor = gamma.value(); + + TT_FATAL(gamma_tensor.get_layout() == Layout::ROW_MAJOR); // Only support packed RM right now + if (gamma_tensor.get_layout() == Layout::TILE) { + TT_FATAL(a.get_legacy_shape()[-1] == gamma.value().get_legacy_shape()[-1], fmt::format("{} != {}", a.get_legacy_shape()[-1], gamma.value().get_legacy_shape()[-1])); + TT_FATAL(gamma.value().buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); + TT_FATAL(a.device() == gamma.value().device()); + TT_FATAL(gamma.value().get_legacy_shape()[-2] == TILE_HEIGHT); + } else { + TT_FATAL(gamma_tensor.get_layout() == Layout::ROW_MAJOR); + TT_FATAL((gamma_tensor.get_legacy_shape()[-1] == TILE_WIDTH && gamma_tensor.volume() / TILE_WIDTH == a.get_legacy_shape()[-1] / TILE_WIDTH)); + TT_FATAL(gamma_tensor.buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); + TT_FATAL(a.device() == gamma_tensor.device()); + TT_FATAL(gamma_tensor.get_dtype() == DataType::BFLOAT16); + } + const bool is_layernorm = this->norm_type == LayerNormType::LAYERNORM; + const bool has_beta = beta.has_value(); + TT_FATAL(is_layernorm == has_beta); // TODO: Is this a necessary check? + + if (beta.has_value()) { + const auto& beta_tensor = beta.value(); + TT_FATAL(gamma_tensor.get_layout() == beta_tensor.get_layout(), "Gamma and beta must have the same layout!"); + TT_FATAL(beta_tensor.get_layout() == Layout::ROW_MAJOR); + if (beta_tensor.get_layout() == Layout::TILE) { + TT_FATAL(a.get_legacy_shape()[-1] == beta_tensor.get_legacy_shape()[-1]); + TT_FATAL(beta_tensor.buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); + TT_FATAL(a.device() == beta_tensor.device()); + TT_FATAL(beta.value().get_legacy_shape()[-2] == TILE_HEIGHT); + } else { + TT_FATAL(beta_tensor.get_layout() == Layout::ROW_MAJOR); + TT_FATAL((beta_tensor.get_legacy_shape()[-1] == TILE_WIDTH && beta_tensor.volume() / TILE_WIDTH == a.get_legacy_shape()[-1] / TILE_WIDTH)); + TT_FATAL(beta_tensor.buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); + TT_FATAL(a.device() == beta_tensor.device()); + TT_FATAL(beta_tensor.get_dtype() == DataType::BFLOAT16); + } + } +} + +std::vector LayerNormPostAllGather::compute_output_shapes(const std::vector &input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + return {input_tensor.get_legacy_shape()}; +} + +std::vector LayerNormPostAllGather::create_output_tensors(const std::vector &input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config); +} + +operation::ProgramWithCallbacks LayerNormPostAllGather::create_program( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + std::vector &output_tensors +) const { + const auto& a = input_tensors.at(0); + const auto& stats = input_tensors.at(1); + const auto& gamma = optional_input_tensors.at(0); + const auto& beta = optional_input_tensors.at(1); + auto& output_tensor = output_tensors.at(0); + + return layernorm_post_allgather_multi_core( + a, stats, gamma, beta, output_tensor, this->norm_type, this->eps, this->compute_kernel_config + ); +} + +tt::stl::reflection::Attributes LayerNormPostAllGather::attributes() const { + return { + {"norm_type", this->norm_type}, + {"eps", this->eps}, + {"output_mem_config", this->output_mem_config}, + {"compute_kernel_config", this->compute_kernel_config} + // {"program_config", this->program_config} + }; +} + +} // namespace tt_metal + +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_post_allgather_op.hpp b/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_post_allgather_op.hpp new file mode 100644 index 00000000000..ddda925f713 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_post_allgather_op.hpp @@ -0,0 +1,102 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "tt_dnn/op_library/compute_kernel_config.hpp" +#include "tt_dnn/op_library/run_operation.hpp" +#include "tt_dnn/op_library/layernorm/layernorm_op.hpp" +#include "tt_eager/tensor/tensor.hpp" +#include "ttnn/operations/core.hpp" + +using namespace tt::constants; + +namespace tt { + +namespace tt_metal { + +operation::ProgramWithCallbacks layernorm_post_allgather_multi_core( + const Tensor &a, + const Tensor &stats, + const std::optional gamma, + const std::optional beta, + Tensor& output, + LayerNormType norm_type, + float eps, + DeviceComputeKernelConfig compute_kernel_config); + + + +struct LayerNormPostAllGather { + LayerNormType norm_type; + float eps; + MemoryConfig output_mem_config; + // LayerNormProgramConfig program_config; + const DeviceComputeKernelConfig compute_kernel_config; + + void validate(const std::vector &input_tensors, const std::vector>& optional_input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector create_output_tensors(const std::vector &input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + std::vector &output_tensors + ) const; + tt::stl::reflection::Attributes attributes() const; +}; + +} // namespace metal + +namespace operations { + +namespace primary { + +template +struct make_layernorm_post_allgather { + Tensor operator()( + const Tensor& a, + const Tensor& stats, + float eps, + std::optional gamma = std::nullopt, + std::optional beta = std::nullopt, + const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + // const LayerNormProgramConfig& program_config = LayerNormDefaultProgramConfig{}, + std::optional compute_kernel_config = std::nullopt) const { + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({a}))}; + log_debug("layernorm_post_allgather: before launch_op"); + operation::launch_op( + [eps, mem_config, + // program_config, + compute_kernel_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + const auto& a = input_tensors.at(0); + const auto& stats = input_tensors.at(1); + const auto& gamma = optional_input_tensors.at(0); + const auto& beta = optional_input_tensors.at(1); + auto arch = a.storage_type() == StorageType::DEVICE ? a.device()->arch() : AutoFormat::GetDefaultDevice()->arch(); + auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::HiFi4, false, false, false); + return operation::run( + LayerNormPostAllGather{ + .norm_type = layernorm_type, + .eps = eps, + .output_mem_config = mem_config, + // .program_config = program_config, + .compute_kernel_config = kernel_config_val}, + {a, stats}, + {gamma, beta}); + }, {a, stats}, output_tensors, {gamma, beta}); + return output_tensors.at(0); + } +}; + +constexpr auto layernorm_post_allgather = make_layernorm_post_allgather{}; +constexpr auto rmsnorm_post_allgather = make_layernorm_post_allgather{}; + + +} // namespace primary + +} // namespace operations + +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_post_allgather_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_post_allgather_op_multi_core.cpp new file mode 100644 index 00000000000..1a2aa601327 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_post_allgather_op_multi_core.cpp @@ -0,0 +1,450 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp" +#include "tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_post_allgather_op.hpp" +#include "tt_eager/tt_dnn/op_library/work_split.hpp" +#include "tt_dnn/op_library/math.hpp" + +#include "tt_metal/host_api.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" + +#include +#include + +using uint32_t = std::uint32_t; +using namespace tt::constants; +using namespace tt::tt_metal; + +namespace tt { + +namespace tt_metal { + +inline bool is_dram(const Tensor& input_tensor) { return input_tensor.memory_config().buffer_type == BufferType::DRAM; } +inline bool is_dram(const std::optional input_tensor) { + return input_tensor.has_value() ? is_dram(input_tensor.value()) : true; +} +inline bool is_dram(const Buffer* b) { return b->buffer_type() == BufferType::DRAM; } + +// computes layernorm(a)*gamma + beta +operation::ProgramWithCallbacks layernorm_post_allgather_multi_core( + const Tensor &a, + const Tensor &stats, + const std::optional gamma, + const std::optional beta, + Tensor& output, + LayerNormType norm_type, + float eps, + DeviceComputeKernelConfig compute_kernel_config +) { + const bool is_rmsnorm = norm_type == LayerNormType::RMSNORM; + const auto shape = a.get_legacy_shape(); + const uint32_t W = shape[-1], H = shape[-2]; + const uint32_t HW = H*W; + const uint32_t NC = a.volume() / HW; + + + // Kernels are configured to support BFLOAT8_B, but bad pcc so we need mixed precision support in compute + const auto& a_dtype = a.get_dtype(); + + const uint32_t Wt = W/TILE_WIDTH; + const uint32_t Ht = H/TILE_HEIGHT; + const uint32_t stats_tiles_cols = stats.get_legacy_shape()[-1] / TILE_WIDTH; + const uint32_t tile_cols_per_device = is_rmsnorm ? 1 : 2; + const uint32_t num_devices = stats_tiles_cols / tile_cols_per_device; + TT_FATAL(num_devices > 0, "Number of devices must be greater than 0"); + TT_FATAL(num_devices * tile_cols_per_device == stats_tiles_cols, "Number of devices must divide number of stats tiles"); + + uint32_t num_tile_rows = NC * Ht; + + log_debug("is_rmsnorm: {}", is_rmsnorm); + log_debug("W: {}", W); + log_debug("H: {}", H); + log_debug("num_tile_rows: {}", num_tile_rows); + log_debug("Wt: {}", Wt); + log_debug("Ht: {}", Ht); + log_debug("stats_tiles_cols: {}", stats_tiles_cols); + log_debug("num_devices: {}", num_devices); + + + //////////////////////////////////////////////////////////////////////////// + // Device Setup + ////////////////////////////////////////////////////////////////////////// + Device *device = a.device(); + + //////////////////////////////////////////////////////////////////////////// + // Circular Buffer Data Format Setup + ////////////////////////////////////////////////////////////////////////// + MathFidelity math_fidelity; + bool math_approx_mode; + bool fp32_dest_acc_en; + + std::visit([&](auto&& compute_kernel_config) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + TT_ASSERT(device->arch() == ARCH::GRAYSKULL, "kernel config is not for graykull"); + math_fidelity = compute_kernel_config.math_fidelity; + math_approx_mode = compute_kernel_config.math_approx_mode; + fp32_dest_acc_en = false; + } else if constexpr (std::is_same_v) { + TT_ASSERT(device->arch() == ARCH::WORMHOLE_B0, "kernel config is not for wormhole_b0"); + math_fidelity = compute_kernel_config.math_fidelity; + math_approx_mode = compute_kernel_config.math_approx_mode; + fp32_dest_acc_en = tt_metal::datatype_to_dataformat_converter(a.get_dtype()) == tt::DataFormat::Float32 ? true : compute_kernel_config.fp32_dest_acc_en; + } else { + TT_FATAL("arch not supported"); + } + + }, compute_kernel_config); + + uint32_t block_size = fp32_dest_acc_en ? find_max_divisor(Wt, 4) : find_max_divisor(Wt, 8); + + tt::DataFormat in_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + tt::DataFormat stats_data_format = tt_metal::datatype_to_dataformat_converter(stats.get_dtype()); + tt::DataFormat out_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + tt::DataFormat cb_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b; + tt::DataFormat gamma_cb_data_format = gamma.has_value() ? tt_metal::datatype_to_dataformat_converter(gamma.value().get_dtype()) : tt::DataFormat::Float16_b; + tt::DataFormat beta_cb_data_format = beta.has_value() ? tt_metal::datatype_to_dataformat_converter(beta.value().get_dtype()) : tt::DataFormat::Float16_b; + uint32_t in_single_tile_size = tt_metal::detail::TileSize(in_data_format); + uint32_t stats_single_tile_size = tt_metal::detail::TileSize(stats_data_format); + uint32_t single_tile_size = tt_metal::detail::TileSize(cb_data_format); + uint32_t out_single_tile_size = tt_metal::detail::TileSize(out_data_format); + uint32_t bfloat16_tile_size = tt_metal::detail::TileSize(tt::DataFormat::Float16_b); + uint32_t gamma_single_tile_size = tt_metal::detail::TileSize(gamma_cb_data_format); + uint32_t beta_single_tile_size = tt_metal::detail::TileSize(beta_cb_data_format); + + log_debug("in_data_format: {}", in_data_format); + log_debug("out_data_format: {}", out_data_format); + log_debug("cb_data_format: {}", cb_data_format); + log_debug("gamma_cb_data_format: {}", gamma_cb_data_format); + log_debug("beta_cb_data_format: {}", beta_cb_data_format); + log_debug("math_fidelity: {}", math_fidelity); + log_debug("math_approx_mode: {}", math_approx_mode); + log_debug("fp32_dest_acc_en: {}", fp32_dest_acc_en); + + tt::DataFormat inb_data_format = tt::DataFormat::Invalid; + uint32_t inb_single_tile_size = 0; + + auto a_addr = a.buffer()->address(); + auto stats_addr = stats.buffer()->address(); + auto gamma_dram_addr = gamma.has_value() ? gamma.value().buffer()->address() : 0; + TT_FATAL(gamma_dram_addr != 0, "Gamma must be provided"); + auto beta_dram_addr = beta.has_value() ? beta.value().buffer()->address() : 0; + auto dst_addr = output.buffer()->address(); + + uint32_t num_tiles = a.volume()/TILE_HW; + uint32_t num_gamma_tiles = gamma.has_value() ? gamma.value().volume()/TILE_HW : 0; + uint32_t num_beta_tiles = beta.has_value() ? beta.value().volume()/TILE_HW : 0; + + // For bert, tensor is packed as RM with width 32 + if (gamma.has_value() and gamma.value().get_layout() == Layout::ROW_MAJOR) { + num_gamma_tiles = gamma.has_value() ? gamma.value().volume()/TILE_WIDTH : 0; + } + if (beta.has_value() and beta.value().get_layout() == Layout::ROW_MAJOR) { + num_beta_tiles = beta.has_value() ? beta.value().volume()/TILE_WIDTH : 0; + } + + log_debug("num_gamma_tiles: {}", num_gamma_tiles); + log_debug("num_beta_tiles: {}", num_beta_tiles); + + + //////////////////////////////////////////////////////////////////////////// + // Parameters Setup + //////////////////////////////////////////////////////////////////////////// + /* + in0_cb: a + in1_cb: stats + in2_cb: gamma + in3_cb: beta + in4_cb: epsilon + in5_cb: 1/row_size (reduction scalar) + + intermediate CBs are packed such that in layernorm, first tile is for x**2 stats, second tile is for x stats + in RMSNorm, only first tile has valid data. + + intermed0_cb: [mean(x**2), mean(x)] # reduce with reduce_scalar + intermed1_cb: mean(x)**2 # LN only + intermed2_cb: var = mean(x**2) - mean(x)**2 # for RMSNorm, this is just mean(x**2) + intermed3_cb: var + epsilon # RMSNorm takes mean(x**2) instead of var + intermed4_cb: 1/sqrt(var + epsilon) + intermed5_cb: x - mean(x) # LN only + intermed6_cb: (x - mean(x)) * 1/sqrt(var + epsilon) # RMSNorm takes x instead of (x - mean(x)) + intermed7_cb: (x - mean(x)) * 1/sqrt(var + epsilon) * gamma + out0_cb: (x - mean(x)) * 1/sqrt(var + epsilon) * gamma + beta # RMSNorm doesn't include beta + + */ + + const uint32_t in0_tiles = Wt; + const uint32_t in1_tiles = stats_tiles_cols; + const uint32_t in2_tiles = Wt; + const uint32_t in3_tiles = Wt; + const uint32_t in4_tiles = 1; // epsilon + const uint32_t in5_tiles = 1; // reduce scalar + + const uint32_t intermed0_tiles = tile_cols_per_device; + const uint32_t intermed1_tiles = 1; + const uint32_t intermed2_tiles = 1; + const uint32_t intermed3_tiles = 1; + const uint32_t intermed4_tiles = 1; + const uint32_t intermed5_tiles = Wt; + const uint32_t intermed6_tiles = Wt; + const uint32_t intermed7_tiles = Wt; + const uint32_t out0_tiles = Wt; + + TT_ASSERT(W <= TILE_WIDTH*in0_tiles && "W exceeds the maximum supported size of tile buffer (kernel limitation right now)."); + TT_ASSERT(in0_tiles % block_size == 0 && "Size of buffer must be divisible by the size of block used by the reader and compute kernel."); + TT_ASSERT(in2_tiles % block_size == 0 && "Size of buffer must be divisible by the size of block used by the reader and compute kernel."); + TT_ASSERT(in3_tiles % block_size == 0 && "Size of buffer must be divisible by the size of block used by the reader and compute kernel."); + TT_ASSERT(out0_tiles % block_size == 0 && "Size of buffer must be divisible by the size of block used by the reader and compute kernel."); + TT_ASSERT(intermed5_tiles % block_size == 0 && "Size of buffer must be divisible by the size of block used by the reader and compute kernel."); + TT_ASSERT(intermed6_tiles % block_size == 0 && "Size of buffer must be divisible by the size of block used by the reader and compute kernel."); + TT_ASSERT(intermed7_tiles % block_size == 0 && "Size of buffer must be divisible by the size of block used by the reader and compute kernel."); + + + auto grid_size = device->compute_with_storage_grid_size(); + auto [num_cores, all_cores, core_group_1, core_group_2, num_tile_rows_per_core_group_1, num_tile_rows_per_core_group_2] = split_work_to_cores(grid_size, num_tile_rows, true); + + log_debug("num_cores: {}", num_cores); + log_debug("grid_size: {}", grid_size); + log_debug("core_group_1: {}", core_group_1.str()); + log_debug("num_tile_rows_per_core_group_1: {}", num_tile_rows_per_core_group_1); + log_debug("core_group_2: {}", core_group_2.str()); + log_debug("num_tile_rows_per_core_group_2: {}", num_tile_rows_per_core_group_2); + + //////////////////////////////////////////////////////////////////////////// + // Application Setup + //////////////////////////////////////////////////////////////////////////// + Program program = CreateProgram(); + + std::vector reader_compile_time_args = { + // interleaved accessor args + (std::uint32_t) is_dram(a), + (std::uint32_t) is_dram(stats), + (std::uint32_t) is_dram(gamma), + (std::uint32_t) is_dram(beta), + (std::uint32_t) block_size, + (std::uint32_t) stats_tiles_cols, + }; + + if (gamma.has_value() and gamma.value().get_layout() == Layout::ROW_MAJOR) { + auto gamma_stick_size = gamma.value().get_legacy_shape()[-1] * gamma.value().element_size(); + bool gamma_stick_size_is_power_of_two = is_power_of_two_at_least_32(gamma_stick_size); + TT_FATAL(gamma_stick_size_is_power_of_two, "Only power of 2 gammas are supported"); + reader_compile_time_args.push_back((std::uint32_t) gamma_stick_size_is_power_of_two); + // if (gamma_stick_size_is_power_of_two) { + uint32_t gamma_log2_stick_size = gamma_stick_size_is_power_of_two ? (std::uint32_t)log2(gamma_stick_size) : 0; + reader_compile_time_args.push_back((std::uint32_t) gamma_log2_stick_size); + } + + std::vector writer_compile_time_args = { + // interleaved accessor args + (std::uint32_t) is_dram(output), + (std::uint32_t) block_size + }; + + + bool tile_dtype_is_bfloat16 = a.get_dtype() == tt::tt_metal::DataType::BFLOAT16; + std::map reader_defines; + std::map compute_defines; + if (gamma.has_value()) { + reader_defines["FUSE_GAMMA"] = "1"; + } + if (beta.has_value()) { + reader_defines["FUSE_BETA"] = "1"; + } + + if (is_rmsnorm) { + compute_defines["RMSNORM"] = "1"; + } + + auto use_row_major_kernel = (gamma.has_value() and gamma.value().get_layout() == Layout::ROW_MAJOR) or (beta.has_value() and beta.value().get_layout() == Layout::ROW_MAJOR); + TT_FATAL(use_row_major_kernel, "Only row major gamma and beta are supported"); + auto reader_kernels_id = CreateKernel( + program, + "tt_eager/tt_dnn/op_library/layernorm_distributed/kernels/dataflow/reader_unary_interleaved_ln_rm_gb_post_allgather.cpp", + all_cores, + tt_metal::ReaderDataMovementConfig(reader_compile_time_args, reader_defines) + ); + + auto writer_kernels_id = CreateKernel( + program, + "tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/writer_unary_interleaved_start_id_blocked.cpp", + all_cores, + tt_metal::WriterDataMovementConfig(writer_compile_time_args) + ); + + vector compute_args = { Wt, block_size, stats_tiles_cols, gamma.has_value(), beta.has_value(), fp32_dest_acc_en }; + + auto compute_kernels_id = CreateKernel( + program, + "tt_eager/tt_dnn/op_library/layernorm_distributed/kernels/compute/layernorm_post_allgather.cpp", + all_cores, + tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .math_approx_mode = math_approx_mode, .compile_args = compute_args, .defines = compute_defines} + ); + + // Create circular buffers + // c_in0 -> a + CircularBufferConfig cb_src0_config = CircularBufferConfig(in0_tiles*in_single_tile_size, {{CB::c_in0, in_data_format}}).set_page_size(CB::c_in0, in_single_tile_size); + CreateCircularBuffer( program, all_cores, cb_src0_config ); + // c_in1 -> stats + CircularBufferConfig cb_stats_config = CircularBufferConfig(in1_tiles*stats_single_tile_size, {{CB::c_in1, stats_data_format}}).set_page_size(CB::c_in1, stats_single_tile_size); + CreateCircularBuffer( program, all_cores, cb_stats_config ); + // c_in2 -> gamma + if (gamma.has_value()) { + CircularBufferConfig cb_gamma_config = CircularBufferConfig(in2_tiles*gamma_single_tile_size, {{CB::c_in2, gamma_cb_data_format}}).set_page_size(CB::c_in2, gamma_single_tile_size); + CreateCircularBuffer( program, all_cores, cb_gamma_config ); + } + // c_in3 -> beta + if (beta.has_value()) { + CircularBufferConfig cb_beta_config = CircularBufferConfig(in3_tiles*beta_single_tile_size, {{CB::c_in3, beta_cb_data_format}}).set_page_size(CB::c_in3, beta_single_tile_size); + CreateCircularBuffer( program, all_cores, cb_beta_config ); + } + // c_in4 -> epsilon + CircularBufferConfig cb_eps_config = CircularBufferConfig(in4_tiles*bfloat16_tile_size, {{CB::c_in4, DataFormat::Float16_b}}).set_page_size(CB::c_in4, bfloat16_tile_size); + CreateCircularBuffer( program, all_cores, cb_eps_config ); + // c_in5 -> reduce scalar + CircularBufferConfig cb_reduce_config = CircularBufferConfig(in5_tiles*bfloat16_tile_size, {{CB::c_in5, DataFormat::Float16_b}}).set_page_size(CB::c_in5, bfloat16_tile_size); + CreateCircularBuffer( program, all_cores, cb_reduce_config ); + + // LN and RMS shared intermediates // + // c_intermed0 -> [mean(x**2), mean(x)] + CircularBufferConfig cb_intermed0_config = CircularBufferConfig(intermed0_tiles*single_tile_size, {{CB::c_intermed0, cb_data_format}}).set_page_size(CB::c_intermed0, single_tile_size); + CreateCircularBuffer( program, all_cores, cb_intermed0_config ); + // c_intermed2 -> var = mean(x**2) - mean(x)**2 + CircularBufferConfig cb_intermed2_config = CircularBufferConfig(intermed2_tiles*single_tile_size, {{CB::c_intermed2, cb_data_format}}).set_page_size(CB::c_intermed2, single_tile_size); + CreateCircularBuffer( program, all_cores, cb_intermed2_config ); + // c_intermed3 -> var + epsilon + CircularBufferConfig cb_intermed3_config = CircularBufferConfig(intermed3_tiles*single_tile_size, {{CB::c_intermed3, cb_data_format}}).set_page_size(CB::c_intermed3, single_tile_size); + CreateCircularBuffer( program, all_cores, cb_intermed3_config ); + // c_intermed4 -> 1/sqrt(var + epsilon) + CircularBufferConfig cb_intermed4_config = CircularBufferConfig(intermed4_tiles*single_tile_size, {{CB::c_intermed4, cb_data_format}}).set_page_size(CB::c_intermed4, single_tile_size); + CreateCircularBuffer( program, all_cores, cb_intermed4_config ); + // c_intermed6 -> (x - mean(x)) * 1/sqrt(var + epsilon) + CircularBufferConfig cb_intermed6_config = CircularBufferConfig(intermed6_tiles*single_tile_size, {{CB::c_intermed6, cb_data_format}}).set_page_size(CB::c_intermed6, single_tile_size); + CreateCircularBuffer( program, all_cores, cb_intermed6_config ); + + + // LN-specific intermediates + if (!is_rmsnorm) { + // c_intermed1 -> mean(x)**2 + CircularBufferConfig cb_intermed1_config = CircularBufferConfig(intermed1_tiles*single_tile_size, {{CB::c_intermed1, cb_data_format}}).set_page_size(CB::c_intermed1, single_tile_size); + CreateCircularBuffer( program, all_cores, cb_intermed1_config ); + // c_intermed5 -> x - mean(x) + CircularBufferConfig cb_intermed5_config = CircularBufferConfig(intermed5_tiles*single_tile_size, {{CB::c_intermed5, cb_data_format}}).set_page_size(CB::c_intermed5, single_tile_size); + CreateCircularBuffer( program, all_cores, cb_intermed5_config ); + if (beta.has_value()) { + // Layernorm has gamma and beta so we need an extra intermediate buffer + // c_intermed7 -> (x - mean(x)) * 1/sqrt(var + epsilon) * gamma + CircularBufferConfig cb_intermed7_config = CircularBufferConfig(intermed7_tiles*single_tile_size, {{CB::c_intermed7, cb_data_format}}).set_page_size(CB::c_intermed7, single_tile_size); + CreateCircularBuffer( program, all_cores, cb_intermed7_config ); + } + } + + + CircularBufferConfig cb_out0_config = CircularBufferConfig(out0_tiles*out_single_tile_size, {{CB::c_out0, out_data_format}}).set_page_size(CB::c_out0, out_single_tile_size); + CreateCircularBuffer( program, all_cores, cb_out0_config ); + + // Log all circular buffers with program.circular_buffers_on_corerange(all_cores), which returns std::vector> + + for (const auto& cb : program.circular_buffers_on_corerange(*all_cores.ranges().begin())) { + for (const auto index : cb->buffer_indices()) { + log_debug("cb_id {}", index); + log_debug("page_size: {}", cb->page_size(index)); + log_debug("num_pages: {}", cb->num_pages(index)); + log_debug("data_format: {}", cb->data_format(index)); + } + } + + uint32_t curr_row = 0; + float winv = 1.0f / (W * num_devices); // bcast-w scaler + bfloat16 bfloat_winv_value = bfloat16(winv); + uint32_t packed_winv_value = pack_two_bfloat16_into_uint32({bfloat_winv_value, bfloat_winv_value}); + union { float f; uint32_t u; } e; e.f = eps; // epsilon + for (uint32_t i = 0; i < num_cores; ++i) { + CoreCoord core = {i % grid_size.x, i / grid_size.x}; + + uint32_t num_tile_rows_per_core = 0; + if (core_group_1.core_coord_in_core_ranges(core)) { + num_tile_rows_per_core = num_tile_rows_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + num_tile_rows_per_core = num_tile_rows_per_core_group_2; + } else { + TT_ASSERT(false, "Core not in specified core ranges"); + } + + uint32_t tile_offset = curr_row * Wt; + uint32_t stats_offset = curr_row * stats_tiles_cols; + + SetRuntimeArgs(program, reader_kernels_id, core, + { a_addr, num_tile_rows_per_core, Wt, tile_offset, stats_offset, packed_winv_value, e.u, // 0-5 + gamma_dram_addr, beta_dram_addr, stats_addr } // 6-8 + ); + SetRuntimeArgs(program, compute_kernels_id, core, { num_tile_rows_per_core }); + SetRuntimeArgs(program, writer_kernels_id, core, { dst_addr, num_tile_rows_per_core * Wt, tile_offset } ); + curr_row += num_tile_rows_per_core; + } + + auto override_runtime_args_callback = [ + reader_kernel_id=reader_kernels_id, + writer_kernel_id=writer_kernels_id, + num_cores, + grid_size + ] + ( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors + ) { + + const auto& input_tensor = input_tensors.at(0); + const auto& stats_tensor = input_tensors.at(1); + const auto& gamma_tensor = optional_input_tensors.at(0); + const auto& beta_tensor = optional_input_tensors.at(1); + + const auto input_addr = input_tensor.buffer()->address(); + const auto stats_addr = stats_tensor.buffer()->address(); + const bool has_gamma = gamma_tensor.has_value(); + const bool has_beta = beta_tensor.has_value(); + const auto gamma_addr = has_gamma ? gamma_tensor.value().buffer()->address() : 0; + const auto beta_addr = has_beta ? beta_tensor.value().buffer()->address() : 0; + + const auto& output_tensor = output_tensors.at(0); + const auto output_addr = output_tensor.buffer()->address(); + + auto& reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); + auto& writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); + + for (uint32_t i = 0; i < num_cores; ++i) { + const CoreCoord core = {i % grid_size.x, i / grid_size.x}; + + { + auto& reader_args = reader_runtime_args_by_core.at(core.x).at(core.y); + + reader_args[0] = input_addr; + reader_args[9] = stats_addr; + if (has_gamma) { + reader_args[7] = gamma_addr; + } + if (has_beta) { + reader_args[8] = beta_addr; + } + } + + { + auto& writer_args = writer_runtime_args_by_core.at(core.x).at(core.y); + writer_args[0] = output_addr; + } + } + }; + + return {std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; +} + + +} // namespace tt_metal + +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.cpp b/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.cpp new file mode 100644 index 00000000000..3281a964c7a --- /dev/null +++ b/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.cpp @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp" +#include "tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.hpp" +#include "tt_eager/tt_dnn/op_library/work_split.hpp" +#include "tt_dnn/op_library/run_operation.hpp" +#include "tt_dnn/op_library/math.hpp" + +#include "tt_metal/host_api.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" + +#include "third_party/magic_enum/magic_enum.hpp" + +#include + +using uint32_t = std::uint32_t; +using namespace tt::constants; +using namespace tt::tt_metal; + +namespace tt { + +namespace tt_metal { + +void LayerNormPreAllGather::validate(const std::vector &input_tensors) const { + TT_FATAL(input_tensors.size() == 1, "Must have 1 input tensor"); + auto& tensor = input_tensors.at(0); + + TT_FATAL(tensor.get_layout() == Layout::TILE, "Only tilized inputs supported."); + TT_FATAL(tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Only interleaved inputs supported."); + TT_FATAL(tensor.get_dtype() == DataType::BFLOAT16 || tensor.get_dtype() == DataType::BFLOAT8_B, "Input data format not supported."); + TT_FATAL(tensor.storage_type() == StorageType::DEVICE, "Operands to layernorm need to be on device!"); + TT_FATAL(tensor.buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); +} + +std::vector LayerNormPreAllGather::compute_output_shapes(const std::vector &input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + + auto output_shape = input_tensor.get_legacy_shape(); + auto padding = output_shape.padding(); + uint32_t num_tiles_w = 1; + if (this->norm_type == LayerNormType::LAYERNORM) { + num_tiles_w = 2; + } + output_shape[3] = num_tiles_w * TILE_WIDTH; + padding[3] = Padding::PadDimension{0, 31}; + + return {Shape(output_shape, padding)}; +} + +std::vector LayerNormPreAllGather::create_output_tensors(const std::vector &input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + return operation::generic_create_output_tensors(*this, input_tensors, this->output_dtype, Layout::TILE, input_tensor.memory_config()); +} + +operation::ProgramWithCallbacks LayerNormPreAllGather::create_program( + const std::vector& input_tensors, + std::vector &output_tensors +) const { + const auto& a = input_tensors.at(0); + auto& output_tensor = output_tensors.at(0); + + return layernorm_pre_allgather_multi_core( + a, output_tensor, this->norm_type, this->compute_kernel_config + ); +} + +tt::stl::reflection::Attributes LayerNormPreAllGather::attributes() const { + return { + {"norm_type", this->norm_type}, + {"compute_kernel_config", this->compute_kernel_config} + }; +} + +} // namespace tt_metal + +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.hpp b/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.hpp new file mode 100644 index 00000000000..0477a90ae7f --- /dev/null +++ b/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.hpp @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "tt_dnn/op_library/compute_kernel_config.hpp" +#include "tt_dnn/op_library/run_operation.hpp" +#include "tt_dnn/op_library/layernorm/layernorm_op.hpp" +#include "tt_eager/tensor/tensor.hpp" +#include "ttnn/operations/core.hpp" + +using namespace tt::constants; + +namespace tt { + +namespace tt_metal { + +operation::ProgramWithCallbacks layernorm_pre_allgather_multi_core( + const Tensor &a, + Tensor& output, + LayerNormType norm_type, + DeviceComputeKernelConfig compute_kernel_config); + + + +struct LayerNormPreAllGather { + LayerNormType norm_type; + const DeviceComputeKernelConfig compute_kernel_config; + const DataType output_dtype; + + void validate(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector create_output_tensors(const std::vector &input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector &input_tensors, std::vector &output_tensors) const; + tt::stl::reflection::Attributes attributes() const; +}; + +} // namespace metal + +namespace operations { + +namespace primary { + +template +struct make_layernorm_pre_allgather { + Tensor operator()( + const Tensor& a, + std::optional compute_kernel_config = std::nullopt, + const DataType output_dtype = DataType::BFLOAT16) const { + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({a}))}; + operation::launch_op( + [compute_kernel_config, output_dtype] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + const auto& a = input_tensors.at(0); + auto arch = a.storage_type() == StorageType::DEVICE ? a.device()->arch() : AutoFormat::GetDefaultDevice()->arch(); + auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::HiFi4, false, false, false); + return operation::run( + LayerNormPreAllGather{ + .norm_type = layernorm_type, + .compute_kernel_config = kernel_config_val, + .output_dtype = output_dtype}, + {a}); + }, {a}, output_tensors); + return output_tensors.at(0); + } +}; + +constexpr auto layernorm_pre_allgather = make_layernorm_pre_allgather{}; +constexpr auto rmsnorm_pre_allgather = make_layernorm_pre_allgather{}; + + +} // namespace primary + +} // namespace operations + +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op_multi_core.cpp new file mode 100644 index 00000000000..2bc905c91c1 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op_multi_core.cpp @@ -0,0 +1,300 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp" +#include "tt_eager/tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.hpp" +#include "tt_eager/tt_dnn/op_library/work_split.hpp" +#include "tt_dnn/op_library/math.hpp" + +#include "tt_metal/host_api.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" + +#include +#include + +using uint32_t = std::uint32_t; +using namespace tt::constants; +using namespace tt::tt_metal; + +namespace tt { + +namespace tt_metal { + +inline bool is_dram(const Tensor& input_tensor) { return input_tensor.memory_config().buffer_type == BufferType::DRAM; } +inline bool is_dram(const std::optional input_tensor) { + return input_tensor.has_value() ? is_dram(input_tensor.value()) : true; +} +inline bool is_dram(const Buffer* b) { return b->buffer_type() == BufferType::DRAM; } + +operation::ProgramWithCallbacks layernorm_pre_allgather_multi_core( + const Tensor &a, + Tensor& output, + LayerNormType norm_type, + DeviceComputeKernelConfig compute_kernel_config +) { + const bool is_rmsnorm = norm_type == LayerNormType::RMSNORM; + const auto shape = a.get_legacy_shape(); + const uint32_t W = shape[-1], H = shape[-2]; + const uint32_t HW = H*W; + const uint32_t NC = a.volume() / HW; + + + // Kernels are configured to support BFLOAT8_B, but bad pcc so we need mixed precision support in compute + const auto& a_dtype = a.get_dtype(); + + const uint32_t Wt = W/TILE_WIDTH; + const uint32_t Ht = H/TILE_HEIGHT; + const uint32_t tile_cols_per_device = is_rmsnorm ? 1 : 2; + + uint32_t num_tile_rows = NC * Ht; + + log_debug("is_rmsnorm: {}", is_rmsnorm); + log_debug("W: {}", W); + log_debug("H: {}", H); + log_debug("num_tile_rows: {}", num_tile_rows); + log_debug("Wt: {}", Wt); + log_debug("Ht: {}", Ht); + + + //////////////////////////////////////////////////////////////////////////// + // Device Setup + ////////////////////////////////////////////////////////////////////////// + Device *device = a.device(); + + //////////////////////////////////////////////////////////////////////////// + // Circular Buffer Data Format Setup + ////////////////////////////////////////////////////////////////////////// + MathFidelity math_fidelity; + bool math_approx_mode; + bool fp32_dest_acc_en; + + std::visit([&](auto&& compute_kernel_config) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + TT_ASSERT(device->arch() == ARCH::GRAYSKULL, "kernel config is not for graykull"); + math_fidelity = compute_kernel_config.math_fidelity; + math_approx_mode = compute_kernel_config.math_approx_mode; + fp32_dest_acc_en = false; + } else if constexpr (std::is_same_v) { + TT_ASSERT(device->arch() == ARCH::WORMHOLE_B0, "kernel config is not for wormhole_b0"); + math_fidelity = compute_kernel_config.math_fidelity; + math_approx_mode = compute_kernel_config.math_approx_mode; + fp32_dest_acc_en = tt_metal::datatype_to_dataformat_converter(a.get_dtype()) == tt::DataFormat::Float32 ? true : compute_kernel_config.fp32_dest_acc_en; + } else { + TT_FATAL("arch not supported"); + } + + }, compute_kernel_config); + + uint32_t block_size = 1; // find_max_divisor(Wt, 8); + uint32_t writer_block_size = 1; + + tt::DataFormat in_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + tt::DataFormat out_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + tt::DataFormat cb_data_format = tt::DataFormat::Float16_b; + uint32_t in_single_tile_size = tt_metal::detail::TileSize(in_data_format); + uint32_t out_single_tile_size = tt_metal::detail::TileSize(out_data_format); + uint32_t single_tile_size = tt_metal::detail::TileSize(cb_data_format); + uint32_t bfloat16_tile_size = tt_metal::detail::TileSize(tt::DataFormat::Float16_b); + + log_debug("in_data_format: {}", in_data_format); + log_debug("out_data_format: {}", out_data_format); + + tt::DataFormat inb_data_format = tt::DataFormat::Invalid; + uint32_t inb_single_tile_size = 0; + + auto a_addr = a.buffer()->address(); + auto dst_addr = output.buffer()->address(); + + uint32_t num_tiles = a.volume()/TILE_HW; + + //////////////////////////////////////////////////////////////////////////// + // Parameters Setup + //////////////////////////////////////////////////////////////////////////// + /* + in0_cb: a + in1_cb: 1 (reduction scalar) + + output CB is packed such that the first tile is for x**2 stats, second tile is for x stats + in RMSNorm, only first tile has valid data. + + intermed0_cb: xˆ2 + out0_cb: [sum(xˆ2), sum(x)] # For layernorm + out0_cb: [sum(xˆ2)] # RMSNorm + + */ + const uint32_t double_buffer_constant = 2; + const uint32_t in0_tiles = Wt * double_buffer_constant; + const uint32_t in1_tiles = 1; // reduce scalar + + const uint32_t intermed0_tiles = Wt * double_buffer_constant; // xˆ2 + uint32_t out0_tiles = 1; + if (!is_rmsnorm) { + out0_tiles = 2; + } + + TT_ASSERT(W <= TILE_WIDTH*in0_tiles && "W exceeds the maximum supported size of tile buffer (kernel limitation right now)."); + TT_ASSERT(in0_tiles % block_size == 0 && "Size of buffer must be divisible by the size of block used by the reader and compute kernel."); + TT_ASSERT(intermed0_tiles % block_size == 0 && "Size of buffer must be divisible by the size of block used by the reader and compute kernel."); + + + auto grid_size = device->compute_with_storage_grid_size(); + auto [num_cores, all_cores, core_group_1, core_group_2, num_tile_rows_per_core_group_1, num_tile_rows_per_core_group_2] = split_work_to_cores(grid_size, num_tile_rows, true); + + log_debug("num_cores: {}", num_cores); + log_debug("grid_size: {}", grid_size); + log_debug("core_group_1: {}", core_group_1.str()); + log_debug("num_tile_rows_per_core_group_1: {}", num_tile_rows_per_core_group_1); + log_debug("core_group_2: {}", core_group_2.str()); + log_debug("num_tile_rows_per_core_group_2: {}", num_tile_rows_per_core_group_2); + + //////////////////////////////////////////////////////////////////////////// + // Application Setup + //////////////////////////////////////////////////////////////////////////// + Program program = CreateProgram(); + + std::vector reader_compile_time_args = { + // interleaved accessor args + (std::uint32_t) is_dram(a), + (std::uint32_t) block_size, + }; + + std::vector writer_compile_time_args = { + // interleaved accessor args + (std::uint32_t) is_dram(output), + (std::uint32_t) writer_block_size + }; + + + bool tile_dtype_is_bfloat16 = a.get_dtype() == tt::tt_metal::DataType::BFLOAT16; + std::map compute_defines; + + if (is_rmsnorm) { + compute_defines["RMSNORM"] = "1"; + } + + auto reader_kernels_id = CreateKernel( + program, + "tt_eager/tt_dnn/op_library/layernorm_distributed/kernels/dataflow/reader_unary_interleaved_ln_rm_gb_pre_allgather.cpp", + all_cores, + tt_metal::ReaderDataMovementConfig(reader_compile_time_args) + ); + + auto writer_kernels_id = CreateKernel( + program, + "tt_eager/tt_dnn/op_library/layernorm/kernels/dataflow/writer_unary_interleaved_start_id_blocked.cpp", + all_cores, + tt_metal::WriterDataMovementConfig(writer_compile_time_args) + ); + + vector compute_args = { Wt, block_size }; + + auto compute_kernels_id = CreateKernel( + program, + "tt_eager/tt_dnn/op_library/layernorm_distributed/kernels/compute/layernorm_pre_allgather.cpp", + all_cores, + tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .math_approx_mode = math_approx_mode, .compile_args = compute_args, .defines = compute_defines} + ); + + // Create circular buffers + // c_in0 -> a + CircularBufferConfig cb_src0_config = CircularBufferConfig(in0_tiles*in_single_tile_size, {{CB::c_in0, in_data_format}}).set_page_size(CB::c_in0, in_single_tile_size); + CreateCircularBuffer( program, all_cores, cb_src0_config ); + // c_in1 -> reduce scalar + CircularBufferConfig cb_reduce_config = CircularBufferConfig(in1_tiles*bfloat16_tile_size, {{CB::c_in1, cb_data_format}}).set_page_size(CB::c_in1, bfloat16_tile_size); + CreateCircularBuffer( program, all_cores, cb_reduce_config ); + + // LN and RMS shared intermediates // + // c_intermed0 -> xˆ2 + CircularBufferConfig cb_intermed0_config = CircularBufferConfig(intermed0_tiles*single_tile_size, {{CB::c_intermed0, cb_data_format}}).set_page_size(CB::c_intermed0, single_tile_size); + CreateCircularBuffer( program, all_cores, cb_intermed0_config ); + + CircularBufferConfig cb_out0_config = CircularBufferConfig(out0_tiles*out_single_tile_size, {{CB::c_out0, out_data_format}}).set_page_size(CB::c_out0, out_single_tile_size); + CreateCircularBuffer( program, all_cores, cb_out0_config ); + + // Log all circular buffers with program.circular_buffers_on_corerange(all_cores), which returns std::vector> + for (const auto& cb : program.circular_buffers_on_corerange(*all_cores.ranges().begin())) { + for (const auto index : cb->buffer_indices()) { + log_debug("cb_id {}", index); + log_debug("page_size: {}", cb->page_size(index)); + log_debug("num_pages: {}", cb->num_pages(index)); + log_debug("data_format: {}", cb->data_format(index)); + } + } + + uint32_t curr_row = 0; + float winv = 1.0f; + bfloat16 bfloat_winv_value = bfloat16(winv); + uint32_t packed_winv_value = pack_two_bfloat16_into_uint32({bfloat_winv_value, bfloat_winv_value}); + for (uint32_t i = 0; i < num_cores; ++i) { + CoreCoord core = {i % grid_size.x, i / grid_size.x}; + + uint32_t num_tile_rows_per_core = 0; + if (core_group_1.core_coord_in_core_ranges(core)) { + num_tile_rows_per_core = num_tile_rows_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + num_tile_rows_per_core = num_tile_rows_per_core_group_2; + } else { + TT_ASSERT(false, "Core not in specified core ranges"); + } + + uint32_t in_tile_offset = curr_row * Wt; + uint32_t out_tile_offset = curr_row * out0_tiles; + + SetRuntimeArgs(program, reader_kernels_id, core, + { a_addr, num_tile_rows_per_core, Wt, in_tile_offset, packed_winv_value } + ); + SetRuntimeArgs(program, compute_kernels_id, core, { num_tile_rows_per_core }); + SetRuntimeArgs(program, writer_kernels_id, core, { dst_addr, num_tile_rows_per_core * out0_tiles, out_tile_offset } ); + curr_row += num_tile_rows_per_core; + } + + auto override_runtime_args_callback = [ + reader_kernel_id=reader_kernels_id, + writer_kernel_id=writer_kernels_id, + num_cores, + grid_size + ] + ( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors + ) { + + const auto& input_tensor = input_tensors.at(0); + + const auto input_addr = input_tensor.buffer()->address(); + + const auto& output_tensor = output_tensors.at(0); + const auto output_addr = output_tensor.buffer()->address(); + + auto& reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); + auto& writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); + + for (uint32_t i = 0; i < num_cores; ++i) { + const CoreCoord core = {i % grid_size.x, i / grid_size.x}; + + { + auto& reader_args = reader_runtime_args_by_core.at(core.x).at(core.y); + + reader_args[0] = input_addr; + } + + { + auto& writer_args = writer_runtime_args_by_core.at(core.x).at(core.y); + writer_args[0] = output_addr; + } + } + }; + + return {std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; +} + + +} // namespace tt_metal + +} // namespace tt diff --git a/tt_eager/tt_lib/csrc/operations/primary/module.hpp b/tt_eager/tt_lib/csrc/operations/primary/module.hpp index d24465cacf6..fd844c659d9 100644 --- a/tt_eager/tt_lib/csrc/operations/primary/module.hpp +++ b/tt_eager/tt_lib/csrc/operations/primary/module.hpp @@ -11,6 +11,8 @@ #include "tt_dnn/op_library/bmm/bmm_op.hpp" #include "tt_dnn/op_library/groupnorm/groupnorm_op.hpp" #include "tt_dnn/op_library/layernorm/layernorm_op.hpp" +#include "tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.hpp" +#include "tt_dnn/op_library/layernorm_distributed/layernorm_post_allgather_op.hpp" #include "tt_dnn/op_library/moreh_adam/moreh_adam_op.hpp" #include "tt_dnn/op_library/moreh_adamw/moreh_adamw_op.hpp" #include "tt_dnn/op_library/moreh_arange/moreh_arange_op.hpp" @@ -534,6 +536,54 @@ void py_module(py::module& m_primary) { Performs a rmsnorm(a+b)*gamma + beta operation. )doc"); + m_primary.def( + "layernorm_pre_allgather", + tt::operations::primary::layernorm_pre_allgather, + py::arg("input").noconvert(), + py::arg("compute_kernel_config").noconvert() = std::nullopt, + py::arg("output_dtype").noconvert() = DataType::BFLOAT16, + R"doc( + Performs the first part of a distributed layernorm operation collecting local statistics E(x) and E(xˆ2). + )doc"); + + m_primary.def( + "rmsnorm_pre_allgather", + tt::operations::primary::rmsnorm_pre_allgather, + py::arg("input").noconvert(), + py::arg("compute_kernel_config").noconvert() = std::nullopt, + py::arg("output_dtype").noconvert() = DataType::BFLOAT16, + R"doc( + Performs the first part of a distributed rms norm operation collecting local statistics E(x) and E(xˆ2). + )doc"); + + m_primary.def( + "layernorm_post_allgather", + tt::operations::primary::layernorm_post_allgather, + py::arg("input").noconvert(), + py::arg("stats").noconvert(), + py::arg("eps").noconvert(), + py::arg("gamma").noconvert() = std::nullopt, + py::arg("beta").noconvert() = std::nullopt, + py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("compute_kernel_config").noconvert() = std::nullopt, + R"doc( + Performs the second part of a distributed layernorm operation normalizing the input based on the gathered statistics input. + )doc"); + + m_primary.def( + "rmsnorm_post_allgather", + tt::operations::primary::rmsnorm_post_allgather, + py::arg("input").noconvert(), + py::arg("stats").noconvert(), + py::arg("eps").noconvert(), + py::arg("gamma").noconvert() = std::nullopt, + py::arg("beta").noconvert() = std::nullopt, + py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("compute_kernel_config").noconvert() = std::nullopt, + R"doc( + Performs the second part of a distributed rms norm operation normalizing the input based on the gathered statistics input. + )doc"); + // prod along all dimensions m_primary.def( "prod_all", diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp index d0cbcf31a27..f40bd1ff5fa 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp @@ -18,6 +18,8 @@ #include "tt_dnn/op_library/fully_connected/fully_connected_op.hpp" #include "tt_dnn/op_library/groupnorm/groupnorm_op.hpp" #include "tt_dnn/op_library/layernorm/layernorm_op.hpp" +#include "tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.hpp" +#include "tt_dnn/op_library/layernorm_distributed/layernorm_post_allgather_op.hpp" #include "tt_dnn/op_library/pool/average_pool.hpp" #include "tt_dnn/op_library/pool/max_pool.hpp" #include "tt_dnn/op_library/reduce/reduce_op.hpp" @@ -628,6 +630,53 @@ void TensorModule(py::module& m_tensor) { R"doc( "Performs a rmsnorm(a+b)*gamma + beta operation. )doc"); + m_tensor.def( + "layernorm_pre_allgather", + tt::operations::primary::layernorm_pre_allgather, + py::arg("input").noconvert(), + py::arg("compute_kernel_config").noconvert() = std::nullopt, + py::arg("output_dtype").noconvert() = DataType::BFLOAT16, + R"doc( + Performs the first part of a distributed layernorm operation collecting local statistics E(x) and E(xˆ2). + )doc"); + + m_tensor.def( + "rmsnorm_pre_allgather", + tt::operations::primary::rmsnorm_pre_allgather, + py::arg("input").noconvert(), + py::arg("compute_kernel_config").noconvert() = std::nullopt, + py::arg("output_dtype").noconvert() = DataType::BFLOAT16, + R"doc( + Performs the first part of a distributed rms norm operation collecting local statistics E(x) and E(xˆ2). + )doc"); + + m_tensor.def( + "layernorm_post_allgather", + tt::operations::primary::layernorm_post_allgather, + py::arg("input").noconvert(), + py::arg("stats").noconvert(), + py::arg("eps").noconvert(), + py::arg("gamma").noconvert() = std::nullopt, + py::arg("beta").noconvert() = std::nullopt, + py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("compute_kernel_config").noconvert() = std::nullopt, + R"doc( + Performs the second part of a distributed layernorm operation normalizing the input based on the gathered statistics input. + )doc"); + + m_tensor.def( + "rmsnorm_post_allgather", + tt::operations::primary::rmsnorm_post_allgather, + py::arg("input").noconvert(), + py::arg("stats").noconvert(), + py::arg("eps").noconvert(), + py::arg("gamma").noconvert() = std::nullopt, + py::arg("beta").noconvert() = std::nullopt, + py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("compute_kernel_config").noconvert() = std::nullopt, + R"doc( + Performs the second part of a distributed rms norm operation normalizing the input based on the gathered statistics input. + )doc"); m_tensor.def( "rotate_half", &rotate_half,