Skip to content

Commit

Permalink
Merge branch 'main' into debug-module-checkpoint-extra-state
Browse files Browse the repository at this point in the history
  • Loading branch information
timmoon10 authored Nov 28, 2024
2 parents 055100d + a132ac4 commit d940df5
Show file tree
Hide file tree
Showing 27 changed files with 1,314 additions and 343 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
run: pip install . -v
env:
NVTE_FRAMEWORK: jax
MAX_JOBS: 1
- name: 'Sanity check'
run: python tests/jax/test_sanity_import.py
paddle:
Expand Down
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.13.0.dev0
1.14.0.dev0
58 changes: 58 additions & 0 deletions qa/L1_pytorch_mcore_integration/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

set -e

# Paths
: ${TE_PATH:=/opt/transformerengine}
: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM}

# Download Megatron-LM if needed
if [ ! -d "${MCORE_PATH}" ]; then
pushd $(dirname ${MCORE_PATH})
git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM
popd
fi

# Megatron-LM invocation
COMMAND="
NVTE_TORCH_COMPILE=0
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
NVTE_FLASH_ATTN=1
NVTE_FWD_LAYERNORM_SM_MARGIN=0
NVTE_BWD_LAYERNORM_SM_MARGIN=0
CUDA_DEVICE_MAX_CONNECTIONS=1
NVTE_BIAS_GELU_NVFUSION=0
NVTE_BIAS_DROPOUT_FUSION=0
python
-m torch.distributed.launch
--use_env
--nnodes=1
--nproc_per_node=1
${MCORE_PATH}/pretrain_gpt.py
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--use-cpu-initialization
--num-layers 2
--hidden-size 128
--num-attention-heads 8
--seq-length 128
--max-position-embeddings 2048
--micro-batch-size 1
--global-batch-size 8
--train-iters 10
--eval-iters 10
--lr 1e-4
--mock-data
--vocab-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-vocab.json
--merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt
--transformer-impl transformer_engine
--fp8-format hybrid
"
COMMAND=$(echo "${COMMAND}" | tr '\n' ' ')

# Launch Megatron-LM
bash -c "${COMMAND}"
10 changes: 10 additions & 0 deletions tests/cpp/util/test_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
************************************************************************/

#include <string>
#include <vector>

#include <gtest/gtest.h>

Expand Down Expand Up @@ -57,6 +58,12 @@ TEST(UtilTest, ToStringLike) { // to_string_like
EXPECT_EQ(std::stof(to_string_like(-2.5f)), -2.5f);
EXPECT_EQ(std::stod(to_string_like(2.25)), 2.25);
EXPECT_EQ(std::stod(to_string_like(-4.5)), -4.5);

// Container types
EXPECT_EQ(to_string_like(std::vector<int>{-3,1,-4}), "(-3,1,-4)");
EXPECT_EQ(to_string_like(std::vector<std::string>{"Accept", "no", "substitutes", ".",
"Buy", "N", "V", "IDIA"}),
"(Accept,no,substitutes,.,Buy,N,V,IDIA)");
}

TEST(UtilTest, ConcatStringsTest) { // concat_strings
Expand Down Expand Up @@ -88,6 +95,9 @@ TEST(UtilTest, ConcatStringsTest) { // concat_strings
EXPECT_EQ(std::stof(concat_strings(6.5f)), 6.5f);
EXPECT_EQ(std::stod(concat_strings("-", 4.25)), -4.25);
EXPECT_EQ(std::stod(concat_strings(8.5)), 8.5);

// Container types
EXPECT_EQ(concat_strings("vector ", std::vector<int>{1,-2,3}), "vector (1,-2,3)");
}

TEST(UtilTest, RegexReplaceTest) { // regex_replace
Expand Down
160 changes: 160 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,166 @@ def test_make_extra_output(
torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_input", (False, True))
@pytest.mark.parametrize("fp8_output", (False, True))
def test_activation(
self,
*,
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_input: bool,
fp8_output: bool,
) -> None:
"""Activation functions"""

# Tensor dimensions
in_shape = list(out_shape)
if activation in ("geglu", "reglu", "swiglu"):
in_shape[-1] *= 2

# Skip invalid configurations
if fp8_input or fp8_output:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_input,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
y_ref: torch.Tensor
if activation == "gelu":
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
elif activation == "relu":
y_ref = torch.nn.functional.relu(x_ref)
elif activation == "geglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "swiglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
else:
raise ValueError(f"Unexpected activation function ({activation})")
y_ref.backward(dy_ref)

# Implementation with fusible operation
make_op = dict(
gelu=te_ops.GELU,
relu=te_ops.ReLU,
geglu=te_ops.GEGLU,
reglu=te_ops.ReGLU,
swiglu=te_ops.SwiGLU,
)[activation]
forward = te_ops.Sequential(
make_op(),
te_ops.Quantize(forward=fp8_output, backward=False),
)
with te.fp8_autocast(enabled=fp8_output):
y_test = forward(x_test)
y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
if fp8_output:
tols = dtype_tols(tex.DType.kFloat8E4M3)

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_output", (False, True))
@pytest.mark.parametrize("fp8_grad_input", (False, True))
def test_swiglu(
self,
*,
out_shape: Iterable[int] = (16, 16),
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_output: bool,
fp8_grad_input: bool,
):

# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2

# Skip invalid configurations
fp8 = fp8_output or fp8_grad_input
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# FP8 recipe
fp8_recipe = None
if fp8_grad_input:
fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
y_ref.backward(dy_ref)

# Implementation with fusible operation
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=fp8_grad_input),
te_ops.SwiGLU(),
te_ops.Quantize(forward=fp8_output, backward=False),
)
with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
y_test = forward(x_test)
y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
if fp8:
tols = dtype_tols(tex.DType.kFloat8E4M3)

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)


class TestFusedOps:
"""Tests for fused operations"""
Expand Down
7 changes: 4 additions & 3 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ endif()

# cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR
"${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include")
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include")
if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find cuDNN frontend API. "
"Could not find cuDNN frontend API at ${CUDNN_FRONTEND_INCLUDE_DIR}. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)

# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
Expand All @@ -61,6 +61,7 @@ list(APPEND transformer_engine_SOURCES
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/thd_utils.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
layer_norm/ln_api.cpp
Expand Down
76 changes: 76 additions & 0 deletions transformer_engine/common/fused_attn/thd_utils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "../cudnn_utils.h"
#include "thd_utils.h"

namespace transformer_engine {
namespace fused_attn {

__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch,
int total_tokens, int world_size, int rank) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
int seqlen = cu_seqlens[i];
// Currently we assume that each sequence length is divisible by (world_size*2) since we have
// to distribute each sequence evenly to different GPUs.
assert(seqlen % (world_size * 2) == 0);
cu_seqlens_s[i] = seqlen / world_size;
}
__syncthreads();

int tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = blockDim.x * gridDim.x;

for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
int index = token_id - cu_seqlens_s[seq_id];
int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank;
index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset;
output[token_id] = index;
}
}

__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch,
int hidden_size_in_bytes, int half_idx,
int dim_size_of_token) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
}
__syncthreads();

int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int laneid = threadIdx.x % 32;
int num_warps = (blockDim.x * gridDim.x) / 32;
int num_total_tokens = cu_seqlens_s[batch];
int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4);

size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size_in_bytes;
half = reinterpret_cast<void *>(reinterpret_cast<char *>(half) + offset / 2 * blockIdx.y);
tensor = reinterpret_cast<void *>(reinterpret_cast<char *>(tensor) + offset * blockIdx.y);

for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) {
int seqid = binary_search(token_id, cu_seqlens_s, batch + 1);

size_t offset_in_bytes = static_cast<size_t>(token_id) * hidden_size_in_bytes;
float4 *cur_half_token =
reinterpret_cast<float4 *>(reinterpret_cast<char *>(half) + offset_in_bytes);

offset_in_bytes =
(static_cast<size_t>(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes;
float4 *cur_token =
reinterpret_cast<float4 *>(reinterpret_cast<char *>(tensor) + offset_in_bytes);

for (int idx = laneid; idx < num_float4s_per_token; idx += 32) {
cur_half_token[idx] = cur_token[idx];
}
}
}

} // namespace fused_attn
} // namespace transformer_engine
Loading

0 comments on commit d940df5

Please sign in to comment.