-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'feature/npu' into feature/npu-accelerator
- Loading branch information
Showing
35 changed files
with
1,064 additions
and
272 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,14 @@ | ||
from .cpu_adam_loader import CPUAdamLoader | ||
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention | ||
from .extensions.flash_attention import AttnMaskType | ||
from .flash_attention_loader import ColoAttention, FlashAttentionLoader | ||
|
||
__all__ = [ | ||
"LayerNorm", | ||
"FusedScaleMaskSoftmax", | ||
"MultiHeadAttention", | ||
"CPUAdamLoader", | ||
"FlashAttentionLoader", | ||
"ColoAttention", | ||
"AttnMaskType", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Dict, List | ||
|
||
from .extensions.base_extension import BaseExtension | ||
|
||
|
||
class BaseKernelLoader(ABC): | ||
""" | ||
Usage: | ||
kernel_loader = KernelLoader() | ||
kernel = kernel_loader.load() | ||
""" | ||
|
||
def __init__(self, extension_map: Dict[str, BaseExtension], supported_device: List[str]): | ||
self._extension_map = extension_map | ||
self._supported_device = supported_device | ||
|
||
def run_checks(self): | ||
# run supported device check and other possible checks | ||
pass | ||
|
||
@abstractmethod | ||
def fetch_kernel(self): | ||
pass | ||
|
||
def load(self): | ||
self.run_checks() | ||
return self.fetch_kernel() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import platform | ||
from collections import OrderedDict | ||
|
||
from .base_kernel_loader import BaseKernelLoader | ||
from .extensions.cpu_adam import ArmCPUAdamExtension, X86CPUAdamExtension | ||
|
||
|
||
class CPUAdamLoader(BaseKernelLoader): | ||
""" | ||
CPU Adam Loader | ||
Usage: | ||
# init | ||
cpu_adam = CPUAdamLoader().load() | ||
cpu_adam_op = cpu_adam.CPUAdamOptimizer( | ||
alpha, beta1, beta2, epsilon, weight_decay, adamw_mode, | ||
) | ||
... | ||
# optim step | ||
cpu_adam_op.step( | ||
step, lr, beta1, beta2, epsilon, weight_decay, bias_correction, | ||
params, grads, exp_avg, exp_avg_sq, loss_scale, | ||
) | ||
Args: | ||
func CPUAdamOptimizer: | ||
alpha (float): learning rate. Default to 1e-3. | ||
beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9. | ||
beta2 (float): coefficients used for computing running averages of its square. Default to 0.99. | ||
epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8. | ||
weight_decay (float): weight decay (L2 penalty). Default to 0. | ||
adamw_mode (bool): whether to use the adamw. Default to True. | ||
func step: | ||
step (int): current step. | ||
lr (float): learning rate. | ||
beta1 (float): coefficients used for computing running averages of gradient. | ||
beta2 (float): coefficients used for computing running averages of its square. | ||
epsilon (float): term added to the denominator to improve numerical stability. | ||
weight_decay (float): weight decay (L2 penalty). | ||
bias_correction (bool): whether to use bias correction. | ||
params (torch.Tensor): parameter. | ||
grads (torch.Tensor): gradient. | ||
exp_avg (torch.Tensor): exp average. | ||
exp_avg_sq (torch.Tensor): exp average square. | ||
loss_scale (float): loss scale value. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__( | ||
extension_map=OrderedDict( | ||
arm=ArmCPUAdamExtension, | ||
x86=X86CPUAdamExtension, | ||
), | ||
supported_device=["cpu"], | ||
) | ||
|
||
def fetch_kernel(self): | ||
if platform.machine() == "x86_64": | ||
kernel = self._extension_map["x86"]().fetch() | ||
elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]: | ||
kernel = self._extension_map["arm"]().fetch() | ||
else: | ||
raise Exception("not supported") | ||
return kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Callable | ||
|
||
|
||
class BaseExtension(ABC): | ||
@abstractmethod | ||
def requires_build(self) -> bool: | ||
pass | ||
|
||
@abstractmethod | ||
def build(self) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def load(self) -> Callable: | ||
pass | ||
|
||
def fetch(self) -> Callable: | ||
if self.requires_build: | ||
self.build() | ||
return self.load() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .arm_extension import ArmCPUAdamExtension | ||
from .x86_extension import X86CPUAdamExtension | ||
|
||
__all__ = ["ArmCPUAdamExtension", "X86CPUAdamExtension"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from ..base_extension import BaseExtension | ||
from ..extension_builder import ExtensionBuilder | ||
|
||
|
||
class ArmCPUAdamExtension(BaseExtension): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.kernel_builder = ArmCPUAdamBuilder() | ||
self._requires_build = False | ||
|
||
@property | ||
def requires_build(self) -> bool: | ||
return self._requires_build | ||
|
||
def build(self): | ||
self.kernel_builder.build() | ||
self._requires_build = True | ||
|
||
def load(self): | ||
return self.kernel_builder.load() | ||
|
||
|
||
class ArmCPUAdamBuilder(ExtensionBuilder): | ||
NAME = "arm_cpu_adam" | ||
PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam" | ||
ext_type = "cpu" | ||
|
||
def __init__(self): | ||
super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH) | ||
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] | ||
|
||
# necessary 4 functions | ||
def sources_files(self): | ||
ret = [ | ||
self.csrc_abs_path("cpu_adam_arm.cpp"), | ||
] | ||
return ret | ||
|
||
def include_dirs(self): | ||
return [self.csrc_abs_path("includes")] | ||
|
||
def cxx_flags(self): | ||
extra_cxx_flags = [ | ||
"-std=c++14", | ||
"-std=c++17", | ||
"-g", | ||
"-Wno-reorder", | ||
"-fopenmp", | ||
] | ||
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags | ||
|
||
def nvcc_flags(self): | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from ..base_extension import BaseExtension | ||
from ..extension_builder import ExtensionBuilder | ||
from ..utils import append_nvcc_threads | ||
|
||
|
||
class X86CPUAdamExtension(BaseExtension): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.kernel_builder = X86CPUAdamBuilder() | ||
self._requires_build = False | ||
|
||
@property | ||
def requires_build(self) -> bool: | ||
return self._requires_build | ||
|
||
def build(self): | ||
self.kernel_builder.build() | ||
self._requires_build = True | ||
|
||
def load(self): | ||
return self.kernel_builder.load() | ||
|
||
|
||
class X86CPUAdamBuilder(ExtensionBuilder): | ||
NAME = "cpu_adam" | ||
PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam" | ||
|
||
def __init__(self): | ||
super().__init__(name=X86CPUAdamBuilder.NAME, prebuilt_import_path=X86CPUAdamBuilder.PREBUILT_IMPORT_PATH) | ||
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] | ||
|
||
# necessary 4 functions | ||
def sources_files(self): | ||
ret = [ | ||
self.csrc_abs_path("cpu_adam.cpp"), | ||
] | ||
return ret | ||
|
||
def include_dirs(self): | ||
return [self.csrc_abs_path("includes"), self.get_cuda_home_include()] | ||
|
||
def cxx_flags(self): | ||
extra_cxx_flags = [ | ||
"-std=c++14", | ||
"-std=c++17", | ||
"-lcudart", | ||
"-lcublas", | ||
"-g", | ||
"-Wno-reorder", | ||
"-fopenmp", | ||
"-march=native", | ||
] | ||
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags | ||
|
||
def nvcc_flags(self): | ||
extra_cuda_flags = [ | ||
"-std=c++14", | ||
"-std=c++17", | ||
"-U__CUDA_NO_HALF_OPERATORS__", | ||
"-U__CUDA_NO_HALF_CONVERSIONS__", | ||
"-U__CUDA_NO_HALF2_OPERATORS__", | ||
"-DTHRUST_IGNORE_CUB_VERSION_CHECK", | ||
] | ||
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags | ||
return append_nvcc_threads(ret) |
Oops, something went wrong.