diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 54e8a6d93487..2cad504f3391 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -140,7 +140,7 @@ jobs: - name: Install Colossal-AI run: | - CUDA_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v -e . pip install -r requirements/requirements-test.txt - name: Store Colossal-AI Cache @@ -160,9 +160,7 @@ jobs: --ignore tests/test_gptq \ --ignore tests/test_infer_ops \ --ignore tests/test_legacy \ - --ignore tests/test_moe \ --ignore tests/test_smoothquant \ - --ignore tests/test_checkpoint_io \ tests/ env: LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index 5b0103eb770d..ae1a5275e5da 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -12,7 +12,7 @@ jobs: if: github.repository == 'hpcaitech/ColossalAI' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 90 steps: @@ -23,6 +23,7 @@ jobs: ngpu=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) endIndex=$(($ngpu-1)) for i in $(seq 0 $endIndex); + do gpu_used=$(nvidia-smi -i $i --query-gpu=memory.used --format=csv,noheader,nounits) [ "$gpu_used" -gt "2000" ] && avai=false done @@ -54,7 +55,7 @@ jobs: if: steps.check-avai.outputs.avai == 'true' run: | [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ - CUDA_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v -e . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ pip install -r requirements/requirements-test.txt diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml index 02e30f52a459..bba321fd2d59 100644 --- a/.github/workflows/example_check_on_dispatch.yml +++ b/.github/workflows/example_check_on_dispatch.yml @@ -45,9 +45,9 @@ jobs: fail-fast: false matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}} container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: 📚 Checkout uses: actions/checkout@v3 diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 6d6952aa169a..fcff8e569ff7 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -77,7 +77,7 @@ jobs: fail-fast: false matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}} container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ timeout-minutes: 20 concurrency: diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml index 919fa5092a6c..abb9479492e7 100644 --- a/.github/workflows/example_check_on_schedule.yml +++ b/.github/workflows/example_check_on_schedule.yml @@ -34,7 +34,7 @@ jobs: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 timeout-minutes: 10 steps: - name: 📚 Checkout diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index f9e9f400962e..bb0ceb4a8296 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -18,7 +18,7 @@ jobs: github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat --shm-size=10.24gb timeout-minutes: 30 defaults: diff --git a/.github/workflows/run_chatgpt_unit_tests.yml b/.github/workflows/run_chatgpt_unit_tests.yml index ec5c8ffa319f..7986889e006b 100644 --- a/.github/workflows/run_chatgpt_unit_tests.yml +++ b/.github/workflows/run_chatgpt_unit_tests.yml @@ -20,7 +20,7 @@ jobs: github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt timeout-minutes: 30 defaults: diff --git a/.github/workflows/run_colossalqa_unit_tests.yml b/.github/workflows/run_colossalqa_unit_tests.yml index 763db277289f..00944b92d9b6 100644 --- a/.github/workflows/run_colossalqa_unit_tests.yml +++ b/.github/workflows/run_colossalqa_unit_tests.yml @@ -19,7 +19,7 @@ jobs: github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 volumes: - /data/scratch/test_data_colossalqa:/data/scratch/test_data_colossalqa - /data/scratch/llama-tiny:/data/scratch/llama-tiny @@ -51,4 +51,4 @@ jobs: TEST_DATA_PATH_EN: /data/scratch/test_data_colossalqa/companies.txt TEST_DATA_PATH_ZH: /data/scratch/test_data_colossalqa/companies_zh.txt TEST_DOCUMENT_LOADER_DATA_PATH: /data/scratch/test_data_colossalqa/tests/* - SQL_FILE_PATH: /data/scratch/test_data_colossalqa/sql_file_path \ No newline at end of file + SQL_FILE_PATH: /data/scratch/test_data_colossalqa/sql_file_path diff --git a/MANIFEST.in b/MANIFEST.in index ad26b634ac3e..f0a5611efc7d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ include *.txt README.md recursive-include requirements *.txt recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi -recursive-include op_builder *.py +recursive-include extensions *.py *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi diff --git a/README.md b/README.md index 971f4375a289..13757eece7db 100644 --- a/README.md +++ b/README.md @@ -141,25 +141,26 @@ distributed training and inference in a few lines. [[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-13b-base) [[Modelscope model weights]](https://www.modelscope.cn/models/colossalai/Colossal-LLaMA-2-13b-base/summary) -| Model | Backbone | Tokens Consumed | MMLU (5-shot) | CMMLU (5-shot)| AGIEval (5-shot) | GAOKAO (0-shot) | CEval (5-shot) | -| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-------------: | :-------------: | -| Baichuan-7B | - | 1.2T | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | -| Baichuan-13B-Base | - | 1.4T | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | -| Baichuan2-7B-Base | - | 2.6T | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | -| Baichuan2-13B-Base | - | 2.6T | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | -| ChatGLM-6B | - | 1.0T | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | -| ChatGLM2-6B | - | 1.4T | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | -| InternLM-7B | - | 1.6T | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | -| Qwen-7B | - | 2.2T | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | -| Llama-2-7B | - | 2.0T | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | -| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | 37.43 | 29.92 | 32.00 | 27.57 | - | -| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | 38.56 | 31.52 | 30.99 | 25.95 | - | -| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | -| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | 43.73 | 42.04 | 37.64 | 30.61 | - | -| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | 48.41 | 38.31 | 38.45 | 27.72 | - | -| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | 49.96 | 41.10 | 39.83 | 33.00 | - | -| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | 50.25 | 40.99 | 40.04 | 30.54 | - | -| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 | +| Model | Backbone | Tokens Consumed | MMLU (5-shot) | CMMLU (5-shot)| AGIEval (5-shot) | GAOKAO (0-shot) | CEval (5-shot) | +| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-------------: | :-------------: | +| Baichuan-7B | - | 1.2T | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | +| Baichuan-13B-Base | - | 1.4T | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | +| Baichuan2-7B-Base | - | 2.6T | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | +| Baichuan2-13B-Base | - | 2.6T | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | +| ChatGLM-6B | - | 1.0T | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | +| ChatGLM2-6B | - | 1.4T | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | +| InternLM-7B | - | 1.6T | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | +| Qwen-7B | - | 2.2T | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| Llama-2-7B | - | 2.0T | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | +| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | 37.43 | 29.92 | 32.00 | 27.57 | - | +| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | 38.56 | 31.52 | 30.99 | 25.95 | - | +| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | +| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | 43.73 | 42.04 | 37.64 | 30.61 | - | +| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | 48.41 | 38.31 | 38.45 | 27.72 | - | +| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | 49.96 | 41.10 | 39.83 | 33.00 | - | +| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | 50.25 | 40.99 | 40.04 | 30.54 | - | +| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 | +| **Colossal-LLaMA-2-13b-base** | Llama-2-13B | **0.025T** | 56.42 | 61.80 | 54.69 | 69.53 | 60.3 | ### ColossalChat diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index d6966689885e..330e4e0e395e 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -10,7 +10,7 @@ from tqdm import tqdm from transformers import PreTrainedTokenizerBase -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from .base import OnPolicyTrainer from .callbacks import Callback @@ -105,7 +105,7 @@ def __init__( self.critic_optim = critic_optim self.offload_inference_models = offload_inference_models - self.device = get_current_device() + self.device = get_accelerator().get_current_device() def _before_fit( self, diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 7129edb060ef..95f01678640c 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -6,7 +6,6 @@ import colossalai from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.utils import get_current_device from colossalai.zero.gemini.gemini_ddp import GeminiDDP from .ddp import DDPStrategy @@ -158,9 +157,19 @@ def __init__( warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.") + # colossalai has changed api for get_current_device in 0.3.4 version or newer + try: + from colossalai.accelerator import get_accelerator + + chunk_init_device = get_accelerator().get_current_device() + except: + from colossalai.utils import get_current_device + + chunk_init_device = get_current_device() + # NOTE: dist should be initialized before calling get_current_device() plugin_initializer = lambda: GeminiPlugin( - chunk_init_device=get_current_device(), + chunk_init_device=chunk_init_device, placement_policy=placement_policy, shard_param_frac=shard_param_frac, offload_optim_frac=offload_optim_frac, diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py index 43297633db1a..439135503002 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py @@ -6,12 +6,12 @@ """ import argparse -import os import json +import os from typing import List, Union -from transformers.models.llama.tokenization_llama import LlamaTokenizer from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model +from transformers.models.llama.tokenization_llama import LlamaTokenizer from colossalai.logging import get_dist_logger diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py index 079faaace0ed..9f6c9c1cc6f3 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py @@ -16,7 +16,10 @@ def unwrap(model): - return model.unwrap().module + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model def neftune_post_forward_hook(module, input, output): diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 41b4ef031b46..92863e8e4bba 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -1,44 +1,37 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Continual Pre-training of LLaMA-2 developed by Colossal-AI Team +Continual Pre-training of LLaMA-2 developed by Colossal-AI Team """ -import json import argparse +import json import os import resource from contextlib import nullcontext -from tqdm import tqdm import torch import torch.distributed as dist +from colossal_llama2.dataset.loader import ( + DataCollatorForSupervisedDataset, + StatefulDistributedSampler, + load_tokenized_dataset, + setup_distributed_dataloader, +) +from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint +from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention +from colossal_llama2.utils.froze import freeze_non_embeds_parameters from torch.utils.tensorboard import SummaryWriter -from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig +from tqdm import tqdm +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import ( - GeminiPlugin, - LowLevelZeroPlugin, - HybridParallelPlugin, -) +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device - -from colossal_llama2.dataset.loader import ( - load_tokenized_dataset, - setup_distributed_dataloader, - DataCollatorForSupervisedDataset, - StatefulDistributedSampler, -) - -from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention -from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint -from colossal_llama2.utils.froze import freeze_non_embeds_parameters def get_model_numel(model: torch.nn.Module) -> int: @@ -215,9 +208,18 @@ def main() -> None: # ====================================================== # Initialize Model, Objective, Optimizer and LR Scheduler # ====================================================== - init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() - ) + + # colossalai has changed api for get_current_device in 0.3.4 version or newer + try: + from colossalai.accelerator import get_accelerator + + current_device = get_accelerator().get_current_device() + except: + from colossalai.utils import get_current_device + + current_device = get_current_device() + + init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() with init_ctx: model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) # Freeze part of parameters. @@ -320,7 +322,7 @@ def main() -> None: initial=start_step, ) as pbar: for step, batch in pbar: - batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} + batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)} batch_output = model(**batch) @@ -372,9 +374,7 @@ def main() -> None: # Final save. coordinator.print_on_master("Start saving final model checkpoint") booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) - coordinator.print_on_master( - f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}" - ) + coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") diff --git a/applications/ColossalQA/colossalqa/local/llm.py b/applications/ColossalQA/colossalqa/local/llm.py index ff7346adcf61..0aa383e9d0b9 100644 --- a/applications/ColossalQA/colossalqa/local/llm.py +++ b/applications/ColossalQA/colossalqa/local/llm.py @@ -136,6 +136,19 @@ def _identifying_params(self) -> Mapping[str, int]: """Get the identifying parameters.""" return {"n": self.n} + def get_token_ids(self, text: str) -> List[int]: + """Return the ordered ids of the tokens in a text. + + Args: + text: The string input to tokenize. + + Returns: + A list of ids corresponding to the tokens in the text, in order they occur + in the text. + """ + # use the colossal llm's tokenizer instead of langchain's cached GPT2 tokenizer + return self.api.tokenizer.encode(text) + class VllmLLM(LLM): """ diff --git a/colossalai/__init__.py b/colossalai/__init__.py index 7da55590305b..6b7f5d055207 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -1,4 +1,5 @@ from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch +from . import accelerator try: # .version will be created by setup.py diff --git a/colossalai/accelerator/README.md b/colossalai/accelerator/README.md new file mode 100644 index 000000000000..8c644493b03a --- /dev/null +++ b/colossalai/accelerator/README.md @@ -0,0 +1,20 @@ +# 🚀 Accelerator + +## 🔗 Table of Contents + +- [🚀 Accelerator](#-accelerator) + - [🔗 Table of Contents](#-table-of-contents) + - [📚 Introduction](#-introduction) + - [📌 Design and Acknowledgement](#-design-and-acknowledgement) + +## 📚 Introduction + +This module offers a layer of abstraction for ColossalAI. With this module, the user can easily switch between different accelerator backends, such as Nvidia GPUs, Huawei NPUs, etc. This module is an attempt to make users' code portable across different hardware platform with a simple `auto_set_accelerator()` API. + +## 📌 Design and Acknowledgement + +Our `accelerator` module is heavily inspired by [`deepspeed/accelerator`](https://www.deepspeed.ai/tutorials/accelerator-abstraction-interface/). We found that it is a very well-designed and well-structured module that can be easily integrated into our project. We would like to thank the DeepSpeed team for their great work. + +We implemented this accelerator module from scratch. At the same time, we have implemented our own modifications: +1. we updated the accelerator API names to be aligned with PyTorch's native API names. +2. we did not include the `op builder` in the `accelerator`. Instead, we have reconstructed our `kernel` module to automatically match the accelerator and its corresponding kernel implementations, so as to make modules less tangled. diff --git a/colossalai/accelerator/__init__.py b/colossalai/accelerator/__init__.py new file mode 100644 index 000000000000..1405133affe2 --- /dev/null +++ b/colossalai/accelerator/__init__.py @@ -0,0 +1,15 @@ +from .api import auto_set_accelerator, get_accelerator, set_accelerator +from .base_accelerator import BaseAccelerator +from .cpu_accelerator import CpuAccelerator +from .cuda_accelerator import CudaAccelerator +from .npu_accelerator import NpuAccelerator + +__all__ = [ + "get_accelerator", + "set_accelerator", + "auto_set_accelerator", + "BaseAccelerator", + "CudaAccelerator", + "NpuAccelerator", + "CpuAccelerator", +] diff --git a/colossalai/accelerator/api.py b/colossalai/accelerator/api.py new file mode 100644 index 000000000000..02b3055d7380 --- /dev/null +++ b/colossalai/accelerator/api.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +from collections import OrderedDict +from typing import Union + +from .base_accelerator import BaseAccelerator +from .cpu_accelerator import CpuAccelerator +from .cuda_accelerator import CudaAccelerator +from .npu_accelerator import NpuAccelerator + +__all__ = ["set_accelerator", "auto_set_accelerator", "get_accelerator"] + + +_ACCELERATOR = None + + +# we use ordered dictionary here to associate the +# order with device check priority +# i.e. auto_set_accelerator will check cuda first +_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator) + + +def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None: + """ + Set the global accelerator for the current process. + + Args: + accelerator (Union[str, BaseAccelerator]): the type of accelerator to which the current device belongs. + """ + + global _ACCELERATOR + + if isinstance(accelerator, str): + _ACCELERATOR = _ACCELERATOR_MAPPING[accelerator]() + elif isinstance(accelerator, BaseAccelerator): + _ACCELERATOR = accelerator + else: + raise TypeError("accelerator must be either a string or an instance of BaseAccelerator") + + +def auto_set_accelerator() -> None: + """ + Automatically check if any accelerator is available. + If an accelerator is availabe, set it as the global accelerator. + """ + global _ACCELERATOR + + for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items(): + try: + accelerator = accelerator_cls() + if accelerator_name == "cpu" or accelerator.is_available(): + _ACCELERATOR = accelerator + break + except: + pass + + if _ACCELERATOR is None: + raise RuntimeError("No accelerator is available.") + + +def get_accelerator() -> BaseAccelerator: + """ + Return the accelerator for the current process. If the accelerator is not initialized, it will be initialized + to the default accelerator type. + + Returns: the accelerator for the current process. + """ + global _ACCELERATOR + + if _ACCELERATOR is None: + auto_set_accelerator() + return _ACCELERATOR diff --git a/colossalai/accelerator/base_accelerator.py b/colossalai/accelerator/base_accelerator.py new file mode 100644 index 000000000000..33c113999018 --- /dev/null +++ b/colossalai/accelerator/base_accelerator.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python + +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +__all__ = ["BaseAccelerator"] + + +class BaseAccelerator(ABC): + support_set_device: bool = True + + def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None: + self._name = name + self._communication_backend = communication_backend + self._is_synchronous = is_synchronous + + # ======================= + # immutable attributes + # ======================= + + @property + def name(self) -> str: + """ + Return the name of the accelerator. + """ + return self._name + + @property + def communication_backend(self) -> str: + """ + Return the name of the backend communication library. + """ + return self._communication_backend + + @property + def is_synchronous(self) -> bool: + """ + Return whether the accelerator is a synchronous device. + """ + return self._is_synchronous + + def __repr__(self) -> str: + cls_name = self.__class__.__name__ + return f"{cls_name}(name={self._name}, communication_backend={self._communication_backend}, is_synchronous={self._is_synchronous})" + + # ======================= + # device APIs + # ======================= + @abstractmethod + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + + @abstractmethod + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + + @abstractmethod + def current_device(self) -> int: + """ + Return the current device index. + """ + + @abstractmethod + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + + @abstractmethod + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + + @abstractmethod + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + + @abstractmethod + def is_available(self): + """ + Check if the accelerator is available. + """ + + @abstractmethod + def device_count(self): + """ + Return the number of devices on the machine. + """ + + def set_to_device(self, models: Any) -> Any: + """ + Send model to device. + + :param models: nn.module or a list of module + """ + if isinstance(models, list) and len(models) > 1: + ret = [] + for model in models: + ret.append(model.to(self.get_current_device())) + return ret + elif isinstance(models, list): + return models[0].to(self.get_current_device()) + else: + return models.to(self.get_current_device()) + + @abstractmethod + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the capability of a device. + """ + + @abstractmethod + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + + @abstractmethod + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + + @abstractmethod + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the device as given by nvidia-smi or npu-smi, etc. + """ + + # ======================= + # random number generator APIs + # ======================= + @abstractmethod + def get_rng_state(self, device="cuda") -> torch.Tensor: + """ + Returns the random number generator state of the specified device as a ByteTensor. + """ + + @abstractmethod + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + + @abstractmethod + def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None: + """ + Sets the random number generator state of the specified device. + """ + + @abstractmethod + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + + @abstractmethod + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current device. + """ + + @abstractmethod + def manual_seed_all(self, seed: int) -> None: + """ + Sets the seed for generating random numbers on all devices. + """ + + @abstractmethod + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current device. + """ + + @abstractmethod + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all devices. + """ + + @abstractmethod + def initial_seed(self) -> int: + """ + Returns the current random seed of the current device. + """ + + # ======================= + # memory management APIs + # ======================= + @abstractmethod + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other device application and visible in nvidia-smi. + """ + + @abstractmethod + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of CUDA memory allocator statistics for a given device. + """ + + @abstractmethod + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + + @abstractmethod + def memory_snapshot(self): + """ + Returns a snapshot of the CUDA memory allocator state across all devices. + """ + + @abstractmethod + def memory_allocated(self, device=None) -> int: + """ + Returns the current device memory occupied by tensors in bytes for a given device. + """ + + @abstractmethod + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum device memory occupied by tensors in bytes for a given device. + """ + + @abstractmethod + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum device memory occupied by tensors for a given device. + """ + + @abstractmethod + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum device memory managed by the caching allocator for a given device. + """ + + @abstractmethod + def memory_reserved(self, device=None) -> int: + """ + Returns the current device memory managed by the caching allocator in bytes for a given device. + """ + + @abstractmethod + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum device memory managed by the caching allocator in bytes for a given device. + """ + + @abstractmethod + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + + @abstractmethod + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the device memory allocator. + """ + + # ======================= + # streams and events APIs + # ======================= + + @abstractmethod + def Stream(self, device=None, priority=0, **kwargs): + """ + A device stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details. + """ + + @abstractmethod + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + device events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams. + """ + + @abstractmethod + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + + @abstractmethod + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + + @abstractmethod + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + + @abstractmethod + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + + # ======================= + # amp APIs + # ======================= + @abstractmethod + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ diff --git a/colossalai/accelerator/cpu_accelerator.py b/colossalai/accelerator/cpu_accelerator.py new file mode 100644 index 000000000000..080aa61e8e3a --- /dev/null +++ b/colossalai/accelerator/cpu_accelerator.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python + +import resource +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import psutil +import torch + +from .base_accelerator import BaseAccelerator + +__all__ = ["CpuAccelerator"] + + +class CpuAccelerator(BaseAccelerator): + support_set_device: bool = False + """ + Accelerator class for cpu. + """ + + def __init__(self): + super().__init__(name="cpu", communication_backend="gloo", is_synchronous=False) + + # ======================= + # device APIs + # ======================= + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + return "" + + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + return torch.device("cpu") + + def current_device(self) -> int: + """ + Return the current device index. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def is_available(self): + """ + Check if the accelerator is available. + """ + return True + + def device_count(self): + """ + Return the number of devices on the machine. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the cuda capability of a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # random number generator APIs + # ======================= + def get_rng_state(self, device=None) -> torch.Tensor: + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + return torch.get_rng_state(device) + + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def set_rng_state(self, new_state: torch.ByteTensor, device: str = None) -> None: + """ + Sets the random number generator state of the specified GPU. + """ + torch.set_rng_state(new_state) + + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current GPU. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def manual_seed_all(self, seed: int) -> None: + """ + Set the random seed for the all processes. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current GPU. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all GPUs. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def initial_seed(self) -> int: + """ + Returns the current random seed of the current GPU. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # memory management APIs + # ======================= + + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of CUDA memory allocator statistics for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_snapshot(self): + """ + Returns a snapshot of the CUDA memory allocator state across all devices. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_allocated(self, device=None) -> int: + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_reserved(self, device=None) -> int: + """ + Returns the current GPU memory managed by the caching allocator in bytes for a given device. + """ + return psutil.Process().memory_info().rss + + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + """ + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + max_memory = int(psutil.virtual_memory().total * fraction) + _, hard = resource.getrlimit(resource.RLIMIT_AS) + resource.setrlimit(resource.RLIMIT_AS, (max_memory, hard)) + + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the CUDA memory allocator. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # streams and events APIs + # ======================= + + def Stream(self, device=None, priority=0, **kwargs): + """ + A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # amp APIs + # ======================= + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ + return nullcontext diff --git a/colossalai/accelerator/cuda_accelerator.py b/colossalai/accelerator/cuda_accelerator.py new file mode 100644 index 000000000000..f1ab487d4f58 --- /dev/null +++ b/colossalai/accelerator/cuda_accelerator.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist + +from .base_accelerator import BaseAccelerator + +__all__ = ["CudaAccelerator"] + + +class CudaAccelerator(BaseAccelerator): + """ + Accelerator class for Nvidia CUDA devices. + """ + + def __init__(self): + super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False) + + # ======================= + # device APIs + # ======================= + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + return torch.version.cuda + + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + return torch.device(f"cuda:{torch.cuda.current_device()}") + + def current_device(self) -> int: + """ + Return the current device index. + """ + return torch.cuda.current_device() + + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + if device is None: + if not dist.is_initialized(): + raise RuntimeError("Cannot get current device when distributed is not initialized.") + device = dist.get_rank() % self.device_count() + torch.cuda.set_device(device) + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + return torch.cuda.get_device_name(device) + + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + torch.cuda.synchronize(device) + + def is_available(self): + """ + Check if the accelerator is available. + """ + return torch.cuda.is_available() + + def device_count(self): + """ + Return the number of devices on the machine. + """ + return torch.cuda.device_count() + + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the cuda capability of a device. + """ + return torch.cuda.get_device_capability(device) + + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + return torch.cuda.get_device_name(device) + + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + return torch.cuda.get_device_properties(device) + + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi + """ + return torch.cuda.utilization(device) + + # ======================= + # random number generator APIs + # ======================= + def get_rng_state(self, device="cuda") -> torch.Tensor: + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + return torch.cuda.get_rng_state(device) + + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + return torch.cuda.get_rng_state_all() + + def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None: + """ + Sets the random number generator state of the specified GPU. + """ + torch.cuda.set_rng_state(new_state, device) + + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + torch.cuda.set_rng_state_all(new_states) + + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current GPU. + """ + torch.cuda.manual_seed(seed) + + def manual_seed_all(self, seed: int) -> None: + """ + Set the random seed for the all processes. + """ + torch.cuda.manual_seed_all(seed) + + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current GPU. + """ + torch.cuda.seed() + + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all GPUs. + """ + torch.cuda.seed_all() + + def initial_seed(self) -> int: + """ + Returns the current random seed of the current GPU. + """ + return torch.cuda.initial_seed() + + # ======================= + # memory management APIs + # ======================= + + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi. + """ + torch.cuda.empty_cache() + + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of CUDA memory allocator statistics for a given device. + """ + return torch.cuda.memory_stats(device=device) + + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + return torch.cuda.memory_summary(device=device, abbreviated=abbreviated) + + def memory_snapshot(self): + """ + Returns a snapshot of the CUDA memory allocator state across all devices. + """ + return torch.cuda.memory_snapshot() + + def memory_allocated(self, device=None) -> int: + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + return torch.cuda.memory_allocated(device=device) + + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + return torch.cuda.max_memory_allocated(device=device) + + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device. + """ + torch.cuda.reset_max_memory_allocated(device=device) + + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + """ + torch.cuda.reset_max_memory_cached(device=device) + + def memory_reserved(self, device=None) -> int: + """ + Returns the current GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.cuda.memory_reserved(device=device) + + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.cuda.max_memory_reserved(device=device) + + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + torch.cuda.set_per_process_memory_fraction(fraction, device=device) + + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the CUDA memory allocator. + """ + torch.cuda.reset_peak_memory_stats(device=device) + + # ======================= + # streams and events APIs + # ======================= + + def Stream(self, device=None, priority=0, **kwargs): + """ + A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details. + """ + return torch.cuda.Stream(device, priority, **kwargs) + + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams. + """ + return torch.cuda.Event(enable_timing, blocking, interprocess) + + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + return torch.cuda.current_stream(device) + + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + return torch.cuda.default_stream(device) + + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + torch.cuda.set_stream(stream_) + + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + return torch.cuda.stream(stream_) + + # ======================= + # amp APIs + # ======================= + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ + return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) diff --git a/colossalai/accelerator/npu_accelerator.py b/colossalai/accelerator/npu_accelerator.py new file mode 100644 index 000000000000..b28492968eeb --- /dev/null +++ b/colossalai/accelerator/npu_accelerator.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist + +from .base_accelerator import BaseAccelerator + +try: + import torch_npu # noqa +except ImportError: + pass + + +__all__ = ["NpuAccelerator"] + + +class NpuAccelerator(BaseAccelerator): + """ + Accelerator class for Huawei NPU devices. + """ + + def __init__(self): + super().__init__(name="npu", communication_backend="hccl", is_synchronous=False) + + # ======================= + # device APIs + # ======================= + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + return torch.version.cann + + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + return torch.device(f"npu:{torch.npu.current_device()}") + + def current_device(self) -> int: + """ + Return the current device index. + """ + return torch.npu.current_device() + + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + if device is None: + if not dist.is_initialized(): + raise RuntimeError("Cannot get current device when distributed is not initialized.") + device = dist.get_rank() % self.device_count() + torch.npu.set_device(device) + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + return torch.npu.get_device_name(device) + + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + torch.npu.synchronize(device) + + def is_available(self): + """ + Check if the accelerator is available. + """ + return torch.npu.is_available() + + def device_count(self): + """ + Return the number of devices on the machine. + """ + return torch.npu.device_count() + + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the npu capability of a device. + """ + return torch.npu.get_device_capability(device) + + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + return torch.npu.get_device_name(device) + + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + return torch.npu.get_device_properties(device) + + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi + """ + return torch.npu.utilization(device) + + # ======================= + # random number generator APIs + # ======================= + def get_rng_state(self, device="npu") -> torch.Tensor: + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + return torch.npu.get_rng_state(device) + + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + return torch.npu.get_rng_state_all() + + def set_rng_state(self, new_state: torch.ByteTensor, device: str = "npu") -> None: + """ + Sets the random number generator state of the specified GPU. + """ + torch.npu.set_rng_state(new_state, device) + + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + torch.npu.set_rng_state_all(new_states) + + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current GPU. + """ + torch.npu.manual_seed(seed) + + def manual_seed_all(self, seed: int) -> None: + """ + Set the random seed for the all processes. + """ + torch.npu.manual_seed_all(seed) + + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current GPU. + """ + torch.npu.seed() + + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all GPUs. + """ + torch.npu.seed_all() + + def initial_seed(self) -> int: + """ + Returns the current random seed of the current GPU. + """ + return torch.npu.initial_seed() + + # ======================= + # memory management APIs + # ======================= + + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi. + """ + torch.npu.empty_cache() + + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of npu memory allocator statistics for a given device. + """ + return torch.npu.memory_stats(device=device) + + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + return torch.npu.memory_summary(device=device, abbreviated=abbreviated) + + def memory_snapshot(self): + """ + Returns a snapshot of the npu memory allocator state across all devices. + """ + return torch.npu.memory_snapshot() + + def memory_allocated(self, device=None) -> int: + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + return torch.npu.memory_allocated(device=device) + + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + return torch.npu.max_memory_allocated(device=device) + + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device. + """ + torch.npu.reset_max_memory_allocated(device=device) + + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + """ + torch.npu.reset_max_memory_cached(device=device) + + def memory_reserved(self, device=None) -> int: + """ + Returns the current GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.npu.memory_reserved(device=device) + + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.npu.max_memory_reserved(device=device) + + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + torch.npu.set_per_process_memory_fraction(fraction, device=device) + + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the npu memory allocator. + """ + torch.npu.reset_peak_memory_stats(device=device) + + # ======================= + # streams and events APIs + # ======================= + + def Stream(self, device=None, priority=0, **kwargs): + """ + A npu stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See npu-semantics for details. + """ + return torch.npu.Stream(device, priority, **kwargs) + + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + npu events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize npu streams. + """ + return torch.npu.Event(enable_timing, blocking, interprocess) + + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + return torch.npu.current_stream(device) + + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + return torch.npu.default_stream(device) + + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + torch.npu.set_stream(stream_) + + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + return torch.npu.stream(stream_) + + # ======================= + # amp APIs + # ======================= + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ + return torch.npu.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py index 439d13dcfc11..fc4c884d4c5d 100644 --- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -7,8 +7,8 @@ import torch from torch import Tensor +from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger -from colossalai.utils.device import get_current_device __all__ = ["BaseGradScaler"] @@ -23,7 +23,7 @@ class BaseGradScaler(ABC): def __init__(self, initial_scale: float, verbose: bool): assert initial_scale > 0 - self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float) + self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float) self._verbose = verbose if self._verbose: diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py index 86ba919ee696..5cd8035d7987 100644 --- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -5,7 +5,7 @@ import torch -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator from .base_grad_scaler import BaseGradScaler @@ -37,14 +37,20 @@ def __init__( hysteresis: int = 2, verbose: bool = False, ): + a = get_accelerator() + a.device_count() super().__init__(initial_scale, verbose) if min_scale: - self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float) + self._min_scale = torch.tensor( + [min_scale], device=get_accelerator().get_current_device(), dtype=torch.float + ) else: self._min_scale = None if max_scale: - self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float) + self._max_scale = torch.tensor( + [max_scale], device=get_accelerator().get_current_device(), dtype=torch.float + ) else: self._max_scale = None @@ -117,7 +123,7 @@ def state_dict(self): return state_dict def load_state_dict(self, state_dict): - self._scale = state_dict["scale"].to(get_current_device()) + self._scale = state_dict["scale"].to(get_accelerator().get_current_device()) self._growth_factor = state_dict["growth_factor"] self._backoff_factor = state_dict["backoff_factor"] self._hysteresis = state_dict["hysteresis"] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py index 9ce272356797..2e7c8a281916 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py @@ -5,8 +5,8 @@ import torch.distributed as dist from torch import Tensor +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler -from colossalai.utils import get_current_device from .base import MixedPrecisionMixin @@ -40,7 +40,7 @@ def __init__( max_scale=max_scale, ) self.optim_state = OptimState.UNSCALED - self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) + self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device()) @property def loss_scale(self) -> float: diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py index 601bf2926d99..fe8439269f48 100644 --- a/colossalai/auto_parallel/offload/amp_optimizer.py +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -4,10 +4,10 @@ import torch from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device from .base_offload_module import BaseOffloadModule from .region import Region @@ -79,7 +79,9 @@ def __init__( hysteresis=hysteresis, max_scale=max_scale, ) - self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + self._found_overflow: torch.Tensor = torch.zeros( + 1, dtype=torch.int64, device=get_accelerator().get_current_device() + ) self._logger = get_dist_logger() def _set_grad_ptr(self): diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py index a6628e29c2bc..3ad210de9f0a 100644 --- a/colossalai/auto_parallel/offload/solver.py +++ b/colossalai/auto_parallel/offload/solver.py @@ -11,7 +11,7 @@ import torch from torch.fx.node import Node -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator from .region import Region from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator @@ -57,7 +57,10 @@ def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error if memory_budget > 0: self.memory_budget = memory_budget * self.error_factor else: - self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor + self.memory_budget = ( + torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory + * self.error_factor + ) self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.comp_power: float = self._extract_computing_power() diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index 443c4094c0e1..c757a878d97a 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -5,8 +5,8 @@ from torch import Tensor from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.utils.device import autocast from .mixed_precision_base import MixedPrecision @@ -89,7 +89,7 @@ def __init__(self, module: nn.Module): super().__init__(module) def forward(self, *args, **kwargs): - with autocast(): + with get_accelerator().autocast(): return self.module(*args, **kwargs) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index a891db422d67..d14109dd43e5 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -15,6 +15,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io.utils import ( get_model_base_filenames, @@ -27,8 +28,6 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.utils import get_current_device -from colossalai.utils.device import IS_NPU_AVAILABLE from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -366,11 +365,11 @@ def __init__( ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" - if IS_NPU_AVAILABLE: + if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" self.gemini_config = dict( chunk_config_dict=chunk_config_dict, - chunk_init_device=(chunk_init_device or get_current_device()), + chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()), placement_policy=placement_policy, enable_gradient_accumulation=enable_gradient_accumulation, shard_param_frac=shard_param_frac, @@ -455,7 +454,7 @@ def control_device(self) -> bool: def supported_devices(self) -> List[str]: return ["cuda", "npu"] - + def prepare_dataloader( self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs ): @@ -486,7 +485,10 @@ def prepare_dataloader( zero_rank = self.pg_mesh.coordinate(ZERO_AXIS) extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS) sampler = DistributedSampler( - dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, shuffle=shuffle + dataset, + num_replicas=zero_world_size * extra_dp_world_size, + rank=zero_rank * extra_dp_world_size + extra_dp_rank, + shuffle=shuffle, ) # Deterministic dataloader diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 8ee1e97c6ce3..5837156a90cd 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -18,6 +18,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh @@ -28,7 +29,6 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.d_tensor.api import is_distributed_tensor -from colossalai.utils.device import get_current_device from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -82,7 +82,7 @@ def __init__( self.mixed_precision = torch.bfloat16 if self.mixed_precision is not None: module = module.to(self.mixed_precision) - module = module.to(get_current_device()) + module = module.to(get_accelerator().get_current_device()) # setting input type cast when using mixed precision self.convert_fn = None @@ -165,7 +165,6 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): Returns: None """ - if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism: if grads is not None: # Synchronize provided gradient tensors across the tensor parallelism group. @@ -346,7 +345,9 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if self.tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) if self.pp_size > 1: @@ -386,7 +387,7 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32 + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 ) if self.tp_size > 1: # compute norm in tp process group @@ -487,7 +488,6 @@ def backward(self, loss: Tensor, *args, **kwargs): Returns: None """ - # Call the superclass backward method to compute gradients. super().backward(loss, *args, **kwargs) @@ -513,7 +513,6 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor): Returns: None """ - # Call the superclass backward method to compute gradients. super().backward_by_grad(tensor, grad) @@ -545,7 +544,9 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ # so we need to calculate the norm of 'tp' and 'pp' gradients. total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if self.tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) @@ -589,7 +590,7 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32 + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 ) if self.tp_size > 1: # compute norm in tp process group @@ -674,7 +675,6 @@ def sync_dp_grads(self): Returns: None """ - # Call the superclass `_sync_grad` method to synchronize gradients. super()._sync_grad() @@ -802,7 +802,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo # so we only need to calculate the norm 'tp' of 'pp' gradients. total_norm = super()._compute_grad_norm(gradients, norm_type) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) @@ -842,7 +844,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32 + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 ) if dp_size > 1: # compute norm in dp process group @@ -1081,7 +1083,7 @@ def control_precision(self) -> bool: return True def support_no_sync(self) -> bool: - return False + return True def control_checkpoint_io(self) -> bool: return True @@ -1175,9 +1177,14 @@ def execute_pipeline( model, data_iter, criterion, optimizer, return_loss, return_outputs ) + # run with gradients accumulation + if model.require_grad_sync == False or ( + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + ): + return outputs + # Synchronize the grads of shared parameters of the model. model.sync_shared_params() - # Synchronize sequence parallelism gradients of the model. model.sync_sp_grads() @@ -1241,5 +1248,8 @@ def seed_worker(worker_id): def get_checkpoint_io(self) -> CheckpointIO: return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) - def no_sync(self, model: Module) -> Iterator[None]: - raise NotImplementedError + def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert ( + self.zero_stage != 2 + ), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed." + return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 89102820cd38..d21496f0b758 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -12,6 +12,7 @@ from torch.utils._pytree import tree_map from torch.utils.data import DataLoader +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO from colossalai.checkpoint_io.utils import ( get_optimizer_base_filenames, @@ -24,7 +25,6 @@ sharded_optimizer_loading_epilogue, ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper -from colossalai.utils import get_current_device from colossalai.zero import LowLevelZeroOptimizer from .dp_plugin_base import DPPluginBase @@ -52,7 +52,7 @@ def __init__(self, module: nn.Module, precision: str) -> None: self.dtype = torch.bfloat16 if self.dtype is not None: module = module.to(self.dtype) - module = module.to(get_current_device()) + module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None if self.dtype is not None: diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 25076b742c26..aaeaad3828f5 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -6,12 +6,12 @@ from pathlib import Path from typing import Dict, Union -import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.context import Config from colossalai.logging import get_dist_logger -from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed +from colossalai.utils import set_seed def launch( @@ -47,17 +47,18 @@ def launch( if rank == 0: warnings.warn("`config` is deprecated and will be removed soon.") - if IS_NPU_AVAILABLE and backend == "nccl": - backend = "hccl" + cur_accelerator = get_accelerator() + + backend = cur_accelerator.communication_backend # init default process group init_method = f"tcp://[{host}]:{port}" dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # set cuda device - if torch.cuda.is_available() or IS_NPU_AVAILABLE: - # if local rank is not given, calculate automatically - set_device(local_rank) + # if local rank is not given, calculate automatically + if cur_accelerator.support_set_device: + cur_accelerator.set_device(local_rank) set_seed(seed) diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index 8933fc0a3c2f..e69de29bb2d1 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,7 +0,0 @@ -from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention - -__all__ = [ - "LayerNorm", - "FusedScaleMaskSoftmax", - "MultiHeadAttention", -] diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu deleted file mode 100644 index 2b1b366b1c02..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu +++ /dev/null @@ -1,63 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "column_remap.cuh" -#include "util.cuh" - -const int SHUF_BLOCKSIZE_X = 256; -const int SHUF_BLOCKSIZE_Y = 16; - -__global__ void column_remap_kernel -( - const half* __restrict__ x, - half* __restrict__ x_new, - const int x_width, - const int x_height, - const uint32_t* x_map -) -{ - int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; - int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; - if (x_column >= x_width) return; - //if (x_row >= x_height) return; - - int x_stride = x_width; - int x_idx = x_row * x_stride + x_column; - - int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); - int x_idx_end = x_row_end * x_stride + x_column; - - int s_column = x_map[x_column]; - int s_idx = x_row * x_stride + s_column; - - while (x_idx < x_idx_end) - { - x_new[x_idx] = x[s_idx]; - x_idx += x_stride; - s_idx += x_stride; - } -} - -// Remap columns in x to correspond to sequential group index before matmul -// -// perform x -> seq_x such that seq_x @ seq_w == x @ w - -void column_remap_cuda -( - const half* x, - half* x_new, - const int x_height, - const int x_width, - const uint32_t* x_map -) -{ - dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); - - dim3 blocks - ( - (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, - (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, - 1 - ); - - column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh deleted file mode 100644 index 0364e38c4779..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh +++ /dev/null @@ -1,19 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _column_remap_cuh -#define _column_remap_cuh - -#include -#include -#include - -void column_remap_cuda -( - const half* x, - half* x_new, - const int x_height, - const int x_width, - const uint32_t* x_map -); - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh deleted file mode 100644 index c5258813e147..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh +++ /dev/null @@ -1,58 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _cuda_compat_cuh -#define _cuda_compat_cuh - -// atomicAdd for half types, to support CC < 7.x - -__device__ __forceinline__ void atomicAdd_half(half* address, half val) -{ - unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; - - do - { - assumed = old; - __half_raw hsum; - hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - half tmpres = __hadd(hsum, val); - hsum = __half_raw(tmpres); - old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; - old = atomicCAS(address_as_ui, assumed, old); - } - while (assumed != old); -} - -// atomicAdd for half2 types - -__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) -{ - unsigned int* address_as_ui = (unsigned int*)address; - unsigned int old = *address_as_ui; - unsigned int assumed; - do - { - assumed = old; - half2 old_val = *((half2*)&old); - half2 new_val = __hadd2(old_val, val); - old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); - } - while (assumed != old); -} - -// - -#if defined(__CUDA_ARCH__) || defined(USE_ROCM) -#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) - -__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } - -#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } -#endif - -#endif -#endif - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu deleted file mode 100644 index 4416027c8387..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu +++ /dev/null @@ -1,75 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#define _cuda_buffers_cu -#include "cuda_buffers.cuh" - -CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; -// __constant__ half2 q4_table[16][256]; -// half2 q4_table_host[16][256]; -// bool q4_table_init = false; - -CudaBuffers::CudaBuffers -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -) : - device(_device), - temp_state_size(_temp_state_size), - temp_state(_temp_state), - temp_dq(_temp_dq) -{ - cudaSetDevice(_device); - - cudaStreamCreate(&alt_stream_1); - cudaStreamCreate(&alt_stream_2); - cudaStreamCreate(&alt_stream_3); - cudaEventCreate(&alt_stream_1_done); - cudaEventCreate(&alt_stream_2_done); - cudaEventCreate(&alt_stream_3_done); -} - -CudaBuffers::~CudaBuffers() -{ - cudaStreamDestroy(alt_stream_1); - cudaStreamDestroy(alt_stream_2); - cudaStreamDestroy(alt_stream_3); - cudaEventDestroy(alt_stream_1_done); - cudaEventDestroy(alt_stream_2_done); - cudaEventDestroy(alt_stream_3_done); -} - -CudaBuffers* get_buffers(const int device_index) -{ - return g_buffers[device_index]; -} - -void prepare_buffers_cuda -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -) -{ - CudaBuffers* buffers = new CudaBuffers - ( - _device, - _temp_state_size, - _temp_state, - _temp_dq - ); - - g_buffers[_device] = buffers; -} - -void cleanup_buffers_cuda() -{ - for (int i = 0; i < CUDA_MAX_DEVICES; i++) - { - if (!g_buffers[i]) continue; - delete g_buffers[i]; - g_buffers[i] = NULL; - } -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh deleted file mode 100644 index 0bf2057c665c..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh +++ /dev/null @@ -1,55 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _cuda_buffers_cuh -#define _cuda_buffers_cuh - -#include -#include -#include -#include - -const int CUDA_MAX_DEVICES = 16; - -// #ifndef _cuda_buffers_cu -// extern __constant__ half2 q4_table[16][256]; -// #endif - -class CudaBuffers -{ -public: - int device; - - half* temp_state; // [max_hidden_rows * intermediate_size] - int temp_state_size; - half* temp_dq; // size of largest quant tensor * 8 - - cudaStream_t alt_stream_1; - cudaStream_t alt_stream_2; - cudaStream_t alt_stream_3; - cudaEvent_t alt_stream_1_done; - cudaEvent_t alt_stream_2_done; - cudaEvent_t alt_stream_3_done; - - CudaBuffers - ( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq - ); - ~CudaBuffers(); -}; - -CudaBuffers* get_buffers(const int device_index); - -void prepare_buffers_cuda -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -); - -void cleanup_buffers_cuda(); - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh deleted file mode 100644 index 5cd2e8553ef6..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh +++ /dev/null @@ -1,49 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _hip_compat_cuh -#define _hip_compat_cuh - -// Workaround for a bug in hipamd, backported from upstream. -__device__ __forceinline__ __half __compat_hrcp(__half x) { - return __half_raw{ - static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; -} - -__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { - return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), - static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; -} - -#define hrcp __compat_hrcp -#define h2rcp __compat_h2rcp - -// Workaround for hipify_python using rocblas instead of hipblas. -__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, - hipblasOperation_t transA, - hipblasOperation_t transB, - int m, - int n, - int k, - const half* alpha, - const half* AP, - int lda, - const half* BP, - int ldb, - const half* beta, - half* CP, - int ldc) { - return hipblasHgemm(handle, transA, transB, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(AP), lda, - reinterpret_cast(BP), ldb, - reinterpret_cast(beta), - reinterpret_cast(CP), ldc); -} - -#define rocblas_handle hipblasHandle_t -#define rocblas_operation_none HIPBLAS_OP_N -#define rocblas_get_stream hipblasGetStream -#define rocblas_set_stream hipblasSetStream -#define rocblas_hgemm __compat_hipblasHgemm - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp deleted file mode 100644 index bcc0e43901de..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp +++ /dev/null @@ -1,254 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include -#include -#include -#include -#include -#include -#include -#include "util.cuh" -#include "tuning.h" -#include "cuda_buffers.cuh" -#include "q4_matrix.cuh" -#include "q4_matmul.cuh" -#include "column_remap.cuh" - -// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a -// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of -// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. - -void check_cuda(cudaError_t ret) -{ - switch (ret) - { - case cudaSuccess: - break; - - case cudaUnspecified: - printf(" **** Unspecified error\n"); - TORCH_CHECK(false, "CUDA error"); - break; - - default: - printf(" **** CUDA error\n"); \ - printf(" **** %s\n", cudaGetErrorString(ret)); \ - TORCH_CHECK(false, "CUDA error"); \ - break; - } -} - -// Some decluttering macros - -#define STRINGIFY_(__x) #__x -#define STRINGIFY(__x) STRINGIFY_(__x) -#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) -#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") - -#define TORCH_CHECK_DEVICE_INDEX(__index) \ -do { \ - TORCH_CHECK(__index >= 0, "no device index"); \ - TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ -} while(0) - -#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ -do { \ - TORCH_CHECK_DTYPE(__w, kInt); \ - TORCH_CHECK_DTYPE(__w_scales, kHalf); \ - TORCH_CHECK_DTYPE(__w_zeros, kInt); \ - TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ - TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ - TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ - TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ -} while(0) - -int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) -{ - int groupsize = w.size(0) * 8 / w_zeros.size(0); - TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") - return groupsize; -} - - -// Tuning parameters - -ExLlamaTuning tuningParams; - -void set_tuning_params -( - int matmul_recons_thd, - bool matmul_fused_remap, - bool matmul_no_half2 -) -{ - tuningParams.matmul_recons_thd = matmul_recons_thd; - tuningParams.matmul_fused_remap = matmul_fused_remap; - tuningParams.matmul_no_half2 = matmul_no_half2; -} - - -// Release all unmanaged objects allocated by the extension - -void cleanup() -{ - cleanup_buffers_cuda(); - g_q4_free_matrices(); -} - - -// Prepare buffers for forward pass - -void prepare_buffers -( - torch::Device device, - torch::Tensor temp_state, - torch::Tensor temp_dq -) -{ - int device_index = device.index(); - TORCH_CHECK_DEVICE_INDEX(device_index); - const at::cuda::OptionalCUDAGuard device_guard(device); - - prepare_buffers_cuda - ( - device_index, - // buffer size used for sanity checks - temp_state.numel(), - (half*) temp_state.data_ptr(), - (half*) temp_dq.data_ptr() - ); -} - - -// Create Q4Matrix, return handle - -uintptr_t make_q4 -( - torch::Tensor qweight, - torch::Tensor qzeros, - torch::Tensor scales, - torch::Tensor g_idx, - int device -) -{ - TORCH_CHECK_DTYPE(qweight, kInt); - TORCH_CHECK_DTYPE(qzeros, kInt); - TORCH_CHECK_DTYPE(scales, kHalf); - TORCH_CHECK_DTYPE_OPT(g_idx, kInt); - TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); - TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); - TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); - - int width = qweight.size(1); - int height = qweight.size(0) * 8; - int groups = qzeros.size(0); - - Q4Matrix* m = new Q4Matrix - ( - height, - width, - groups, - - (uint32_t*) qweight.data_ptr(), - (uint32_t*) qzeros.data_ptr(), - (half*) scales.data_ptr(), - g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), - - device - ); - - g_q4_keep_matrix(m); - return reinterpret_cast (m); -} - - -// Matmul half @ quant -> half - -void q4_matmul -( - torch::Tensor x, - uintptr_t w, - torch::Tensor out -) -{ - Q4Matrix* wm = reinterpret_cast (w); - - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(out, kHalf); - TORCH_CHECK_SHAPES(x, 0, out, 0, 1); - TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - int x_height = x.size(0); - - if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) - { - q4_matmul_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr() - ); - } - else - { - q4_matmul_recons_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr(), - at::cuda::getCurrentCUDABlasHandle() - ); - } -} - - -// Remap columns in half tensor - -void column_remap -( - torch::Tensor x, - torch::Tensor x_new, - torch::Tensor x_map -) -{ - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(x_new, kHalf); - TORCH_CHECK_DTYPE(x_map, kInt); - TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); - - int height = x.size(0); - int width = x.size(1); - - TORCH_CHECK_BUFFER_SIZE(x_new, height * width); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - column_remap_cuda - ( - (half*) x.data_ptr(), - (half*) x_new.data_ptr(), - height, - width, - (uint32_t*) x_map.data_ptr() - ); -} - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); - m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); - m.def("cleanup", &cleanup, "cleanup"); - m.def("make_q4", &make_q4, "make_q4"); - m.def("q4_matmul", &q4_matmul, "q4_matmul"); -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh deleted file mode 100644 index 2fd5ab0b36cd..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh +++ /dev/null @@ -1,294 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _matrix_cuh -#define _matrix_cuh - -#include -#include - -class MatrixView_half -{ -public: - const half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } - __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } -}; - -class MatrixView_half_rw -{ -public: - half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } - __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } - __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } -}; - -class MatrixView_q4_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x07) * 4; - return (data[row * width / 8 + column / 8] >> shift) & 0x0f; - } -}; - -class MatrixView_q4_column -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (row & 0x07) * 4; - return (data[row / 8 * width + column] >> shift) & 0x0f; - } - - __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } - __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } -}; - -// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu - -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale - -__device__ __forceinline__ half2 dot_product_8 -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) - const int count -) -{ - const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - -// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) -// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; -// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; -// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; - - half2 tmp = __hmul2(*h_ptr++, v_01); - tmp = __hfma2(*h_ptr++, v_23, tmp); - tmp = __hfma2(*h_ptr++, v_45, tmp); - tmp = __hfma2(*h_ptr++, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, // + 1 (!!) - const int count -) -{ - const half* h_ptr = h_.item_ptr(h_row, h_column); - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(*h_ptr++, v_0); - tmp = __hfma(*h_ptr++, v_1, tmp); - tmp = __hfma(*h_ptr++, v_2, tmp); - tmp = __hfma(*h_ptr++, v_3, tmp); - tmp = __hfma(*h_ptr++, v_4, tmp); - tmp = __hfma(*h_ptr++, v_5, tmp); - tmp = __hfma(*h_ptr++, v_6, tmp); - tmp = __hfma(*h_ptr++, v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map - -__device__ __forceinline__ half2 dot_product_8_x_map -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - - half h_0 = h_ptr[*x_map_ptr++]; - half h_1 = h_ptr[*x_map_ptr++]; - half h_2 = h_ptr[*x_map_ptr++]; - half h_3 = h_ptr[*x_map_ptr++]; - half h_4 = h_ptr[*x_map_ptr++]; - half h_5 = h_ptr[*x_map_ptr++]; - half h_6 = h_ptr[*x_map_ptr++]; - half h_7 = h_ptr[*x_map_ptr++]; - - half2 h_01 = __halves2half2(h_0, h_1); - half2 h_23 = __halves2half2(h_2, h_3); - half2 h_45 = __halves2half2(h_4, h_5); - half2 h_67 = __halves2half2(h_6, h_7); - - half2 tmp = __hmul2(h_01, v_01); - tmp = __hfma2(h_23, v_23, tmp); - tmp = __hfma2(h_45, v_45, tmp); - tmp = __hfma2(h_67, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_x_map_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, // + 1 (!!) - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); - tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu deleted file mode 100644 index f47daeb0e877..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu +++ /dev/null @@ -1,260 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "q4_matmul.cuh" -#include "column_remap.cuh" -#include "util.cuh" -#include "matrix.cuh" -#include "cu_compat.cuh" -#include "cuda_buffers.cuh" -#if defined(USE_ROCM) -#include "hip_compat.cuh" -#endif - -const int THREADS_X = 32; // Block size and thread count along columns in w and out -const int THREADS_Y = 1; // Block size and thread count along rows in x and out - -typedef void (*fp_q4_matmul_kernel) -( - const half*, - const uint32_t*, - half*, - const half*, - const uint32_t*, - const int, - const int, - const int, - const int, - const int, - const uint32_t*, - bool -); - -template -__global__ void q4_matmul_kernel -( - const half* __restrict__ x, - const uint32_t* __restrict__ w, - half* __restrict__ out, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int height, - const int dim, - const int width, - const int groupsize, - const int block_size_z, - const uint32_t* __restrict__ x_map, - bool no_zero -) -{ - // Start of block - - int x_column = block_size_z * blockIdx.z; - int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); - - int w_column = THREADS_X * blockIdx.x + threadIdx.x; - int x_row = THREADS_Y * blockIdx.y + threadIdx.y; - - int iterations = (x_column_end - x_column) / 8; - - // Views - - MatrixView_half x_(x, height, dim); - MatrixView_half w_scales_(w_scales, dim / groupsize, width); - MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); - MatrixView_q4_column w_(w, dim, width); - MatrixView_half_rw out_(out, height, width); - - // Zero output - - if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) - { - *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; - __syncthreads(); - } - - // Loop over part of x row (and w column) - - half2 acc = {}; - half acc_h = {}; - - if constexpr (use_groupsize) - { - // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this - // could be slightly faster - - for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) - { - if constexpr (use_half2) - { - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - else - { - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - } - } - else - { - // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache - - for (int k = x_column; k < x_column + iterations * 8; k += 8) - { - if constexpr (use_half2) - { - int group = k / groupsize; - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - } - else - { - int group = k / groupsize; - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - } - } - } - - // Add to block result - - if constexpr (use_half2) - { - half result = __hadd(__low2half(acc), __high2half(acc)); - atomicAdd(out_.item_ptr(x_row, w_column), result); - } - else - { - atomicAdd(out_.item_ptr(x_row, w_column), acc_h); - } -} - -fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) -{ - // - if (tuningParams->matmul_no_half2) { - if (block_size_z % groupsize == 0) { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } else { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } - } else { - if (block_size_z % groupsize == 0) - { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } else { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } - } -}; - -// Compute y = x @ w - -void q4_matmul_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - const Q4Matrix* w, - half* out, - bool no_zero, - cudaStream_t alt_stream -) -{ - int height = x_height; - int dim = w->height; - int width = w->width; - - cudaSetDevice(w->device); - - uint32_t* x_map = w->cuda_x_map; - const half* x_mapped = x; - if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) - { - CudaBuffers* buffers = get_buffers(w->device); - column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); - x_mapped = buffers->temp_state; - x_map = NULL; - } - - int block_size_z; - if (w->width == 4096) block_size_z = 384; // 7B - else if (w->width == 11008) block_size_z = 256; - else if (w->width == 5120) block_size_z = 384; // 13B - else if (w->width == 13824) block_size_z = 256; - else if (w->width == 6656) block_size_z = 256; // 33B - else if (w->width == 17920) block_size_z = 128; - else block_size_z = 256; - - //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); - - dim3 threads(THREADS_X, THREADS_Y, 1); - - dim3 blocks - ( - (width + threads.x - 1) / threads.x, - (height + threads.y - 1) / threads.y, - (dim + block_size_z - 1) / block_size_z - ); - - fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); - - kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); -} - -void q4_matmul_recons_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - Q4Matrix* w, - half* out, - const cublasHandle_t handle, - bool no_zero -) -{ - int height = x_height; - int dim = w->height; - int width = w->width; - - cudaSetDevice(w->device); - CudaBuffers* buffers = get_buffers(w->device); - - const half* x_mapped = x; - if (w->cuda_x_map) - { - TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small"); - column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); - x_mapped = buffers->temp_state; - } - - w->reconstruct(buffers->temp_dq); - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 - const float alpha = 1.0f; - const float beta = no_zero ? 1.0f : 0.0f; - cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, - x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); -#else - const half alpha = __float2half(1.0f); - const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); - cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); -#endif -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh deleted file mode 100644 index 09f3e1a63362..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh +++ /dev/null @@ -1,43 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _q4_matmul_cuh -#define _q4_matmul_cuh - -#include -#include -#include -#include -#include - -#include "q4_matrix.cuh" -#include "tuning.h" - -// Workaround for hipify_python using rocblas instead of hipblas. -#if defined(USE_ROCM) -#include -#define rocblas_handle hipblasHandle_t -#endif - -void q4_matmul_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - const Q4Matrix* w, - half* out, - bool no_zero = false, - cudaStream_t alt_stream = NULL -); - -void q4_matmul_recons_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - Q4Matrix* w, - half* out, - const cublasHandle_t handle, - bool no_zero = false -); - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu deleted file mode 100644 index 9c61143f565e..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu +++ /dev/null @@ -1,225 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "q4_matrix.cuh" -#include -#include "util.cuh" -#include "matrix.cuh" - -using namespace std; - -const int UNSHUF_BLOCKSIZE_X = 64; - -const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column -const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows - -vector g_q4_matrices; - -void g_q4_keep_matrix(Q4Matrix* m) -{ - g_q4_matrices.push_back(m); -} - -void g_q4_free_matrices() -{ - for (const auto& m : g_q4_matrices) delete m; - g_q4_matrices.clear(); -} - -Q4Matrix::Q4Matrix -( - const int _height, - const int _width, - const int _groups, - - uint32_t* _qweight, - uint32_t* _qzeros, - half* _scales, - uint32_t* _g_idx, - - const int _device -) : - height(_height), - width(_width), - groups(_groups), - device(_device) -{ - cudaSetDevice(device); - - cuda_qweight = _qweight; - cuda_qzeros = _qzeros; - cuda_scales = _scales; - - groupsize = height / groups; - - if (_g_idx) make_sequential(_g_idx); -} - -Q4Matrix::~Q4Matrix() -{ -} - -// Make sequential - -__global__ void make_sequential_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const uint32_t* __restrict__ x_map, - const int w_height, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - - int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - - int w_new2_row = blockIdx.y; - - int x_map_idx = w_new2_row << 3; - - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - int source_row = x_map[x_map_idx++]; - - int w2_row = source_row >> 3; - int w2_subrow = source_row & 0x07; - int w2_row_shift = w2_subrow << 2; - int wnew2_row_shift = i << 2; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000f0000000f; - src <<= wnew2_row_shift; - dst |= src; - } - - w_new2[w_new2_row * w2_stride + w2_column] = dst; -} - -void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) -{ - uint32_t* cuda_new_qweight = NULL; - cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); - cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch - - uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); - uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); - uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); - - // Group histogram - - for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; - - // Group map - - for (int i = 0, acc = 0; i < groups; i++) - { - short tmp = cpu_g_idx_map[i]; - cpu_g_idx_map[i] = acc; - acc += tmp; - } - - // X map (inverse) - - for (int row = 0; row < height; row++) - { - uint32_t target_group = cpu_g_idx[row]; - uint32_t target_row = cpu_g_idx_map[target_group]; - cpu_g_idx_map[target_group]++; - cpu_x_map_inv[row] = target_row; - } - - // X map - - for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; - - // Move to CUDA - - cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); - - // Rearrange rows in w - - dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); - dim3 blocks - ( - (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2), - height / 8, - 1 - ); - - make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); - - // Replace qweights - - cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); - - // Cleanup - - cudaDeviceSynchronize(); - cudaFree(cuda_new_qweight); - free(cpu_g_idx_map); - free(cpu_x_map); - free(cpu_x_map_inv); -} - -__global__ void reconstruct_kernel -( - const uint32_t* __restrict__ w, - half* __restrict__ out, // (y) - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int height, - const int width, - const int groupsize -) -{ - // Start of block - - int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; - int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; - if (column >= width) return; - - // Views - - MatrixView_q4_column w_(w, height, width); - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, height / groupsize, width); - MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); - - // Groupsize version - - int group = row / groupsize; - - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - - uint32_t w_read = w_.item_uint32_t(row, column); - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int s = 0; s < 32; s += 4) - { - half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); - *out_ptr = w_item; out_ptr += out_.width; - } -} - -void Q4Matrix::reconstruct(half* out) -{ - dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); - - dim3 blocks - ( - (width + threads.x - 1) / threads.x, - (height / 8 + threads.y - 1) / threads.y, - 1 - ); - - reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh deleted file mode 100644 index 50cb72a41518..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh +++ /dev/null @@ -1,53 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _q4_matrix_cuh -#define _q4_matrix_cuh - -#include -#include -#include - -class Q4Matrix -{ -public: - - int device; - - int height; - int width; - int groups; - int groupsize; - - uint32_t* cuda_qweight = NULL; - uint32_t* cuda_qzeros = NULL; - half* cuda_scales = NULL; - uint32_t* cuda_x_map = NULL; - - Q4Matrix - ( - const int _height, - const int _width, - const int _groups, - - uint32_t* _qweight, - uint32_t* _qzeros, - half* _scales, - uint32_t* _g_idx, - - const int _device - ); - - ~Q4Matrix(); - - void reconstruct(half* out); - -private: - - void make_sequential(const uint32_t* cpu_g_idx); - -}; - -void g_q4_keep_matrix(Q4Matrix* m); -void g_q4_free_matrices(); - -#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h deleted file mode 100644 index e413b8a96c11..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h +++ /dev/null @@ -1,12 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _tuning_h -#define _tuning_h - -struct ExLlamaTuning { - int matmul_recons_thd; - bool matmul_fused_remap; - bool matmul_no_half2; -}; - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh deleted file mode 100644 index 7b397573214b..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh +++ /dev/null @@ -1,33 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _util_cuh -#define _util_cuh - -#include -#include -#include -#include - -#if defined(USE_ROCM) -#define cudaUnspecified hipErrorUnknown -#else -#define cudaUnspecified cudaErrorApiFailureBase -#endif - -// React to failure on return code != cudaSuccess - -#define _cuda_check(fn) \ -do { \ - {_cuda_err = fn;} \ - if (_cuda_err != cudaSuccess) goto _cuda_fail; \ -} while(false) - -// React to failure on return code == 0 - -#define _alloc_check(fn) \ -do { \ - if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ - else _cuda_err = cudaSuccess; \ -} while(false) - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu deleted file mode 100644 index 58d26235a9cc..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu +++ /dev/null @@ -1,191 +0,0 @@ -#include "block_reduce.h" -#include "cuda_util.h" -#include "kernels.h" -#include "ls_cub.cuh" - -ls::cub::CachingDeviceAllocator g_allocator(true); - -template -__global__ void ls_cross_entropy_fw_kernel( - const T *__restrict__ inputs, const int *__restrict__ targets, - float *__restrict__ outputs, float *__restrict__ nll_loss_outputs, - const int padding_idx, const float epsilon, const int vocab_size) { - /* step1: compute each thread's max_logit and sum_exp_logit, store in - * max_input, sum_exp_logit */ - const int block_start = blockIdx.x * vocab_size; - const int left_idx = block_start + threadIdx.x; - const int right_idx = (blockIdx.x + 1) * vocab_size; - float max_input[1] = {REDUCE_FLOAT_INF_NEG}; - float sum_logits[2] = {0.f, 0.f}; // logit and logit exp - int target_tid = targets[blockIdx.x]; - - if (target_tid == padding_idx) { - if (threadIdx.x == 0) { - nll_loss_outputs[blockIdx.x] = 0.f; - outputs[blockIdx.x] = 0.f; - } - return; - } - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); - } - blockReduce(max_input); - __shared__ float s_max_input; - if (threadIdx.x == 0) { - s_max_input = max_input[0]; - } - __syncthreads(); - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - float logit = static_cast(inputs[i]) - s_max_input; - sum_logits[0] += logit; - sum_logits[1] += expf(logit); - } - - blockReduce(sum_logits); - __shared__ float s_sum_logit; - __shared__ float s_sum_exp; - if (threadIdx.x == 0) { - s_sum_logit = sum_logits[0]; - s_sum_exp = sum_logits[1]; - } - __syncthreads(); - - float eps_i = epsilon / (vocab_size - 1); - if (threadIdx.x == 0) { - // neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max) - float nll_loss = logf(s_sum_exp) - - static_cast(inputs[block_start + target_tid]) + - s_max_input; - nll_loss_outputs[blockIdx.x] = nll_loss; - float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit; - outputs[blockIdx.x] = - (1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss; - } -} - -template -__global__ void ls_cross_entropy_bw_kernel( - const float *__restrict__ grad_outputs, const T *__restrict__ inputs, - const int *__restrict__ targets, T *__restrict__ grad_inputs, - const int padding_idx, const float epsilon, const int vocab_size) { - /* step1: compute each thread's max_logit and sum_exp_logit, store in - * max_input, sum_exp_logit */ - const int block_start = blockIdx.x * vocab_size; - const int left_idx = block_start + threadIdx.x; - const int right_idx = (blockIdx.x + 1) * vocab_size; - float max_input[1] = {REDUCE_FLOAT_INF_NEG}; - float sum_logits[1] = {0.f}; - const float grad_out = static_cast(grad_outputs[0]); - int target_tid = targets[blockIdx.x]; - - if (target_tid == padding_idx) { - for (int i = left_idx; i < right_idx; i += blockDim.x) { - grad_inputs[i] = 0.f; - } - return; - } - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); - } - blockReduce(max_input); - __shared__ float s_max_input; - if (threadIdx.x == 0) { - s_max_input = max_input[0]; - } - __syncthreads(); - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - float logit = static_cast(inputs[i]) - s_max_input; - sum_logits[0] += expf(logit); - } - - blockReduce(sum_logits); - __shared__ float s_sum_exp; - if (threadIdx.x == 0) { - s_sum_exp = sum_logits[0]; - } - __syncthreads(); - - float eps_i = epsilon / (vocab_size - 1); - float nll_weight = 1.0 - epsilon - eps_i; - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - float prob = expf(static_cast(inputs[i]) - s_max_input) / s_sum_exp; - float grad = 0; - grad += (vocab_size * prob - 1) * eps_i; - grad += prob * nll_weight; - if ((i - block_start) == target_tid) { - grad -= nll_weight; - } - grad_inputs[i] = grad_out * grad; - } -} - -template -void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, - float *outputs_ptr, float *nll_loss_ptr, - float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, - const int seq_len, const int vocab_size, - cudaStream_t stream) { - int grid_dim = batch_size * seq_len; - float *nll_loss_buffer = loss_buffer + grid_dim; - ls_cross_entropy_fw_kernel<<>>( - inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx, - epsilon, vocab_size); - - int num_items = grid_dim; - void *d_temp_storage = NULL; - size_t temp_storage_bytes = 0; - CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, - loss_buffer, outputs_ptr, - num_items, stream)); - CHECK_GPU_ERROR( - g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes)); - CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, - loss_buffer, outputs_ptr, - num_items, stream)); - CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, - nll_loss_buffer, nll_loss_ptr, - num_items, stream)); - CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage)); -} - -template void launch_cross_entropy_fw( - const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr, - float *nll_loss_ptr, float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template void launch_cross_entropy_fw<__half>( - const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr, - float *nll_loss_ptr, float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template -void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, - const int *targets_ptr, T *grad_inputs_ptr, - const int padding_idx, const float epsilon, - const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream) { - int grid_dim = batch_size * seq_len; - ls_cross_entropy_bw_kernel<<>>( - grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx, - epsilon, vocab_size); -} - -template void launch_cross_entropy_bw( - const float *grad_outputs_ptr, const float *inputs_ptr, - const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template void launch_cross_entropy_bw<__half>( - const float *grad_outputs_ptr, const __half *inputs_ptr, - const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu deleted file mode 100644 index 09f34763f9b2..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#include "cublas_wrappers.h" - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const float *A, - const float *B, float *C, cublasGemmAlgo_t algo) { - cublasStatus_t status = - cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha, - (const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k, - (const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n, - (const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, n, k, (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const __half *A, - const __half *B, __half *C, cublasGemmAlgo_t algo) { - cublasStatus_t status = cublasGemmEx( - handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A, - CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F, - (transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C, - CUDA_R_16F, m, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, n, k, (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, - const float *alpha, const float *beta, - const float *A, const float *B, float *C, - cublasOperation_t op_A, cublasOperation_t op_B, - int stride_A, int stride_B, int stride_C, - int batch, cublasGemmAlgo_t algo) { - cublasStatus_t status = cublasGemmStridedBatchedEx( - handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F, - (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F, - (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C, - batch, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, " - "error: %d) \n", - batch, m, n, k, (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, - const float *alpha, const float *beta, - const __half *A, const __half *B, __half *C, - cublasOperation_t op_A, cublasOperation_t op_B, - int stride_A, int stride_B, int stride_C, - int batch, cublasGemmAlgo_t algo) { - cublasStatus_t status = cublasGemmStridedBatchedEx( - handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F, - (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F, - (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C, - batch, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, n, k, (int)status); - return EXIT_FAILURE; - } - - return 0; -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu deleted file mode 100644 index e5ac17308640..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu +++ /dev/null @@ -1,169 +0,0 @@ -#include -#include -#include -#include "cuda_util.h" - -/* GPU function guard */ -std::string _cudaGetErrorString(cudaError_t error) { - return cudaGetErrorString(error); -} - -std::string _cudaGetErrorString(cublasStatus_t error) { - switch (error) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; - - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; - } - return "CUBLAS_UNKNOW"; -} - -template -void check_gpu_error(T result, char const *const func, const char *const file, - int const line) { - if (result) { - throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" + - std::to_string(line) + - "): " + (_cudaGetErrorString(result)) + "\n"); - } -} - -template void check_gpu_error(cudaError_t result, - char const *const func, - const char *const file, - int const line); -template void check_gpu_error(cublasStatus_t result, - char const *const func, - const char *const file, - int const line); - -template -void print_vec(const T *outv, std::string outn, int num_output_ele) { - std::cout << outn << ": "; - std::vector hout(num_output_ele, (T)0); - cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T), - cudaMemcpyDeviceToHost); - for (int i = 0; i < num_output_ele; i++) { - std::cout << hout[i] << ", "; - } - std::cout << std::endl; -} - -template <> -void print_vec<__half>(const __half *outv, std::string outn, - int num_output_ele) { - std::cout << outn << ": "; - std::vector<__half> hout(num_output_ele, (__half)0.f); - cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half), - cudaMemcpyDeviceToHost); - for (int i = 0; i < num_output_ele; i++) { - std::cout << __half2float(hout[i]) << ", "; - } - std::cout << std::endl; -} - -template void print_vec(const float *outv, std::string outn, - int num_output_ele); - -template void print_vec(const int *outv, std::string outn, - int num_output_ele); - -template void print_vec<__half>(const __half *outv, std::string outn, - int num_output_ele); - -template -T *cuda_malloc(size_t ele_num) { - size_t byte_size = ele_num * sizeof(T); - T *pdata = nullptr; - CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size)); - return pdata; -} - -template float *cuda_malloc(size_t ele_num); - -template __half *cuda_malloc<__half>(size_t ele_num); - -template uint8_t *cuda_malloc(size_t ele_num); - -void cuda_free(void *pdata) { - if (pdata != nullptr) { - cudaFree(pdata); - } -} - -template -struct _isnan { - __device__ bool operator()(T a) const { return isnan(a); } -}; - -template <> -struct _isnan<__half> { - __device__ bool operator()(const __half a) const { return __hisnan(a); } -}; - -template -struct _isinf { - __device__ bool operator()(T a) const { return isinf(a); } -}; - -template <> -struct _isinf<__half> { - __device__ bool operator()(const __half a) const { return __hisinf(a); } -}; - -template -void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, - std::string file, int line, cudaStream_t stream) { - // check_nan_inf = 0 for checking nan - // check_nan_inf = 1 for checking inf - bool res = false; - std::string msg = file + "(" + std::to_string(line) + "): "; - if (check_nan_inf) { - msg += "nan."; - res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, - data_ptr + dsize, _isnan(), false, - thrust::logical_or()); - } else { - msg += "inf."; - res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, - data_ptr + dsize, _isinf(), false, - thrust::logical_or()); - } - if (res) { - throw std::runtime_error(msg); - } - std::cout << msg << " [check pass]." << std::endl; -} - -template void check_nan_inf(const float *data_ptr, int dsize, - bool check_nan_inf, std::string file, - int line, cudaStream_t stream); - -template void check_nan_inf<__half>(const __half *data_ptr, int dsize, - bool check_nan_inf, std::string file, - int line, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu deleted file mode 100644 index ce0b017f12e1..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu +++ /dev/null @@ -1,1002 +0,0 @@ -#include -#include - -#include "kernels.h" - -#include - - -namespace cg = cooperative_groups; - -curandStatePhilox4_32_10_t *curandstate; - -/** - * @brief element-wise activation function on device, like Relu, Gelu - * - * @tparam enum class ActivationType, kRelu, kGelu - * @tparam input type - * @param any shape of float and __half2 - * @return same shape and type with input - */ -template -__forceinline__ __device__ T activation_kernel(T x); - -template <> -__device__ float activation_kernel(float x) { - float cdf = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -template <> -__device__ __half2 -activation_kernel(__half2 val) { - __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); - float2 tmp_pow = __half22float2(val_pow3); - float2 tmp = __half22float2(val); - - tmp.x = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); - tmp.y = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); - return __hmul2(val, __float22half2_rn(tmp)); -} - -template <> -__device__ float activation_kernel(float x) { - return fmaxf(x, 0); -} - -template <> -__device__ __half2 -activation_kernel(__half2 x) { - return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), - fmaxf(0.f, __half2float(x.y))); -} - -/** - * @brief element-wise activation backward function on device - * - * @tparam enum class ActivationType - * @tparam input type - * @param any shape of float and __half2 - * @return same shape of input - */ -template -__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * (dg1 + dg2 + dg3); -} - -template <> -__device__ __half activation_bwd_kernel( - __half grad, __half x_half) { - float x = __half2float(x_half); - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * __float2half(dg1 + dg2 + dg3); -} - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - return x > 0.f ? grad : 0.f; -} - -template <> -__device__ __half -activation_bwd_kernel(__half grad, __half x) { - const __half half_zero = __float2half(0.f); - return x > half_zero ? grad : half_zero; -} - -template <> -__device__ __half2 activation_bwd_kernel( - __half2 grad2, __half2 x_half2) { - const __half half_zero = __float2half(0.f); - return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, - x_half2.y > half_zero ? grad2.y : half_zero); -} - -/** - * @brief init curand states in global memory - * - * @thread grid_dim * block*dim to suuport any size of states - * @param state persistant curand states - * @param seed seed to init states - * @return void - */ -__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, - int seed) { - /* Each thread gets same seed, a different sequence - number, no offset */ - int id = threadIdx.x + blockIdx.x * blockDim.x; - curand_init(seed, id, 0, &state[id]); -} - -void launch_curand_init(int total_count, int dim, cudaStream_t stream) { - cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); - int grid_dim = total_count >> 9; - curand_init_kernel<<>>( - curandstate, std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); -} - -/** - * @brief element-wise dropout, store dropped position in mask, it's not - * in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out any size of float and __half - * @param in same with out - * @param mask uint8 type, same size with out - * @param seed seed to curand - * @return void - */ -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - float *__restrict__ out, - const float *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - - float4 input4 = data4[i]; - float4 res4; - res4.x = input4.x * scale * m[0]; - res4.y = input4.y * scale * m[1]; - res4.z = input4.z * scale * m[2]; - res4.w = input4.w * scale * m[3]; - out4[i] = res4; -} - -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - __half *__restrict__ out, - const __half *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - outs_float4[i] = out_float4; -} - -/** - * @brief element-wise dropout backward with dropout mask, it's - * not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param in any size of float and __half - * @param mask uint8 type, same size with in - * @return void - */ -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - float *out, const float *in, - const uint8_t *__restrict__ mask) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *in4 = reinterpret_cast(in); - const uint32_t *mask4 = reinterpret_cast(mask); - - uint32_t *m4 = reinterpret_cast(m); - m4[0] = mask4[i]; - - float4 input4 = in4[i]; - float4 res4; - res4.x = input4.x * scale * static_cast(m[0]); - res4.y = input4.y * scale * static_cast(m[1]); - res4.z = input4.z * scale * static_cast(m[2]); - res4.w = input4.w * scale * static_cast(m[3]); - out4[i] = res4; -} - -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - __half *out, const __half *in, - const uint8_t *__restrict__ mask) { - const __half scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - float4 *out4 = reinterpret_cast(out); - const float4 *vals_float4 = reinterpret_cast(in); - const uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - uint64_t *m8 = reinterpret_cast(m); - m8[0] = mask8[i]; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - out4[i] = out_float4; -} - -template <> -void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, - int total_count, float ratio, cudaStream_t stream, - bool backward) { - int grid_dim = total_count >> 12; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -template <> -void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, - int total_count, float ratio, - cudaStream_t stream, bool backward) { - int grid_dim = total_count >> 13; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -/** - * @brief fused bias, dropout, and residual at the end of Attention and FFN, - * store dropped position in mask, it's not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param residual [batch_size, seq_len, hidden_size], float and __half - * @param seed seed to curand - * @param hidden_size hidden size - * @return void - */ -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const float *__restrict__ residual, - const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 output4; - - output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; - output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; - output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; - output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; - - out4[i] = output4; -} - -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const __half *__restrict__ residual, - const int seed, const int hidden_size) { - const __half scale = 1. / (1. - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = static_cast(rand.x > ratio); - m[5] = static_cast(rand.y > ratio); - m[6] = static_cast(rand.z > ratio); - m[7] = static_cast(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = m8[0]; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - const __half2 *res_half2 = reinterpret_cast(&res4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = - __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); - out_half2[1] = - __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); - out_half2[2] = - __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); - out_half2[3] = - __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_res_bias(float *out, const float *vals, - uint8_t *mask, const float *bias, - const float *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 12; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, - uint8_t *mask, const __half *bias, - const __half *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 13; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias and dropout backward at the end of Attention and FFN - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, float *__restrict__ in_grad, - float *__restrict__ bias_grad, const float *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - // every block generate 8 bias result - __shared__ float tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - float val = out_grad[idx]; - val *= scale * static_cast(mask[idx]); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - float sum = 0; - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, __half *__restrict__ in_grad, - __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); - __shared__ __half2 tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); - const __half2 *out_grad2 = reinterpret_cast(out_grad); - __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - __half2 local_sum = __float2half2_rn(0.f); - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - __half2 val = out_grad2[idx]; - __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); - val *= scale * m2; - local_sum += val; - in_grad2[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - __half2 sum = __float2half2_rn(0.f); - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad2[pos] = tile[0][threadIdx.x]; - } -} - -template -void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template <> -void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, - const __half *out_grad, const uint8_t *mask, - int row_size, int dim, float ratio, - cudaStream_t stream) { - dim >>= 1; - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, - const float *out_grad, - const uint8_t *mask, int row_size, - int dim, float ratio, - cudaStream_t stream); - -/** - * @brief fused bias, activation, and dropout at the end of first ffn - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @tparam act_type activation function, like kRelu, kGelu - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param seed seed to curand - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 output4; - - output4.x = - activation_kernel(input4.x + b4.x) * scale * m[0]; - output4.y = - activation_kernel(input4.y + b4.y) * scale * m[1]; - output4.z = - activation_kernel(input4.z + b4.z) * scale * m[2]; - output4.w = - activation_kernel(input4.w + b4.w) * scale * m[3]; - - out4[i] = output4; -} - -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2( - activation_kernel(__hadd2(val_half2[0], b_half2[0])), - scale_mask_1); - out_half2[1] = __hmul2( - activation_kernel(__hadd2(val_half2[1], b_half2[1])), - scale_mask_2); - out_half2[2] = __hmul2( - activation_kernel(__hadd2(val_half2[2], b_half2[2])), - scale_mask_3); - out_half2[3] = __hmul2( - activation_kernel(__hadd2(val_half2[3], b_half2[3])), - scale_mask_4); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias, activation, and dropout backward - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @tparam act_type kRelu - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_bwd_kernel( - const int row_size, const float ratio, T *in_grad, - T *__restrict__ bias_grad, const T *__restrict__ input, - const T *__restrict__ bias, const T *out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - - int stride = hidden_size * WARP_SIZE; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - if (col_idx < hidden_size) { - for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { - float val = out_grad[idx]; - float in = input[idx]; - float b = bias[idx % hidden_size]; - val = activation_bwd_kernel( - val * scale * static_cast(mask[idx]), in + b); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - float sum = tile[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; - __syncthreads(); - - if (threadIdx.y == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -// @brief fused bias, activation, and dropout backward -// It is deprecated for precision reason. Keep it for future optimization. -// -// template -// __global__ void ls_dropout_act_bias_bwd_kernel( -// const int row_size, const float ratio, __half * in_grad, -// __half *__restrict__ bias_grad, const __half *__restrict__ input, const -// __half *__restrict__ bias, const __half * out_grad, const uint8_t -// *__restrict__ mask, const int hidden_size) { -// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); -// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; - -// cg::thread_block b = cg::this_thread_block(); -// cg::thread_block_tile g = cg::tiled_partition(b); - -// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); -// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); -// const __half2 *out_grad2 = reinterpret_cast(out_grad); -// const __half2 *input2 = reinterpret_cast(input); -// const __half2 *bias2 = reinterpret_cast(bias); - -// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - -// int stride = hidden_size * WARP_SIZE; -// __half2 local_sum = __float2half2_rn(0.f); - -// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); -// if (col_idx < hidden_size) { -// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { -// __half2 val = out_grad2[idx]; -// __half2 in2 = input2[idx]; -// __half2 b2 = bias2[idx % hidden_size ]; -// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); -// val = activation_bwd_kernel(val * scale -// * -// m2, -// in2+b2); -// local_sum += val; -// in_grad2[idx] = val; -// idx += stride; -// } -// } - -// tile[threadIdx.x][threadIdx.y] = local_sum; -// __syncthreads(); -// __half2 sum = tile[threadIdx.y][threadIdx.x]; -// __syncthreads(); - -// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - -// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; -// __syncthreads(); - -// if (threadIdx.y == 0) { -// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); -// bias_grad2[pos] = tile[0][threadIdx.x]; -// } -// } - -template -void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, - const T *bias, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - ls_dropout_act_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); -} - -// template <> -// void launch_ls_dropout_act_bias_bwd( -// __half *in_grad, __half *bias_grad,const __half *input, const __half -// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int -// dim, float ratio, cudaStream_t stream) { -// dim >>= 1; -// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); -// dim3 block_dim(WARP_SIZE, WARP_SIZE); -// ls_dropout_act_bias_bwd_kernel -// <<>>(row_size, ratio, in_grad, -// bias_grad, -// input, bias,out_grad, mask, dim); -// } - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu deleted file mode 100644 index 625b02cd25d9..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu +++ /dev/null @@ -1,232 +0,0 @@ -#include - -#include "kernels.h" - -namespace cg = cooperative_groups; - -/** -@brief: fuse_transpose_bias -Calculate the sum of elements in each column of the matrix. - -@thread -gridDim.x = ceil(cols / WARP_SIZE) -blockDim.x = WARP_SIZE -blockDim.y = WARP_SIZE - -@param -inp: [rows, cols] -out: [cols] -rows: the number of rows in the matrix -cols: the number of cols in the matrix -*/ -template -__global__ void column_sum_reduce(const T *__restrict__ inp, - T *__restrict__ out, int rows, int cols) { - __shared__ float tile[WARP_SIZE][WARP_SIZE]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - int y_stride = cols * WARP_SIZE; - float localSum = 0; - - // Loop across matrix row - // TODO: optimize to log complexity - if (idx < cols) { - int offset = flat_2dim(threadIdx.y, idx, cols); - for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { - localSum += (float)inp[offset]; - offset += y_stride; - } - } - - // The sum of a row in tile is equal to the sum of a col in original matrix - tile[threadIdx.x][threadIdx.y] = localSum; - - __syncthreads(); - - // Sum the shared buffer. - // The change of threadIdx.x is continuous - float sum = tile[threadIdx.y][threadIdx.x]; - - __syncthreads(); - - // Calculate the sum of a row in tile - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); - if (pos < cols) out[pos] = sum; - } -} - -// [r, c] -> [c] -template <> -void launch_fuse_transpose_bias_kernel(const float *inp, float *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce - <<>>(inp, out, rows, cols); -} - -template <> -void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce<__half> - <<>>(inp, out, rows, cols); -} - -/** -@brief: fused_add2 -Add two matrix inp1 and inp2 to out. - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -inp1: [batch_size, seq_len, hidden_dim] -inp2: [batch_size, seq_len, hidden_dim] -out: [batch_size, seq_len, hidden_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -*/ -template -__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, - int hidden_dim); - -template <> -__global__ void fused_add2_kernel(float *out, const float *inp1, - const float *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - val.x = vinp1.x + vinp2.x; - val.y = vinp1.y + vinp2.y; - val.z = vinp1.z + vinp2.z; - val.w = vinp1.w + vinp2.w; - out_4[offset + i] = val; - } -} - -template <> -__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, - const __half *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); - __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); - __half2 *h2_val = reinterpret_cast<__half2 *>(&val); - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); - h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); - h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); - h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); - out_4[offset + i] = val; - } -} - -//[b, s, h] -> [b, s, h] -template <> -void launch_fused_add2(float *out, const float *inp1, const float *inp2, - int batch_size, int seq_len, int hidden_dim, - cudaStream_t &stream) { - hidden_dim >>= 2; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template <> -void launch_fused_add2<__half>(__half *out, const __half *inp1, - const __half *inp2, int batch_size, int seq_len, - int hidden_dim, cudaStream_t &stream) { - hidden_dim >>= 3; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template -__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, - int sz0, int sz2, int sz1_1, int sz1_2) { - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); - if (idx >= nele) { - return; - } - float4 *dst_ptr = (float4 *)output + idx; - int idx2 = idx % sz2; - idx = idx / sz2; - int idx1 = idx % (sz1_1 + sz1_2); - int idx0 = idx / (sz1_1 + sz1_2); - float4 *src_ptr = nullptr; - int sz1 = 0; - if (idx1 < sz1_1) { - sz1 = sz1_1; - src_ptr = (float4 *)inp1; - } else { - idx1 -= sz1_1; - sz1 = sz1_2; - src_ptr = (float4 *)inp2; - } - src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); - dst_ptr[0] = src_ptr[0]; -} - -template <> -void launch_concat3_dim1(const float *inp1, const float *inp2, - float *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 2; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} - -template <> -void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, - __half *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 3; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h b/colossalai/kernel/cuda_native/csrc/kernels/include/context.h deleted file mode 100644 index f7d75f38cc2b..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -#include -#include - -#include -#include - -#include "cuda_util.h" - -class Context { - public: - Context() : _stream(nullptr) { - CHECK_GPU_ERROR(cublasCreate(&_cublasHandle)); - } - - virtual ~Context() {} - - static Context &Instance() { - static Context _ctx; - return _ctx; - } - - void set_stream(cudaStream_t stream) { - _stream = stream; - CHECK_GPU_ERROR(cublasSetStream(_cublasHandle, _stream)); - } - - cudaStream_t get_stream() { return _stream; } - - cublasHandle_t get_cublashandle() { return _cublasHandle; } - - private: - cudaStream_t _stream; - cublasHandle_t _cublasHandle; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h deleted file mode 100644 index f4e9befc6588..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "cuda_util.h" - -template -class CrossEntropyLayer { - public: - CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens); - - virtual ~CrossEntropyLayer(); - - void Forward(const T *inputs_ptr, const int *targets_ptr, float *outputs_ptr, - float *nll_loss_ptr); - - void Backward(const float *grad_outputs_ptr, const T *inputs_ptr, - const int *targets_ptr, T *grad_inputs_ptr); - - void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size); - - private: - void allocate_mem_buffer() { - // allocate local gpu memory - _loss_buffer = cuda_malloc(_max_batch_tokens * 2); - } - - void free_mem_buffer() { - // free local gpu memory - cuda_free(_loss_buffer); - } - - const int _padding_idx; - const float _epsilon; - const int _max_batch_tokens; - - size_t _batch_size; - size_t _seq_len; - size_t _vocab_size; - - float *_loss_buffer; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h deleted file mode 100644 index 90255152b2c8..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const float *A, - const float *B, float *C, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const __half *A, - const __half *B, __half *C, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); - -int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, - const float *alpha, const float *beta, - const float *A, const float *B, float *C, - cublasOperation_t op_A, cublasOperation_t op_B, - int stride_A, int stride_B, int stride_C, - int batch, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); - -int cublas_strided_batched_gemm( - cublasHandle_t handle, int m, int n, int k, const float *alpha, - const float *beta, const __half *A, const __half *B, __half *C, - cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B, - int stride_C, int batch, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h deleted file mode 100644 index 1595257be0f5..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -template -void check_gpu_error(T result, char const *const func, const char *const file, - int const line); - -#define CHECK_GPU_ERROR(val) check_gpu_error((val), #val, __FILE__, __LINE__) - -template -void print_vec(const T *outv, std::string outn, int num_output_ele); - -template -T *cuda_malloc(size_t ele_num); - -void cuda_free(void *pdata); - -template -void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, - std::string file, int line, cudaStream_t stream); - -#define CHECK_NAN_INF(ptr, size, stream) \ - check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ - check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream)) diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h deleted file mode 100644 index 025fbf3f8f15..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h +++ /dev/null @@ -1,96 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -template -class Dropout { - public: - struct Config { - float ratio; - bool training; - - Config(float r) : ratio(r), training(true) {} - float RATIO() const { return training ? ratio : 0.0; } - }; - - Dropout(const Config &config, size_t max_ele_num) - : _config(config), _mask(nullptr) { - _mask = cuda_malloc(max_ele_num); - } - - virtual ~Dropout() { cuda_free(_mask); } - - // after attention softmax - void dropout(T *output, const T *input, int count, cudaStream_t stream, - bool bwd = false) { - launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, - bwd); - } - - void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { - launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), - stream, true); - } - - // transformer layer's postprocessing dropout, after attn or ffn module, - // before residual add. - void bias_dropout_residual(T *output, const T *input, const T *residual, - const T *bias, int rows, int cols, - cudaStream_t stream) { - launch_ls_dropout_res_bias(output, input, _mask, bias, residual, - rows * cols, cols, _config.RATIO(), stream); - } - - void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, - int rows, int cols, cudaStream_t stream) { - launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, - _config.RATIO(), stream); - } - - // dropout inside ffn. - void bias_act_dropout(T *output, const T *input, const T *bias, int rows, - int cols, std::string activation_fn, - cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, - const T *bias, int rows, int cols, - std::string activation_fn, cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - bool HasDropout() const { return _config.RATIO() > 0.0; } - - void SetTrainingMode(bool training) { _config.training = training; } - - private: - uint8_t *_mask; - Config _config; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h deleted file mode 100644 index 8186da1eed5f..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h +++ /dev/null @@ -1,69 +0,0 @@ -#pragma once - -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#include -#include -#include - -#include - -#include "cublas_wrappers.h" -#include "kernels.h" - -template -class FeedForward { - public: - struct Config { - int outputSize; - int inputSize; - std::array gemm_algos; - Config(int outputs, int inputs) - : outputSize(outputs), - inputSize(inputs), - gemm_algos(std::array({99, 99, 99})) {} - }; - - FeedForward(Config config) : config_(config) {} - - ~FeedForward() {} - - void Forward(int bsz, const T *input_ptr, const T *weights, T *out, - cublasHandle_t &_cublasHandle) { - float alpha = T(1.); - float beta = T(0.); - - cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize, - bsz, config_.inputSize, &alpha, &beta, weights, input_ptr, - out, cublasGemmAlgo_t(config_.gemm_algos[0])); - } - void Backward(int bsz, const T *out_grad, const T *input_ptr, - const T *weights, T *weights_grad, T *bias_grad, - cublasHandle_t &_cublasHandle, cudaStream_t &stream, - T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr, - bool compute_bias = true) { - float alpha = (T)1.0, beta = (T)0.0; - cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize, - config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad, - weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1])); - - cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize, - bsz, config_.outputSize, &alpha, &beta, weights, out_grad, - inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2])); - if (compute_bias) { - launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz, - config_.outputSize, stream); - } - } - - void reset_size(int outputSize, int inputSize) { - config_.outputSize = outputSize; - config_.inputSize = inputSize; - } - - private: - Config config_; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h deleted file mode 100644 index 735e1363cc46..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h +++ /dev/null @@ -1,275 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include - -#define MAX_THREADS 1024 -#define WARP_SIZE 32 - -enum class ActivationType { kRelu, kGelu }; - -void launch_curand_init(int total_count, int dim, cudaStream_t stream); - -template -void launch_layer_norm(T *ln_res, T *vars, T *means, const T *inp, - const T *scale, const T *bias, int batch_size, - int hidden_dim, cudaStream_t stream); - -template -void launch_ln_bw(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, const T *gamma, - const T *betta, const T *vars, const T *means, int batch, - int hidden_dim, cudaStream_t stream[2]); - -template -void launch_attn_softmax(T *vals, const T *attn_mask, int batch_size, int heads, - int from_len, int to_len, bool mask_future, - cudaStream_t stream); - -template -void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, - int softmax_len, cudaStream_t stream); - -// [b, s, h] -> [b, nh, s, ad] -template -void launch_transform_0213(T *output, const T *vals, int batch_size, - int seq_length, int hidden_dim, int nhead, - cudaStream_t stream); - -// [b, s, 3, h] -> [3, b, nh, s, ad] -template -void launch_bias_add_transform_20314(T *output, const T *input, const T *bias, - int dim_0, int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream); - -// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] -template -void launch_transform4d_0213(T *output, const T *vals, int batch_size, - int seq_len, int hidden_dim, int nhead, - int trans_count, cudaStream_t stream); - -template -void launch_ls_dropout(T *out, const T *vals, uint8_t *mask, int total_count, - float ratio, cudaStream_t stream, bool backward = false); - -template -void launch_ls_dropout_res_bias(T *out, const T *vals, uint8_t *mask, - const T *bias, const T *residual, - int total_count, int dim, float ratio, - cudaStream_t stream); - -template -void launch_ls_dropout_act_bias(T *out, const T *vals, uint8_t *mask, - const T *bias, int total_count, int dim, - float ratio, cudaStream_t stream); - -template -void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template -void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, - const T *bias, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template -void launch_fuse_transpose_bias_kernel(const T *inp, T *out, int rows, int cols, - cudaStream_t stream); - -void launch_param_update(const float *input, __half *output, int size, - cudaStream_t stream); - -template -void launch_concat3_dim1(const T *inp1, const T *inp2, T *output, int sz0, - int sz2, int sz1_1, int sz1_2, cudaStream_t stream); - -template -void launch_fused_add2(T *out, const T *inp1, const T *inp2, int batch_size, - int seq_len, int hidden_size, cudaStream_t &stream); - -template -void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, - float *outputs_ptr, float *nll_loss_ptr, - float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, - const int seq_len, const int vocab_size, - cudaStream_t stream); - -template -void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, - const int *targets_ptr, T *grad_inputs_ptr, - const int padding_idx, const float epsilon, - const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template -void launch_lookup_scale_pos_dropout( - T *output, const int *input, const T *embeddings, const T *pos_embeddings, - uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, - int padding_idx, float dropout_ratio, int step, cudaStream_t &stream); - -template -void launch_d_lookup_scale_pos_dropout( - T *grad_embeddings, const T *grad_output, const int *input, - const uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, - int vocab_size, int padding_idx, float dropout_ratio, cudaStream_t &stream); - -/* Convert 2-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_2dim(int id1, int id2, int dim2) { - return id1 * dim2 + id2; -} - -/* Convert 3-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, - int dim2, int dim3) { - return id1 * dim2 * dim3 + id2 * dim3 + id3; -} - -/* Convert 4-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3, - int id4, int dim2, int dim3, - int dim4) { - // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; - int res = id4; - - int ld = dim4; - res += id3 * ld; - - ld *= dim3; - res += id2 * ld; - - ld *= dim2; - res += id1 * ld; - - return res; -} - -/* Convert 5-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_5dim(int id1, int id2, int id3, - int id4, int id5, int dim2, - int dim3, int dim4, - int dim5) { - // return id1*(dim2*dim3*dim4*dim5) + id2*(dim3*dim4*dim5) + id3*(dim4*dim5) + - // id4*dim5 + dim5; - int res = id5; - - int ld = dim5; - res += id4 * ld; - - ld *= dim4; - res += id3 * ld; - - ld *= dim3; - res += id2 * ld; - - ld *= dim2; - res += id1 * ld; - - return res; -} - -/* Convert 6-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, - int id4, int id5, int id6, - int dim2, int dim3, int dim4, - int dim5, int dim6) { - // return id1*(dim2*dim3*dim4*dim5*dim6) + id2*(dim3*dim4*dim5*dim6) + - // id3*(dim4*dim5*dim6) + id4*(dim5*dim6) + id5*dim6 + id6; - int res = id6; - - int ld = dim6; - res += id5 * ld; - - ld *= dim5; - res += id4 * ld; - - ld *= dim4; - res += id3 * ld; - - ld *= dim3; - res += id2 * ld; - - ld *= dim2; - res += id1 * ld; - - return res; -} - -/* Convert vector index to 6-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_6dim( - int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0, - int *id1, int *id2, int *id3, int *id4, int *id5) { - *id5 = src % dim5; - src /= dim5; - - *id4 = src % dim4; - src /= dim4; - - *id3 = src % dim3; - src /= dim3; - - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 5-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1, - int dim2, int dim3, - int dim4, int *id0, - int *id1, int *id2, - int *id3, int *id4) { - *id4 = src % dim4; - src /= dim4; - - *id3 = src % dim3; - src /= dim3; - - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 4-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, - int dim2, int dim3, - int *id0, int *id1, - int *id2, int *id3) { - *id3 = src % dim3; - src /= dim3; - - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 3-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, - int dim2, int *id0, - int *id1, int *id2) { - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 2-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_2dim(int src, int dim1, - int *id0, int *id1) { - *id1 = src % dim1; - *id0 = src / dim1; -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh b/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh deleted file mode 100644 index 4f65e7b54ba1..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh +++ /dev/null @@ -1,12 +0,0 @@ -// copied from https://github.com/dmlc/dgl/pull/2758 -#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_ -#define DGL_ARRAY_CUDA_DGL_CUB_CUH_ - -#define CUB_NS_PREFIX namespace ls { -#define CUB_NS_POSTFIX } -#include "cub/cub.cuh" -#include "cub/util_allocator.cuh" -#undef CUB_NS_POSTFIX -#undef CUB_NS_PREFIX - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h deleted file mode 100644 index a7767e187ffc..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h +++ /dev/null @@ -1,65 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template -class Normalize_Layer { - public: - struct Config { - uint32_t hidden_dim; - bool use_mean; - Config(uint32_t hidden_dim, bool use_mean = false) - : hidden_dim(hidden_dim), use_mean(use_mean) {} - }; - - Normalize_Layer(Config config, size_t max_rows) - : config_(config), vars_(nullptr), means_(nullptr) { - vars_ = cuda_malloc(max_rows); - if (config_.use_mean) { - means_ = cuda_malloc(max_rows); - } - } - - ~Normalize_Layer() { - cuda_free(vars_); - cuda_free(means_); - } - - void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, - int batch_size, cudaStream_t stream) { - launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, - config_.hidden_dim, stream); - } - - /* - residual_grad, inp_or_out, betta should be treated carefully. - inp_or_out = input if use_mean else output - residual_grad, betta can be nullptr. - residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln - betta are only used to compute xhat, - (use_mean == false) ^ (betta == nullptr) should be true - */ - void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, const T *gamma, - const T *betta, int batch_size, cudaStream_t stream[2]) { - launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, - inp_or_out, gamma, betta, vars_, means_, batch_size, - config_.hidden_dim, stream); - } - - inline bool use_mean() const { return config_.use_mean; } - - private: - Config config_; - T *vars_; - T *means_; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h deleted file mode 100644 index b917abaf0336..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template -class Softmax { - public: - struct Config { - size_t nhead; - Config(size_t nhead) : nhead(nhead) {} - }; - - Softmax(Config config) : config_(config) {} - - ~Softmax() {} - - void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, - int to_len, cudaStream_t &stream, bool mask_future = true) { - launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, - to_len, mask_future, stream); - } - - void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, - int to_len, cudaStream_t stream) { - launch_attn_softmax_bw(out_grad, soft_out, - batch_size * config_.nhead * from_len, to_len, - stream); - } - - void reset_size(size_t nhead) { config_.nhead = nhead; } - - private: - Config config_; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h b/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h deleted file mode 100644 index d386650e8235..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#pragma once - -#include -#include -#include - -#include - -#include "cublas_wrappers.h" - -template -class StridedBatchGemm { - public: - struct Config { - int m; - int n; - int k; - float alpha; - float beta; - cublasOperation_t op_A; - cublasOperation_t op_B; - std::array gemm_algos; - - Config(float param_alpha, float param_beta, cublasOperation_t opA, - cublasOperation_t opB) - : alpha(param_alpha), - beta(param_beta), - op_A(opA), - op_B(opB), - gemm_algos(std::array({99, 99, 99})) {} - void SetConfig(int mm, int nn, int kk) { - m = mm; - n = nn; - k = kk; - } - }; - - StridedBatchGemm(const Config &config) : _config(config) {} - - virtual ~StridedBatchGemm() {} - - void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b, - cublasHandle_t handle) { - int stride_a = _config.m * _config.k; - int stride_b = _config.n * _config.k; - int stride_c = _config.m * _config.n; - - cublas_strided_batched_gemm( - handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta, - _buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a, - stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0])); - } - - void Backward(int bsz, const T *d_output, const T *_buffer_a, - const T *_buffer_b, cublasHandle_t handle, - T *inpGradA = nullptr, T *inpGradB = nullptr) { - int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); - int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); - - int stride_a = mb * _config.n; - int stride_b = _config.n * kb; - int stride_c = _config.m * _config.k; - - // B need to transpose. - cublasOperation_t op_b = - (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - // Calculate d_A. - cublas_strided_batched_gemm( - handle, mb, kb, _config.n, &_config.alpha, &_config.beta, - (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), - (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA, - CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz, - cublasGemmAlgo_t(_config.gemm_algos[1])); - - // A need to transpose. - cublasOperation_t op_a = - (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - stride_a = _config.m * _config.k; - stride_b = _config.m * _config.n; - stride_c = _config.n * _config.k; - - // Calculate d_B. - cublas_strided_batched_gemm( - handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta, - _buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b, - stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2])); - } - - inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } - - private: - Config _config; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu deleted file mode 100644 index e2f1869b165e..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu +++ /dev/null @@ -1,1172 +0,0 @@ -#include - -#include "block_reduce.h" -#include "kernels.h" - -namespace cg = cooperative_groups; -const float LN_EPSILON = 1e-8f; -#define TILE_DIM 32 - -template -__forceinline__ __device__ T add_eps(T x) { - return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); -} - -/** -@brief: ker_layer_norm -Standard layer normalization. -It will not only output the layer norm result, - but also outputs variance. - may also output means, depends on whether - the means argument is nullptr - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -ln_res: [batch_size* seq_len, hidden_size], ln result. -vars: [batch_size* seq_len], variance per token -means: [batch_size* seq_len], means per token, can be nullput -inp: [batch_size * seq_len, hidden_size], ln input. -scale: [hidden_size], ln scale -bias: [hidden_size], ln bias -*/ -template -__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, - const T *scale, const T *bias, int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val = inp_f4[idx]; - l_sum += val.x + val.y + val.z + val.w; - l_square_sum += - val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 4.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 vscale = __ldg((const float4 *)scale + idx); - float4 vbias = __ldg((const float4 *)bias + idx); - float4 val = inp_f4[idx]; - val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; - val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; - val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; - val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; - output_f4[idx] = val; - } -} - -template <> -__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, - __half *means, const __half *inp, - const __half *scale, const __half *bias, - int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 val_f2 = __half22float2(val_h2[i]); - l_sum += val_f2.x + val_f2.y; - l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; - } - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 8.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - // load scale, bias, input - float4 scale_f4 = __ldg((const float4 *)scale + idx); - __half2 *scale_h2 = (__half2 *)(&scale_f4); - float4 bias_f4 = __ldg((const float4 *)bias + idx); - __half2 *bias_h2 = (__half2 *)(&bias_f4); - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); - -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 scale_f2 = __half22float2(scale_h2[i]); - float2 bias_f2 = __half22float2(bias_h2[i]); - float2 val_f2 = __half22float2(val_h2[i]); - val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; - val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; - val_h2[i] = __float22half2_rn(val_f2); - } - output_f4[idx] = val_f4; - } -} - -// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; -// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x -// * val_f2_1.x + val_f2_1.y * val_f2_1.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 2; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_h2[i] = __float22half2_rn(val_f2); -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// } -// } - -// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// float4 val_f4_2 = inp_f4[idx+2]; -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + -// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * -// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x -// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + -// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + -// val_f2_3.y * val_f2_3.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 4; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); -// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); -// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); -// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); -// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); -// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); -// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// float4 val_f4_2 = inp_f4[idx+2]; -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); -// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); -// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * -// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var -// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * -// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) -// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = -// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); -// val_h2_2[i] = __float22half2_rn(val_f2_2); -// val_h2_3[i] = __float22half2_rn(val_f2_3); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// output_f4[idx+2] = val_f4_2; -// output_f4[idx+3] = val_f4_3; -// } -// } - -template <> -void launch_layer_norm(float *ln_res, float *vars, float *means, - const float *inp, const float *scale, - const float *bias, int batch_size, int hidden_dim, - cudaStream_t stream) { - if (hidden_dim % 4 != 0) { - throw std::runtime_error("violate hidden_dim % 4 = 0"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); -} - -template <> -void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, - const __half *inp, const __half *scale, - const __half *bias, int batch_size, - int hidden_dim, cudaStream_t stream) { - if (hidden_dim % 8 != 0) { - throw std::runtime_error("violate hidden_dim % 8 = 0"); - } - hidden_dim >>= 3; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<__half><<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); - // if (hidden_dim % 8 != 0) { - // throw std::runtime_error("violate hidden_dim % 8 = 0"); - // } - // hidden_dim >>= 3; - - // if (hidden_dim * 8 < 8192) { - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm<__half><<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { - // hidden_dim >>= 1; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x2<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { - // hidden_dim >>= 2; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x4<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else { - // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - // } -} - -/** -@brief: ker_ln_bw_dgamma_dbetta -Layer norm backword kernel, compute the gradient of gamma and betta. -dbetta = sum(dout, dim=0) -dgamma = sum(xhat * dout, dim=0) -xhat = (input - mean) * rsqrt(var) or - (output - betta) / gamma - - -@thread -gridDim.x = hidden_size / 32 -blockDim.x = 32 -blockDim.y = 32 - -@param -gamma_grad: [hidden_size], gradient of gamma -betta_grad: [hidden_size], gradient of betta -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat, maybe nullptr -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat, maybe nullptr -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -(gamma && betta) ^ (vars && means) should be true -*/ -template -__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, - const T *out_grad, const T *inp_or_out, - const T *gamma, const T *betta, - const T *vars, const T *means, int rows, - int width) { - __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - // Loop across inp height - float dbetta = 0; - float dgamma = 0; - float dout, val; - if (idx < width) { - if (means == nullptr) { - float vbetta = (float)betta[idx]; - float vgamma = (float)gamma[idx]; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is output - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - vbetta) / add_eps(vgamma) * dout); - offset += y_stride; - } - } else { - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is input - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - (float)means[r]) * - rsqrtf((float)vars[r] + LN_EPSILON) * dout); - offset += y_stride; - } - } - } - - // Sum the shared buffer. - betta_buffer[threadIdx.x][threadIdx.y] = dbetta; - gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; - __syncthreads(); - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - if (threadIdx.x == 0 && idx < width) { - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -/** -@brief: ker_ln_bw_dinp -Layer norm backword kernel, compute the gradient of input. -dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) - * rsqrt(var) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dxhat = dout * gamma - - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, - usually appear in pre-layer-norm for transformer layer, maybe nullptr -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat and dxhat -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat and dinp -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -*/ -template -__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, - const T *gamma, const T *betta, const T *vars, - const T *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - float4 dxhat, xhat; - float var_rsqrt; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - dxhat = ((const float4 *)out_grad)[offset]; - float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; - dxhat.x *= vgamma.x; - dxhat.y *= vgamma.y; - dxhat.z *= vgamma.z; - dxhat.w *= vgamma.w; - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - xhat = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); - xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); - xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); - xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; - xhat.x = (xhat.x - fmean) * var_rsqrt; - xhat.y = (xhat.y - fmean) * var_rsqrt; - xhat.z = (xhat.z - fmean) * var_rsqrt; - xhat.w = (xhat.w - fmean) * var_rsqrt; - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - float reduce_val[2] = {0.f, 0.f}; - if (threadIdx.x < hidden_dim) { - reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; - reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + - dxhat.w * xhat.w; - } - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - dxhat.x += dresidual.x; - dxhat.y += dresidual.y; - dxhat.z += dresidual.z; - dxhat.w += dresidual.w; - } - ((float4 *)inp_grad)[offset] = dxhat; -} - -template <> -__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, - int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - - float2 dxhat[4], xhat[4]; - float var_rsqrt; - float4 vtmp; - __half2 *tmp_h2; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vbetta = __half22float2(betta_h2[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; -} - -__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float var_rsqrt; - float4 vtmp, vtmp_1; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 2; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; -} - -__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float2 dxhat_2[4], xhat_2[4]; - float2 dxhat_3[4], xhat_3[4]; - float var_rsqrt; - float4 vtmp, vtmp_1, vtmp_2, vtmp_3; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - __half2 *tmp_h2_2; - __half2 *tmp_h2_3; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - vtmp_2 = ((const float4 *)out_grad)[offset + 2]; - vtmp_3 = ((const float4 *)out_grad)[offset + 3]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); - tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; - float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; - float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); - __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); - __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vdout_2 = __half22float2(tmp_h2_2[i]); - float2 vdout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - dxhat_2[i].x = vdout_2.x * vgamma_2.x; - dxhat_2[i].y = vdout_2.y * vgamma_2.y; - dxhat_3[i].x = vdout_3.x * vgamma_3.x; - dxhat_3[i].y = vdout_3.y * vgamma_3.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + - dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + - dxhat_3[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; - vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; - float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; - float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); - __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); - __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vout_2 = __half22float2(tmp_h2_2[i]); - float2 vout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - float2 vbetta_2 = __half22float2(betta_h2_2[i]); - float2 vbetta_3 = __half22float2(betta_h2_3[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); - xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); - xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - float2 vinp_2 = __half22float2(tmp_h2_2[i]); - float2 vinp_3 = __half22float2(tmp_h2_3[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; - xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; - xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; - float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); - __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); - __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_2[2 * i])); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_3[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; - ((float4 *)inp_grad)[offset + 2] = vtmp_2; - ((float4 *)inp_grad)[offset + 3] = vtmp_3; -} - -/** -Layer norm backword, - compute the gradient of gamma, betta and input. -dbetta = sum(dout, dim=0) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dgamma = sum(xhat * dout, dim=0) -dxhat = dout * gamma -dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) - * rsqrt(var) - -residual_grad, means, betta can be nullptr. -residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln -means and betta are only used to compute xhat, - (means == nullptr) ^ (betta == nullptr) should be true -*/ -template <> -void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, - const float *out_grad, const float *residual_grad, - const float *inp_or_out, const float *gamma, - const float *betta, const float *vars, - const float *means, int batch, int hidden_dim, - cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 4 != 0 || hidden_dim > 4096) { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, - hidden_dim); -} - -template <> -void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, - __half *inp_grad, const __half *out_grad, - const __half *residual_grad, const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, int batch, - int hidden_dim, cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<__half><<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 8 != 0) { - throw std::runtime_error("hidden_dim % 8 != 0"); - } - hidden_dim >>= 3; - - if (hidden_dim * 8 <= 8192) { - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { - hidden_dim >>= 1; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x2<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x4<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - } -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu deleted file mode 100644 index 3862a699d3c3..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu +++ /dev/null @@ -1,365 +0,0 @@ -#include -#include - -#include -#include - -#include "block_reduce.h" -#include "kernels.h" - -namespace cg = cooperative_groups; -const float EPSILON = 1e-8f; - -/** -@brief: softmax_kernel -Softmax forward kernel for - enc-self-attn, dec-self-attn, encdec-attn - -@thread -gridDim.x = dynamic -gridDim.y = batch_size -gridDim.z = nhead -blockDim.x = from_len - -@param -inp: [batch_size, nhead, from_len, to_len], softmax input. -attn_mask: [batch_size, to_len], padding tokens are -inf, - non padding tokens are 0. - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template -__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // block reduce max - blockReduce(l_max); - // write shared - __shared__ float s_max[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_max[i] = l_max[i]; - } - } - __syncthreads(); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - s_max[i]); - l_sum[i] += val[i][j]; - } - } - // block reduce sum - blockReduce(l_sum); - // write shared - __shared__ float s_sum[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - } - } - __syncthreads(); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * s_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -template -__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // warp reduce max - warpReduce(l_max); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - l_max[i]); - l_sum[i] += val[i][j]; - } - } - // warp reduce sum - warpReduce(l_sum); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * l_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -/* - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template <> -void launch_attn_softmax(float *inp, const float *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 16; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 32; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 64; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -template <> -void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<__half, 32, 1><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<__half, 32, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 8; - ker_attn_softmax<__half, 64, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 16; - ker_attn_softmax<__half, 128, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 32; - ker_attn_softmax<__half, 256, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -/** -@brief: ker_attn_softmax_bw -Softmax backward in self attention. - -@thread -gridDim.x = batch_size * nhead * seq_len / warps_per_block -blockDim.x = WARP_SIZE -blockDim.y = warps_per_block - -@param -grad: [batch_size, nhead, seq_len, seq_len], output grad. -output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. -*/ -template -__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { - int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; - int offset = batch_idx * softmax_length + threadIdx.x; - - grad += offset; - inp += offset; - - T grad_reg[ITERATIONS]; - T inp_reg[ITERATIONS]; - float sum = 0.0; - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) { - grad_reg[i] = grad[i * WARP_SIZE]; - inp_reg[i] = inp[i * WARP_SIZE]; - sum += (float)grad_reg[i] * (float)inp_reg[i]; - } - } - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) - grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); - } -} - -template -void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, - int softmax_len, cudaStream_t stream) { - const int warps_per_block = 4; - // rows = batch_size * nhead * from_len - dim3 grid_dim(rows / warps_per_block); - dim3 block_dim(WARP_SIZE, warps_per_block); - - if (softmax_len <= 32) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 64) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 128) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 256) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 384) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 512) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 768) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 1024) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 2048) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else - throw std::runtime_error( - std::string( - "Special sequence length found in softmax backward, seq_len: ") + - std::to_string(softmax_len)); -} - -template void launch_attn_softmax_bw<__half>(__half *out_grad, - const __half *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); -template void launch_attn_softmax_bw(float *out_grad, - const float *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu deleted file mode 100644 index 04de3c092ee0..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu +++ /dev/null @@ -1,314 +0,0 @@ -#include -#include -#include - -#include "kernels.h" - -using namespace cub; - -/** -@brief: transform_0213 -Split the attention heads and reshape input -during backward progress of encoder self-attention - -@thread -gridDim.x = batch_size -gridDim.y = seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -input: [batch_size, seq_len, hidden_dim] -output: [batch_size, nhead, seq_len, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -*/ - -template -__global__ void transform_0213(T *output, const T *input, int hidden_dim, - int head_dim); - -template <> -__global__ void transform_0213(float *output, const float *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -template <> -__global__ void transform_0213<__half>(__half *output, const __half *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -// [b, s, h] -> [b, nh, s, ad] -template <> -void launch_transform_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213 - <<>>(output, input, hidden_dim, head_dim); -} - -template <> -void launch_transform_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213<__half> - <<>>(output, input, hidden_dim, head_dim); -} - -/** -@brief: bias_add_transform_20314 -Add bias to input, transform from -[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] - -@thread -gridDim.x = dim_0 -gridDim.y = dim_1 -gridDim.z = dim_2 -blockDim.x = min(dim_3 * dim_4, MAX_THREADS) - -@param -input: [dim_0, dim_1, dim_2, dim_3, dim_4] -bias: [dim_2, dim_3, dim_4] -output: [dim_2, dim_0, dim_3, dim_1, dim_4] -*/ -template -__global__ void bias_add_transform_20314(T *output, const T *input, - const T *bias, int dim_3, int dim_4); - -template <> -__global__ void bias_add_transform_20314(float *output, - const float *input, - const float *bias, int dim_3, - int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - vres4.x = vqkv4.x + vbias4.x; - vres4.y = vqkv4.y + vbias4.y; - vres4.z = vqkv4.z + vbias4.z; - vres4.w = vqkv4.w + vbias4.w; - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -template <> -__global__ void bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_3, - int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); - __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); - __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); - h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); - h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); - h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -// [b, s, 3, h] -> [3, b, nh, s, ad] -template <> -void launch_bias_add_transform_20314(float *output, const float *input, - const float *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 2; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314 - <<>>(output, input, bias, dim_3, dim_4); -} - -template <> -void launch_bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 3; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314<__half> - <<>>(output, input, bias, dim_3, dim_4); -} - -/** -@brief: transform4d_0213 -Reshape the input matrix to merge the heads - -@thread -gridDim.x = (num_all + max_block_thread - 1) / max_block_thread -blockDim.x = max_block_thread - -@param -input: [trans_count, batch_size, nhead, seq_len, head_dim] -output: [batch_size, seq_len, trans_count, nhead, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -trans_count: 1 or 3, the count of matrice need to be transformed -*/ -template -__global__ void transform4d_0213(T *output, const T *input, int batch_size, - int seq_len, int trans_count, int nhead, - int head_dim, int num_all) { - int offset = blockIdx.x * blockDim.x + threadIdx.x; - if (offset >= num_all) { - return; - } - int trans_id, batch_id, head_id, token_id, dim_id; - decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, - &batch_id, &head_id, &token_id, &dim_id); - // [b, s, tc, nh, ad] - int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, - seq_len, trans_count, nhead, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - res4[trg_offset] = input4[offset]; -} - -// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] -template <> -void launch_transform4d_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} - -template <> -void launch_transform4d_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, - int hidden_dim, int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<__half><<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp deleted file mode 100644 index d08f3dbc74d8..000000000000 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp +++ /dev/null @@ -1,406 +0,0 @@ -#include "multihead_attention_1d.h" - -#include -#include -#include - -#if TORCH_VERSION_MAJOR > 1 || \ - (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) -#include -#else -#include -#endif -#include - -#include "context.h" -#include "kernels.h" - -template -MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, - int max_seq_len, int hidden_size, - int num_heads, - float attn_prob_dropout_ratio, - float hidden_output_dropout_ratio, - bool pre_or_postLayerNorm) - : _layer_id(layer_id), - _max_batch_tokens(max_batch_tokens), - _max_seq_len(max_seq_len), - _hidden_size(hidden_size), - _heads(num_heads), - _training(true), - _pre_or_postLayerNorm(pre_or_postLayerNorm), - _qkv_linear( - typename FeedForward::Config(3 * hidden_size, hidden_size)), - _attn_out_linear( - typename FeedForward::Config(hidden_size, hidden_size)), - _attn_ln(typename Normalize_Layer::Config(hidden_size, false), - _max_batch_tokens), - _softmax(typename Softmax::Config(num_heads)), - _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio), - _max_batch_tokens * _heads * _max_seq_len), - _attn_dropout(typename Dropout::Config(hidden_output_dropout_ratio), - _max_batch_tokens * _hidden_size), - _attn_scores(typename StridedBatchGemm::Config( - (T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T, - CUBLAS_OP_N)), - _attn_context(typename StridedBatchGemm::Config( - T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) { - assert(_hidden_size % _heads == 0); -} - -template -MultiHeadAttention::~MultiHeadAttention() { - free_mem_buffer(); -} - -template -void MultiHeadAttention::attn_layer_fw(const T *input_ptr, - const T *input_mask_ptr, - T *output_ptr, T *buffer) { - T *q_tf_ptr = _qkv_ptr; - T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; - T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; - - if (_pre_or_postLayerNorm) { - _attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, - _batch_tokens, _stream); - } - const T *gemmQKV_inp_ptr = - _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; - _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); - _qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, - _cublasHandle); - - launch_bias_add_transform_20314(q_tf_ptr, buffer, _attn_qkvb_ptr, - _batch_size, _seq_len, 3, _heads / pg_size, - _hidden_size / _heads, _stream); - - // attention scores, q*k - _attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, - _cublasHandle); - - // Softmax + Mask - _softmax.reset_size(_heads / pg_size); - _softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, - _seq_len, _stream, true); - - // attn prob dropout. - _attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, - _batch_heads * _seq_len * _seq_len, _stream); - - // attention context, score * v - _attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, - _cublasHandle); - - // [b, nh, s, ad] -> [b, s, nh, ad] - launch_transform4d_0213(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, - _hidden_size / pg_size, _heads / pg_size, 1, - _stream); - - _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); - _attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, - output_ptr, _cublasHandle); - - // allreduce - if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { - } else { - auto data_type = torch::kFloat; - if (typeid(T) != typeid(float)) { - data_type = torch::kHalf; - } - auto output_tensor = torch::from_blob( - output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, - torch::TensorOptions(torch::kCUDA).dtype(data_type)); - std::vector allreduce_tensors = {output_tensor}; - auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); - work->wait(); - } - - _attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, - _attn_ob_ptr, _batch_tokens, _hidden_size, - _stream); - if (!_pre_or_postLayerNorm) { - // in-place ln since ln-input will not be used in post-ln mode - _attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, - _batch_tokens, _stream); - } -} - -template -void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, - T *out_ptr) { - _stream = Context::Instance().get_stream(); - _cublasHandle = Context::Instance().get_cublashandle(); - T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim - - attn_layer_fw(input_ptr, input_mask_ptr, out_ptr, attn_buffer); -} - -template -void MultiHeadAttention::attn_layer_bw(const T *input_ptr, - const T *input_mask_ptr, - const T *output_ptr, - const T *grad_output_ptr, - T *grad_input_ptr, T *buffer) { - cudaStream_t streams[2] = {_stream, _stream}; - - const T *q_tf_ptr = _qkv_ptr; - const T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; - const T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; - // batch_dim = batch_size * seq_len * hidden_size - // buffer size: batch_dim * 3 + max(batch_dim * 3, - // batch_size * head_num * seq_len * seq_len) - T *grad_residual_ptr = buffer; - buffer += _batch_dim; - - T *grad_input_buf_ptr = buffer; // batch_dim - T *grad_qkv_5d_ptr = buffer; // batch_dim * 3 - buffer += 3 * _batch_dim / pg_size; - - T *grad_qkv_4d_ptr = buffer; // batch_dim * 3 - T *grad_softmax_ptr = buffer; // batch_size * head_num * seq_len * seq_len - // buffer += max(3 * _batch_dim, - // batch_size * head_num * seq_len * seq_len); - - if (_pre_or_postLayerNorm) { - _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, - grad_output_ptr, _batch_tokens, - _hidden_size, _stream); - } else { - _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, - grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr, - _attn_nb_ptr, _batch_tokens, streams); - _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, - grad_residual_ptr, _batch_tokens, - _hidden_size, _stream); - } - - // bw of output project - _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); - _attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, - _attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr, - _cublasHandle, _stream, grad_input_buf_ptr, nullptr, - false); - launch_transform_0213(grad_input_ptr, grad_input_buf_ptr, _batch_size, - _seq_len, _hidden_size / pg_size, _heads / pg_size, - _stream); - - // bw of score * v - _attn_context.Backward( - _batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle, - grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr); - - _attn_prob_dropout.d_dropout(grad_softmax_ptr, - _batch_heads * _seq_len * _seq_len, _stream); - - _softmax.reset_size(_heads / pg_size); - _softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, - _seq_len, _stream); - - // bw of q * k - _attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, - _cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size, - grad_qkv_5d_ptr); - - // [3, b, nh, s, ad] -> [b, s, 3, h] - launch_transform4d_0213(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, - _seq_len, _hidden_size / pg_size, _heads / pg_size, - 3, _stream); - - const T *gemmQKV_inp_ptr = - _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; - _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); - _qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, - _attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, - _cublasHandle, _stream, grad_input_buf_ptr, nullptr, - true); - - // allreduce - if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { - } else { - auto data_type = torch::kFloat; - if (typeid(T) != typeid(float)) { - data_type = torch::kHalf; - } - auto grad_input_tensor = - torch::from_blob(grad_input_buf_ptr, - {int(_batch_size), int(_seq_len), int(_hidden_size)}, - torch::TensorOptions(torch::kCUDA).dtype(data_type)); - std::vector allreduce_tensors = {grad_input_tensor}; - auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); - work->wait(); - } - - if (_pre_or_postLayerNorm) { - _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, - grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr, - _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams); - } else { - // FIXME later - launch_fused_add2(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, - _batch_size, _seq_len, _hidden_size, _stream); - } -} - -template -void MultiHeadAttention::Backward(const T *grad_output_ptr, - const T *input_ptr, const T *output_ptr, - const T *input_mask_ptr, - T *grad_input_ptr) { - _stream = Context::Instance().get_stream(); - _cublasHandle = Context::Instance().get_cublashandle(); - T *buffer = _shared_mem_ptr; - - /* - buffer size needed by attn bw: - 4 * _batch_dim + max(3 * _batch_dim, - _batch_size * _head_num * _seq_len * _seq_len); - */ - attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, - grad_input_ptr, buffer); -} - -template -void MultiHeadAttention::SetTrainingMode(bool training) { - // Dropout will be skipped when not in training model. - _attn_prob_dropout.SetTrainingMode(training); - _attn_dropout.SetTrainingMode(training); -} - -template -T *MultiHeadAttention::_shared_mem_ptr = nullptr; - -template class MultiHeadAttention; -template class MultiHeadAttention<__half>; - -// x is torch::Tensor -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -static std::unordered_map> s_multihead_attention; - -template -int create_multihead_attention(int layer_id, int max_batch_tokens, - int max_seq_len, int hidden_dim, int num_heads, - float attn_prob_dropout_ratio, - float hidden_dropout_ratio, - bool pre_or_postLayerNorm, - c10::intrusive_ptr pg_) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - Context::Instance().set_stream(stream); - auto layer = std::make_shared>( - layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, - attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm); - - layer->SetPG(pg_); - - s_multihead_attention[layer_id] = layer; - - std::string dtype = (std::is_same::value) ? "half" : "float"; - - return 0; -} - -template -std::vector multihead_attention_fw( - int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask, - const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias, - const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias, - const torch::Tensor &norm_weight, const torch::Tensor &norm_bias, - bool training_mode, bool prelayernorm) { - CHECK_INPUT(input); - CHECK_INPUT(input_mask); - - const T *input_ptr = (const T *)input.data_ptr(); - const T *input_mask_ptr = (const T *)input_mask.data_ptr(); - - auto output = torch::empty_like(input); - T *out_ptr = (T *)output.data_ptr(); - - std::shared_ptr> layer = - std::static_pointer_cast>( - s_multihead_attention[layer_id]); - layer->set_cur_batch_shape(input.size(0), input.size(1)); - layer->SetTrainingMode(training_mode); - - layer->_attn_qkvw_ptr = (const T *)in_proj_weight.data_ptr(); - layer->_attn_qkvb_ptr = (const T *)in_proj_bias.data_ptr(); - layer->_attn_ow_ptr = (const T *)out_proj_weight.data_ptr(); - layer->_attn_ob_ptr = (const T *)out_proj_bias.data_ptr(); - layer->_attn_nw_ptr = (const T *)norm_weight.data_ptr(); - layer->_attn_nb_ptr = (const T *)norm_bias.data_ptr(); - - layer->Forward(input_ptr, input_mask_ptr, out_ptr); - - return {output}; -} - -template -std::vector multihead_attention_bw( - int layer_id, const torch::Tensor &grad_dec_output, - const torch::Tensor &output, const torch::Tensor &input, - const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight, - const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight, - const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight, - const torch::Tensor &norm_bias) { - auto g_output = grad_dec_output.contiguous(); - CHECK_INPUT(g_output); - CHECK_INPUT(output); - CHECK_INPUT(input); - CHECK_INPUT(input_mask); - - auto grad_input = torch::empty_like(input); - auto grad_in_proj_weight = torch::empty_like(in_proj_weight); - auto grad_in_proj_bias = torch::empty_like(in_proj_bias); - auto grad_out_proj_weight = torch::empty_like(out_proj_weight); - auto grad_out_proj_bias = torch::empty_like(out_proj_bias); - auto grad_norm_weight = torch::empty_like(norm_weight); - auto grad_norm_bias = torch::empty_like(norm_bias); - - // inputs. - const T *grad_dec_output_ptr = (const T *)g_output.data_ptr(); - const T *input_ptr = (const T *)input.data_ptr(); - const T *output_ptr = (const T *)output.data_ptr(); - const T *input_mask_ptr = (const T *)input_mask.data_ptr(); - - // outputs. - T *grad_input_ptr = (T *)grad_input.data_ptr(); - - std::shared_ptr> layer = - std::static_pointer_cast>( - s_multihead_attention[layer_id]); - layer->set_cur_batch_shape(g_output.size(0), g_output.size(1)); - - layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr(); - layer->_grad_attn_qkvb_ptr = (T *)grad_in_proj_bias.data_ptr(); - layer->_grad_attn_ow_ptr = (T *)grad_out_proj_weight.data_ptr(); - layer->_grad_attn_ob_ptr = (T *)grad_out_proj_bias.data_ptr(); - layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr(); - layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr(); - - layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, - grad_input_ptr); - - return {grad_input, grad_in_proj_weight, grad_in_proj_bias, - grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight, - grad_norm_bias}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multihead_attention_fw_fp32", &multihead_attention_fw, - "Multi-head Attention forward with fp32 (CUDA)"); - m.def("multihead_attention_fw_fp16", &multihead_attention_fw<__half>, - "Multi-head Attention forward with fp16 (CUDA)"); - m.def("multihead_attention_bw_fp32", &multihead_attention_bw, - "Multi-head Attention backward with fp32 (CUDA)"); - m.def("multihead_attention_bw_fp16", &multihead_attention_bw<__half>, - "Multi-head Attention backward with fp16 (CUDA)"); - m.def("create_multihead_attention_fp32", &create_multihead_attention, - "Create Multi-head Attention with fp32 (CUDA)"); - m.def("create_multihead_attention_fp16", &create_multihead_attention<__half>, - "Create Multi-head Attention with fp16 (CUDA)"); -} diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h deleted file mode 100644 index 6505eb31fb9f..000000000000 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h +++ /dev/null @@ -1,167 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#if TORCH_VERSION_MAJOR > 1 || \ - (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) -#include -#else -#include -#endif - -#include -#include - -#include "cuda_util.h" -#include "dropout.h" -#include "feed_forward.h" -#include "normalize_layer.h" -#include "softmax.h" -#include "strided_batch_gemm.h" - -template -class MultiHeadAttention { - public: - MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, - int hidden_size, int num_heads, float attn_dropout_ratio, - float hidden_output_dropout_ratio, - bool pre_or_postLayerNorm); - - virtual ~MultiHeadAttention(); - - void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr); - - void Backward(const T *grad_output_ptr, const T *input_ptr, - const T *output_ptr, const T *input_mask_ptr, - T *grad_input_ptr); - - void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, - T *buffer); - - void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, - const T *output_ptr, const T *grad_output_ptr, - T *grad_input_attn_layer_bwptr, T *buffer); - - void set_cur_batch_shape(int batch_size, int seq_len) { - _batch_size = batch_size; - _seq_len = seq_len; - _batch_tokens = batch_size * seq_len; - _batch_heads = batch_size * _heads / pg_size; - _batch_dim = _batch_tokens * _hidden_size; - _attn_scores.SetConfig(_seq_len, _seq_len, _hidden_size / _heads); - _attn_context.SetConfig(_hidden_size / _heads, _seq_len, _seq_len); - } - - void SetTrainingMode(bool training); - inline bool IsTrainingMode() const { return _training; } - - void SetPG(c10::intrusive_ptr pg_) { - pg = pg_; - pg_size = 1; - if (pg != c10::detail::UniqueVoidPtr()) { - pg_size = pg->getSize(); - } - allocate_mem_buffer(); - } - - // weights ptr - const T *_attn_qkvw_ptr; - const T *_attn_qkvb_ptr; - const T *_attn_ow_ptr; - const T *_attn_ob_ptr; - const T *_attn_nw_ptr; - const T *_attn_nb_ptr; - - // grads ptr - T *_grad_attn_qkvw_ptr; - T *_grad_attn_qkvb_ptr; - T *_grad_attn_ow_ptr; - T *_grad_attn_ob_ptr; - T *_grad_attn_nw_ptr; - T *_grad_attn_nb_ptr; - - private: - void allocate_mem_buffer() { - // allocate local gpu memory - if (_pre_or_postLayerNorm) { - _gemmQKV_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); - } else { - _gemmQKV_inp_ptr = nullptr; - } - - _qkv_ptr = cuda_malloc(_max_batch_tokens * _hidden_size * 3); - _soft_out_ptr = - cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); - _ctx_bufB_ptr = - cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); - _attn_o_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); - - // buffer size needed by attn bw - size_t smem_size = - 4 * _max_batch_tokens * _hidden_size / pg_size + - std::max(3 * _max_batch_tokens * _hidden_size / pg_size, - _max_batch_tokens * _heads / pg_size * _max_seq_len); - - if (!_shared_mem_ptr) { - cuda_free(_shared_mem_ptr); - _shared_mem_ptr = cuda_malloc(smem_size); - } - } - - void free_mem_buffer() { - // free local gpu memory - cuda_free(_gemmQKV_inp_ptr); - cuda_free(_qkv_ptr); - cuda_free(_soft_out_ptr); - cuda_free(_ctx_bufB_ptr); - cuda_free(_attn_o_inp_ptr); - - // free shared gpu memory between layers - cuda_free(_shared_mem_ptr); - _shared_mem_ptr = nullptr; - } - - // const parameter between batch - const size_t _layer_id; - const size_t _hidden_size; - const size_t _heads; - const size_t _max_batch_tokens; - const size_t _max_seq_len; - const bool _pre_or_postLayerNorm; - // dynamic parameter between batch - size_t _batch_size; - size_t _seq_len; - size_t _batch_tokens; - size_t _batch_heads; - size_t _batch_dim; - bool _training; - - cublasHandle_t _cublasHandle; - cudaStream_t _stream; - - // layers - FeedForward _qkv_linear; - FeedForward _attn_out_linear; - Normalize_Layer _attn_ln; - Softmax _softmax; - Dropout _attn_prob_dropout; - Dropout _attn_dropout; - StridedBatchGemm _attn_scores; - StridedBatchGemm _attn_context; - - // local GPU memory - T *_gemmQKV_inp_ptr; - T *_qkv_ptr; - T *_soft_out_ptr; - T *_ctx_bufB_ptr; - T *_attn_o_inp_ptr; - // shared GPU memory between layer - static T *_shared_mem_ptr; - - c10::intrusive_ptr pg; - int pg_size; -}; diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp deleted file mode 100644 index 8444272940b4..000000000000 --- a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include - -#include "linear.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32, - "Linear SiLU (INT8)"); -} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu deleted file mode 100644 index a30d02a4cf42..000000000000 --- a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu +++ /dev/null @@ -1,162 +0,0 @@ -// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu - -#include "linear.h" -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // FP32 - float alpha, // FP32 - float beta // FP32 -) { - auto M = input.size(0); - auto N = weight.size(0); - auto K = input.size(1); - - using ElementOutput = float; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - using ElementInputA = int8_t; // <- data type of elements in input matrix A - using ElementInputB = int8_t; // <- data type of elements in input matrix B - - // The code section below describes matrix layout of input and output - // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major - // for Matrix C - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - -#if CUDA_ARCH >= 800 - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits< - ElementOutput>::value, // <- this is the number of elements per - // vectorized memory access. For half - // precision, it's 8 elements. This - // becomes the vector width of math - // instructions in epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue // <- data type for alpha in linear combination - // function - >; - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - EpilogueOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -#elif CUDA_ARCH >= 750 - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits< - ElementOutput>::value, // <- this is the number of elements per - // vectorized memory access. For half - // precision, it's 8 elements. This - // becomes the vector width of math - // instructions in epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue // <- data type for alpha in linear combination - // function - >; - - using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, - ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, - DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, - DefaultGemmCfg::InstructionShape, - EpilogueOp>; -#elif CUDA_ARCH >= 700 - #define USE_TORCH_SILU - using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< - cutlass::arch::OpClassSimt, cutlass::arch::Sm70, - ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassSimt, cutlass::arch::Sm70, - DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, - DefaultGemmCfg::InstructionShape, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>; -#else - #error "Unsupported cuda arch" -#endif - - auto input_size = cutlass::MatrixCoord(M, K); - auto weight_size = cutlass::MatrixCoord(K, N); - auto output_size = cutlass::MatrixCoord(M, N); - - auto device = input.device(); - // use the broadcasted bias as the output - auto out = bias.to(device).view({1, -1}).repeat({M, 1}); - - // constexpr int kSparse = Gemm::kSparse; - // How many elements of A are covered per ElementE - // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; - // The size of individual meta data - // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; - cutlass::gemm::GemmCoord problem_size(M, N, K); - - cutlass::TensorRef input_ref( - input.data_ptr(), LayoutInputA::packed(input_size)); - cutlass::TensorRef weight_ref( - weight.data_ptr(), LayoutInputB::packed(weight_size)); - cutlass::TensorRef out_ref( - out.data_ptr(), LayoutOutput::packed(output_size)); - - typename Gemm::Arguments arguments{ - problem_size, // <- problem size of matrix multiplication - input_ref, // <- reference to matrix A on device - weight_ref, // <- reference to matrix B on device - out_ref, // <- reference to matrix C on device - out_ref, // <- reference to matrix D on device - {alpha, beta}, 1}; - Gemm gemm_op; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm_op(); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot run"); - } -#ifdef USE_TORCH_SILU -#undef USE_TORCH_SILU - out = torch::silu(out); -#endif - return out; -} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h deleted file mode 100644 index b62a27f3f8f3..000000000000 --- a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h +++ /dev/null @@ -1,12 +0,0 @@ -#include -#include - -#include -#include - -torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // FP32 - float alpha, // FP32 - float beta // FP32 -); diff --git a/colossalai/kernel/cuda_native/mha/__init__.py b/colossalai/kernel/cuda_native/mha/__init__.py deleted file mode 100644 index cad36e598d14..000000000000 --- a/colossalai/kernel/cuda_native/mha/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .mha import ColoAttention - -__all__ = ["ColoAttention"] diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py deleted file mode 100644 index 9ee83915b1b4..000000000000 --- a/colossalai/kernel/cuda_native/mha/flash_attn_2.py +++ /dev/null @@ -1,80 +0,0 @@ -import warnings -from typing import Optional - -import torch - - -def is_ampere_or_better_gpu(): - if torch.cuda.is_available(): - device = torch.device("cuda") - properties = torch.cuda.get_device_properties(device) - if properties.major >= 8: # Ampere GPUs or newer - return True - return False - - -# "Check Ampere GPUs or newer" -HAS_FLASH_ATTN = False -if is_ampere_or_better_gpu(): - HAS_FLASH_ATTN = True -else: - warnings.warn("FlashAttention only supports Ampere GPUs or newer.") - HAS_FLASH_ATTN = False -try: - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func - - HAS_FLASH_ATTN = True -except ImportError: - warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention") - HAS_FLASH_ATTN = False - -if HAS_FLASH_ATTN: - pass - - from .utils import SeqLenInfo - - def flash_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - """ - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - if padded: - if seq_len_info_kv == None: - seq_len_info_kv = seq_len_info_q - - attn_out = flash_attn_varlen_func( - q, - k, - v, - seq_len_info_q.cu_seqlens, - seq_len_info_kv.cu_seqlens, - seq_len_info_q.max_seqlen, - seq_len_info_kv.max_seqlen, - dropout_p, - scale, - causal, - ) - else: - attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) - return attn_out diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py deleted file mode 100644 index 649e74d61bab..000000000000 --- a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py +++ /dev/null @@ -1,70 +0,0 @@ -import warnings - -HAS_MEM_EFF_ATTN = False -try: - from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention - from xformers.ops.fmha.attn_bias import ( - BlockDiagonalCausalMask, - BlockDiagonalMask, - LowerTriangularMask, - LowerTriangularMaskWithTensorBias, - ) - - HAS_MEM_EFF_ATTN = True -except ImportError: - warnings.warn("please install xformers from https://github.com/facebookresearch/xformers") - HAS_MEM_EFF_ATTN = False - -if HAS_MEM_EFF_ATTN: - """ - A general attention module using the flash attention kernels from xformers: - https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha - """ - from typing import Optional - - import torch - - from .utils import SeqLenInfo - - allow_alibi = True - for op in MemoryEfficientAttentionCutlassOp: - allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) - - def mem_eff_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - attn_bias = None - if padded: # bert style - if not causal: - attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - else: - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - elif causal: # gpt style - attn_bias = LowerTriangularMask() - - if bias is not None: # alibi / relative position embedding - assert allow_alibi, "flash attention with bias is not supported in this system." - assert causal, "attention with bias is only supported for causal attention so far." - attn_bias = attn_bias.add_bias(bias) - - if padded: - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - - out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) - - # shape: (b*s, n, d) - if padded: - out = out.squeeze(0) - - return out diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py deleted file mode 100644 index 1c778439d33f..000000000000 --- a/colossalai/kernel/cuda_native/mha/mha.py +++ /dev/null @@ -1,113 +0,0 @@ -import math -from typing import Optional - -import torch -from einops import rearrange - -from ..scaled_softmax import AttnMaskType -from .flash_attn_2 import HAS_FLASH_ATTN -from .mem_eff_attn import HAS_MEM_EFF_ATTN -from .utils import Repad, SeqLenInfo, Unpad - -if HAS_FLASH_ATTN: - from .flash_attn_2 import flash_attention -if HAS_MEM_EFF_ATTN: - from .mem_eff_attn import mem_eff_attention - - -class ColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): - super().__init__() - assert ( - embed_dim % num_heads == 0 - ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - if scale is not None: - self.scale = scale - else: - self.scale = 1 / math.sqrt(embed_dim // num_heads) - self.dropout = dropout - - if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN: - raise Exception("flash attention can not support!") - - @staticmethod - def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - return Unpad.apply(tensor, indices) - - @staticmethod - def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: - return Repad.apply(tensor, indices, batch_size, seq_len) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: Optional[AttnMaskType] = None, - bias: Optional[torch.Tensor] = None, - ): - attn = None - if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: - attn = flash_attention - else: - attn = mem_eff_attention - - padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 - causal = attn_mask_type is not None and attn_mask_type.value > 1 - - batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] - # unpad - seq_len_info_q = None - seq_len_info_kv = None - if padded: - # bert style, unpad process - assert ( - attn_mask is not None - ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." - assert attn_mask.dim() == 2, ( - "attention mask is supposed to have shape (batch_size, seq_len), " - + f"but got {attn_mask.dim()} dimensions." - ) - - # bert style - if tgt_len == src_len: - seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query, key, value = self.unpad( - torch.stack([query, key, value], dim=2), seq_len_info_q.indices - ).unbind(dim=1) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - seq_len_info_kv = seq_len_info_q - else: - seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) - seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query = rearrange(query, "b s ... -> c (b s) ...", c=1) - key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( - dim=1 - ) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - - out = attn( - query, - key, - value, - seq_len_info_q, - seq_len_info_kv, - dropout_p=self.dropout, - scale=self.scale, - causal=causal, - padded=padded, - ) - - # repad - if padded: - if batch_size > 1: - out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) - out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) - - out = rearrange(out, "b s h d -> b s (h d)") - return out diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py deleted file mode 100644 index 5f01e3ef327d..000000000000 --- a/colossalai/kernel/cuda_native/mha/utils.py +++ /dev/null @@ -1,82 +0,0 @@ -from dataclasses import dataclass -from typing import Iterable, Tuple - -import torch -import torch.nn.functional as F -from einops import rearrange - -from colossalai.utils.device import get_current_device - - -class Unpad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): - ctx.save_for_backward(indices) - # [b, s, ...] - assert tensor.ndim >= 3 - ctx.bsz = tensor.shape[0] - out = rearrange(tensor, "b s ... -> (b s) ...") - ctx.shape = out.shape - # [ntokens, ...] - return out[indices] - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [ntokens, ...] - grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) - grad[indices] = grad_output - grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) - # [b, s, ...] - return grad, None - - -class Repad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): - ctx.save_for_backward(indices) - # [ntokens, ...] - tensor = tensor - out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) - # [b*s, ...] - out[indices] = tensor - return out - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [b*s, ...] - grad = grad_output[indices] - # [ntokens, ...] - return grad, None, None, None - - -@dataclass -class SeqLenInfo: - seqlens: Iterable[int] = None - indices: torch.Tensor = None - max_seqlen: int = None - cu_seqlens: torch.Tensor = None - - @staticmethod - def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()): - if attn_mask is not None: - indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) - seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() - else: - batch_size, tgt_len = size[0], size[1] - indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) - seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) - max_seqlen = max(seqlens) - cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) - return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py deleted file mode 100644 index 87afc1862847..000000000000 --- a/colossalai/kernel/cuda_native/multihead_attention.py +++ /dev/null @@ -1,338 +0,0 @@ -import math -from dataclasses import dataclass - -import torch -from torch import nn -from torch.autograd import Function - - -def check_config(config): - if config.hidden_size % config.nhead != 0: - raise Exception("hidden_size % nhead != 0") - - factor = 8 if config.fp16 else 4 - upbound = factor * 1024 * 4 - if config.hidden_size > upbound: - # as required by ln backward kernel currently - raise Exception(f"hidden_size > {upbound}") - - head_dim = config.hidden_size // config.nhead - if head_dim % factor != 0: - # as required by reshape kernel - raise Exception(f"head_dim({head_dim}) % {factor} != 0") - - -def calc_offset(sizes): - offsets = [0] - tmp = 0 - for x in sizes: - tmp += x - offsets.append(tmp) - return offsets - - -colossal_multihead_attention = None - - -@dataclass -class Config: - max_batch_tokens: int # max batch token numbers - max_seq_len: int # max sequence length - hidden_size: int # size of transformer hidden layers - nhead: int # number of heads in attention - attn_prob_dropout_ratio: float # attention score dropout ratio - hidden_dropout_ratio: float # dropout ration before residual - norm_first: bool # norm_first - fp16: bool # fp16 precision - - -class MultiHeadAttention1DFunc(Function): - @staticmethod - def forward( - ctx, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - config, - ): - cuda_module = colossal_multihead_attention - forward_func = ( - cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32 - ) - if config.fp16: - input = input.to(torch.half) - input_mask = input_mask.to(torch.half) - - (output,) = forward_func( - config.layer_id, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - config.training, - config.norm_first, - ) - - if config.is_grad_enabled and config.training: - ctx.save_for_backward( - output, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - ) - ctx.config = config - return output - - @staticmethod - def backward(ctx, grad_output): - assert ctx.config.training - - cuda_module = colossal_multihead_attention - backward_func = ( - cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32 - ) - - ( - output, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - ) = ctx.saved_tensors - - grad_input = None - grad_in_proj_weight = None - grad_in_proj_bias = None - grad_out_proj_weight = None - grad_out_proj_bias = None - grad_norm_weight = None - grad_norm_bias = None - - if ctx.config.fp16: - grad_output = grad_output.to(torch.half) - output = output.to(torch.half) - input = input.to(torch.half) - input_mask = input_mask.to(torch.half) - ( - grad_input, - grad_in_proj_weight, - grad_in_proj_bias, - grad_out_proj_weight, - grad_out_proj_bias, - grad_norm_weight, - grad_norm_bias, - ) = backward_func( - ctx.config.layer_id, - grad_output, - output, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - ) - - return ( - grad_input, - None, - grad_in_proj_weight, - grad_in_proj_bias, - grad_out_proj_weight, - grad_out_proj_bias, - grad_norm_weight, - grad_norm_bias, - None, - ) - - -class MultiHeadAttention(nn.Module): - """Initialize the MultiHeadAttention. - - Static variable: - - layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated, - e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23. - - Arguments: - hidden_size: Total dimension of hidden_size. - nhead: Number of parallel attention heads. - batch_size: Batch Size for one forward - max_seq_len: Max length of input sequence - dropout: Dropout probability - norm_first: perform LayerNorms before attention - """ - - layer_id = 0 - - def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None): - super(MultiHeadAttention, self).__init__() - - self.config = Config( - batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16 - ) - check_config(self.config) - self.pg = pg - self.pg_size = 1 - if self.pg: - self.pg_size = pg.size() - self.config.layer_id = MultiHeadAttention.layer_id - MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1 - - # Load cuda modules if needed - global colossal_multihead_attention - if colossal_multihead_attention is None: - from colossalai.kernel.op_builder import MultiHeadAttnBuilder - - multihead_attention = MultiHeadAttnBuilder().load() - colossal_multihead_attention = multihead_attention - - # create the layer in cuda kernels. - cuda_module = colossal_multihead_attention - create_layer_func = ( - cuda_module.create_multihead_attention_fp16 - if self.config.fp16 - else cuda_module.create_multihead_attention_fp32 - ) - - create_layer_func( - self.config.layer_id, - self.config.max_batch_tokens, - self.config.max_seq_len, - self.config.hidden_size, - self.config.nhead, - self.config.attn_prob_dropout_ratio, - self.config.hidden_dropout_ratio, - self.config.norm_first, - self.pg, - ) - - hs = self.config.hidden_size - - self.precision = torch.float32 - if self.config.fp16: - self.precision = torch.half - - self.hs_per_rank = int(hs / self.pg_size) - - self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs)) - self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank)) - self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank)) - self.out_proj_bias = nn.Parameter(torch.Tensor(hs)) - self.norm_weight = nn.Parameter(torch.Tensor(hs)) - self.norm_bias = nn.Parameter(torch.Tensor(hs)) - - self.reset_parameters() - torch.cuda.empty_cache() - - def calc_bound(self, w): - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w) - bound = 1.0 / math.sqrt(fan_in) - return bound - - def reset_parameters(self): - hs = self.config.hidden_size - - nn.init.zeros_(self.out_proj_bias) - - nn.init.ones_(self.norm_weight) - nn.init.zeros_(self.norm_bias) - - if self.pg_size > 1: - rank_in_pg = torch.distributed.get_rank(self.pg) - attn_qkvw_global = torch.empty(hs * 3, hs) - attn_qkvb_global = torch.empty(hs * 3) - nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0)) - bound = self.calc_bound(attn_qkvw_global) - nn.init.uniform_(attn_qkvb_global, -bound, bound) - - attn_qkvw_global = attn_qkvw_global.cuda() - attn_qkvb_global = attn_qkvb_global.cuda() - torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg) - torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg) - attn_qkvw_global = attn_qkvw_global.cpu() - attn_qkvb_global = attn_qkvb_global.cpu() - - with torch.no_grad(): - self.in_proj_weight.copy_( - attn_qkvw_global.view(3, hs, hs)[ - :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), : - ] - ) - self.in_proj_bias.copy_( - attn_qkvb_global.view(3, hs)[ - :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size) - ] - ) - - attn_ow_global = torch.empty(hs, hs) - nn.init.xavier_uniform_(attn_ow_global, 1.0) - attn_ow_global = attn_ow_global.cuda() - torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) - attn_ow_global = attn_ow_global.cpu() - with torch.no_grad(): - self.out_proj_weight.copy_( - attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)] - ) - - else: - attn_qkvw = self.in_proj_weight.view(-1, hs) - nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0)) - bound = self.calc_bound(attn_qkvw) - nn.init.uniform_(self.in_proj_bias, -bound, bound) - - nn.init.xavier_uniform_(self.out_proj_weight, 1.0) - - def state_dict(self, destination=None, prefix="", keep_vars=False): - destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars) - return destination - - def forward(self, hidden_states, encoder_padding_mask): - self.config.training = self.training - self.config.is_grad_enabled = torch.is_grad_enabled() - hidden_states = hidden_states.contiguous() - encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous() - - bs, sl, dim = hidden_states.size() - if bs * sl > self.config.max_batch_tokens: - raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.") - if sl > self.config.max_seq_len: - raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.") - if len(encoder_padding_mask.size()) == 1: - assert bs == 1 and sl == encoder_padding_mask.size(0) - else: - assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) - - output = MultiHeadAttention1DFunc.apply( - hidden_states, - encoder_padding_mask, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.norm_weight, - self.norm_bias, - self.config, - ) - - return output.to(self.precision) diff --git a/colossalai/kernel/extensions b/colossalai/kernel/extensions new file mode 120000 index 000000000000..e8eb45a54893 --- /dev/null +++ b/colossalai/kernel/extensions @@ -0,0 +1 @@ +../../extensions \ No newline at end of file diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index 8bebad894ca4..d392649a62f2 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -1,7 +1,7 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear -from colossalai.utils import get_current_device from .bias_dropout_add import bias_dropout_add_fused_train from .bias_gelu import bias_gelu_impl @@ -46,11 +46,13 @@ def warmup_jit_fusion( ): """Compile JIT functions before the main training steps""" - embed = Embedding(vocab_size, hidden_size).to(get_current_device()) - linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) - linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device()) + embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device()) + linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device()) + linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_accelerator().get_current_device()) - x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device()) + x = torch.randint( + vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_accelerator().get_current_device() + ) x = embed(x) y, y_bias = linear_1(x) z, z_bias = linear_2(y) @@ -58,8 +60,8 @@ def warmup_jit_fusion( # prop and recomputation for bias_grad, input_grad in zip([True, True], [False, True]): for _ in range(10): - bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device()) - input_ = torch.rand_like(y, dtype=dtype, device=get_current_device()) + bias = torch.rand_like(y_bias, dtype=dtype, device=get_accelerator().get_current_device()) + input_ = torch.rand_like(y, dtype=dtype, device=get_accelerator().get_current_device()) bias.requires_grad, input_.requires_grad = bias_grad, input_grad bias_gelu_impl(input_, bias) @@ -69,9 +71,9 @@ def warmup_jit_fusion( # prop and recomputation for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): for _ in range(10): - input_ = torch.rand_like(z, dtype=dtype, device=get_current_device()) - residual = torch.rand_like(x, dtype=dtype, device=get_current_device()) - bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device()) + input_ = torch.rand_like(z, dtype=dtype, device=get_accelerator().get_current_device()) + residual = torch.rand_like(x, dtype=dtype, device=get_accelerator().get_current_device()) + bias = torch.rand_like(z_bias, dtype=dtype, device=get_accelerator().get_current_device()) input_.requires_grad = input_grad bias.requires_grad = bias_grad residual.requires_grad = residual_grad diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py new file mode 100644 index 000000000000..148c3e3fc08a --- /dev/null +++ b/colossalai/kernel/kernel_loader.py @@ -0,0 +1,109 @@ +import warnings +from typing import List + +from .extensions import ( + CpuAdamArmExtension, + CpuAdamX86Extension, + FlashAttentionDaoCudaExtension, + FlashAttentionNpuExtension, + FlashAttentionXformersCudaExtension, + FusedOptimizerCudaExtension, + LayerNormCudaExtension, + MoeCudaExtension, + ScaledMaskedSoftmaxCudaExtension, + ScaledUpperTriangleMaskedSoftmaxCudaExtension, +) +from .extensions.base_extension import _Extension + +__all__ = [ + "KernelLoader", + "CPUAdamLoader", + "LayerNormLoader", + "MoeLoader", + "FusedOptimizerLoader", + "ScaledMaskedSoftmaxLoader", + "ScaledUpperTriangleMaskedSoftmaxLoader", +] + + +class KernelLoader: + """ + An abstract class which offers encapsulation to the kernel loading process. + + Usage: + kernel_loader = KernelLoader() + kernel = kernel_loader.load() + """ + + REGISTRY: List[_Extension] = [] + + @classmethod + def register_extension(cls, extension: _Extension): + """ + This classmethod is an extension point which allows users to register their customized + kernel implementations to the loader. + + Args: + extension (_Extension): the extension to be registered. + """ + cls.REGISTRY.append(extension) + + def load(self, ext_name: str = None): + """ + Load the kernel according to the current machine. + + Args: + ext_name (str): the name of the extension to be loaded. If not specified, the loader + will try to look for an kernel available on the current machine. + """ + exts = [ext_cls() for ext_cls in self.__class__.REGISTRY] + + # look for exts which can be built/loaded on the current machine + + if ext_name: + usable_exts = list(filter(lambda ext: ext.name == ext_name, exts)) + else: + usable_exts = [] + for ext in exts: + if ext.is_hardware_available(): + # make sure the machine is compatible during kernel loading + ext.assert_hardware_compatible() + usable_exts.append(ext) + + assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine." + + if len(usable_exts) > 1: + # if more than one usable kernel is found, we will try to load the kernel with the highest priority + usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True) + warnings.warn( + f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}" + ) + return usable_exts[0].load() + + +class CPUAdamLoader(KernelLoader): + REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension] + + +class LayerNormLoader(KernelLoader): + REGISTRY = [LayerNormCudaExtension] + + +class MoeLoader(KernelLoader): + REGISTRY = [MoeCudaExtension] + + +class FusedOptimizerLoader(KernelLoader): + REGISTRY = [FusedOptimizerCudaExtension] + + +class ScaledMaskedSoftmaxLoader(KernelLoader): + REGISTRY = [ScaledMaskedSoftmaxCudaExtension] + + +class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader): + REGISTRY = [ScaledUpperTriangleMaskedSoftmaxCudaExtension] + + +class FlashAttentionLoader(KernelLoader): + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension] diff --git a/colossalai/kernel/op_builder b/colossalai/kernel/op_builder deleted file mode 120000 index db4f9c335065..000000000000 --- a/colossalai/kernel/op_builder +++ /dev/null @@ -1 +0,0 @@ -../../op_builder \ No newline at end of file diff --git a/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py index 97ec57fbd007..d2dceb50b240 100644 --- a/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py @@ -7,7 +7,7 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler -from colossalai.kernel.op_builder import FusedOptimBuilder +from colossalai.kernel.kernel_loader import FusedOptimizerLoader from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes @@ -28,7 +28,7 @@ def load_fused_optim(): global fused_optim if fused_optim is None: - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): diff --git a/colossalai/legacy/amp/torch_amp/torch_amp.py b/colossalai/legacy/amp/torch_amp/torch_amp.py index 0a8d09be21ea..08f867eee96c 100644 --- a/colossalai/legacy/amp/torch_amp/torch_amp.py +++ b/colossalai/legacy/amp/torch_amp/torch_amp.py @@ -1,18 +1,19 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from colossalai.utils.device import autocast - import torch.nn as nn from torch import Tensor from torch.nn.modules.loss import _Loss from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.legacy.utils import clip_grad_norm_fp32 from ._grad_scaler import GradScaler +autocast = get_accelerator().autocast + class TorchAMPOptimizer(OptimizerWrapper): """A wrapper class which integrate Pytorch AMP with an optimizer diff --git a/colossalai/legacy/communication/p2p.py b/colossalai/legacy/communication/p2p.py index 19c3919b6e29..cf0bd4ba2437 100644 --- a/colossalai/legacy/communication/p2p.py +++ b/colossalai/legacy/communication/p2p.py @@ -8,9 +8,9 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks @@ -43,12 +43,16 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): if isinstance(recv_shapes, torch.Size): recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) - buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + buffer_recv = torch.empty( + recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype + ) return buffer_recv, recv_split buffer_recv = [] for recv_shape in recv_shapes: recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors) - tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + tensor_recv = torch.empty( + recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype + ) buffer_recv.append(tensor_recv) return buffer_recv, recv_split diff --git a/colossalai/legacy/communication/ring.py b/colossalai/legacy/communication/ring.py index a61dae56cd42..792a15abdfae 100644 --- a/colossalai/legacy/communication/ring.py +++ b/colossalai/legacy/communication/ring.py @@ -3,9 +3,9 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device, synchronize def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor: @@ -29,7 +29,7 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> current_rank = gpc.get_global_rank() tensor_recv_prev = torch.empty( - buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype + buffer_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=tensor_send_next.dtype ) # send to next rank @@ -52,6 +52,6 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> req.wait() # To protect against race condition when using batch_isend_irecv(). - synchronize() + get_accelerator().synchronize() return tensor_recv_prev diff --git a/colossalai/legacy/communication/utils.py b/colossalai/legacy/communication/utils.py index 6d77f3753fe8..0b7c0eb74651 100644 --- a/colossalai/legacy/communication/utils.py +++ b/colossalai/legacy/communication/utils.py @@ -3,9 +3,9 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device TensorShape = Union[torch.Size, List[int], Tuple[int]] @@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: if next_rank is None: next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()} if isinstance(obj, torch.Tensor): send_obj_nums = torch.tensor(1, **tensor_kwargs) dist.send(send_obj_nums, next_rank) @@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: if prev_rank is None: prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()} recv_obj_nums = torch.empty((), **tensor_kwargs) dist.recv(recv_obj_nums, prev_rank) if recv_obj_nums.item() == 1: diff --git a/colossalai/legacy/engine/schedule/_base_schedule.py b/colossalai/legacy/engine/schedule/_base_schedule.py index 4a3ccfda1bb5..9b2913442225 100644 --- a/colossalai/legacy/engine/schedule/_base_schedule.py +++ b/colossalai/legacy/engine/schedule/_base_schedule.py @@ -6,8 +6,8 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device class BaseSchedule(ABC): @@ -29,12 +29,12 @@ def __init__(self, data_process_func: Callable = None): def _move_tensor(element): if torch.is_tensor(element): if not element.is_cuda: - return element.to(get_current_device()).detach() + return element.to(get_accelerator().get_current_device()).detach() return element def _move_to_device(self, data): if isinstance(data, torch.Tensor): - data = data.to(get_current_device()) + data = data.to(get_accelerator().get_current_device()) elif isinstance(data, (list, tuple)): data_to_return = [] for element in data: diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py index 5fd5602e790c..4a23853c137a 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py @@ -7,12 +7,12 @@ import torch.cuda import colossalai.legacy.communication as comm +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp.naive_amp import NaiveAMPModel from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank from colossalai.logging import get_dist_logger -from colossalai.utils.device import get_current_device from ._base_schedule import BaseSchedule @@ -352,7 +352,7 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo output_objs = [] return_tensors = [] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None # Used for tensor meta information communication @@ -584,7 +584,7 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo if not forward_only: output_obj_grads = [[] for _ in range(len(model))] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py index 4cd7e47c37f1..6e7760218c16 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py @@ -6,10 +6,10 @@ import torch.cuda import colossalai.legacy.communication.p2p_v2 as comm +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.engine import Engine -from colossalai.utils.device import get_current_device from ._pipeline_schedule import PipelineSchedule @@ -99,7 +99,7 @@ def forward_backward_step( output_objs = [] return_tensors = [] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py index 4035bd6b54ef..d99a7d3f0c65 100644 --- a/colossalai/legacy/initialize.py +++ b/colossalai/legacy/initialize.py @@ -15,6 +15,7 @@ from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader +from colossalai.accelerator import get_accelerator from colossalai.context import Config, ConfigException from colossalai.interface import OptimizerWrapper from colossalai.legacy.amp import AMP_TYPE, convert_to_amp @@ -34,7 +35,6 @@ from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2 from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device def get_default_parser(): @@ -309,9 +309,9 @@ def initialize( else: if isinstance(model, nn.Module): # first sync model across dp ranks - model.to(get_current_device()) + model.to(get_accelerator().get_current_device()) elif isinstance(model, Callable): - model = model().to(get_current_device()) + model = model().to(get_accelerator().get_current_device()) # optimizer maybe a optimizer_cls if isinstance(optimizer, Callable): diff --git a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py index e1db0fe98a02..aa661664f4e8 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py @@ -3,8 +3,8 @@ from torch import dtype, nn +from colossalai.accelerator import get_accelerator from colossalai.nn import init -from colossalai.utils import get_current_device from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D @@ -83,7 +83,7 @@ def __init__( embed = ( nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs) .to(dtype) - .to(get_current_device()) + .to(get_accelerator().get_current_device()) ) weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) elif num_embeddings <= vocab_parallel_limit: diff --git a/colossalai/legacy/nn/layer/colossalai_layer/normalization.py b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py index f8e317e723f1..58842f481a10 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/normalization.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py @@ -1,6 +1,6 @@ from torch import nn -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from ..parallel_1d import LayerNorm1D from ..parallel_2d import LayerNorm2D @@ -36,7 +36,7 @@ class LayerNorm(ColossalaiModule): def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: tensor_parallel = get_tensor_parallel_mode() if tensor_parallel is None: - norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) + norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_accelerator().get_current_device()) else: norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) super().__init__(norm) diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py index b6ec5347f2e2..b38e1c4338b2 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py @@ -10,7 +10,7 @@ from torch import Tensor from torch.nn.parameter import Parameter -from colossalai.kernel import LayerNorm +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context.parallel_context import global_context as gpc @@ -22,7 +22,7 @@ partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule @@ -221,7 +221,7 @@ def __init__( # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False @@ -357,7 +357,7 @@ def __init__( # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False @@ -499,7 +499,7 @@ def __init__( # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) if bias: @@ -638,7 +638,7 @@ def __init__( # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) if self.stream_chunk_num > 1: @@ -802,7 +802,9 @@ def __init__( self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -912,7 +914,11 @@ def __init__( self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype) + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim), + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.reset_parameters(weight_initializer) diff --git a/colossalai/legacy/nn/layer/parallel_2d/_operation.py b/colossalai/legacy/nn/layer/parallel_2d/_operation.py index f1eff7128e7a..f67ee2e60be1 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2d/_operation.py @@ -5,10 +5,10 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device def matmul_2d( @@ -250,7 +250,7 @@ def forward( B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[-1]) - C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -399,7 +399,7 @@ def forward( B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[0]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -556,7 +556,7 @@ def forward( B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[-1], B.shape[-1]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py index f81c5334ad77..4987afa18672 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py @@ -8,6 +8,7 @@ from torch import Tensor from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.core import global_context as gpc @@ -18,7 +19,6 @@ partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple @@ -82,7 +82,7 @@ def __init__( self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter( torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) ) @@ -259,7 +259,7 @@ def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=N self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) # create parameters - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: @@ -438,18 +438,24 @@ def __init__( self.weight = Parameter( torch.empty( (self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) - self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) + ) self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.pos_embed = Parameter( torch.zeros( - (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + (1, self.num_patches + 1, self.embed_size_per_partition), + device=get_accelerator().get_current_device(), + dtype=dtype, ) ) @@ -619,7 +625,9 @@ def __init__( self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -758,7 +766,7 @@ def __init__( self.weight = Parameter( torch.empty( (self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) @@ -895,11 +903,18 @@ def __init__( self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, + self.input_size_per_partition, + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -1052,7 +1067,7 @@ def __init__( self.output_size_per_partition = divide(num_classes, self.summa_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py index 50900c135cab..43328bd033c8 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py @@ -5,10 +5,10 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device def get_parallel_group(parallel_mode: ParallelMode): @@ -205,7 +205,7 @@ def forward( B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[-1]) - C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -362,7 +362,7 @@ def forward( B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[0]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -527,7 +527,7 @@ def forward( B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[-1], B.shape[-1]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -661,7 +661,9 @@ def forward( if row_rank == 0: bias_temp = bias.clone() else: - bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) + bias_temp = torch.zeros( + output_size_per_partition, dtype=bias.dtype, device=get_accelerator().get_current_device() + ) src_rank = ( col_rank + dep_rank * tesseract_dim**2 @@ -984,7 +986,7 @@ def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: Par @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: grad_shape = (ctx.batch_size,) + output_grad.shape[1:] - grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) + grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_accelerator().get_current_device()) dist.all_gather( list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode) ) diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py index b451a4031c25..d9410f1cbcbc 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py @@ -8,6 +8,7 @@ from torch import Tensor from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.core import global_context as gpc @@ -19,7 +20,6 @@ partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple @@ -84,7 +84,7 @@ def __init__( self.hidden_size_per_partition = divide(out_features, self.tesseract_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter( torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) ) @@ -272,7 +272,7 @@ def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=N self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * # create parameters - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: @@ -451,18 +451,24 @@ def __init__( self.weight = Parameter( torch.empty( (self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) - self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) + ) self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.pos_embed = Parameter( torch.zeros( - (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + (1, self.num_patches + 1, self.embed_size_per_partition), + device=get_accelerator().get_current_device(), + dtype=dtype, ) ) @@ -632,7 +638,9 @@ def __init__( self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -772,7 +780,7 @@ def __init__( self.weight = Parameter( torch.empty( (self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) @@ -910,11 +918,18 @@ def __init__( self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, + self.input_size_per_partition, + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -1068,7 +1083,7 @@ def __init__( self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py index 16e515f87da3..bb01ec85130a 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py @@ -8,6 +8,7 @@ from torch import Tensor from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import all_reduce, broadcast from colossalai.legacy.constants import ( INPUT_GROUP_3D, @@ -27,7 +28,6 @@ partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ._operation import ( @@ -69,11 +69,13 @@ def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=N self.normalized_shape_per_partition = divide(normalized_shape, self.depth) self.weight = Parameter( - torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + torch.ones(self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) ) if bias: self.bias = Parameter( - torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + torch.zeros( + self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype + ) ) else: self.bias = None @@ -202,13 +204,15 @@ def __init__( torch.empty( self.in_features_per_partition, self.out_features_per_partition, - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) if bias: self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + torch.zeros( + self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype + ) ) else: self.bias = None @@ -380,11 +384,18 @@ def __init__( self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, + self.in_features_per_partition, + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -523,14 +534,16 @@ def __init__( torch.empty( self.out_features_per_partition, self.in_features_per_partition, - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) self.has_weight = True if bias: self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + torch.zeros( + self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype + ) ) else: self.bias = None @@ -705,16 +718,24 @@ def __init__( self.weight = nn.Parameter( torch.empty( - (embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype + (embed_size_per_partition, in_chans, *self.patch_size), + device=get_accelerator().get_current_device(), + dtype=dtype, ) ) - self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = nn.Parameter( + torch.empty(embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) + ) self.cls_token = nn.Parameter( - torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros((1, 1, embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype) ) self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, self.num_patches + 1, embed_size_per_partition), + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) @@ -880,7 +901,9 @@ def __init__( self.embed_kwargs = kwargs self.weight = nn.Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -1019,7 +1042,7 @@ def __init__( self.weight = Parameter( torch.empty( (self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) diff --git a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py index 24d5499e3a5f..4e9bf364d8eb 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py @@ -5,11 +5,11 @@ from torch import distributed as dist from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import ring_forward from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range -from colossalai.utils import get_current_device class RingQK(torch.autograd.Function): @@ -30,7 +30,7 @@ def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length): sub_seq_length, sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), dtype=sub_q.dtype, - device=get_current_device(), + device=get_accelerator().get_current_device(), ) # compute local QK^T @@ -71,7 +71,7 @@ def backward(ctx, grad_output): grad_q = torch.zeros_like( sub_q, dtype=sub_q.dtype, - device=get_current_device(), + device=get_accelerator().get_current_device(), ) # compute with local sub_k @@ -105,7 +105,7 @@ def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attent batch_size * num_attention_heads, sub_seq_length, attention_head_size, - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=attention_score.dtype, ) @@ -142,7 +142,9 @@ def backward(ctx, grad_output): grad_v /= local_world_size # calculate gradient for attention score - grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device()) + grad_attention_score = torch.zeros_like( + attention_scores, dtype=grad_output.dtype, device=get_accelerator().get_current_device() + ) # compute with local sub_k grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) diff --git a/colossalai/legacy/nn/layer/parallel_sequence/layers.py b/colossalai/legacy/nn/layer/parallel_sequence/layers.py index 063b0cd8e2b2..445b7e4cda2a 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/layers.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/layers.py @@ -8,13 +8,12 @@ import torch.nn.functional as F from torch.nn import Parameter -from colossalai.kernel import FusedScaleMaskSoftmax -from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType from colossalai.legacy.context import seed from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK from colossalai.legacy.registry import LAYERS +from colossalai.nn.layer.scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax @LAYERS.register_module diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py index 590ad5ff6085..3a1c2e57b4be 100644 --- a/colossalai/legacy/nn/layer/vanilla/layers.py +++ b/colossalai/legacy/nn/layer/vanilla/layers.py @@ -7,10 +7,10 @@ from torch import nn as nn from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import seed from colossalai.legacy.registry import LAYERS from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..utils import to_2tuple @@ -173,12 +173,18 @@ def __init__( self.flatten = flatten self.weight = nn.Parameter( - torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype) + torch.empty( + (embed_size, in_chans, *self.patch_size), device=get_accelerator().get_current_device(), dtype=dtype + ) + ) + self.bias = nn.Parameter(torch.empty(embed_size, device=get_accelerator().get_current_device(), dtype=dtype)) + self.cls_token = nn.Parameter( + torch.zeros((1, 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype) ) - self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) - self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, self.num_patches + 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) @@ -242,11 +248,15 @@ def __init__( self.has_weight = False else: self.weight = nn.Parameter( - torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, self.in_features, device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.has_weight = True if bias: - self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = nn.Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -287,7 +297,7 @@ def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): self.normalized_shape = (normalized_shape,) self.variance_epsilon = eps - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) if bias: @@ -333,7 +343,7 @@ def __init__( self.in_features = in_features self.out_features = out_features self.skip_bias_add = skip_bias_add - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) if bias: self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) diff --git a/colossalai/legacy/nn/loss/loss_2d.py b/colossalai/legacy/nn/loss/loss_2d.py index 44f39a6db262..474fd4a2cb9c 100644 --- a/colossalai/legacy/nn/loss/loss_2d.py +++ b/colossalai/legacy/nn/loss/loss_2d.py @@ -4,12 +4,12 @@ from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.legacy.registry import LOSSES -from colossalai.utils import get_current_device @LOSSES.register_module @@ -118,7 +118,7 @@ def backward(ctx, output_grad): grad_2d = grad_input.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device()) grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. diff --git a/colossalai/legacy/nn/loss/loss_2p5d.py b/colossalai/legacy/nn/loss/loss_2p5d.py index c57bf26e9139..b423ab3d8699 100644 --- a/colossalai/legacy/nn/loss/loss_2p5d.py +++ b/colossalai/legacy/nn/loss/loss_2p5d.py @@ -4,12 +4,12 @@ from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.legacy.registry import LOSSES -from colossalai.utils import get_current_device @LOSSES.register_module @@ -112,7 +112,7 @@ def backward(ctx, output_grad): grad_2d = grad_input.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device()) grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. diff --git a/colossalai/legacy/nn/loss/loss_3d.py b/colossalai/legacy/nn/loss/loss_3d.py index 988317cae3eb..de6a674d61db 100644 --- a/colossalai/legacy/nn/loss/loss_3d.py +++ b/colossalai/legacy/nn/loss/loss_3d.py @@ -4,12 +4,12 @@ from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.accelerator import get_accelerator from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.legacy.registry import LOSSES -from colossalai.utils import get_current_device @LOSSES.register_module @@ -80,7 +80,7 @@ def forward(ctx, logits, targets, output_parallel_mode): target_mask = (targets < vocab_start) | (targets > vocab_end) masked_target = targets.clone() - vocab_start masked_target[target_mask] = 0 - arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_accelerator().get_current_device()) predicted_logits = logits[arange_1d, masked_target] predicted_logits = predicted_logits.clone().contiguous().view_as(targets) predicted_logits[target_mask] = 0.0 @@ -110,7 +110,7 @@ def backward(ctx, output_grad): grad_2d = input_grad.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device()) grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() input_grad.mul_(output_grad.unsqueeze(dim=-1)) diff --git a/colossalai/legacy/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py index 35a7f0a156ab..0e6731db5a77 100644 --- a/colossalai/legacy/trainer/hooks/_metric_hook.py +++ b/colossalai/legacy/trainer/hooks/_metric_hook.py @@ -7,12 +7,12 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import all_reduce from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import HOOKS from colossalai.legacy.utils import is_no_pp_or_last_stage -from colossalai.utils import get_current_device from ._base_hook import BaseHook from ._commons_ import _format_number @@ -82,8 +82,8 @@ class LossMetric(Metric): def __init__(self, epoch_only): super().__init__(epoch_only=epoch_only) - self.last_step_loss = torch.zeros(1, device=get_current_device()) - self.accum_loss = torch.zeros(1, device=get_current_device()) + self.last_step_loss = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) self.count = 0 def reset(self) -> None: @@ -164,10 +164,10 @@ class AccuracyMetric(Metric): def __init__(self, epoch_only: bool, accuracy_func: Callable): super().__init__(epoch_only=epoch_only) self.acc = accuracy_func - self.last_step_sum = torch.zeros(1, device=get_current_device()) - self.last_step_correct = torch.zeros(1, device=get_current_device()) - self.accumulated_sum = torch.zeros(1, device=get_current_device()) - self.accumulated_correct = torch.zeros(1, device=get_current_device()) + self.last_step_sum = torch.zeros(1, device=get_accelerator().get_current_device()) + self.last_step_correct = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accumulated_sum = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accumulated_correct = torch.zeros(1, device=get_accelerator().get_current_device()) def reset(self) -> None: self.last_step_sum.zero_() @@ -320,10 +320,10 @@ def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int super().__init__(epoch_only=epoch_only) self.ignored_steps = ignored_steps self.cur_steps = 0 - self.accumulated_num_samples = torch.zeros(1, device=get_current_device()) - self.accumulated_used_time = torch.zeros(1, device=get_current_device()) - self.last_step_num_samples = torch.zeros(1, device=get_current_device()) - self.last_step_used_time = torch.zeros(1, device=get_current_device()) + self.accumulated_num_samples = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accumulated_used_time = torch.zeros(1, device=get_accelerator().get_current_device()) + self.last_step_num_samples = torch.zeros(1, device=get_accelerator().get_current_device()) + self.last_step_used_time = torch.zeros(1, device=get_accelerator().get_current_device()) self._tflop_per_step = tflop_per_step self._use_local = use_local diff --git a/colossalai/legacy/utils/activation_checkpoint.py b/colossalai/legacy/utils/activation_checkpoint.py index 9a8051ae937f..d1382cb1e36d 100644 --- a/colossalai/legacy/utils/activation_checkpoint.py +++ b/colossalai/legacy/utils/activation_checkpoint.py @@ -6,8 +6,8 @@ import torch from torch.utils.checkpoint import check_backward_validity, detach_variable +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states -from colossalai.utils.device import autocast, get_current_device def copy_to_device(obj, device): @@ -33,7 +33,7 @@ def forward(ctx, run_function, activation_offload=False, *args): check_backward_validity(args) ctx.run_function = run_function ctx.activation_offload = activation_offload - ctx.device = get_current_device() + ctx.device = get_accelerator().get_current_device() # preserve rng states ctx.fwd_cpu_rng_state = torch.get_rng_state() @@ -110,7 +110,7 @@ def backward(ctx, *args): inputs[idx] = tensors[i] detached_inputs = detach_variable(tuple(inputs)) if ctx.had_autocast_in_fwd: - with torch.enable_grad(), autocast(): + with torch.enable_grad(), get_accelerator().autocast()(): outputs = ctx.run_function(*detached_inputs) else: with torch.enable_grad(): @@ -226,7 +226,7 @@ def inner_unpack(packed): # rerun forward, the inner_pack will store all the activations in storage if has_autocast_in_fwd: - with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks( + with torch.enable_grad(), get_accelerator().autocast()(), torch.autograd.graph.saved_tensors_hooks( inner_pack, inner_unpack ): _unused = function(*args) @@ -245,7 +245,7 @@ def inner_unpack(packed): # get device if we need to offload the activation if activation_offload: - device = get_current_device() + device = get_accelerator().get_current_device() # run function with pack and unpack as saved_tensors_hooks with torch.autograd.graph.saved_tensors_hooks(pack, unpack): diff --git a/colossalai/legacy/utils/common.py b/colossalai/legacy/utils/common.py index 671bcc3d6ad7..76ec08e96a6d 100644 --- a/colossalai/legacy/utils/common.py +++ b/colossalai/legacy/utils/common.py @@ -96,9 +96,9 @@ def _calc_l2_norm(grads): global fused_optim if fused_optim is None: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() norm = 0.0 if len(grads) > 0: diff --git a/colossalai/legacy/utils/memory.py b/colossalai/legacy/utils/memory.py index 2f99a7d2f72e..cfb22d3153d9 100644 --- a/colossalai/legacy/utils/memory.py +++ b/colossalai/legacy/utils/memory.py @@ -6,9 +6,9 @@ import torch.distributed as dist from packaging import version +from colossalai.accelerator import get_accelerator from colossalai.legacy.core import global_context as gpc from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device _GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CPU_MEM_CAPACITY = -1 @@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int: # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node if device.type == "cuda": - return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION + return ( + torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory + * _GLOBAL_CUDA_MEM_FRACTION + ) def colo_device_memory_used(device: torch.device) -> int: @@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None: return global _GLOBAL_CUDA_MEM_FRACTION _GLOBAL_CUDA_MEM_FRACTION = ratio - torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) + torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_accelerator().get_current_device()) def colo_set_cpu_memory_capacity(size: int) -> None: diff --git a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py index ad54b989f412..a9e3ffe1a2ec 100644 --- a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py +++ b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py @@ -8,7 +8,7 @@ from torch.autograd.profiler import profile from torch.distributed import ReduceOp -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time @@ -177,7 +177,7 @@ def close_profiler(self, group=None): assert current_comm_event is not None, "dist op has not been found" - buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) + buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_accelerator().get_current_device()) torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) current_comm_event.self_cuda_time = buffer.item() diff --git a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py index e336717f4164..b0360880e7ad 100644 --- a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py +++ b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py @@ -3,7 +3,7 @@ from time import time from typing import List -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator from .stateful_tensor import StatefulTensor, TensorState from .tensor_placement_policy import TensorPlacementPolicy @@ -69,7 +69,7 @@ def adjust_layout(self) -> None: # move COMPUTE tensors to CUDA self._cpu_gpu_move_volume += cuda_demand for t in move_to_cuda_tensor_list: - colo_model_data_tensor_move_inline(t, get_current_device()) + colo_model_data_tensor_move_inline(t, get_accelerator().get_current_device()) @property def cpu_gpu_move_volume(self): diff --git a/colossalai/legacy/zero/gemini/tensor_placement_policy.py b/colossalai/legacy/zero/gemini/tensor_placement_policy.py index 3aca80cfe56a..6fde91d4a3a3 100644 --- a/colossalai/legacy/zero/gemini/tensor_placement_policy.py +++ b/colossalai/legacy/zero/gemini/tensor_placement_policy.py @@ -5,8 +5,8 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity -from colossalai.utils import get_current_device from colossalai.zero.gemini.memory_tracer import MemStatsCollector from .stateful_tensor import StatefulTensor @@ -38,7 +38,7 @@ def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) - class CUDATensorPlacementPolicy(TensorPlacementPolicy): def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available" - super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector) + super().__init__(get_accelerator().get_current_device(), mem_stats_collector=mem_stats_collector) def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: return 0, 0 @@ -78,7 +78,7 @@ def evict_tensors( int: the volume of memory that is evicted """ start = time() - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. diff --git a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py index b9d3071a877e..e5a35dea1b94 100644 --- a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -4,8 +4,8 @@ import torch.distributed as dist from torch._utils import _flatten_dense_tensors as flatten +from colossalai.accelerator import get_accelerator from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.utils import get_current_device from .tensor_shard_strategy import TensorShardStrategy @@ -30,9 +30,11 @@ def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist. rank = dist.get_rank(process_group) for i in range(world_size): if i == rank: - buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) + buffer_list.append( + flatten([t.payload for t in tensor_list]).cuda(get_accelerator().get_current_device()) + ) else: - buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) + buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_accelerator().get_current_device())) dist.all_gather(buffer_list, buffer_list[rank], group=process_group) # Move to target device before splitting buffer # Ensure we utilize maximum PCIE bandwidth diff --git a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py index ebaef774bd06..fb6ef534be56 100644 --- a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py @@ -3,11 +3,11 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.shard_utils.commons import get_shard from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.utils import get_current_device class TensorShardStrategy(BaseShardStrategy): @@ -34,9 +34,9 @@ def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGr if t.is_sharded: return if t.payload.device.type == "cuda": - assert t.payload.device == get_current_device(), ( + assert t.payload.device == get_accelerator().get_current_device(), ( f"shard tensor on cuda device index {t.payload.device.index}," - f" but current cuda device is {get_current_device()}" + f" but current cuda device is {get_accelerator().get_current_device()}" ) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) t.payload_reset(sharded_payload) @@ -50,7 +50,9 @@ def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessG world_size = dist.get_world_size(process_group) rank = dist.get_rank(process_group) - buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device()) + buffer = torch.empty( + payload_numel * world_size, dtype=t.payload.dtype, device=get_accelerator().get_current_device() + ) buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0)) buffer_list[rank].copy_(t.payload) diff --git a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py index 85f2ac2159f4..bb7744a80851 100644 --- a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils.memory import colo_device_memory_capacity @@ -22,7 +23,7 @@ from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.logging import get_dist_logger -from colossalai.utils import disposable, get_current_device +from colossalai.utils import disposable from colossalai.zero.gemini.memory_tracer import MemStatsCollector from ._utils import ( @@ -212,8 +213,12 @@ def dump_memory_stats(self, filename: Optional[str] = "dump_mem_stats.log") -> N self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0]) if gpc.get_global_rank() == 0: with open(filename, "w+") as f: - f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n") - f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n") + f.write( + f"cuda reserved {torch.cuda.memory_reserved(get_accelerator().get_current_device()) / 1e9} GB\n" + ) + f.write( + f"cuda max allocated {torch.cuda.max_memory_allocated(get_accelerator().get_current_device()) / 1e9} GB\n" + ) f.write("CUDA model data (GB)\n") f.write("\n") f.write("CUDA non model data (GB)\n") @@ -266,7 +271,8 @@ def _update_memstats(self): # model data is fixed in cuda during training. # cuda margin space can be used to store OS. self._cuda_margin_space = ( - colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda + colo_device_memory_capacity(get_accelerator().get_current_device()) + - self._memstats_collector._memstats.max_overall_cuda ) @torch.no_grad() diff --git a/colossalai/legacy/zero/sharded_model/zero_hook.py b/colossalai/legacy/zero/sharded_model/zero_hook.py index 892e9f31ded4..332f44d5397b 100644 --- a/colossalai/legacy/zero/sharded_model/zero_hook.py +++ b/colossalai/legacy/zero/sharded_model/zero_hook.py @@ -3,13 +3,13 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.registry import OPHOOKS from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.legacy.zero.gemini.stateful_tensor import TensorState from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device from colossalai.zero.gemini.memory_tracer import MemStatsCollector @@ -33,7 +33,7 @@ def __init__( self.process_group = process_group # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU - self.computing_device = get_current_device() + self.computing_device = get_accelerator().get_current_device() self._memstarts_collector = memstarts_collector self._stateful_tensor_mgr = stateful_tensor_mgr diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index c71e6c1f40c7..34342436f263 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -11,9 +11,9 @@ def load_moe(): global MOE_KERNEL - from colossalai.kernel.op_builder import MOEBuilder + from colossalai.kernel.kernel_loader import MoeLoader - MOE_KERNEL = MOEBuilder().load() + MOE_KERNEL = MoeLoader().load() class AllGather(torch.autograd.Function): @@ -145,14 +145,8 @@ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: class HierarchicalAllToAll(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - inputs: Tensor, - groups: Tuple[ProcessGroup, ProcessGroup], - src_rank: int - ) -> Tensor: + def forward(ctx: Any, inputs: Tensor, groups: Tuple[ProcessGroup, ProcessGroup], src_rank: int) -> Tensor: """ Returns: outputs: Tensor @@ -276,8 +270,9 @@ def backward(ctx, tokens_grad): if tokens_grad.dtype != torch.float32: tokens_grad = tokens_grad.to(torch.float32) - d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, - mask, dest_idx) + d_expert, d_logits = MOE_KERNEL.combine_backward( + ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, mask, dest_idx + ) if d_expert.dtype != ctx.dtype: d_expert = d_expert.to(ctx.dtype) diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index 3e64d796cce7..eaca75b8f18e 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -69,7 +69,7 @@ def setup( fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0. fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0. fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0. - use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. + use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if False. Defaults to True. """ assert not self.is_initialized, "MoE distributed context shouldn't be set up again" assert torch.cuda.is_available(), "MoE requires to enable CUDA first" diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index c5bb508621b2..f5815d05d111 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -8,9 +8,9 @@ import torch.nn.functional as F from torch.distributed import ProcessGroup +from colossalai.accelerator import get_accelerator from colossalai.moe._operation import moe_cumsum from colossalai.moe.manager import MOE_MANAGER -from colossalai.utils import get_current_device class MoeRouter(nn.Module, ABC): @@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC): drop_tks (bool, optional): Whether drops tokens in evaluation """ - def __init__(self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - use_kernel: bool = False): + def __init__( + self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + use_kernel: bool = False, + ): super().__init__() self.k_value = k_value self.capacity_factor_train = capacity_factor_train @@ -68,8 +70,9 @@ def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, if router_probs.dim() == expert_indices.dim() == 2: router_probs = router_probs.unsqueeze(0) expert_indices = expert_indices.unsqueeze(0) - assert router_probs.dim() == expert_indices.dim() == 3, \ - "router_probs must be 3D tensor and expert_indices must be 4D tensor" + assert ( + router_probs.dim() == expert_indices.dim() == 3 + ), "router_probs must be 3D tensor and expert_indices must be 4D tensor" # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. expert_mask = F.one_hot(expert_indices, num_experts) @@ -122,25 +125,29 @@ class Top1Router(MoeRouter): drop_tks (bool, optional): Whether drops tokens in evaluation """ - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, device=get_current_device()) + low=torch.tensor(0.0, device=get_accelerator().get_current_device()), + high=torch.tensor(1.0, device=get_accelerator().get_current_device()), ).rsample def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: @@ -216,18 +223,22 @@ class Top2Router(MoeRouter): drop_tks (bool, optional): Whether drops tokens in evaluation. """ - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: """ @@ -255,8 +266,8 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti top2_idx = torch.argmax(logits_except1, dim=-1) mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - cmask = (mask1 + mask2) # loss: [s, e] - cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 + cmask = mask1 + mask2 # loss: [s, e] + cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 # calculate loss expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) @@ -269,7 +280,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() - rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] + rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) rank2 += torch.sum(mask1, dim=-2, keepdim=True) @@ -336,15 +347,18 @@ class TopKRouter(MoeRouter): oversubscribed / reach capacity. """ - def __init__(self, - num_selected_experts: int, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, - drop_tks) + def __init__( + self, + num_selected_experts: int, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks + ) def forward( self, @@ -410,7 +424,7 @@ def forward( # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. - combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask) + combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) return combine_array, dispatch_mask diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index 5a17a6e0d769..e25e7dd48892 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -7,13 +7,12 @@ import torch.nn as nn import torch.nn.functional as F +from colossalai.accelerator import get_accelerator from colossalai.moe.manager import MOE_MANAGER from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor -from colossalai.utils import get_current_device class ForceFP32Parameter(torch.nn.Parameter): - def half(self, memory_format=None): return self.data.clone() @@ -30,8 +29,8 @@ class NormalNoiseGenerator: def __init__(self, num_experts: int): self.normal = torch.distributions.normal.Normal( - loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), + loc=torch.tensor(0.0, device=get_accelerator().get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()), ).rsample def __call__(self, inputs: torch.Tensor): @@ -52,8 +51,8 @@ class UniformNoiseGenerator: def __init__(self, eps: float = 1e-2): self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, device=get_current_device()), + low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()), + high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()), ).rsample def __call__(self, inputs: torch.Tensor): @@ -142,7 +141,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]] epsize_param_dict = dict() for param in model.parameters(): if not is_moe_tensor(param): - ep_size = 1 # set ep_size to 1 for dp parameters + ep_size = 1 # set ep_size to 1 for dp parameters else: ep_size = get_ep_size(param) if ep_size not in epsize_param_dict: @@ -193,18 +192,13 @@ def create_ep_hierarchical_group( assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." nproc_per_node = int(nproc_per_node) else: - assert dist.get_world_size() % nproc_per_node == 0, \ - "nproc_per_node should be a divisor of world_size." + assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size." num_node = dist.get_world_size() // nproc_per_node intra_src_rank = None ep_intra_node_group = None for i in range(num_node): - ep_intra_ranks = [ - i * nproc_per_node + j - for j in range(nproc_per_node) - if j in ep_group_ranks - ] + ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks] group = dist.new_group(ep_intra_ranks) if rank in ep_intra_ranks: assert ep_intra_node_group is None @@ -212,10 +206,7 @@ def create_ep_hierarchical_group( intra_src_rank = ep_intra_ranks[0] ep_inter_node_group = None - ep_inter_ranks = [ - ep_group_ranks[0] + i * nproc_per_node - for i in range(num_node) - ] + ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)] if len(ep_inter_ranks) > 1: group = dist.new_group(ep_inter_ranks) if rank in ep_inter_ranks: diff --git a/colossalai/nn/layer/colo_attention.py b/colossalai/nn/layer/colo_attention.py new file mode 100644 index 000000000000..0b7011e8e2d8 --- /dev/null +++ b/colossalai/nn/layer/colo_attention.py @@ -0,0 +1,209 @@ +import enum +import math +import warnings +from dataclasses import dataclass +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange + +from colossalai.accelerator import get_accelerator +from colossalai.kernel.kernel_loader import FlashAttentionLoader + + +@dataclass +class SeqLenInfo: + seqlens: Iterable[int] = None + indices: torch.Tensor = None + max_seqlen: int = None + cu_seqlens: torch.Tensor = None + + @staticmethod + def materialize( + attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device() + ): + if attn_mask is not None: + indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) + seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() + else: + batch_size, tgt_len = size[0], size[1] + indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) + seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) + max_seqlen = max(seqlens) + cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) + return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + paddedcausal = 3 + + +class Unpad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): + ctx.save_for_backward(indices) + # [b, s, ...] + assert tensor.ndim >= 3 + ctx.bsz = tensor.shape[0] + out = rearrange(tensor, "b s ... -> (b s) ...") + ctx.shape = out.shape + # [ntokens, ...] + return out[indices] + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [ntokens, ...] + grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) + grad[indices] = grad_output + grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) + # [b, s, ...] + return grad, None + + +class Repad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): + ctx.save_for_backward(indices) + # [ntokens, ...] + tensor = tensor + out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + # [b*s, ...] + out[indices] = tensor + return out + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [b*s, ...] + grad = grad_output[indices] + # [ntokens, ...] + return grad, None, None, None + + +class ColoAttention(torch.nn.Module): + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): + super().__init__() + assert ( + embed_dim % num_heads == 0 + ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + if scale is not None: + self.scale = scale + else: + self.scale = 1 / math.sqrt(embed_dim // num_heads) + self.dropout = dropout + + self.attn = FlashAttentionLoader().load() + + @staticmethod + def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return Unpad.apply(tensor, indices) + + @staticmethod + def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: + return Repad.apply(tensor, indices, batch_size, seq_len) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + origin_attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: Optional[AttnMaskType] = None, + bias: Optional[torch.Tensor] = None, + ): + """ + ColoAttention + + Args: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + origin_attn_mask: (nheads, q_seqlen, kv_seqlen) + bias: will not be used + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + # if flash attention is not applicable, switch to memory effcient attention + if self.attn.__name__ == "flash_attention" and ( + query.dtype not in [torch.float16, torch.bfloat16] or bias != None + ): + warnings.warn( + f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation." + ) + self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda") + + padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 + causal = attn_mask_type is not None and attn_mask_type.value > 1 + + batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] + # unpad + seq_len_info_q = None + seq_len_info_kv = None + if padded: + # bert style, unpad process + assert ( + attn_mask is not None + ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." + assert attn_mask.dim() == 2, ( + "attention mask is supposed to have shape (batch_size, seq_len), " + + f"but got {attn_mask.dim()} dimensions." + ) + + # bert style + if tgt_len == src_len: + seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query, key, value = self.unpad( + torch.stack([query, key, value], dim=2), seq_len_info_q.indices + ).unbind(dim=1) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + seq_len_info_kv = seq_len_info_q + else: + seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) + seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query = rearrange(query, "b s ... -> c (b s) ...", c=1) + key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( + dim=1 + ) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + + out = self.attn( + query, + key, + value, + seq_len_info_q=seq_len_info_q, + seq_len_info_kv=seq_len_info_kv, + origin_attn_mask=origin_attn_mask, + dropout_p=self.dropout, + scale=self.scale, + causal=causal, + padded=padded, + ) + + # repad + if padded: + if batch_size > 1: + out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) + out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) + + if len(out.shape) == 4: + out = rearrange(out, "b s h d -> b s (h d)") + return out diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/nn/layer/layernorm.py similarity index 95% rename from colossalai/kernel/cuda_native/layer_norm.py rename to colossalai/nn/layer/layernorm.py index c7d2a3a45022..1db48faee213 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/nn/layer/layernorm.py @@ -9,7 +9,7 @@ from torch.nn import init from torch.nn.parameter import Parameter -from colossalai.kernel.op_builder.layernorm import LayerNormBuilder +from colossalai.kernel.kernel_loader import LayerNormLoader try: from colossalai._C import layer_norm @@ -29,7 +29,7 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): global layer_norm if layer_norm is None: - layer_norm = LayerNormBuilder().load() + layer_norm = LayerNormLoader().load() output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.layernorm_op = layer_norm ctx.save_for_backward(input_, weight_, bias_, mean, invvar) diff --git a/colossalai/nn/layer/scaled_softmax.py b/colossalai/nn/layer/scaled_softmax.py new file mode 100644 index 000000000000..a8d72ddd90c9 --- /dev/null +++ b/colossalai/nn/layer/scaled_softmax.py @@ -0,0 +1,184 @@ +# This code from NVIDIA Megatron: +# with minor changes. + +import enum + +import torch +import torch.nn as nn + +from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + paddedcausal = 3 + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + global scaled_upper_triang_masked_softmax + if scaled_upper_triang_masked_softmax: + scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load() + + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + + return input_grads, None + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + scale_t = torch.tensor([scale]) + + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() + + softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None, None + + +class FusedScaleMaskSoftmax(nn.Module): + """ + Fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: Flag to indicate if input in fp16 data format. + input_in_bf16: Flag to indicate if input in bf16 data format. + attn_mask_type: Attention mask type (pad or causal) + scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion + mask_func: Mask function to be applied. + softmax_in_fp32: If True, softmax in performed at fp32 precision. + scale: Scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super(FusedScaleMaskSoftmax, self).__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type.value > 1: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type.value > 1: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + return ScaledMaskedSoftmax.apply(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + def get_batch_per_block(self, sq, sk, b, np): + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() + + return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 7d53a1dd6834..5be629fb2045 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -1,10 +1,9 @@ import math -import platform from typing import Optional import torch -from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder +from colossalai.kernel.kernel_loader import CPUAdamLoader from .nvme_optimizer import NVMeOptimizer @@ -78,7 +77,7 @@ def __init__( default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode - cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load() + cpu_adam = CPUAdamLoader().load() # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index fcdd3257d700..aeb5cc91bb9e 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -70,9 +70,9 @@ def __init__( self.adamw_mode = 1 if adamw_mode else 0 self.set_grad_none = set_grad_none if multi_tensor_applier.available: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py index 3e1d5a7ba539..da8d1608a072 100644 --- a/colossalai/nn/optimizer/fused_lamb.py +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -77,9 +77,9 @@ def __init__( ) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm # Skip buffer diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 95a6354208a8..3fae9bbca765 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -72,9 +72,9 @@ def __init__( self.wd_after_momentum = wd_after_momentum if multi_tensor_applier.available: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() # Skip buffer self._dummy_overflow_buf = torch.tensor( diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index d34fd601ab25..c9c1f81bfc9a 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -2,7 +2,7 @@ import torch -from colossalai.kernel.op_builder import FusedOptimBuilder +from colossalai.kernel.kernel_loader import FusedOptimizerLoader from colossalai.utils import multi_tensor_applier from .cpu_adam import CPUAdam @@ -85,7 +85,7 @@ def __init__( nvme_offload_dir, ) if torch.cuda.is_available(): - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() self.gpu_adam_op = fused_optim.multi_tensor_adam self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 72480526bd5c..20f316c2ae48 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -7,10 +7,10 @@ from torch.nn import Module from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.device import get_current_device from ._utils import get_batch_size, get_micro_batch, model_forward, to_device from .base import PipelineSchedule @@ -86,7 +86,7 @@ def load_micro_batch(self) -> Any: """ micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) self.microbatch_offset += self.microbatch_size - return tree_map(partial(to_device, device=get_current_device()), micro_batch) + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def _prepare_inputs_for_interval_stage(self): """ diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 0a01a1e7864b..a4ace5e1baad 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -6,10 +6,11 @@ from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.device import get_current_device +from colossalai.utils import get_current_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -72,6 +73,10 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) assert self.last_batch_size is None or self.last_batch_size == self.batch_size assert self.batch_size == self.microbatch_size * self.num_microbatch + assert ( + self.num_microbatch % self.stage_manager.num_stages == 0 + ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" + if self.forward_only: self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 # NOTE: disable metadata cache when batch size changes (not valid anymore) @@ -96,7 +101,7 @@ def load_micro_batch(self, model_chunk_id: int) -> Any: assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted" micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) self.microbatch_offset[model_chunk_id] += self.microbatch_size - return tree_map(partial(to_device, device=get_current_device()), micro_batch) + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: """Helper method to get the model chunk ID given the iteration number. diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index cb078b25faeb..bf2f01b10e9b 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -6,10 +6,11 @@ from torch.nn import Module from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.device import get_current_device +from colossalai.utils import get_current_device from ._utils import ( detach, @@ -85,6 +86,10 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) assert self.last_batch_size is None or self.last_batch_size == self.batch_size assert self.batch_size == self.microbatch_size * self.num_microbatches + assert ( + self.num_microbatches >= self.stage_manager.num_stages + ), "Number of microbatch should be larger than number of stages" + if self.forward_only: self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1 # NOTE: disable metadata cache when batch size changes (not valid anymore) @@ -106,7 +111,7 @@ def load_micro_batch(self) -> Any: assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted" micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) self.microbatch_offset += self.microbatch_size - return tree_map(partial(to_device, device=get_current_device()), micro_batch) + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def recv_forward(self, prev_rank: int = None) -> Any: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. @@ -313,7 +318,7 @@ def run_forward_only( accum_loss = None if return_loss and self.stage_manager.is_last_stage(): - accum_loss = torch.scalar_tensor(0, device=get_current_device()) + accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None for _ in range(self.num_microbatches): diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 4b6343adcd3b..0d2cc1b3370d 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -6,7 +6,8 @@ from torch import nn from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup, get_world_size -from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed + +from colossalai.accelerator import get_accelerator class SeqParallelUtils: @@ -109,10 +110,10 @@ def __init__(self, seed: int): # 1. get the current rng state # 2. set the seed and store the rng state # 3. recover the original rng state - device_original_rng_state = get_rng_state() - manual_seed(seed) - self.device_rng_state = get_rng_state() - set_rng_state(device_original_rng_state) + device_original_rng_state = get_accelerator().get_rng_state() + get_accelerator().manual_seed(seed) + self.device_rng_state = get_accelerator().get_rng_state() + get_accelerator().set_rng_state(device_original_rng_state) # to the same for cpu rng state cpu_original_rng_state = torch.get_rng_state() @@ -121,10 +122,10 @@ def __init__(self, seed: int): torch.set_rng_state(cpu_original_rng_state) def _set_device_rng_state(self, rng_state): - set_rng_state(rng_state) + get_accelerator().set_rng_state(rng_state) def _get_device_rng_state(self): - current_state = get_rng_state() + current_state = get_accelerator().get_rng_state() return current_state def _set_cpu_rng_state(self, rng_state): @@ -209,7 +210,7 @@ def is_randomizer_index_synchronized(process_group: ProcessGroup = None): index = Randomizer.index() if dist.is_initialized(): # convert the index to tensor - index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) + index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device()) # all gather the index gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] @@ -231,7 +232,7 @@ def synchronize_index(process_group: ProcessGroup = None): if dist.is_initialized(): # convert the index to tensor - index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) + index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device()) # all gather the index gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index 00b2037fbdc8..d5c10541a28f 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -62,7 +62,7 @@ def forward( def get_blip2_flash_attention_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2Attention - from colossalai.kernel.cuda_native import ColoAttention + from colossalai.nn.layer.colo_attention import ColoAttention def forward( self: Blip2Attention, diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index c8a311df7c6d..d13bd34926a5 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -14,7 +14,7 @@ def get_flash_core_attention_forward(): - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention from .chatglm2_6b.modeling_chatglm import CoreAttention diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 8f456353742c..055e3096d794 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -719,7 +719,7 @@ def gpt2_for_sequence_classification_forward( def get_gpt2_flash_attention_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def split_heads(tensor, num_heads, attn_head_size): """ diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index ad51bf2c709b..22b0f7a90656 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -530,7 +530,7 @@ def gptj_for_question_answering_forward( def get_gptj_flash_attention_forward(): from transformers.models.gptj.modeling_gptj import GPTJAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def split_heads(tensor, num_attention_heads, attn_head_size, rotary): """ diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 1b53ce4afebb..e10a7ed7da0c 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -3,7 +3,6 @@ import torch import torch.nn.functional as F -import torch.distributed as dist from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, @@ -15,14 +14,17 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig + from ..layer import cross_entropy_1d try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask + LATEST_VERSION = True except ImportError: LATEST_VERSION = False + class LlamaPipelineForwards: """ This class serves as a micro library for forward function substitution of Llama models @@ -203,7 +205,7 @@ def llama_for_causal_lm_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None + shard_config: ShardConfig = None, ): r""" Args: @@ -279,12 +281,13 @@ def llama_for_causal_lm_forward( if shard_config.enable_tensor_parallelism: new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d(shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits, shift_labels) - if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -417,7 +420,7 @@ def llama_for_sequence_classification_forward( def get_llama_flash_attention_forward(shard_config: ShardConfig): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention llama_version = 2 try: @@ -480,7 +483,12 @@ def forward( attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + query_states, + key_states, + value_states, + attn_mask=flash_attention_mask, + attn_mask_type=attn_mask_type, + origin_attn_mask=attention_mask, ) attn_output = self.o_proj(attn_output) @@ -492,7 +500,7 @@ def forward( def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): from transformers import LlamaForCausalLM - + def forward( self: LlamaForCausalLM, input_ids: torch.LongTensor = None, @@ -573,12 +581,13 @@ def forward( if shard_config.enable_tensor_parallelism: new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d(shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits, shift_labels) - if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -590,4 +599,5 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + return forward diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 1ddb26c25d5c..0da1a35a0278 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -6,7 +6,7 @@ def get_mistral_flash_attention_forward(): from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def forward( self: MistralAttention, diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 71f2ca3353bc..7f6cbbbcf4f3 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -514,7 +514,7 @@ def opt_for_question_answering_forward( def get_opt_flash_attention_forward(): from transformers.models.opt.modeling_opt import OPTAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def forward( self: OPTAttention, diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index f67aa84e4e72..dcb1785207eb 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -593,10 +593,6 @@ def t5_encoder_model_forward( def get_t5_flash_attention_forward(): - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") from transformers.models.t5.modeling_t5 import T5Attention def forward( @@ -632,11 +628,11 @@ def forward( def shape(states): """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) def unshape(states): """reshape""" - return states.view(batch_size, -1, self.inner_dim) + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) def project(hidden_states, proj_layer, key_value_states, past_key_value): """projects hidden states correctly to key/query states""" @@ -653,8 +649,8 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if key_value_states is None: # self-attn # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=1) - elif past_key_value.shape[1] != key_value_states.shape[1]: + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: # checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning # cross-attn @@ -701,10 +697,15 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): else: position_bias_masked = position_bias - position_bias_masked = position_bias_masked.contiguous() - attn_output = me_attention( - query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0 - ) + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias_masked, + dropout_p=self.dropout, + scale=1.0, + ) attn_output = unshape(attn_output) attn_output = self.o(attn_output) diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 5a50e7379cdc..ab141a74aef8 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -336,7 +336,7 @@ def pp_forward( def get_vit_flash_self_attention_forward(): from transformers.models.vit.modeling_vit import ViTSelfAttention - from colossalai.kernel.cuda_native import ColoAttention + from colossalai.nn.layer.colo_attention import ColoAttention def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 9827d4801f8d..cb8b45ae7d01 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -26,7 +26,7 @@ def get_whisper_flash_attention_forward(): from transformers.models.whisper.modeling_whisper import WhisperAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index f2eeb9d69c81..5c148880f980 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -49,7 +49,7 @@ def module_policy(self): if not self.model.config.new_decoder_architecture and self.model.config.multi_query: warnings.warn( - "Falcon dosen't support tensor parallelism when (not new_decoder_architecture and multi_query) is True, will ignore the tensor parallelism flag." + "Falcon doesn't support tensor parallelism when (not new_decoder_architecture and multi_query) is True, will ignore the tensor parallelism flag." ) self.shard_config.enable_tensor_parallelism = False diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 1faa24f71e0b..42bf0825b045 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -46,7 +46,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: decoder_attribute_replacement = { diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index c16aa6deab3b..c0b8b3375836 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -35,7 +35,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( - "Mistral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + "Mistral doesn't support sequence parallelism now, will ignore the sequence parallelism flag." ) if self.shard_config.enable_tensor_parallelism: @@ -136,7 +136,7 @@ def __init__(self) -> None: def module_policy(self): if self.pipeline_stage_manager: - warnings.warn("Mistral dosen't support pipeline parallelism now.") + warnings.warn("Mistral doesn't support pipeline parallelism now.") return super().module_policy() @@ -160,7 +160,7 @@ def module_policy(self): } if self.pipeline_stage_manager: - warnings.warn("Mistral dosen't support pipeline parallelism now.") + warnings.warn("Mistral doesn't support pipeline parallelism now.") policy.update(new_item) @@ -186,7 +186,7 @@ def module_policy(self): } if self.pipeline_stage_manager: - warnings.warn("Mistral dosen't support pipeline parallelism now.") + warnings.warn("Mistral doesn't support pipeline parallelism now.") policy.update(new_item) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index e2f3a829cc6f..a542808ba794 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -59,7 +59,7 @@ def module_policy(self): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[OPTDecoder] = ModulePolicyDescription( diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 4d906e3f4c04..e183b0632f88 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -66,7 +66,7 @@ def module_policy(self): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[T5Stack] = ModulePolicyDescription( @@ -263,7 +263,7 @@ def distribute_t5_layers( if num_decoder_layers == 0: return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages - # the number of stages distributed between encoder and decoder is optmized in this way: + # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 def objective(num_encoder_stages): diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 6ef0e3b34b2b..584d4e2652c0 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -33,7 +33,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[ViTEmbeddings] = ModulePolicyDescription( diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 6dae99e8cedb..b5b5db79d9de 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -69,13 +69,13 @@ def module_policy(self): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( - "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + "Whisper doesn't support sequence parallelism now, will ignore the sequence parallelism flag." ) # TODO using the jit fused add_and_dropout affect the accuracy if self.shard_config.enable_jit_fused: self.shard_config.enable_jit_fused = False - warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.") + warnings.warn("Whisper doesn't support jit fused operator now, will ignore the jit fused operator flag.") if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription( @@ -302,7 +302,7 @@ def distribute_whisper_layers( if num_decoder_layers == 0: return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages - # the number of stages distributed between encoder and decoder is optmized in this way: + # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 def objective(num_encoder_stages): diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index de0cba26b52a..27afac9e95d3 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -451,7 +451,7 @@ def __repr__(self): elif self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: res_list.append(f"comm_pattern:MIXGATHER_FWD_SPLIT_BWD, ") res_list.append(f"gather_dim:{self.gather_dim}, ") - res_list.append(f"logical_process_asex:{self.logical_process_axes})") + res_list.append(f"logical_process_axes:{self.logical_process_axes})") return "".join(res_list) diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index 74a785f2dcd4..da6ef275e108 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -96,9 +96,9 @@ def _apply_layout(tensor, layout): """ Apply the layout to the local tensor during initializing process. """ - # layout converter requires a source and target laytout + # layout converter requires a source and target layout # we construct the source layer for an unsharded tensor - # and use self.dist_layer as the targer layout for the sharded tensor + # and use self.dist_layer as the target layout for the sharded tensor source_spec = _construct_default_sharding_spec(tensor) source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape) sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout) diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index 1e4486101dd3..b6843df7a478 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -40,7 +40,7 @@ def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> M ep_size (int): The expert parallel size. dp_size (int): The data parallel size. pp_size (int): The pipeline parallel size. - ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. + ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if False. Returns: dict: The moe info of the given tensor. diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index 5097ac1044e7..ba6c77056222 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -12,7 +12,7 @@ def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1 ep_size (int): expert parallel size dp_size (int): data parallel (zero) size pp_size (int, optional): pipeline parallel size. Defaults to 1. - ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. + ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if False. Defaults to True. """ self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size if ep_inside: diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 7cd24b0adc60..5f6864ff0059 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -9,7 +9,8 @@ import torch import torch.multiprocessing as mp from packaging import version -from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count + +from colossalai.accelerator import get_accelerator def parameterize(argument: str, values: List[Any]) -> Callable: @@ -199,7 +200,7 @@ def test_something(): def _wrap_func(f): def _execute_by_gpu_num(*args, **kwargs): - num_avail_gpu = device_count() + num_avail_gpu = get_accelerator().device_count() if num_avail_gpu >= min_gpus: f(*args, **kwargs) @@ -263,11 +264,11 @@ def test_something(): def _wrap_func(f): def _clear_cache(*args, **kwargs): - empty_cache() - reset_peak_memory_stats() - reset_max_memory_allocated() - reset_max_memory_cached() - synchronize() + get_accelerator().empty_cache() + get_accelerator().reset_peak_memory_stats() + get_accelerator().reset_max_memory_allocated() + get_accelerator().reset_max_memory_cached() + get_accelerator().synchronize() gc.collect() f(*args, **kwargs) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 0246a35e2a1b..cdba467091be 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -4,20 +4,16 @@ disposable, ensure_path_exists, free_storage, + get_current_device, is_ddp_ignored, set_seed, ) -from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize from .multi_tensor_apply import multi_tensor_applier from .tensor_detector import TensorDetector from .timer import MultiTimer, Timer __all__ = [ "conditional_context", - "get_current_device", - "synchronize", - "empty_cache", - "set_to_cuda", "Timer", "MultiTimer", "multi_tensor_applier", @@ -27,7 +23,6 @@ "_cast_float", "free_storage", "set_seed", + "get_current_device", "is_ddp_ignored", - "set_device", - "IS_NPU_AVAILABLE", ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index c43caaff4806..4a1889eb57ff 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -10,6 +10,15 @@ import numpy as np import torch +from colossalai.accelerator import get_accelerator + + +def get_current_device(): + """ + A wrapper function for accelerator's API for backward compatibility. + """ + return get_accelerator().get_current_device() + def ensure_path_exists(filename: str): # ensure the path exists diff --git a/colossalai/utils/device.py b/colossalai/utils/device.py deleted file mode 100644 index c70dbdaa5ee1..000000000000 --- a/colossalai/utils/device.py +++ /dev/null @@ -1,223 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import Any, Dict, List, Optional, Tuple, Callable - -import torch -import torch.distributed as dist - -IS_NPU_AVAILABLE: bool = False -try: - import torch_npu # noqa - - IS_NPU_AVAILABLE = torch.npu.is_available() -except ImportError: - pass - - -def set_to_cuda(models): - """Send model to gpu. - - :param models: nn.module or a list of module - """ - if isinstance(models, list) and len(models) > 1: - ret = [] - for model in models: - ret.append(model.to(get_current_device())) - return ret - elif isinstance(models, list): - return models[0].to(get_current_device()) - else: - return models.to(get_current_device()) - - -def get_current_device() -> torch.device: - """ - Returns currently selected device (gpu/cpu). - If cuda available, return gpu, otherwise return cpu. - """ - if torch.cuda.is_available(): - return torch.device(f"cuda:{torch.cuda.current_device()}") - elif IS_NPU_AVAILABLE: - return torch.device(f"npu:{torch.npu.current_device()}") - else: - return torch.device("cpu") - - -def _dispatch_device_func(fn_name: str, *args, **kwargs): - if torch.cuda.is_available(): - return getattr(torch.cuda, fn_name)(*args, **kwargs) - elif IS_NPU_AVAILABLE: - return getattr(torch.npu, fn_name)(*args, **kwargs) - else: - raise RuntimeError("No device available") - - -# device semantics - - -def can_device_access_peer(device, peer_device) -> bool: - return _dispatch_device_func("can_device_access_peer", device, peer_device) - - -def current_device() -> int: - return _dispatch_device_func("current_device") - - -def current_stream(device=None): - return _dispatch_device_func("current_stream", device) - - -def default_stream(device=None): - return _dispatch_device_func("default_stream", device) - - -def device_count() -> int: - return _dispatch_device_func("device_count") - - -def get_device_capability(device=None) -> Tuple[int, int]: - return _dispatch_device_func("get_device_capability", device) - - -def get_device_name(device=None) -> str: - return _dispatch_device_func("get_device_name", device) - - -def get_device_properties(device): - return _dispatch_device_func("get_device_properties", device) - - -def set_device(index: Optional[int] = None) -> None: - if index is None: - index = dist.get_rank() % device_count() - _dispatch_device_func("set_device", index) - - -def set_stream(stream_): - return _dispatch_device_func("set_stream", stream_) - - -def stream(stream_): - return _dispatch_device_func("stream", stream_) - - -def synchronize(): - return _dispatch_device_func("synchronize") - - -def utilization(device=None) -> int: - return _dispatch_device_func("utilization", device) - - -# random number generator - - -def get_rng_state(device="cuda") -> torch.Tensor: - return _dispatch_device_func("get_rng_state", device) - - -def get_rng_state_all() -> List[torch.Tensor]: - return _dispatch_device_func("get_rng_state_all") - - -def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None: - return _dispatch_device_func("set_rng_state", new_state, device) - - -def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None: - return _dispatch_device_func("set_rng_state_all", new_states) - - -def manual_seed(seed: int) -> None: - return _dispatch_device_func("manual_seed", seed) - - -def manual_seed_all(seed: int) -> None: - return _dispatch_device_func("manual_seed_all", seed) - - -def seed() -> None: - return _dispatch_device_func("seed") - - -def seed_all() -> None: - return _dispatch_device_func("seed_all") - - -def initial_seed() -> int: - return _dispatch_device_func("initial_seed") - - -# streams and events - - -def Stream(device=None, priority=0, **kwargs): - return _dispatch_device_func("Stream", device, priority, **kwargs) - - -def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): - return _dispatch_device_func("Event", enable_timing, blocking, interprocess) - - -# memory management - - -def empty_cache() -> None: - return _dispatch_device_func("empty_cache") - - -def memory_stats(device=None) -> Dict[str, Any]: - return _dispatch_device_func("memory_stats", device) - - -def memory_summary(device=None, abbreviated=False) -> str: - return _dispatch_device_func("memory_summary", device, abbreviated) - - -def memory_snapshot(): - return _dispatch_device_func("memory_snapshot") - - -def memory_allocated(device=None) -> int: - return _dispatch_device_func("memory_allocated", device) - - -def max_memory_allocated(device=None) -> int: - return _dispatch_device_func("max_memory_allocated", device) - - -def reset_max_memory_allocated(device=None) -> None: - return _dispatch_device_func("reset_max_memory_allocated", device) - - -def reset_max_memory_cached(device=None) -> None: - return _dispatch_device_func("reset_max_memory_cached", device) - - -def memory_reserved(device=None) -> int: - return _dispatch_device_func("memory_reserved", device) - - -def max_memory_reserved(device=None) -> int: - return _dispatch_device_func("max_memory_reserved", device) - - -def set_per_process_memory_fraction(fraction: float, device=None) -> None: - return _dispatch_device_func("set_per_process_memory_fraction", fraction, device) - - -def reset_peak_memory_stats(device=None) -> None: - return _dispatch_device_func("reset_peak_memory_stats", device) - - -# amp - - -def autocast() -> Callable: - if torch.cuda.is_available(): - return torch.cuda.amp.autocast() - elif IS_NPU_AVAILABLE: - return torch.npu.amp.autocast() - else: - raise RuntimeError("No device available") diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py index 8ab6b46f28b6..2feded7751ea 100644 --- a/colossalai/utils/timer.py +++ b/colossalai/utils/timer.py @@ -3,7 +3,7 @@ import time from typing import Tuple -from .device import synchronize +from colossalai.accelerator import get_accelerator class Timer: @@ -21,13 +21,13 @@ def has_history(self): @property def current_time(self) -> float: - synchronize() + get_accelerator().synchronize() return time.time() def start(self): """Firstly synchronize cuda, reset the clock and then start the timer.""" self._elapsed = 0 - synchronize() + get_accelerator().synchronize() self._start_time = time.time() self._started = True @@ -44,7 +44,7 @@ def stop(self, keep_in_history: bool = False): Returns: int: Start-stop interval. """ - synchronize() + get_accelerator().synchronize() end_time = time.time() elapsed = end_time - self._start_time if keep_in_history: @@ -123,7 +123,7 @@ def stop(self, name: str, keep_in_history: bool): return None def get_timer(self, name): - """Get timer by its name (from multitimer) + """Get timer by its name (from multimer) Args: name (str): Timer's key. diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index defc6c4cb150..cad2622f2851 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -6,8 +6,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.utils import get_current_device -from colossalai.utils.device import IS_NPU_AVAILABLE +from colossalai.accelerator import get_accelerator class TensorState(Enum): @@ -107,7 +106,7 @@ def __init__( self.valid_end = self.shard_size self.dtype = dtype - device = init_device or get_current_device() + device = init_device or get_accelerator().get_current_device() # chunk_temp is a global chunk, which only exists during building the chunks. self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero @@ -125,7 +124,7 @@ def __init__( # configure the init device of the shard # no-offload default: fp16, fp32 -> CUDA # offload default: fp16, fp32 -> CPU - self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device() + self.shard_device = torch.device("cpu") if cpu_shard_init else get_accelerator().get_current_device() self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() self.shard_mem = self.chunk_mem // self.pg_size @@ -191,11 +190,10 @@ def memory_usage(self) -> Dict[str, int]: def device_type(self) -> str: if self.chunk_temp is not None: return self.chunk_temp.device.type + elif self.is_gathered or self.cuda_shard is not None: + return get_accelerator().name else: - if self.is_gathered or self.cuda_shard is not None: - return "npu" if IS_NPU_AVAILABLE else "cuda" - else: - return "cpu" + return "cpu" @property def payload(self) -> torch.Tensor: @@ -297,7 +295,7 @@ def close_chunk(self): self.valid_end = self.utilized_size - self.shard_begin if self.chunk_temp.device.type == "cpu": - self.cuda_global_chunk = self.chunk_temp.to(get_current_device()) + self.cuda_global_chunk = self.chunk_temp.to(get_accelerator().get_current_device()) self.__update_tensors_ptr() else: self.cuda_global_chunk = self.chunk_temp @@ -334,12 +332,12 @@ def shard_move(self, device: torch.device, force_copy: bool = False): return if device.type == "cuda" or device.type == "npu": - assert device == get_current_device(), "can't move chunk to another device" + assert device == get_accelerator().get_current_device(), "can't move chunk to another device" if self.cuda_shard: return - self.cuda_shard = self.cpu_shard.to(get_current_device()) + self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device()) if not self.pin_memory: self.cpu_shard = None @@ -394,7 +392,9 @@ def reduce(self): if self.extra_dp_group is not None: dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) else: - self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) + self.cuda_shard = torch.empty( + self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() + ) input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) @@ -533,7 +533,7 @@ def __paired_shard_move(self): # only be called when optimizer state is in CPU memory # the grad and param should be in the same device assert self.cuda_shard is None - temp = optim_chunk.cpu_shard.to(get_current_device()) + temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device()) # avoid to transform FP32 in CPU self.cuda_shard = temp.to(self.dtype) @@ -631,7 +631,7 @@ def init_grad_chunk(self) -> "Chunk": grad_chunk.valid_end = self.valid_end if grad_chunk.chunk_temp.device.type == "cpu": - grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device()) + grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_accelerator().get_current_device()) else: grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp grad_chunk.chunk_temp = None diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 5f4f37c267aa..5bc662a6189c 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -5,7 +5,8 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.utils import free_storage, get_current_device +from colossalai.accelerator import get_accelerator +from colossalai.utils import free_storage from .chunk import Chunk, ChunkFullError, TensorState @@ -20,7 +21,7 @@ class ChunkManager: """ def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: - self.device = init_device or get_current_device() + self.device = init_device or get_accelerator().get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.kwargs_config = chunk_configuration for k, v in self.kwargs_config.items(): @@ -107,7 +108,7 @@ def access_chunk(self, chunk: Chunk) -> None: return self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": - chunk.shard_move(get_current_device()) + chunk.shard_move(get_accelerator().get_current_device()) self.__add_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) @@ -276,7 +277,10 @@ def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk: accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size) else: accumulated_grad = ( - chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size) + chunk.grad_chunk.cpu_shard.to(get_accelerator().get_current_device()) + .clone() + .detach() + .mul_(chunk.pg_size) ) accumulated_grad_gathered = False diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 5217b8036bcd..79831cf33dbc 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor @@ -27,7 +28,7 @@ is_distributed_tensor, ) from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored +from colossalai.utils import _cast_float, free_storage, is_ddp_ignored from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .gemini_hook import GeminiZeROHook @@ -766,7 +767,7 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi # move ignored parameters to CUDA if is_ddp_ignored(p): - p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) + p.data = p.data.to(device=get_accelerator().get_current_device(), dtype=self.mixed_precision) continue # create a fp16 parameter @@ -815,7 +816,7 @@ def _cast_buffers(self): for buffer in self.module.buffers(): if isinstance(buffer, LazyTensor): buffer.materialize() - buffer.data = buffer.to(get_current_device()) + buffer.data = buffer.to(get_accelerator().get_current_device()) if torch.is_floating_point(buffer): buffer.data = buffer.to(self.mixed_precision) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 8f828bd6cf20..98fbb0c50e24 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -11,6 +11,7 @@ from torch.nn import Parameter from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import OptimizerWrapper @@ -26,7 +27,7 @@ is_customized_distributed_tensor, is_distributed_tensor, ) -from colossalai.utils import disposable, get_current_device, is_ddp_ignored +from colossalai.utils import disposable, is_ddp_ignored from .chunk import Chunk, ChunkManager from .gemini_ddp import GeminiDDP @@ -233,7 +234,7 @@ def _calc_global_norm(self) -> float: grad_chunk.l2_norm = None # clear l2 norm - comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) + comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device()) for group, part_norm in group_to_norm.items(): comm_buffer.fill_(part_norm) dist.all_reduce(comm_buffer, group=group) @@ -314,10 +315,10 @@ def _maybe_move_fp32_params(self): continue if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: - self.chunk_manager.move_chunk(chunk32, get_current_device()) + self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device()) # stores grad now - self.chunk_manager.move_chunk(chunk16, get_current_device()) - self.module.set_chunk_grad_device(chunk16, get_current_device()) + self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device()) + self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device()) fp32_params_used_cuda_margin_mem += chunk32.payload_mem for group in self.param_groups: @@ -328,7 +329,7 @@ def _maybe_move_fp32_params(self): state = self.optim.state[fake_param] for k, v in state.items(): if isinstance(v, torch.Tensor): - state[k] = v.to(get_current_device()) + state[k] = v.to(get_accelerator().get_current_device()) def _register_states_(self): for group in self.optim.param_groups: @@ -413,7 +414,7 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: only_rank_0(bool): if True, states will be collected only on master rank, otherwise collected on every rank. Returns: - collected_states(dict): the gathered optimzier state of parameter with given id + collected_states(dict): the gathered optimizer state of parameter with given id if this method is called by master rank, otherwise an empty dict. This method can work only when called by all processes simultaneously. @@ -461,7 +462,7 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: global_shape = self.optimizer_params_info["id2shape"][param_id] # If the chunk is kept gathered, - # the parameteres are treated the same as that of those in strict DDP during training. + # the parameters are treated the same as that of those in strict DDP during training. # So states can be directly fetched from current device. if chunk.keep_gathered: assert param_id in self.id_to_fake_params @@ -551,7 +552,7 @@ def pack_optimizer_states_to_tensor( self, param_id: int, state_names: list, - device: torch.device = get_current_device(), + device: torch.device = get_accelerator().get_current_device(), dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ @@ -644,7 +645,7 @@ def state_dict(self, only_rank_0: bool = True) -> dict: """ Args: only_rank_0 (bool): a boolean value indicating whether the state_dict is collected - only on rank 0, dafault to True. + only on rank 0, default to True. Returns: The complete state of the optimizer as a :class:`dict`. @@ -783,7 +784,7 @@ def state_shard( prefix (str, optional): the prefix for states. Default to ''. max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024. only_rank_0 (bool, optional): a boolean value indicating whether the state_dict is collected - only on rank 0, dafault to True. + only on rank 0, default to True. Yields: Iterator[OrderedDict]: A generator of state dict shard of optimizer states. diff --git a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py index b5e40a817e58..e302805dfbb7 100644 --- a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py @@ -1,6 +1,6 @@ from typing import Optional -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from colossalai.zero.gemini.chunk import ChunkManager from .memory_stats import MemStats @@ -33,4 +33,4 @@ def record_model_data_volume(self) -> None: def cuda_margin_mem(self) -> float: from colossalai.legacy.utils.memory import colo_device_memory_capacity - return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda + return colo_device_memory_capacity(get_accelerator().get_current_device()) - self._memstats.max_overall_cuda diff --git a/colossalai/zero/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py index 513a6326d5f1..82c8e9dab098 100644 --- a/colossalai/zero/gemini/memory_tracer/memory_monitor.py +++ b/colossalai/zero/gemini/memory_tracer/memory_monitor.py @@ -5,7 +5,7 @@ import torch -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator class MemoryMonitor: @@ -77,7 +77,7 @@ def __init__(self, power: int = 10): super().__init__() self.keep_measuring = False - current_device = get_current_device() + current_device = get_accelerator().get_current_device() def _set_cuda_device(): torch.cuda.set_device(current_device) @@ -116,7 +116,7 @@ def _measure_usage(self): while self.keep_measuring: max_usage = max( max_usage, - colo_device_memory_used(get_current_device()), + colo_device_memory_used(get_accelerator().get_current_device()), ) sleep(self.interval) return max_usage diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index c410ad3793c9..388999549bd8 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -6,8 +6,8 @@ import torch -from colossalai.utils import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.accelerator import get_accelerator +from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager @@ -85,7 +85,7 @@ def setup_grads_device( # init offload optim settings # keep gathered chunks are in CUDA if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem: - device = get_current_device() + device = get_accelerator().get_current_device() else: device = torch.device("cpu") # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here @@ -140,7 +140,7 @@ def evict_tensors( int: the volume of memory that is evicted """ start = time() - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) used_cuda_model_data = self.chunk_manager.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. @@ -194,7 +194,7 @@ def setup_grads_device( # init offload optim settings # keep gathered chunks are in CUDA if chunk.keep_gathered: - grads_device_map[p] = get_current_device() + grads_device_map[p] = get_accelerator().get_current_device() else: grads_device_map[p] = torch.device("cpu") diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py index 5305953fe1ee..b563ea5b2de6 100644 --- a/colossalai/zero/gemini/utils.py +++ b/colossalai/zero/gemini/utils.py @@ -6,7 +6,7 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from .chunk import Chunk @@ -18,11 +18,11 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype): if chunk.cuda_shard is not None: shard_temp = chunk.cuda_shard else: - shard_temp = chunk.cpu_shard.to(get_current_device()) + shard_temp = chunk.cpu_shard.to(get_accelerator().get_current_device()) shard_temp = shard_temp.to(dtype) - total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device()) + total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_accelerator().get_current_device()) gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 2828d517573d..f395fc60ec42 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -15,7 +15,7 @@ def __init__(self, torch_pg: ProcessGroup): # init self.current_group_id = 0 self._num_elements_in_bucket = 0 - # mapping gardient slices and parameter + # mapping gradient slices and parameter self.grad_to_param_mapping = dict() self._grad_in_bucket = dict() @@ -59,7 +59,7 @@ def add_param_grad(self, group_id: int, param: Tensor, padding_size: int): self.offset_list[-1] += 1 def build_grad_in_bucket(self): - """Orgnize parameters' gradient(padding and split), follows the paramters' splitting method + """Organize parameters' gradient(padding and split), follows the parameters' splitting method Data structure of self._grad_in_bucket: { @@ -91,7 +91,7 @@ def get_grad(self) -> Dict: return self._grad_in_bucket def get_flatten_grad(self) -> Tensor: - """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor: + """Return the flattened gradients slices in the bucket, the data organization of the flattened tensor: [grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....] Returns: diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 1164532fa3a3..73a1db5a0c0d 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -9,7 +9,7 @@ class GradientStore(BaseStore): def __init__(self, *args, partition_grad: bool = False): super().__init__(*args) """ - self._grads_of_params mapping the paramater and its gradient slices + self._grads_of_params mapping the parameter and its gradient slices data structure: { group_id:{ diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index c1b35ee17f91..e01c852bee50 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -12,7 +12,7 @@ from torch.distributed import ProcessGroup from torch.optim import Optimizer -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_mixin import ( BF16MixedPrecisionMixin, FP16MixedPrecisionMixin, @@ -22,9 +22,6 @@ from colossalai.logging import get_dist_logger from colossalai.tensor.moe_tensor.api import is_moe_tensor -# from colossalai.tensor import ColoParameter, ProcessGroup -from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device - from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -171,7 +168,7 @@ def __init__( # managed by this data parallel rank param_group["params"] = master_param_current_rank - # if there are moe params, store in addtional group in optim + # if there are moe params, store in additional group in optim if len(moe_params) > 0: param_group = dict() for key, value in self.optim.param_groups[0].items(): @@ -180,10 +177,10 @@ def __init__( param_group["params"] = moe_params self.optim.param_groups.append(param_group) - # intialize communication stream for - # communication-compuation overlapping + # initialize communication stream for + # communication-computation overlapping if self._overlap_communication: - self._comm_stream = device_utils.Stream() + self._comm_stream = get_accelerator().Stream() # reduction hook is only used if overlapping communication # or stage 2 is used @@ -217,7 +214,7 @@ def num_param_groups(self): return len(self._working_param_groups) def _sanity_checks(self): - assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required" + assert get_accelerator().name in ["cuda", "npu"], "device is required" for param_group in self.optim.param_groups: group_params = param_group["params"] for param in group_params: @@ -228,7 +225,7 @@ def _sanity_checks(self): def _create_master_param_current_rank(self, param_list): # split each param evenly by world size params_current_rank = [] - device = "cpu" if self._cpu_offload else get_current_device() + device = "cpu" if self._cpu_offload else get_accelerator().get_current_device() for param in param_list: padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size @@ -340,11 +337,11 @@ def _run_reduction(self): if len(moe_grad_list) > 0: moe_flat_grads.record_stream(stream) # waiting for ops in the default stream finishing - stream.wait_stream(device_utils.current_stream()) + stream.wait_stream(get_accelerator().current_stream()) else: - stream = device_utils.current_stream() + stream = get_accelerator().current_stream() - with device_utils.stream(stream): + with get_accelerator().stream(stream): group_id = self._bucket_store.current_group_id if self.moe_extra_dp_pg is None: @@ -486,7 +483,7 @@ def backward(self, loss, retain_graph=False): # clear reduced grads if self._overlap_communication: - device_utils.synchronize() + get_accelerator().synchronize() self.zero_grad() @@ -505,7 +502,7 @@ def backward_by_grad(self, tensor, grad): # clear reduced grads if self._overlap_communication: - device_utils.synchronize() + get_accelerator().synchronize() self.zero_grad() @@ -621,7 +618,7 @@ def step(self, closure=None): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank - device = get_current_device() + device = get_accelerator().get_current_device() for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): @@ -661,7 +658,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo norm_type = float(norm_type) if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float + ) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) total_norm = total_norm_cuda.item() @@ -673,7 +672,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo # Sum across all model parallel GPUs. total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float ) torch.distributed.all_reduce( total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg @@ -765,7 +764,7 @@ def state_dict(self) -> Dict: Dict: the pytorch form state_dict """ zero_state = dict() - device = get_current_device() + device = get_accelerator().get_current_device() for param, state in self.optim.state.items(): zero_state[param] = copy.deepcopy(state) for k, v in state.items(): @@ -827,7 +826,7 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i ret_block = dict() ret_block_size = 0 - device = get_current_device() + device = get_accelerator().get_current_device() local_states = self.optim.state_dict()["state"] for param_idx, states in local_states.items(): current_block_size = 0 diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 7a0e3b1a0276..e87eafb6eec7 100644 --- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -45,7 +45,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device ``` ## Define Plugin Create a `HybridParallelPlugin` object and specify the desired parallelism strategies to be used. In this example, both pipeline parallelism and ZeRO-1 are used simultaneously. @@ -149,7 +148,7 @@ model, optimizer, _criterion, _, lr_scheduler = booster.boost( ## Training GPT-2 using hybrid parallelism -In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. +In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. Define a training function. When pipeline parallelism is used, you need to call `booster.execute_pipeline` to schedule the stages of model training. ```python def train_epoch( @@ -204,4 +203,4 @@ Training the gpt-2 model for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` - \ No newline at end of file + diff --git a/docs/source/en/basics/booster_api.md b/docs/source/en/basics/booster_api.md index 4d7ffe5a4cbf..2c75dd9acfea 100644 --- a/docs/source/en/basics/booster_api.md +++ b/docs/source/en/basics/booster_api.md @@ -32,7 +32,7 @@ Plugin is an important component that manages parallel configuration (eg: The ge More details about usages of each plugin can be found in chapter [Booster Plugins](./booster_plugins.md). -Some plugins support lazy initialization, which can be used to save memory when initializating large models. For more details, please see [Lazy Initialization](../features/lazy_init.md). +Some plugins support lazy initialization, which can be used to save memory when initializing large models. For more details, please see [Lazy Initialization](../features/lazy_init.md). ### API of booster diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 11740698057f..ae941b489b90 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -43,7 +43,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device ``` ### 定义plugin 定义一个[`HybridParallelPlugin`](../basics/booster_plugins.md)对象,指定所需要使用的并行策略,在该例子中,同时使用了流水线并行和zero1. @@ -201,4 +200,4 @@ def train_epoch( for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` - \ No newline at end of file + diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py index 5396de6935cb..40b11d649ae0 100644 --- a/examples/community/roberta/pretraining/run_pretraining.py +++ b/examples/community/roberta/pretraining/run_pretraining.py @@ -16,10 +16,10 @@ from utils.logger import Logger import colossalai +from colossalai.accelerator import get_accelerator from colossalai.context import ParallelMode from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.tensor import ProcessGroup, ShardSpec -from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext @@ -53,7 +53,7 @@ def main(): set_global_variables(launch_time, args.tensorboard_path) world_size = torch.distributed.get_world_size() - get_current_device() + get_accelerator().get_current_device() # build model, optimizer and criterion if args.distplan.startswith("CAI"): @@ -67,7 +67,10 @@ def main(): # build GPT model with ColoInitContext( - device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg + device=get_accelerator().get_current_device(), + dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg, ): config, model, numel = get_model(args, logger) @@ -78,7 +81,7 @@ def main(): elif args.distplan == "CAI_Gemini": gemini_config = dict( strict_ddp_mode=args.tp_degree == 1, - device=get_current_device(), + device=get_accelerator().get_current_device(), placement_policy=args.placement, pin_memory=True, hidden_dim=model.config.hidden_size, diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 1a7f8da7f7d0..cc2b2ebc7b88 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -20,11 +20,11 @@ from transformers import AutoTokenizer, PretrainedConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device disable_existing_loggers() logger = get_dist_logger() @@ -386,7 +386,7 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 + torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -401,7 +401,7 @@ def main(args): sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - pipeline.to(get_current_device()) + pipeline.to(get_accelerator().get_current_device()) for example in tqdm( sample_dataloader, @@ -578,8 +578,8 @@ def collate_fn(examples): # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. - vae.to(get_current_device(), dtype=weight_dtype) - text_encoder.to(get_current_device(), dtype=weight_dtype) + vae.to(get_accelerator().get_current_device(), dtype=weight_dtype) + text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader)) @@ -613,7 +613,7 @@ def collate_fn(examples): torch.cuda.reset_peak_memory_stats() # Move batch to gpu for key, value in batch.items(): - batch[key] = value.to(get_current_device(), non_blocking=True) + batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True) # Convert images to latent space optimizer.zero_grad() diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index ea6dde8bb578..227488abe204 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -21,13 +21,13 @@ from transformers import AutoTokenizer, PretrainedConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device disable_existing_loggers() logger = get_dist_logger() @@ -385,7 +385,7 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 + torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -400,7 +400,7 @@ def main(args): sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - pipeline.to(get_current_device()) + pipeline.to(get_accelerator().get_current_device()) for example in tqdm( sample_dataloader, @@ -598,8 +598,8 @@ def collate_fn(examples): # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. - vae.to(get_current_device(), dtype=weight_dtype) - text_encoder.to(get_current_device(), dtype=weight_dtype) + vae.to(get_accelerator().get_current_device(), dtype=weight_dtype) + text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader)) @@ -633,7 +633,7 @@ def collate_fn(examples): torch.cuda.reset_peak_memory_stats() # Move batch to gpu for key, value in batch.items(): - batch[key] = value.to(get_current_device(), non_blocking=True) + batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True) # Convert images to latent space optimizer.zero_grad() diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py index 13df516d4189..5871bbf8748b 100644 --- a/examples/images/resnet/train.py +++ b/examples/images/resnet/train.py @@ -13,12 +13,12 @@ from tqdm import tqdm import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl @torch.no_grad() def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: model.eval() - correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) - total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) for images, labels in test_dataloader: images = images.cuda() labels = labels.cuda() diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index b770bc9cfb95..0780173241aa 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -33,9 +33,10 @@ def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224 def colo_memory_cap(size_in_GB): - from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + from colossalai.accelerator import get_accelerator + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) print(f"Limiting GPU memory usage to {size_in_GB} GB") diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 772fe2200fed..c49d9898238b 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -8,11 +8,9 @@ from transformers import AutoTokenizer, GenerationConfig import colossalai -import colossalai.utils.device as device_utils -from colossalai.inference.config import InferenceConfig -from colossalai.inference.core.engine import InferenceEngine +from colossalai.accelerator import get_accelerator +from colossalai.inference import InferenceEngine from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -from colossalai.utils.device import get_current_device GIGABYTE = 1024**3 MEGABYTE = 1024 * 1024 @@ -55,7 +53,7 @@ def data_gen(batch_size: int = 4, seq_len: int = 512): - input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device()) + input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device()) return input_ids @@ -78,9 +76,9 @@ def print_details_info(model_config, args, whole_end2end): msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n" if torch.cuda.is_available(): - msg += f"-------Memory Summary Device:{device_utils.current_device()}-------\n" - msg += f"Max memory allocated: {device_utils.max_memory_allocated() / GIGABYTE:.2f} GB\n" - msg += f"Max memory reserved: {device_utils.max_memory_reserved() / GIGABYTE:.2f} GB\n" + msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n" + msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n" + msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n" print(msg) diff --git a/examples/inference/run_llama_inference.py b/examples/inference/run_llama_inference.py index 8f85a936352b..b5228c64efa5 100644 --- a/examples/inference/run_llama_inference.py +++ b/examples/inference/run_llama_inference.py @@ -5,9 +5,9 @@ from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.inference import InferenceEngine from colossalai.testing import spawn -from colossalai.utils.device import get_current_device INPUT_TEXTS = [ "What is the longest river in the world?", @@ -57,7 +57,7 @@ def run_inference(args): ) inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True) - inputs = {k: v.to(get_current_device()) for k, v in inputs.items()} + inputs = {k: v.to(get_accelerator().get_current_device()) for k, v in inputs.items()} outputs = engine.generate(inputs) if rank == 0: diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index aad12c9c2c59..0b1e77ffff06 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -18,11 +18,11 @@ ) import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -59,7 +59,7 @@ def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True) - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) labels = batch["labels"] @@ -89,8 +89,10 @@ def evaluate_subset(dataloader: DataLoader): object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) - metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) - accum_loss.add_(object_list[1].to(get_current_device())) + metric.add_batch( + predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels + ) + accum_loss.add_(object_list[1].to(get_accelerator().get_current_device())) else: batch = move_to_cuda(batch) diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py index e811e1acbf7e..b35112498978 100644 --- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -7,13 +7,13 @@ from torch.utils._pytree import tree_map import colossalai +from colossalai.accelerator import get_accelerator from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.fx.profiler import parameter_size from colossalai.nn.optimizer import HybridAdam from colossalai.testing import spawn -from colossalai.utils import get_current_device def parse_args(): @@ -41,7 +41,7 @@ def train_gpt(args): 64, 8, ), - device=get_current_device(), + device=get_accelerator().get_current_device(), ) criterion = GPTLMLoss() diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 88b76c654b1d..78d090ba29da 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -12,12 +12,12 @@ from packaging import version import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device CAI_VERSION = colossalai.__version__ @@ -141,7 +141,11 @@ def main(): criterion = GPTLMLoss() torch.manual_seed(123) if args.distplan.startswith("CAI"): - ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext() + ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if args.distplan == "CAI_Gemini" + else nullcontext() + ) # build GPT model with ctx: model = model_builder(args.model_type)(checkpoint=True) diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 62804eff8ea5..eb56ee530a0a 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -13,11 +13,11 @@ from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -54,7 +54,7 @@ def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) labels = batch["labels"] @@ -83,8 +83,10 @@ def evaluate_subset(dataloader: DataLoader): object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) - metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) - accum_loss.add_(object_list[1].to(get_current_device())) + metric.add_batch( + predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels + ) + accum_loss.add_(object_list[1].to(get_accelerator().get_current_device())) else: batch = move_to_cuda(batch) diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py index b2e3f71a5387..ec3df50c4e67 100644 --- a/examples/language/gpt/titans/model/embed.py +++ b/examples/language/gpt/titans/model/embed.py @@ -5,6 +5,7 @@ from torch.nn import functional as F from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.base_layer import ParallelLayer @@ -12,7 +13,6 @@ from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row from colossalai.legacy.nn.layer.utils import divide from colossalai.legacy.registry import LAYERS, LOSSES -from colossalai.utils import get_current_device class VocabParallelEmbedding(torch.nn.Module): @@ -96,7 +96,9 @@ def forward(self, input_ids, position_ids=None, tokentype_ids=None): if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) if position_ids is None: - position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device()) + position_ids = torch.arange( + 0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device() + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_embeddings = self.position_embeddings(position_ids) @@ -194,7 +196,7 @@ def __init__(self, num_embeddings, embedding_dim, dtype=None, init_method=None): self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # Allocate weights and initialize. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs)) init.uniform_(self.weight, -1, 1) @@ -439,7 +441,9 @@ def forward(self, input_ids, position_ids=None, tokentype_ids=None): if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) if position_ids is None: - position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device()) + position_ids = torch.arange( + 0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device() + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_embeddings = self.position_embeddings(position_ids) @@ -532,7 +536,7 @@ def __init__(self, num_embeddings, embedding_dim, dtype=torch.float, padding_idx self._weight = None # Allocate weights and initialize. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs)) init.uniform_(self.weight, -1, 1) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index a4c29b7c8231..b8f70ce9c9d8 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -13,13 +13,12 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM import colossalai -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Constants @@ -74,8 +73,8 @@ def main(): parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") - parser.add_argument("--mbs", type=int, default=1) - parser.add_argument("--zero", type=int, default=0) + parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") + parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") args = parser.parse_args() colossalai.launch_from_torch({}) @@ -98,7 +97,13 @@ def empty_init(): extra_dp_size=args.extra_dp, ) elif args.plugin == "gemini_auto": - plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp) + plugin = GeminiPlugin( + placement_policy="auto", + precision="bf16", + warmup_non_model_data_ratio=args.warmup_ratio, + tp_size=args.tp, + extra_dp_size=args.extra_dp, + ) elif args.plugin == "fsdp": if use_empty_init: plugin = TorchFSDPPlugin( @@ -137,7 +142,7 @@ def empty_init(): zero_stage=args.zero, num_model_chunks=2, enable_fused_normalization=torch.cuda.is_available(), - num_microbatches=args.mbs, + microbatch_size=args.mbs, precision="bf16", ) elif args.plugin == "3d_cpu": @@ -147,7 +152,7 @@ def empty_init(): zero_stage=args.zero, cpu_offload=True, enable_fused_normalization=torch.cuda.is_available(), - num_microbatches=args.mbs, + microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", ) @@ -171,7 +176,7 @@ def empty_init(): # Initialize Model and Optimizer # ============================== init_ctx = ( - LazyInitContext(default_device=get_current_device()) + LazyInitContext(default_device=get_accelerator().get_current_device()) if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) else nullcontext() ) @@ -202,7 +207,9 @@ def empty_init(): torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) torch.set_default_dtype(torch.float) - coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" + ) coordinator.print_on_master( f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" ) @@ -228,7 +235,7 @@ def empty_init(): performance_evaluator.on_step_end(**batch) performance_evaluator.on_fit_end() - coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": diff --git a/examples/language/llama2/data_utils.py b/examples/language/llama2/data_utils.py index a438833e1680..6b9e8ef28eb7 100644 --- a/examples/language/llama2/data_utils.py +++ b/examples/language/llama2/data_utils.py @@ -8,7 +8,7 @@ from torch.distributed.distributed_c10d import _get_default_group from torch.utils.data import DataLoader, Dataset, DistributedSampler -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator class StatefulDistributedSampler(DistributedSampler): @@ -108,7 +108,9 @@ class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): self.num_samples = num_samples self.max_length = max_length - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py index f7708b1a38ab..66b5400765f7 100644 --- a/examples/language/llama2/finetune.py +++ b/examples/language/llama2/finetune.py @@ -21,13 +21,13 @@ from transformers.models.llama.tokenization_llama import LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def get_model_numel(model: nn.Module) -> int: @@ -191,7 +191,9 @@ def main(): config = LlamaConfig.from_pretrained(args.model_path) # use lazy init when using GeminiPlugin init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, GeminiPlugin) + else nullcontext() ) with init_ctx: diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py index 6b1c92711d48..c2169a730a88 100644 --- a/examples/language/llama2/performance_evaluator.py +++ b/examples/language/llama2/performance_evaluator.py @@ -5,9 +5,8 @@ import torch.distributed as dist from torch import Tensor -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.cluster import DistCoordinator -from colossalai.utils.device import get_current_device def divide(x: float, y: float) -> float: @@ -22,7 +21,7 @@ def divide(x: float, y: float) -> float: def all_reduce_mean(x: float, world_size: int) -> float: if world_size == 1: return x - tensor = torch.tensor([x], device=get_current_device()) + tensor = torch.tensor([x], device=get_accelerator().get_current_device()) dist.all_reduce(tensor) tensor = tensor / world_size return tensor.item() @@ -86,13 +85,13 @@ def on_step_start(self, step: int) -> None: self.disable = self.ignore_steps > 0 and step < self.ignore_steps if self.disable: return - device_utils.synchronize() + get_accelerator().synchronize() self.timer.start() def on_step_end(self, input_ids: Tensor, **kwargs) -> None: if self.disable: return - device_utils.synchronize() + get_accelerator().synchronize() self.timer.end() batch_size, seq_len = input_ids.shape diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index bb10f7a00e8a..4cdf93e1914b 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -20,13 +20,13 @@ from transformers.models.llama.tokenization_llama import LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device MODEL_CONFIGS = { "7b": LlamaConfig(max_position_embeddings=4096), @@ -227,7 +227,9 @@ def main(): config = MODEL_CONFIGS[args.config] # use lazy init when using GeminiPlugin init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, GeminiPlugin) + else nullcontext() ) with init_ctx: @@ -273,11 +275,10 @@ def main(): dataloader.sampler.set_start_index(sampler_start_idx) for epoch in range(start_epoch, args.num_epochs): dataloader.sampler.set_epoch(epoch) - step_nums = num_steps_per_epoch - start_step dataloader_iter = iter(dataloader) with tqdm( - range(step_nums), + range(start_step, num_steps_per_epoch), desc=f"Epoch {epoch}", disable=not print_flag, total=num_steps_per_epoch, diff --git a/examples/language/llama2/scripts/benchmark_70B/3d.sh b/examples/language/llama2/scripts/benchmark_70B/3d.sh index d50c57042d1a..cb8f218fa3fc 100644 --- a/examples/language/llama2/scripts/benchmark_70B/3d.sh +++ b/examples/language/llama2/scripts/benchmark_70B/3d.sh @@ -14,4 +14,4 @@ cd ../.. export OMP_NUM_THREADS=8 -colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 4 +colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 1 diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 65562b386cf9..03b660ecf446 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -14,6 +14,7 @@ from utils import PerformanceEvaluator, get_model_numel import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator @@ -21,7 +22,6 @@ from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def move_to_cuda(batch, device): @@ -64,13 +64,15 @@ def __init__( ) self.input_ids.append(encode["input_ids"]) self.attention_mask.append(encode["attention_mask"]) - self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device()) - self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device()) + self.input_ids = torch.cat(self.input_ids, dim=0).to(get_accelerator().get_current_device()) + self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_accelerator().get_current_device()) repeat_times = num_samples // self.input_ids.shape[0] + 1 self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples] self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples] else: - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index ec7644317903..eee3b505a22a 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -35,7 +35,7 @@ replace_return_docstrings, ) -from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN +from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index f354bbea990e..17e7aa46ce85 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -43,7 +43,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False raise NotImplementedError( - "openmoe dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.") diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index b084361661ac..1ae661f548b8 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -15,6 +15,7 @@ from transformers.models.llama import LlamaConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator @@ -22,7 +23,6 @@ from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def move_to_cuda(batch, device): @@ -61,7 +61,9 @@ class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None): self.num_samples = num_samples self.max_length = max_length - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 7af02e24e6cf..4fac7b5072ed 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -14,12 +14,12 @@ from torch.utils.data import DataLoader, Dataset import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn import HybridAdam -from colossalai.utils import get_current_device # constants @@ -159,7 +159,11 @@ def __len__(self): logger.info(f"plugin: {plugin}") booster = Booster(plugin=plugin, **booster_kwargs) - ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext() + ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if args.plugin == "gemini" + else nullcontext() + ) with ctx: model = PaLM(num_tokens=50304, dim=4096, depth=64) diff --git a/examples/tutorial/auto_parallel/README.md b/examples/tutorial/auto_parallel/README.md index 13561567636e..6f11298fcfc8 100644 --- a/examples/tutorial/auto_parallel/README.md +++ b/examples/tutorial/auto_parallel/README.md @@ -49,7 +49,7 @@ You should expect to the log like this. This log shows the edge cost on the comp ### Auto-Checkpoint Tutorial -We prepare two bechmarks for you to test the performance of auto checkpoint +We prepare two benchmarks for you to test the performance of auto checkpoint The first test `auto_ckpt_solver_test.py` will show you the ability of solver to search checkpoint strategy that could fit in the given budget (test on GPT2 Medium and ResNet 50). It will output the benchmark summary and data visualization of peak memory vs. budget memory and relative step time vs. peak memory. diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py index 4407a51c3153..a4733126f3ee 100644 --- a/examples/tutorial/new_api/cifar_resnet/train.py +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -13,12 +13,12 @@ from tqdm import tqdm import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl @torch.no_grad() def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: model.eval() - correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) - total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) for images, labels in test_dataloader: images = images.cuda() labels = labels.cuda() diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py index 700e4d2e0cd9..ec6c852b5965 100644 --- a/examples/tutorial/new_api/cifar_vit/train.py +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -13,13 +13,13 @@ from tqdm import tqdm import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.cluster import DistCoordinator from colossalai.nn.lr_scheduler import LinearWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -73,8 +73,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl @torch.no_grad() def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: model.eval() - correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) - total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) for images, labels in test_dataloader: images = images.cuda() labels = labels.cuda() diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py index 990822c9feba..e97c9017fe56 100644 --- a/examples/tutorial/new_api/glue_bert/finetune.py +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -12,11 +12,11 @@ from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -45,7 +45,7 @@ def evaluate( model.eval() def evaluate_subset(dataloader: DataLoader): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) outputs = model(**batch) diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index 9bd23ffc8aba..3f0d048795e6 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -51,13 +51,13 @@ from transformers.utils.versions import require_version import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.tensor import ProcessGroup from colossalai.legacy.utils import get_dataloader from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device from colossalai.zero import GeminiOptimizer require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") @@ -249,9 +249,9 @@ def parse_args(): def colo_memory_cap(size_in_GB): - from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) print("Using {} GB of GPU memory".format(size_in_GB)) @@ -265,7 +265,9 @@ def __init__(self, length, batch_size, seq_len, vocab_size): self.vocab_size = vocab_size def generate(self): - input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device()) + input_ids = torch.randint( + 0, self.vocab_size, (self.batch_size, self.seq_len), device=get_accelerator().get_current_device() + ) attention_mask = torch.ones_like(input_ids) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids} @@ -390,7 +392,7 @@ def main(): if args.init_in_cpu: init_dev = torch.device("cpu") else: - init_dev = get_current_device() + init_dev = get_accelerator().get_current_device() cai_version = colossalai.__version__ logger.info(f"using Colossal-AI version {cai_version}") @@ -439,7 +441,9 @@ def main(): except ImportError: # this works for unreleased main branch, and this may be released on 0.2.9 from colossalai.zero import GeminiDDP - model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) + model = GeminiDDP( + model, device=get_accelerator().get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True + ) elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): from colossalai.gemini import ChunkManager, GeminiManager diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py index 7b0e93d958ca..64260374a0d5 100644 --- a/examples/tutorial/sequence_parallel/model/bert.py +++ b/examples/tutorial/sequence_parallel/model/bert.py @@ -3,13 +3,13 @@ import torch import torch.nn as nn -from colossalai.kernel import LayerNorm from colossalai.legacy.context import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper from colossalai.legacy.pipeline.utils import partition_uniform from colossalai.logging import get_dist_logger +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from .layers import BertDualHead, BertLayer, Embedding, PreProcessor, VocabEmbedding from .layers.init_method import init_normal, output_init_normal diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py index 75afeee60ad4..ff81ace39736 100644 --- a/examples/tutorial/sequence_parallel/model/layers/head.py +++ b/examples/tutorial/sequence_parallel/model/layers/head.py @@ -3,9 +3,9 @@ import torch.nn.functional as F from loss_func.cross_entropy import vocab_cross_entropy -from colossalai.kernel import LayerNorm from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from .linear import Linear from .pooler import Pooler diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py index e9ceb8d70cb8..f25fc818981a 100644 --- a/examples/tutorial/sequence_parallel/train.py +++ b/examples/tutorial/sequence_parallel/train.py @@ -8,12 +8,12 @@ from model.bert import BertForPretrain, build_pipeline_bert import colossalai -from colossalai.kernel import LayerNorm from colossalai.legacy.amp import AMP_TYPE from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import is_using_pp from colossalai.logging import get_dist_logger +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from colossalai.nn.optimizer import FusedAdam from colossalai.utils import MultiTimer diff --git a/extensions/README.md b/extensions/README.md new file mode 100644 index 000000000000..6f5feb55c2af --- /dev/null +++ b/extensions/README.md @@ -0,0 +1,140 @@ +# 🔌 Extensions + +## 📌 Table of Contents + +- [🔌 Extensions](#-extensions) + - [📌 Table of Contents](#-table-of-contents) + - [📚 Introduction](#-introduction) + - [🪅 Design](#-design) + - [🛠 API Usage](#-api-usage) + - [🏗 Write a customized extension](#-write-a-customized-extension) + - [✏️ Acknowledgement](#️-acknowledgement) + +## 📚 Introduction + +This module is a designed to offer extensions to the existing ColossalAI framework. It is designed to be a collection of high-performance kernels to speed up the training and inference process. Different from writing an individual kernel, the `extensions` module offers a layer of abstraction to collate kernels written in different compiler backends and for different hardware backends in an organized way. Please see the design and usage in the sections below. + +## 🪅 Design + +The `extensions` module is a sub-module of the `colossalai.kernel` module. This module is put at the project root directory so that it can be imported for AOT (ahead-of-time) build. At the same time, it is symbolically linked at the `colossalai.kernel.extensions` path for runtime build. + +As we want to support multi-backend kernels, we have to consider multiple compiler options such as `torch.jit`, `CUDA`, `triton` and multiple hardware backends such as `CPU`, `GPU` and `NPU`. To make it easy for the users, we have abstract away the kernels into extensions and expose a single loader to the user for each kind of kernel. + +For example, if the user wants to use the CPU Adam kernel, he can just call `load()` on the kernel loader. The kernel loader will automatically select the correct extension based on the current hardware and compiler backend. The user does not need to worry about the details of the kernel implementation. For example, if the user is using ARM CPU, then Arm kernel will be built and loaded. If it is a X86 CPU, then it is the X86 kernel that will be loaded. + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader + +# load the kernel compatible with the current hardware +kernel = CPUAdamLoader().load() +``` + +![](https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/extensions.png?raw=true) + +## 🛠 API Usage + +To make the `colossalai.kernel` easy to use, we expose some simple APIs and you can use them based on your scenario. + +- Case 1: Simply load a kernel + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader + +# load the kernel compatible with the current hardware +kernel = CPUAdamLoader().load() +``` + +- Case 2: Load a specific kernel + +This case applies if you are familiar with the extensions available. + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader + +# load the kernel by giving the kernel name +kernel = CPUAdamLoader().load(ext_name="cpu_adam_arm") +``` + +- Case 3: Register your own extension + +This case applies if you know how to write an extension. If you do not know how, you can refer to the section below. + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader +from colossalai.kernel.base_extension import _Extension + +# create your own extension class +class MyExtension(_Extension): + + def __init__(self): + self._name = "my_extension" + self._support_aot = True + self._support_jit = True + self.priority = 10 + + # implementation here + ... + +# register your extension +# you can use the priority value to make sure your kernel will be loaded by default +CPUAdamLoader.register_extension(MyExtension) + +# load the kernel +kernel = CPUAdamLoader().load() +``` + +## 🏗 Write a customized extension + +It is easy to write a customized extension. If you have experience writing CUDA/triton kernels, you should get familiar with the process quickly. + +You just need to inherit the `_Extension` base class or other backend-specific classes such as `_CudaExtension` and implement the abstract methods. Then, you need to register your extension to the kernel loader based on the Case 3 above. The kernel loader will automatically select the correct extension based on the priority score, current hardware, compiler backend. + +```python +from colossalai.kernel.base_extension import _Extension + + +class MyExtension(_Extension): + + def __init__(self): + self._name = "my_extension" + self._support_aot = True + self._support_jit = True + self.priority = 10 + + def is_hardware_available(self) -> bool: + """ + Return if the required hardware can be found. + """ + ... + + def assert_hardware_compatible(self) -> None: + """ + Check if the hardware required by the kernel is compatible. + """ + ... + + def build_aot(self) -> Union["CppExtension", "CUDAExtension"]: + """ + If this kernel can be built AOT, it should return an extension object + to Python setuptools for compilation. + """ + ... + + def build_jit(self) -> Callable: + """ + Build extension kernel just in time. + """ + ... + + def load(self): + """ + The API called by the user to get the kernel. + """ + ... + +``` + +## ✏️ Acknowledgement + +This module is written from scratch but we learnt a lot by looking into [DeepSpeed' +s op_builder](https://github.com/microsoft/DeepSpeed/tree/master/op_builder). We wish to acknowledge their great work and contributions to the open-source community. diff --git a/extensions/__init__.py b/extensions/__init__.py new file mode 100644 index 000000000000..9343cadda194 --- /dev/null +++ b/extensions/__init__.py @@ -0,0 +1,36 @@ +from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension +from .flash_attention import ( + FlashAttentionDaoCudaExtension, + FlashAttentionNpuExtension, + FlashAttentionXformersCudaExtension, +) +from .layernorm import LayerNormCudaExtension +from .moe import MoeCudaExtension +from .optimizer import FusedOptimizerCudaExtension +from .softmax import ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension + +ALL_EXTENSIONS = [ + CpuAdamArmExtension, + CpuAdamX86Extension, + LayerNormCudaExtension, + MoeCudaExtension, + FusedOptimizerCudaExtension, + ScaledMaskedSoftmaxCudaExtension, + ScaledUpperTriangleMaskedSoftmaxCudaExtension, + FlashAttentionDaoCudaExtension, + FlashAttentionXformersCudaExtension, + FlashAttentionNpuExtension, +] + +__all__ = [ + "CpuAdamArmExtension", + "CpuAdamX86Extension", + "LayerNormCudaExtension", + "MoeCudaExtension", + "FusedOptimizerCudaExtension", + "ScaledMaskedSoftmaxCudaExtension", + "ScaledUpperTriangleMaskedSoftmaxCudaExtension", + "FlashAttentionDaoCudaExtension", + "FlashAttentionXformersCudaExtension", + "FlashAttentionNpuExtension", +] diff --git a/extensions/base_extension.py b/extensions/base_extension.py new file mode 100644 index 000000000000..c815a7f2ac4a --- /dev/null +++ b/extensions/base_extension.py @@ -0,0 +1,82 @@ +import hashlib +import os +from abc import ABC, abstractmethod +from typing import Callable, Union + +__all__ = ["_Extension"] + + +class _Extension(ABC): + def __init__(self, name: str, support_aot: bool, support_jit: bool, priority: int = 1): + self._name = name + self._support_aot = support_aot + self._support_jit = support_jit + self.priority = priority + + @property + def name(self): + return self._name + + @property + def support_aot(self): + return self._support_aot + + @property + def support_jit(self): + return self._support_jit + + @staticmethod + def get_jit_extension_folder_path(): + """ + Kernels which are compiled during runtime will be stored in the same cache folder for reuse. + The folder is in the path ~/.cache/colossalai/torch_extensions/. + The name of the follows a common format: + torch._- + + The suffix is the hash value of the path of the `colossalai` file. + """ + import torch + + import colossalai + from colossalai.accelerator import get_accelerator + + # get torch version + torch_version_major = torch.__version__.split(".")[0] + torch_version_minor = torch.__version__.split(".")[1] + + # get device version + device_name = get_accelerator().name + device_version = get_accelerator().get_version() + + # use colossalai's file path as hash + hash_suffix = hashlib.sha256(colossalai.__file__.encode()).hexdigest() + + # concat + home_directory = os.path.expanduser("~") + extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_{device_name}-{device_version}-{hash_suffix}" + cache_directory = os.path.join(home_directory, extension_directory) + return cache_directory + + @abstractmethod + def is_hardware_available(self) -> bool: + """ + Check if the hardware required by the kernel is available. + """ + + @abstractmethod + def assert_hardware_compatible(self) -> None: + """ + Check if the hardware required by the kernel is compatible. + """ + + @abstractmethod + def build_aot(self) -> Union["CppExtension", "CUDAExtension"]: + pass + + @abstractmethod + def build_jit(self) -> Callable: + pass + + @abstractmethod + def load(self) -> Callable: + pass diff --git a/extensions/cpp_extension.py b/extensions/cpp_extension.py new file mode 100644 index 000000000000..b4c40c9f1105 --- /dev/null +++ b/extensions/cpp_extension.py @@ -0,0 +1,134 @@ +import importlib +import os +import time +from abc import abstractmethod +from pathlib import Path +from typing import List + +from .base_extension import _Extension + +__all__ = ["_CppExtension"] + + +class _CppExtension(_Extension): + def __init__(self, name: str, priority: int = 1): + super().__init__(name, support_aot=True, support_jit=True, priority=priority) + + # we store the op as an attribute to avoid repeated building and loading + self.cached_op = None + + # build-related variables + self.prebuilt_module_path = "colossalai._C" + self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}" + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + def csrc_abs_path(self, path): + return os.path.join(self.relative_to_abs_path("csrc"), path) + + def relative_to_abs_path(self, code_path: str) -> str: + """ + This function takes in a path relative to the colossalai root directory and return the absolute path. + """ + + # get the current file path + # iteratively check the parent directory + # if the parent directory is "extensions", then the current file path is the root directory + # otherwise, the current file path is inside the root directory + current_file_path = Path(__file__) + while True: + if current_file_path.name == "extensions": + break + else: + current_file_path = current_file_path.parent + extension_module_path = current_file_path + code_abs_path = extension_module_path.joinpath(code_path) + return str(code_abs_path) + + # functions must be overrided over + def strip_empty_entries(self, args): + """ + Drop any empty strings from the list of compile and link flags + """ + return [x for x in args if len(x) > 0] + + def import_op(self): + """ + This function will import the op module by its string name. + """ + return importlib.import_module(self.prebuilt_import_path) + + def build_aot(self) -> "CppExtension": + from torch.utils.cpp_extension import CppExtension + + return CppExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args=self.strip_empty_entries(self.cxx_flags()), + ) + + def build_jit(self) -> None: + from torch.utils.cpp_extension import load + + build_directory = _Extension.get_jit_extension_folder_path() + build_directory = Path(build_directory) + build_directory.mkdir(parents=True, exist_ok=True) + + # check if the kernel has been built + compiled_before = False + kernel_file_path = build_directory.joinpath(f"{self.name}.o") + if kernel_file_path.exists(): + compiled_before = True + + # load the kernel + if compiled_before: + print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now") + else: + print(f"[extension] Compiling the JIT {self.name} kernel during runtime now") + + build_start = time.time() + op_kernel = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_ldflags=[], + build_directory=str(build_directory), + ) + build_duration = time.time() - build_start + + if compiled_before: + print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds") + else: + print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds") + + return op_kernel + + # functions must be overrided begin + @abstractmethod + def sources_files(self) -> List[str]: + """ + This function should return a list of source files for extensions. + """ + + @abstractmethod + def include_dirs(self) -> List[str]: + """ + This function should return a list of include files for extensions. + """ + + @abstractmethod + def cxx_flags(self) -> List[str]: + """ + This function should return a list of cxx compilation flags for extensions. + """ + + def load(self): + try: + op_kernel = self.import_op() + except ImportError: + # if import error occurs, it means that the kernel is not pre-built + # so we build it jit + op_kernel = self.build_jit() + + return op_kernel diff --git a/extensions/cpu_adam/__init__.py b/extensions/cpu_adam/__init__.py new file mode 100644 index 000000000000..cfd26a6a20f8 --- /dev/null +++ b/extensions/cpu_adam/__init__.py @@ -0,0 +1,5 @@ +from .cpu_adam_arm import CpuAdamArmExtension +from .cpu_adam_x86 import CpuAdamX86Extension + +__all__ = ['CpuAdamArmExtension', 'CpuAdamX86Extension'] + diff --git a/extensions/cpu_adam/cpu_adam_arm.py b/extensions/cpu_adam/cpu_adam_arm.py new file mode 100644 index 000000000000..35bff3b55928 --- /dev/null +++ b/extensions/cpu_adam/cpu_adam_arm.py @@ -0,0 +1,41 @@ +import platform + +from ..cpp_extension import _CppExtension + + +class CpuAdamArmExtension(_CppExtension): + def __init__(self): + super().__init__(name="cpu_adam_arm") + + def is_hardware_available(self) -> bool: + # only arm allowed + return platform.machine() == "aarch64" + + def assert_hardware_compatible(self) -> None: + arch = platform.machine() + assert ( + arch == "aarch64" + ), f"[extension] The {self.name} kernel requires the CPU architecture to be aarch64 but got {arch}" + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("arm/cpu_adam_arm.cpp"), + ] + return ret + + def include_dirs(self): + return [] + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-g", + "-Wno-reorder", + "-fopenmp", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + return [] diff --git a/op_builder/cpu_adam.py b/extensions/cpu_adam/cpu_adam_x86.py similarity index 60% rename from op_builder/cpu_adam.py rename to extensions/cpu_adam/cpu_adam_x86.py index 7988aae4be12..a38194167b01 100644 --- a/op_builder/cpu_adam.py +++ b/extensions/cpu_adam/cpu_adam_x86.py @@ -1,19 +1,27 @@ -from .builder import Builder -from .utils import append_nvcc_threads +import platform +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads -class CPUAdamBuilder(Builder): - NAME = "cpu_adam" - PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam" +class CpuAdamX86Extension(_CudaExtension): def __init__(self): - super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + super().__init__(name="cpu_adam_x86") + + def is_hardware_available(self) -> bool: + return platform.machine() == "x86_64" and super().is_hardware_available() + + def assert_hardware_compatible(self) -> None: + arch = platform.machine() + assert ( + arch == "x86_64" + ), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}" + super().assert_hardware_compatible() # necessary 4 functions def sources_files(self): ret = [ - self.csrc_abs_path("cpu_adam.cpp"), + self.csrc_abs_path("cuda/cpu_adam.cpp"), ] return ret diff --git a/colossalai/kernel/cuda_native/__init__.py b/extensions/csrc/__init__.py similarity index 86% rename from colossalai/kernel/cuda_native/__init__.py rename to extensions/csrc/__init__.py index f8a974b5fb26..0eac28d23e24 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/extensions/csrc/__init__.py @@ -1,5 +1,4 @@ from .layer_norm import MixedFusedLayerNorm as LayerNorm -from .mha.mha import ColoAttention from .multihead_attention import MultiHeadAttention from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax @@ -8,6 +7,5 @@ "MultiHeadAttention", "FusedScaleMaskSoftmax", "ScaledUpperTriangMaskedSoftmax", - "ColoAttention", "AttnMaskType", ] diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp b/extensions/csrc/arm/cpu_adam_arm.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp rename to extensions/csrc/arm/cpu_adam_arm.cpp diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h b/extensions/csrc/arm/cpu_adam_arm.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h rename to extensions/csrc/arm/cpu_adam_arm.h diff --git a/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp b/extensions/csrc/cuda/colossal_C_frontend.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp rename to extensions/csrc/cuda/colossal_C_frontend.cpp diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/extensions/csrc/cuda/compat.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/compat.h rename to extensions/csrc/cuda/compat.h diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/extensions/csrc/cuda/cpu_adam.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam.cpp rename to extensions/csrc/cuda/cpu_adam.cpp diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/extensions/csrc/cuda/cpu_adam.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam.h rename to extensions/csrc/cuda/cpu_adam.h diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h b/extensions/csrc/cuda/include/block_reduce.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h rename to extensions/csrc/cuda/include/block_reduce.h diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/extensions/csrc/cuda/layer_norm_cuda.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp rename to extensions/csrc/cuda/layer_norm_cuda.cpp diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu b/extensions/csrc/cuda/layer_norm_cuda_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu rename to extensions/csrc/cuda/layer_norm_cuda_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/extensions/csrc/cuda/moe_cuda.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/moe_cuda.cpp rename to extensions/csrc/cuda/moe_cuda.cpp diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/extensions/csrc/cuda/moe_cuda_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu rename to extensions/csrc/cuda/moe_cuda_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu b/extensions/csrc/cuda/multi_tensor_adam.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu rename to extensions/csrc/cuda/multi_tensor_adam.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh b/extensions/csrc/cuda/multi_tensor_apply.cuh similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh rename to extensions/csrc/cuda/multi_tensor_apply.cuh diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu rename to extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/extensions/csrc/cuda/multi_tensor_lamb.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu rename to extensions/csrc/cuda/multi_tensor_lamb.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu b/extensions/csrc/cuda/multi_tensor_scale_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu rename to extensions/csrc/cuda/multi_tensor_scale_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu rename to extensions/csrc/cuda/multi_tensor_sgd_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/extensions/csrc/cuda/scaled_masked_softmax.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp rename to extensions/csrc/cuda/scaled_masked_softmax.cpp diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h b/extensions/csrc/cuda/scaled_masked_softmax.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h rename to extensions/csrc/cuda/scaled_masked_softmax.h diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu rename to extensions/csrc/cuda/scaled_masked_softmax_cuda.cu diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/extensions/csrc/cuda/type_shim.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/type_shim.h rename to extensions/csrc/cuda/type_shim.h diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/extensions/csrc/scaled_softmax.py similarity index 94% rename from colossalai/kernel/cuda_native/scaled_softmax.py rename to extensions/csrc/scaled_softmax.py index 26a5bce16d5c..7c220d60dd19 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/extensions/csrc/scaled_softmax.py @@ -6,8 +6,7 @@ import torch import torch.nn as nn -from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder -from colossalai.kernel.op_builder.scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder +from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader try: from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax @@ -35,7 +34,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): def forward(ctx, inputs, scale): global scaled_upper_triang_masked_softmax if scaled_upper_triang_masked_softmax: - scaled_upper_triang_masked_softmax = ScaledUpperTrainglemaskedSoftmaxBuilder().load() + scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load() scale_t = torch.tensor([scale]) softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) @@ -67,7 +66,7 @@ def forward(ctx, inputs, mask, scale): # build and load kernel if not pre-built global scaled_masked_softmax if scaled_masked_softmax is None: - scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) diff --git a/extensions/cuda_extension.py b/extensions/cuda_extension.py new file mode 100644 index 000000000000..b5e8a285b7e0 --- /dev/null +++ b/extensions/cuda_extension.py @@ -0,0 +1,106 @@ +import os +from abc import abstractmethod +from typing import List + +from .cpp_extension import _CppExtension +from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list + +__all__ = ["_CudaExtension"] + +# Some constants for installation checks +MIN_PYTORCH_VERSION_MAJOR = 1 +MIN_PYTORCH_VERSION_MINOR = 10 + + +class _CudaExtension(_CppExtension): + @abstractmethod + def nvcc_flags(self) -> List[str]: + """ + This function should return a list of nvcc compilation flags for extensions. + """ + + def is_hardware_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_hardware_compatible(self) -> None: + from torch.utils.cpp_extension import CUDA_HOME + + if not CUDA_HOME: + raise AssertionError( + "[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions" + ) + check_system_pytorch_cuda_match(CUDA_HOME) + check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR) + + def get_cuda_home_include(self): + """ + return include path inside the cuda home. + """ + from torch.utils.cpp_extension import CUDA_HOME + + if CUDA_HOME is None: + raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") + cuda_include = os.path.join(CUDA_HOME, "include") + return cuda_include + + def build_jit(self) -> None: + from torch.utils.cpp_extension import CUDA_HOME, load + + set_cuda_arch_list(CUDA_HOME) + + # get build dir + build_directory = _Extension.get_jit_extension_folder_path() + build_directory = Path(build_directory) + build_directory.mkdir(parents=True, exist_ok=True) + + # check if the kernel has been built + compiled_before = False + kernel_file_path = build_directory.joinpath(f"{self.name}.o") + if kernel_file_path.exists(): + compiled_before = True + + # load the kernel + if compiled_before: + print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now") + else: + print(f"[extension] Compiling the JIT {self.name} kernel during runtime now") + + build_start = time.time() + op_kernel = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=str(build_directory), + ) + build_duration = time.time() - build_start + + if compiled_before: + print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds") + else: + print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds") + + return op_kernel + + def build_aot(self) -> "CUDAExtension": + from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension + + set_cuda_arch_list(CUDA_HOME) + return CUDAExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + "cxx": self.strip_empty_entries(self.cxx_flags()), + "nvcc": self.strip_empty_entries(self.nvcc_flags()), + }, + ) diff --git a/extensions/flash_attention/__init__.py b/extensions/flash_attention/__init__.py new file mode 100644 index 000000000000..18abb6191035 --- /dev/null +++ b/extensions/flash_attention/__init__.py @@ -0,0 +1,20 @@ +from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension +from .flash_attention_npu import FlashAttentionNpuExtension +from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension + +try: + import flash_attention # noqa + + HAS_FLASH_ATTN = True +except: + HAS_FLASH_ATTN = False + +try: + import xformers # noqa + + HAS_MEM_EFF_ATTN = True +except: + HAS_MEM_EFF_ATTN = False + + +__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"] diff --git a/extensions/flash_attention/flash_attention_dao_cuda.py b/extensions/flash_attention/flash_attention_dao_cuda.py new file mode 100644 index 000000000000..1b7f8ac4736a --- /dev/null +++ b/extensions/flash_attention/flash_attention_dao_cuda.py @@ -0,0 +1,93 @@ +from ..base_extension import _Extension + + +class FlashAttentionDaoCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10) + + def is_hardware_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_hardware_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'" + ) + + def load(self): + try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + except ImportError: + raise ModuleNotFoundError( + ( + "We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'" + ) + ) + + from typing import Optional + + import torch + + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: "SeqLenInfo", + seq_len_info_kv: "SeqLenInfo", + origin_attn_mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): + """ + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + # check if the input is in allowed dtypes + if padded: + if seq_len_info_kv == None: + seq_len_info_kv = seq_len_info_q + + attn_out = flash_attn_varlen_func( + q, + k, + v, + seq_len_info_q.cu_seqlens, + seq_len_info_kv.cu_seqlens, + seq_len_info_q.max_seqlen, + seq_len_info_kv.max_seqlen, + dropout_p, + scale, + causal, + ) + else: + attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) + return attn_out + + return flash_attention diff --git a/extensions/flash_attention/flash_attention_npu.py b/extensions/flash_attention/flash_attention_npu.py new file mode 100644 index 000000000000..58d0f9306e3d --- /dev/null +++ b/extensions/flash_attention/flash_attention_npu.py @@ -0,0 +1,73 @@ +from ..base_extension import _Extension + + +class FlashAttentionNpuExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False) + + def is_hardware_available(self) -> bool: + try: + import torch_npu # noqa + + return True + except: + return False + + def assert_hardware_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "Flash Attention NPU does not require ahead-of-time compilation. Please use it by installing torch_npu." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "Flash Attention NPU does not require just-in-time compilation. Please use it by installing torch_npu." + ) + + def load(self): + import torch + from einops import rearrange + + def npu_sdpa_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q=None, + seq_len_info_kv=None, + origin_attn_mask: torch.Tensor = None, + dropout_p: float = 0.0, + scale: float = 1.0, + causal=None, + padded=None, + ): + """ + The scaled dot product attention. + + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + scale: float. The scaling of QK^T before applying softmax. + Default to 1. + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] + output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=origin_attn_mask, + dropout_p=dropout_p, + is_causal=origin_attn_mask is None, + scale=scale, + ) + output = rearrange(output, "b h s d -> b s (h d)") + return output + + return npu_sdpa_attention diff --git a/extensions/flash_attention/flash_attention_xformers_cuda.py b/extensions/flash_attention/flash_attention_xformers_cuda.py new file mode 100644 index 000000000000..27cd823de14b --- /dev/null +++ b/extensions/flash_attention/flash_attention_xformers_cuda.py @@ -0,0 +1,94 @@ +from ..base_extension import _Extension + + +class FlashAttentionXformersCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False) + + def is_hardware_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_hardware_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." + ) + + def load(self): + try: + from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention + from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + ) + except ImportError: + raise ModuleNotFoundError( + ( + "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." + ) + ) + from typing import Optional + + import torch + + allow_alibi = True + for op in MemoryEfficientAttentionCutlassOp: + allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) + + def mem_eff_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: "SeqLenInfo", + seq_len_info_kv: "SeqLenInfo", + origin_attn_mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): + attn_bias = None + if padded: # bert style + if not causal: + attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + elif causal: # gpt style + attn_bias = LowerTriangularMask() + + if bias is not None: # alibi / relative position embedding + assert allow_alibi, "flash attention with bias is not supported in this system." + assert causal, "attention with bias is only supported for causal attention so far." + attn_bias = attn_bias.add_bias(bias) + + if padded: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) + + # shape: (b*s, n, d) + if padded: + out = out.squeeze(0) + + return out + + return mem_eff_attention diff --git a/extensions/layernorm/__init__.py b/extensions/layernorm/__init__.py new file mode 100644 index 000000000000..9d1bd2d019ee --- /dev/null +++ b/extensions/layernorm/__init__.py @@ -0,0 +1,3 @@ +from .layernorm_cuda import LayerNormCudaExtension + +__all__ = ["LayerNormCudaExtension"] \ No newline at end of file diff --git a/extensions/layernorm/layernorm_cuda.py b/extensions/layernorm/layernorm_cuda.py new file mode 100644 index 000000000000..db5f2fce1368 --- /dev/null +++ b/extensions/layernorm/layernorm_cuda.py @@ -0,0 +1,24 @@ +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag + + +class LayerNormCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="layernorm_cuda") + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ["cuda/layer_norm_cuda.cpp", "cuda/layer_norm_cuda_kernel.cu"]] + return ret + + def include_dirs(self): + ret = [self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-maxrregcount=50"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros + return append_nvcc_threads(ret) diff --git a/extensions/moe/__init__.py b/extensions/moe/__init__.py new file mode 100644 index 000000000000..962084d4bdde --- /dev/null +++ b/extensions/moe/__init__.py @@ -0,0 +1,3 @@ +from .moe_cuda import MoeCudaExtension + +__all__ = ['MoeCudaExtension'] \ No newline at end of file diff --git a/op_builder/moe.py b/extensions/moe/moe_cuda.py similarity index 56% rename from op_builder/moe.py rename to extensions/moe/moe_cuda.py index 6f8028b1720c..52883e97fc3a 100644 --- a/op_builder/moe.py +++ b/extensions/moe/moe_cuda.py @@ -1,20 +1,17 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag -class MOEBuilder(Builder): - NAME = "moe" - PREBUILT_IMPORT_PATH = "colossalai._C.moe" - +class MoeCudaExtension(_CudaExtension): def __init__(self): - super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH) + super().__init__(name="moe_cuda") def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] + ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()] return ret def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["moe_cuda.cpp", "moe_cuda_kernel.cu"]] + ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe_cuda.cpp", "cuda/moe_cuda_kernel.cu"]] return ret def cxx_flags(self): diff --git a/extensions/optimizer/__init__.py b/extensions/optimizer/__init__.py new file mode 100644 index 000000000000..9c8e87cae5de --- /dev/null +++ b/extensions/optimizer/__init__.py @@ -0,0 +1,3 @@ +from .fused_optimizer_cuda import FusedOptimizerCudaExtension + +__all__ = ['FusedOptimizerCudaExtension'] \ No newline at end of file diff --git a/extensions/optimizer/fused_optimizer_cuda.py b/extensions/optimizer/fused_optimizer_cuda.py new file mode 100644 index 000000000000..e065cf34a17d --- /dev/null +++ b/extensions/optimizer/fused_optimizer_cuda.py @@ -0,0 +1,34 @@ +from ..cuda_extension import _CudaExtension +from ..utils import get_cuda_cc_flag + + +class FusedOptimizerCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="fused_optim_cuda") + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "cuda/colossal_C_frontend.cpp", + "cuda/multi_tensor_sgd_kernel.cu", + "cuda/multi_tensor_scale_kernel.cu", + "cuda/multi_tensor_adam.cu", + "cuda/multi_tensor_l2norm_kernel.cu", + "cuda/multi_tensor_lamb.cu", + ] + ] + return ret + + def include_dirs(self): + ret = [self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + return ["-O3"] + version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-lineinfo"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/extensions/softmax/__init__.py b/extensions/softmax/__init__.py new file mode 100644 index 000000000000..9bc50c6cd91c --- /dev/null +++ b/extensions/softmax/__init__.py @@ -0,0 +1,4 @@ +from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension +from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension + +__all__ = ['ScaledMaskedSoftmaxCudaExtension', 'ScaledUpperTriangleMaskedSoftmaxCudaExtension'] \ No newline at end of file diff --git a/op_builder/scaled_masked_softmax.py b/extensions/softmax/scaled_masked_softmax_cuda.py similarity index 50% rename from op_builder/scaled_masked_softmax.py rename to extensions/softmax/scaled_masked_softmax_cuda.py index d9239a80eef6..5b4208dba895 100644 --- a/op_builder/scaled_masked_softmax.py +++ b/extensions/softmax/scaled_masked_softmax_cuda.py @@ -1,23 +1,20 @@ -from .builder import Builder -from .utils import append_nvcc_threads +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads -class ScaledMaskedSoftmaxBuilder(Builder): - NAME = "scaled_masked_softmax" - PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax" - +class ScaledMaskedSoftmaxCudaExtension(_CudaExtension): def __init__(self): - super().__init__( - name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH - ) + super().__init__(name="scaled_masked_softmax_cuda") - # necessary 4 functions def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["scaled_masked_softmax.cpp", "scaled_masked_softmax_cuda.cu"]] + ret = [ + self.csrc_abs_path(fname) + for fname in ["cuda/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_cuda.cu"] + ] return ret def include_dirs(self): - return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] + return [self.get_cuda_home_include()] def cxx_flags(self): return ["-O3"] + self.version_dependent_macros diff --git a/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py b/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py new file mode 100644 index 000000000000..d4f27a9218ff --- /dev/null +++ b/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py @@ -0,0 +1,34 @@ +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag + + +class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="scaled_upper_triangle_masked_softmax_cuda") + + def include_dirs(self): + return [self.get_cuda_home_include()] + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "cuda/scaled_upper_triang_masked_softmax.cpp", + "cuda/scaled_upper_triang_masked_softmax_cuda.cu", + ] + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/extensions/triton_extension.py b/extensions/triton_extension.py new file mode 100644 index 000000000000..9f0792f8ce68 --- /dev/null +++ b/extensions/triton_extension.py @@ -0,0 +1,21 @@ +from .base_extension import _Extension + +__all__ = ["_TritonExtension"] + + +class _TritonExtension(_Extension): + def __init__(self, name: str, priority: int = 1): + super().__init__(name, support_aot=False, support_jit=True, priority=priority) + + def is_hardware_compatible(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def load(self): + return self.build_jit() diff --git a/op_builder/utils.py b/extensions/utils.py similarity index 100% rename from op_builder/utils.py rename to extensions/utils.py diff --git a/op_builder/README.md b/op_builder/README.md deleted file mode 100644 index 9c33a4a328d7..000000000000 --- a/op_builder/README.md +++ /dev/null @@ -1,32 +0,0 @@ -# Build PyTorch Extensions - -## Overview - -Building PyTorch extensions can be a difficult task for users not from the system background. It is definitely frustrating if the users encounter many strange technical jargons when install Colossal-AI. Therefore, we will provide two methods of building the PyTorch extensions for the users. - -1. Build CUDA extensions when running `pip install` if `CUDA_EXT=1` -2. Build the extension during runtime - -The first method is more suitable for users who are familiar with CUDA environment configurations. The second method is for those who are not as they only need to build the kernel which is required by their program. - -These two methods have different advantages and disadvantages. -Method 1 is good because it allows the user to build all kernels during installation and directly import the kernel. They don't need to care about kernel building when running their program. However, installation may fail if they don't know how to configure their environments and this leads to much frustration. -Method 2 is good because it allows the user to only build the kernel they actually need, such that there is a lower probability that they encounter environment issue. However, it may slow down their program due to the first build and subsequence load. - -## PyTorch Extensions in Colossal-AI - -The project [DeepSpeed](https://github.com/microsoft/DeepSpeed) has proposed a [solution](https://github.com/microsoft/DeepSpeed/tree/master/op_builder) to support kernel-build during either installation or runtime. -We have adapted from DeepSpeed's solution to build extensions. The extension build requires two main functions from PyTorch: - -1. `torch.utils.cpp_extension.CUDAExtension`: used to build extensions in `setup.py` during `pip install`. -2. `torch.utils.cpp_extension.load`: used to build and load extension during runtime - -Please note that the extension build by `CUDAExtension` cannot be loaded by the `load` function and `load` will run its own build again (correct me if I am wrong). - -Based on the DeepSpeed's work, we have make several modifications and improvements: - -1. All pre-built kernels (those installed with `setup.py`) will be found in `colossalai._C` -2. All runtime-built kernels will be found in the default torch extension path, i.e. ~/.cache/colossalai/torch_extensions. (If we put the built kernels in the installed site-package directory, this will make pip uninstall incomplete) -3. Once a kernel is loaded, we will cache it in the builder to avoid repeated kernel loading. - -When loading the built kernel, we will first check if the pre-built one exists. If not, the runtime build will be triggered. diff --git a/op_builder/__init__.py b/op_builder/__init__.py deleted file mode 100644 index 21e216437c47..000000000000 --- a/op_builder/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -from .arm_cpu_adam import ArmCPUAdamBuilder -from .cpu_adam import CPUAdamBuilder -from .fused_optim import FusedOptimBuilder -from .layernorm import LayerNormBuilder -from .moe import MOEBuilder -from .multi_head_attn import MultiHeadAttnBuilder -from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder -from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder - -ALL_OPS = { - "cpu_adam": CPUAdamBuilder, - "fused_optim": FusedOptimBuilder, - "moe": MOEBuilder, - "multi_head_attn": MultiHeadAttnBuilder, - "scaled_masked_softmax": ScaledMaskedSoftmaxBuilder, - "scaled_upper_triangle_masked_softmax": ScaledUpperTrainglemaskedSoftmaxBuilder, - "layernorm": LayerNormBuilder, -} - -__all__ = [ - "ALL_OPS", - "CPUAdamBuilder", - "FusedOptimBuilder", - "MultiHeadAttnBuilder", - "ScaledMaskedSoftmaxBuilder", - "ScaledUpperTrainglemaskedSoftmaxBuilder", - "MOEBuilder", - "MultiTensorSGDBuilder", - "MultiTensorAdamBuilder", - "MultiTensorLambBuilder", - "MultiTensorScaleBuilder", - "MultiTensorL2NormBuilder", - "ArmCPUAdamBuilder", -] diff --git a/op_builder/arm_cpu_adam.py b/op_builder/arm_cpu_adam.py deleted file mode 100644 index 18dd519fae46..000000000000 --- a/op_builder/arm_cpu_adam.py +++ /dev/null @@ -1,34 +0,0 @@ -from .builder import Builder - - -class ArmCPUAdamBuilder(Builder): - NAME = "arm_cpu_adam" - PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam" - ext_type = "cpu" - - def __init__(self): - super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - - # necessary 4 functions - def sources_files(self): - ret = [ - self.csrc_abs_path("cpu_adam_arm.cpp"), - ] - return ret - - def include_dirs(self): - return [self.csrc_abs_path("includes")] - - def cxx_flags(self): - extra_cxx_flags = [ - "-std=c++14", - "-std=c++17", - "-g", - "-Wno-reorder", - "-fopenmp", - ] - return ["-O3"] + self.version_dependent_macros + extra_cxx_flags - - def nvcc_flags(self): - return [] diff --git a/op_builder/builder.py b/op_builder/builder.py deleted file mode 100644 index d804cb1602e4..000000000000 --- a/op_builder/builder.py +++ /dev/null @@ -1,236 +0,0 @@ -# This code has been adapted from the DeepSpeed library. -# Copyright (c) Microsoft Corporation. - -# Licensed under the MIT License. -import importlib -import os -import time -from abc import ABC, abstractmethod -from pathlib import Path -from typing import List, Optional, Union - -from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 - - -class Builder(ABC): - """ - Builder is the base class to build extensions for PyTorch. - - Args: - name (str): the name of the kernel to be built - prebuilt_import_path (str): the path where the extension is installed during pip install - """ - - ext_type: str = "cuda" - - def __init__(self, name: str, prebuilt_import_path: str): - self.name = name - self.prebuilt_import_path = prebuilt_import_path - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - - # we store the op as an attribute to avoid repeated building and loading - self.cached_op_module = None - - assert prebuilt_import_path.startswith( - "colossalai._C" - ), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}" - - def relative_to_abs_path(self, code_path: str) -> str: - """ - This function takes in a path relative to the colossalai root directory and return the absolute path. - """ - op_builder_module_path = Path(__file__).parent - - # if we install from source - # the current file path will be op_builder/builder.py - # if we install via pip install colossalai - # the current file path will be colossalai/kernel/op_builder/builder.py - # this is because that the op_builder inside colossalai is a symlink - # this symlink will be replaced with actual files if we install via pypi - # thus we cannot tell the colossalai root directory by checking whether the op_builder - # is a symlink, we can only tell whether it is inside or outside colossalai - if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"): - root_path = op_builder_module_path.parent.parent - else: - root_path = op_builder_module_path.parent.joinpath("colossalai") - - code_abs_path = root_path.joinpath(code_path) - return str(code_abs_path) - - def get_cuda_home_include(self): - """ - return include path inside the cuda home. - """ - from torch.utils.cpp_extension import CUDA_HOME - - if CUDA_HOME is None: - raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") - cuda_include = os.path.join(CUDA_HOME, "include") - return cuda_include - - def csrc_abs_path(self, path): - return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path) - - # functions must be overrided begin - @abstractmethod - def sources_files(self) -> List[str]: - """ - This function should return a list of source files for extensions. - """ - raise NotImplementedError - - @abstractmethod - def include_dirs(self) -> List[str]: - """ - This function should return a list of include files for extensions. - """ - - @abstractmethod - def cxx_flags(self) -> List[str]: - """ - This function should return a list of cxx compilation flags for extensions. - """ - - @abstractmethod - def nvcc_flags(self) -> List[str]: - """ - This function should return a list of nvcc compilation flags for extensions. - """ - - # functions must be overrided over - def strip_empty_entries(self, args): - """ - Drop any empty strings from the list of compile and link flags - """ - return [x for x in args if len(x) > 0] - - def import_op(self): - """ - This function will import the op module by its string name. - """ - return importlib.import_module(self.prebuilt_import_path) - - def check_runtime_build_environment(self): - """ - Check whether the system environment is ready for extension compilation. - """ - try: - from torch.utils.cpp_extension import CUDA_HOME - - TORCH_AVAILABLE = True - except ImportError: - TORCH_AVAILABLE = False - CUDA_HOME = None - - if not TORCH_AVAILABLE: - raise ModuleNotFoundError( - "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions" - ) - - if CUDA_HOME is None: - raise RuntimeError( - "CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" - ) - - # make sure CUDA is available for compilation during - cuda_available = check_cuda_availability() - if not cuda_available: - raise RuntimeError("CUDA is not available on your system as torch.cuda.is_available() returns False.") - - # make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not - check_system_pytorch_cuda_match(CUDA_HOME) - - def load(self, verbose: Optional[bool] = None): - """ - load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel. - If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the - kernel is built during pip install, it can be accessed through `colossalai._C`. - - Warning: do not load this kernel repeatedly during model execution as it could slow down the training process. - - Args: - verbose (bool, optional): show detailed info. Defaults to True. - """ - if verbose is None: - verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1" - # if the kernel has be compiled and cached, we directly use it - if self.cached_op_module is not None: - return self.cached_op_module - - try: - # if the kernel has been pre-built during installation - # we just directly import it - op_module = self.import_op() - if verbose: - print_rank_0( - f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building." - ) - except ImportError: - # check environment - if self.ext_type == "cuda": - self.check_runtime_build_environment() - - # time the kernel compilation - start_build = time.time() - - # construct the build directory - import torch - from torch.utils.cpp_extension import load - - torch_version_major = torch.__version__.split(".")[0] - torch_version_minor = torch.__version__.split(".")[1] - torch_cuda_version = torch.version.cuda - home_directory = os.path.expanduser("~") - extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" - build_directory = os.path.join(home_directory, extension_directory) - Path(build_directory).mkdir(parents=True, exist_ok=True) - - if verbose: - print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now") - - # load the kernel - op_module = load( - name=self.name, - sources=self.strip_empty_entries(self.sources_files()), - extra_include_paths=self.strip_empty_entries(self.include_dirs()), - extra_cflags=self.cxx_flags(), - extra_cuda_cflags=self.nvcc_flags(), - extra_ldflags=[], - build_directory=build_directory, - verbose=verbose, - ) - - build_duration = time.time() - start_build - - # log jit compilation time - if verbose: - print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds") - - # cache the built/loaded kernel - self.cached_op_module = op_module - - return op_module - - def builder(self) -> Union["CUDAExtension", "CppExtension"]: - """ - get a CUDAExtension instance used for setup.py - """ - from torch.utils.cpp_extension import CppExtension, CUDAExtension - - if self.ext_type == "cpp": - return CppExtension( - name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args=self.strip_empty_entries(self.cxx_flags()), - ) - - return CUDAExtension( - name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args={ - "cxx": self.strip_empty_entries(self.cxx_flags()), - "nvcc": self.strip_empty_entries(self.nvcc_flags()), - }, - ) diff --git a/op_builder/fused_optim.py b/op_builder/fused_optim.py deleted file mode 100644 index 3baa0880d801..000000000000 --- a/op_builder/fused_optim.py +++ /dev/null @@ -1,37 +0,0 @@ -from .builder import Builder -from .utils import get_cuda_cc_flag - - -class FusedOptimBuilder(Builder): - NAME = "fused_optim" - PREBUILT_IMPORT_PATH = "colossalai._C.fused_optim" - - def __init__(self): - super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH) - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "colossal_C_frontend.cpp", - "multi_tensor_sgd_kernel.cu", - "multi_tensor_scale_kernel.cu", - "multi_tensor_adam.cu", - "multi_tensor_l2norm_kernel.cu", - "multi_tensor_lamb.cu", - ] - ] - return ret - - def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - return ret - - def cxx_flags(self): - version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - return ["-O3"] + version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = ["-lineinfo"] - extra_cuda_flags.extend(get_cuda_cc_flag()) - return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/op_builder/gptq.py b/op_builder/gptq.py deleted file mode 100644 index a17801f8783c..000000000000 --- a/op_builder/gptq.py +++ /dev/null @@ -1,56 +0,0 @@ -import re - -import torch - -from .builder import Builder -from .utils import append_nvcc_threads - - -class GPTQBuilder(Builder): - NAME = "cu_gptq" - PREBUILT_IMPORT_PATH = "colossalai._C.cu_gptq" - - def __init__(self): - super().__init__(name=GPTQBuilder.NAME, prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH) - - def include_dirs(self): - ret = [self.csrc_abs_path("gptq"), self.get_cuda_home_include()] - return ret - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "gptq/linear_gptq.cpp", - "gptq/column_remap.cu", - "gptq/cuda_buffers.cu", - "gptq/q4_matmul.cu", - "gptq/q4_matrix.cu", - ] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = [ - "-v", - "-std=c++14", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - "-lcublas", - ] - - for arch in torch.cuda.get_arch_list(): - res = re.search(r"sm_(\d+)", arch) - if res: - arch_cap = res[1] - if int(arch_cap) >= 80: - extra_cuda_flags.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) - - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/op_builder/layernorm.py b/op_builder/layernorm.py deleted file mode 100644 index 2684c6ddb7f7..000000000000 --- a/op_builder/layernorm.py +++ /dev/null @@ -1,27 +0,0 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag - - -class LayerNormBuilder(Builder): - NAME = "layernorm" - PREBUILT_IMPORT_PATH = "colossalai._C.layernorm" - - def __init__(self): - super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH) - - def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["layer_norm_cuda.cpp", "layer_norm_cuda_kernel.cu"]] - return ret - - def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = ["-maxrregcount=50"] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros - return append_nvcc_threads(ret) diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py deleted file mode 100644 index cb8fc489ced1..000000000000 --- a/op_builder/multi_head_attn.py +++ /dev/null @@ -1,46 +0,0 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag - - -class MultiHeadAttnBuilder(Builder): - NAME = "multihead_attention" - PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention" - - def __init__(self): - super().__init__(name=MultiHeadAttnBuilder.NAME, prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) - - def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - return ret - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "multihead_attention_1d.cpp", - "kernels/cublas_wrappers.cu", - "kernels/transform_kernels.cu", - "kernels/dropout_kernels.cu", - "kernels/normalize_kernels.cu", - "kernels/softmax_kernels.cu", - "kernels/general_kernels.cu", - "kernels/cuda_util.cu", - ] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = [ - "-std=c++14", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - ] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/op_builder/scaled_upper_triangle_masked_softmax.py b/op_builder/scaled_upper_triangle_masked_softmax.py deleted file mode 100644 index 1445230acbc1..000000000000 --- a/op_builder/scaled_upper_triangle_masked_softmax.py +++ /dev/null @@ -1,37 +0,0 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag - - -class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder): - NAME = "scaled_upper_triangle_masked_softmax" - PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax" - - def __init__(self): - super().__init__( - name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, - prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH, - ) - - def include_dirs(self): - return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in ["scaled_upper_triang_masked_softmax.cpp", "scaled_upper_triang_masked_softmax_cuda.cu"] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = [ - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - ] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/op_builder/smoothquant.py b/op_builder/smoothquant.py deleted file mode 100644 index d562a4c4f626..000000000000 --- a/op_builder/smoothquant.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch - -from .builder import Builder -from .utils import append_nvcc_threads - - -class SmoothquantBuilder(Builder): - NAME = "cu_smoothquant" - PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant" - - def __init__(self): - super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH) - - def include_dirs(self): - ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()] - return ret - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "smoothquant/binding.cpp", - "smoothquant/linear.cu", - ] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - compute_capability = torch.cuda.get_device_capability() - cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10 - - extra_cuda_flags = [ - "-v", - f"-DCUDA_ARCH={cuda_arch}", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - ] - - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) - - def builder(self): - try: - super().builder() - except: - warnings.warn("build smoothquant lib not successful") diff --git a/setup.py b/setup.py index cda1ba7ee7a6..1244bfff0327 100644 --- a/setup.py +++ b/setup.py @@ -5,55 +5,23 @@ from setuptools import find_packages, setup -from op_builder.utils import ( - check_cuda_availability, - check_pytorch_version, - check_system_pytorch_cuda_match, - get_cuda_bare_metal_version, - get_pytorch_version, - set_cuda_arch_list, -) - try: - from torch.utils.cpp_extension import CUDA_HOME, BuildExtension + import torch # noqa + from torch.utils.cpp_extension import BuildExtension TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False - CUDA_HOME = None -# Some constants for installation checks -MIN_PYTORCH_VERSION_MAJOR = 1 -MIN_PYTORCH_VERSION_MINOR = 10 THIS_DIR = os.path.dirname(os.path.abspath(__file__)) -BUILD_CUDA_EXT = int(os.environ.get("CUDA_EXT", "0")) == 1 +BUILD_EXT = int(os.environ.get("BUILD_EXT", "0")) == 1 IS_NIGHTLY = int(os.environ.get("NIGHTLY", "0")) == 1 -# a variable to store the op builder -ext_modules = [] - # we do not support windows currently if sys.platform == "win32": raise RuntimeError("Windows is not supported yet. Please try again within the Windows Subsystem for Linux (WSL).") -# check for CUDA extension dependencies -def environment_check_for_cuda_extension_build(): - if not TORCH_AVAILABLE: - raise ModuleNotFoundError( - "[extension] PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" - ) - - if not CUDA_HOME: - raise RuntimeError( - "[extension] CUDA_HOME is not found while CUDA_EXT=1. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" - ) - - check_system_pytorch_cuda_match(CUDA_HOME) - check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR) - check_cuda_availability() - - def fetch_requirements(path) -> List[str]: """ This function reads the requirements file. @@ -98,46 +66,35 @@ def get_version() -> str: # write version into version.py with open(version_py_path, "w") as f: f.write(f"__version__ = '{version}'\n") - - # look for pytorch and cuda version - if BUILD_CUDA_EXT: - torch_major, torch_minor, _ = get_pytorch_version() - torch_version = f"{torch_major}.{torch_minor}" - cuda_version = ".".join(get_cuda_bare_metal_version(CUDA_HOME)) - else: - torch_version = None - cuda_version = None - - # write the version into the python file - if torch_version: - f.write(f'torch = "{torch_version}"\n') - else: - f.write("torch = None\n") - - if cuda_version: - f.write(f'cuda = "{cuda_version}"\n') - else: - f.write("cuda = None\n") - return version -if BUILD_CUDA_EXT: - environment_check_for_cuda_extension_build() - set_cuda_arch_list(CUDA_HOME) +if BUILD_EXT: + if not TORCH_AVAILABLE: + raise ModuleNotFoundError( + "[extension] PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" + ) - from op_builder import ALL_OPS + from extensions import ALL_EXTENSIONS op_names = [] + ext_modules = [] - # load all builders - for name, builder_cls in ALL_OPS.items(): - op_names.append(name) - ext_modules.append(builder_cls().builder()) + for ext_cls in ALL_EXTENSIONS: + ext = ext_cls() + if ext.support_aot and ext.is_hardware_available(): + ext.assert_hardware_compatible() + op_names.append(ext.name) + ext_modules.append(ext.build_aot()) # show log - op_name_list = ", ".join(op_names) - print(f"[extension] loaded builders for {op_name_list}") + if len(ext_modules) == 0: + raise RuntimeError("[extension] Could not find any kernel compatible with the current environment.") + else: + op_name_list = ", ".join(op_names) + print(f"[extension] Building extensions{op_name_list}") +else: + ext_modules = [] # always put not nightly branch as the if branch # otherwise github will treat colossalai-nightly as the project name diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index 5e8e0b3822df..a16b16ad6af7 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -61,7 +61,9 @@ def register( """ self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute) - def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None): + def get_sub_registry( + self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None, allow_empty: bool = False + ): """ Get a sub registry with models that contain the keyword. @@ -95,7 +97,8 @@ def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, L if not should_exclude: new_dict[k] = v - assert len(new_dict) > 0, f"No model found with keyword {keyword}" + if not allow_empty: + assert len(new_dict) > 0, f"No model found with keyword {keyword}" return new_dict diff --git a/tests/kit/model_zoo/transformers/gptj.py b/tests/kit/model_zoo/transformers/gptj.py index 9eefbb43dad8..c89124f0164d 100644 --- a/tests/kit/model_zoo/transformers/gptj.py +++ b/tests/kit/model_zoo/transformers/gptj.py @@ -63,6 +63,9 @@ def data_gen_for_sequence_classification(): n_layer=2, n_head=4, vocab_size=50258, + n_embd=256, + hidden_size=256, + n_positions=512, attn_pdrop=0, embd_pdrop=0, resid_pdrop=0, diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index 2c8b260e6498..373ba28b8545 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -5,13 +5,13 @@ from torch.utils._pytree import tree_map import colossalai +from colossalai.accelerator import get_accelerator from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.fx.profiler import parameter_size from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_tensor.common_utils import set_seed @@ -31,7 +31,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): 64, 8, ), - device=get_current_device(), + device=get_accelerator().get_current_device(), ) criterion = LMLoss() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index aba746f1992d..d577173266da 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -10,12 +10,12 @@ except: NO_CODEGEN = True +from colossalai.accelerator import get_accelerator from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer import HybridAdam from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn -from colossalai.utils import get_current_device from colossalai.zero import zero_model_wrapper, zero_optim_wrapper @@ -72,7 +72,11 @@ def check_auto_parallel_with_gemini(rank, world_size, port): print("=" * msg_length) gemini_config = dict( - strict_ddp_mode=False, device=get_current_device(), placement_policy="cpu", pin_memory=True, search_range_m=128 + strict_ddp_mode=False, + device=get_accelerator().get_current_device(), + placement_policy="cpu", + pin_memory=True, + search_range_m=128, ) gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index e724d7359c54..67b0bef50594 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -1,19 +1,49 @@ +import copy from contextlib import nullcontext from typing import Optional import torch import torch.distributed as dist +from torch.testing import assert_close +from torch.utils.data import Dataset import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin from colossalai.fx import is_compatible_with_meta from colossalai.lazy.lazy_init import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed from tests.kit.model_zoo import model_zoo +class RandomDataset(Dataset): + def __init__(self, num_samples: int = 100, max_length: int = 512, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + set_seed(42) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + + @clear_cache_before_run() def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: try: @@ -85,10 +115,145 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True): assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) +@parameterize( + "test_args", + [ + { + "batch_size": 8, + "num_steps": 4, + "tp": 2, + "pp": 2, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 4, + "zero": 0, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, + { + "batch_size": 8, + "num_steps": 4, + "tp": 1, + "pp": 2, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 4, + "zero": 1, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, + { + "batch_size": 1, + "num_steps": 4, + "tp": 2, + "pp": 1, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 1, + "zero": 2, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, + { + "batch_size": 1, + "num_steps": 4, + "tp": 2, + "pp": 1, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 1, + "zero": 0, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, + ], +) +def run_grad_acc_test(test_args): + model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())) + model = model_fn() + optimizer = HybridAdam(model.parameters()) + origin_model = copy.deepcopy(model).cuda() + origin_optimizer = HybridAdam(origin_model.parameters()) + + plugin = HybridParallelPlugin( + tp_size=test_args["tp"], + pp_size=test_args["pp"], + pp_style=test_args["pp_style"], + zero_stage=test_args["zero"], + num_model_chunks=test_args["num_model_chunks"], + enable_fused_normalization=True, + num_microbatches=test_args["num_microbatches"], + precision=test_args["precision"], + ) + booster = Booster(plugin=plugin) + + dataset = RandomDataset( + num_samples=test_args["batch_size"] * test_args["num_steps"] * plugin.dp_size, + max_length=test_args["max_length"], + vocab_size=model.config.vocab_size, + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=test_args["batch_size"], shuffle=True, drop_last=True) + + model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + + grad_accu_step = test_args["gradient_accumulation_step"] + for step, batch in enumerate(dataloader): + batch = move_to_cuda(batch) + # train origin model + origin_output = origin_model(**batch) + origin_loss = origin_output[0] / grad_accu_step + origin_loss.backward() + + if (step + 1) % grad_accu_step != 0 and test_args["zero"] != 2: + ctx = booster.no_sync(model, optimizer) + else: + ctx = nullcontext() + + with ctx: + if plugin.stage_manager is not None: + batch = iter([batch]) + booster.execute_pipeline( + batch, + model, + criterion=lambda outputs, inputs: outputs[0] / grad_accu_step, + optimizer=optimizer, + return_loss=False, + ) + else: + outputs = model(**batch) + loss = outputs[0] / grad_accu_step + booster.backward(loss, optimizer) + + if (step + 1) % grad_accu_step == 0: + # update origin model weight + origin_optimizer.step() + origin_optimizer.zero_grad() + + # update sharded model + optimizer.step() + optimizer.zero_grad() + + # tricky code here, shard the origin model inorder to check the parameters in the same stage. + origin_model, origin_optimizer, _, dataloader, _ = booster.boost( + origin_model, origin_optimizer, dataloader=dataloader + ) + for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()): + assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) + + def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_3d_plugin(early_stop=early_stop) + run_grad_acc_test() @rerun_if_address_is_in_use() diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 9952e41e5b13..17dfa3a1860d 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -12,7 +12,13 @@ from colossalai.lazy.lazy_init import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import ( + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + skip_if_not_enough_gpus, + spawn, +) from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo @@ -172,6 +178,7 @@ def test_gemini_plugin(early_stop: bool = True): @pytest.mark.largedist +@skip_if_not_enough_gpus(8) @rerun_if_address_is_in_use() def test_gemini_plugin_3d(early_stop: bool = True): spawn(run_dist, 8, early_stop=early_stop) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index bcdcc1470e6c..861fa0131397 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -5,13 +5,13 @@ from torch.optim import Adam import colossalai -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin # from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo # These models are not compatible with AMP _AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"] @@ -21,8 +21,9 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"] +@clear_cache_before_run() def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: - device = device_utils.get_current_device() + device = get_accelerator().get_current_device() try: plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5) booster = Booster(plugin=plugin) @@ -74,7 +75,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): continue err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) - device_utils.empty_cache() + get_accelerator().empty_cache() if err is None: passed_models.append(name) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index fa32feb2ff85..e785843fb053 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -10,10 +10,11 @@ from colossalai.booster import Booster from colossalai.booster.plugin import TorchDDPPlugin from colossalai.interface import OptimizerWrapper -from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo +@clear_cache_before_run() def run_fn(model_fn, data_gen_fn, output_transform_fn): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index 8a14d7cf872d..f698070465d6 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -11,11 +11,12 @@ from colossalai.booster.plugin import TorchFSDPPlugin from colossalai.interface import OptimizerWrapper -from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo # test basic fsdp function +@clear_cache_before_run() def run_fn(model_fn, data_gen_fn, output_transform_fn): plugin = TorchFSDPPlugin() booster = Booster(plugin=plugin) @@ -40,12 +41,18 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): optimizer.clip_grad_by_norm(1.0) optimizer.step() + del model + del optimizer + del criterion + del booster + del plugin + def check_torch_fsdp_plugin(): if IS_FAST_TEST: registry = model_zoo.get_sub_registry(COMMON_MODELS) else: - registry = model_zoo + registry = model_zoo.get_sub_registry("transformers_gptj") for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items(): if any( @@ -59,6 +66,7 @@ def check_torch_fsdp_plugin(): ] ): continue + print(name) run_fn(model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() @@ -73,3 +81,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_torch_fsdp_plugin(): spawn(run_dist, 2) + + +if __name__ == "__main__": + test_torch_fsdp_plugin() diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 49fd85ffba0a..708a1906b118 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -7,7 +7,6 @@ from utils import shared_tempdir import colossalai -from colossalai.testing import skip_if_not_enough_gpus from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin from colossalai.lazy import LazyInitContext @@ -17,6 +16,7 @@ clear_cache_before_run, parameterize, rerun_if_address_is_in_use, + skip_if_not_enough_gpus, spawn, ) from tests.kit.model_zoo import model_zoo @@ -52,7 +52,12 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b bert_model.config.save_pretrained(save_directory=pretrained_path) extra_dp_size = dist.get_world_size() // (zero_size * tp_size) - plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size) + plugin = GeminiPlugin( + **placement_config, + tp_size=tp_size, + enable_all_optimization=enable_all_optimization, + extra_dp_size=extra_dp_size, + ) booster = Booster(plugin=plugin) bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 @@ -78,7 +83,14 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha criterion = lambda x: x.mean() enable_all_optimization = True if tp_size > 1 else False extra_dp_size = dist.get_world_size() // (zero_size * tp_size) - plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization) + plugin = GeminiPlugin( + **placement_config, + precision="fp16", + initial_scale=(2**14), + tp_size=tp_size, + extra_dp_size=extra_dp_size, + enable_all_optimization=enable_all_optimization, + ) booster = Booster(plugin=plugin) model = model_fn() @@ -161,8 +173,13 @@ def run_dist(rank, world_size, port): def test_gemini_ckpIO(): spawn(run_dist, 4) + @pytest.mark.largedist @skip_if_not_enough_gpus(min_gpus=8) @rerun_if_address_is_in_use() def test_gemini_ckpIO_3d(): - spawn(run_dist, 8) \ No newline at end of file + spawn(run_dist, 8) + + +if __name__ == "__main__": + test_gemini_ckpIO() diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index db3c56da874d..a42b550cd6fc 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -38,11 +38,11 @@ ] -@clear_cache_before_run() @parameterize("shard", [True, False]) @parameterize("model_name", ["transformers_llama_for_casual_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) +@clear_cache_before_run() def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -104,30 +104,32 @@ def _preprocess_data(data): # Check whether the loaded model & optimizer works smoothly. model.train() new_model.train() + data_for_shard = data_gen_fn() + data_for_origin = data_gen_fn() if booster.plugin.stage_manager is not None: booster.execute_pipeline( - _preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False + _preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True, return_outputs=False ) booster.execute_pipeline( - _preprocess_data(data), new_model, _criterion, new_optimizer, return_loss=True, return_outputs=False + _preprocess_data(data_for_origin), + new_model, + _criterion, + new_optimizer, + return_loss=True, + return_outputs=False, ) else: - old_model_loss = criterion(model(**_preprocess_data(data))) + old_model_loss = criterion(model(**_preprocess_data(data_for_shard))) optimizer.backward(old_model_loss) - new_model_loss = criterion(new_model(**_preprocess_data(data))) + new_model_loss = criterion(new_model(**_preprocess_data(data_for_origin))) new_optimizer.backward(new_model_loss) optimizer.step() new_optimizer.step() # Check updated weights. - stage_manager = booster.plugin.stage_manager - - if stage_manager is None or stage_manager.is_first_stage(): - assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3) - assert_close_loose( - model.unwrap().h[0].mlp.c_fc.weight.data, new_model.unwrap().h[0].mlp.c_fc.weight.data, atol=5e-3, rtol=5e-3 - ) + for p1, p2 in zip(model.unwrap().parameters(), new_model.unwrap().parameters()): + assert_close_loose(p1, p2, atol=5e-3, rtol=5e-3) dist.barrier() Randomizer.reset_index() @@ -145,3 +147,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_hybrid_ckpIO(world_size): spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_hybrid_ckpIO(4) diff --git a/tests/test_infer_ops/triton/test_llama_act_combine.py b/tests/test_infer_ops/triton/test_llama_act_combine.py deleted file mode 100644 index 5341aa35ab90..000000000000 --- a/tests/test_infer_ops/triton/test_llama_act_combine.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest -import torch -from packaging import version -from torch import nn - -from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine - -try: - import triton - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') - -BATCH_SIZE = 4 -SEQ_LEN = 16 -HIDDEN_SIZE = 32 - - -def SwiGLU(x): - """Gated linear unit activation function. - Args: - x : input array - axis: the axis along which the split should be computed (default: -1) - """ - size = x.shape[-1] - assert size % 2 == 0, "axis size must be divisible by 2" - x1, x2 = torch.split(x, size // 2, -1) - return x1 * (x2 * torch.sigmoid(x2.to(torch.float32)).to(x.dtype)) - - -@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) -def test_llama_act_combine(dtype: str): - x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda() - x_gate_torch = nn.Parameter(x_gate.detach().clone()) - x_gate_kernel = nn.Parameter(x_gate.detach().clone()) - x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda() - x_up_torch = nn.Parameter(x_up.detach().clone()) - x_up_kernel = nn.Parameter(x_up.detach().clone()) - - torch_out = SwiGLU(x_gate_torch) * x_up_torch - kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel) - atol = 1e-5 if dtype == torch.float32 else 5e-2 - assert torch.allclose(torch_out, kernel_out, atol=atol) - - torch_out.mean().backward() - kernel_out.mean().backward() - assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad]) - assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=atol) - assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=atol) - - -if __name__ == '__main__': - test_llama_act_combine(torch.float16) diff --git a/tests/test_infer_ops/triton/test_softmax.py b/tests/test_infer_ops/triton/test_softmax.py deleted file mode 100644 index 43b9c0929c4a..000000000000 --- a/tests/test_infer_ops/triton/test_softmax.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest -import torch -from packaging import version -from torch import nn - -try: - from colossalai.kernel.triton.softmax import softmax - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_softmax_op(): - data_samples = [ - torch.randn((3, 4, 5, 32), device="cuda", dtype=torch.float32), - torch.randn((320, 320, 78), device="cuda", dtype=torch.float32), - torch.randn((2345, 4, 5, 64), device="cuda", dtype=torch.float16), - ] - - for data in data_samples: - module = nn.Softmax(dim=-1) - data_torch_out = module(data) - data_triton_out = softmax(data) - check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3) - assert check is True, "softmax outputs from triton and torch are not matched" - - -if __name__ == "__main__": - test_softmax_op() diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index ee50e5b61009..d0c4cd0a7c48 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -1,14 +1,19 @@ import pytest from lazy_init_utils import SUPPORT_LAZY, check_lazy_init -from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS +from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo @pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0") -@pytest.mark.parametrize("subset", [COMMON_MODELS] if IS_FAST_TEST else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"]) +@pytest.mark.parametrize( + "subset", + [COMMON_MODELS] + if IS_FAST_TEST + else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"], +) @pytest.mark.parametrize("default_device", ["cpu", "cuda"]) def test_torchvision_models_lazy_init(subset, default_device): - sub_model_zoo = model_zoo.get_sub_registry(subset) + sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith( diff --git a/tests/test_legacy/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py index 7d2c81972e5a..079022e930cf 100644 --- a/tests/test_legacy/test_comm/test_comm.py +++ b/tests/test_legacy/test_comm/test_comm.py @@ -2,12 +2,12 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.initialize import launch from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) @@ -16,7 +16,7 @@ def check_all_gather(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) + tensor = tensor.to(get_accelerator().get_current_device()) print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True) print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) @@ -27,7 +27,7 @@ def check_all_gather(): def check_reduce_scatter(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) + tensor = tensor.to(get_accelerator().get_current_device()) print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True) print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) @@ -38,7 +38,7 @@ def check_reduce_scatter(): def check_all_reduce(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) + tensor = tensor.to(get_accelerator().get_current_device()) print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True) print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) diff --git a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py index 8a9a73d65f38..f09df9253a38 100644 --- a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py +++ b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -2,6 +2,7 @@ import torch.distributed as dist from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.global_variables import tensor_parallel_env as env @@ -16,13 +17,12 @@ VocabParallelEmbedding1D, ) from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal def check_linear_col(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -68,7 +68,7 @@ def check_linear_col(): print_rank_0("linear_col forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) dist.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] grad = grad.clone() @@ -91,7 +91,7 @@ def check_linear_col(): def check_linear_row(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -137,7 +137,7 @@ def check_linear_row(): print_rank_0("linear_row forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) dist.broadcast(grad_master, src=0) grad = grad_master.clone() out.backward(grad) @@ -159,7 +159,7 @@ def check_linear_row(): def check_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -201,7 +201,7 @@ def check_embed(): def check_vocab_parallel_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -243,7 +243,7 @@ def check_vocab_parallel_embed(): def check_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -309,7 +309,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -369,7 +369,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -420,7 +420,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_vocab_parallel_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -508,7 +508,7 @@ def check_vocab_parallel_loss(): @torch.no_grad() def check_linear_row_stream_inference(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py index 0bbc72eca809..78bd407b9193 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py @@ -1,5 +1,6 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import ( @@ -16,13 +17,12 @@ VocabParallelEmbedding2D, ) from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal def check_linear(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = HIDDEN_SIZE @@ -74,7 +74,7 @@ def check_linear(): print_rank_0("linear forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -103,7 +103,7 @@ def check_linear(): def check_layernorm(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE EPS = 1e-12 @@ -139,7 +139,7 @@ def check_layernorm(): print_rank_0("layer norm forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -154,7 +154,7 @@ def check_layernorm(): def check_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -201,7 +201,7 @@ def check_embed(): def check_patch_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -274,7 +274,7 @@ def check_patch_embed(): def check_vocab_parallel_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -321,7 +321,7 @@ def check_vocab_parallel_embed(): def check_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = NUM_CLASSES @@ -371,7 +371,7 @@ def check_classifier_no_given_weight(): print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] # grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -399,7 +399,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -467,7 +467,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -519,7 +519,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -573,7 +573,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -608,7 +608,7 @@ def check_loss(): def check_vocab_parallel_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -645,7 +645,7 @@ def check_vocab_parallel_loss(): # def check_attention(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 @@ -683,7 +683,7 @@ def check_vocab_parallel_loss(): # print_rank_0('self attention backward: pass') # def check_mlp(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE @@ -716,7 +716,7 @@ def check_vocab_parallel_loss(): # print_rank_0('mlp backward: pass') # def check_transformerlayer(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py index 9c126cefeba8..4506cfee686d 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py @@ -3,11 +3,11 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal @@ -27,7 +27,7 @@ def check_AB(): i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[j] @@ -35,7 +35,7 @@ def check_AB(): A.requires_grad = True B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, DEPTH, dim=0)[i] B = torch.chunk(B, DEPTH, dim=-1)[j] @@ -72,7 +72,7 @@ def check_AB(): print_rank_0("AB forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -105,7 +105,7 @@ def check_ABT(): tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float - device = get_current_device() + device = get_accelerator().get_current_device() j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -184,7 +184,7 @@ def check_ATB(): ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py index 283e7f68374f..914607614a00 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py @@ -1,6 +1,7 @@ import torch from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import ( @@ -17,13 +18,12 @@ VocabParallelEmbedding2p5D, ) from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import * def check_linear(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -76,7 +76,7 @@ def check_linear(): print_rank_0("linear forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -104,7 +104,7 @@ def check_linear(): def check_layernorm(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE EPS = 1e-12 @@ -141,7 +141,7 @@ def check_layernorm(): print_rank_0("layer norm forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -156,7 +156,7 @@ def check_layernorm(): def check_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -204,7 +204,7 @@ def check_embed(): def check_patch_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -278,7 +278,7 @@ def check_patch_embed(): def check_vocab_parallel_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -326,7 +326,7 @@ def check_vocab_parallel_embed(): def check_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = NUM_CLASSES @@ -377,7 +377,7 @@ def check_classifier_no_given_weight(): print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] # grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -405,7 +405,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -524,7 +524,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -578,7 +578,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -613,7 +613,7 @@ def check_loss(): def check_vocab_parallel_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -650,7 +650,7 @@ def check_vocab_parallel_loss(): # def check_attention(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 @@ -689,7 +689,7 @@ def check_vocab_parallel_loss(): # print_rank_0('self attention backward: pass') # def check_mlp(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE @@ -725,7 +725,7 @@ def check_vocab_parallel_loss(): # print_rank_0('mlp backward: pass') # def check_transformerlayer(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py index 992bd6107f08..91a15c81dfe5 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py @@ -1,10 +1,10 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import * @@ -25,7 +25,7 @@ def check_AB(): k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] @@ -33,7 +33,7 @@ def check_AB(): A.requires_grad = True B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] @@ -70,7 +70,7 @@ def check_AB(): print_rank_0("AB forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -103,7 +103,7 @@ def check_ABT(): tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float - device = get_current_device() + device = get_accelerator().get_current_device() i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -184,7 +184,7 @@ def check_ATB(): ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) diff --git a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py index a4a4ae9a5ba4..f9f19a17b9d1 100644 --- a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -5,6 +5,7 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.legacy.core import global_context from colossalai.legacy.nn import ( @@ -23,7 +24,6 @@ from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.legacy.utils import print_rank_0 from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal @@ -31,7 +31,7 @@ def check_linear(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -84,7 +84,7 @@ def check_linear(): logger.info("Rank {} linear forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, device=get_current_device()) + grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -119,7 +119,7 @@ def check_linear(): def check_layernorm(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -206,7 +206,7 @@ def check_layernorm(): def check_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -258,7 +258,7 @@ def check_classifier_no_given_weight(): logger.info("Rank {} classifier (no given weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, device=get_current_device()) + grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=0)[j] @@ -306,7 +306,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -413,7 +413,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -463,7 +463,7 @@ def check_classifier_given_embed_weight(): logger.info("Rank {} classifier (given embed weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=0)[j] @@ -497,7 +497,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -580,7 +580,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_patch_embed(): rank = torch.distributed.get_rank() - device = get_current_device() + device = get_accelerator().get_current_device() logger = get_dist_logger() torch.float32 @@ -678,7 +678,7 @@ def check_patch_embed(): def check_embed(): rank = torch.distributed.get_rank() - device = get_current_device() + device = get_accelerator().get_current_device() logger = get_dist_logger() torch.float32 @@ -746,7 +746,7 @@ def check_embed(): def check_vocab_parallel_embed(): rank = torch.distributed.get_rank() - device = get_current_device() + device = get_accelerator().get_current_device() logger = get_dist_logger() torch.float32 @@ -823,7 +823,7 @@ def check_vocab_parallel_embed(): def check_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -876,7 +876,7 @@ def check_loss(): def check_vocab_parallel_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) diff --git a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py index aa4d5d6ceeb3..f4ad0d6d1671 100644 --- a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py +++ b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py @@ -1,9 +1,9 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import TransformerSelfAttentionRing -from colossalai.utils import get_current_device def check_selfattention(): @@ -13,10 +13,10 @@ def check_selfattention(): HIDDEN_SIZE = 16 layer = TransformerSelfAttentionRing(16, 8, 8, 0.1) - layer = layer.to(get_current_device()) + layer = layer.to(get_accelerator().get_current_device()) - hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device()) + hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_accelerator().get_current_device()) attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to( - get_current_device() + get_accelerator().get_current_device() ) layer(hidden_states, attention_mask) diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py index a5a2d38577dc..cab111358c9c 100644 --- a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import ( recv_backward, recv_forward, @@ -18,7 +19,6 @@ from colossalai.legacy.initialize import launch from colossalai.logging import get_dist_logger from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device BATCH_SIZE = 4 SEQ_LENGTH = 2 @@ -73,7 +73,7 @@ def check_forward_backward(output_tensor, output_grad, rank, logger): def check_comm(size, rank, prev_rank, next_rank, logger): dtype = torch.float32 - device = get_current_device() + device = get_accelerator().get_current_device() tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) tensor = torch.randn(tensor_shape, dtype=dtype, device=device) diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py index 9df7cf75aae5..4993df4f3713 100644 --- a/tests/test_legacy/test_utils/test_memory.py +++ b/tests/test_legacy/test_utils/test_memory.py @@ -1,15 +1,15 @@ import pytest import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.testing import spawn -from colossalai.utils.device import get_current_device def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): - frac1 = colo_device_memory_capacity(get_current_device()) + frac1 = colo_device_memory_capacity(get_accelerator().get_current_device()) colo_set_process_memory_fraction(0.5) - frac2 = colo_device_memory_capacity(get_current_device()) + frac2 = colo_device_memory_capacity(get_accelerator().get_current_device()) assert frac2 * 2 == frac1 diff --git a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py index b5f2be705890..9975cc04ff30 100644 --- a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py +++ b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py @@ -4,12 +4,12 @@ from torch.nn.utils import clip_grad_norm_ import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup, distspec from colossalai.legacy.utils.common import clip_grad_norm from colossalai.logging import disable_existing_loggers from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): @@ -36,7 +36,7 @@ def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None: @parameterize("norm_type", [2.0, 3.0, float("inf")]) def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float): print(f"{world_size}, {dtype}, {device}, {norm_type}") - cuda_device = get_current_device() + cuda_device = get_accelerator().get_current_device() devices = [cuda_device] * 4 if device == "cpu": devices = [torch.device("cpu")] * 4 diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 3fac624729db..a349bc5a910a 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -4,10 +4,10 @@ import torch.nn as nn import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from tests.test_moe.moe_utils import MoeGradientHandler BATCH_SIZE = 4 @@ -38,7 +38,7 @@ def run_test(rank, world_size, port): layer_list.append(moe_layer) model = nn.ModuleList(layer_list) - model = model.to(get_current_device()) + model = model.to(get_accelerator().get_current_device()) dist_dict = MOE_MANAGER.parallel_info_dict assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) @@ -52,7 +52,7 @@ def run_test(rank, world_size, port): rank = dist.get_rank() torch.cuda.manual_seed(78 + rank) - data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) + data = torch.randn(BATCH_SIZE, DIM, device=get_accelerator().get_current_device()) grad = torch.randn_like(data) MOE_MANAGER.reset_loss() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 255ec7444a2c..62d61a3d4b2c 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -3,10 +3,10 @@ import torch.distributed as dist import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device BATCH_SIZE = 4 NUM_EXPERTS = 4 @@ -28,7 +28,9 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f torch.manual_seed(rs + local_rank) # set each process has different random seed # get randomized data - tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) + tokens = torch.randn( + BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True + ) layer = SparseMLP( hidden_size=hidden_size, @@ -37,7 +39,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f router_top_k=topk, router_capacity_factor_train=1.0, ) - layer = layer.to(get_current_device()) + layer = layer.to(get_accelerator().get_current_device()) if data_type == torch.float16: layer = layer.half() @@ -45,7 +47,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer.enable_kernel = False old_out = layer(tokens) ech = old_out.shape - grad = torch.randn(ech, device=get_current_device()) + grad = torch.randn(ech, device=get_accelerator().get_current_device()) old_out.backward(grad) # get gradient # save all results diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index bd1103df30d3..8f51e1663727 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -9,11 +9,11 @@ from transformers.models.llama import LlamaConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device sys.path.append( os.path.join( @@ -28,7 +28,7 @@ def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20): - input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device()) + input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_accelerator().get_current_device()) attention_mask = torch.ones_like(input_ids) return { "input_ids": input_ids, diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index f87d4c792155..74feeeb59722 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -7,12 +7,12 @@ import torch.distributed as dist import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from tests.test_moe.moe_utils import MoeGradientHandler @@ -23,8 +23,9 @@ def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_ tp_model (MoeModule) local_model (MoeModule) """ - for (tp_name, tp_param), (local_name, local_param) in \ - zip(tp_model.named_parameters(), local_model.named_parameters()): + for (tp_name, tp_param), (local_name, local_param) in zip( + tp_model.named_parameters(), local_model.named_parameters() + ): assert tp_name == local_name if not is_moe_tensor(tp_param): if assert_grad_flag: @@ -54,8 +55,7 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: tp_model (MoeModule) ep_model (MoeModule) """ - for (tp_name, tp_param), (ep_name, ep_param) in \ - zip(tp_model.named_parameters(), ep_model.named_parameters()): + for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()): assert tp_name == ep_name if not is_moe_tensor(tp_param): if assert_grad_flag: @@ -97,8 +97,9 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ local_model (MoeModule) ep_model (MoeModule) """ - for (local_name, local_param), (ep_name, ep_param) in \ - zip(local_model.named_parameters(), ep_model.named_parameters()): + for (local_name, local_param), (ep_name, ep_param) in zip( + local_model.named_parameters(), ep_model.named_parameters() + ): assert local_name == ep_name if "experts" not in local_name: if assert_grad_flag: @@ -141,14 +142,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2, - enable_hierarchical_comm=enable_hierarchical_comm + enable_hierarchical_comm=enable_hierarchical_comm, ) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="TP") tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) - ep_model = ep_model.to(get_current_device()) - tp_model = tp_model.to(get_current_device()) - local_model = local_model.to(get_current_device()) + ep_model = ep_model.to(get_accelerator().get_current_device()) + tp_model = tp_model.to(get_accelerator().get_current_device()) + local_model = local_model.to(get_accelerator().get_current_device()) # sync ep param sync_moe_model_param(ep_model) @@ -163,11 +164,11 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size tp_grad_handler = MoeGradientHandler(tp_model) rank = dist.get_rank() - input_data = torch.randn(batch_size, dim, device=get_current_device()) + input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device()) micro_batch_size = batch_size // world_size index = rank * micro_batch_size # NOTE: ep & tp takes in sharded data for each process - shard_data = input_data.detach()[index:index + micro_batch_size] + shard_data = input_data.detach()[index : index + micro_batch_size] out_local = local_model(input_data) MOE_MANAGER.reset_loss() @@ -176,13 +177,15 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size out_ep = ep_model(shard_data) MOE_MANAGER.reset_loss() - assert torch.allclose(out_tp, out_ep, atol=1e-6), \ - f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}" + assert torch.allclose( + out_tp, out_ep, atol=1e-6 + ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}" try: - out_local_slice = out_local[index:index + micro_batch_size] - assert torch.allclose(out_ep, out_local_slice, atol=1e-6), \ - f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}" - except AssertionError as e: + out_local_slice = out_local[index : index + micro_batch_size] + assert torch.allclose( + out_ep, out_local_slice, atol=1e-6 + ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}" + except AssertionError: """ e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1 router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2 @@ -193,8 +196,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature. """ warnings.warn( - "EP & TP may result in different behavior from local model. " - "Please check the comments for details." + "EP & TP may result in different behavior from local model. " "Please check the comments for details." ) out_local.mean().backward() @@ -208,10 +210,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) try: sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) - except AssertionError as e: + except AssertionError: warnings.warn( - "EP & TP may result in different behavior from local model. " - "Please check the comments for details." + "EP & TP may result in different behavior from local model. " "Please check the comments for details." ) @@ -219,14 +220,17 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size @pytest.mark.parametrize("num_experts", [4, 64]) @pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("dim", [64]) -@pytest.mark.parametrize("config", [ - {"enable_hierarchical_comm": False}, - {"enable_hierarchical_comm": True}, -]) +@pytest.mark.parametrize( + "config", + [ + {"enable_hierarchical_comm": False}, + {"enable_hierarchical_comm": True}, + ], +) @rerun_if_address_is_in_use() def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict): spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_ep_tp(num_experts=8, batch_size=32, dim=32) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 95c0e715dc34..2f08a335de5a 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -3,11 +3,11 @@ import torch.nn as nn import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe.experts import MLPExperts from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device HIDDEN_SIZE = 4 INTERMEDIATE_SIZE = 8 @@ -46,7 +46,7 @@ def run_moe_init(expert_parallel): assert dist.get_rank(parallel_info_dict[1].dp_group) == rank model = nn.ModuleList([exp0, exp1, exp2]) - model = model.to(get_current_device()) + model = model.to(get_accelerator().get_current_device()) sync_moe_model_param(model) # MOE experts layout success when ep_size = 1 diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 6bbe3e4e8172..6d932156a270 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -8,7 +8,8 @@ import torch from torch import Tensor -from colossalai.utils import get_current_device, multi_tensor_applier +from colossalai.accelerator import get_accelerator +from colossalai.utils import multi_tensor_applier _FUSED_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), @@ -64,9 +65,9 @@ def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_av class FusedAdamKernel(AdamKernel): def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() self.fused_adam = fused_optim.multi_tensor_adam self.dummy_overflow_buf = torch.cuda.IntTensor([0]) @@ -90,9 +91,9 @@ def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_av class CPUAdamKernel(AdamKernel): def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) - from colossalai.kernel.op_builder import CPUAdamBuilder + from colossalai.kernel.kernel_loader import CPUAdamLoader - cpu_optim = CPUAdamBuilder().load() + cpu_optim = CPUAdamLoader().load() self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) @@ -155,7 +156,9 @@ def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): rtol, atol = 1e-3, 1e-3 if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: rtol, atol = 4e-3, 4e-3 - check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol) + check_adam_kernel( + FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_accelerator().get_current_device(), 3, rtol, atol + ) @pytest.mark.parametrize("adamw", [False, True]) diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index caf6e6bbbd42..6f5e734b7472 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -3,11 +3,11 @@ import torch.distributed as dist import colossalai +from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device WORLD_SIZE = 2 @@ -19,7 +19,7 @@ def check_p2p_communication(): rank = dist.get_rank() - tensor = torch.ones(1, device=get_current_device()) + tensor = torch.ones(1, device=get_accelerator().get_current_device()) data = [ "tensor", tensor, diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 5f27be39657d..a08dc6d277d0 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -155,7 +155,7 @@ def run_dist( @pytest.mark.dist -@pytest.mark.parametrize("num_microbatch", [4, 12]) +@pytest.mark.parametrize("num_microbatch", [4, 6]) @pytest.mark.parametrize("batch_size", [12]) @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 87e6618023d3..62d4d1bf3c7c 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -154,7 +154,7 @@ def _criterion(outputs, inputs): data = data_gen_fn() - if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: + if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0: seq_len = data["input_ids"].shape[-1] lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) times = lcm // seq_len diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index a5c465ba0b07..3ec1700045e3 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -4,13 +4,11 @@ import torch from einops import rearrange -from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN -from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN +from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN from colossalai.testing import clear_cache_before_run, parameterize if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: - from colossalai.kernel.cuda_native import ColoAttention - from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention DTYPE = [torch.float16, torch.bfloat16, torch.float32] diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index 5977c706fdd1..e4dc569b825b 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -4,15 +4,15 @@ from torch.distributed.distributed_c10d import _get_default_group import colossalai +from colossalai.accelerator import get_accelerator from colossalai.tensor import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.chunk import Chunk def dist_sum(x): - temp = torch.tensor([x], device=get_current_device()) + temp = torch.tensor([x], device=get_accelerator().get_current_device()) dist.all_reduce(temp) return temp.item() @@ -66,7 +66,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): assert my_chunk.cpu_shard.size(0) == 1024 // world_size assert my_chunk.device_type == "cpu" assert my_chunk.can_move - my_chunk.shard_move(get_current_device()) + my_chunk.shard_move(get_accelerator().get_current_device()) else: assert my_chunk.cuda_global_chunk.size(0) == 1024 assert my_chunk.device_type == "cuda" diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 21afff753ae6..3a9742e01566 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -5,11 +5,11 @@ from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd_bwd @@ -47,7 +47,7 @@ def exam_gpt_fwd_bwd( use_grad_checkpoint: bool = False, master_weights: bool = True, ): - init_device = get_current_device() + init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) ) diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index 35323e516071..36a803492b6d 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -6,10 +6,10 @@ from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd @@ -53,7 +53,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): def exam_gemini_grad_acc( placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool ): - init_device = get_current_device() + init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) ) diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 152bf289502a..7f3c7176e99e 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -7,11 +7,11 @@ from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd @@ -47,7 +47,9 @@ def multi_chunk_init(model: torch.nn.Module, placement_config: dict): def single_chunk_init(model: torch.nn.Module, placement_config: dict): - model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config) + model = GeminiDDP( + model, chunk_init_device=get_accelerator().get_current_device(), pin_memory=True, **placement_config + ) return model @@ -63,7 +65,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - init_dev = get_current_device() + init_dev = get_accelerator().get_current_device() model = model_builder().to(init_dev) for torch_p, p in zip(torch_model.parameters(), model.parameters()): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 405d7d789b01..71bb27b4aca1 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -5,11 +5,11 @@ from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd_bwd @@ -150,7 +150,7 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. model = GeminiDDP( model, - chunk_init_device=get_current_device(), + chunk_init_device=get_accelerator().get_current_device(), search_range_m=1, pin_memory=True, mixed_precision=mixed_precision, diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py index e99f6d59ba8e..cf3658bf9920 100644 --- a/tests/test_zero/test_gemini/test_search.py +++ b/tests/test_zero/test_gemini/test_search.py @@ -2,8 +2,8 @@ import torch import colossalai +from colossalai.accelerator import get_accelerator from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from tests.kit.model_zoo import model_zoo @@ -34,7 +34,7 @@ def exam_chunk_manager(): sharded_ddp_model = model_builder() chunk_manager = init_chunk_manager( sharded_ddp_model, - get_current_device(), + get_accelerator().get_current_device(), hidden_dim=128, search_range_m=1, min_chunk_size_m=0, diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 351ae5f67ff7..11f738615d16 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -7,9 +7,10 @@ from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.testing import spawn from colossalai.testing.random import seed_all -from colossalai.utils import conditional_context, get_current_device +from colossalai.utils import conditional_context from colossalai.zero import LowLevelZeroOptimizer @@ -28,7 +29,7 @@ def forward(self, x): def exam_zero_1_2_grad_acc(): local_rank = torch.distributed.get_rank() seed_all(2009) - device = get_current_device() + device = get_accelerator().get_current_device() # create model zero1_model = MlpModel().to(device) zero2_model = copy.deepcopy(zero1_model) @@ -71,7 +72,7 @@ def fwd_bwd_func(number, cur_data, check_flag): def exam_zero_1_grad_acc(sync): local_rank = torch.distributed.get_rank() seed_all(2008) - device = get_current_device() + device = get_accelerator().get_current_device() # create models zero_model = MlpModel()