From 1c2f03240cc40997ddda5e4f673f0fd82121148b Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 20 Jun 2023 10:39:06 +0800 Subject: [PATCH 01/80] [cluster] add process group mesh (#4039) * [cluster] add process group mesh * [test] add process group mesh test * force sync --- colossalai/cluster/__init__.py | 3 +- colossalai/cluster/process_group_mesh.py | 203 ++++++++++++++++++ tests/test_cluster/test_process_group_mesh.py | 151 +++++++++++++ 3 files changed, 356 insertions(+), 1 deletion(-) create mode 100644 colossalai/cluster/process_group_mesh.py create mode 100644 tests/test_cluster/test_process_group_mesh.py diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py index 2fbdfd3cc999..44f571ca2501 100644 --- a/colossalai/cluster/__init__.py +++ b/colossalai/cluster/__init__.py @@ -1,5 +1,6 @@ from .device_mesh_manager import DeviceMeshManager from .dist_coordinator import DistCoordinator from .process_group_manager import ProcessGroupManager +from .process_group_mesh import ProcessGroupMesh -__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager'] +__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager', 'ProcessGroupMesh'] diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py new file mode 100644 index 000000000000..1dfd261d5d01 --- /dev/null +++ b/colossalai/cluster/process_group_mesh.py @@ -0,0 +1,203 @@ +import itertools +from functools import reduce +from operator import mul +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +def prod(nums: List[int]) -> int: + """Product of a list of numbers. + + Args: + nums (List[int]): A list of numbers. + + Returns: + int: The product of the numbers. + """ + return reduce(mul, nums) + + +class ProcessGroupMesh: + """A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method. + It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation. + + We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process. + For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``. + + Args: + *size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size. + + Attributes: + shape (Tuple[int, ...]): The shape of the process group mesh. + rank (int): The rank of the current process. + """ + + def __init__(self, *size: int) -> None: + assert dist.is_initialized(), "Please initialize torch.distributed first." + assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size." + self._shape = size + self._rank = dist.get_rank() + self._coord = ProcessGroupMesh.unravel(self._rank, self._shape) + self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {} + self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} + + @property + def shape(self) -> Tuple[int, ...]: + return self._shape + + @property + def rank(self) -> int: + return self._rank + + def size(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]: + """Get the size of the process group mesh. + + Args: + dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None. + + Returns: + Union[int, Tuple[int, ...]]: Size of the target dimension or the whole process group mesh. + """ + if dim is None: + return self._shape + else: + return self._shape[dim] + + def coordinate(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]: + """Get the coordinate of the process group mesh. + + Args: + dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None. + + Returns: + Union[int, Tuple[int, ...]]: Coordinate of the target dimension or the whole process group mesh. + """ + if dim is None: + return self._coord + else: + return self._coord[dim] + + @staticmethod + def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Convert a rank to a coordinate. + + Args: + rank (int): Rank to be converted. + shape (Tuple[int, ...]): Shape of the process group mesh. + + Returns: + Tuple[int, ...]: Coordinate of the rank. + """ + return np.unravel_index(rank, shape) + + @staticmethod + def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int: + """Convert a coordinate to a rank. + + Args: + coords (Tuple[int, ...]): Coordinate to be converted. + shape (Tuple[int, ...]): Shape of the process group mesh. + + Returns: + int: Rank of the coordinate. + """ + return np.ravel_multi_index(coord, shape) + + def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: + """Get the process group with the given ranks. It the process group doesn't exist, it will be created. + + Args: + ranks_in_group (List[int]): Ranks in the process group. + backend (Optional[str], optional): Backend of the process group. Defaults to None. + + Returns: + ProcessGroup: The process group with the given ranks. + """ + ranks_in_group = sorted(ranks_in_group) + if tuple(ranks_in_group) not in self._group_to_ranks: + group = dist.new_group(ranks_in_group, backend=backend) + self._ranks_to_group[tuple(ranks_in_group)] = group + self._group_to_ranks[group] = tuple(ranks_in_group) + return self._ranks_to_group[tuple(ranks_in_group)] + + def get_ranks_in_group(self, group: ProcessGroup) -> List[int]: + """Get the ranks in the given process group. The process group must be created by this class. + + Args: + group (ProcessGroup): The process group. + + Returns: + List[int]: Ranks in the process group. + """ + return list(self._group_to_ranks[group]) + + @staticmethod + def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int, + indices_at_axis: List[int]) -> List[Tuple[int, ...]]: + """Get coordinates along the given axis. + + Args: + base_coord (Tuple[int, ...]): Base coordinate which the coordinates along the axis are based on. + axis (int): Axis along which the coordinates are generated. + indices_at_axis (List[int]): Indices at the axis. + + Returns: + List[Tuple[int, ...]]: Coordinates along the axis. + """ + coords_in_group = [] + for idx in indices_at_axis: + coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1:]) + return coords_in_group + + def create_group_along_axis(self, + axis: int, + indices_at_axis: Optional[List[int]] = None, + backend: Optional[str] = None) -> ProcessGroup: + """Create all process groups along the given axis, and return the one which the current process belongs to. + + Args: + axis (int): Axis along which the process groups are created. + indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. + backend (Optional[str], optional): Backend of the process group. Defaults to None. + + Returns: + ProcessGroup: The process group along the given axis which the current process belongs to. + """ + indices_at_axis = indices_at_axis or list(range(self._shape[axis])) + reduced_shape = list(self._shape) + # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` + reduced_shape[axis] = 1 + target_group = None + # use Cartesian product to generate all combinations of coordinates + for base_coord in itertools.product(*[range(s) for s in reduced_shape]): + coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) + ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) + group = self.get_group(ranks_in_group, backend=backend) + if self._rank in ranks_in_group: + target_group = group + return target_group + + def get_group_along_axis(self, + axis: int, + indices_at_axis: Optional[List[int]] = None, + backend: Optional[str] = None) -> ProcessGroup: + """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. + + Args: + axis (int): Axis along which the process groups are created. + indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. + backend (Optional[str], optional): Backend of the process group. Defaults to None. + + Returns: + ProcessGroup: The process group along the given axis which the current process belongs to. + """ + indices_at_axis = indices_at_axis or list(range(self._shape[axis])) + coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis) + ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) + if ranks_in_group not in self._ranks_to_group: + # no need to cache it explicitly, since it will be cached in `create_group_along_axis` + return self.create_group_along_axis(axis, indices_at_axis, backend=backend) + return self._ranks_to_group[ranks_in_group] diff --git a/tests/test_cluster/test_process_group_mesh.py b/tests/test_cluster/test_process_group_mesh.py new file mode 100644 index 000000000000..13b7119424e4 --- /dev/null +++ b/tests/test_cluster/test_process_group_mesh.py @@ -0,0 +1,151 @@ +import pytest +import torch.distributed as dist + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.testing import spawn + + +def check_process_group_mesh_with_gpc(): + from colossalai.context import ParallelMode + from colossalai.core import global_context as gpc + + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + pg_mesh = ProcessGroupMesh(1, 2, 2) + + # check world size + assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size( + TP_DIM), f'{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}' + assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM) + assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM) + + # check locak rank (coordinate) + assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate( + TP_DIM), f'{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}' + assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM) + assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM) + + # check ranks in group + tp_group = pg_mesh.get_group_along_axis(TP_DIM) + assert gpc.get_ranks_in_group(ParallelMode.TENSOR) == pg_mesh.get_ranks_in_group(tp_group) + pp_group = pg_mesh.get_group_along_axis(PP_DIM) + assert gpc.get_ranks_in_group(ParallelMode.PIPELINE) == pg_mesh.get_ranks_in_group(pp_group) + dp_group = pg_mesh.get_group_along_axis(DP_DIM) + assert gpc.get_ranks_in_group(ParallelMode.DATA) == pg_mesh.get_ranks_in_group(dp_group) + + # check prev rank + coord = pg_mesh.coordinate() + if not gpc.is_first_rank(ParallelMode.TENSOR): + assert coord[TP_DIM] != 0 + prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1:] + assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape) + if not gpc.is_first_rank(ParallelMode.PIPELINE): + assert coord[PP_DIM] != 0 + prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1:] + assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape) + + # check next rank + if not gpc.is_last_rank(ParallelMode.TENSOR): + assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1 + next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1:] + assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape) + if not gpc.is_last_rank(ParallelMode.PIPELINE): + assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1 + next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1:] + assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape) + + +def check_process_group_mesh_with_cases(): + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + DP_SIZE, PP_SIZE, TP_SIZE = 1, 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0, 0), + 1: (0, 0, 1), + 2: (0, 1, 0), + 3: (0, 1, 1), + } + TP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + PP_RANKS_IN_GROUP = { + 0: [0, 2], + 1: [1, 3], + 2: [0, 2], + 3: [1, 3], + } + DP_RANKS_IN_GROUP = { + 0: [0], + 1: [1], + 2: [2], + 3: [3], + } + + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE) + + rank = dist.get_rank() + assert rank == pg_mesh.rank + + # check world size + assert pg_mesh.size(TP_DIM) == 2 + assert pg_mesh.size(PP_DIM) == 2 + assert pg_mesh.size(DP_DIM) == 1 + + # check coordinate + assert pg_mesh.coordinate(TP_DIM) == RANK_TO_COORDINATE[rank][TP_DIM] + assert pg_mesh.coordinate(PP_DIM) == RANK_TO_COORDINATE[rank][PP_DIM] + assert pg_mesh.coordinate(DP_DIM) == RANK_TO_COORDINATE[rank][DP_DIM] + + # check ranks in group + tp_group = pg_mesh.get_group_along_axis(TP_DIM) + assert pg_mesh.get_ranks_in_group(tp_group) == TP_RANKS_IN_GROUP[rank] + pp_group = pg_mesh.get_group_along_axis(PP_DIM) + assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank] + dp_group = pg_mesh.get_group_along_axis(DP_DIM) + assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank] + + # check prev rank + if RANK_TO_COORDINATE[rank][TP_DIM] != 0: + prev_coord = RANK_TO_COORDINATE[rank][:TP_DIM] + (RANK_TO_COORDINATE[rank][TP_DIM] - 1,) + \ + RANK_TO_COORDINATE[rank][TP_DIM + 1:] + prev_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) - 1] + assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank + if RANK_TO_COORDINATE[rank][PP_DIM] != 0: + prev_coord = RANK_TO_COORDINATE[rank][:PP_DIM] + (RANK_TO_COORDINATE[rank][PP_DIM] - 1,) + \ + RANK_TO_COORDINATE[rank][PP_DIM + 1:] + prev_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) - 1] + assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank + + # check next rank + if RANK_TO_COORDINATE[rank][TP_DIM] != TP_SIZE - 1: + next_coord = RANK_TO_COORDINATE[rank][:TP_DIM] + (RANK_TO_COORDINATE[rank][TP_DIM] + 1,) + \ + RANK_TO_COORDINATE[rank][TP_DIM + 1:] + next_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) + 1] + assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank + if RANK_TO_COORDINATE[rank][PP_DIM] != PP_SIZE - 1: + next_coord = RANK_TO_COORDINATE[rank][:PP_DIM] + (RANK_TO_COORDINATE[rank][PP_DIM] + 1,) + \ + RANK_TO_COORDINATE[rank][PP_DIM + 1:] + next_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) + 1] + assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode='1d', size=2))), + rank=rank, + world_size=world_size, + port=port, + host='localhost') + # TODO(ver217): this function should be removed when gpc is removed + check_process_group_mesh_with_gpc() + check_process_group_mesh_with_cases() + + +@pytest.mark.dist +def test_process_group_mesh(): + spawn(run_dist, 4) + + +if __name__ == '__main__': + test_process_group_mesh() From d71b8e44df080f1e23d34650128ee0913b1b7f7b Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 27 Jun 2023 16:17:01 +0800 Subject: [PATCH 02/80] [pipeline] add stage manager (#4093) * [pipeline] add stage manager * [test] add pipeline stage manager test * [pipeline] add docstring for stage manager --- colossalai/pipeline/stage_manager.py | 176 ++++++++++++++++++++++ tests/test_pipeline/test_stage_manager.py | 86 +++++++++++ 2 files changed, 262 insertions(+) create mode 100644 colossalai/pipeline/stage_manager.py create mode 100644 tests/test_pipeline/test_stage_manager.py diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py new file mode 100644 index 000000000000..fe228e2270dd --- /dev/null +++ b/colossalai/pipeline/stage_manager.py @@ -0,0 +1,176 @@ +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple + +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.cluster import ProcessGroupMesh + + +class PipelineStageManager: + """PipelineStageManager is a helper class to manage pipeline stages. + + Args: + pg_mesh (ProcessGroupMesh): Process group mesh. + pipeline_axis (int): The axis along which the pipeline is constructed. + + Attributes: + num_stages (int): Number of stages in the pipeline. + stage (int): The current stage. + num_virtual_stages (int): Number of virtual stages in the pipeline. + virtual_stage (int): The current virtual stage. + """ + + def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None: + self.pg_mesh = pg_mesh + self.pipeline_axis = pipeline_axis + self.num_virtual_stages: Optional[int] = None + self.virtual_stage: Optional[int] = None + self.prev_rank: Optional[Tuple[int, ...]] = None + self.next_rank: Optional[Tuple[int, ...]] = None + self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} + # init prev and next coord + coord = self.pg_mesh.coordinate() + if self.stage > 0: + prev_coord = coord[: self.pipeline_axis] + \ + (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] + self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape) + if self.stage < self.num_stages - 1: + next_coord = coord[: self.pipeline_axis] + \ + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] + self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape) + + # init p2p process groups + stages = list(range(self.num_stages)) + for prev, cur in zip(stages[:-1], stages[1:]): + group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur]) + if self.stage in [prev, cur]: + ranks_in_group = self.pg_mesh.get_ranks_in_group(group) + self.p2p_groups[tuple(ranks_in_group)] = group + + def is_first_stage(self, virtual: bool = False) -> bool: + """Is the current stage the first stage. + + Args: + virtual (bool, optional): Whether to consider virtual stages. Defaults to False. + + Returns: + bool: Whether the current stage is the first stage. + """ + if virtual: + assert self.num_virtual_stages is not None + return self.virtual_stage == 0 + return self.stage == 0 + + def is_last_stage(self, virtual: bool = False) -> bool: + """Is the current stage the last stage. + + Args: + virtual (bool, optional): Whether to consider virtual stages. Defaults to False. + + Returns: + bool: Whether the current stage is the last stage. + """ + if virtual: + assert self.num_virtual_stages is not None + return self.virtual_stage == self.num_virtual_stages - 1 + return self.stage == self.num_stages - 1 + + @property + def num_stages(self) -> int: + """Number of stages in the pipeline. + + Returns: + int: Number of stages in the pipeline. + """ + return self.pg_mesh.size(self.pipeline_axis) + + @property + def stage(self) -> int: + """Current stage. + + Returns: + int: Current stage. + """ + return self.pg_mesh.coordinate(self.pipeline_axis) + + def get_rank(self) -> int: + """Get the rank of the current process. + + Returns: + int: Rank of the current process. + """ + return dist.get_rank() + + def get_prev_rank(self) -> int: + """Get the rank of the previous stage. + + Returns: + int: Rank of the previous stage. + """ + assert not self.is_first_stage(), "Cannot get previous rank in the first stage." + return self.prev_rank + + def get_next_rank(self) -> int: + """Get the rank of the next stage. + + Returns: + int: Rank of the next stage. + """ + assert not self.is_last_stage(), "Cannot get next rank in the last stage." + return self.next_rank + + def set_num_virtual_stages(self, num_virtual_stages: int) -> None: + """Set the number of virtual stages. + + Args: + num_virtual_stages (int): Number of virtual stages. + """ + self.num_virtual_stages = num_virtual_stages + + def set_virtual_stage(self, virtual_stage: int) -> None: + """Set the virtual stage. + + Args: + virtual_stage (int): Virtual stage. + """ + self.virtual_stage = virtual_stage + + @contextmanager + def switch_virtual_stage(self, virtual_stage: int) -> None: + """A context manager to switch virtual stage. + + Args: + virtual_stage (int): Target virtual stage. + """ + old_stage = self.virtual_stage + try: + self.set_virtual_stage(virtual_stage) + yield + finally: + self.set_virtual_stage(old_stage) + + def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup: + """Get the p2p process group between two ranks. The order of the two ranks does not matter. + + Args: + first_rank (int): The first rank. + second_rank (int): The second rank. + + Returns: + ProcessGroup: P2P process group between the two ranks. + """ + if first_rank > second_rank: + first_rank, second_rank = second_rank, first_rank + return self.p2p_groups[(first_rank, second_rank)] + + def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup: + """Get the process group of the given stages. + + Args: + stages (List[int]): List of stages. + + Returns: + ProcessGroup: Process group of the given stages. + """ + return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py new file mode 100644 index 000000000000..b920f88dbfae --- /dev/null +++ b/tests/test_pipeline/test_stage_manager.py @@ -0,0 +1,86 @@ +import pytest +import torch.distributed as dist + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import spawn + + +def check_stage_manager(): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + # check stage info + assert stage_manager.num_stages == PP_SIZE + assert stage_manager.stage == RANK_TO_COORDINATE[rank][PP_DIM] + + # check is_first_stage + ranks_in_group = PP_RANKS_IN_GROUP[rank] + is_first_stage = ranks_in_group.index(rank) == 0 + assert stage_manager.is_first_stage() == is_first_stage + + # check is_last_stage + is_last_stage = ranks_in_group.index(rank) == len(ranks_in_group) - 1 + assert stage_manager.is_last_stage() == is_last_stage + + # check prev rank + if not is_first_stage: + prev_rank = ranks_in_group[ranks_in_group.index(rank) - 1] + assert stage_manager.get_prev_rank() == prev_rank + + # check next rank + if not is_last_stage: + next_rank = ranks_in_group[ranks_in_group.index(rank) + 1] + assert stage_manager.get_next_rank() == next_rank + + # check virtual stage + stage_manager.set_num_virtual_stages(PP_SIZE * 2) + assert stage_manager.num_virtual_stages == PP_SIZE * 2 + stage_manager.set_virtual_stage(stage_manager.stage * 2) + assert stage_manager.virtual_stage == stage_manager.stage * 2 + with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1): + assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1 + assert stage_manager.virtual_stage == stage_manager.stage * 2 + + # check p2p groups + for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]): + if rank in [prev, cur]: + group = stage_manager.get_p2p_process_group(prev, cur) + dist.barrier(group=group) + + # check stage groups + pg_mesh = ProcessGroupMesh(4) + stage_manager = PipelineStageManager(pg_mesh, 0) + group = stage_manager.init_process_group_by_stages([0, 2]) + if rank in [0, 2]: + dist.barrier(group=group) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_stage_manager() + + +@pytest.mark.dist +def test_process_group_mesh(): + spawn(run_dist, 4) + + +if __name__ == '__main__': + test_process_group_mesh() From a636ea485ce58b0c908014dfa22e0cc030791c66 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 28 Jun 2023 17:12:19 +0800 Subject: [PATCH 03/80] [pipeline] implement p2p communication (#4100) * [pipeline] add p2p communication * [test] add p2p communication test * [test] add rerun decorator * [test] rename to avoid conflict --- colossalai/pipeline/p2p.py | 224 ++++++++++++++++++ tests/test_pipeline/test_p2p_communication.py | 59 +++++ tests/test_pipeline/test_stage_manager.py | 7 +- 3 files changed, 287 insertions(+), 3 deletions(-) create mode 100644 colossalai/pipeline/p2p.py create mode 100644 tests/test_pipeline/test_p2p_communication.py diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py new file mode 100644 index 000000000000..203b7439d7ef --- /dev/null +++ b/colossalai/pipeline/p2p.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import io +import pickle +from typing import Any, List, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed import distributed_c10d as c10d + +from .stage_manager import PipelineStageManager + +_unpickler = pickle.Unpickler + + +def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object: + """transform tensor to object with unpickle. + Info of the device in bytes stream will be modified into current device before unpickling + + Args: + tensor (:class:`torch.tensor`): tensor to be unpickled + tensor_size (:class:`torch.Size`): Size of the real info in bytes + + Returns: + Any: object after unpickled + """ + buf = tensor.numpy().tobytes()[:tensor_size] + if b'cuda' in buf: + buf_array = bytearray(buf) + device_index = torch.cuda.current_device() + buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index + buf = bytes(buf_array) + + io_bytes = io.BytesIO(buf) + byte_pickler = _unpickler(io_bytes) + unpickle = byte_pickler.load() + + return unpickle + + +def _broadcast_object_list(object_list: List[Any], + src: int, + group: ProcessGroup, + device: Optional[Union[torch.device, str, int]] = None): + """This is a modified version of the broadcast_object_list in torch.distribution + The only difference is that object will be move to correct device after unpickled. + If local_rank = src, then object list will be sent to rank src. Otherwise, object list will + be updated with data sent from rank src. + + Args: + object_list (List[Any]): list of object to broadcast + src (int): source rank to broadcast + dst (int): dst rank to broadcast + device (:class:`torch.device`): device to do broadcast. current device in default + + """ + + if c10d._rank_not_in_group(group): + c10d._warn_not_in_group("broadcast_object_list") + return + + my_rank = dist.get_rank() + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) + + is_nccl_backend = c10d._check_for_nccl_backend(group) + current_device = None + + if device is not None: + if is_nccl_backend and device.type != "cuda": + raise ValueError("device type must be cuda for nccl backend") + current_device = device + else: + current_device = torch.device("cpu") + if is_nccl_backend: + current_device = torch.device("cuda", torch.cuda.current_device()) + if is_nccl_backend: + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False) + + # Concatenate and broadcast serialized object tensors + if my_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + ) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + + c10d.broadcast(object_tensor, src=src, group=group, async_op=False) + + # Deserialize objects using their stored sizes. + offset = 0 + + if my_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset:offset + obj_size] + obj_view = obj_view.type(torch.uint8) + if obj_view.device != torch.device("cpu"): + obj_view = obj_view.cpu() + offset += obj_size + # unpickle + unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) + + # unconsistence in device + if isinstance(unpickle_object, + torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): + unpickle_object = unpickle_object.cuda() + + object_list[i] = unpickle_object + + +def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: + """send anything to dst rank + + Args: + object (Any): object needed to be sent + dst (int): rank of the destination + + Returns: + None + """ + # then broadcast safely + _broadcast_object_list([object], src, group) + + +def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: + """recv anything from src + + Args: + src (int): source rank of data. local rank will receive data from src rank. + + Returns: + Any: Object received from src. + """ + object_list = [None] + _broadcast_object_list(object_list, src, group) + + return object_list[0] + + +class PipelineP2PCommunication: + + def __init__(self, stage_manager: PipelineStageManager) -> None: + self.stage_manager = stage_manager + + def recv_forward(self, prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + + Args: + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if self.stage_manager.is_first_stage(): + input_tensor = None + else: + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + cur_rank = self.stage_manager.get_rank() + input_tensor = _recv_object(prev_rank, cur_rank, + self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) + + return input_tensor + + def recv_backward(self, next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + + Args: + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if self.stage_manager.is_last_stage(): + output_tensor_grad = None + else: + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + output_tensor_grad = _recv_object(next_rank, cur_rank, + self.stage_manager.get_p2p_process_group(next_rank, cur_rank)) + + return output_tensor_grad + + def send_forward(self, output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.stage_manager.is_last_stage(): + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + _send_object(output_object, cur_rank, next_rank, + self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) + + def send_backward(self, input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not self.stage_manager.is_first_stage(): + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + cur_rank = self.stage_manager.get_rank() + _send_object(input_object, cur_rank, prev_rank, + self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py new file mode 100644 index 000000000000..71946f6b988a --- /dev/null +++ b/tests/test_pipeline/test_p2p_communication.py @@ -0,0 +1,59 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + + +def check_p2p_communication(): + pg_mesh = ProcessGroupMesh(2) + stage_manager = PipelineStageManager(pg_mesh, 0) + p2p = PipelineP2PCommunication(stage_manager) + + rank = dist.get_rank() + + tensor = torch.ones(1, device=get_current_device()) + + if rank == 0: + p2p.send_forward(tensor) + p2p.send_forward([tensor]) + p2p.send_forward({'tensor': tensor}) + else: + obj = p2p.recv_forward() + assert torch.equal(obj, tensor) + obj = p2p.recv_forward() + assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) + obj = p2p.recv_forward() + assert type(obj) == dict and 'tensor' in obj and torch.equal(obj['tensor'], tensor) + + if rank == 1: + p2p.send_backward(tensor) + p2p.send_backward([tensor]) + p2p.send_backward({'tensor': tensor}) + else: + obj = p2p.recv_backward() + assert torch.equal(obj, tensor) + obj = p2p.recv_backward() + assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) + obj = p2p.recv_backward() + assert type(obj) == dict and 'tensor' in obj and torch.equal(obj['tensor'], tensor) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_p2p_communication() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pipeline_p2p(): + spawn(run_dist, 2) + + +if __name__ == '__main__': + test_pipeline_p2p() diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index b920f88dbfae..be4591d58f74 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -4,7 +4,7 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_stage_manager(): @@ -78,9 +78,10 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -def test_process_group_mesh(): +@rerun_if_address_is_in_use() +def test_pipeline_stage_manager(): spawn(run_dist, 4) if __name__ == '__main__': - test_process_group_mesh() + test_pipeline_stage_manager() From aa690f9dc2cf7fa1f8db58b48d9dd269bd6cd77c Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 29 Jun 2023 13:35:39 +0800 Subject: [PATCH 04/80] [pipeline] refactor 1f1b schedule (#4115) * [api] update optimizer wrapper to fit pipeline * [pipeline] add base schedule * [pipeline] add 1f1b schedule * [test] add pipeline schedule utils test * [pipeline] fix import --- colossalai/interface/optimizer.py | 4 + colossalai/pipeline/schedule/__init__.py | 7 + colossalai/pipeline/schedule/_utils.py | 129 ++++++++++ colossalai/pipeline/schedule/base.py | 35 +++ colossalai/pipeline/schedule/one_f_one_b.py | 229 ++++++++++++++++++ .../test_pipeline_schedule_utils.py | 47 ++++ 6 files changed, 451 insertions(+) create mode 100644 colossalai/pipeline/schedule/__init__.py create mode 100644 colossalai/pipeline/schedule/_utils.py create mode 100644 colossalai/pipeline/schedule/base.py create mode 100644 colossalai/pipeline/schedule/one_f_one_b.py create mode 100644 tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index 0eaf2e1ef8ba..bc270b1d9c89 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -1,5 +1,6 @@ from typing import Union +import torch import torch.nn as nn from torch import Tensor from torch.optim import Optimizer @@ -53,6 +54,9 @@ def backward(self, loss: Tensor, *args, **kwargs): """ loss.backward(*args, **kwargs) + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + torch.autograd.backward(tensor, grad) + def state_dict(self): """ Returns the optimizer state. diff --git a/colossalai/pipeline/schedule/__init__.py b/colossalai/pipeline/schedule/__init__.py new file mode 100644 index 000000000000..8b13413b1a31 --- /dev/null +++ b/colossalai/pipeline/schedule/__init__.py @@ -0,0 +1,7 @@ +from .base import PipelineSchedule +from .one_f_one_b import OneForwardOneBackwardSchedule + +__all__ = [ + 'PipelineSchedule', + 'OneForwardOneBackwardSchedule', +] diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py new file mode 100644 index 000000000000..045c86e40e63 --- /dev/null +++ b/colossalai/pipeline/schedule/_utils.py @@ -0,0 +1,129 @@ +from typing import Any, List, Optional + +import torch +import torch.cuda +from torch.nn import Module +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + + +def to_device(x: Any, device: Optional[torch.device] = None) -> Any: + """Move object to device if it is a tensor. + + Args: + x (Any): Object to be moved. + device (Optional[torch.device], optional): Target device. Defaults to None. + + Returns: + Any: Moved object. + """ + if isinstance(x, torch.Tensor): + return x.to(device) + return x + + +def get_batch_size(batch: Any) -> int: + """Get the batch size (size of dimension-0) of the first tensor in the batch. + + Args: + batch (Any): Batch to be inspected. + + Raises: + RuntimeError: If no tensor is found in the batch. + + Returns: + int: Batch size. + """ + data_list, _ = tree_flatten(batch) + for data in data_list: + if isinstance(data, torch.Tensor): + return data.size(0) + raise RuntimeError('No tensor found in the batch') + + +def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any: + """Get a micro batch of the original batch. + + Args: + batch (Any): Batch to be sliced. + start (int): Start index of the micro batch. + micro_batch_size (int): Size of the micro batch. + + Returns: + Any: Target micro batch. + """ + + def _get_tensor_slice(x: Any): + if isinstance(x, torch.Tensor): + return x[start:start + micro_batch_size] + return x + + return tree_map(_get_tensor_slice, batch) + + +def model_forward(model: Module, data: Any, internal_inputs: Optional[dict]) -> Any: + """Call model forward function with data and internal inputs. + + Args: + model (Module): Model to be called. + data (Any): Data loaded from data iterator. + internal_inputs (Optional[dict]): Data from previous stage. It must be a dict or None if it's the first stage. + + Returns: + Any: Outputs of the model. + """ + if internal_inputs is None: + internal_inputs = {} + if isinstance(data, (list, tuple)): + return model(*data, **internal_inputs) + elif isinstance(data, dict): + return model(**data, **internal_inputs) + return model(data, **internal_inputs) + + +def retain_grad(x: Any) -> None: + """Call retain_grad() on a tensor. + + Args: + x (Any): Object to be called. + """ + if isinstance(x, torch.Tensor): + x.retain_grad() + + +def detach(x: Any) -> Any: + """Call detach() on a tensor. + + Args: + x (Any): Object to be called. + + Returns: + Any: The detached object. + """ + if isinstance(x, torch.Tensor): + return x.detach() + return x + + +def merge_batch(data: List[Any]) -> Any: + """Merge micro batches into a batch. + + Args: + data (List[Any]): A list of micro batches. + + Returns: + Any: Merge batch. + """ + if len(data) == 0: + return + flattened_data = [] + tree_spec = None + for d in data: + elems, tree_spec = tree_flatten(d) + flattened_data.append(elems) + merged_data = [] + for elem_batch in zip(*flattened_data): + if isinstance(elem_batch[0], torch.Tensor): + merged_data.append(torch.cat(elem_batch, dim=0)) + else: + merged_data.append(list(elem_batch)) + return tree_unflatten(merged_data, tree_spec) diff --git a/colossalai/pipeline/schedule/base.py b/colossalai/pipeline/schedule/base.py new file mode 100644 index 000000000000..9cd9beded65a --- /dev/null +++ b/colossalai/pipeline/schedule/base.py @@ -0,0 +1,35 @@ +from typing import Any, Callable, Iterable + +from torch import Tensor +from torch.nn import Module + +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class PipelineSchedule: + + def __init__(self, stage_manager: PipelineStageManager) -> None: + self.stage_manager = stage_manager + + def forward_backward_step(self, + model: Module, + optimizer: OptimizerWrapper, + data_iter: Iterable, + criterion: Callable[[Any, Any], Tensor], + return_loss: bool = False, + return_outputs: bool = False) -> dict: + """Forward and backward step for pipeline training. + + Args: + model (Module): Model to be trained. + optimizer (OptimizerWrapper): Optimizer to be used. + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + raise NotImplementedError diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py new file mode 100644 index 000000000000..a8933bfbb4da --- /dev/null +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -0,0 +1,229 @@ +from functools import partial +from typing import Any, Callable, Iterable, List, Optional, Union + +import torch +import torch.cuda +from torch.nn import Module +from torch.utils._pytree import tree_map + +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.utils.cuda import get_current_device + +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device +from .base import PipelineSchedule + + +class OneForwardOneBackwardSchedule(PipelineSchedule): + + def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None: + super().__init__(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager) + self.num_microbatches = num_microbatches + self.batch: Optional[Any] = None + self.batch_size: Optional[int] = None + self.microbatch_offset: Optional[int] = None + self.microbatch_size: Optional[int] = None + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + self.batch = batch + self.batch_size = get_batch_size(batch) + self.microbatch_offset = 0 + assert self.batch_size % self.num_microbatches == 0, \ + "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + + def load_micro_batch(self) -> Any: + """Load a micro batch from the current batch. + + Returns: + Any: Micro batch. + """ + micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) + self.microbatch_offset += self.microbatch_size + return tree_map(partial(to_device, device=get_current_device()), micro_batch) + + def forward_step(self, + model: Module, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + + Args: + model (Module): Model to be run + input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. + criterion (Callable): Criterion to calculate loss. + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + micro_batch = self.load_micro_batch() + + # for the first stage, input_obj is None + # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict + output_obj = model_forward(model, micro_batch, input_obj) + + if self.stage_manager.is_last_stage(): + loss = criterion(output_obj, micro_batch) / self.num_microbatches + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map(detach, output_obj)) + return loss + else: + return output_obj + + def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]: + """Backward one step of the pipeline + + Args: + optimizer (OptimizerWrapper): Optimizer to update the model + input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None. + output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor). + output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None. + + Returns: + Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None. + """ + + # Retain the grad on the input_obj. + tree_map(retain_grad, input_obj) + + # Backward pass. + if output_obj_grad is None: + optimizer.backward(output_obj) + else: + for k, grad in output_obj_grad.items(): + optimizer.backward_by_grad(output_obj[k], grad) + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + input_obj_grad = {} + for k, v in input_obj.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + return input_obj_grad + + def forward_backward_step(self, + model: Module, + optimizer: OptimizerWrapper, + data_iter: Iterable, + criterion: Callable[..., Any], + return_loss: bool = False, + return_outputs: bool = False) -> dict: + """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. + + Args: + model (Module): Model to be trained. + optimizer (OptimizerWrapper): Optimizer to be used. + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + forward_only = not torch.is_grad_enabled() + + self.load_batch(data_iter) + + # num_warmup_microbatches is the step when not all the processes are working + num_warmup_microbatches = self.stage_manager.num_stages - self.stage_manager.stage - 1 + num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) + num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + + if not forward_only: + input_objs = [] + output_objs = [] + + outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + if return_loss and self.stage_manager.is_last_stage(): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + input_obj = self.comm.recv_forward() + + output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) + + self.comm.send_forward(output_obj) + + if not forward_only: + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Before running 1F1B, need to receive first forward tensor. + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_microbatches_remaining > 0: + input_obj = self.comm.recv_forward() + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + last_iteration = (i == (num_microbatches_remaining - 1)) + + output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) + if forward_only: + self.comm.send_forward(output_obj) + + if not last_iteration: + input_obj = self.comm.recv_forward() + + else: + # TODO adjust here + self.comm.send_forward(output_obj) + output_obj_grad = self.comm.recv_backward() + + # Add input_obj and output_obj to end of list. + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + else: + input_obj = self.comm.recv_forward() + self.comm.send_backward(input_obj_grad) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_warmup_microbatches): + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + output_obj_grad = self.comm.recv_backward() + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + self.comm.send_backward(input_obj_grad) + + if outputs is not None: + outputs = merge_batch(outputs) + return {'loss': accum_loss, 'outputs': outputs} diff --git a/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py b/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py new file mode 100644 index 000000000000..4c23a23ebaba --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py @@ -0,0 +1,47 @@ +import torch + +from colossalai.pipeline.schedule._utils import get_batch_size, get_micro_batch, merge_batch + + +def test_get_batch_size(): + tensor = torch.rand(2, 3) + assert get_batch_size(tensor) == 2 + assert get_batch_size([tensor]) == 2 + assert get_batch_size((1, tensor)) == 2 + assert get_batch_size({'tensor': tensor}) == 2 + assert get_batch_size({'dummy': [1], 'tensor': tensor}) == 2 + assert get_batch_size({'tensor': [tensor]}) == 2 + + +def test_get_micro_batch(): + x = torch.rand(2, 1) + y = torch.rand(2, 3) + micro_batch = get_micro_batch(x, 0, 1) + assert torch.equal(micro_batch, x[0:1]) + micro_batch = get_micro_batch(x, 1, 1) + assert torch.equal(micro_batch, x[1:2]) + micro_batch = get_micro_batch([x, y], 0, 1) + assert torch.equal(micro_batch[0], x[0:1]) + assert torch.equal(micro_batch[1], y[0:1]) + micro_batch = get_micro_batch([x, y], 1, 1) + assert torch.equal(micro_batch[0], x[1:2]) + assert torch.equal(micro_batch[1], y[1:2]) + micro_batch = get_micro_batch({'x': x, 'y': y}, 0, 1) + assert torch.equal(micro_batch['x'], x[0:1]) + assert torch.equal(micro_batch['y'], y[0:1]) + micro_batch = get_micro_batch({'x': x, 'y': y}, 1, 1) + assert torch.equal(micro_batch['x'], x[1:2]) + assert torch.equal(micro_batch['y'], y[1:2]) + + +def test_merge_batch(): + x = torch.rand(2, 1) + y = torch.rand(2, 3) + merged = merge_batch([x[0:1], x[1:2]]) + assert torch.equal(merged, x) + merged = merge_batch([[x[0:1], y[0:1]], [x[1:2], y[1:2]]]) + assert torch.equal(merged[0], x) + assert torch.equal(merged[1], y) + merged = merge_batch([{'x': x[0:1], 'y': y[0:1]}, {'x': x[1:2], 'y': y[1:2]}]) + assert torch.equal(merged['x'], x) + assert torch.equal(merged['y'], y) From 3c2fbe2caa836aeddea05593389b58489355e692 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 4 Jul 2023 13:46:16 +0800 Subject: [PATCH 05/80] [pipeline]add pipeline policy and bert forward (#4130) * add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt --- colossalai/pipeline/policy/__init__.py | 22 + colossalai/pipeline/policy/base.py | 108 +++++ colossalai/pipeline/policy/bert.py | 390 ++++++++++++++++++ colossalai/pipeline/policy/bloom.py | 153 +++++++ .../test_policy/test_bert_model.py | 112 +++++ tests/test_pipeline/test_stage_manager.py | 2 +- 6 files changed, 786 insertions(+), 1 deletion(-) create mode 100644 colossalai/pipeline/policy/__init__.py create mode 100644 colossalai/pipeline/policy/base.py create mode 100644 colossalai/pipeline/policy/bert.py create mode 100644 colossalai/pipeline/policy/bloom.py create mode 100644 tests/test_pipeline/test_policy/test_bert_model.py diff --git a/colossalai/pipeline/policy/__init__.py b/colossalai/pipeline/policy/__init__.py new file mode 100644 index 000000000000..fd9e6e04588e --- /dev/null +++ b/colossalai/pipeline/policy/__init__.py @@ -0,0 +1,22 @@ +from typing import Any, Dict, List, Optional, Tuple, Type + +from torch import Tensor +from torch.nn import Module, Parameter + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy +from .bert import BertModel, BertModelPolicy + +POLICY_MAP: Dict[Type[Module], Type[Policy]] = { + BertModel: BertModelPolicy, +} + + +def pipeline_parallelize( + model: Module, + stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: + if type(model) not in POLICY_MAP: + raise NotImplementedError(f"Policy for {type(model)} not implemented") + policy = POLICY_MAP[type(model)](stage_manager) + return policy.parallelize_model(model) diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py new file mode 100644 index 000000000000..ad595a04b1b0 --- /dev/null +++ b/colossalai/pipeline/policy/base.py @@ -0,0 +1,108 @@ +from typing import Any, Dict, List, Optional, Tuple + +from colossalai.lazy import LazyTensor +from torch import Tensor +from torch.nn import Module, Parameter + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class Policy: + def __init__(self, stage_manager: PipelineStageManager) -> None: + self.stage_manager = stage_manager + + def setup_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor]]: + """Setup model for pipeline parallel + + Args: + module (Module): Module to be setup + + Returns: + Tuple[Dict[str, Parameter], Dict[str, Tensor]]: Hold parameters and buffers + """ + hold_params = set() + hold_buffers = set() + + def init_layer(layer: Module): + for p in layer.parameters(): + if isinstance(p, LazyTensor): + p.materialize() + p.data = p.cuda() + hold_params.add(p) + for b in layer.buffers(): + if isinstance(b, LazyTensor): + b.materialize() + b.data = b.cuda() + hold_buffers.add(b) + + hold_layers = self.get_hold_layers(module) + + for layer in hold_layers: + init_layer(layer) + + hold_params_dict = {} + hold_buffers_dict = {} + + # release other tensors + for n, p in module.named_parameters(): + if p in hold_params: + hold_params_dict[n] = p + else: + if isinstance(p, LazyTensor): + p.materialize() + p.data = p.cuda() + p.storage().resize_(0) + for n, b in module.named_buffers(): + if b in hold_buffers: + hold_buffers_dict[n] = b + else: + if isinstance(b, LazyTensor): + b.materialize() + b.data = b.cuda() + # FIXME(ver217): use meta tensor may be better + b.storage().resize_(0) + return hold_params_dict, hold_buffers_dict + + def replace_forward(self, module: Module) -> None: + """Replace module forward in place. This method should be implemented by subclass. The output of internal layers must be a dict + + Args: + module (Module): _description_ + """ + raise NotImplementedError + + def get_hold_layers(self, module: Module) -> List[Module]: + """Get layers that should be hold in current stage. This method should be implemented by subclass. + + Args: + module (Module): Module to be setup + + Returns: + List[Module]: List of layers that should be hold in current stage + """ + raise NotImplementedError + + def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]: + """Get parameters that should be shared across stages. This method should be implemented by subclass. + + Args: + module (Module): Module to be setup + + Returns: + List[Module]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] + """ + raise NotImplementedError + + def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: + """Parallelize model for pipeline parallel + + Args: + module (Module): Module to be setup + + Returns: + Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: Hold parameters, buffers and shared parameters + """ + hold_params, hold_buffers = self.setup_model(module) + self.replace_forward(module) + shared_params = self.get_shared_params(module) + return hold_params, hold_buffers, shared_params diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py new file mode 100644 index 000000000000..6f912d2c6b80 --- /dev/null +++ b/colossalai/pipeline/policy/bert.py @@ -0,0 +1,390 @@ +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from transformers.models.bert.modeling_bert import BertForPreTraining, BertForPreTrainingOutput, BertModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy + +logger = logging.get_logger(__name__) + + +def bert_model_forward( + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + #labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage +): + #TODO: add explaination of the output here. + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + # debugging + # preprocess: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask + else: + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + hidden_states = hidden_states if hidden_states is not None else None + if stage_manager.is_first_stage(): + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + #inherit from bert_layer + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + next_decoder_cache = () if use_cache else None + + #calculate the num_layers + num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage + + #layer_outputs + layer_outputs = hidden_states if hidden_states is not None else None + for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): + if stage_manager.is_first_stage() and idx == 0: + encoder_attention_mask = encoder_extended_attention_mask + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[idx] if head_mask is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + #end of a stage loop + sequence_output = layer_outputs[0] if layer_outputs is not None else None + + if stage_manager.is_last_stage(): + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + if not return_dict: + return (sequence_output, pooled_output) + layer_outputs[1:] + + #output of non-first and non-last stages: + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + + #return dict is not supported at this moment + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# The layer partition policy for bertmodel +class BertModelPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + + def get_hold_layers(self, module: BertModel) -> List[Module]: + """ + get pipeline layers for current stage + """ + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.embeddings) + num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) + hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \ + [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: + num_layers_per_stage_accumulated[self.stage_manager.stage]]) + + if self.stage_manager.is_last_stage(): + hold_layers.append(module.pooler) + + return hold_layers + + def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: + '''no shared params in bertmodel''' + pass + + def replace_forward(self, module: Module) -> None: + module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) + + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + """ + divide layers into stages + """ + quotient = num_layers // num_stages + remainder = num_layers % num_stages + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages + + # deal with the rest layers + if remainder > 0: + start_position = num_layers // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage + + +def bert_for_pretraining_forward( + self: BertForPreTraining, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, +) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BertForPreTrainingPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + + def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: + """ + get pipeline layers for current stage + """ + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.bert.embeddings) + num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) + hold_layers.extend(module.bert.encoder.layer[num_layers_per_stage_accumulated \ + [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: + num_layers_per_stage_accumulated[self.stage_manager.stage]]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.cls) + + return hold_layers + + def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]: + '''no shared params in bertmodel''' + pass + + def replace_forward(self, module: Module) -> None: + module.model.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), + module.model) + + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + """ + divide layers into stages + """ + quotient = num_layers // num_stages + remainder = num_layers % num_stages + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages + + # deal with the rest layers + if remainder > 0: + start_position = num_layers // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py new file mode 100644 index 000000000000..8dffcd8f9af5 --- /dev/null +++ b/colossalai/pipeline/policy/bloom.py @@ -0,0 +1,153 @@ +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.bloom.modeling_bloom import BloomModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy + + +def bloom_model_forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, +) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py new file mode 100644 index 000000000000..c92f7f6c34c0 --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -0,0 +1,112 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bert.modeling_bert import BertModel + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bert import BertModelPolicy, bert_model_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bert_model_forward(): + model = BertModel.from_pretrained('bert-base-uncased') + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + #print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x) + output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('start the training') + else: + attention_mask = torch.ones((2, 12, 3, 3)) + output = bert_model_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bert_model_policy(): + model = BertModel.from_pretrained('bert-base-uncased') + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + #print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer), 2) + assert model_policy.layers_per_stage == [6, 6] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_model_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_model_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_model_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_model_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bert model forward and bert model policy""" + test_bert_model_forward() + test_bert_model_policy() diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index be4591d58f74..67a2e90532e2 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -21,7 +21,7 @@ def check_stage_manager(): 1: [0, 1], 2: [2, 3], 3: [2, 3], - } + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() From b902be27cdc89930ecd21b0abd81628ef0646a3a Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 27 Jun 2023 16:17:01 +0800 Subject: [PATCH 06/80] [pipeline] add stage manager (#4093) * [pipeline] add stage manager * [test] add pipeline stage manager test * [pipeline] add docstring for stage manager --- tests/test_pipeline/test_stage_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index 67a2e90532e2..be4591d58f74 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -21,7 +21,7 @@ def check_stage_manager(): 1: [0, 1], 2: [2, 3], 3: [2, 3], - } + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() From 35213bc11f30d4dea0192ae70623a915acd9b85b Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 4 Jul 2023 13:46:16 +0800 Subject: [PATCH 07/80] [pipeline]add pipeline policy and bert forward (#4130) * add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt --- tests/test_pipeline/test_stage_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index be4591d58f74..67a2e90532e2 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -21,7 +21,7 @@ def check_stage_manager(): 1: [0, 1], 2: [2, 3], 3: [2, 3], - } + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() From c80f8cde407a19a5c1ac177d9d079a106245733f Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Wed, 5 Jul 2023 10:52:53 +0800 Subject: [PATCH 08/80] [pipeline] build bloom model and policy , revise the base class of policy (#4161) * add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt * add bloom model and policy ,revise the base class of policy * revise * revision * add bert_for_pretraining --- colossalai/pipeline/policy/base.py | 37 +++++- colossalai/pipeline/policy/bert.py | 94 ++++++-------- colossalai/pipeline/policy/bloom.py | 111 ++++++++++++---- .../test_policy/test_bert_model.py | 5 +- .../test_policy/test_bloom_model.py | 119 ++++++++++++++++++ 5 files changed, 286 insertions(+), 80 deletions(-) create mode 100644 tests/test_pipeline/test_policy/test_bloom_model.py diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py index ad595a04b1b0..9736f1004fe4 100644 --- a/colossalai/pipeline/policy/base.py +++ b/colossalai/pipeline/policy/base.py @@ -1,13 +1,15 @@ from typing import Any, Dict, List, Optional, Tuple -from colossalai.lazy import LazyTensor +import numpy as np from torch import Tensor from torch.nn import Module, Parameter +from colossalai.lazy import LazyTensor from colossalai.pipeline.stage_manager import PipelineStageManager class Policy: + def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager @@ -93,7 +95,8 @@ def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]: """ raise NotImplementedError - def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: + def parallelize_model(self, + module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: """Parallelize model for pipeline parallel Args: @@ -106,3 +109,33 @@ def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[ self.replace_forward(module) shared_params = self.get_shared_params(module) return hold_params, hold_buffers, shared_params + + @staticmethod + def distribute_layers(num_layers: int, num_stages: int) -> List[int]: + """ + divide layers into stages + """ + quotient = num_layers // num_stages + remainder = num_layers % num_stages + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages + + # deal with the rest layers + if remainder > 0: + start_position = num_layers // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage + + @staticmethod + def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: + """ + get the start index and end index of layers for each stage. + """ + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + start_idx = num_layers_per_stage_accumulated[stage] + end_idx = num_layers_per_stage_accumulated[stage + 1] + + return [start_idx, end_idx] diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 6f912d2c6b80..8cd0fadd167f 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -22,25 +22,26 @@ def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - #labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + # labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + # this is from the previous stage + hidden_states: Optional[torch.FloatTensor] = None, ): - #TODO: add explaination of the output here. + # TODO: add explaination of the output here. r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -93,6 +94,7 @@ def bert_model_forward( batch_size, seq_length = input_shape device = hidden_states.device + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -144,7 +146,7 @@ def bert_model_forward( else: encoder_extended_attention_mask = None - #inherit from bert_layer + # inherit from bert_layer all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None @@ -156,12 +158,12 @@ def bert_model_forward( use_cache = False next_decoder_cache = () if use_cache else None - #calculate the num_layers + # calculate the num_layers num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages start_layer = stage_manager.stage * num_layers_per_stage end_layer = (stage_manager.stage + 1) * num_layers_per_stage - #layer_outputs + # layer_outputs layer_outputs = hidden_states if hidden_states is not None else None for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): if stage_manager.is_first_stage() and idx == 0: @@ -206,12 +208,13 @@ def custom_forward(*inputs): if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + all_cross_attentions = all_cross_attentions + \ + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - #end of a stage loop + # end of a stage loop sequence_output = layer_outputs[0] if layer_outputs is not None else None if stage_manager.is_last_stage(): @@ -219,7 +222,7 @@ def custom_forward(*inputs): if not return_dict: return (sequence_output, pooled_output) + layer_outputs[1:] - #output of non-first and non-last stages: + # output of non-first and non-last stages: if not return_dict: return tuple(v for v in [ hidden_states, @@ -229,7 +232,7 @@ def custom_forward(*inputs): all_cross_attentions, ] if v is not None) - #return dict is not supported at this moment + # return dict is not supported at this moment return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, @@ -243,6 +246,7 @@ def custom_forward(*inputs): class BertModelPolicy(Policy): def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager self.layers_per_stage = self.distribute_layers(num_layers, num_stages) @@ -253,11 +257,8 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: hold_layers = [] if self.stage_manager.is_first_stage(): hold_layers.append(module.embeddings) - num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) - hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \ - [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: - num_layers_per_stage_accumulated[self.stage_manager.stage]]) - + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) + hold_layers.extend(module.encoder.layer[start_idx:end_idx]) if self.stage_manager.is_last_stage(): hold_layers.append(module.pooler) @@ -270,23 +271,6 @@ def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: def replace_forward(self, module: Module) -> None: module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: - """ - divide layers into stages - """ - quotient = num_layers // num_stages - remainder = num_layers % num_stages - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages - - # deal with the rest layers - if remainder > 0: - start_position = num_layers // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage - def bert_for_pretraining_forward( self: BertForPreTraining, @@ -306,8 +290,8 @@ def bert_for_pretraining_forward( ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( + outputs = bert_model_forward( + self.bert, input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -320,7 +304,8 @@ def bert_for_pretraining_forward( ) sequence_output, pooled_output = outputs[:2] - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + if stage_manager.is_last_stage(): + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) total_loss = None if labels is not None and next_sentence_label is not None: @@ -355,11 +340,12 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: hold_layers = [] if self.stage_manager.is_first_stage(): hold_layers.append(module.bert.embeddings) - num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) - hold_layers.extend(module.bert.encoder.layer[num_layers_per_stage_accumulated \ - [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: - num_layers_per_stage_accumulated[self.stage_manager.stage]]) + + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) + hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.bert.pooler) hold_layers.append(module.cls) return hold_layers diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py index 8dffcd8f9af5..71d2913fc3aa 100644 --- a/colossalai/pipeline/policy/bloom.py +++ b/colossalai/pipeline/policy/bloom.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from types import MethodType from typing import Dict, List, Optional, Tuple, Union @@ -14,6 +15,8 @@ from .base import Policy +logger = logging.get_logger(__name__) + def bloom_model_forward( self: BloomModel, @@ -26,6 +29,8 @@ def bloom_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop("position_ids", False) is not False: @@ -44,28 +49,45 @@ def bloom_model_forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - + # add warnings here + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) + # case: First stage of training + if stage_manager.is_first_stage(): + # check input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(inputs_embeds) + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + # initialize in the first stage and then pass to the next stage + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + + # extra recording tensor should be generated in the first stage presents = () if use_cache else None all_self_attentions = () if output_attentions else None @@ -77,11 +99,14 @@ def bloom_model_forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False - # Compute alibi tensor: check build_alibi_tensor documentation + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] + past_key_values_length = past_key_values[0][0].shape[2] # source_len + seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) @@ -90,13 +115,19 @@ def bloom_model_forward( alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + # causal_mask is constructed every stage and its input is passed through different stages causal_mask = self._prepare_attn_mask( attention_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length, ) - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # calculate the num_layers + num_layers_per_stage = len(self.h) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage + + for i, (block, layer_past) in enumerate(zip(self.h[start_layer:end_layer], past_key_values[start_layer:end_layer])): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -130,24 +161,60 @@ def custom_forward(*inputs): ) hidden_states = outputs[0] + if use_cache is True: presents = presents + (outputs[1],) - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + \ + (outputs[2 if use_cache else 1],) - # Add last hidden state - hidden_states = self.ln_f(hidden_states) + if stage_manager.is_last_stage(): + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + # TODO: deal with all_hidden_states, all_self_attentions, presents if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + # attention_mask is not returned ; presents = past_key_values return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) + + +class BloomModelPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + super().__init__(stage_manager=stage_manager) + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + + def get_hold_layers(self, module: BloomModel) -> List[Module]: + """ + get pipeline layers for current stage + """ + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.word_embeddings) + hold_layers.append(module.word_embeddings_layernorm) + + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) + hold_layers.extend(module.h[start_idx:end_idx]) + + if self.stage_manager.is_last_stage(): + hold_layers.append(module.ln_f) + + return hold_layers + + def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]: + '''no shared params in bloommodel''' + pass + + def replace_forward(self, module: Module) -> None: + module.forward = MethodType(partial(bloom_model_forward, stage_manager=self.stage_manager), module.model) diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index c92f7f6c34c0..cf5dc95feb8c 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -27,7 +27,8 @@ def check_bert_model_forward(): 3: [2, 3], } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - #print(pg_mesh) + + # print(pg_mesh) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() @@ -72,7 +73,7 @@ def check_bert_model_policy(): 3: [2, 3], } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - #print(pg_mesh) + # print(pg_mesh) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() diff --git a/tests/test_pipeline/test_policy/test_bloom_model.py b/tests/test_pipeline/test_policy/test_bloom_model.py new file mode 100644 index 000000000000..5ba92d734590 --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bloom_model.py @@ -0,0 +1,119 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bloom import BloomConfig, BloomModel + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bloom import BloomModelPolicy, bloom_model_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bloom_model_forward(): + # create a BloomModel + configuration = BloomConfig() + model = BloomModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32) + if stage_manager.is_first_stage(): + attention_mask = torch.ones_like(x) + output = bloom_model_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 64) + print('start the training') + else: + attention_mask = torch.ones((2, 3)) + output = bloom_model_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 64) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bloom_model_policy(): + # create a BloomModel + configuration = BloomConfig() + model = BloomModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BloomModelPolicy(stage_manager=stage_manager, num_layers=len(model.h), num_stages=2) + assert model_policy.layers_per_stage == [1, 1] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bloom_model_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bloom_model_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bloom_model_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bloom_model_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bloom model forward and bloom model policy""" + test_bloom_model_forward() + test_bloom_model_policy() From 223de4648f374bad89520497d370ecffc087b35e Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 5 Jul 2023 14:16:55 +0800 Subject: [PATCH 09/80] [pipeline] update shardformer policy --- colossalai/shardformer/policies/basepolicy.py | 33 ++++++++++++++++--- colossalai/shardformer/shard/shard_config.py | 7 +++- colossalai/shardformer/shard/sharder.py | 29 +++++++++++++++- colossalai/shardformer/shard/shardformer.py | 4 +-- colossalai/shardformer/shard/utils.py | 19 +++++++++++ 5 files changed, 84 insertions(+), 8 deletions(-) create mode 100644 colossalai/shardformer/shard/utils.py diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 2d347542fa7a..16f3fa14eca0 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -2,9 +2,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Type, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +from colossalai.pipeline.stage_manager import PipelineStageManager from ..shard.shard_config import ShardConfig @@ -71,9 +75,8 @@ class Policy(ABC): """ def __init__(self) -> None: - self.shard_config = None - self.model = None - self.shard_config = None + self.shard_config: Optional[ShardConfig] = None + self.model: Optional[Module] = None def set_model(self, model: nn.Module) -> None: r""" @@ -94,6 +97,12 @@ def set_shard_config(self, shard_config: ShardConfig) -> None: self.shard_config = shard_config self.config_sanity_check() + @property + def pipeline_stage_manager(self) -> Optional[PipelineStageManager]: + if self.shard_config is not None: + return self.shard_config.pipeline_stage_manager + return None + @abstractmethod def config_sanity_check(self): """ @@ -151,3 +160,19 @@ def append_or_create_submodule_replacement( policy[target_key] = ModulePolicyDescription(sub_module_replacement=description) return policy + + def get_held_layers(self) -> List[Module]: + """Get layers that should be held in current stage. This method should be implemented by subclass. + + Returns: + List[Module]: List of layers that should be hold in current stage + """ + raise NotImplementedError + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """Get parameters that should be shared across stages. This method should be implemented by subclass. + + Returns: + List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] + """ + return [] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 83c08d275df3..fba2c27a2a87 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,8 +1,11 @@ from dataclasses import dataclass +from typing import Optional import torch.distributed as dist from torch.distributed import ProcessGroup +from colossalai.pipeline.stage_manager import PipelineStageManager + __all__ = ['ShardConfig'] @@ -13,11 +16,13 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. + pipeline_stage_manager (PipelineStageManager): The pipeline stage manager, defaults to None, which means no pipeline. enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ - tensor_parallel_process_group: ProcessGroup = None + tensor_parallel_process_group: Optional[ProcessGroup] = None + pipeline_stage_manager: Optional[PipelineStageManager] = None enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 201e0a08cbfe..429ca8ed74be 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,11 +1,15 @@ from typing import Any, Callable, Dict, List, Union import torch.nn as nn +from torch import Tensor + +from colossalai.lazy import LazyTensor from .._utils import getattr_, setattr_ from ..policies.autopolicy import get_autopolicy from ..policies.basepolicy import Policy, SubModuleReplacementDescription from .shard_config import ShardConfig +from .utils import set_tensors_to_none __all__ = ['ModelSharder', 'shard_model'] @@ -25,15 +29,18 @@ def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = self.policy = get_autopolicy(self.model) if policy is None else policy self.shard_config = shard_config - def shard(self) -> None: + def shard(self) -> List[Dict[int, Tensor]]: r""" Shard the model according to the policy """ self.policy.set_model(self.model) self.policy.set_shard_config(self.shard_config) self._preprocess() + self._release_unheld_layers() self._replace_module() + self._materialize() self._postprocess() + return self.policy.get_shared_params() def _preprocess(self) -> None: self.model = self.policy.preprocess() @@ -172,3 +179,23 @@ def _replace_sub_module( ) setattr_(org_layer, suffix, replace_layer) + + def _release_unheld_layers(self) -> None: + r""" + Release the unheld layers in the model + """ + if self.shard_config and self.shard_config.pipeline_stage_manager: + held_layers = self.policy.get_held_layers() + set_tensors_to_none(self.model, exclude=set(held_layers)) + + def _materialize(self) -> None: + r""" + Materialize the model if lazy initialization is used + """ + for p in self.model.parameters(): + if isinstance(p, LazyTensor): + p.materialize() + + for b in self.model.buffers(): + if isinstance(b, LazyTensor): + b.materialize() diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 3fce12463414..069a46ca57ea 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -42,5 +42,5 @@ def optimize(self, model: nn.Module, policy: Policy = None): policy (`Policy`): the custom policy for sharding """ sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) - sharder.shard() - return model + shared_params = sharder.shard() + return model, shared_params diff --git a/colossalai/shardformer/shard/utils.py b/colossalai/shardformer/shard/utils.py new file mode 100644 index 000000000000..2bac37bfedda --- /dev/null +++ b/colossalai/shardformer/shard/utils.py @@ -0,0 +1,19 @@ +from typing import Set + +import torch.nn as nn + + +def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> None: + """Set all parameters and buffers of model to None + + Args: + model (nn.Module): The model to set + """ + if model in exclude: + return + for child in model.children(): + set_tensors_to_none(child, exclude=exclude) + for n, p in model.named_parameters(recurse=False): + setattr(model, n, None) + for n, buf in model.named_buffers(recurse=False): + setattr(model, n, None) From 318b5e8ad2606247a522a95ecf08bb2af1cce3a2 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 5 Jul 2023 14:19:12 +0800 Subject: [PATCH 10/80] [pipeline] update shardformer docstring --- colossalai/shardformer/shard/shardformer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 069a46ca57ea..6e0f90257df6 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,4 +1,7 @@ +from typing import Dict, List, Tuple + import torch.nn as nn +from torch import Tensor from colossalai.cluster import DistCoordinator @@ -24,7 +27,7 @@ class ShardFormer: org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') shard_config = ShardConfig() shard_former = ShardFormer(shard_config=shard_config) - model = shard_former.optimize(org_model) + model, shared_params = shard_former.optimize(org_model) ``` """ @@ -32,7 +35,7 @@ def __init__(self, shard_config: ShardConfig): self.coordinator = DistCoordinator() self.shard_config = shard_config - def optimize(self, model: nn.Module, policy: Policy = None): + def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: r""" This method will optimize the model based on the given policy. @@ -40,6 +43,8 @@ def optimize(self, model: nn.Module, policy: Policy = None): model (`torch.nn.Model`): the origin huggingface model shard_config (`ShardConfig`): the config for distribute information policy (`Policy`): the custom policy for sharding + + Returns: the sharded model and the shared parameters """ sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) shared_params = sharder.shard() From 71470a2c7ae381b2131092f28aff0e375f538f25 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 5 Jul 2023 14:30:17 +0800 Subject: [PATCH 11/80] [test] update shardformer tests --- tests/test_shardformer/test_model/_utils.py | 4 ++-- tests/test_shardformer/test_with_torch_ddp.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d83d9ecd39e0..e03014f3f234 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -12,8 +12,8 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle enable_tensor_parallelism=enable_tensor_parallelism) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) - sharded_model = shard_former.optimize(model_copy).cuda() - return org_model, sharded_model + sharded_model, shared_params = shard_former.optimize(model_copy) + return org_model, sharded_model.cuda() def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py index 9f8a5db6c94f..f29c8d6f605b 100644 --- a/tests/test_shardformer/test_with_torch_ddp.py +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -44,7 +44,7 @@ def check_shardformer_with_ddp(rank, world_size, port): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): # create and shard model model = model_fn().cuda() - sharded_model = shardformer.optimize(model) + sharded_model, _ = shardformer.optimize(model) # add ddp sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group) From fb4b18d4ea707d0c06948645fb6a5f3a93fcbcbd Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 5 Jul 2023 14:49:05 +0800 Subject: [PATCH 12/80] [test] add shard util tests --- tests/test_shardformer/test_shard_utils.py | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/test_shardformer/test_shard_utils.py diff --git a/tests/test_shardformer/test_shard_utils.py b/tests/test_shardformer/test_shard_utils.py new file mode 100644 index 000000000000..220b8291c9c6 --- /dev/null +++ b/tests/test_shardformer/test_shard_utils.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn + +from colossalai.shardformer.shard.utils import set_tensors_to_none + + +class Net(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.layers = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) + self.out = nn.Linear(3, 1) + + +def test_release_layer(): + orig_cuda_allocated = torch.cuda.memory_allocated() + model = Net().cuda() + set_tensors_to_none(model, exclude={model.layers[0]}) + assert model.layers[1].weight is None + assert model.layers[1].bias is None + assert model.out.weight is None + assert model.out.bias is None + set_tensors_to_none(model) + assert model.layers[0].weight is None + assert model.layers[0].bias is None + assert len(list(model.parameters())) == 0 + assert torch.cuda.memory_allocated() == orig_cuda_allocated From d6101a8dbe5d80bad0b9bc5159f0d53f7c360c57 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 5 Jul 2023 15:13:00 +0800 Subject: [PATCH 13/80] [shardformer] rename policy file name --- .../shardformer/policies/{autopolicy.py => auto_policy.py} | 2 +- .../shardformer/policies/{basepolicy.py => base_policy.py} | 0 colossalai/shardformer/policies/bert.py | 2 +- colossalai/shardformer/policies/bloom.py | 2 +- colossalai/shardformer/policies/gpt2.py | 2 +- colossalai/shardformer/policies/llama.py | 2 +- colossalai/shardformer/policies/opt.py | 2 +- colossalai/shardformer/policies/t5.py | 4 ++-- colossalai/shardformer/policies/vit.py | 2 +- colossalai/shardformer/shard/sharder.py | 4 ++-- colossalai/shardformer/shard/shardformer.py | 2 +- 11 files changed, 12 insertions(+), 12 deletions(-) rename colossalai/shardformer/policies/{autopolicy.py => auto_policy.py} (99%) rename colossalai/shardformer/policies/{basepolicy.py => base_policy.py} (100%) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/auto_policy.py similarity index 99% rename from colossalai/shardformer/policies/autopolicy.py rename to colossalai/shardformer/policies/auto_policy.py index 085e3150c697..8e961a240758 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -3,7 +3,7 @@ import torch.nn as nn -from .basepolicy import Policy +from .base_policy import Policy __all__ = ["PolicyLocation", "get_autopolicy", "import_policy"] diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/base_policy.py similarity index 100% rename from colossalai/shardformer/policies/basepolicy.py rename to colossalai/shardformer/policies/base_policy.py diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 9c2736cc64d3..b69ee72097c4 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -3,7 +3,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ 'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index a0b5340f72bc..8d6f07d4a67d 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -4,7 +4,7 @@ from .._utils import getattr_, setattr_ from ..modeling.bloom import build_bloom_alibi_tensor_fn -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class BloomPolicy(Policy): diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 549cdbf87a80..598f393c029a 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -3,7 +3,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ 'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy', diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 157785bdcf13..391938b27167 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -4,7 +4,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index b87db53f45f1..c4c6cde015b6 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,7 +1,7 @@ from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .._utils import getattr_, setattr_ -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ 'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy', diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index cde59ab77042..6167e81613f2 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -6,10 +6,10 @@ Linear1D_Row, VocabParallelEmbedding1D, ) -from colossalai.shardformer.policies.basepolicy import ModulePolicyDescription +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription from .._utils import getattr_, setattr_ -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index eaebe2eee0ba..3f6bbd10607a 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -4,7 +4,7 @@ from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['ViTPolicy'] diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 429ca8ed74be..ca2f46a187d1 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -6,8 +6,8 @@ from colossalai.lazy import LazyTensor from .._utils import getattr_, setattr_ -from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Policy, SubModuleReplacementDescription +from ..policies.auto_policy import get_autopolicy +from ..policies.base_policy import Policy, SubModuleReplacementDescription from .shard_config import ShardConfig from .utils import set_tensors_to_none diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 6e0f90257df6..7a0d75bf2f2a 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -5,7 +5,7 @@ from colossalai.cluster import DistCoordinator -from ..policies.basepolicy import Policy +from ..policies.base_policy import Policy from .shard_config import ShardConfig from .sharder import ModelSharder From 951bde44b31b5830a2198fb8bfe3f24e07d3a981 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 5 Jul 2023 15:20:59 +0800 Subject: [PATCH 14/80] [shardformer] fix type hint --- colossalai/shardformer/shard/shard_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index fba2c27a2a87..75fad4eb7431 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -15,8 +15,8 @@ class ShardConfig: The config for sharding the huggingface model Args: - tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. - pipeline_stage_manager (PipelineStageManager): The pipeline stage manager, defaults to None, which means no pipeline. + tensor_parallel_process_group (Optional[ProcessGroup]): The process group for tensor parallelism, defaults to None, which is the global process group. + pipeline_stage_manager (Optional[PipelineStageManager]): The pipeline stage manager, defaults to None, which means no pipeline. enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. enable_all_optimization (bool): Whether to turn on all optimization, default is False. From c9511ad1769b0a506b8a06759f8cdf21743ab4bb Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Thu, 6 Jul 2023 14:49:10 +0800 Subject: [PATCH 15/80] [pipeline] add bert_for_pretraining bert_lmhead forward and policy (#4172) * add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt * add bloom model and policy ,revise the base class of policy * revise * revision * add bert_for_pretraining * add bert_for_pretraining forward and policy * fix typos * cancel warning * change the imediate output to default dict * change the default output of get_shared_params --- colossalai/pipeline/policy/bert.py | 367 ++++++++++++------ .../test_bert_for_pretraining_model.py | 118 ++++++ .../test_policy/test_bert_lmhead_model.py | 118 ++++++ .../test_policy/test_bert_model.py | 8 +- 4 files changed, 497 insertions(+), 114 deletions(-) create mode 100644 tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py create mode 100644 tests/test_pipeline/test_policy/test_bert_lmhead_model.py diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py index 8cd0fadd167f..abce504e9d61 100644 --- a/colossalai/pipeline/policy/bert.py +++ b/colossalai/pipeline/policy/bert.py @@ -10,9 +10,15 @@ BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, ) -from transformers.models.bert.modeling_bert import BertForPreTraining, BertForPreTrainingOutput, BertModel -from transformers.utils import logging +from transformers.models.bert.modeling_bert import ( + BertForPreTraining, + BertForPreTrainingOutput, + BertLMHeadModel, + BertModel, +) +from transformers.utils import ModelOutput, logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -22,24 +28,23 @@ def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, # labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - # this is from the previous stage - hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage ): # TODO: add explaination of the output here. r""" @@ -85,10 +90,6 @@ def bert_model_forward( raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - attention_mask = extended_attention_mask else: input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape @@ -119,14 +120,29 @@ def bert_model_forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - hidden_states = hidden_states if hidden_states is not None else None + if stage_manager.is_first_stage(): hidden_states = self.embeddings( input_ids=input_ids, @@ -135,18 +151,8 @@ def bert_model_forward( inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - # inherit from bert_layer + # inherit from bert_layer,this should be changed when we add the feature to record hidden_states all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None @@ -221,34 +227,35 @@ def custom_forward(*inputs): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + layer_outputs[1:] + # return dict is not supported at this moment + else: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) - # output of non-first and non-last stages: - if not return_dict: - return tuple(v for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) - - # return dict is not supported at this moment - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) + # output of non-first and non-last stages: must be a dict + else: + # intermediate stage always return dict + return { + 'hidden_states': hidden_states, + } # The layer partition policy for bertmodel class BertModelPolicy(Policy): - def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + def __init__( + self, + stage_manager: PipelineStageManager, + num_layers: int, + ): super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) def get_hold_layers(self, module: BertModel) -> List[Module]: """ @@ -266,10 +273,10 @@ def get_hold_layers(self, module: BertModel) -> List[Module]: def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: '''no shared params in bertmodel''' - pass + return [] def replace_forward(self, module: Module) -> None: - module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) + module.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module) def bert_for_pretraining_forward( @@ -285,53 +292,74 @@ def bert_for_pretraining_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - hidden_states: Optional[torch.LongTensor] = None, + hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, -) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: - +): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output, pooled_output = outputs[:2] + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None if stage_manager.is_last_stage(): + sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + # the last stage for pretraining model + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss - total_loss = None - if labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss - - if not return_dict: - output = (prediction_scores, seq_relationship_score) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') - return BertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + # intermediate stage always return dict + return { + 'hidden_states': hidden_states, + } class BertForPreTrainingPolicy(Policy): - def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + def __init__(self, stage_manager: PipelineStageManager, num_layers: int): + super().__init__(stage_manager=stage_manager) self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: """ @@ -352,25 +380,144 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]: '''no shared params in bertmodel''' - pass + return [] def replace_forward(self, module: Module) -> None: - module.model.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), - module.model) + module.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), + module.forward) + + +def bert_lmhead_forward(self: BertLMHeadModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + use_cache = False + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + # intermediate stage always return dict + return {'hidden_states': hidden_states} + + +class BertLMHeadModelPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int): + super().__init__(stage_manager=stage_manager) + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) + + def get_hold_layers(self, module: BertLMHeadModel) -> List[Module]: """ - divide layers into stages + get pipeline layers for current stage """ - quotient = num_layers // num_stages - remainder = num_layers % num_stages - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages - - # deal with the rest layers - if remainder > 0: - start_position = num_layers // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) + hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.bert.pooler) + hold_layers.append(module.cls) + + return hold_layers + + def get_shared_params(self, module: BertLMHeadModel) -> List[Dict[int, Tensor]]: + '''no shared params in bertmodel''' + return [] + + def replace_forward(self, module: Module) -> None: + module.forward = MethodType(partial(bert_lmhead_forward, stage_manager=self.stage_manager), module) diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py new file mode 100644 index 000000000000..afbea49c1829 --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -0,0 +1,118 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bert import BertConfig +from transformers.models.bert.modeling_bert import BertForPreTraining + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bert_for_pretraining_forward(): + configuration = BertConfig() + model = BertForPreTraining(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x) + output = bert_for_pretraining_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) + print('start the training') + else: + attention_mask = torch.ones((2, 3)) + output = bert_for_pretraining_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 30522) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bert_for_pretraining_policy(): + configuration = BertConfig() + model = BertForPreTraining(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BertForPreTrainingPolicy(stage_manager, len(model.bert.encoder.layer)) + assert model_policy.layers_per_stage == [6, 6] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_for_pretraining_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_for_pretraining_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_for_pretraining_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_for_pretraining_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bert for pretraining model forward and bert for pretraining model policy""" + test_bert_for_pretraining_forward() + test_bert_for_pretraining_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py new file mode 100644 index 000000000000..d41eddc74dff --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py @@ -0,0 +1,118 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bert import BertConfig +from transformers.models.bert.modeling_bert import BertLMHeadModel + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bert import BertLMHeadModelPolicy, bert_lmhead_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bert_lmhead_forward(): + configuration = BertConfig() + model = BertLMHeadModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x) + output = bert_lmhead_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) + print('start the training') + else: + attention_mask = torch.ones((2, 3)) + output = bert_lmhead_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 30522) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bert_lmhead_policy(): + configuration = BertConfig() + model = BertLMHeadModel(configuration) + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + # print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BertLMHeadModelPolicy(stage_manager, len(model.bert.encoder.layer)) + assert model_policy.layers_per_stage == [6, 6] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_lmhead_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_lmhead_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_lmhead_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_lmhead_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bert for pretraining model forward and bert for pretraining model policy""" + test_bert_lmhead_forward() + test_bert_lmhead_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index cf5dc95feb8c..92485072a5e4 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -39,11 +39,11 @@ def check_bert_model_forward(): if stage_manager.stage == 0: attention_mask = torch.ones_like(x) output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 768) + print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 768) print('start the training') else: - attention_mask = torch.ones((2, 12, 3, 3)) + attention_mask = torch.ones((2, 3)) output = bert_model_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, @@ -78,7 +78,7 @@ def check_bert_model_policy(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer), 2) + model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer)) assert model_policy.layers_per_stage == [6, 6] layers = model_policy.get_hold_layers(model) for layer in layers: From 6f476369b40aa76aa2bee33948f8ee65586e92b1 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 7 Jul 2023 15:41:00 +0800 Subject: [PATCH 16/80] [pipeline] move bert related pipeline components to shardformer (#4187) * move bert related pipeline components to shardformer * fix bugs * revision * fix bert model tests * fix bert_lm_head model tests * fix tests * fix tests * done checks * skip bloom --- colossalai/pipeline/policy/base.py | 30 -- .../shardformer/policies/auto_policy.py | 2 +- .../shardformer/policies/base_policy.py | 31 ++ colossalai/shardformer/policies/bert.py | 487 +++++++++++++++++- .../test_bert_for_pretraining_model.py | 20 +- .../test_policy/test_bert_lmhead_model.py | 19 +- .../test_policy/test_bert_model.py | 22 +- .../test_policy/test_bloom_model.py | 8 +- .../test_layer/test_layernorm.py | 2 +- 9 files changed, 556 insertions(+), 65 deletions(-) diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py index 9736f1004fe4..f51d74fdbac3 100644 --- a/colossalai/pipeline/policy/base.py +++ b/colossalai/pipeline/policy/base.py @@ -109,33 +109,3 @@ def parallelize_model(self, self.replace_forward(module) shared_params = self.get_shared_params(module) return hold_params, hold_buffers, shared_params - - @staticmethod - def distribute_layers(num_layers: int, num_stages: int) -> List[int]: - """ - divide layers into stages - """ - quotient = num_layers // num_stages - remainder = num_layers % num_stages - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages - - # deal with the rest layers - if remainder > 0: - start_position = num_layers // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage - - @staticmethod - def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: - """ - get the start index and end index of layers for each stage. - """ - num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) - - start_idx = num_layers_per_stage_accumulated[stage] - end_idx = num_layers_per_stage_accumulated[stage + 1] - - return [start_idx, end_idx] diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 8e961a240758..640b61b579bd 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -29,7 +29,7 @@ class PolicyLocation: "transformers.models.bert.modeling_bert.BertModel": PolicyLocation(file_name="bert", class_name="BertModelPolicy"), "transformers.models.bert.modeling_bert.BertForPreTraining": - PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"), + PolicyLocation(file_name="bert", class_name="BertForPreTrainingPolicy"), "transformers.models.bert.modeling_bert.BertLMHeadModel": PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"), "transformers.models.bert.modeling_bert.BertForMaskedLM": diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 16f3fa14eca0..65aee13861ee 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union +import numpy as np import torch.nn as nn from torch import Tensor from torch.nn import Module @@ -176,3 +177,33 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] """ return [] + + @staticmethod + def distribute_layers(num_layers: int, num_stages: int) -> List[int]: + """Divide layers into stages + + """ + quotient = num_layers // num_stages + remainder = num_layers % num_stages + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages + + # deal with the rest layers + if remainder > 0: + start_position = num_layers // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage + + @staticmethod + def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: + """ + get the start index and end index of layers for each stage. + """ + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + start_idx = num_layers_per_stage_accumulated[stage] + end_idx = num_layers_per_stage_accumulated[stage + 1] + + return [start_idx, end_idx] diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index b69ee72097c4..e18cb6ece674 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,12 +1,35 @@ +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import torch import torch.nn as nn +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from transformers.models.bert.modeling_bert import ( + BertForPreTraining, + BertForPreTrainingOutput, + BertLMHeadModel, + BertModel, +) +from transformers.utils import ModelOutput, logging import colossalai.shardformer.layer as col_nn +from colossalai.pipeline.stage_manager import PipelineStageManager from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +logger = logging.get_logger(__name__) + __all__ = [ - 'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', + 'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', 'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy', 'BertForMultipleChoicePolicy' ] @@ -153,9 +176,27 @@ class BertModelPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(self.model.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.pooler) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in bert model""" + return [] + # BertForPreTraining -class BertForPretrainingPolicy(BertPolicy): +class BertForPreTrainingPolicy(BertPolicy): def __init__(self) -> None: super().__init__() @@ -165,6 +206,28 @@ def module_policy(self): module_policy = self.add_lm_head_policy(module_policy) return module_policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage""" + module = self.model + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + held_layers = [] + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.cls) + + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''No shared params in bertmodel''' + return [] + def postprocess(self): binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): @@ -184,6 +247,27 @@ def module_policy(self): module_policy = self.add_lm_head_policy(module_policy) return module_policy + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.cls) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''No shared params in bertmodel''' + return [] + def postprocess(self): binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): @@ -291,3 +375,402 @@ def module_policy(self): } module_policy.update(addon_module) return module_policy + + +def bert_model_forward( + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + # labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage +): + # TODO: add explaination of the output here. + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + # debugging + # preprocess: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + else: + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + hidden_states = hidden_states if hidden_states is not None else None + + if stage_manager.is_first_stage(): + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + # inherit from bert_layer,this should be changed when we add the feature to record hidden_states + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + next_decoder_cache = () if use_cache else None + + # calculate the num_layers + num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage + + # layer_outputs + layer_outputs = hidden_states if hidden_states is not None else None + for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): + if stage_manager.is_first_stage() and idx == 0: + encoder_attention_mask = encoder_extended_attention_mask + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[idx] if head_mask is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + \ + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # end of a stage loop + sequence_output = layer_outputs[0] if layer_outputs is not None else None + + if stage_manager.is_last_stage(): + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + if not return_dict: + return (sequence_output, pooled_output) + layer_outputs[1:] + # return dict is not supported at this moment + else: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + # output of non-first and non-last stages: must be a dict + else: + # intermediate stage always return dict + return { + 'hidden_states': hidden_states, + } + + +def bert_for_pretraining_forward( + self: BertForPreTraining, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, +): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + if stage_manager.is_last_stage(): + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + # the last stage for pretraining model + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + + # intermediate stage always return dict + return { + 'hidden_states': hidden_states, + } + + +def bert_lmhead_forward(self: BertLMHeadModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + use_cache = False + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + # intermediate stage always return dict + return {'hidden_states': hidden_states} diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py index afbea49c1829..97d7d2fa538a 100644 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -6,8 +6,9 @@ import colossalai from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.policy.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward +from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -45,7 +46,7 @@ def check_bert_for_pretraining_forward(): stage_manager=stage_manager) print(output['hidden_states'].shape) assert output['hidden_states'].shape == (2, 3, 768) - print('start the training') + else: attention_mask = torch.ones((2, 3)) output = bert_for_pretraining_forward(self=model, @@ -54,9 +55,6 @@ def check_bert_for_pretraining_forward(): stage_manager=stage_manager) print(output[0].shape) assert output[0].shape == (2, 3, 30522) - print('end the training') - print(output) - # assert output[1].shape == (2, 768) @@ -83,11 +81,13 @@ def check_bert_for_pretraining_policy(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - model_policy = BertForPreTrainingPolicy(stage_manager, len(model.bert.encoder.layer)) - assert model_policy.layers_per_stage == [6, 6] - layers = model_policy.get_hold_layers(model) - for layer in layers: - print(layer) + model_policy = BertForPreTrainingPolicy() + model_policy.set_model(model) + + model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) + model_policy.set_shard_config(model_config) + layers = model_policy.get_held_layers() + assert layers is not None def run_dist_model(rank, world_size, port): diff --git a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py index d41eddc74dff..b14dadf29e3c 100644 --- a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py +++ b/tests/test_pipeline/test_policy/test_bert_lmhead_model.py @@ -6,8 +6,9 @@ import colossalai from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.policy.bert import BertLMHeadModelPolicy, bert_lmhead_forward from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lmhead_forward +from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -45,7 +46,7 @@ def check_bert_lmhead_forward(): stage_manager=stage_manager) print(output['hidden_states'].shape) assert output['hidden_states'].shape == (2, 3, 768) - print('start the training') + else: attention_mask = torch.ones((2, 3)) output = bert_lmhead_forward(self=model, @@ -54,8 +55,6 @@ def check_bert_lmhead_forward(): stage_manager=stage_manager) print(output[0].shape) assert output[0].shape == (2, 3, 30522) - print('end the training') - print(output) # assert output[1].shape == (2, 768) @@ -83,11 +82,13 @@ def check_bert_lmhead_policy(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - model_policy = BertLMHeadModelPolicy(stage_manager, len(model.bert.encoder.layer)) - assert model_policy.layers_per_stage == [6, 6] - layers = model_policy.get_hold_layers(model) - for layer in layers: - print(layer) + model_policy = BertLMHeadModelPolicy() + model_policy.set_model(model) + model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) + model_policy.set_shard_config(model_config) + layers = model_policy.get_held_layers() + + assert layers is not None def run_dist_model(rank, world_size, port): diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index 92485072a5e4..f5a443309cb2 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -5,8 +5,9 @@ import colossalai from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.policy.bert import BertModelPolicy, bert_model_forward from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.bert import BertModelPolicy, bert_model_forward +from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -41,7 +42,6 @@ def check_bert_model_forward(): output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) print(output['hidden_states'].shape) assert output['hidden_states'].shape == (2, 3, 768) - print('start the training') else: attention_mask = torch.ones((2, 3)) output = bert_model_forward(self=model, @@ -50,8 +50,6 @@ def check_bert_model_forward(): stage_manager=stage_manager) print(output[0].shape) assert output[0].shape == (2, 3, 768) - print('end the training') - print(output) # assert output[1].shape == (2, 768) @@ -78,11 +76,14 @@ def check_bert_model_policy(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer)) - assert model_policy.layers_per_stage == [6, 6] - layers = model_policy.get_hold_layers(model) - for layer in layers: - print(layer) + model_policy = BertModelPolicy() + model_policy.set_model(model) + model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) + model_policy.set_shard_config(model_config) + + layers = model_policy.get_held_layers() + + assert layers is not None def run_dist_model(rank, world_size, port): @@ -109,5 +110,6 @@ def test_bert_model_policy(): if __name__ == "__main__": """test the bert model forward and bert model policy""" - test_bert_model_forward() + #test_bert_model_forward() test_bert_model_policy() + # this test need config to run diff --git a/tests/test_pipeline/test_policy/test_bloom_model.py b/tests/test_pipeline/test_policy/test_bloom_model.py index 5ba92d734590..73584b4f8ef1 100644 --- a/tests/test_pipeline/test_policy/test_bloom_model.py +++ b/tests/test_pipeline/test_policy/test_bloom_model.py @@ -101,12 +101,15 @@ def run_dist_policy(rank, world_size, port): check_bloom_model_policy() +#TODO: Bloom model should be fixed after bert model +@pytest.mark.skip(reason="Bloom model should be fixed after bert model") @pytest.mark.dist @rerun_if_address_is_in_use() def test_bloom_model_forward(): spawn(run_dist_model, 4) +@pytest.mark.skip(reason="Bloom model should be fixed after bert model") @pytest.mark.dist @rerun_if_address_is_in_use() def test_bloom_model_policy(): @@ -115,5 +118,6 @@ def test_bloom_model_policy(): if __name__ == "__main__": """test the bloom model forward and bloom model policy""" - test_bloom_model_forward() - test_bloom_model_policy() + # test_bloom_model_forward() + # test_bloom_model_policy() + #TODO: Bloom model should be fixed after bert model is all ready diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index a117845545be..fc6d894c4aae 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -41,4 +41,4 @@ def test_layernorm(): if __name__ == '__main__': - test_layernorm_1d() + test_layernorm() From b66870a0dc972458514256f4d1a499b7acead310 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 10 Jul 2023 10:48:53 +0800 Subject: [PATCH 17/80] [shardformer] support lazy init (#4202) * [shardformer] support lazy init * [shardformer] linear support lazy init * [shardformer] embedding support lazy init * [shardformer] norm support lazy init * [shardformer] fused linear support lazy init * [test] update shardformer test layer * [test] shardformer with lazy init fit ddp * [lazy] hotfix deepcopy of param * [shardformer] fix bert policy and update test * [shardformer] fix bloom policy and update test * [shardformer] fix opt policy and update test * [shardformer] fix t5 policy and update test * [shardformer] fix gpt2 policy and update test * [shardformer] fix llama policy and update test --- colossalai/lazy/lazy_init.py | 41 +++++++++++------ colossalai/shardformer/layer/embedding.py | 5 ++- colossalai/shardformer/layer/linear.py | 3 ++ colossalai/shardformer/layer/normalization.py | 4 ++ .../shardformer/layer/qkv_fused_linear.py | 7 ++- colossalai/shardformer/policies/bert.py | 38 +++++++++------- colossalai/shardformer/policies/bloom.py | 26 +++++------ colossalai/shardformer/policies/gpt2.py | 29 ++++++------ colossalai/shardformer/policies/llama.py | 13 +++--- colossalai/shardformer/policies/opt.py | 28 ++++++------ colossalai/shardformer/policies/t5.py | 45 ++++++++++--------- colossalai/shardformer/shard/sharder.py | 10 +---- .../test_layer/test_embedding.py | 13 ++++-- .../test_layer/test_layernorm.py | 13 ++++-- .../test_layer/test_linear_1d.py | 31 +++++++++---- .../test_layer/test_qkv_fused_linear_1d.py | 21 ++++++--- .../test_vocab_parallel_embedding_1d.py | 14 ++++-- tests/test_shardformer/test_model/_utils.py | 17 ++++--- .../test_model/test_shard_bert.py | 10 +++-- .../test_model/test_shard_bloom.py | 6 ++- .../test_model/test_shard_gpt2.py | 6 ++- .../test_model/test_shard_llama.py | 6 ++- .../test_model/test_shard_opt.py | 6 ++- .../test_model/test_shard_t5.py | 6 ++- tests/test_shardformer/test_with_torch_ddp.py | 24 +++++++--- 25 files changed, 264 insertions(+), 158 deletions(-) diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 1f5345015bf2..e071563c045a 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -6,6 +6,7 @@ import torch.distributed as dist import torch.nn as nn from torch import Tensor +from torch.nn import Parameter from torch.utils._pytree import tree_map from colossalai._analyzer._subclasses import MetaTensor @@ -99,8 +100,11 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: the converted tensor """ - cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor + cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor tensor.__class__ = cls_to_become + if cls_to_become is Parameter: + # to fit UninitializedParameter + delattr(tensor, '_is_param') tensor.data = target tensor.requires_grad = target.requires_grad # subclass of torch.Tensor does not have tolist() method @@ -198,10 +202,10 @@ def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> to def clean(self) -> None: """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. """ - self._factory_method = None - self._op_buffer = None - self._materialized_data = None - self._meta_data = None + delattr(self, '_factory_method') + delattr(self, '_op_buffer') + delattr(self, '_materialized_data') + delattr(self, '_meta_data') @staticmethod def _replace_with_materialized(x): @@ -350,20 +354,19 @@ def __deepcopy__(self, memo): def factory_fn(): # if self is materialized, return self new_tensor = self.materialize() if type(self) is LazyTensor else self - copied = new_tensor.detach().clone() - if new_tensor.requires_grad: - copied.requires_grad_() - return copied + return _copy_tensor(new_tensor, new_tensor.requires_grad) if self._materialized_data is not None: # self is early materialized - copied = self._materialized_data.detach().clone() - if self.requires_grad: - copied.requires_grad_() + copied = _copy_tensor(self._materialized_data, self.requires_grad) target = LazyTensor(lambda: None, concrete_data=copied) else: target = LazyTensor(factory_fn, meta_data=self._meta_data) + if isinstance(self, Parameter): + # hack isinstance check of parameter + target._is_param = True + memo[id(self)] = target return target @@ -408,6 +411,10 @@ def tolist(self) -> list: def __hash__(self): return id(self) + def __rpow__(self, other): + dtype = torch.result_type(self, other) + return torch.tensor(other, dtype=dtype, device=self.device)**self + class LazyInitContext: """Context manager for lazy initialization. Enables initializing the model without allocating real memory. @@ -536,7 +543,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): @staticmethod def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: - """Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. + """Initialize all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: module (nn.Module): Target ``nn.Module`` @@ -553,7 +560,7 @@ def distribute(module: nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False) -> nn.Module: - """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. + """Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: module (nn.Module): Target ``nn.Module`` @@ -625,3 +632,9 @@ def _is_int_tuple(args) -> bool: if not isinstance(x, int): return False return True + + +def _copy_tensor(tensor: Tensor, requires_grad: bool) -> Tensor: + copied = tensor.data.clone() + copied.requires_grad = requires_grad + return copied diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index db39a457b7fd..07341ef73515 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -9,8 +9,8 @@ import torch.nn.functional as F from torch import Tensor from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter +from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param @@ -95,6 +95,7 @@ def from_native_module(module: nn.Embedding, r""" Build a 1D parallelized Embedding from a native nn.Embedding module. """ + LazyInitContext.materialize(module) # get the attributes num_embedding = module.num_embeddings embedding_dim = module.embedding_dim @@ -223,6 +224,7 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, r""" Convert a native pytorch embedding module to a parallel module. """ + LazyInitContext.materialize(module) # get the origin attributes num_embeddings = module.num_embeddings embedding_dim = module.embedding_dim @@ -243,6 +245,7 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, process_group=process_group, *args, **kwargs) + with torch.no_grad(): # shard and slice the weight along the vocabulary(num_embeddings) dimension # the shape of the weight is (num_embeddings, embedding_dim) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 26ba5883c64f..a8439f303bd1 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -12,6 +12,7 @@ from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter +from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param @@ -106,6 +107,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ + LazyInitContext.materialize(module) # get the attributes in_features = module.in_features out_features = module.out_features @@ -242,6 +244,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ + LazyInitContext.materialize(module) # get the attributes in_features = module.in_features out_features = module.out_features diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index b27307154a76..9bb7738c0f0a 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -4,6 +4,8 @@ import torch import torch.nn as nn +from colossalai.lazy import LazyInitContext + __all__ = ['FusedLayerNorm', 'FusedRMSNorm'] FAST_LAYERNORM_SUPPORTED_SIZE = [ @@ -35,6 +37,7 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: raise ImportError( 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel') + LazyInitContext.materialize(module) # get the attributes of the module normalized_shape = module.normalized_shape eps = module.eps @@ -84,6 +87,7 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel' ) + LazyInitContext.materialize(module) # to check if it is huggingface LlamaRMSNorm if module.__class__.__name__ == "LlamaRMSNorm": normalized_shape = module.weight.shape[0] diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 9d51670c65dd..c94d93069e93 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -12,6 +12,7 @@ from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter +from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor.api import ( @@ -231,6 +232,7 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight. """ + LazyInitContext.materialize(module) # get the attributes in_features = module.weight.shape[0] out_features = module.weight.shape[1] @@ -380,6 +382,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ + LazyInitContext.materialize(module) # get the attributes in_features = module.weight.shape[0] out_features = module.weight.shape[1] @@ -428,9 +431,9 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) origin_device = self.bias.device - self.bias = self.bias.cuda() + self.bias.data = self.bias.cuda() dist.broadcast(self.bias, src=src_rank, group=self.process_group) - self.bias = self.bias.to(origin_device) + self.bias.data = self.bias.to(origin_device) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index e18cb6ece674..b80475e0552c 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -46,11 +46,12 @@ def preprocess(self): Reshape the Embedding layer to make the embedding dimension divisible by world_size """ # TODO: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): @@ -229,10 +230,11 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] def postprocess(self): - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) + if self.shard_config.enable_tensor_parallelism: + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) return self.model @@ -269,10 +271,11 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] def postprocess(self): - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) + if self.shard_config.enable_tensor_parallelism: + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) return self.model @@ -288,10 +291,11 @@ def module_policy(self): return module_policy def postprocess(self): - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) + if self.shard_config.enable_tensor_parallelism: + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) return self.model diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 8d6f07d4a67d..662ff5b4977a 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -17,11 +17,12 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): @@ -128,16 +129,13 @@ def module_policy(self): return policy def postprocess(self): - binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} - - for k, v in binding_map.items(): - param = getattr_(self.model, k) - - if not isinstance(param, nn.Parameter): - param = nn.Parameter(param) + if self.shard_config.enable_tensor_parallelism: + binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} - # tie weights - setattr_(self.model, v, param) + for k, v in binding_map.items(): + param = getattr_(self.model, k) + # tie weights + setattr_(self.model, v, param) return self.model diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 598f393c029a..8f9d90e67e59 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -21,11 +21,12 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): @@ -142,10 +143,11 @@ def module_policy(self): return module_policy def postprocess(self): - binding_map = {"transformer.wte.weight": "lm_head.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) + if self.shard_config.enable_tensor_parallelism: + binding_map = {"transformer.wte.weight": "lm_head.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) return self.model @@ -172,10 +174,11 @@ def module_policy(self): return module_policy def postprocess(self): - binding_map = {"transformer.wte.weight": "lm_head.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) + if self.shard_config.enable_tensor_parallelism: + binding_map = {"transformer.wte.weight": "lm_head.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) return self.model diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 391938b27167..b10e07560d22 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -15,13 +15,14 @@ def config_sanity_check(self): pass def preprocess(self): - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index c4c6cde015b6..1435805d2846 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -19,11 +19,12 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): @@ -116,14 +117,15 @@ def module_policy(self): return policy def postprocess(self): - binding_map = { - 'model.decoder.embed_tokens': 'lm_head', - } - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight + if self.shard_config.enable_tensor_parallelism: + binding_map = { + 'model.decoder.embed_tokens': 'lm_head', + } + + for k, v in binding_map.items(): + src_mod = getattr_(self.model, k) + dst_mod = getattr_(self.model, v) + dst_mod.weight = src_mod.weight return self.model diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 6167e81613f2..37864885b4cc 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -24,11 +24,12 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): @@ -164,11 +165,12 @@ def module_policy(self): return policy def postprocess(self): - binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] + if self.shard_config.enable_tensor_parallelism: + binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] - for k, v in binding_map: - mod = getattr_(self.model, k) - setattr_(self.model, v, mod) + for k, v in binding_map: + mod = getattr_(self.model, k) + setattr_(self.model, v, mod) return self.model @@ -211,13 +213,13 @@ def module_policy(self): def postprocess(self): super().postprocess() + if self.shard_config.enable_tensor_parallelism: + binding_map = {"shared": "lm_head"} - binding_map = {"shared": "lm_head"} - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight + for k, v in binding_map.items(): + src_mod = getattr_(self.model, k) + dst_mod = getattr_(self.model, v) + dst_mod.weight = src_mod.weight return self.model @@ -239,11 +241,12 @@ def module_policy(self): return base_policy def postprocess(self): - binding_map = [ - ["shared", "encoder.embed_tokens"], - ] + if self.shard_config.enable_tensor_parallelism: + binding_map = [ + ["shared", "encoder.embed_tokens"], + ] - for k, v in binding_map: - mod = getattr_(self.model, k) - setattr_(self.model, v, mod) + for k, v in binding_map: + mod = getattr_(self.model, k) + setattr_(self.model, v, mod) return self.model diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index ca2f46a187d1..56eb76973807 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -3,7 +3,7 @@ import torch.nn as nn from torch import Tensor -from colossalai.lazy import LazyTensor +from colossalai.lazy import LazyInitContext from .._utils import getattr_, setattr_ from ..policies.auto_policy import get_autopolicy @@ -192,10 +192,4 @@ def _materialize(self) -> None: r""" Materialize the model if lazy initialization is used """ - for p in self.model.parameters(): - if isinstance(p, LazyTensor): - p.materialize() - - for b in self.model.buffers(): - if isinstance(b, LazyTensor): - b.materialize() + LazyInitContext.materialize(self.model) diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py index 8a6aa42a42f2..99e494359af7 100644 --- a/tests/test_shardformer/test_layer/test_embedding.py +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -1,15 +1,22 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai +from colossalai.lazy import LazyInitContext from colossalai.shardformer.layer import Embedding1D -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + +@parameterize('lazy_init', [False, True]) +def check_embedding_1d(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() -def check_embedding_1d(): - embedding = nn.Embedding(32, 128).cuda() + with ctx: + embedding = nn.Embedding(32, 128).cuda() embedding_1d = Embedding1D.from_native_module(embedding, process_group=None) assert embedding_1d.weight.shape == torch.Size([32, 64]) diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index fc6d894c4aae..2cb6928edf83 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -1,14 +1,21 @@ +from contextlib import nullcontext + import torch import torch.nn as nn from torch.testing import assert_close import colossalai +from colossalai.lazy import LazyInitContext from colossalai.shardformer.layer import FusedLayerNorm -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + +@parameterize('lazy_init', [False, True]) +def check_layernorm(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() -def check_layernorm(): - norm = nn.LayerNorm(128, 0.00001).cuda() + with ctx: + norm = nn.LayerNorm(128, 0.00001).cuda() norm1d = FusedLayerNorm.from_native_module(norm, process_group=None) assert norm1d.weight.shape == torch.Size([128]) diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index da3bdc1d78d3..da3cd85ec407 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -1,16 +1,23 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai +from colossalai.lazy import LazyInitContext from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.tensor.d_tensor import is_distributed_tensor -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + +@parameterize('lazy_init', [False, True]) +def check_linear_1d_col(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() -def check_linear_1d_col(): - linear = nn.Linear(32, 128).cuda() + with ctx: + linear = nn.Linear(32, 128).cuda() linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) # ensure that the parameters are distributed @@ -50,8 +57,12 @@ def check_linear_1d_col(): assert_close(x_for_unshard.grad, x_for_shard.grad) -def check_linear_1d_row(): - linear = nn.Linear(32, 128).cuda() +@parameterize('lazy_init', [False, True]) +def check_linear_1d_row(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + linear = nn.Linear(32, 128).cuda() linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False) assert linear_row.weight.shape == torch.Size([128, 16]) @@ -83,9 +94,13 @@ def check_linear_1d_row(): assert_close(x_for_unshard.grad, x_for_shard.grad) -def check_linear_col_plus_row(): - linear_1 = nn.Linear(32, 128).cuda() - linear_2 = nn.Linear(128, 32).cuda() +@parameterize('lazy_init', [False, True]) +def check_linear_col_plus_row(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + linear_1 = nn.Linear(32, 128).cuda() + linear_2 = nn.Linear(128, 32).cuda() linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False) linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True) diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index 681c4f6dd9f1..186b1e8212cc 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -1,12 +1,15 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai +from colossalai.lazy import LazyInitContext from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn # This code is copied from https://github.com/huggingface/transformers @@ -50,8 +53,12 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -def check_linear_conv_1d_col(): - linear = Conv1D(192, 48).cuda() +@parameterize('lazy_init', [False, True]) +def check_linear_conv_1d_col(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + linear = Conv1D(192, 48).cuda() linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, @@ -80,8 +87,12 @@ def check_linear_conv_1d_col(): assert_close(target_grad, linear_conv_col.weight.grad) -def check_linear_conv_1d_row(): - linear = Conv1D(192, 48).cuda() +@parameterize('lazy_init', [False, True]) +def check_linear_conv_1d_row(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + linear = Conv1D(192, 48).cuda() linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) assert linear.weight.shape == torch.Size([48, 192]) diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index 8991d9b304f5..bf5803496f03 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -1,15 +1,23 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer import VocabParallelEmbedding1D +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_vocab_embedding_1d(): - embedding = nn.Embedding(128, 32).to('cuda') +@parameterize('lazy_init', [False, True]) +def check_vocab_embedding_1d(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + + with ctx: + embedding = nn.Embedding(128, 32).to('cuda') dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None) assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index e03014f3f234..f83cfcd499cb 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,19 +1,24 @@ import copy +from contextlib import nullcontext +from colossalai.lazy import LazyInitContext from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True): - # create new model - org_model = model_fn().cuda() - +def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False): + ctx = LazyInitContext() if use_lazy_init else nullcontext() + with ctx: + # create new model + org_model = model_fn() + model_copy = copy.deepcopy(org_model) + if use_lazy_init: + ctx.materialize(org_model) # shard model shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism) - model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) - return org_model, sharded_model.cuda() + return org_model.cuda(), sharded_model.cuda() def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 1afedb7079ea..7f179acd7356 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -67,12 +67,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -def run_bert_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_fused_normalization', [False, True]) +@parameterize('enable_tensor_parallelism', [False, True]) +@parameterize('use_lazy_init', [False, True]) +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index a3389652269c..e18168292df5 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('use_lazy_init', [False, True]) +def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index ee7737687d99..96c4b90a8075 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('use_lazy_init', [False, True]) +def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 74b5fdd18af8..4d63a43489a3 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -72,10 +72,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('use_lazy_init', [False, True]) +def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 25bccb13b1a8..c008596fe2b6 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -71,10 +71,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('use_lazy_init', [False, True]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 0762dc09e5af..ccd7d3787d3d 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -82,10 +82,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('use_lazy_init', [False, True]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py index f29c8d6f605b..2b6933246298 100644 --- a/tests/test_shardformer/test_with_torch_ddp.py +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -1,3 +1,5 @@ +from contextlib import nullcontext + import pytest import torch import torch.distributed as dist @@ -5,15 +7,15 @@ import colossalai from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -def check_shardformer_with_ddp(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') +@parameterize('lazy_init', [True, False]) +def check_shardformer_with_ddp(lazy_init: bool): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') @@ -41,9 +43,12 @@ def check_shardformer_with_ddp(rank, world_size, port): shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True) shardformer = ShardFormer(shard_config=shard_config) + ctx = LazyInitContext() if lazy_init else nullcontext() + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): # create and shard model - model = model_fn().cuda() + with ctx: + model = model_fn().cuda() sharded_model, _ = shardformer.optimize(model) # add ddp @@ -65,13 +70,18 @@ def check_shardformer_with_ddp(rank, world_size, port): torch.cuda.empty_cache() +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_shardformer_with_ddp() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_gpt2(): - spawn(check_shardformer_with_ddp, 4) + spawn(run_dist, 4) if __name__ == "__main__": test_gpt2() - test_gpt2() From 612a5d1920c3916641db51992c1783ce173692ca Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 10 Jul 2023 13:58:58 +0800 Subject: [PATCH 18/80] [pipeline] Bert pipeline for shardformer and its tests (#4197) * add pipeline forward * complete pipeline forward check * fix bert forward without pipeline * fix comments * discard useless line * add todo * clean prints * fix distribute layers --- .../shardformer/policies/base_policy.py | 2 +- colossalai/shardformer/policies/bert.py | 156 +++++++++++++++++- colossalai/shardformer/shard/sharder.py | 4 +- tests/test_shardformer/test_model/_utils.py | 23 +++ .../test_model/test_shard_bert_pipeline.py | 85 ++++++++++ 5 files changed, 259 insertions(+), 11 deletions(-) create mode 100644 tests/test_shardformer/test_model/test_shard_bert_pipeline.py diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 65aee13861ee..aac86eb20a56 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -191,7 +191,7 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]: # deal with the rest layers if remainder > 0: - start_position = num_layers // 2 - remainder // 2 + start_position = num_stages // 2 - remainder // 2 for i in range(start_position, start_position + remainder): layers_per_stage[i] += 1 return layers_per_stage diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index b80475e0552c..eacd0b449ad4 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -13,6 +13,8 @@ CausalLMOutputWithCrossAttentions, ) from transformers.models.bert.modeling_bert import ( + BertForMaskedLM, + BertForNextSentencePrediction, BertForPreTraining, BertForPreTrainingOutput, BertLMHeadModel, @@ -135,7 +137,6 @@ def module_policy(self): ], policy=policy, target_key=BertLayer) - # handle embedding layer self.append_or_create_submodule_replacement( description=[SubModuleReplacementDescription( @@ -144,6 +145,7 @@ def module_policy(self): )], policy=policy, target_key=BertEmbeddings) + return policy def add_lm_head_policy(self, base_policy): @@ -177,6 +179,15 @@ class BertModelPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + module_policy = super().module_policy() + from transformers.models.bert.modeling_bert import BertModel + if self.pipeline_stage_manager: + # set None as default + module_policy[BertModel] = ModulePolicyDescription( + method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)}) + return module_policy + def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" module = self.model @@ -444,6 +455,13 @@ def bert_model_forward( raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) else: input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape @@ -466,14 +484,6 @@ def bert_model_forward( if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: - if hasattr(self.embeddings, "token_type_ids"): - buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) @@ -778,3 +788,131 @@ def bert_lmhead_forward(self: BertLMHeadModel, hidden_states = outputs.get('hidden_states') # intermediate stage always return dict return {'hidden_states': hidden_states} + + +def bert_for_masked_lm_forward( + self: BertForMaskedLM, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, +): + #-> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + pass + + +def bert_for_next_sentence_prediction_forward( + self: BertForNextSentencePrediction, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + **kwargs, +): + #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, BertForNextSentencePrediction + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + # intermediate stage always return dict + return {'hidden_states': hidden_states} diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 56eb76973807..882f93c7acc5 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,3 +1,4 @@ +from types import MethodType from typing import Any, Callable, Dict, List, Union import torch.nn as nn @@ -134,7 +135,8 @@ def _replace_param( def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]): for method_name, new_method in method_replacement.items(): # bind the new method to the module - setattr(module, method_name, new_method.__get__(module, module.__class__)) + bound_method = MethodType(new_method, module) + setattr(module, method_name, bound_method) def _replace_sub_module( self, diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index f83cfcd499cb..de8cb65d21d0 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -2,6 +2,7 @@ from contextlib import nullcontext from colossalai.lazy import LazyInitContext +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -21,6 +22,28 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle return org_model.cuda(), sharded_model.cuda() +def build_pipeline_model(model_fn, + stage_manager=None, + enable_fused_normalization=False, + enable_tensor_parallelism=False, + use_lazy_init: bool = False): + ctx = LazyInitContext() if use_lazy_init else nullcontext() + with ctx: + # create new model + org_model = model_fn() + model_copy = copy.deepcopy(org_model) + if use_lazy_init: + ctx.materialize(org_model) + + # shard model + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + pipeline_stage_manager=stage_manager) + shard_former = ShardFormer(shard_config=shard_config) + sharded_model, shared_params = shard_former.optimize(model_copy) + return org_model.cuda(), sharded_model.cuda() + + def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): # prepare input data = data_gen_fn() diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py new file mode 100644 index 000000000000..9cca5ec8bc51 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py @@ -0,0 +1,85 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + pass + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_bert +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + x = torch.randint(0, 1000, (2, 3)).cuda() + hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name == 'transformers_bert': + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) + # print(output['hidden_states'].shape) + assert output['hidden_states'].shape == (2, 3, 128) + else: + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model(hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + # print(output[0].shape) + assert output[0].shape == (2, 3, 128) + + torch.cuda.empty_cache() + + +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bert_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bert(): + spawn(check_bert, 4) + + +if __name__ == "__main__": + test_bert() From fa4515f146740044484d58e870b621c3e4f69dcc Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 11 Jul 2023 11:37:26 +0800 Subject: [PATCH 19/80] [pipeline] Llama pipeline (#4205) * bloom policy * llama pipeline forward and tests * fix the output and attention_mask * fix name * bind argument to policy * Revert "bloom policy" This reverts commit 8dee68a0a22568dbeed6d4563372b25e1e825fb0. This policy should be revert and copied to feature/bloom * revert the bloom changes * cancel unneeded inputs * gpt --- .../shardformer/policies/auto_policy.py | 2 +- colossalai/shardformer/policies/bert.py | 2 +- colossalai/shardformer/policies/llama.py | 428 +++++++++++++++++- tests/kit/model_zoo/transformers/gpt.py | 2 +- tests/test_shardformer/test_model/_utils.py | 1 + .../test_model/test_shard_llama_pipeline.py | 85 ++++ 6 files changed, 516 insertions(+), 4 deletions(-) create mode 100644 tests/test_shardformer/test_model/test_shard_llama_pipeline.py diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 640b61b579bd..0ad9a3e95a0e 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -45,7 +45,7 @@ class PolicyLocation: # LLaMA "transformers.models.llama.modeling_llama.LlamaModel": - PolicyLocation(file_name="llama", class_name="LlamaPolicy"), + PolicyLocation(file_name="llama", class_name="LlamaModelPolicy"), "transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"), "transformers.models.llama.modeling_llama.LlamaForSequenceClassification": diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index eacd0b449ad4..2b2c003ffb04 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -193,7 +193,7 @@ def get_held_layers(self) -> List[Module]: module = self.model stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(self.model.encoder.layer), stage_manager.num_stages) + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b10e07560d22..b2b6470188a4 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,11 +1,30 @@ -from typing import Dict, Union +import math +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union +import torch import torch.nn as nn +from torch import Tensor +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel +from transformers.utils import ModelOutput, logging +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +logger = logging.get_logger(__name__) + __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] @@ -106,6 +125,43 @@ def postprocess(self): return self.model +class LlamaModelPolicy(LlamaPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + from transformers.models.llama.modeling_llama import LlamaModel + if self.pipeline_stage_manager: + # set None as default + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + module_policy[LlamaModel] = ModulePolicyDescription(method_replacement={ + 'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index) + }) + return module_policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in bert model""" + return [] + + class LlamaForCausalLMPolicy(LlamaPolicy): def module_policy(self): @@ -144,3 +200,373 @@ def module_policy(self): } policy.update(new_item) return policy + + +def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, +): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device) + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, + past_key_values_length) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx]): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # always return dict for imediate stage + return {'hidden_states': hidden_states} + + +def llama_for_causal_lm_forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, +): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = llama_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def llama_for_sequence_classification_forward( + self: LlamaForSequenceClassification, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + transformer_outputs = llama_model_forward( + self.model, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + batch_size = hidden_states.shape[0] + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index b9e0310780af..ac70138e3f8f 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -52,7 +52,7 @@ def data_gen_for_sequence_classification(): loss_fn = lambda x: x.loss config = transformers.GPT2Config(n_layer=2, - n_head=4, + n_head=2, vocab_size=50258, attn_pdrop=0, embd_pdrop=0, diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index de8cb65d21d0..f26c6622da7e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -39,6 +39,7 @@ def build_pipeline_model(model_fn, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism, pipeline_stage_manager=stage_manager) + shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) return org_model.cuda(), sharded_model.cuda() diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py new file mode 100644 index 000000000000..81c183d3230e --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -0,0 +1,85 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + pass + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_llama +def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + x = torch.randint(0, 1000, (2, 3)).cuda() + hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name == 'transformers_llama': + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask) + assert output['hidden_states'].shape == (2, 3, 128) + else: + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + # print(output[0].shape) + assert output[0].shape == (2, 3, 128) + + torch.cuda.empty_cache() + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, 4) + + +if __name__ == "__main__": + test_llama() From 8161405d67dcb0c4e7aad8bb4392707ee3680bbe Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 11 Jul 2023 15:23:33 +0800 Subject: [PATCH 20/80] [pipeline] Llama causal lm and llama for sequence classification pipeline (#4208) * bloom policy * llama pipeline forward and tests * fix the output and attention_mask * fix name * bind argument to policy * Revert "bloom policy" This reverts commit 8dee68a0a22568dbeed6d4563372b25e1e825fb0. This policy should be revert and copied to feature/bloom * revert the bloom changes * cancel unneeded inputs * gpt * finish llama * causal lm and sequence classification * revision --- .../shardformer/policies/base_policy.py | 18 ++++ colossalai/shardformer/policies/llama.py | 82 +++++++++++++++++-- tests/kit/model_zoo/transformers/gpt.py | 2 +- .../test_model/test_shard_llama_pipeline.py | 28 +++---- 4 files changed, 109 insertions(+), 21 deletions(-) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index aac86eb20a56..68fde0115de6 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -162,6 +162,24 @@ def append_or_create_submodule_replacement( return policy + def append_or_create_method_replacement( + self, description: Dict[str, Callable], policy: Dict[Union[str, nn.Module], ModulePolicyDescription], + target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + r""" + Append or create a new method replacement description to the policy for the given key. + + Args: + description (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended + policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated + target_key (Union[str, nn.Module]): the key of the policy to be updated + """ + if target_key in policy: + policy[target_key].method_replacement.update(description) + else: + policy[target_key] = ModulePolicyDescription(method_replacement=description) + + return policy + def get_held_layers(self) -> List[Module]: """Get layers that should be held in current stage. This method should be implemented by subclass. diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b2b6470188a4..a3ea807269bb 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -131,17 +131,20 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() + policy = super().module_policy() from transformers.models.llama.modeling_llama import LlamaModel if self.pipeline_stage_manager: # set None as default stage_manager = self.pipeline_stage_manager layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - module_policy[LlamaModel] = ModulePolicyDescription(method_replacement={ + method_replacement = { 'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index) - }) - return module_policy + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaModel) + return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" @@ -158,7 +161,7 @@ def get_held_layers(self) -> List[Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in bert model""" + """No shared params in llama model""" return [] @@ -179,8 +182,43 @@ def module_policy(self): ]) } policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': partial(llama_for_causal_lm_forward, stage_manager=stage_manager, stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaForCausalLM) return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.model.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.model.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.model.norm) + held_layers.append(module.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + llama_model = self.model.model + if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight): + # tie weights + return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}] + return [] + class LlamaForSequenceClassificationPolicy(LlamaPolicy): @@ -199,8 +237,42 @@ def module_policy(self): ]) } policy.update(new_item) + # to be confirmed + if self.pipeline_stage_manager: + # set None as default + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': + partial(llama_for_sequence_classification_forward, + stage_manager=stage_manager, + stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaForSequenceClassification) return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.model.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.model.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.model.norm) + held_layers.append(module.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama for sequence classification model""" + return [] + def llama_model_forward( self: LlamaModel, diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index ac70138e3f8f..b9e0310780af 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -52,7 +52,7 @@ def data_gen_for_sequence_classification(): loss_fn = lambda x: x.loss config = transformers.GPT2Config(n_layer=2, - n_head=2, + n_head=4, vocab_size=50258, attn_pdrop=0, embd_pdrop=0, diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py index 81c183d3230e..8fd9ed099478 100644 --- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -49,21 +49,19 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la x = torch.randint(0, 1000, (2, 3)).cuda() hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name == 'transformers_llama': - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask) - assert output['hidden_states'].shape == (2, 3, 128) - else: - attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - # print(output[0].shape) - assert output[0].shape == (2, 3, 128) + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask) + assert output['hidden_states'].shape == (2, 3, 128) + else: + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + assert output[0] is not None torch.cuda.empty_cache() From 701257a3754279a8f088ea1b91cef9ea128cab61 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Thu, 13 Jul 2023 12:47:26 +0800 Subject: [PATCH 21/80] [pipeline] add bloom model pipeline (#4210) * bloom policy * llama pipeline forward and tests * fix the output and attention_mask * fix name * bind argument to policy * finish bloom model * test shard gpt2 * clear cache --- colossalai/shardformer/policies/bloom.py | 227 +++++++++++++++++- tests/kit/model_zoo/transformers/gpt.py | 20 +- .../test_model/test_shard_bloom_pipeline.py | 84 +++++++ .../test_model/test_shard_gpt2.py | 1 + 4 files changed, 322 insertions(+), 10 deletions(-) create mode 100644 tests/test_shardformer/test_model/test_shard_bloom_pipeline.py diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 662ff5b4977a..5cfc1ab29edf 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,11 +1,26 @@ +import warnings +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch import torch.nn as nn +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.bloom.modeling_bloom import BloomModel +from transformers.utils import logging import colossalai.shardformer.layer as col_nn +from colossalai.pipeline.stage_manager import PipelineStageManager from .._utils import getattr_, setattr_ from ..modeling.bloom import build_bloom_alibi_tensor_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +logger = logging.get_logger(__name__) + class BloomPolicy(Policy): @@ -110,7 +125,46 @@ def postprocess(self): class BloomModelPolicy(BloomPolicy): - pass + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + from transformers.models.bloom.modeling_bloom import BloomModel + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + policy[BloomModel] = ModulePolicyDescription(method_replacement={ + "forward": + partial(bloom_model_forward, stage_manager=self.pipeline_stage_manager, stage_index=stage_index) + }) + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.word_embeddings) + held_layers.append(module.word_embeddings_layernorm) + + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''no shared params in bloommodel''' + return [] class BloomForCausalLMPolicy(BloomPolicy): @@ -181,3 +235,174 @@ def module_policy(self): class BloomForQuestionAnsweringPolicy(BloomPolicy): # No head sharding as the output features is only 2 pass + + +def bloom_model_forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + **deprecated_arguments, +) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # add warnings here + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + # case: First stage of training + if stage_manager.is_first_stage(): + # check input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + # initialize in the first stage and then pass to the next stage + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + + # extra recording tensor should be generated in the first stage + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] # source_len + + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + # causal_mask is constructed every stage and its input is passed through different stages + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + start_idx, end_idx = stage_index[0], stage_index[1] + for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx])): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + + if use_cache is True: + presents = presents + (outputs[1],) + if output_attentions: + all_self_attentions = all_self_attentions + \ + (outputs[2 if use_cache else 1],) + + if stage_manager.is_last_stage(): + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + # TODO: deal with all_hidden_states, all_self_attentions, presents + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + # attention_mask is not returned ; presents = past_key_values + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + else: + # always return dict for imediate stage + return {'hidden_states': hidden_states} diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index b9e0310780af..f9a0888ff80d 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -51,15 +51,17 @@ def data_gen_for_sequence_classification(): loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean() loss_fn = lambda x: x.loss -config = transformers.GPT2Config(n_layer=2, - n_head=4, - vocab_size=50258, - attn_pdrop=0, - embd_pdrop=0, - resid_pdrop=0, - summary_first_dropout=0, - hidden_dropout=0, - problem_type="single_label_classification") +config = transformers.GPT2Config( + n_layer=2, + n_head=4, + #n_embd=128, + vocab_size=50258, + attn_pdrop=0, + embd_pdrop=0, + resid_pdrop=0, + summary_first_dropout=0, + hidden_dropout=0, + problem_type="single_label_classification") # register the following models model_zoo.register(name='transformers_gpt', diff --git a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py new file mode 100644 index 000000000000..31760d692694 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py @@ -0,0 +1,84 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + pass + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_bloom +def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + x = torch.randint(0, 1000, (2, 3)).cuda() + hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32).cuda() + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name == 'transformers_bloom': + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask) + assert output['hidden_states'].shape == (2, 3, 64) + else: + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + assert output[0].shape == (2, 3, 64) + + torch.cuda.empty_cache() + + +def check_bloom(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bloom_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom(): + spawn(check_bloom, 4) + + +if __name__ == "__main__": + test_bloom() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 96c4b90a8075..9e5608e7fdcf 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -70,6 +70,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @parameterize('use_lazy_init', [False, True]) +@clear_cache_before_run() def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): From 6c9b16181061996dbe624de1307e74b60ec9ca8f Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 13 Jul 2023 15:34:06 +0800 Subject: [PATCH 22/80] [pipeline] Add Pipeline Forward for GPT2Model Shardformer (#4224) * * fix typehint & docstring in sharder.py * * update pipeline forward for GPT2Model * * add test for pipeline forward of GPT2Model * * add cache cleaning in gpt2 test * * change assert to raise command --- colossalai/shardformer/layer/linear.py | 2 +- colossalai/shardformer/policies/gpt2.py | 270 +++++++++++++++++- colossalai/shardformer/shard/sharder.py | 15 +- .../test_model/test_shard_gpt2.py | 2 + .../test_model/test_shard_gpt2_pipeline.py | 77 +++++ 5 files changed, 357 insertions(+), 9 deletions(-) create mode 100644 tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index a8439f303bd1..383d9b3f533a 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -129,7 +129,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis **kwargs) with torch.no_grad(): - # the weigh to the linear layer is a transpose + # the weight to the linear layer is a transpose # thus shard on row is equal to shard on column sharded_weight = shard_rowwise(module.weight.data, process_group) linear_1d.weight.data.copy_(sharded_weight) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 8f9d90e67e59..ffba27a50e72 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,6 +1,14 @@ -import torch.nn as nn +import logging +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import Module import colossalai.shardformer.layer as col_nn +from colossalai.pipeline.stage_manager import PipelineStageManager from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -119,6 +127,46 @@ class GPT2ModelPolicy(GPT2Policy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + + policy = super().module_policy() + if self.pipeline_stage_manager: + # set None as default + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': + partial(GPT2PipelineForwards.gpt2_model_forward, + stage_manager=stage_manager, + stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=GPT2Model) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.wpe) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # TODO: check whether there is shared param in gpt2model + """No shared params in gpt2 model.""" + return [] + # GPT2LMHeadModel class GPT2LMHeadModelPolicy(GPT2Policy): @@ -194,3 +242,223 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy): def __init__(self) -> None: super().__init__() + + +class GPT2PipelineForwards: + ''' + This class serves as a micro library for forward function substitution of GPT2 models + under pipeline setting. + ''' + + @staticmethod + def gpt2_model_forward( + self: 'GPT2Model', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'BaseModelOutputWithPastAndCrossAttentions']: + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. + # Please refer to original code of transformers for more details. + + from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions + + # Preprocess passed in arguments + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + input_shape = input_ids.size() + input_ids = input_ids.view(-1, seq_length) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_length) + else: + if hidden_states is None: + raise ValueError("hidden_states shouln't be None for stages other than the first stage.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if stage_manager.is_first_stage(): + if position_ids is not None: + position_ids = position_ids.view(-1, seq_length) + else: + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logging.warning('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logging.warning('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logging.warning('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logging.warning('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + if self.gradient_checkpointing and self.training: + if use_cache: + logging.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # Going through held blocks. + start_idx, end_idx = stage_index[0], stage_index[1] + for i in range(start_idx, end_idx): + block = self.h[i] + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=None, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + if stage_manager.is_last_stage(): + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + else: + # always return dict for intermediate stage + return {'hidden_states': hidden_states} diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 882f93c7acc5..5e0b572e259c 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -72,17 +72,18 @@ def _recursive_replace_layer( attr_replacement: Dict[str, Any], param_replacement: List[Callable], method_replacement: Dict[str, Callable], - sub_module_replacement: List[Callable], + sub_module_replacement: List[SubModuleReplacementDescription], ) -> None: r""" Reverse the replace layer operation Args: - layer (torch.nn.Module): The object of layer to shard - origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name. - attr_replacement (Dict): The attribute dict to modify + module (torch.nn.Module): The object of layer to shard + origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name + attr_replacement (Dict[str, Any]): The attribute dict to modify param_replacement (List[Callable]): The function list to get parameter shard information in policy - sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy + method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement + sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy """ if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ (module.__class__ == origin_cls): @@ -111,7 +112,7 @@ def _replace_attr( Replace the attribute of the layer Args: - layer (:class:`torch.nn.Module`): The object of layer to shard + module (:class:`torch.nn.Module`): The object of layer to shard attr_replacement (Dict): The attribute dict to modify """ for k, v in attr_replacement.items(): @@ -126,7 +127,7 @@ def _replace_param( Replace the parameter of the layer Args: - layer (:class:`torch.nn.Module`): The object of layer to shard + module (:class:`torch.nn.Module`): The object of layer to shard param_replacement (List[Callable]): The function list to get parameter shard information in policy """ for param_func in param_replacement: diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 9e5608e7fdcf..552c6e2f4d53 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -65,6 +65,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo assert torch.allclose( org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" + torch.cuda.empty_cache() @parameterize('enable_fused_normalization', [True, False]) @@ -77,6 +78,7 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py new file mode 100644 index 000000000000..5f92f638f863 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -0,0 +1,77 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # TODO: add tests for forward/backward later + pass + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_gpt2 +def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name != "transformers_gpt": + continue + + inputs = data_gen_fn() + inputs = {k: v.cuda() for k, v in inputs.items()} + + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + org_model.train() + org_output = org_model(**inputs) + hidden_state_shape = org_output['last_hidden_state'].shape + + if stage_manager.is_first_stage(): + output = sharded_model(**inputs) + assert output['hidden_states'].shape == hidden_state_shape + else: + attention_mask = inputs['attention_mask'] + hidden_states = torch.zeros(*hidden_state_shape).cuda() + output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask) + if stage_manager.is_last_stage(): + assert output['last_hidden_state'].shape == hidden_state_shape + else: + assert output['hidden_states'].shape == hidden_state_shape + + torch.cuda.empty_cache() + + +def check_gpt2(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gpt2_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gpt2(): + spawn(check_gpt2, 4) + + +if __name__ == "__main__": + test_gpt2() From b7673a1c9e1f904369a7ec4475e5f455468e8574 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 14 Jul 2023 09:51:53 +0800 Subject: [PATCH 23/80] [shardformer] fix base policy (#4229) --- colossalai/shardformer/policies/base_policy.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 68fde0115de6..69493bfb6007 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -156,7 +156,10 @@ def append_or_create_submodule_replacement( # append or create a new description if target_key in policy: - policy[target_key].sub_module_replacement.extend(description) + if policy[target_key].sub_module_replacement is None: + policy[target_key].sub_module_replacement = description + else: + policy[target_key].sub_module_replacement.extend(description) else: policy[target_key] = ModulePolicyDescription(sub_module_replacement=description) @@ -174,7 +177,10 @@ def append_or_create_method_replacement( target_key (Union[str, nn.Module]): the key of the policy to be updated """ if target_key in policy: - policy[target_key].method_replacement.update(description) + if policy[target_key].method_replacement is None: + policy[target_key].method_replacement = description + else: + policy[target_key].method_replacement.update(description) else: policy[target_key] = ModulePolicyDescription(method_replacement=description) From 59cbbe04ec1d3efe18f2a1b63d601b5cf23e0802 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 17 Jul 2023 15:21:51 +0800 Subject: [PATCH 24/80] [pipeline] add pipeline forward for variants of gpt2 (#4238) * add forward for GPTLMHeadModel * add test for gpt_lm * arranging get_held_layers method * arrange forward replacement * add forward for GPT2ForTokenClassification * add forward for GPT2ForSequenceClassification * fix test_shard_gpt2.py * add GPT2DoubleHeadsmodel & fix bugs * add id checking in get_shared_params --- colossalai/shardformer/policies/gpt2.py | 545 ++++++++++++++++-- .../test_model/test_shard_gpt2_pipeline.py | 50 +- 2 files changed, 529 insertions(+), 66 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index ffba27a50e72..5d6f47636587 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,11 +1,11 @@ import logging from functools import partial from types import MethodType -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from torch import Tensor -from torch.nn import Module +from torch import Tensor, nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager @@ -48,6 +48,10 @@ def module_policy(self): suffix="wte", target_module=col_nn.VocabParallelEmbedding1D, ), + SubModuleReplacementDescription( + suffix="drop", + target_module=col_nn.DropoutForParallelInput, + ), ]) policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -120,6 +124,45 @@ def module_policy(self): def postprocess(self): return self.model + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == 'GPT2Model': + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.wpe) + held_layers.append(module.drop) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == 'GPT2Model': + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + # GPT2Model class GPT2ModelPolicy(GPT2Policy): @@ -131,40 +174,16 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Model policy = super().module_policy() - if self.pipeline_stage_manager: - # set None as default - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': - partial(GPT2PipelineForwards.gpt2_model_forward, - stage_manager=stage_manager, - stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=GPT2Model) + self.set_pipeline_forward(model_cls=GPT2Model, + new_forward=GPT2PipelineForwards.gpt2_model_forward, + policy=policy) return policy - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - module = self.model - stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.wte) - held_layers.append(module.wpe) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) - return held_layers + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() def get_shared_params(self) -> List[Dict[int, Tensor]]: - # TODO: check whether there is shared param in gpt2model - """No shared params in gpt2 model.""" + """No shared params in GPT2Model.""" return [] @@ -188,10 +207,31 @@ def module_policy(self): ]) } module_policy.update(addon_module) + + self.set_pipeline_forward(model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy) return module_policy + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''The weights of wte and lm_head are shared.''' + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + else: + return [] + def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism \ + and self.pipeline_stage_manager is None: binding_map = {"transformer.wte.weight": "lm_head.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -199,7 +239,7 @@ def postprocess(self): return self.model -# GPT22DoubleHeadsModel +# GPT2DoubleHeadsModel class GPT2DoubleHeadsModelPolicy(GPT2Policy): def __init__(self) -> None: @@ -219,10 +259,38 @@ def module_policy(self): ]) } module_policy.update(addon_module) + + self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel, + new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward, + policy=module_policy) + return module_policy + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + multiple_choice_head = self.model.multiple_choice_head + held_layers.append(self.model.lm_head) + held_layers.append(multiple_choice_head.summary) + held_layers.append(multiple_choice_head.activation) + held_layers.append(multiple_choice_head.first_dropout) + held_layers.append(multiple_choice_head.last_dropout) + + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''The weights of wte and lm_head are shared.''' + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + else: + return [] + def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism \ + and self.pipeline_stage_manager is None: binding_map = {"transformer.wte.weight": "lm_head.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -236,6 +304,36 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2ForTokenClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput) + ]) + } + module_policy.update(addon_module) + + self.set_pipeline_forward(model_cls=GPT2ForTokenClassification, + new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward, + policy=module_policy) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2ForTokenClassification.""" + return [] + # GPT2ForSequenceClassification class GPT2ForSequenceClassificationPolicy(GPT2Policy): @@ -243,6 +341,25 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification + + module_policy = super().module_policy() + self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification, + new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, + policy=module_policy) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2ForTokenClassification.""" + return [] + class GPT2PipelineForwards: ''' @@ -299,8 +416,7 @@ def gpt2_model_forward( if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, seq_length) else: - if hidden_states is None: - raise ValueError("hidden_states shouln't be None for stages other than the first stage.") + assert hidden_states is not None input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device @@ -462,3 +578,356 @@ def custom_forward(*inputs): else: # always return dict for intermediate stage return {'hidden_states': hidden_states} + + @staticmethod + def gpt2_lmhead_model_forward( + self: 'GPT2LMHeadModel', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'CausalLMOutputWithCrossAttentions']: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ + + from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @staticmethod + def gpt2_double_heads_model_forward( + self: 'GPT2DoubleHeadsModel', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'GPT2DoubleHeadsModelOutput']: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward. + Please refer to original code of transformers for more details. + ```""" + from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_token_classification_forward( + self: 'GPT2ForTokenClassification', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'TokenClassifierOutput']: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward. + # Please refer to original code of transformers for more details. + """ + + from transformers.modeling_outputs import TokenClassifierOutput + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_sequence_classification_forward( + self: 'GPT2ForSequenceClassification', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'SequenceClassifierOutputWithPast']: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward. + # Please refer to original code of transformers for more details. + """ + from transformers.modeling_outputs import SequenceClassifierOutputWithPast + + if input_ids is not None: + batch_size, _ = input_ids.shape[:2] + else: + batch_size, _ = hidden_states.shape[:2] + assert (self.config.pad_token_id is not None + or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined." + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logging.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py index 5f92f638f863..dd439a394827 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -5,15 +5,9 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward +from tests.test_shardformer.test_model._utils import build_pipeline_model def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -21,8 +15,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo pass -@parameterize('enable_fused_normalization', [False]) @parameterize('enable_tensor_parallelism', [False]) +@parameterize('enable_fused_normalization', [False]) @parameterize('use_lazy_init', [False]) #TODO: merge this into test_shard_gpt2 def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): @@ -32,30 +26,30 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name != "transformers_gpt": - continue + for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} + input_ids, _ = inputs['input_ids'], inputs['attention_mask'] + batch_size, seq_len = input_ids.shape + hidden_size = 768 + hidden_state_shape = (batch_size, seq_len, hidden_size) - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - org_model.train() - org_output = org_model(**inputs) - hidden_state_shape = org_output['last_hidden_state'].shape - - if stage_manager.is_first_stage(): - output = sharded_model(**inputs) - assert output['hidden_states'].shape == hidden_state_shape - else: - attention_mask = inputs['attention_mask'] + if not stage_manager.is_first_stage(): + # change inputs if not the first stage hidden_states = torch.zeros(*hidden_state_shape).cuda() - output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask) - if stage_manager.is_last_stage(): - assert output['last_hidden_state'].shape == hidden_state_shape - else: - assert output['hidden_states'].shape == hidden_state_shape + inputs['input_ids'] = None + inputs['hidden_states'] = hidden_states + + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + sharded_model.train() + output = sharded_model(**inputs) + if stage_manager.is_last_stage(): + if name != 'transformers_gpt': + assert output.loss is not None + else: + assert output['hidden_states'].shape == hidden_state_shape torch.cuda.empty_cache() From f2695759382bc622ad2eb3d313f8993549e5701e Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 17 Jul 2023 16:12:20 +0800 Subject: [PATCH 25/80] [pipeline] All bert models (#4233) * bloom policy * llama pipeline forward and tests * fix the output and attention_mask * fix name * bind argument to policy * Revert "bloom policy" This reverts commit 8dee68a0a22568dbeed6d4563372b25e1e825fb0. This policy should be revert and copied to feature/bloom * revert the bloom changes * cancel unneeded inputs * gpt * finish llama * causal lm and sequence classification * revision * add pure pipeline test * finish some bert models * finish all bert models * finish bert tests * fix bugs * fix bugs * fix test pipeline * fix data gen for qa * update the set pipeline forward * shared params * fix bugs --- colossalai/pipeline/p2p.py | 5 +- colossalai/pipeline/schedule/one_f_one_b.py | 1 - .../shardformer/policies/auto_policy.py | 2 + colossalai/shardformer/policies/bert.py | 832 +++++++++++++++--- colossalai/shardformer/policies/llama.py | 6 +- tests/kit/model_zoo/torchrec/__init__.py | 2 +- tests/kit/model_zoo/transformers/bert.py | 17 + .../test_bert_for_pretraining_model.py | 19 +- ...ad_model.py => test_bert_lm_head_model.py} | 33 +- .../test_policy/test_bert_model.py | 16 +- tests/test_shardformer/test_model/_utils.py | 1 - .../test_model/test_pure_pipeline.py | 164 ++++ .../test_model/test_shard_bert_pipeline.py | 28 +- 13 files changed, 985 insertions(+), 141 deletions(-) rename tests/test_pipeline/test_policy/{test_bert_lmhead_model.py => test_bert_lm_head_model.py} (73%) create mode 100644 tests/test_shardformer/test_model/test_pure_pipeline.py diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 203b7439d7ef..2fd135d5475d 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -64,7 +64,10 @@ def _broadcast_object_list(object_list: List[Any], my_rank = dist.get_rank() # Serialize object_list elements to tensors on src rank. if my_rank == src: - tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) + if torch.__version__ >= "1.13.0": + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=device) for obj in object_list]) + else: + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) object_sizes_tensor = torch.cat(size_list) else: object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index a8933bfbb4da..6ed3055d689b 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -205,7 +205,6 @@ def forward_backward_step(self, # the backward pass. input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) - input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) if last_iteration: diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 0ad9a3e95a0e..ccdb33b2efe5 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -42,6 +42,8 @@ class PolicyLocation: PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"), "transformers.models.bert.modeling_bert.BertForMultipleChoice": PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"), + "transformers.models.bert.modeling_bert.BertForQuestionAnswering": + PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"), # LLaMA "transformers.models.llama.modeling_llama.LlamaModel": diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 2b2c003ffb04..1af26f50484c 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,22 +1,30 @@ from functools import partial from types import MethodType -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from torch import Tensor from torch.nn import CrossEntropyLoss, Module from transformers.modeling_outputs import ( - BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, ) from transformers.models.bert.modeling_bert import ( BertForMaskedLM, + BertForMultipleChoice, BertForNextSentencePrediction, BertForPreTraining, BertForPreTrainingOutput, + BertForQuestionAnswering, + BertForSequenceClassification, + BertForTokenClassification, BertLMHeadModel, BertModel, ) @@ -31,9 +39,9 @@ logger = logging.get_logger(__name__) __all__ = [ - 'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', + 'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMdHeadModelPolicy', 'BertForMaskedLMPolicy', 'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy', - 'BertForMultipleChoicePolicy' + 'BertForMultipleChoicePolicy', 'BertForQuestionAnsweringPolicy' ] @@ -172,6 +180,25 @@ def add_lm_head_policy(self, base_policy): def postprocess(self): return self.model + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "BertModel": + module = self.model + else: + module = self.model.bert + + layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + + return + # BertModel class BertModelPolicy(BertPolicy): @@ -180,13 +207,10 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() + policy = super().module_policy() from transformers.models.bert.modeling_bert import BertModel - if self.pipeline_stage_manager: - # set None as default - module_policy[BertModel] = ModulePolicyDescription( - method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)}) - return module_policy + self.set_pipeline_forward(model_cls=BertModel, new_forward=bert_model_forward, policy=policy) + return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" @@ -214,15 +238,17 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + from transformers.models.bert.modeling_bert import BertForPreTraining + self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy) + return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage""" module = self.model stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) held_layers = [] if stage_manager.is_first_stage(): held_layers.append(module.bert.embeddings) @@ -237,11 +263,18 @@ def get_held_layers(self) -> List[Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''No shared params in bertmodel''' + model = self.model + if self.pipeline_stage_manager: + if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight): + #tie weights + return [{ + 0: model.bert.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight + }] return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -256,9 +289,11 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + from transformers.models.bert.modeling_bert import BertLMHeadModel + self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy) + return policy def get_held_layers(self) -> List[Module]: """ @@ -267,7 +302,7 @@ def get_held_layers(self) -> List[Module]: module = self.model held_layers = [] stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(self.model.bert.encoder.layer), stage_manager.num_stages) + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.bert.embeddings) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) @@ -278,11 +313,18 @@ def get_held_layers(self) -> List[Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''No shared params in bertmodel''' + bert_model = self.model.bert + if self.pipeline_stage_manager: + if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): + #tie weights + return [{ + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight + }] return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -297,12 +339,42 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + from transformers.models.bert.modeling_bert import BertForMaskedLM + self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy) + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.cls) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + bert_model = self.model.bert + if self.pipeline_stage_manager: + if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): + #tie weights + return [{ + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight + }] + return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -319,7 +391,7 @@ def __init__(self) -> None: def module_policy(self): from transformers.models.bert.modeling_bert import BertForSequenceClassification - module_policy = super().module_policy() + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: addon_module = { @@ -331,8 +403,35 @@ def module_policy(self): ) ]) } - module_policy.update(addon_module) - return module_policy + policy.update(addon_module) + + self.set_pipeline_forward(model_cls=BertForSequenceClassification, + new_forward=bert_for_sequence_classification_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.dropout) + held_layers.append(module.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] # BertForTokenClassification @@ -344,7 +443,7 @@ def __init__(self) -> None: def module_policy(self): from transformers.models.bert.modeling_bert import BertForTokenClassification - module_policy = super().module_policy() + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: addon_module = { @@ -356,8 +455,35 @@ def module_policy(self): ) ]) } - module_policy.update(addon_module) - return module_policy + policy.update(addon_module) + + self.set_pipeline_forward(model_cls=BertForTokenClassification, + new_forward=bert_for_token_classification_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.dropout) + held_layers.append(module.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] # BertForNextSentencePrediction @@ -366,6 +492,36 @@ class BertForNextSentencePredictionPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + policy = super().module_policy() + from transformers.models.bert.modeling_bert import BertForNextSentencePrediction + self.set_pipeline_forward(model_cls=BertForNextSentencePrediction, + new_forward=bert_for_next_sentence_prediction_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.cls) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] + # BertForMultipleChoice class BertForMultipleChoicePolicy(BertPolicy): @@ -376,7 +532,7 @@ def __init__(self) -> None: def module_policy(self): from transformers.models.bert.modeling_bert import BertForMultipleChoice - module_policy = super().module_policy() + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: addon_module = { @@ -388,28 +544,91 @@ def module_policy(self): ) ]) } - module_policy.update(addon_module) - return module_policy + policy.update(addon_module) + + self.set_pipeline_forward(model_cls=BertForMultipleChoice, + new_forward=bert_for_multiple_choice_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.dropout) + held_layers.append(module.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] + + +class BertForQuestionAnsweringPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bert.modeling_bert import BertForQuestionAnswering + policy = super().module_policy() + self.set_pipeline_forward(model_cls=BertForQuestionAnswering, + new_forward=bert_for_question_answering_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + held_layers = [] + stage_manager = self.pipeline_stage_manager + layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.bert.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.bert.pooler) + held_layers.append(module.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - # labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + stage_index: Optional[List[int]] = None, ): # TODO: add explaination of the output here. r""" @@ -528,14 +747,10 @@ def bert_model_forward( use_cache = False next_decoder_cache = () if use_cache else None - # calculate the num_layers - num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - + start_idx, end_idx = stage_index[0], stage_index[1] # layer_outputs layer_outputs = hidden_states if hidden_states is not None else None - for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): + for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: encoder_attention_mask = encoder_extended_attention_mask @@ -593,8 +808,9 @@ def custom_forward(*inputs): return (sequence_output, pooled_output) + layer_outputs[1:] # return dict is not supported at this moment else: - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, @@ -624,6 +840,7 @@ def bert_for_pretraining_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. @@ -637,18 +854,21 @@ def bert_for_pretraining_forward( logger.warning_once('return_dict is not supported for pipeline models at the moment') return_dict = False - outputs = bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states if hidden_states is not None else None) + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None, + stage_index=stage_index, + ) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -684,23 +904,26 @@ def bert_for_pretraining_forward( } -def bert_lmhead_forward(self: BertLMHeadModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.Tensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_manager: Optional[PipelineStageManager] = None): +def bert_lm_head_model_forward( + self: BertLMHeadModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -754,7 +977,8 @@ def bert_lmhead_forward(self: BertLMHeadModel, output_hidden_states=output_hidden_states, return_dict=return_dict, stage_manager=stage_manager, - hidden_states=hidden_states if hidden_states is not None else None) + hidden_states=hidden_states if hidden_states is not None else None, + stage_index=stage_index) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -806,15 +1030,66 @@ def bert_for_masked_lm_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, ): - #-> Union[Tuple[torch.Tensor], MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ - pass + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} def bert_for_next_sentence_prediction_forward( @@ -831,6 +1106,7 @@ def bert_for_next_sentence_prediction_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, **kwargs, ): #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: @@ -881,6 +1157,259 @@ def bert_for_next_sentence_prediction_forward( return_dict = False return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index) + + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + # intermediate stage always return dict + return {'hidden_states': hidden_states} + + +def bert_for_sequence_classification_forward( + self: BertForSequenceClassification, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index) + + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def bert_for_token_classification_forward( + self: BertForTokenClassification, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def bert_for_multiple_choice_forward( + self: BertForMultipleChoice, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # in our pipeline design,input ids are copied for every stage and shouldn't be none + # the input_ids for multiple choice model is [batch_size, num_choices, sequence_length] + if stage_manager.is_last_stage(): + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None else None) + outputs = bert_model_forward( self.bert, input_ids, @@ -892,27 +1421,128 @@ def bert_for_next_sentence_prediction_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, ) if stage_manager.is_last_stage(): pooled_output = outputs[1] - seq_relationship_scores = self.cls(pooled_output) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) - next_sentence_loss = None + loss = None if labels is not None: loss_fct = CrossEntropyLoss() - next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + loss = loss_fct(reshaped_logits, labels) if not return_dict: - output = (seq_relationship_scores,) + outputs[2:] - return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output - return NextSentencePredictorOutput( - loss=next_sentence_loss, - logits=seq_relationship_scores, + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def bert_for_question_answering_forward( + self: BertForQuestionAnswering, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, +): + # NOTE: the arg start_position and end_position are used only for the last stage + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) else: hidden_states = outputs.get('hidden_states') - # intermediate stage always return dict return {'hidden_states': hidden_states} diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index a3ea807269bb..b3757452c314 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -212,11 +212,13 @@ def get_held_layers(self) -> List[Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in llama model""" llama_model = self.model.model if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight): # tie weights - return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}] + return [{ + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight + }] return [] diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 43952e6998cf..4a19f2449602 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -from .torchrec import * +#from .torchrec import * diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index d2d3de7b7bee..1993af51ad63 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -87,6 +87,17 @@ def data_gen_for_mcq(): return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) +def data_gen_for_qa(): + # generating data for question answering + # no need for labels and use start and end position instead + data = data_gen() + start_positions = torch.tensor([0], dtype=torch.int64) + data['start_positions'] = start_positions + end_positions = torch.tensor([1], dtype=torch.int64) + data['end_positions'] = end_positions + return data + + # define output transform function output_transform_fn = lambda x: x @@ -150,3 +161,9 @@ def data_gen_for_mcq(): output_transform_fn=output_transform_fn, loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bert_for_question_answering', + model_fn=lambda: transformers.BertForQuestionAnswering(config), + data_gen_fn=data_gen_for_qa, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py index 97d7d2fa538a..6a8d7b636375 100644 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -7,6 +7,7 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -35,16 +36,20 @@ def check_bert_for_pretraining_forward(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() # print(rank) + layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) x = torch.randint(0, 1000, (2, 3)) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) if stage_manager.stage == 0: attention_mask = torch.ones_like(x) - output = bert_for_pretraining_forward(self=model, - input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager) - print(output['hidden_states'].shape) + output = bert_for_pretraining_forward( + self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index, + ) assert output['hidden_states'].shape == (2, 3, 768) else: @@ -52,8 +57,8 @@ def check_bert_for_pretraining_forward(): output = bert_for_pretraining_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, - stage_manager=stage_manager) - print(output[0].shape) + stage_manager=stage_manager, + stage_index=stage_index) assert output[0].shape == (2, 3, 30522) # assert output[1].shape == (2, 768) diff --git a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py similarity index 73% rename from tests/test_pipeline/test_policy/test_bert_lmhead_model.py rename to tests/test_pipeline/test_policy/test_bert_lm_head_model.py index b14dadf29e3c..cd47f7a33c4b 100644 --- a/tests/test_pipeline/test_policy/test_bert_lmhead_model.py +++ b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py @@ -7,12 +7,13 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lmhead_forward +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lm_head_model_forward from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn -def check_bert_lmhead_forward(): +def check_bert_lm_head_model_forward(): configuration = BertConfig() model = BertLMHeadModel(configuration) DP_DIM, PP_DIM = 0, 1 @@ -35,24 +36,28 @@ def check_bert_lmhead_forward(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() # print(rank) - + layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) x = torch.randint(0, 1000, (2, 3)) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) if stage_manager.stage == 0: attention_mask = torch.ones_like(x) - output = bert_lmhead_forward(self=model, - input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager) + + output = bert_lm_head_model_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) print(output['hidden_states'].shape) assert output['hidden_states'].shape == (2, 3, 768) else: attention_mask = torch.ones((2, 3)) - output = bert_lmhead_forward(self=model, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) + output = bert_lm_head_model_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) print(output[0].shape) assert output[0].shape == (2, 3, 30522) @@ -93,7 +98,7 @@ def check_bert_lmhead_policy(): def run_dist_model(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_lmhead_forward() + check_bert_lm_head_model_forward() def run_dist_policy(rank, world_size, port): @@ -103,7 +108,7 @@ def run_dist_policy(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() -def test_bert_lmhead_forward(): +def test_bert_lm_head_model_forward(): spawn(run_dist_model, 4) @@ -115,5 +120,5 @@ def test_bert_lmhead_policy(): if __name__ == "__main__": """test the bert for pretraining model forward and bert for pretraining model policy""" - test_bert_lmhead_forward() + test_bert_lm_head_model_forward() test_bert_lmhead_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index f5a443309cb2..f116bc761aa7 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -6,12 +6,14 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.bert import BertModelPolicy, bert_model_forward from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn def check_bert_model_forward(): + # this test may crash for internet reasons model = BertModel.from_pretrained('bert-base-uncased') DP_DIM, PP_DIM = 0, 1 DP_SIZE, PP_SIZE = 2, 2 @@ -34,20 +36,25 @@ def check_bert_model_forward(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() # print(rank) - + layers_per_stage = Policy.distribute_layers(len(model.encoder.layer), 2) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) x = torch.randint(0, 1000, (2, 3)) hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) if stage_manager.stage == 0: attention_mask = torch.ones_like(x) - output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - print(output['hidden_states'].shape) + output = bert_model_forward(self=model, + input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) assert output['hidden_states'].shape == (2, 3, 768) else: attention_mask = torch.ones((2, 3)) output = bert_model_forward(self=model, hidden_states=hidden_states, attention_mask=attention_mask, - stage_manager=stage_manager) + stage_manager=stage_manager, + stage_index=stage_index) print(output[0].shape) assert output[0].shape == (2, 3, 768) @@ -112,4 +119,3 @@ def test_bert_model_policy(): """test the bert model forward and bert model policy""" #test_bert_model_forward() test_bert_model_policy() - # this test need config to run diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index f26c6622da7e..825d6df6bb5e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -49,7 +49,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, # prepare input data = data_gen_fn() data = {k: v.cuda() for k, v in data.items()} - # switch to train mode original_model.train() sharded_model.train() diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py new file mode 100644 index 000000000000..24cda193a5e6 --- /dev/null +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -0,0 +1,164 @@ +import random +from contextlib import nullcontext +from typing import Any, Callable, Iterator, List, Optional, Tuple + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from torch import Tensor +from torch.nn import Module +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + + +class PipelineOptimizer(OptimizerWrapper): + + def __init__(self, optim: Optimizer, model: Module): + super().__init__(optim) + params = set(model.parameters()) + new_param_groups = [] + for group in optim.param_groups: + params = [p for p in group['params'] if p in params] + new_param_groups.append({**group, 'params': params}) + optim.__setstate__({'param_groups': new_param_groups}) + # TODO: support amp + + +class PipelinedModel(ModelWrapper): + + def __init__(self, module: Module, shard_config: ShardConfig, stage_manager: PipelineStageManager) -> None: + self.stage_manager = stage_manager + shardformer = ShardFormer(shard_config) + module, self.shared_params = shardformer.optimize(module) + self.shared_param_process_groups = [] + super().__init__(module) + + +def prepare_dataloader(dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0): + sampler = DistributedSampler( + dataset, + #rank=self.pg_mesh.coordinate(DP_AXIS), + shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + ) + + +def execute_pipeline( + data_iter: Iterator, + model: PipelinedModel, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: PipelineOptimizer, + return_loss: bool = True, + return_outputs: bool = False, + schedule: OneForwardOneBackwardSchedule = None, +) -> dict: + # return loss or outputs if needed + outputs = schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, return_outputs) + return outputs + + +class data_iter(): + + def __getitem__(self, x): + return torch.randint(0, 100, (4, 128)).cuda() + + +def loss(x, y): + return (x[0].float().mean() - y[0].float().mean()) + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + PP_DIM = 0 + PP_SIZE = 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + from datasets import load_dataset + + #dataset = load_dataset("open_subtitles", lang1="fi", lang2="hi") + pg_mesh = ProcessGroupMesh(PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + num_microbatches = 2 + org_model = model_fn().cuda() + optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3) + #dataloader=prepare_dataloader(dataset=dataset['train'],batch_size=4) + schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager) + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + pipeline_stage_manager=stage_manager) + pipelined_model = PipelinedModel(org_model, shard_config, stage_manager) + pp_optimizer = PipelineOptimizer(optimizer, pipelined_model) + data_it = iter(data_iter()) + results = execute_pipeline(data_it, pipelined_model, loss, pp_optimizer, schedule=schedule) + if stage_manager.is_last_stage(): + assert results['loss'] is not None + assert results['outputs'] is None + torch.cuda.empty_cache() + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, 2) + + +if __name__ == "__main__": + test_llama() diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py index 9cca5ec8bc51..4feaf982aa37 100644 --- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py @@ -45,25 +45,37 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') - x = torch.randint(0, 1000, (2, 3)).cuda() - hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name == 'transformers_bert': - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + if name == 'transformers_bert_for_mcq': + x = torch.randint(0, 1000, (2, 3, 3)).cuda() + attention_mask = torch.ones_like(x).cuda() + if stage_manager.stage == 0: + output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) + assert output['hidden_states'].shape == (6, 3, 128) + else: + hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda() + output = sharded_model(input_ids=x, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + assert output[0].shape == (2, 3) + else: + x = torch.randint(0, 1000, (2, 3)).cuda() + # one batch, 2 single sentences, each sentence has 3 tokens + hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() if stage_manager.stage == 0: attention_mask = torch.ones_like(x).cuda() output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - # print(output['hidden_states'].shape) assert output['hidden_states'].shape == (2, 3, 128) else: attention_mask = torch.ones((2, 3)).cuda() output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask, stage_manager=stage_manager) - # print(output[0].shape) - assert output[0].shape == (2, 3, 128) + assert output[0].shape[0] == 2 torch.cuda.empty_cache() From db33d4a2a2320b1efc1cec6af2a95c467a53e5fa Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 17 Jul 2023 17:10:55 +0800 Subject: [PATCH 26/80] [pipeline] finish bloom models pipeline and tests (#4223) * bloom policy * llama pipeline forward and tests * fix the output and attention_mask * fix name * bind argument to policy * finish bloom model * test shard gpt2 * clear cache * support all bloom models * add bloom models policies * finish bloom pipeline and tests * add set pipeline * finish bloom --- colossalai/shardformer/policies/bloom.py | 563 +++++++++++++++++- .../test_model/test_shard_bloom_pipeline.py | 31 +- 2 files changed, 562 insertions(+), 32 deletions(-) diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 5cfc1ab29edf..8afaadefb696 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,15 +1,27 @@ import warnings from functools import partial from types import MethodType -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from torch import Tensor -from torch.nn import CrossEntropyLoss, Module -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.bloom.modeling_bloom import BloomModel +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.models.bloom.modeling_bloom import ( + BloomForCausalLM, + BloomForQuestionAnswering, + BloomForSequenceClassification, + BloomForTokenClassification, + BloomModel, +) from transformers.utils import logging import colossalai.shardformer.layer as col_nn @@ -123,6 +135,24 @@ def module_policy(self): def postprocess(self): return self.model + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "BloomModel": + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + return + class BloomModelPolicy(BloomPolicy): @@ -132,14 +162,7 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() from transformers.models.bloom.modeling_bloom import BloomModel - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - policy[BloomModel] = ModulePolicyDescription(method_replacement={ - "forward": - partial(bloom_model_forward, stage_manager=self.pipeline_stage_manager, stage_index=stage_index) - }) + self.set_pipeline_forward(model_cls=BloomModel, new_forward=bloom_model_forward, policy=policy) return policy def get_held_layers(self) -> List[Module]: @@ -163,7 +186,7 @@ def get_held_layers(self) -> List[Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''no shared params in bloommodel''' + '''no shared params in bloom model''' return [] @@ -180,10 +203,38 @@ def module_policy(self): policy=policy, target_key=BloomForCausalLM) + self.set_pipeline_forward(model_cls=BloomForCausalLM, new_forward=bloom_for_causal_lm_forward, policy=policy) return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.transformer.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.transformer.word_embeddings) + held_layers.append(module.transformer.word_embeddings_layernorm) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.transformer.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.transformer.ln_f) + held_layers.append(module.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + bloom_model = self.model + if self.pipeline_stage_manager: + if id(bloom_model.transformer.word_embeddings.weight) == id(bloom_model.lm_head.weight): + # tie weights + return [{ + 0: bloom_model.transformer.word_embeddings.weight, + self.stage_manager.num_stages - 1: bloom_model.lm_head.weight + }] + return [] + def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} for k, v in binding_map.items(): @@ -205,9 +256,31 @@ def module_policy(self): suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), policy=policy, target_key=BloomForSequenceClassification) - + self.set_pipeline_forward(model_cls=BloomForSequenceClassification, + new_forward=bloom_for_sequence_classification_forward, + policy=policy) return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.transformer.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.transformer.word_embeddings) + held_layers.append(module.transformer.word_embeddings_layernorm) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.transformer.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.transformer.ln_f) + held_layers.append(module.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in bloom for sequence classification model""" + return [] + class BloomForTokenClassificationPolicy(BloomPolicy): @@ -229,12 +302,63 @@ def module_policy(self): policy=policy, target_key=BloomForTokenClassification) + self.set_pipeline_forward(model_cls=BloomForTokenClassification, + new_forward=bloom_for_token_classification_forward, + policy=policy) + return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.transformer.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.transformer.word_embeddings) + held_layers.append(module.transformer.word_embeddings_layernorm) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.transformer.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.transformer.ln_f) + held_layers.append(module.dropout) + held_layers.append(module.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in bloom for token classification model""" + return [] + class BloomForQuestionAnsweringPolicy(BloomPolicy): # No head sharding as the output features is only 2 - pass + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForQuestionAnswering + policy = super().module_policy() + self.set_pipeline_forward(model_cls=BloomForQuestionAnswering, + new_forward=bloom_for_question_answering_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.transformer.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.transformer.word_embeddings) + held_layers.append(module.transformer.word_embeddings_layernorm) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.transformer.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.transformer.ln_f) + held_layers.append(module.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in bloom for question answering model""" + return [] def bloom_model_forward( @@ -406,3 +530,410 @@ def custom_forward(*inputs): else: # always return dict for imediate stage return {'hidden_states': hidden_states} + + +def bloom_for_causal_lm_forward(self: 'BloomForCausalLM', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + **deprecated_arguments): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + transformer_outputs = bloom_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), + shift_labels.view(batch_size * seq_length)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def bloom_for_sequence_classification_forward( + self: BloomForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + **deprecated_arguments, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + transformer_outputs = bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + if stage_manager.is_last_stage(): + batch_size = hidden_states.shape[0] + #update batch size + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def bloom_for_token_classification_forward( + self: BloomForTokenClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + **deprecated_arguments, +): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + transformer_outputs = bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + +def bloom_for_question_answering_forward( + self: BloomForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, +): + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = bloom_model_forward( + self.transformer, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} diff --git a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py index 31760d692694..3a36479fc8bb 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py @@ -46,23 +46,22 @@ def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_la stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') - x = torch.randint(0, 1000, (2, 3)).cuda() - hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32).cuda() + x = torch.randint(0, 1000, (1, 3)).cuda() + hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name == 'transformers_bloom': - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask) - assert output['hidden_states'].shape == (2, 3, 64) - else: - attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - assert output[0].shape == (2, 3, 64) + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask) + assert output['hidden_states'].shape == (1, 3, 64) + else: + attention_mask = torch.ones((1, 3)).cuda() + output = sharded_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + assert output[0].shape[0] == 1 torch.cuda.empty_cache() From 1fcafe9f671df4cb4304af8f0928779e81639469 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 18 Jul 2023 11:42:58 +0800 Subject: [PATCH 27/80] [bugs] hot fix some testing bugs for new models (#4268) * hot fix * hot fx tracer --- tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py | 1 + tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py | 2 ++ tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py | 2 +- tests/test_shardformer/test_model/test_pure_pipeline.py | 2 -- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index 58c8132e1490..e6f8df2e0af7 100644 --- a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -22,6 +22,7 @@ def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = Non try: meta_args = {k: v.to('meta') for k, v in inputs.items()} gm = symbolic_trace(model, meta_args=meta_args) + except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index 632ad366ccc4..7773de480302 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -14,6 +14,8 @@ def test_bert(): for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() + if model.__class__.__name__ == "BertForQuestionAnswering": + continue trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label']) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 31bcb7028e25..e29afe786c46 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -18,7 +18,7 @@ def test_gpt(): # TODO: support the following models # 1. GPT2DoubleHeadsModel # as they are not supported, let's skip them - if model.__class__.__name__ in ['GPT2DoubleHeadsModel']: + if model.__class__.__name__ in ['GPT2DoubleHeadsModel', 'GPT2ForQuestionAnswering']: continue trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py index 24cda193a5e6..80767f71c3fb 100644 --- a/tests/test_shardformer/test_model/test_pure_pipeline.py +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -122,9 +122,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la 2: [2, 3], 3: [2, 3], } - from datasets import load_dataset - #dataset = load_dataset("open_subtitles", lang1="fi", lang2="hi") pg_mesh = ProcessGroupMesh(PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') From e9025441b8c2df9d74787b3ac39a60e991d0e02d Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 19 Jul 2023 09:28:27 +0800 Subject: [PATCH 28/80] [pipeline] support shardformer for GPT2ForQuestionAnswering & complete pipeline support for GPT2 (#4245) * change for transformers loggers * add forward for GPT2ForQuestionAnswering * fix assert * fix torchrec test --- .../shardformer/policies/auto_policy.py | 2 + colossalai/shardformer/policies/gpt2.py | 136 ++++++++++++++++-- tests/kit/model_zoo/torchrec/__init__.py | 2 +- tests/kit/model_zoo/transformers/gpt.py | 17 +++ .../test_model/test_shard_gpt2_pipeline.py | 1 - 5 files changed, 147 insertions(+), 11 deletions(-) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index ccdb33b2efe5..b31f1b35f580 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -68,6 +68,8 @@ class PolicyLocation: PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"), "transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering": + PolicyLocation(file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"), "transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"), "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 5d6f47636587..05178895d2e9 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,6 +1,4 @@ -import logging from functools import partial -from types import MethodType from typing import Callable, Dict, List, Optional, Tuple, Union import torch @@ -298,6 +296,33 @@ def postprocess(self): return self.model +# GPT2ForQuestionAnswering +class GPT2ForQuestionAnsweringPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering + + module_policy = super().module_policy() + self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering, + new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward, + policy=module_policy) + + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''No shared_params in gpt2 for QA.''' + return [] + + # GPT2ForTokenClassification class GPT2ForTokenClassificationPolicy(GPT2Policy): @@ -391,6 +416,8 @@ def gpt2_model_forward( # Please refer to original code of transformers for more details. from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions + from transformers.utils import logging + logger = logging.get_logger(__name__) # Preprocess passed in arguments output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -416,7 +443,8 @@ def gpt2_model_forward( if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, seq_length) else: - assert hidden_states is not None + if hidden_states is None: + raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device @@ -478,21 +506,21 @@ def gpt2_model_forward( # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logging.warning('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None if output_attentions: - logging.warning('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False if output_hidden_states: - logging.warning('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False if use_cache: - logging.warning('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') use_cache = False if self.gradient_checkpointing and self.training: if use_cache: - logging.warning( + logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False presents = () if use_cache else None @@ -751,6 +779,94 @@ def gpt2_double_heads_model_forward( attentions=outputs.attentions, ) + @staticmethod + def gpt2_for_question_answering_forward( + self: 'GPT2ForQuestionAnswering', + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'QuestionAnsweringModelOutput']: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward. + # Please refer to original code of transformers for more details. + """ + from transformers.modeling_outputs import QuestionAnsweringModelOutput + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + @staticmethod def gpt2_for_token_classification_forward( self: 'GPT2ForTokenClassification', @@ -852,6 +968,8 @@ def gpt2_for_sequence_classification_forward( # Please refer to original code of transformers for more details. """ from transformers.modeling_outputs import SequenceClassifierOutputWithPast + from transformers.utils import logging + logger = logging.get_logger(__name__) if input_ids is not None: batch_size, _ = input_ids.shape[:2] @@ -892,7 +1010,7 @@ def gpt2_for_sequence_classification_forward( sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) else: sequence_lengths = -1 - logging.warning( + logger.warning_once( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`") diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 4a19f2449602..43952e6998cf 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -#from .torchrec import * +from .torchrec import * diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index f9a0888ff80d..0fbcaa1e2bb3 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -29,6 +29,17 @@ def data_gen_for_lm(): return data +def data_gen_for_question_answering(): + # question answering data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + start_positions = torch.tensor([0], dtype=torch.int64) + data['start_positions'] = start_positions + end_positions = torch.tensor([1], dtype=torch.int64) + data['end_positions'] = end_positions + return data + + def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 @@ -82,6 +93,12 @@ def data_gen_for_sequence_classification(): output_transform_fn=output_transform_fn, loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_gpt_for_question_answering', + model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_for_token_classification', model_fn=lambda: transformers.GPT2ForTokenClassification(config), data_gen_fn=data_gen_for_token_classification, diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py index dd439a394827..005e3d6f8759 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -27,7 +27,6 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} input_ids, _ = inputs['input_ids'], inputs['attention_mask'] From 1d3a5c44e215b1a83d9249d03258d8326171446d Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 20 Jul 2023 10:39:06 +0800 Subject: [PATCH 29/80] [shardformer] support inplace sharding (#4251) * [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer version --- colossalai/shardformer/layer/__init__.py | 3 +- colossalai/shardformer/layer/embedding.py | 68 +++++---- colossalai/shardformer/layer/linear.py | 120 +++++++++------ colossalai/shardformer/layer/normalization.py | 10 +- .../shardformer/layer/qkv_fused_linear.py | 138 ++++++++++-------- colossalai/shardformer/policies/bert.py | 54 +++---- colossalai/shardformer/policies/bloom.py | 17 +-- colossalai/shardformer/policies/gpt2.py | 99 +++++-------- colossalai/shardformer/policies/llama.py | 9 +- colossalai/shardformer/policies/opt.py | 14 -- colossalai/shardformer/policies/t5.py | 32 +--- colossalai/shardformer/shard/sharder.py | 4 +- colossalai/tensor/d_tensor/api.py | 20 +++ requirements/requirements-test.txt | 1 + .../test_layer/test_embedding.py | 6 +- .../test_layer/test_layernorm.py | 7 +- .../test_layer/test_linear_1d.py | 34 +++-- .../test_layer/test_qkv_fused_linear_1d.py | 19 ++- .../test_vocab_parallel_embedding_1d.py | 9 +- tests/test_shardformer/test_model/_utils.py | 15 +- .../test_model/test_shard_bert.py | 3 +- .../test_model/test_shard_bloom.py | 3 +- .../test_model/test_shard_gpt2.py | 3 +- .../test_model/test_shard_llama.py | 3 +- .../test_model/test_shard_opt.py | 3 +- .../test_model/test_shard_t5.py | 3 +- 26 files changed, 364 insertions(+), 333 deletions(-) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7fad4948dfd0..7cdcfc31811f 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,10 +3,11 @@ from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm +from .parallel_module import ParallelModule from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm' + 'FusedLayerNorm', 'FusedRMSNorm', 'ParallelModule' ] diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 07341ef73515..09b22abb17cc 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union import torch import torch.distributed as dist @@ -13,7 +13,12 @@ from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param +from colossalai.tensor.d_tensor.api import ( + is_distributed_tensor, + shard_colwise, + shard_rowwise, + sharded_tensor_to_existing_param, +) from ._operation import gather_forward_split_backward, reduce_forward from .parallel_module import ParallelModule @@ -60,6 +65,7 @@ def __init__(self, device: torch.device = None, process_group: ProcessGroup = None, gather_output: bool = True, + weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -74,18 +80,24 @@ def __init__(self, self.embed_kwargs = kwargs self.gather_output = gather_output - # Parameters. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs) - sharded_weight = shard_colwise(weight, process_group) - self.weight = sharded_tensor_to_param(sharded_weight) - # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) + # Parameters. + if weight is None: + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_colwise(self.weight.data, process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if weight is None: + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) @staticmethod def from_native_module(module: nn.Embedding, @@ -121,14 +133,10 @@ def from_native_module(module: nn.Embedding, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, + weight=module.weight, *args, **kwargs) - # copy the weight - with torch.no_grad(): - sharded_weight = shard_colwise(module.weight.data, process_group) - embedding.weight.copy_(sharded_weight) - return embedding def reset_parameters(self, weight_initializer) -> None: @@ -143,7 +151,6 @@ def _fill_padding_idx_with_zero(self) -> None: def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - if self.gather_output: output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) return output @@ -188,6 +195,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -207,16 +215,23 @@ def __init__(self, self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition - # parameter - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs) - sharded_weight = shard_rowwise(weight, process_group) - self.weight = sharded_tensor_to_param(sharded_weight) - # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - self.reset_parameters(weight_initializer) + + # parameter + if weight is None: + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if weight is None: + self.reset_parameters(weight_initializer) @staticmethod def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -243,15 +258,10 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, padding_idx=padding_idx, device=device, process_group=process_group, + weight=module.weight, *args, **kwargs) - with torch.no_grad(): - # shard and slice the weight along the vocabulary(num_embeddings) dimension - # the shape of the weight is (num_embeddings, embedding_dim) - shard_weight = shard_rowwise(module.weight.data, process_group) - vocab_embedding_1d.weight.data.copy_(shard_weight) - return vocab_embedding_1d def reset_parameters(self, weight_initializer) -> None: diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 383d9b3f533a..bb36854bd772 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -2,7 +2,7 @@ # -*- encoding: utf-8 -*- import math -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -15,7 +15,12 @@ from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param +from colossalai.tensor.d_tensor.api import ( + is_distributed_tensor, + shard_colwise, + shard_rowwise, + sharded_tensor_to_existing_param, +) from ._operation import ( gather_forward_split_backward, @@ -65,6 +70,8 @@ def __init__(self, process_group: ProcessGroup = None, gather_output: bool = False, skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -80,26 +87,42 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - # Parameters. - factory_kwargs = {'device': device, 'dtype': dtype} + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' - weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) - sharded_weight = shard_rowwise(weight, self.process_group) - self.weight = sharded_tensor_to_param(sharded_weight) + # Parameters. + if weight is None: + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) if bias: - bias = torch.empty(self.out_features, **factory_kwargs) - sharded_bias = shard_colwise(bias, self.process_group) - self.bias = sharded_tensor_to_param(sharded_bias) + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + if not is_distributed_tensor(self.bias): + sharded_bias = shard_colwise(self.bias.data, self.process_group) + sharded_tensor_to_existing_param(sharded_bias, self.bias) else: self.bias = None - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - # init weights - self.reset_parameters(weight_initializer, bias_initializer) + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -125,17 +148,11 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - with torch.no_grad(): - # the weight to the linear layer is a transpose - # thus shard on row is equal to shard on column - sharded_weight = shard_rowwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight) - if bias: - sharded_bias = shard_colwise(module.bias.data, process_group) - linear_1d.bias.copy_(sharded_bias) return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: @@ -198,6 +215,8 @@ def __init__(self, process_group: ProcessGroup = None, parallel_input: bool = True, skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1): @@ -216,27 +235,44 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + # Parameters. - # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) - sharded_weight = shard_colwise(weight, self.process_group) - self.weight = sharded_tensor_to_param(sharded_weight) + if weight is None: + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_colwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) if self.stream_chunk_num > 1: # TODO() work for inference only self.chunk_weight() + if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ else: self.bias = None - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) + if weight is None: + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -262,19 +298,11 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on col is equal to shard on row - sharded_weight = shard_colwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight) - - if bias: - linear_1d.bias.copy_(module.bias.data) - return linear_1d def chunk_weight(self): diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 9bb7738c0f0a..0aea295664a7 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -60,10 +60,8 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) - with torch.no_grad(): - # copy weight and bias - layernorm.weight.copy_(module.weight) - layernorm.bias.copy_(module.bias) + layernorm.weight = module.weight + layernorm.bias = module.bias return layernorm @@ -101,8 +99,6 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) - with torch.no_grad(): - # copy weight and bias - rmsnorm.weight.copy_(module.weight) + rmsnorm.weight = module.weight return rmsnorm diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index c94d93069e93..bcefcf058ce0 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -2,12 +2,11 @@ # -*- encoding: utf-8 -*- import math -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn -import torch.nn.functional as F from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter @@ -16,10 +15,12 @@ from colossalai.nn import init as init from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor.api import ( - customized_distributed_tensor_to_param, + customized_distributed_tensor_to_existing_param, distribute_tensor_with_customization, + is_customized_distributed_tensor, + is_distributed_tensor, shard_rowwise, - sharded_tensor_to_param, + sharded_tensor_to_existing_param, ) from ._operation import ( @@ -173,6 +174,8 @@ def __init__(self, gather_output: bool = False, skip_bias_add: bool = False, n_fused: int = 3, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -190,40 +193,56 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + # Parameters. - # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) + if weight is None: + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight def shard_fn(tensor): return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) def gather_fn(tensor): - return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True) + return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) - with torch.no_grad(): - sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn) - self.weight = customized_distributed_tensor_to_param(sharded_weight) + if not is_customized_distributed_tensor(self.weight): + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_weight, self.weight) if bias: - bias = torch.empty(self.out_features, **factory_kwargs) - - with torch.no_grad(): - sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn) - self.bias = customized_distributed_tensor_to_param(sharded_bias) + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + if not is_customized_distributed_tensor(self.bias): + with torch.no_grad(): + sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_bias, self.bias) else: self.bias = None - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - # init weights - self.reset_parameters(weight_initializer, bias_initializer) + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, - *args, **kwargs) -> ParallelModule: + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: r""" Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. @@ -250,24 +269,11 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - # TODO: copy the sharded weights - with torch.no_grad(): - sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, - n_fused=n_fused, - process_group=process_group, - is_transposed=True) - linear_1d.weight.data.copy_(sharded_weight.data) - - if bias: - sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, - n_fused=n_fused, - process_group=process_group, - is_transposed=True) - linear_1d.bias.data.copy_(sharded_bias.data) - return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: @@ -333,6 +339,8 @@ def __init__(self, process_group: ProcessGroup = None, parallel_input: bool = True, skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1): @@ -351,30 +359,46 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + # Divide the weight matrix along the last dimension. self.input_size_per_partition = divide(in_features, self.num_partitions) + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + # Parameters. - # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) - sharded_weight = shard_rowwise(weight, self.process_group) - self.weight = sharded_tensor_to_param(sharded_weight) + if weight is None: + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) if self.stream_chunk_num > 1: # TODO() work for inference only self.chunk_weight() if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ else: self.bias = None - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - # init weights - self.reset_parameters(weight_initializer, bias_initializer) + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -400,19 +424,11 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on col is equal to shard on row - sharded_weight = shard_rowwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight.data) - - if bias: - linear_1d.bias.copy_(module.bias.data) - return linear_1d def chunk_weight(self): diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 1af26f50484c..0a1a466210b2 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,13 +1,11 @@ from functools import partial -from types import MethodType -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional import torch import torch.nn as nn from torch import Tensor from torch.nn import CrossEntropyLoss, Module from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, MultipleChoiceModelOutput, @@ -28,12 +26,11 @@ BertLMHeadModel, BertModel, ) -from transformers.utils import ModelOutput, logging +from transformers.utils import logging import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription logger = logging.get_logger(__name__) @@ -177,6 +174,17 @@ def add_lm_head_policy(self, base_policy): target_key=BertLMPredictionHead) return base_policy + def add_lm_prediction_policy(self, base_policy): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + method_replacement = { + '_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict, + '_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict, + } + self.append_or_create_method_replacement(description=method_replacement, + policy=base_policy, + target_key=BertLMPredictionHead) + return base_policy + def postprocess(self): return self.model @@ -240,6 +248,7 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) + policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForPreTraining self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy) return policy @@ -266,21 +275,13 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: model = self.model if self.pipeline_stage_manager: if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight): - #tie weights + # tie weights return [{ 0: model.bert.embeddings.word_embeddings.weight, self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight }] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): @@ -291,6 +292,7 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) + policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertLMHeadModel self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy) return policy @@ -316,21 +318,13 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: bert_model = self.model.bert if self.pipeline_stage_manager: if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): - #tie weights + # tie weights return [{ 0: bert_model.embeddings.word_embeddings.weight, self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight }] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # BertForMaskedLM class BertForMaskedLMPolicy(BertPolicy): @@ -341,6 +335,7 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) + mpolicy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForMaskedLM self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy) return policy @@ -366,21 +361,13 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: bert_model = self.model.bert if self.pipeline_stage_manager: if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): - #tie weights + # tie weights return [{ 0: bert_model.embeddings.word_embeddings.weight, self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight }] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # BertForSequenceClassification class BertForSequenceClassificationPolicy(BertPolicy): @@ -1032,6 +1019,7 @@ def bert_for_masked_lm_forward( stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, ): + # -> Union[Tuple[torch.Tensor], MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., @@ -1109,7 +1097,7 @@ def bert_for_next_sentence_prediction_forward( stage_index: Optional[List[int]] = None, **kwargs, ): - #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 8afaadefb696..b0e45452964e 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,9 +1,7 @@ import warnings from functools import partial -from types import MethodType from typing import Callable, Dict, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn from torch import Tensor @@ -27,7 +25,6 @@ import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from .._utils import getattr_, setattr_ from ..modeling.bloom import build_bloom_alibi_tensor_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -229,20 +226,10 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # tie weights return [{ 0: bloom_model.transformer.word_embeddings.weight, - self.stage_manager.num_stages - 1: bloom_model.lm_head.weight + self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight }] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} - - for k, v in binding_map.items(): - param = getattr_(self.model, k) - # tie weights - setattr_(self.model, v, param) - return self.model - class BloomForSequenceClassificationPolicy(BloomPolicy): @@ -692,7 +679,7 @@ def bloom_for_sequence_classification_forward( all_cross_attentions = None if stage_manager.is_last_stage(): batch_size = hidden_states.shape[0] - #update batch size + # update batch size hidden_states = transformer_outputs[0] logits = self.score(hidden_states) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 05178895d2e9..6614a32b54d0 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -8,7 +8,6 @@ import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -56,42 +55,42 @@ def module_policy(self): "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: @@ -99,8 +98,8 @@ def module_policy(self): suffix="ln_f", target_module=col_nn.FusedLayerNorm, ), - policy=policy, - target_key=GPT2Model) + policy=policy, + target_key=GPT2Model) self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( @@ -115,8 +114,8 @@ def module_policy(self): target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True) ], - policy=policy, - target_key=GPT2Block) + policy=policy, + target_key=GPT2Block) return policy def postprocess(self): @@ -227,15 +226,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: else: return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism \ - and self.pipeline_stage_manager is None: - binding_map = {"transformer.wte.weight": "lm_head.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # GPT2DoubleHeadsModel class GPT2DoubleHeadsModelPolicy(GPT2Policy): @@ -286,15 +276,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: else: return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism \ - and self.pipeline_stage_manager is None: - binding_map = {"transformer.wte.weight": "lm_head.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # GPT2ForQuestionAnswering class GPT2ForQuestionAnsweringPolicy(GPT2Policy): diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b3757452c314..c7cd8182a4ca 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,7 +1,5 @@ -import math from functools import partial -from types import MethodType -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import torch import torch.nn as nn @@ -9,14 +7,11 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel -from transformers.utils import ModelOutput, logging +from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 1435805d2846..bbcc90e00157 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,6 +1,5 @@ from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -116,19 +115,6 @@ def module_policy(self): target_key=OPTForCausalLM) return policy - def postprocess(self): - if self.shard_config.enable_tensor_parallelism: - binding_map = { - 'model.decoder.embed_tokens': 'lm_head', - } - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight - - return self.model - class OPTForSequenceClassificationPolicy(OPTPolicy): diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 37864885b4cc..6b8f404f1769 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -8,7 +8,6 @@ ) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription -from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] @@ -53,7 +52,7 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="embed_tokens", - target_module=Embedding1D, + target_module=VocabParallelEmbedding1D, ) ]) policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[ @@ -165,12 +164,6 @@ def module_policy(self): return policy def postprocess(self): - if self.shard_config.enable_tensor_parallelism: - binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] - - for k, v in binding_map: - mod = getattr_(self.model, k) - setattr_(self.model, v, mod) return self.model @@ -211,18 +204,6 @@ def module_policy(self): target_key=T5ForConditionalGeneration) return policy - def postprocess(self): - super().postprocess() - if self.shard_config.enable_tensor_parallelism: - binding_map = {"shared": "lm_head"} - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight - - return self.model - class T5EncoderPolicy(T5BasePolicy): @@ -239,14 +220,3 @@ def module_policy(self): policy=base_policy, target_key=T5EncoderModel) return base_policy - - def postprocess(self): - if self.shard_config.enable_tensor_parallelism: - binding_map = [ - ["shared", "encoder.embed_tokens"], - ] - - for k, v in binding_map: - mod = getattr_(self.model, k) - setattr_(self.model, v, mod) - return self.model diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 5e0b572e259c..b32c285bdaab 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -37,11 +37,13 @@ def shard(self) -> List[Dict[int, Tensor]]: self.policy.set_model(self.model) self.policy.set_shard_config(self.shard_config) self._preprocess() + # get shared params before release unheld layers, this avoid misjudgement of shared params (None is None) + shared_params = self.policy.get_shared_params() self._release_unheld_layers() self._replace_module() self._materialize() self._postprocess() - return self.policy.get_shared_params() + return shared_params def _preprocess(self) -> None: self.model = self.policy.preprocess() diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index 95a44e09e16a..32182faf6981 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -235,6 +235,14 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): return param +def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None: + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + param.data = dtensor + # make it distributed as well + param.dist_layout = dtensor.dist_layout + _hijack_detach_and_clone(param) + + def compute_global_numel(dtensor: torch.Tensor) -> int: """ Compute the global number of elements in the distributed tensor. @@ -432,3 +440,15 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: param.gather_fn = dtensor.gather_fn _hijack_detach_and_clone_for_customized_distributed_tensor(param) return param + + +def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter): + """ + Convert the given customized distributed tensor to an existing parameter. + """ + assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + + param.data = dtensor.data + param.shard_fn = dtensor.shard_fn + param.gather_fn = dtensor.gather_fn + _hijack_detach_and_clone_for_customized_distributed_tensor(param) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index e65271621ddd..2dae645f7eb9 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -17,3 +17,4 @@ requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggi SentencePiece ninja flash_attn>=2.0 +datasets diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py index 99e494359af7..d62dba7ea92a 100644 --- a/tests/test_shardformer/test_layer/test_embedding.py +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -15,11 +15,13 @@ def check_embedding_1d(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + embedding = nn.Embedding(32, 128).cuda() with ctx: - embedding = nn.Embedding(32, 128).cuda() - embedding_1d = Embedding1D.from_native_module(embedding, process_group=None) + embedding_copy = nn.Embedding(32, 128).cuda() + embedding_1d = Embedding1D.from_native_module(embedding_copy, process_group=None) assert embedding_1d.weight.shape == torch.Size([32, 64]) + assert embedding_1d.weight is embedding_copy.weight # ensure state dict is reversibly loadable embedding.load_state_dict(embedding_1d.state_dict()) diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index 2cb6928edf83..f9c21b82a282 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -14,11 +14,14 @@ def check_layernorm(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + norm = nn.LayerNorm(128, 0.00001).cuda() with ctx: - norm = nn.LayerNorm(128, 0.00001).cuda() - norm1d = FusedLayerNorm.from_native_module(norm, process_group=None) + norm_copy = nn.LayerNorm(128, 0.00001).cuda() + norm1d = FusedLayerNorm.from_native_module(norm_copy, process_group=None) assert norm1d.weight.shape == torch.Size([128]) + assert norm_copy.weight is norm1d.weight + assert norm_copy.bias is norm1d.bias # ensure state dict is reversibly loadable norm.load_state_dict(norm1d.state_dict()) diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index da3cd85ec407..aa75879e0313 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -15,14 +15,16 @@ @parameterize('lazy_init', [False, True]) def check_linear_1d_col(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() - + linear = nn.Linear(32, 128).cuda() with ctx: - linear = nn.Linear(32, 128).cuda() - linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) + linear_copy = nn.Linear(32, 128).cuda() + linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True) # ensure that the parameters are distributed assert is_distributed_tensor(linear_col.weight) assert is_distributed_tensor(linear_col.bias) + assert linear_copy.weight is linear_col.weight + assert linear_copy.bias is linear_col.bias # ensure the shape is correct assert linear_col.weight.shape == torch.Size([64, 32]) @@ -61,12 +63,18 @@ def check_linear_1d_col(lazy_init: bool): def check_linear_1d_row(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + linear = nn.Linear(32, 128).cuda() with ctx: - linear = nn.Linear(32, 128).cuda() - linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + linear_copy = nn.Linear(32, 128).cuda() + linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.bias.shape == torch.Size([128]) + assert linear_copy.weight is linear_row.weight + assert linear_copy.bias is linear_row.bias + + linear.load_state_dict(linear_row.state_dict()) + linear_row.load_state_dict(linear.state_dict()) # check computation correctness x = torch.rand(4, 32).cuda() @@ -98,11 +106,19 @@ def check_linear_1d_row(lazy_init: bool): def check_linear_col_plus_row(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + linear_1 = nn.Linear(32, 128).cuda() + linear_2 = nn.Linear(128, 32).cuda() + with ctx: - linear_1 = nn.Linear(32, 128).cuda() - linear_2 = nn.Linear(128, 32).cuda() - linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False) - linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True) + linear_1_copy = nn.Linear(32, 128).cuda() + linear_2_copy = nn.Linear(128, 32).cuda() + linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False) + linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True) + + linear_1.load_state_dict(linear_col.state_dict()) + linear_col.load_state_dict(linear_1.state_dict()) + linear_2.load_state_dict(linear_row.state_dict()) + linear_row.load_state_dict(linear_2.state_dict()) # check computation correctness x = torch.rand(4, 32).cuda() diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index 186b1e8212cc..b45cd172c3ca 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -56,10 +56,10 @@ def rearrange(tensor: torch.Tensor, dim: int): @parameterize('lazy_init', [False, True]) def check_linear_conv_1d_col(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() - + linear = Conv1D(192, 48).cuda() with ctx: - linear = Conv1D(192, 48).cuda() - linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, + linear_copy = Conv1D(192, 48).cuda() + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True, n_fused=3) @@ -68,6 +68,8 @@ def check_linear_conv_1d_col(lazy_init: bool): assert linear.bias.shape == torch.Size([192]) assert linear_conv_col.weight.shape == torch.Size([48, 96]) assert linear_conv_col.bias.shape == torch.Size([96]) + assert linear_copy.weight is linear_conv_col.weight + assert linear_copy.bias is linear_conv_col.bias # ensure weights are reversibly loadable linear_conv_col.load_state_dict(linear.state_dict()) @@ -91,13 +93,20 @@ def check_linear_conv_1d_col(lazy_init: bool): def check_linear_conv_1d_row(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + linear = Conv1D(192, 48).cuda() with ctx: - linear = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + linear_copy = Conv1D(192, 48).cuda() + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) assert linear_row.bias.shape == torch.Size([192]) + assert linear_copy.weight is linear_row.weight + assert linear_copy.bias is linear_row.bias + + # ensure weights are reversibly loadable + linear_row.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_row.state_dict()) # check computation correctness x = torch.rand(4, 48).cuda() diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index bf5803496f03..6d2f087302d9 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -7,8 +7,7 @@ import colossalai from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, VocabParallelEmbedding1D -from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style +from colossalai.shardformer.layer import VocabParallelEmbedding1D from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -16,13 +15,15 @@ def check_vocab_embedding_1d(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + embedding = nn.Embedding(128, 32).to('cuda') with ctx: - embedding = nn.Embedding(128, 32).to('cuda') - dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None) + embedding_copy = nn.Embedding(128, 32).to('cuda') + dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None) assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) assert dist_embedding_1d.num_embeddings == 64 assert dist_embedding_1d.embedding_dim == 32 + assert embedding_copy.weight is dist_embedding_1d.weight # ensure state dict is reversibly loadable embedding.load_state_dict(dist_embedding_1d.state_dict()) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 825d6df6bb5e..2320c725d444 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,8 +1,10 @@ import copy from contextlib import nullcontext +import torch +from torch.nn import Module + from colossalai.lazy import LazyInitContext -from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -61,3 +63,14 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, shard_output = output_transform_fn(shard_output) shard_loss = loss_fn(shard_output) return org_output, org_loss, shard_output, shard_loss + + +def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''): + org_sd = org_model.state_dict() + shard_sd = sharded_model.state_dict() + for k, v in org_sd.items(): + assert k in shard_sd, f'{name} {k} not in sharded model' + shard_v = shard_sd[k] + assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}' + assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}' + assert torch.equal(v, shard_v), f'{name} {k} value mismatch' diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 7f179acd7356..ea0f122644dc 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -12,7 +12,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -75,6 +75,7 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index e18168292df5..fe4686aeb979 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -12,7 +12,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -75,6 +75,7 @@ def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_la for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 552c6e2f4d53..99451b403eb7 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -12,7 +12,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -77,6 +77,7 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 4d63a43489a3..aaeef13ef873 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -14,7 +14,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -78,6 +78,7 @@ def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_la for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index c008596fe2b6..297affceb68a 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -15,7 +15,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -77,6 +77,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index ccd7d3787d3d..96dfdeb73827 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -14,7 +14,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -88,6 +88,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From bd6cf035780cdc407e679fb5c11cd9cd5c3953cf Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 20 Jul 2023 11:46:05 +0800 Subject: [PATCH 30/80] [pipeline] refactor gpt2 pipeline forwards (#4287) * move gpt2 pipeline forwards to modeling folder * check pipeline status when adding replacing policy * fix typehint * fix arguments processing in gpt2_model_forward --- colossalai/shardformer/modeling/gpt2.py | 668 +++++++++++++++++++++ colossalai/shardformer/policies/gpt2.py | 759 ++---------------------- 2 files changed, 718 insertions(+), 709 deletions(-) create mode 100644 colossalai/shardformer/modeling/gpt2.py diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py new file mode 100644 index 000000000000..5519d0b3098c --- /dev/null +++ b/colossalai/shardformer/modeling/gpt2.py @@ -0,0 +1,668 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.models.gpt2.modeling_gpt2 import ( + GPT2DoubleHeadsModel, + GPT2DoubleHeadsModelOutput, + GPT2ForQuestionAnswering, + GPT2ForSequenceClassification, + GPT2ForTokenClassification, + GPT2LMHeadModel, + GPT2Model, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class GPT2PipelineForwards: + ''' + This class serves as a micro library for forward function substitution of GPT2 models + under pipeline setting. + ''' + + @staticmethod + def gpt2_model_forward( + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. + # Please refer to original code of transformers for more details. + + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + input_shape = input_ids.size() + input_ids = input_ids.view(-1, seq_length) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_length) + else: + if hidden_states is None: + raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if stage_manager.is_first_stage(): + if position_ids is not None: + position_ids = position_ids.view(-1, seq_length) + else: + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # Going through held blocks. + start_idx, end_idx = stage_index[0], stage_index[1] + for i in range(start_idx, end_idx): + block = self.h[i] + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=None, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + if stage_manager.is_last_stage(): + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + else: + # always return dict for intermediate stage + return {'hidden_states': hidden_states} + + @staticmethod + def gpt2_lmhead_model_forward( + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @staticmethod + def gpt2_double_heads_model_forward( + self: GPT2DoubleHeadsModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward. + Please refer to original code of transformers for more details. + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_question_answering_forward( + self: GPT2ForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward. + # Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_token_classification_forward( + self: GPT2ForTokenClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward. + # Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def gpt2_for_sequence_classification_forward( + self: GPT2ForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward. + # Please refer to original code of transformers for more details. + """ + logger = logging.get_logger(__name__) + + if input_ids is not None: + batch_size, _ = input_ids.shape[:2] + else: + batch_size, _ = hidden_states.shape[:2] + assert (self.config.pad_token_id is not None + or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined." + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return {'hidden_states': outputs['hidden_states']} + + hidden_states = outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 6614a32b54d0..6d734b063036 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,13 +1,11 @@ from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List -import torch from torch import Tensor, nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss import colossalai.shardformer.layer as col_nn -from colossalai.pipeline.stage_manager import PipelineStageManager +from ..modeling.gpt2 import GPT2PipelineForwards from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -146,19 +144,18 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface to customized forward method, and add this changing to policy.""" - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'GPT2Model': - module = self.model - else: - module = self.model.transformer + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == 'GPT2Model': + module = self.model + else: + module = self.model.transformer - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) # GPT2Model @@ -171,9 +168,11 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Model policy = super().module_policy() - self.set_pipeline_forward(model_cls=GPT2Model, - new_forward=GPT2PipelineForwards.gpt2_model_forward, - policy=policy) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=GPT2Model, + new_forward=GPT2PipelineForwards.gpt2_model_forward, + policy=policy) return policy def get_held_layers(self) -> List[nn.Module]: @@ -205,9 +204,10 @@ def module_policy(self): } module_policy.update(addon_module) - self.set_pipeline_forward(model_cls=GPT2LMHeadModel, - new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - policy=module_policy) + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy) return module_policy def get_held_layers(self) -> List[nn.Module]: @@ -220,11 +220,11 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: '''The weights of wte and lm_head are shared.''' module = self.model stage_manager = self.pipeline_stage_manager - if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight): - first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] - else: - return [] + if stage_manager is not None: + if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [] # GPT2DoubleHeadsModel @@ -248,9 +248,10 @@ def module_policy(self): } module_policy.update(addon_module) - self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel, - new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward, - policy=module_policy) + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel, + new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward, + policy=module_policy) return module_policy @@ -270,11 +271,11 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: '''The weights of wte and lm_head are shared.''' module = self.model stage_manager = self.pipeline_stage_manager - if stage_manager and id(module.transformer.wte.weight) == id(module.lm_head.weight): - first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] - else: - return [] + if stage_manager is not None: + if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [] # GPT2ForQuestionAnswering @@ -287,9 +288,11 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering module_policy = super().module_policy() - self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering, - new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward, - policy=module_policy) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering, + new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward, + policy=module_policy) return module_policy @@ -324,9 +327,10 @@ def module_policy(self): } module_policy.update(addon_module) - self.set_pipeline_forward(model_cls=GPT2ForTokenClassification, - new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward, - policy=module_policy) + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=GPT2ForTokenClassification, + new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward, + policy=module_policy) return module_policy def get_held_layers(self) -> List[nn.Module]: @@ -351,9 +355,11 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification module_policy = super().module_policy() - self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification, - new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, - policy=module_policy) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification, + new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, + policy=module_policy) return module_policy def get_held_layers(self) -> List[nn.Module]: @@ -365,668 +371,3 @@ def get_held_layers(self) -> List[nn.Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in GPT2ForTokenClassification.""" return [] - - -class GPT2PipelineForwards: - ''' - This class serves as a micro library for forward function substitution of GPT2 models - under pipeline setting. - ''' - - @staticmethod - def gpt2_model_forward( - self: 'GPT2Model', - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Tuple, 'BaseModelOutputWithPastAndCrossAttentions']: - - # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. - # Please refer to original code of transformers for more details. - - from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions - from transformers.utils import logging - logger = logging.get_logger(__name__) - - # Preprocess passed in arguments - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - input_shape = input_ids.size() - input_ids = input_ids.view(-1, seq_length) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, seq_length) - else: - if hidden_states is None: - raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape[0], input_shape[1] - device = hidden_states.device - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if stage_manager.is_first_stage(): - if position_ids is not None: - position_ids = position_ids.view(-1, seq_length) - else: - position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - hidden_states = self.drop(hidden_states) - - output_shape = input_shape + (hidden_states.size(-1),) - - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') - past_key_values = None - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - - # Going through held blocks. - start_idx, end_idx = stage_index[0], stage_index[1] - for i in range(start_idx, end_idx): - block = self.h[i] - torch.cuda.set_device(hidden_states.device) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=None, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - if stage_manager.is_last_stage(): - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if stage_manager.is_last_stage(): - if not return_dict: - return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - else: - # always return dict for intermediate stage - return {'hidden_states': hidden_states} - - @staticmethod - def gpt2_lmhead_model_forward( - self: 'GPT2LMHeadModel', - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Tuple, 'CausalLMOutputWithCrossAttentions']: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - - This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. - Please refer to original code of transformers for more details. - """ - - from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) - - # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} - - hidden_states = outputs[0] - lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - @staticmethod - def gpt2_double_heads_model_forward( - self: 'GPT2DoubleHeadsModel', - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - mc_token_ids: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - mc_labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Tuple, 'GPT2DoubleHeadsModelOutput']: - r""" - mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): - Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - - 1]`. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to - `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` - mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) - - This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward. - Please refer to original code of transformers for more details. - ```""" - from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) - - # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} - - hidden_states = outputs[0] - lm_logits = self.lm_head(hidden_states) - mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) - - mc_loss = None - if mc_labels is not None: - loss_fct = CrossEntropyLoss() - mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) - lm_loss = None - if labels is not None: - labels = labels.to(lm_logits.device) - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (lm_logits, mc_logits) + outputs[1:] - if mc_loss is not None: - output = (mc_loss,) + output - return ((lm_loss,) + output) if lm_loss is not None else output - - return GPT2DoubleHeadsModelOutput( - loss=lm_loss, - mc_loss=mc_loss, - logits=lm_logits, - mc_logits=mc_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - @staticmethod - def gpt2_for_question_answering_forward( - self: 'GPT2ForQuestionAnswering', - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Tuple, 'QuestionAnsweringModelOutput']: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - - # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward. - # Please refer to original code of transformers for more details. - """ - from transformers.modeling_outputs import QuestionAnsweringModelOutput - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) - - # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1).to(start_logits.device) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1).to(end_logits.device) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return QuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - @staticmethod - def gpt2_for_token_classification_forward( - self: 'GPT2ForTokenClassification', - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Tuple, 'TokenClassifierOutput']: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward. - # Please refer to original code of transformers for more details. - """ - - from transformers.modeling_outputs import TokenClassifierOutput - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) - - # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} - - hidden_states = outputs[0] - hidden_states = self.dropout(hidden_states) - logits = self.classifier(hidden_states) - - loss = None - if labels is not None: - labels = labels.to(logits.device) - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - @staticmethod - def gpt2_for_sequence_classification_forward( - self: 'GPT2ForSequenceClassification', - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Tuple, 'SequenceClassifierOutputWithPast']: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward. - # Please refer to original code of transformers for more details. - """ - from transformers.modeling_outputs import SequenceClassifierOutputWithPast - from transformers.utils import logging - logger = logging.get_logger(__name__) - - if input_ids is not None: - batch_size, _ = input_ids.shape[:2] - else: - batch_size, _ = hidden_states.shape[:2] - assert (self.config.pad_token_id is not None - or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined." - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) - - # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} - - hidden_states = outputs[0] - logits = self.score(hidden_states) - - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`") - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) From 6221a244ea9973a3f15502685cf5215236c26bd7 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Thu, 20 Jul 2023 11:49:46 +0800 Subject: [PATCH 31/80] [pipeline] OPT model pipeline (#4258) * opt forward and test * pause * finish opt model pipeline * finish opt pipeline * opt forward and test * pause * finish opt model pipeline * finish opt pipeline * fix opt * set transformers version * refactor the test pipeline --- colossalai/shardformer/policies/opt.py | 734 ++++++++++++++++++ .../test_bert_for_pretraining_model.py | 69 +- .../test_policy/test_bert_lm_head_model.py | 72 +- .../test_policy/test_bert_model.py | 75 +- .../test_policy/test_bloom_model.py | 86 +- .../test_model/test_shard_opt_pipeline.py | 70 ++ 6 files changed, 838 insertions(+), 268 deletions(-) create mode 100644 tests/test_shardformer/test_model/test_shard_opt_pipeline.py diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index bbcc90e00157..31934965ee56 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,3 +1,15 @@ +import logging +import random +from functools import partial +from types import MethodType +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor, nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -94,12 +106,69 @@ def module_policy(self): def postprocess(self): return self.model + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == 'OPTModel': + module = self.model.decoder + else: + module = self.model.model.decoder + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + held_layers.append(module.embed_positions) + held_layers.append(module.project_in) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.final_layer_norm) + held_layers.append(module.project_out) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == 'OPTModel': + module = self.model.decoder + else: + module = self.model.model.decoder + + layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + class OPTModelPolicy(OPTPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers.models.opt.modeling_opt import OPTModel + + policy = super().module_policy() + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=OPTModel, + new_forward=OPTPipelineForwards.opt_model_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in OPTModel.""" + return [] + class OPTForCausalLMPolicy(OPTPolicy): @@ -113,16 +182,681 @@ def module_policy(self): suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), policy=policy, target_key=OPTForCausalLM) + + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=OPTForCausalLM, + new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, + policy=policy) + return policy + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + opt_model = self.model + num_stages = self.pipeline_stage_manager.num_stages + if self.pipeline_stage_manager and num_stages > 1: + if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight): + return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}] + + def postprocess(self): + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: + binding_map = { + 'model.decoder.embed_tokens': 'lm_head', + } + + for k, v in binding_map.items(): + src_mod = getattr_(self.model, k) + dst_mod = getattr_(self.model, v) + dst_mod.weight = src_mod.weight + + return self.model + class OPTForSequenceClassificationPolicy(OPTPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers.models.opt.modeling_opt import OPTForSequenceClassification + + policy = super().module_policy() + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=OPTForSequenceClassification, + new_forward=OPTPipelineForwards.opt_for_sequence_classification_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + "no shared params in OPTForSequenceClassification" + return [] + class OPTForQuestionAnsweringPolicy(OPTPolicy): def __init__(self) -> None: super().__init__() + + def module_policy(self): + from transformers.models.opt.modeling_opt import OPTForQuestionAnswering + + policy = super().module_policy() + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=OPTForQuestionAnswering, + new_forward=OPTPipelineForwards.opt_for_question_answering_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + "no shared params in OPTForSequenceClassification" + return [] + + +class OPTPipelineForwards: + ''' + This class serves as a micro library for forward function substitution of OPT models + under pipeline setting. + ''' + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + from transformers.models.opt.modeling_opt import _make_causal_mask + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + _dtype, + device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, + tgt_len=input_shape[-1]).to(device) + combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + + combined_attention_mask) + + return combined_attention_mask + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + @staticmethod + def opt_model_forward( + self: 'OPTModel', + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, 'BaseModelOutputWithPast']: + ''' + This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward + ''' + + from transformers.modeling_outputs import BaseModelOutputWithPast + from transformers.utils import logging + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + decoder = self.decoder + if stage_manager.is_first_stage(): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + batch_size, seq_length = input_shape + + if inputs_embeds is None: + inputs_embeds = decoder.embed_tokens(input_ids) + + if decoder.project_in is not None: + inputs_embeds = decoder.project_in(inputs_embeds) + device = input_ids.device if input_ids is not None else inputs_embeds.device + _dtype = inputs_embeds.dtype + + else: + if hidden_states is None: + raise ValueError("hidden_states shouln't be None for intermediate stages.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + _dtype = hidden_states.dtype + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + # embed positions + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)") + + causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, + device, past_key_values_length) + + if stage_manager.is_first_stage(): + pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) + hidden_states = inputs_embeds + pos_embeds + + if decoder.gradient_checkpointing and decoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(decoder.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for" + f" {head_mask.size()[0]}.") + + start_idx, end_idx = stage_index[0], stage_index[1] + + torch.cuda.set_device(device) + + for idx in range(start_idx, end_idx): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + decoder_layer = decoder.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + dropout_probability = random.uniform(0, 1) + if decoder.training and (dropout_probability < decoder.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if decoder.gradient_checkpointing and decoder.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + causal_attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + if decoder.final_layer_norm is not None: + hidden_states = decoder.final_layer_norm(hidden_states) + if decoder.project_out is not None: + hidden_states = decoder.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + else: + return {'hidden_states': hidden_states} + + @staticmethod + def opt_for_causal_lm_forward( + self: 'OPTForCausalLM', + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, 'CausalLMOutputWithPast']: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + from transformers.modeling_outputs import CausalLMOutputWithPast + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = OPTPipelineForwards.opt_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + logits = self.lm_head(outputs[0]).contiguous() + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def opt_for_sequence_classification_forward( + self: 'OPTForSequenceClassification', + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, 'SequenceClassifierOutputWithPast']: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + from transformers.modeling_outputs import SequenceClassifierOutputWithPast + from transformers.utils import logging + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0] + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def opt_for_question_answering_forward( + self: 'OPTForQuestionAnswering', + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, 'QuestionAnsweringModelOutput']: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForQuestionAnswering + >>> import torch + + >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> # note: we are loading a OPTForQuestionAnswering from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random + >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + + >>> inputs = tokenizer(question, text, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + + >>> answer_offset = len(tokenizer(question)[0]) + + >>> predict_answer_tokens = inputs.input_ids[ + ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 + ... ] + >>> predicted = tokenizer.decode(predict_answer_tokens) + >>> predicted + ' a nice puppet' + ```""" + from transformers.modeling_outputs import QuestionAnsweringModelOutput + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py index 6a8d7b636375..bc3a9bf1b010 100644 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py @@ -8,61 +8,11 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward +from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn -def check_bert_for_pretraining_forward(): - configuration = BertConfig() - model = BertForPreTraining(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - # print(rank) - layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - - x = torch.randint(0, 1000, (2, 3)) - hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x) - output = bert_for_pretraining_forward( - self=model, - input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager, - stage_index=stage_index, - ) - assert output['hidden_states'].shape == (2, 3, 768) - - else: - attention_mask = torch.ones((2, 3)) - output = bert_for_pretraining_forward(self=model, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager, - stage_index=stage_index) - assert output[0].shape == (2, 3, 30522) - # assert output[1].shape == (2, 768) - - def check_bert_for_pretraining_policy(): configuration = BertConfig() model = BertForPreTraining(configuration) @@ -92,12 +42,10 @@ def check_bert_for_pretraining_policy(): model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) model_policy.set_shard_config(model_config) layers = model_policy.get_held_layers() - assert layers is not None - - -def run_dist_model(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_for_pretraining_forward() + if stage_manager.is_first_stage(): + assert len(layers) == 6 + 1 + else: + assert len(layers) == 6 + 2 def run_dist_policy(rank, world_size, port): @@ -105,12 +53,6 @@ def run_dist_policy(rank, world_size, port): check_bert_for_pretraining_policy() -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_for_pretraining_forward(): - spawn(run_dist_model, 4) - - @pytest.mark.dist @rerun_if_address_is_in_use() def test_bert_for_pretraining_policy(): @@ -119,5 +61,4 @@ def test_bert_for_pretraining_policy(): if __name__ == "__main__": """test the bert for pretraining model forward and bert for pretraining model policy""" - test_bert_for_pretraining_forward() test_bert_for_pretraining_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_lm_head_model.py b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py index cd47f7a33c4b..1aeb00123db8 100644 --- a/tests/test_pipeline/test_policy/test_bert_lm_head_model.py +++ b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py @@ -8,62 +8,11 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lm_head_model_forward +from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn -def check_bert_lm_head_model_forward(): - configuration = BertConfig() - model = BertLMHeadModel(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - # print(rank) - layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - x = torch.randint(0, 1000, (2, 3)) - hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x) - - output = bert_lm_head_model_forward(self=model, - input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager, - stage_index=stage_index) - print(output['hidden_states'].shape) - assert output['hidden_states'].shape == (2, 3, 768) - - else: - attention_mask = torch.ones((2, 3)) - output = bert_lm_head_model_forward(self=model, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager, - stage_index=stage_index) - print(output[0].shape) - assert output[0].shape == (2, 3, 30522) - - # assert output[1].shape == (2, 768) - - def check_bert_lmhead_policy(): configuration = BertConfig() model = BertLMHeadModel(configuration) @@ -93,12 +42,10 @@ def check_bert_lmhead_policy(): model_policy.set_shard_config(model_config) layers = model_policy.get_held_layers() - assert layers is not None - - -def run_dist_model(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_lm_head_model_forward() + if stage_manager.is_first_stage(): + assert len(layers) == 6 + 1 + else: + assert len(layers) == 6 + 2 def run_dist_policy(rank, world_size, port): @@ -106,12 +53,6 @@ def run_dist_policy(rank, world_size, port): check_bert_lmhead_policy() -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_lm_head_model_forward(): - spawn(run_dist_model, 4) - - @pytest.mark.dist @rerun_if_address_is_in_use() def test_bert_lmhead_policy(): @@ -119,6 +60,5 @@ def test_bert_lmhead_policy(): if __name__ == "__main__": - """test the bert for pretraining model forward and bert for pretraining model policy""" - test_bert_lm_head_model_forward() + """test the bert for lm head model policy""" test_bert_lmhead_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py index f116bc761aa7..b366df01788b 100644 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -1,5 +1,8 @@ +''' +In the test policy we only test policy: held layers and others, as the tests for forward logic are done in test_shardformer/test_model +''' + import pytest -import torch import torch.distributed as dist from transformers.models.bert.modeling_bert import BertModel @@ -7,60 +10,11 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.policies.bert import BertModelPolicy, bert_model_forward +from colossalai.shardformer.policies.bert import BertModelPolicy from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn -def check_bert_model_forward(): - # this test may crash for internet reasons - model = BertModel.from_pretrained('bert-base-uncased') - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - # print(rank) - layers_per_stage = Policy.distribute_layers(len(model.encoder.layer), 2) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - x = torch.randint(0, 1000, (2, 3)) - hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x) - output = bert_model_forward(self=model, - input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager, - stage_index=stage_index) - assert output['hidden_states'].shape == (2, 3, 768) - else: - attention_mask = torch.ones((2, 3)) - output = bert_model_forward(self=model, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager, - stage_index=stage_index) - print(output[0].shape) - assert output[0].shape == (2, 3, 768) - - # assert output[1].shape == (2, 768) - - def check_bert_model_policy(): model = BertModel.from_pretrained('bert-base-uncased') DP_DIM, PP_DIM = 0, 1 @@ -90,12 +44,10 @@ def check_bert_model_policy(): layers = model_policy.get_held_layers() - assert layers is not None - - -def run_dist_model(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_model_forward() + if stage_manager.is_first_stage(): + assert len(layers) == 6 + 1 + else: + assert len(layers) == 6 + 1 def run_dist_policy(rank, world_size, port): @@ -103,12 +55,6 @@ def run_dist_policy(rank, world_size, port): check_bert_model_policy() -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_model_forward(): - spawn(run_dist_model, 4) - - @pytest.mark.dist @rerun_if_address_is_in_use() def test_bert_model_policy(): @@ -116,6 +62,5 @@ def test_bert_model_policy(): if __name__ == "__main__": - """test the bert model forward and bert model policy""" - #test_bert_model_forward() + """test the bert model policy""" test_bert_model_policy() diff --git a/tests/test_pipeline/test_policy/test_bloom_model.py b/tests/test_pipeline/test_policy/test_bloom_model.py index 73584b4f8ef1..e6a222f4e3d5 100644 --- a/tests/test_pipeline/test_policy/test_bloom_model.py +++ b/tests/test_pipeline/test_policy/test_bloom_model.py @@ -5,12 +5,14 @@ import colossalai from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.policy.bloom import BloomModelPolicy, bloom_model_forward from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.policies.bloom import BloomModelPolicy +from colossalai.shardformer.shard import ShardConfig from colossalai.testing import rerun_if_address_is_in_use, spawn -def check_bloom_model_forward(): +def check_bloom_model_policy(): # create a BloomModel configuration = BloomConfig() model = BloomModel(configuration) @@ -33,67 +35,16 @@ def check_bloom_model_forward(): stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank() - # print(rank) - x = torch.randint(0, 1000, (2, 3)) - hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32) + model_policy = BloomModelPolicy() + model_policy.set_model(model) + model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) + model_policy.set_shard_config(model_config) + layers = model_policy.get_held_layers() if stage_manager.is_first_stage(): - attention_mask = torch.ones_like(x) - output = bloom_model_forward(self=model, - input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 64) - print('start the training') + assert len(layers) == 1 + 2 else: - attention_mask = torch.ones((2, 3)) - output = bloom_model_forward(self=model, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) - print(output[0].shape) - assert output[0].shape == (2, 3, 64) - print('end the training') - print(output) - - # assert output[1].shape == (2, 768) - - -def check_bloom_model_policy(): - # create a BloomModel - configuration = BloomConfig() - model = BloomModel(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - - model_policy = BloomModelPolicy(stage_manager=stage_manager, num_layers=len(model.h), num_stages=2) - assert model_policy.layers_per_stage == [1, 1] - layers = model_policy.get_hold_layers(model) - for layer in layers: - print(layer) - - -def run_dist_model(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bloom_model_forward() + assert len(layers) == 1 + 1 def run_dist_policy(rank, world_size, port): @@ -101,15 +52,6 @@ def run_dist_policy(rank, world_size, port): check_bloom_model_policy() -#TODO: Bloom model should be fixed after bert model -@pytest.mark.skip(reason="Bloom model should be fixed after bert model") -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bloom_model_forward(): - spawn(run_dist_model, 4) - - -@pytest.mark.skip(reason="Bloom model should be fixed after bert model") @pytest.mark.dist @rerun_if_address_is_in_use() def test_bloom_model_policy(): @@ -117,7 +59,5 @@ def test_bloom_model_policy(): if __name__ == "__main__": - """test the bloom model forward and bloom model policy""" - # test_bloom_model_forward() - # test_bloom_model_policy() - #TODO: Bloom model should be fixed after bert model is all ready + """test the bloom model policy""" + test_bloom_model_policy() diff --git a/tests/test_shardformer/test_model/test_shard_opt_pipeline.py b/tests/test_shardformer/test_model/test_shard_opt_pipeline.py new file mode 100644 index 000000000000..0684418d0a8d --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_opt_pipeline.py @@ -0,0 +1,70 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_pipeline_model + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # TODO: add tests for forward/backward later + pass + + +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('enable_fused_normalization', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_opt +def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): + inputs = data_gen_fn() + inputs = {k: v.cuda() for k, v in inputs.items()} + input_ids, _ = inputs['input_ids'], inputs['attention_mask'] + batch_size, seq_len = input_ids.shape + hidden_size = 128 + hidden_state_shape = (batch_size, seq_len, hidden_size) + + if not stage_manager.is_first_stage(): + # change inputs if not the first stage + + hidden_states = torch.zeros(*hidden_state_shape).cuda() + inputs['input_ids'] = None + inputs['hidden_states'] = hidden_states + + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + sharded_model.train() + + output = sharded_model(**inputs) + if stage_manager.is_last_stage(): + assert output[0] is not None + else: + assert output['hidden_states'].shape == hidden_state_shape + torch.cuda.empty_cache() + + +def check_opt(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_opt_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_opt(): + spawn(check_opt, 4) + + +if __name__ == "__main__": + test_opt() From 29efdc33aca57fb2936e075c9a3b48d20970d755 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Thu, 20 Jul 2023 17:21:28 +0800 Subject: [PATCH 32/80] [hotfix] fix opt pipeline (#4293) * opt forward and test * pause * finish opt model pipeline * finish opt pipeline * opt forward and test * pause * finish opt model pipeline * finish opt pipeline * fix opt * set transformers version * refactor the test pipeline * fix bug --- colossalai/shardformer/policies/opt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 31934965ee56..244a0a54ef63 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -12,6 +12,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -198,8 +199,8 @@ def get_held_layers(self) -> List[nn.Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: opt_model = self.model - num_stages = self.pipeline_stage_manager.num_stages - if self.pipeline_stage_manager and num_stages > 1: + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + num_stages = self.pipeline_stage_manager.num_stages if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight): return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}] From 7d323335600fa06e8d9635426fde3fa19d355682 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 21 Jul 2023 10:46:39 +0800 Subject: [PATCH 33/80] [pipeline] reformat for unified design (#4283) * bert_reformat * reformat * reformat * fix a typo * format * format * fix bug --- colossalai/shardformer/modeling/bert.py | 989 ++++++++++++++++++ colossalai/shardformer/modeling/bloom.py | 620 ++++++++++++ colossalai/shardformer/modeling/llama.py | 394 ++++++++ colossalai/shardformer/policies/bert.py | 1161 ++-------------------- colossalai/shardformer/policies/bloom.py | 724 +------------- colossalai/shardformer/policies/llama.py | 511 ++-------- 6 files changed, 2206 insertions(+), 2193 deletions(-) create mode 100644 colossalai/shardformer/modeling/bert.py create mode 100644 colossalai/shardformer/modeling/llama.py diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py new file mode 100644 index 000000000000..df64c93cf85a --- /dev/null +++ b/colossalai/shardformer/modeling/bert.py @@ -0,0 +1,989 @@ +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.models.bert.modeling_bert import ( + BertForMaskedLM, + BertForMultipleChoice, + BertForNextSentencePrediction, + BertForPreTraining, + BertForPreTrainingOutput, + BertForQuestionAnswering, + BertForSequenceClassification, + BertForTokenClassification, + BertLMHeadModel, + BertModel, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class BertPipelineForwards: + ''' + This class serves as a micro library for forward function substitution of Bert models + under pipeline setting. + ''' + + @staticmethod + def bert_model_forward( + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + stage_index: Optional[List[int]] = None, + ): + # TODO: add explaination of the output here. + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + else: + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + hidden_states = hidden_states if hidden_states is not None else None + + if stage_manager.is_first_stage(): + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + # inherit from bert_layer,this should be changed when we add the feature to record hidden_states + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + # layer_outputs + layer_outputs = hidden_states if hidden_states is not None else None + for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): + if stage_manager.is_first_stage() and idx == 0: + encoder_attention_mask = encoder_extended_attention_mask + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[idx] if head_mask is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + \ + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # end of a stage loop + sequence_output = layer_outputs[0] if layer_outputs is not None else None + + if stage_manager.is_last_stage(): + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + if not return_dict: + return (sequence_output, pooled_output) + layer_outputs[1:] + # return dict is not supported at this moment + else: + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + # output of non-first and non-last stages: must be a dict + else: + # intermediate stage always return dict + return { + 'hidden_states': hidden_states, + } + + @staticmethod + def bert_for_pretraining_forward( + self: BertForPreTraining, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + ): + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None, + stage_index=stage_index, + ) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + + if stage_manager.is_last_stage(): + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + # the last stage for pretraining model + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + + # intermediate stage always return dict + return { + 'hidden_states': hidden_states, + } + + @staticmethod + def bert_lm_head_model_forward( + self: BertLMHeadModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + logger = logging.get_logger(__name__) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + use_cache = False + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states if hidden_states is not None else None, + stage_index=stage_index) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + # intermediate stage always return dict + return {'hidden_states': hidden_states} + + @staticmethod + def bert_for_masked_lm_forward( + self: BertForMaskedLM, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def bert_for_next_sentence_prediction_forward( + self: BertForNextSentencePrediction, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + **kwargs, + ): + #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, BertForNextSentencePrediction + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ + logger = logging.get_logger(__name__) + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = BertPipelineForwards.bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index) + + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + # intermediate stage always return dict + return {'hidden_states': hidden_states} + + @staticmethod + def bert_for_sequence_classification_forward( + self: BertForSequenceClassification, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = BertPipelineForwards.bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index) + + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def bert_for_token_classification_forward( + self: BertForTokenClassification, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def bert_for_multiple_choice_forward( + self: BertForMultipleChoice, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # in our pipeline design,input ids are copied for every stage and shouldn't be none + # the input_ids for multiple choice model is [batch_size, num_choices, sequence_length] + if stage_manager.is_last_stage(): + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None else None) + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def bert_for_question_answering_forward( + self: BertForQuestionAnswering, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.Tensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + stage_index: Optional[List[int]] = None, + ): + # NOTE: the arg start_position and end_position are used only for the last stage + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index a3d774ff2abb..fd200665df3d 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -1,6 +1,27 @@ +import warnings +from typing import List, Optional, Tuple, Union + import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.models.bloom.modeling_bloom import ( + BloomForCausalLM, + BloomForQuestionAnswering, + BloomForSequenceClassification, + BloomForTokenClassification, + BloomModel, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: @@ -67,3 +88,602 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) return build_bloom_alibi_tensor + + +class BloomPipelineForwards: + ''' + This class serves as a micro library for bloom pipeline forwards. + ''' + + @staticmethod + def bloom_model_forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']: + + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # add warnings here + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + # case: First stage of training + if stage_manager.is_first_stage(): + # check input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + # initialize in the first stage and then pass to the next stage + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + + # extra recording tensor should be generated in the first stage + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] # source_len + + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + # causal_mask is constructed every stage and its input is passed through different stages + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + start_idx, end_idx = stage_index[0], stage_index[1] + for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), + start=start_idx): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + + if use_cache is True: + presents = presents + (outputs[1],) + if output_attentions: + all_self_attentions = all_self_attentions + \ + (outputs[2 if use_cache else 1],) + + if stage_manager.is_last_stage(): + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + # TODO: deal with all_hidden_states, all_self_attentions, presents + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + # attention_mask is not returned ; presents = past_key_values + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + else: + # always return dict for imediate stage + return {'hidden_states': hidden_states} + + @staticmethod + def bloom_for_causal_lm_forward(self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + **deprecated_arguments): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), + shift_labels.view(batch_size * seq_length)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def bloom_for_sequence_classification_forward( + self: BloomForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + **deprecated_arguments, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + transformer_outputs = BloomPipelineForwards.bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + if stage_manager.is_last_stage(): + batch_size = hidden_states.shape[0] + #update batch size + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def bloom_for_token_classification_forward( + self: BloomForTokenClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + **deprecated_arguments, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + transformer_outputs = BloomPipelineForwards.bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), + labels.view(batch_size * seq_length)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def bloom_for_question_answering_forward( + self: BloomForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + outputs = BloomPipelineForwards.bloom_model_forward( + self.transformer, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py new file mode 100644 index 000000000000..7bc626fe6825 --- /dev/null +++ b/colossalai/shardformer/modeling/llama.py @@ -0,0 +1,394 @@ +from typing import Callable, List, Optional + +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class LlamaPipelineForwards: + ''' + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + ''' + + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device) + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, + past_key_values_length) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # always return dict for imediate stage + return {'hidden_states': hidden_states} + + def llama_for_causal_lm_forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = LlamaPipelineForwards.llama_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + def llama_for_sequence_classification_forward( + self: LlamaForSequenceClassification, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if return_dict: + logger.warning_once('return_dict is not supported for pipeline models at the moment') + return_dict = False + + transformer_outputs = LlamaPipelineForwards.llama_model_forward( + self.model, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + batch_size = hidden_states.shape[0] + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 0a1a466210b2..f6a4c706eb14 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,40 +1,15 @@ from functools import partial -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List -import torch import torch.nn as nn from torch import Tensor -from torch.nn import CrossEntropyLoss, Module -from transformers.modeling_outputs import ( - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, - MultipleChoiceModelOutput, - NextSentencePredictorOutput, - QuestionAnsweringModelOutput, - SequenceClassifierOutput, - TokenClassifierOutput, -) -from transformers.models.bert.modeling_bert import ( - BertForMaskedLM, - BertForMultipleChoice, - BertForNextSentencePrediction, - BertForPreTraining, - BertForPreTrainingOutput, - BertForQuestionAnswering, - BertForSequenceClassification, - BertForTokenClassification, - BertLMHeadModel, - BertModel, -) -from transformers.utils import logging +from torch.nn import Module import colossalai.shardformer.layer as col_nn -from colossalai.pipeline.stage_manager import PipelineStageManager +from ..modeling.bert import BertPipelineForwards from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -logger = logging.get_logger(__name__) - __all__ = [ 'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMdHeadModelPolicy', 'BertForMaskedLMPolicy', 'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy', @@ -207,6 +182,27 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli return + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == 'BertModel': + module = self.model + else: + module = self.model.bert + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.pooler) + + return held_layers + # BertModel class BertModelPolicy(BertPolicy): @@ -217,21 +213,15 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() from transformers.models.bert.modeling_bert import BertModel - self.set_pipeline_forward(model_cls=BertModel, new_forward=bert_model_forward, policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BertModel, + new_forward=BertPipelineForwards.bert_model_forward, + policy=policy) return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" - module = self.model - stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.encoder.layer[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.pooler) + held_layers = super().get_held_layers() return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -250,30 +240,24 @@ def module_policy(self): policy = self.add_lm_head_policy(policy) policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForPreTraining - self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BertForPreTraining, + new_forward=BertPipelineForwards.bert_for_pretraining_forward, + policy=policy) return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage""" - module = self.model + held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) - held_layers = [] - if stage_manager.is_first_stage(): - held_layers.append(module.bert.embeddings) - - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.bert.pooler) - held_layers.append(module.cls) + held_layers.append(self.model.cls) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: model = self.model - if self.pipeline_stage_manager: + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight): # tie weights return [{ @@ -294,29 +278,25 @@ def module_policy(self): policy = self.add_lm_head_policy(policy) policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertLMHeadModel - self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BertLMHeadModel, + new_forward=BertPipelineForwards.bert_lm_head_model_forward, + policy=policy) return policy def get_held_layers(self) -> List[Module]: """ get pipeline layers for current stage """ - module = self.model - held_layers = [] + held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.bert.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): - held_layers.append(module.bert.pooler) - held_layers.append(module.cls) + held_layers.append(self.model.cls) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: bert_model = self.model.bert - if self.pipeline_stage_manager: + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): # tie weights return [{ @@ -337,29 +317,25 @@ def module_policy(self): policy = self.add_lm_head_policy(policy) mpolicy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForMaskedLM - self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BertForMaskedLM, + new_forward=BertPipelineForwards.bert_for_masked_lm_forward, + policy=policy) return policy def get_held_layers(self) -> List[Module]: """ get pipeline layers for current stage """ - module = self.model - held_layers = [] + held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.bert.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): - held_layers.append(module.bert.pooler) - held_layers.append(module.cls) + held_layers.append(self.model.cls) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: bert_model = self.model.bert - if self.pipeline_stage_manager: + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): # tie weights return [{ @@ -391,10 +367,10 @@ def module_policy(self): ]) } policy.update(addon_module) - - self.set_pipeline_forward(model_cls=BertForSequenceClassification, - new_forward=bert_for_sequence_classification_forward, - policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BertForSequenceClassification, + new_forward=BertPipelineForwards.bert_for_sequence_classification_forward, + policy=policy) return policy @@ -402,18 +378,11 @@ def get_held_layers(self) -> List[Module]: """ get pipeline layers for current stage """ - module = self.model - held_layers = [] + held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.bert.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): - held_layers.append(module.bert.pooler) - held_layers.append(module.dropout) - held_layers.append(module.classifier) + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -443,10 +412,10 @@ def module_policy(self): ]) } policy.update(addon_module) - - self.set_pipeline_forward(model_cls=BertForTokenClassification, - new_forward=bert_for_token_classification_forward, - policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BertForTokenClassification, + new_forward=BertPipelineForwards.bert_for_token_classification_forward, + policy=policy) return policy @@ -454,18 +423,11 @@ def get_held_layers(self) -> List[Module]: """ get pipeline layers for current stage """ - module = self.model - held_layers = [] + held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.bert.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): - held_layers.append(module.bert.pooler) - held_layers.append(module.dropout) - held_layers.append(module.classifier) + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -482,9 +444,10 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() from transformers.models.bert.modeling_bert import BertForNextSentencePrediction - self.set_pipeline_forward(model_cls=BertForNextSentencePrediction, - new_forward=bert_for_next_sentence_prediction_forward, - policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BertForNextSentencePrediction, + new_forward=BertPipelineForwards.bert_for_next_sentence_prediction_forward, + policy=policy) return policy @@ -492,17 +455,10 @@ def get_held_layers(self) -> List[Module]: """ get pipeline layers for current stage """ - module = self.model - held_layers = [] + held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.bert.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): - held_layers.append(module.bert.pooler) - held_layers.append(module.cls) + held_layers.append(self.model.cls) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -532,10 +488,10 @@ def module_policy(self): ]) } policy.update(addon_module) - - self.set_pipeline_forward(model_cls=BertForMultipleChoice, - new_forward=bert_for_multiple_choice_forward, - policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BertForMultipleChoice, + new_forward=BertPipelineForwards.bert_for_multiple_choice_forward, + policy=policy) return policy @@ -543,18 +499,11 @@ def get_held_layers(self) -> List[Module]: """ get pipeline layers for current stage """ - module = self.model - held_layers = [] + held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.bert.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): - held_layers.append(module.bert.pooler) - held_layers.append(module.dropout) - held_layers.append(module.classifier) + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -570,9 +519,10 @@ def __init__(self) -> None: def module_policy(self): from transformers.models.bert.modeling_bert import BertForQuestionAnswering policy = super().module_policy() - self.set_pipeline_forward(model_cls=BertForQuestionAnswering, - new_forward=bert_for_question_answering_forward, - policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BertForQuestionAnswering, + new_forward=BertPipelineForwards.bert_for_question_answering_forward, + policy=policy) return policy @@ -580,957 +530,12 @@ def get_held_layers(self) -> List[Module]: """ get pipeline layers for current stage """ - module = self.model - held_layers = [] + held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - layers_per_stage = self.distribute_layers(len(module.bert.encoder.layer), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.bert.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): - held_layers.append(module.bert.pooler) - held_layers.append(module.qa_outputs) + held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: # no shared params for sequence classification model return [] - - -def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage - stage_index: Optional[List[int]] = None, -): - # TODO: add explaination of the output here. - r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ - # debugging - # preprocess: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - batch_size, seq_length = input_shape - device = input_ids.device if input_ids is not None else inputs_embeds.device - if token_type_ids is None: - if hasattr(self.embeddings, "token_type_ids"): - buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - else: - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device - - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - attention_mask = extended_attention_mask - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - hidden_states = hidden_states if hidden_states is not None else None - - if stage_manager.is_first_stage(): - hidden_states = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - - # inherit from bert_layer,this should be changed when we add the feature to record hidden_states - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - if self.encoder.gradient_checkpointing and self.encoder.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - next_decoder_cache = () if use_cache else None - - start_idx, end_idx = stage_index[0], stage_index[1] - # layer_outputs - layer_outputs = hidden_states if hidden_states is not None else None - for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): - if stage_manager.is_first_stage() and idx == 0: - encoder_attention_mask = encoder_extended_attention_mask - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[idx] if head_mask is not None else None - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.encoder.gradient_checkpointing and self.encoder.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + \ - (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # end of a stage loop - sequence_output = layer_outputs[0] if layer_outputs is not None else None - - if stage_manager.is_last_stage(): - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - if not return_dict: - return (sequence_output, pooled_output) + layer_outputs[1:] - # return dict is not supported at this moment - else: - return BaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - # output of non-first and non-last stages: must be a dict - else: - # intermediate stage always return dict - return { - 'hidden_states': hidden_states, - } - - -def bert_for_pretraining_forward( - self: BertForPreTraining, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - next_sentence_label: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_manager: Optional[PipelineStageManager] = None, - stage_index: Optional[List[int]] = None, -): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - outputs = bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states if hidden_states is not None else None, - stage_index=stage_index, - ) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None - if stage_manager.is_last_stage(): - sequence_output, pooled_output = outputs[:2] - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) - # the last stage for pretraining model - total_loss = None - if labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss - - if not return_dict: - output = (prediction_scores, seq_relationship_score) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return BertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - - # intermediate stage always return dict - return { - 'hidden_states': hidden_states, - } - - -def bert_lm_head_model_forward( - self: BertLMHeadModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.Tensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_manager: Optional[PipelineStageManager] = None, - stage_index: Optional[List[int]] = None, -): - r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if labels is not None: - use_cache = False - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - outputs = bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states if hidden_states is not None else None, - stage_index=stage_index) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None - - if stage_manager.is_last_stage(): - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - lm_loss = None - if labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((lm_loss,) + output) if lm_loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=lm_loss, - logits=prediction_scores, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - # intermediate stage always return dict - return {'hidden_states': hidden_states} - - -def bert_for_masked_lm_forward( - self: BertForMaskedLM, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.Tensor] = None, - stage_manager: Optional[PipelineStageManager] = None, - stage_index: Optional[List[int]] = None, -): - # -> Union[Tuple[torch.Tensor], MaskedLMOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - outputs = bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - ) - - if stage_manager.is_last_stage(): - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - masked_lm_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return MaskedLMOutput( - loss=masked_lm_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} - - -def bert_for_next_sentence_prediction_forward( - self: BertForNextSentencePrediction, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.Tensor] = None, - stage_manager: Optional[PipelineStageManager] = None, - stage_index: Optional[List[int]] = None, - **kwargs, -): - # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair - (see `input_ids` docstring). Indices should be in `[0, 1]`: - - - 0 indicates sequence B is a continuation of sequence A, - - 1 indicates sequence B is a random sequence. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, BertForNextSentencePrediction - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") - - >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." - >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") - - >>> outputs = model(**encoding, labels=torch.LongTensor([1])) - >>> logits = outputs.logits - >>> assert logits[0, 0] < logits[0, 1] # next sentence was random - ``` - """ - - if "next_sentence_label" in kwargs: - warnings.warn( - "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" - " `labels` instead.", - FutureWarning, - ) - labels = kwargs.pop("next_sentence_label") - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index) - - if stage_manager.is_last_stage(): - pooled_output = outputs[1] - seq_relationship_scores = self.cls(pooled_output) - - next_sentence_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) - - if not return_dict: - output = (seq_relationship_scores,) + outputs[2:] - return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output - - return NextSentencePredictorOutput( - loss=next_sentence_loss, - logits=seq_relationship_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - # intermediate stage always return dict - return {'hidden_states': hidden_states} - - -def bert_for_sequence_classification_forward( - self: BertForSequenceClassification, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.Tensor] = None, - stage_manager: Optional[PipelineStageManager] = None, - stage_index: Optional[List[int]] = None, -): - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index) - - if stage_manager.is_last_stage(): - pooled_output = outputs[1] - - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} - - -def bert_for_token_classification_forward( - self: BertForTokenClassification, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.Tensor] = None, - stage_manager: Optional[PipelineStageManager] = None, - stage_index: Optional[List[int]] = None, -): - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - ) - - if stage_manager.is_last_stage(): - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} - - -def bert_for_multiple_choice_forward( - self: BertForMultipleChoice, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.Tensor] = None, - stage_manager: Optional[PipelineStageManager] = None, - stage_index: Optional[List[int]] = None, -): - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., - num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See - `input_ids` above) - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # in our pipeline design,input ids are copied for every stage and shouldn't be none - # the input_ids for multiple choice model is [batch_size, num_choices, sequence_length] - if stage_manager.is_last_stage(): - num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] - - input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None - attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None - position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None else None) - - outputs = bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - ) - if stage_manager.is_last_stage(): - pooled_output = outputs[1] - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, num_choices) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - - if not return_dict: - output = (reshaped_logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return MultipleChoiceModelOutput( - loss=loss, - logits=reshaped_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} - - -def bert_for_question_answering_forward( - self: BertForQuestionAnswering, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - start_positions: Optional[torch.Tensor] = None, - end_positions: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.Tensor] = None, - stage_manager: Optional[PipelineStageManager] = None, - stage_index: Optional[List[int]] = None, -): - # NOTE: the arg start_position and end_position are used only for the last stage - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - ) - if stage_manager.is_last_stage(): - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return QuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index b0e45452964e..15bae2f4a959 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,35 +1,15 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Optional, Tuple, Union -import torch import torch.nn as nn from torch import Tensor -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithCrossAttentions, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) -from transformers.models.bloom.modeling_bloom import ( - BloomForCausalLM, - BloomForQuestionAnswering, - BloomForSequenceClassification, - BloomForTokenClassification, - BloomModel, -) -from transformers.utils import logging +from torch.nn import Module import colossalai.shardformer.layer as col_nn -from colossalai.pipeline.stage_manager import PipelineStageManager -from ..modeling.bloom import build_bloom_alibi_tensor_fn +from ..modeling.bloom import BloomPipelineForwards, build_bloom_alibi_tensor_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -logger = logging.get_logger(__name__) - class BloomPolicy(Policy): @@ -150,6 +130,28 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli target_key=model_cls) return + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == 'BloomModel': + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.word_embeddings) + held_layers.append(module.word_embeddings_layernorm) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + + return held_layers + class BloomModelPolicy(BloomPolicy): @@ -159,27 +161,17 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() from transformers.models.bloom.modeling_bloom import BloomModel - self.set_pipeline_forward(model_cls=BloomModel, new_forward=bloom_model_forward, policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BloomModel, + new_forward=BloomPipelineForwards.bloom_model_forward, + policy=policy) return policy def get_held_layers(self) -> List[Module]: """ get pipeline layers for current stage """ - module = self.model - stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.word_embeddings) - held_layers.append(module.word_embeddings_layernorm) - - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.h[start_idx:end_idx]) - - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) - + held_layers = super().get_held_layers() return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -199,29 +191,23 @@ def module_policy(self): suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), policy=policy, target_key=BloomForCausalLM) - - self.set_pipeline_forward(model_cls=BloomForCausalLM, new_forward=bloom_for_causal_lm_forward, policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BloomForCausalLM, + new_forward=BloomPipelineForwards.bloom_for_causal_lm_forward, + policy=policy) return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" - module = self.model stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(module.transformer.h), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.transformer.word_embeddings) - held_layers.append(module.transformer.word_embeddings_layernorm) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.transformer.h[start_idx:end_idx]) + held_layers = super().get_held_layers() if stage_manager.is_last_stage(): - held_layers.append(module.transformer.ln_f) - held_layers.append(module.lm_head) + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: bloom_model = self.model - if self.pipeline_stage_manager: + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if id(bloom_model.transformer.word_embeddings.weight) == id(bloom_model.lm_head.weight): # tie weights return [{ @@ -243,25 +229,18 @@ def module_policy(self): suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), policy=policy, target_key=BloomForSequenceClassification) - self.set_pipeline_forward(model_cls=BloomForSequenceClassification, - new_forward=bloom_for_sequence_classification_forward, - policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BloomForSequenceClassification, + new_forward=BloomPipelineForwards.bloom_for_sequence_classification_forward, + policy=policy) return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" - module = self.model stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(module.transformer.h), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.transformer.word_embeddings) - held_layers.append(module.transformer.word_embeddings_layernorm) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.transformer.h[start_idx:end_idx]) + held_layers = super().get_held_layers() if stage_manager.is_last_stage(): - held_layers.append(module.transformer.ln_f) - held_layers.append(module.score) + held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -288,28 +267,20 @@ def module_policy(self): ], policy=policy, target_key=BloomForTokenClassification) - - self.set_pipeline_forward(model_cls=BloomForTokenClassification, - new_forward=bloom_for_token_classification_forward, - policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BloomForTokenClassification, + new_forward=BloomPipelineForwards.bloom_for_token_classification_forward, + policy=policy) return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" - module = self.model stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(module.transformer.h), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.transformer.word_embeddings) - held_layers.append(module.transformer.word_embeddings_layernorm) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.transformer.h[start_idx:end_idx]) + held_layers = super().get_held_layers() if stage_manager.is_last_stage(): - held_layers.append(module.transformer.ln_f) - held_layers.append(module.dropout) - held_layers.append(module.classifier) + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -322,605 +293,20 @@ class BloomForQuestionAnsweringPolicy(BloomPolicy): def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForQuestionAnswering policy = super().module_policy() - self.set_pipeline_forward(model_cls=BloomForQuestionAnswering, - new_forward=bloom_for_question_answering_forward, - policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward(model_cls=BloomForQuestionAnswering, + new_forward=BloomPipelineForwards.bloom_for_question_answering_forward, + policy=policy) return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" - module = self.model + held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(module.transformer.h), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.transformer.word_embeddings) - held_layers.append(module.transformer.word_embeddings_layernorm) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.transformer.h[start_idx:end_idx]) if stage_manager.is_last_stage(): - held_layers.append(module.transformer.ln_f) - held_layers.append(module.qa_outputs) + held_layers.append(self.model.qa_outputs) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in bloom for question answering model""" return [] - - -def bloom_model_forward( - self: BloomModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - **deprecated_arguments, -) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - if len(deprecated_arguments) > 0: - raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # add warnings here - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - # case: First stage of training - if stage_manager.is_first_stage(): - # check input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - # initialize in the first stage and then pass to the next stage - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - - # extra recording tensor should be generated in the first stage - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] # source_len - - seq_length_with_past = seq_length_with_past + past_key_values_length - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) - else: - attention_mask = attention_mask.to(hidden_states.device) - - alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - - # causal_mask is constructed every stage and its input is passed through different stages - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) - - start_idx, end_idx = stage_index[0], stage_index[1] - for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx])): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - alibi, - causal_mask, - layer_past, - head_mask[i], - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=causal_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - ) - - hidden_states = outputs[0] - - if use_cache is True: - presents = presents + (outputs[1],) - if output_attentions: - all_self_attentions = all_self_attentions + \ - (outputs[2 if use_cache else 1],) - - if stage_manager.is_last_stage(): - # Add last hidden state - hidden_states = self.ln_f(hidden_states) - - # TODO: deal with all_hidden_states, all_self_attentions, presents - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if stage_manager.is_last_stage(): - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - # attention_mask is not returned ; presents = past_key_values - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - else: - # always return dict for imediate stage - return {'hidden_states': hidden_states} - - -def bloom_for_causal_lm_forward(self: 'BloomForCausalLM', - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - **deprecated_arguments): - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - if len(deprecated_arguments) > 0: - raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - transformer_outputs = bloom_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None - if stage_manager.is_last_stage(): - hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), - shift_labels.view(batch_size * seq_length)) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} - - -def bloom_for_sequence_classification_forward( - self: BloomForSequenceClassification, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - **deprecated_arguments, -): - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - if len(deprecated_arguments) > 0: - raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - transformer_outputs = bloom_model_forward( - self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None - if stage_manager.is_last_stage(): - batch_size = hidden_states.shape[0] - # update batch size - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - logger.warning( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`") - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} - - -def bloom_for_token_classification_forward( - self: BloomForTokenClassification, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - **deprecated_arguments, -): - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - if len(deprecated_arguments) > 0: - raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - transformer_outputs = bloom_model_forward( - self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None - if stage_manager.is_last_stage(): - hidden_states = transformer_outputs[0] - hidden_states = self.dropout(hidden_states) - logits = self.classifier(hidden_states) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - batch_size, seq_length = labels.shape - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)) - - if not return_dict: - output = (logits,) + transformer_outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} - - -def bloom_for_question_answering_forward( - self: BloomForQuestionAnswering, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, -): - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - outputs = bloom_model_forward( - self.transformer, - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None - - if stage_manager.is_last_stage(): - sequence_output = outputs[0] - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return QuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index c7cd8182a4ca..5988366ed57b 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,25 +1,15 @@ from functools import partial -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Union -import torch import torch.nn as nn from torch import Tensor -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel -from transformers.utils import logging - -from colossalai.pipeline.stage_manager import PipelineStageManager +from torch.nn import Module + from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from ..modeling.llama import LlamaPipelineForwards from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -logger = logging.get_logger(__name__) - __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] @@ -119,32 +109,35 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def postprocess(self): return self.model - -class LlamaModelPolicy(LlamaPolicy): - - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - policy = super().module_policy() - from transformers.models.llama.modeling_llama import LlamaModel + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: - # set None as default stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages) + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index) - } + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, - target_key=LlamaModel) - return policy + target_key=model_cls) + + return def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" - module = self.model + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == 'LlamaModel': + module = self.model + else: + module = self.model.model stage_manager = self.pipeline_stage_manager + held_layers = [] layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) if stage_manager.is_first_stage(): @@ -153,6 +146,28 @@ def get_held_layers(self) -> List[Module]: held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) + + return held_layers + + +class LlamaModelPolicy(LlamaPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + from transformers.models.llama.modeling_llama import LlamaModel + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward(model_cls=LlamaModel, + new_forward=LlamaPipelineForwards.llama_model_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -180,40 +195,30 @@ def module_policy(self): if self.pipeline_stage_manager: # set None as default - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': partial(llama_for_causal_lm_forward, stage_manager=stage_manager, stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LlamaForCausalLM) + self.set_pipeline_forward(model_cls=LlamaForCausalLM, + new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, + policy=policy) + return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" - module = self.model stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.model.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.model.layers[start_idx:end_idx]) + held_layers = super().get_held_layers() if stage_manager.is_last_stage(): - held_layers.append(module.model.norm) - held_layers.append(module.lm_head) + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: llama_model = self.model.model - if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight): - # tie weights - return [{ - 0: llama_model.embed_tokens.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight - }] + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if id(llama_model.embed_tokens.weight) == id( + self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1: + # tie weights + return [{ + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight + }] return [] @@ -237,405 +242,19 @@ def module_policy(self): # to be confirmed if self.pipeline_stage_manager: # set None as default - stage_manager = self.pipeline_stage_manager - layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - 'forward': - partial(llama_for_sequence_classification_forward, - stage_manager=stage_manager, - stage_index=stage_index) - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LlamaForSequenceClassification) + self.set_pipeline_forward(model_cls=LlamaForSequenceClassification, + new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward, + policy=policy) return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" - module = self.model stage_manager = self.pipeline_stage_manager - held_layers = [] - layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.model.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.model.layers[start_idx:end_idx]) + held_layers = super().get_held_layers() if stage_manager.is_last_stage(): - held_layers.append(module.model.norm) - held_layers.append(module.score) + held_layers.append(self.model.score) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in llama for sequence classification model""" return [] - - -def llama_model_forward( - self: LlamaModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, -): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device - - seq_length_with_past = seq_length - past_key_values_length = 0 - - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - position_ids = torch.arange(past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - # embed positions, for the first stage, hidden_states is the input embeddings, - # for the other stages, hidden_states is the output of the previous stage - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, - past_key_values_length) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - start_idx, end_idx = stage_index[0], stage_index[1] - for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx]): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if stage_manager.is_last_stage(): - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if stage_manager.is_last_stage(): - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - # always return dict for imediate stage - return {'hidden_states': hidden_states} - - -def llama_for_causal_lm_forward( - self: LlamaForCausalLM, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, -): - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = llama_model_forward( - self.model, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None - - if stage_manager.is_last_stage(): - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} - - -def llama_for_sequence_classification_forward( - self: LlamaForSequenceClassification, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, -): - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - transformer_outputs = llama_model_forward( - self.model, - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) - - if input_ids is not None: - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - batch_size = inputs_embeds.shape[0] - else: - batch_size = hidden_states.shape[0] - - if stage_manager.is_last_stage(): - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} From 63c2f30a38423d9fda2933121a6257283c367039 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 21 Jul 2023 16:23:04 +0800 Subject: [PATCH 34/80] [pipeline] add pipeline support for T5Stack/T5EncoderModel (#4300) * modify t5 policy & add test * pipeline stage distribution for t5 * complete t5 base policy * t5 stack: halfway * modify gpt2 pipeline test * complete pipeline forward for T5Stack/T5EncoderModel * fix docstring * move t5 util tests to test_pipeline --- colossalai/shardformer/modeling/t5.py | 279 ++++++++++++++++++ colossalai/shardformer/policies/t5.py | 179 ++++++++++- tests/kit/model_zoo/transformers/gpt.py | 20 +- .../test_policy/test_t5_pipeline_utils.py | 39 +++ .../test_model/test_shard_gpt2_pipeline.py | 12 +- .../test_model/test_shard_t5_pipeline.py | 96 ++++++ 6 files changed, 604 insertions(+), 21 deletions(-) create mode 100644 colossalai/shardformer/modeling/t5.py create mode 100644 tests/test_pipeline/test_policy/test_t5_pipeline_utils.py create mode 100644 tests/test_shardformer/test_model/test_shard_t5_pipeline.py diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py new file mode 100644 index 000000000000..cc270d5828a2 --- /dev/null +++ b/colossalai/shardformer/modeling/t5.py @@ -0,0 +1,279 @@ +from functools import partial +from types import MethodType +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.utils.checkpoint import checkpoint +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions +from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class T5PipelineForwards: + ''' + This class serves as a micro library for forward function substitution of + T5 models under pipeline setting. + ''' + + @staticmethod + def t5_stack_forward( + self: T5Stack, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.Tensor] = None, + encoder_decoder_position_bias: Optional[torch.Tensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: + + # This function is modified on the basis of transformers.models.t5.modeling_t5.T5Stack.forward. + # Please refer to original code of transformers for more details. + + logger = logging.get_logger(__name__) + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + if use_cache is True: + if not in_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + stage = stage_manager.stage + in_decoder = self.is_decoder + if in_decoder != (stage >= decoder_starting_stage): + raise ValueError("Config in T5Stack is not aligned with pipeline setting.") + + # at_first_stage: current stage is the first stage of encoder/decoder, taking input_ids/input_embedds + # at_last_stage: current stage is the last stage of encoder/decoder, making outputs the same form as huggingface + at_first_stage = (stage == 0) or (stage == decoder_starting_stage) + at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1) + + # Process inputs if at the first stage of encoder/decoder. + if at_first_stage: + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if in_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if in_decoder else "" + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + batch_size, seq_length = input_shape + device = inputs_embeds.device + hidden_states = self.dropout(inputs_embeds) + else: + if hidden_states is None: + raise ValueError( + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=device) + if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + + # Going through held blocks. + start_idx, end_idx = stage_index[0], stage_index[1] + + for i in range(start_idx, end_idx): + + past_key_value = past_key_values[i] + layer_module = self.block[i] + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + torch.cuda.set_device(hidden_states.device) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + if use_cache is False or use_cache is None: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + hidden_states, present_key_value_state = layer_outputs[:2] + # print(stage, len(layer_outputs), present_key_value_state.shape) + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + # last layer + if at_last_stage: + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + else: + return { + 'hidden_states': hidden_states, + 'position_bias': position_bias, + 'encoder_decoder_position_bias': encoder_decoder_position_bias + } + + @staticmethod + def t5_encoder_model_forward( + self: T5EncoderModel, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.Tensor] = None, + encoder_decoder_position_bias: Optional[torch.Tensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + This function is modified on the basis of transformers.models.t5.modeling_gpt2.T5EncoderModel.forward. + Please refer to original code of transformers for more details. + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = T5PipelineForwards.t5_stack_forward(self.encoder, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + return outputs diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 6b8f404f1769..1846c5873801 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,3 +1,8 @@ +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple + +from torch import Tensor, nn + from colossalai.shardformer.layer import ( DropoutForParallelInput, Embedding1D, @@ -8,9 +13,11 @@ ) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription +from .._utils import getattr_, setattr_ +from ..modeling.t5 import T5PipelineForwards from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] +__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] class T5BasePolicy(Policy): @@ -106,7 +113,7 @@ def module_policy(self): ]) policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( - suffix="wi_0", + suffix="wi_0 ", target_module=Linear1D_Col, ), SubModuleReplacementDescription( @@ -166,6 +173,123 @@ def module_policy(self): def postprocess(self): return self.model + @staticmethod + def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int, + num_stages: int) -> Tuple[List[int], int]: + """ + Distribute t5 layers into stages when pipeline parallel is used. + Return the layer distribution as a list and the starting stage of decoder. + If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. + """ + + # number of encoder layers must be a positive integer + if num_encoder_layers <= 0: + raise ValueError("The number of encoder layers for T5 must be a positive integer.") + + # number of layers should be large enough to fill in every stage + if num_encoder_layers + num_decoder_layers < num_stages: + raise ValueError("The total number of layers can't be smaller than number of stages.") + + # in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist + if num_decoder_layers == 0: + return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + + # the number of stages distributed between encoder and decoder is optmized in this way: + # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) + # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 + def objective(num_encoder_stages): + return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages)) + + num_encoder_stages = 0 + optimal_diff = 2**31 - 1 + for i in range(1, num_stages): + attempt = objective(i) + if attempt < optimal_diff: + num_encoder_stages = i + optimal_diff = attempt + num_decoder_stages = num_stages - num_encoder_stages + + encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + return encoder_distribution + decoder_distribution, num_encoder_stages + + @staticmethod + def get_t5_stage_index(layers_per_stage: List[int], stage: int, + decoder_starting_stage: int) -> Tuple[bool, int, int]: + """ + Input the distribution of layers among stages, the current stage and the first stage of decoder. + Return the starting/ending idx of layers in encoder/decoder + """ + if stage < decoder_starting_stage: + return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + else: + return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + stage_manager = self.pipeline_stage_manager + + model = self.model + encoder = self.model.encoder + decoder = self.model.__dict__.get('decoder', None) + + num_encoder_layers = len(encoder.block) + num_decoder_layers = len(decoder.block) if decoder else 0 + + held_layers = [] + layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages) + start_idx, end_idx = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, + decoder_starting_stage) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in t5's encoder + if stage_manager.is_first_stage(): + held_layers.append(model.shared) + held_layers.append(encoder.embed_tokens) + held_layers.append(encoder.dropout) + if stage_manager.stage == decoder_starting_stage - 1: + held_layers.append(encoder.final_layer_norm) + held_layers.append(encoder.dropout) + held_layers.extend(encoder.block[start_idx:end_idx]) + else: + # current stage is in t5's decoder + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.dropout) + if stage_manager.is_last_stage(): + held_layers.append(decoder.final_layer_norm) + held_layers.append(decoder.dropout) + held_layers.extend(decoder.block[start_idx:end_idx]) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + + encoder = self.model.encoder + decoder = self.model.__dict__.get('decoder', None) + + num_encoder_layers = len(encoder.block) + num_decoder_layers = len(decoder.block) if decoder else 0 + + layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages) + stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) + + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + class T5ModelPolicy(T5BasePolicy): @@ -182,6 +306,15 @@ def module_policy(self): target_key=T5Model) return base_policy + def postprocess(self): + if self.shard_config.enable_tensor_parallelism: + binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]} + for k, v in binding_map.items(): + src = getattr_(self.model, k) + for dst in v: + setattr_(self.model, dst, src) + return self.model + class T5ForConditionalGenerationPolicy(T5BasePolicy): @@ -204,19 +337,55 @@ def module_policy(self): target_key=T5ForConditionalGeneration) return policy + def postprocess(self): + super().postprocess() + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: + binding_map = { + "shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + } + for k, v in binding_map.items(): + src = getattr_(self.model, k) + for dst in v: + setattr_(self.model, dst, src) + + return self.model + class T5EncoderPolicy(T5BasePolicy): + def __init__(self) -> None: + super().__init__() + def module_policy(self): from transformers import T5EncoderModel - base_policy = super().module_policy() + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="shared", target_module=VocabParallelEmbedding1D, ), - policy=base_policy, + policy=policy, target_key=T5EncoderModel) - return base_policy + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=T5EncoderModel, + new_forward=T5PipelineForwards.t5_encoder_model_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + return [] + + def postprocess(self): + if self.shard_config.enable_tensor_parallelism: + binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]} + for k, v in binding_map.items(): + src = getattr_(self.model, k) + for dst in v: + setattr_(self.model, dst, src) + return self.model diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 0fbcaa1e2bb3..e447b700105e 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -62,17 +62,15 @@ def data_gen_for_sequence_classification(): loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean() loss_fn = lambda x: x.loss -config = transformers.GPT2Config( - n_layer=2, - n_head=4, - #n_embd=128, - vocab_size=50258, - attn_pdrop=0, - embd_pdrop=0, - resid_pdrop=0, - summary_first_dropout=0, - hidden_dropout=0, - problem_type="single_label_classification") +config = transformers.GPT2Config(n_layer=2, + n_head=4, + vocab_size=50258, + attn_pdrop=0, + embd_pdrop=0, + resid_pdrop=0, + summary_first_dropout=0, + hidden_dropout=0, + problem_type="single_label_classification") # register the following models model_zoo.register(name='transformers_gpt', diff --git a/tests/test_pipeline/test_policy/test_t5_pipeline_utils.py b/tests/test_pipeline/test_policy/test_t5_pipeline_utils.py new file mode 100644 index 000000000000..0cbb852b97a0 --- /dev/null +++ b/tests/test_pipeline/test_policy/test_t5_pipeline_utils.py @@ -0,0 +1,39 @@ +from colossalai.shardformer.policies.t5 import T5BasePolicy + + +def test_t5_pipeline_distribution(): + num_test_cases = 8 + test_dict = { + 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], + 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], + 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], + 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + } + + for i in range(num_test_cases): + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i], + test_dict['num_decoder_layers'][i], + test_dict['num_stages'][i]) + assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + + +def test_t5_pipeline_layers(): + num_test_cases = 4 + test_dict = { + 'num_encoder_layers': [2, 3, 2, 4], + 'num_decoder_layers': [2, 0, 2, 8], + 'num_stages': [2, 2, 4, 4], + 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]]] + } + + for i in range(num_test_cases): + layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) + + for stage in range(test_dict['num_stages'][i]): + start_idx, end_idx = test_dict['layers_per_stage'][i][stage] + predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage, + decoder_starting_stage) + assert start_idx == predicted_start + assert end_idx == predicted_end diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py index 005e3d6f8759..d5453ee72644 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -29,9 +29,11 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} - input_ids, _ = inputs['input_ids'], inputs['attention_mask'] + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + input_ids = inputs['input_ids'] batch_size, seq_len = input_ids.shape - hidden_size = 768 + hidden_size = sharded_model.config.n_embd hidden_state_shape = (batch_size, seq_len, hidden_size) if not stage_manager.is_first_stage(): @@ -40,12 +42,12 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz inputs['input_ids'] = None inputs['hidden_states'] = hidden_states - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) sharded_model.train() output = sharded_model(**inputs) if stage_manager.is_last_stage(): - if name != 'transformers_gpt': + if name == 'transformers_gpt': + assert output[0].shape == hidden_state_shape + else: assert output.loss is not None else: assert output['hidden_states'].shape == hidden_state_shape diff --git a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py b/tests/test_shardformer/test_model/test_shard_t5_pipeline.py new file mode 100644 index 000000000000..3662aa8ac125 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_t5_pipeline.py @@ -0,0 +1,96 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.t5 import T5BasePolicy +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_pipeline_model + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # TODO: add tests for forward/backward later + pass + + +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('enable_fused_normalization', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_t5.py +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') + for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): + if name != 'transformers_t5_encoder_model': + continue + + inputs = data_gen_fn() + inputs = {k: v.cuda() for k, v in inputs.items()} + input_ids = inputs['input_ids'] + + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + + batch_size, seq_len = input_ids.shape + hidden_size = sharded_model.config.d_model + num_heads = sharded_model.config.num_heads + hidden_state_shape = (batch_size, seq_len, hidden_size) + position_bias_shape = (batch_size, num_heads, seq_len, seq_len) + + num_encoder_layers = len(sharded_model.encoder.block) + decoder = sharded_model.__dict__.get('decoder', None) + num_decoder_layers = len(decoder.block) if decoder else 0 + + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(num_encoder_layers, num_decoder_layers, PP_SIZE) + stage = stage_manager.stage + at_first_stage = (stage == 0) or (stage == decoder_starting_stage) + at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1) + + if not at_first_stage: + # change inputs if not the first stage + hidden_states = torch.zeros(*hidden_state_shape).cuda() + position_bias = torch.zeros(*position_bias_shape).cuda() + encoder_decoder_position_bias = torch.zeros(*position_bias_shape).cuda() + inputs['input_ids'] = None + inputs['hidden_states'] = hidden_states + inputs['position_bias'] = position_bias + inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias + + sharded_model.train() + output = sharded_model(**inputs) + if at_last_stage: + if name != 'transformers_t5_for_conditional_generation': + assert output[0].shape == hidden_state_shape + else: + assert output.loss is not None + else: + assert output['hidden_states'].shape == hidden_state_shape + # position_bias information should be passed in T5 + assert 'position_bias' in output + assert 'encoder_decoder_position_bias' in output + + torch.cuda.empty_cache() + + +def check_t5(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_t5_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_t5(): + spawn(check_t5, 4) + + +if __name__ == "__main__": + test_t5() From 2fe393fc29e42728787beda5c4094825b67670f6 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 25 Jul 2023 14:31:21 +0800 Subject: [PATCH 35/80] [pipeline] test pure pipeline process using llama (#4218) * bloom policy * llama pipeline forward and tests * fix the output and attention_mask * fix name * bind argument to policy * Revert "bloom policy" This reverts commit 8dee68a0a22568dbeed6d4563372b25e1e825fb0. This policy should be revert and copied to feature/bloom * revert the bloom changes * cancel unneeded inputs * gpt * finish llama * causal lm and sequence classification * revision * add pure pipeline test * fixed version * fixed version * pure pipeline --- colossalai/pipeline/p2p.py | 24 ++++++++++--------- .../test_model/test_pure_pipeline.py | 24 +++++++++++++------ 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 2fd135d5475d..851a0b595bc6 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -9,6 +9,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed import distributed_c10d as c10d +from version_parser.version import Version from .stage_manager import PipelineStageManager @@ -61,17 +62,6 @@ def _broadcast_object_list(object_list: List[Any], c10d._warn_not_in_group("broadcast_object_list") return - my_rank = dist.get_rank() - # Serialize object_list elements to tensors on src rank. - if my_rank == src: - if torch.__version__ >= "1.13.0": - tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=device) for obj in object_list]) - else: - tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) - object_sizes_tensor = torch.cat(size_list) - else: - object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) - is_nccl_backend = c10d._check_for_nccl_backend(group) current_device = None @@ -83,6 +73,18 @@ def _broadcast_object_list(object_list: List[Any], current_device = torch.device("cpu") if is_nccl_backend: current_device = torch.device("cuda", torch.cuda.current_device()) + + my_rank = dist.get_rank() + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + if Version(torch.__version__) >= Version("1.13.0"): + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list]) + else: + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) + if is_nccl_backend: object_sizes_tensor = object_sizes_tensor.to(current_device) diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py index 80767f71c3fb..2f51eb9b02f7 100644 --- a/tests/test_shardformer/test_model/test_pure_pipeline.py +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -1,3 +1,4 @@ +import copy import random from contextlib import nullcontext from typing import Any, Callable, Iterator, List, Optional, Tuple @@ -6,7 +7,6 @@ import pytest import torch import torch.distributed as dist -from torch import Tensor from torch.nn import Module from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler @@ -94,10 +94,10 @@ def execute_pipeline( return outputs -class data_iter(): +class data_loader(): def __getitem__(self, x): - return torch.randint(0, 100, (4, 128)).cuda() + return torch.ones((4, 128), dtype=torch.int).cuda() * 10 def loss(x, y): @@ -127,20 +127,30 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name != 'transformers_llama': + continue num_microbatches = 2 org_model = model_fn().cuda() + data_iter = iter(data_loader()) + + model_copy = copy.deepcopy(org_model) + batch = next(data_iter) + with torch.no_grad(): + y = model_copy(batch) + org_loss = loss(batch, y) optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3) - #dataloader=prepare_dataloader(dataset=dataset['train'],batch_size=4) schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager) shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism, pipeline_stage_manager=stage_manager) pipelined_model = PipelinedModel(org_model, shard_config, stage_manager) pp_optimizer = PipelineOptimizer(optimizer, pipelined_model) - data_it = iter(data_iter()) - results = execute_pipeline(data_it, pipelined_model, loss, pp_optimizer, schedule=schedule) + results = execute_pipeline(data_iter, pipelined_model, loss, pp_optimizer, schedule=schedule) + if stage_manager.is_last_stage(): - assert results['loss'] is not None + assert results['loss'] == org_loss + else: + assert results['loss'] is None assert results['outputs'] is None torch.cuda.empty_cache() From f172208fe380a1a1da9dced19caf52566cf77381 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 25 Jul 2023 14:45:33 +0800 Subject: [PATCH 36/80] [pipeline] add pipeline support for all T5 models (#4310) * complete policy for T5Model & T5ForConditionalGeneration * modify function signature in forwards * add forward for T5model * add forward for T5ForConditionalGeneration * fix a bug * fix hidden_states transporting in decoder * fix the passing of encoder_outputs --- colossalai/shardformer/modeling/t5.py | 324 +++++++++++++++++- colossalai/shardformer/policies/t5.py | 64 +++- .../test_model/test_shard_t5_pipeline.py | 19 +- 3 files changed, 388 insertions(+), 19 deletions(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index cc270d5828a2..7eb4d17928d6 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -1,11 +1,15 @@ -from functools import partial -from types import MethodType -from typing import Callable, Dict, List, Optional, Tuple, Union +import warnings +from typing import Dict, List, Optional, Tuple, Union import torch -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss from torch.utils.checkpoint import checkpoint -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack from transformers.utils import logging @@ -198,14 +202,13 @@ def custom_forward(*inputs): if use_cache is False or use_cache is None: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] hidden_states, present_key_value_state = layer_outputs[:2] - # print(stage, len(layer_outputs), present_key_value_state.shape) # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: + if in_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] # append next layer key value states if use_cache: @@ -238,6 +241,313 @@ def custom_forward(*inputs): 'encoder_decoder_position_bias': encoder_decoder_position_bias } + @staticmethod + def t5_model_forward( + self: T5Model, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.Tensor] = None, + encoder_decoder_position_bias: Optional[torch.Tensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + + # This function is modified on the basis of transformers.models.t5.modeling_t5.T5Model.forward. + # Please refer to original code of transformers for more details. + + __HEAD_MASK_WARNING_MSG = """ + The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, + `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. + If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, + num_heads)`. + """ + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + logger = logging.get_logger(__name__) + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + in_decoder = stage_manager.stage >= decoder_starting_stage + + # Stage is in encoder, directly return the output of t5_stack_forward + if not in_decoder: + encoder_outputs = T5PipelineForwards.t5_stack_forward( + self.encoder, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + if stage_manager.stage == decoder_starting_stage - 1: + # last stage of encoder + return {'encoder_outputs': encoder_outputs} + else: + return encoder_outputs + + at_last_decoder_stage = stage_manager.is_last_stage() + at_first_decoder_stage = stage_manager.stage == decoder_starting_stage + + if encoder_outputs is None: + raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.") + + encoder_hidden_states = encoder_outputs[0] + if return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in. + if not at_first_decoder_stage and hidden_states is None: + raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + + # Decode + decoder_outputs = T5PipelineForwards.t5_stack_forward( + self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + # Directly return outputs of overloaded T5Stack forward if not at last stage. + if not at_last_decoder_stage: + decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage + return decoder_outputs + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @staticmethod + def t5_for_conditional_generation_forward( + self: T5ForConditionalGeneration, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.Tensor] = None, + encoder_decoder_position_bias: Optional[torch.Tensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + + # This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward. + # Please refer to original code of transformers for more details. + + __HEAD_MASK_WARNING_MSG = """ + The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, + `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. + If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, + num_heads)`. + """ + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + logger = logging.get_logger(__name__) + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + in_decoder = stage_manager.stage >= decoder_starting_stage + + # Stage is in encoder, directly return the output of t5_stack_forward + if not in_decoder: + encoder_outputs = T5PipelineForwards.t5_stack_forward( + self.encoder, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + if stage_manager.stage == decoder_starting_stage - 1: + # last stage of encoder + return {'encoder_outputs': encoder_outputs} + else: + return encoder_outputs + + at_last_decoder_stage = stage_manager.is_last_stage() + at_first_decoder_stage = stage_manager.stage == decoder_starting_stage + + if encoder_outputs is None: + raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.") + + encoder_hidden_states = encoder_outputs[0] + if return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in. + if not at_first_decoder_stage and hidden_states is None: + raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + + # Decode + decoder_outputs = T5PipelineForwards.t5_stack_forward( + self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + # Directly return outputs of overloaded T5Stack forward if not at last stage. + if not at_last_decoder_stage: + decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage + return decoder_outputs + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + @staticmethod def t5_encoder_model_forward( self: T5EncoderModel, diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 1846c5873801..0ee18d6c4940 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -293,21 +293,42 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli class T5ModelPolicy(T5BasePolicy): + def __init__(self) -> None: + super().__init__() + def module_policy(self): from transformers import T5Model - base_policy = super().module_policy() + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="shared", target_module=VocabParallelEmbedding1D, ), - policy=base_policy, + policy=policy, target_key=T5Model) - return base_policy + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None and stage_manager.num_stages > 1: + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block), + len(module.decoder.block), + stage_manager.num_stages) + + if id(module.decoder.embed_tokens.weight) == id(module.shared.weight): + return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}] + return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]} for k, v in binding_map.items(): src = getattr_(self.model, k) @@ -318,6 +339,9 @@ def postprocess(self): class T5ForConditionalGenerationPolicy(T5BasePolicy): + def __init__(self) -> None: + super().__init__() + def module_policy(self): from transformers import T5ForConditionalGeneration @@ -335,8 +359,38 @@ def module_policy(self): ], policy=policy, target_key=T5ForConditionalGeneration) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=T5ForConditionalGeneration, + new_forward=T5PipelineForwards.t5_for_conditional_generation_forward, + policy=policy) return policy + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None and stage_manager.num_stages > 1: + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block), + len(module.decoder.block), + stage_manager.num_stages) + + shared_params = [] + if id(module.decoder.embed_tokens.weight) == id(module.shared.weight): + shared_params.append({ + 0: module.shared.weight, + decoder_starting_stage: module.decoder.embed_tokens.weight + }) + if id(module.lm_head.weight) == id(module.shared.weight): + shared_params.append({0: module.shared.weight, stage_manager.num_stages - 1: module.lm_head.weight}) + return shared_params + return [] + def postprocess(self): super().postprocess() if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: @@ -382,7 +436,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] def postprocess(self): - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]} for k, v in binding_map.items(): src = getattr_(self.model, k) diff --git a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py b/tests/test_shardformer/test_model/test_shard_t5_pipeline.py index 3662aa8ac125..7f3a5f2ea40b 100644 --- a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_t5_pipeline.py @@ -28,8 +28,6 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - if name != 'transformers_t5_encoder_model': - continue inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} @@ -52,6 +50,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ stage = stage_manager.stage at_first_stage = (stage == 0) or (stage == decoder_starting_stage) at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1) + in_decoder = stage >= decoder_starting_stage if not at_first_stage: # change inputs if not the first stage @@ -62,19 +61,25 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ inputs['hidden_states'] = hidden_states inputs['position_bias'] = position_bias inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias + if in_decoder: + encoder_output_states = torch.zeros(*hidden_state_shape).cuda() + inputs['encoder_outputs'] = (encoder_output_states,) sharded_model.train() output = sharded_model(**inputs) if at_last_stage: - if name != 'transformers_t5_for_conditional_generation': - assert output[0].shape == hidden_state_shape - else: + if name == 'transformers_t5_for_conditional_generation' and in_decoder: assert output.loss is not None + else: + if name != 'transformers_t5_encoder_model' and not in_decoder: + output = output['encoder_outputs'] + assert output[0].shape == hidden_state_shape else: assert output['hidden_states'].shape == hidden_state_shape # position_bias information should be passed in T5 - assert 'position_bias' in output - assert 'encoder_decoder_position_bias' in output + assert output['position_bias'].shape == position_bias_shape + if in_decoder: + assert output['encoder_decoder_position_bias'].shape == position_bias_shape torch.cuda.empty_cache() From 668ab102f023b50d1c587a663706d194ac82c6a6 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 25 Jul 2023 15:02:29 +0800 Subject: [PATCH 37/80] [shardformer] support pipeline base vit model (#4284) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * support base vit pipeline * support vit downstream model * fix vit shard test * modify hidden states return type --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> --- colossalai/shardformer/modeling/vit.py | 337 ++++++++++++++++++ .../shardformer/policies/auto_policy.py | 8 + colossalai/shardformer/policies/vit.py | 283 ++++++++++----- tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/vit.py | 68 ++++ .../test_model/test_shard_vit.py | 63 ++-- .../test_model/test_shard_vit_pipeline.py | 74 ++++ 7 files changed, 729 insertions(+), 105 deletions(-) create mode 100644 colossalai/shardformer/modeling/vit.py create mode 100644 tests/kit/model_zoo/transformers/vit.py create mode 100644 tests/test_shardformer/test_model/test_shard_vit_pipeline.py diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py new file mode 100644 index 000000000000..f28c13ad0aa2 --- /dev/null +++ b/colossalai/shardformer/modeling/vit.py @@ -0,0 +1,337 @@ +import logging +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +def _encoder_forward( + encoder: ViTEncoder, + start_idx: int, + end_idx: int, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + stage_manager: PipelineStageManager = None, +) -> Union[tuple, BaseModelOutput]: + + for i in range(start_idx, end_idx): + layer_module = encoder.layer[i] + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if encoder.gradient_checkpointing and encoder.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, False) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, False) + + hidden_states = layer_outputs[0] + if not stage_manager.is_last_stage(): + return hidden_states + else: + if not return_dict: + return tuple(hidden_states) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=None, + attentions=None, + ) + + +def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): + + from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling + + def pp_forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions is not None: + logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.') + output_attentions = None + if output_hidden_states is not None: + logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.') + output_hidden_states = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if stage_manager.is_first_stage(): + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + + embedding_output = self.embeddings(pixel_values, + bool_masked_pos=bool_masked_pos, + interpolate_pos_encoding=interpolate_pos_encoding) + else: + assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + # Go through encoder + if not stage_manager.is_last_stage(): + hidden_states = _encoder_forward( + encoder=self.encoder, + start_idx=stage_index[0], + end_idx=stage_index[1], + hidden_states=embedding_output, + head_mask=head_mask, + return_dict=return_dict, + stage_manager=stage_manager, + ) + return {'hidden_states': hidden_states} + else: + encoder_outputs = _encoder_forward( + encoder=self.encoder, + start_idx=stage_index[0], + end_idx=stage_index[1], + hidden_states=hidden_states, + head_mask=head_mask, + return_dict=return_dict, + stage_manager=stage_manager, + ) + + # Go through rest layers + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + return pp_forward + + +def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): + + from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + from transformers.models.vit.modeling_vit import ImageClassifierOutput + + def pp_forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if not stage_manager.is_first_stage(): + assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + outputs = self.vit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + hidden_states=hidden_states, + ) + + # not last stage, return hidden_states + if not stage_manager.is_last_stage(): + return outputs + else: + sequence_output = outputs[0] + + # last stage + logits = self.classifier(sequence_output[:, 0, :]) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return pp_forward + + +def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): + + import math + + import torch.nn as nn + from transformers.models.vit.modeling_vit import ImageClassifierOutput, MaskedImageModelingOutput + + def pp_forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") + >>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction + >>> list(reconstructed_pixel_values.shape) + [1, 3, 224, 224] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input." + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}.") + + if not stage_manager.is_first_stage(): + assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + outputs = self.vit(pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + hidden_states=hidden_states) + if not stage_manager.is_last_stage(): + return outputs + else: + sequence_output = outputs[0] + + # Reshape to (batch_size, num_channels, height, width) + sequence_output = sequence_output[:, 1:] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = (bool_masked_pos.repeat_interleave(self.config.patch_size, + 1).repeat_interleave(self.config.patch_size, + 2).unsqueeze(1).contiguous()) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return pp_forward diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index b31f1b35f580..d00a03c9237e 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -75,6 +75,14 @@ class PolicyLocation: "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"), + # ViT + "transformers.models.vit.modeling_vit.ViTModel": + PolicyLocation(file_name="vit", class_name="ViTModelPolicy"), + "transformers.models.vit.modeling_vit.ViTForImageClassification": + PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"), + "transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": + PolicyLocation(file_name="vit", class_name="ViTForMaskedImageModelingPolicy"), + # OPT "transformers.models.opt.modeling_opt.OPTModel": PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 3f6bbd10607a..47f2c58fc436 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,12 +1,18 @@ -from typing import Dict, Union +from functools import partial +from typing import Callable, Dict, List, Union import torch.nn as nn -from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row +import colossalai.shardformer.layer as col_nn +from ..modeling.vit import ( + ViTForImageClassification_pipeline_forward, + ViTForMaskedImageModeling_pipeline_forward, + ViTModel_pipeline_forward, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['ViTPolicy'] +__all__ = ['ViTPolicy', 'ViTModelPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy'] class ViTPolicy(Policy): @@ -15,96 +21,203 @@ def config_sanity_check(self): pass def preprocess(self): - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer - base_policy = { - ViTEmbeddings: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForReplicatedInput, - ) - ]), - ViTLayer: - ModulePolicyDescription(attribute_replacement={ - "attention.attention.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "attention.attention.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=DropoutForParallelInput, - ), - ]), - } - - # optimization configuration - if self.shard_config.enable_fused_normalization: - base_policy[ViTAttention].sub_module_replacement.extend([ - SubModuleReplacementDescription( - suffix="layernorm_before", - target_module=FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layernorm_after", - target_module=FusedLayerNorm, - ) - ]) - base_policy[ViTModel].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="layernorm", - target_module=FusedLayerNorm, - )) - - return base_policy + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ) + ]) + + policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={ + "attention.attention.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ]) + + return policy def new_model_class(self): return None def postprocess(self): return self.model + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + if self.model.__class__.__name__ == 'ViTModel': + module = self.model + else: + module = self.model.vit + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict): + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == 'ViTModel': + module = self.model + else: + module = self.model.vit + + layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + + +# ViTModel +class ViTModelPolicy(ViTPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTModel + + policy = super().module_policy() + + if self.shard_config.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(module.pooler) + + return held_layers + + +# ViTForImageClassification +class ViTForImageClassificationPolicy(ViTPolicy): + + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel + + policy = super().module_policy() + if self.shard_config.enable_tensor_parallelism: + new_item = { + ViTForImageClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + + if self.shard_config.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) + self.set_pipeline_forward(model_cls=ViTForImageClassification, + pipeline_forward=ViTForImageClassification_pipeline_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + module = self.model.vit + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(self.model.classifier) + + return held_layers + + +# ViTForMaskedImageModeling +class ViTForMaskedImageModelingPolicy(ViTPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel + + policy = super().module_policy() + + if self.shard_config.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) + self.set_pipeline_forward(model_cls=ViTForMaskedImageModeling, + pipeline_forward=ViTForMaskedImageModeling_pipeline_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + module = self.model.vit + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(self.model.decoder) + + return held_layers diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 4aa01abe13ee..a298767d12e7 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -5,3 +5,4 @@ from .llama import * from .opt import * from .t5 import * +from .vit import * diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py new file mode 100644 index 000000000000..93a8d6c615d7 --- /dev/null +++ b/tests/kit/model_zoo/transformers/vit.py @@ -0,0 +1,68 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence VIT +# =============================== + +config = transformers.ViTConfig( + num_hidden_layers=4, + # hidden_size=128, + # intermediate_size=256, + num_attention_heads=4) + + +# define data gen function +def data_gen(): + pixel_values = torch.randn(1, 3, 224, 224) + return dict(pixel_values=pixel_values) + + +def data_gen_for_image_classification(): + data = data_gen() + data['labels'] = torch.tensor([0]) + return data + + +def data_gen_for_masked_image_modeling(): + data = data_gen() + num_patches = (config.image_size // config.patch_size)**2 + bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + data['bool_masked_pos'] = bool_masked_pos + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# function to get the loss +loss_fn_for_vit_model = lambda x: x.pooler_output.mean() +loss_fn_for_image_classification = lambda x: x.logits.mean() +loss_fn_for_masked_image_modeling = lambda x: x.loss + +# register the following models +# transformers.ViTModel, +# transformers.ViTForMaskedImageModeling, +# transformers.ViTForImageClassification, +model_zoo.register(name='transformers_vit', + model_fn=lambda: transformers.ViTModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_vit_model, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='transformers_vit_for_masked_image_modeling', + model_fn=lambda: transformers.ViTForMaskedImageModeling(config), + data_gen_fn=data_gen_for_masked_image_modeling, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_masked_image_modeling, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='transformers_vit_for_image_classification', + model_fn=lambda: transformers.ViTForImageClassification(config), + data_gen_fn=data_gen_for_image_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_image_classification, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index af1605b6b659..2b02c83e0d27 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -1,9 +1,18 @@ +import os + import pytest import torch import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -12,44 +21,58 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output) - + assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3) # do backward org_loss.backward() shard_loss.backward() - # check grad - org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad - shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad - - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - assert torch.allclose(org_loss, shard_loss, atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + # unwrap model + if org_model.__class__.__name__ == 'ViTModel': + vit_model = org_model + shard_vit_model = sharded_model + else: + vit_model = org_model.vit + shard_vit_model = sharded_model.vit -def check_vit(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # check attention grad + org_grad = vit_model.encoder.layer[0].attention.attention.query.weight.grad + shard_grad = shard_vit_model.encoder.layer[0].attention.attention.query.weight.grad + shard_weight = shard_vit_model.encoder.layer[0].attention.attention.query.weight + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_vit_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() +def check_vit(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit_test() + + @pytest.mark.dist -@pytest.mark.skip @rerun_if_address_is_in_use() @clear_cache_before_run() def test_vit(): - spawn(check_vit, 4) + spawn(check_vit, 2) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py new file mode 100644 index 000000000000..114992a2a2a5 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py @@ -0,0 +1,74 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_pipeline_model + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # TODO: add tests for forward/backward later + pass + + +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('enable_fused_normalization', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_vit +def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + + for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): + + inputs = data_gen_fn() + inputs = {k: v.cuda() for k, v in inputs.items()} + pixel_values = inputs['pixel_values'] + batch_size = len(pixel_values) + hidden_size = 768 + hidden_state_shape = (batch_size, 197, hidden_size) + + if not stage_manager.is_first_stage(): + # change inputs if not the first stage + hidden_states = torch.randn(*hidden_state_shape).cuda() + # inputs['pixel_values'] = None + inputs['hidden_states'] = hidden_states + + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + sharded_model.train() + + output = sharded_model(**inputs) + if stage_manager.is_last_stage(): + if name != 'transformers_vit': + assert output.loss is not None + else: + assert output['hidden_states'].shape == hidden_state_shape, \ + f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}' + + torch.cuda.empty_cache() + + +def check_vit(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_vit(): + spawn(check_vit, 4) + + +if __name__ == "__main__": + test_vit() From 9e4018d3daa604bb2467a09c7fdbb91ce28eb195 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 26 Jul 2023 00:53:57 +0800 Subject: [PATCH 38/80] [plugin] add 3d parallel plugin (#4295) * [amp] add mixed precision optimizer * [plugin] add 3d parallel plugin * [booster] support pipeline * [plugin] 3d parallel plugin support clip grad norm * [shardformer] fix sharder and add plugin test * [plugin] rename 3d parallel plugin * [ci] support testmon core pkg change detection (#4305) * [hotfix] debug testmon * [hotfix] fix llama * [hotfix] fix p2p bugs * [hotfix] fix requirements --- .../naive_amp/mixed_precision_optimizer.py | 149 +++++++++ colossalai/booster/booster.py | 12 +- colossalai/booster/plugin/__init__.py | 3 +- .../booster/plugin/hybrid_parallel_plugin.py | 316 ++++++++++++++++++ colossalai/booster/plugin/pp_plugin_base.py | 21 ++ colossalai/pipeline/p2p.py | 2 +- colossalai/shardformer/modeling/llama.py | 6 - colossalai/shardformer/shard/sharder.py | 37 +- .../test_plugin/test_3d_plugin.py | 99 ++++++ 9 files changed, 621 insertions(+), 24 deletions(-) create mode 100644 colossalai/amp/naive_amp/mixed_precision_optimizer.py create mode 100644 colossalai/booster/plugin/hybrid_parallel_plugin.py create mode 100644 colossalai/booster/plugin/pp_plugin_base.py create mode 100644 tests/test_booster/test_plugin/test_3d_plugin.py diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py new file mode 100644 index 000000000000..d4183be3fb5f --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -0,0 +1,149 @@ +from typing import Dict, List + +import torch +from torch import Tensor +from torch.nn import Parameter +from torch.optim import Optimizer + +from colossalai.interface import OptimizerWrapper + +from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin + + +class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + + def __init__(self, + working_params: List[Parameter], + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32) -> None: + super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, + max_scale) + self.params = working_params + + def check_local_overflow(self) -> bool: + for p in self.params: + if p.grad is not None and not torch.isfinite(p.grad).all(): + return True + return False + + +class MixedPrecisionOptimizer(OptimizerWrapper): + + def __init__(self, + optim: Optimizer, + precision: str = 'fp16', + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0): + super().__init__(optim) + if precision == 'fp16': + working_params = [] + for group in self.optim.param_groups: + for p in group['params']: + working_params.append(p) + self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + elif precision == 'bf16': + self.mixed_precision = BF16MixedPrecisionMixin() + else: + raise ValueError(f'Unsupported precision: {precision}') + if max_norm > 0.0: + raise NotImplementedError('max_norm is not supported yet.') + self.max_norm = max_norm + self.working_to_master_map: Dict[Parameter, Tensor] = {} + self.master_to_working_map: Dict[Tensor, Parameter] = {} + + # create master weights + for group in self.optim.param_groups: + master_params = [] + for p in group['params']: + if p.requires_grad: + master_p = p + if p.dtype != torch.float: + master_p = p.detach().float() + self.working_to_master_map[p] = master_p + self.master_to_working_map[master_p] = p + master_params.append(master_p) + group['params'] = master_params + + def backward(self, loss: Tensor, *args, **kwargs): + loss = self.mixed_precision.pre_backward(loss) + loss.backward(*args, **kwargs) + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) + tensor.backward(grad) + + def zero_grad(self, *args, **kwargs): + for p in self.working_to_master_map.keys(): + p.grad = None + self.mixed_precision.pre_zero_grad() + return super().zero_grad(*args, **kwargs) + + def _unscale_and_clip_grads(self, total_norm: float) -> None: + div_scale = 1.0 + if self.mixed_precision is not None: + div_scale = self.mixed_precision.get_grad_div_scale() + + if self.max_norm > 0.: + # norm is in fact norm*scale + clip = ((total_norm / div_scale) + 1e-6) / self.max_norm + if clip > 1: + div_scale = clip * div_scale + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + p.grad.data.mul_(1. / div_scale) + + def _compute_grad_norm(self) -> float: + if self.max_norm <= 0.: + return 0. + grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None] + if len(grads) == 0: + return 0. + device = grads[0].device + # TODO(ver217): support tp + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2) + return total_norm.item() + + def step(self, *args, **kwargs): + if self.mixed_precision.should_skip_step(): + self.zero_grad() + return + # prepare grads + for group in self.optim.param_groups: + for p in group['params']: + working_param = self.master_to_working_map[p] + if p is working_param: + continue + if working_param.grad is None: + p.grad = working_param.grad.data.float() + working_param.grad = None + total_norm = self._compute_grad_norm() + self._unscale_and_clip_grads(total_norm) + self.optim.step(*args, **kwargs) + # update working params + for group in self.optim.param_groups: + for p in group['params']: + working_param = self.master_to_working_map[p] + if p is working_param: + continue + working_param.data.copy_(p.data) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index ec3dc7fc143f..8a28b1286cfa 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -1,6 +1,6 @@ import warnings from contextlib import contextmanager -from typing import Callable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Iterator, List, Optional, Union import torch import torch.nn as nn @@ -14,6 +14,7 @@ from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory from .plugin import Plugin +from .plugin.pp_plugin_base import PipelinePluginBase __all__ = ['Booster'] @@ -144,14 +145,15 @@ def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: def execute_pipeline(self, data_iter: Iterator, model: nn.Module, - criterion: Callable[[torch.Tensor], torch.Tensor], + criterion: Callable[[Any, Any], torch.Tensor], optimizer: Optimizer, return_loss: bool = True, - return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]: - # TODO: implement this method + return_outputs: bool = False) -> dict: # run pipeline forward backward pass # return loss or outputs if needed - pass + assert isinstance(self.plugin, + PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.' + return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs) def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager: """Context manager to disable gradient synchronization across DP process groups. diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py index a3b87b5f11d3..f48bf38bd724 100644 --- a/colossalai/booster/plugin/__init__.py +++ b/colossalai/booster/plugin/__init__.py @@ -1,9 +1,10 @@ from .gemini_plugin import GeminiPlugin +from .hybrid_parallel_plugin import HybridParallelPlugin from .low_level_zero_plugin import LowLevelZeroPlugin from .plugin_base import Plugin from .torch_ddp_plugin import TorchDDPPlugin -__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin'] +__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin'] import torch from packaging import version diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py new file mode 100644 index 000000000000..37badb613433 --- /dev/null +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -0,0 +1,316 @@ +import random +from contextlib import nullcontext +from typing import Any, Callable, Iterator, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import Module +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer +from colossalai.checkpoint_io import CheckpointIO +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.zero.low_level import LowLevelZeroOptimizer + +from .pp_plugin_base import PipelinePluginBase + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + + +class HybridParallelModule(ModelWrapper): + + def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None: + self.stage_manager = shard_config.pipeline_stage_manager + self.dp_group = dp_group + shardformer = ShardFormer(shard_config) + module, self.shared_params = shardformer.optimize(module) + # TODO(ver217): add input type cast + self.shared_param_process_groups = [] + for shared_param in self.shared_params: + if len(shared_param) > 0: + self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) + if precision == 'fp16': + module = module.half().cuda() + elif precision == 'bf16': + module = module.to(dtype=torch.bfloat16).cuda() + # TODO(ver217): support TP+DP + super().__init__(module) + + def sync_shared_params(self): + for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): + param = shared_param[self.stage_manager.stage] + dist.all_reduce(param.grad, group=group) + + def no_sync(self) -> Iterator[None]: + # no sync grads across data parallel + return nullcontext() + + def sync_grads(self): + # sync grad across data parallel + if self.dp_group.size() == 1: + return + for p in self.module.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, group=self.dp_group) + + +def init_pipeline_optimizer(optim: Optimizer, model: Module): + params = set(model.parameters()) + new_param_groups = [] + for group in optim.param_groups: + params = [p for p in group['params'] if p in params] + new_param_groups.append({**group, 'params': params}) + optim.__setstate__({'param_groups': new_param_groups}) + + +class HybridParallelOptimizer(MixedPrecisionOptimizer): + + def __init__(self, + optim: Optimizer, + model: Module, + use_pipeline: bool, + precision: str = 'fp16', + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0): + if use_pipeline: + init_pipeline_optimizer(optim, model) + super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, + hysteresis, max_scale, max_norm) + + +class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): + + def __init__( + self, + optimizer: Optimizer, + model: Module, + use_pipeline: bool, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2., + backoff_factor: float = .5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + forced_dtype: Optional[torch.dtype] = None): + if use_pipeline: + init_pipeline_optimizer(optimizer, model) + super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, + hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype, + overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group, + forced_dtype) + + +class HybridParallelPlugin(PipelinePluginBase): + + def __init__( + self, + tp_size: int, + pp_size: int, + precision: str = 'fp16', + zero_stage: int = 0, + cpu_offload: bool = False, + enable_fused_normalization: bool = False, + num_microbatches: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + ) -> None: + super().__init__() + assert dist.get_world_size() % ( + tp_size * pp_size + ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' + # TODO(ver217): support zero + assert zero_stage == 0, 'zero is not support yet' + self.tp_size = tp_size + self.pp_size = pp_size + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + self.precision = precision + self.zero_stage = zero_stage + self.cpu_offload = cpu_offload + self.enable_fused_normalization = enable_fused_normalization + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) + self.stage_manager = None + self.schedule = None + assert zero_stage in (0, 1, 2) + if self.pp_size > 1: + assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism' + assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' + self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) + self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager) + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_fused_normalization=self.enable_fused_normalization) + self.amp_config = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + ) + self.max_norm = max_norm + + @property + def enable_pipeline_parallelism(self) -> bool: + return self.pp_size > 1 + + def supported_devices(self) -> List[str]: + return ['cuda'] + + def supported_precisions(self) -> List[str]: + return ['fp16', 'bf16'] + + def control_device(self) -> bool: + return True + + def control_precision(self) -> bool: + return True + + def support_no_sync(self) -> bool: + return False + + def control_checkpoint_io(self) -> bool: + return True + + def configure( + self, + model: Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + if not isinstance(model, ModelWrapper): + model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if self.zero_stage == 0: + optimizer = HybridParallelOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config) + else: + optimizer = HybridParallelZeroOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + partition_grad=(self.zero_stage == 2), + cpu_offload=self.cpu_offload, + dp_process_group=self.dp_group, + tp_process_group=self.tp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **self.amp_config) + return model, optimizer, criterion, dataloader, lr_scheduler + + def execute_pipeline(self, + data_iter: Iterator, + model: HybridParallelModule, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Union[HybridParallelOptimizer, HybridParallelZeroOptimizer], + return_loss: bool = True, + return_outputs: bool = False) -> dict: + assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled' + # return loss or outputs if needed + ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + with ctx: + outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, + return_outputs) + # model.sync_shared_params() + if isinstance(optimizer, HybridParallelZeroOptimizer): + optimizer.sync_grad() + else: + model.sync_grads() + return outputs + + def prepare_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, + num_replicas=self.pg_mesh.size(DP_AXIS), + rank=self.pg_mesh.coordinate(DP_AXIS), + shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + + def get_checkpoint_io(self) -> CheckpointIO: + return None + + def no_sync(self, model: Module) -> Iterator[None]: + raise NotImplementedError diff --git a/colossalai/booster/plugin/pp_plugin_base.py b/colossalai/booster/plugin/pp_plugin_base.py new file mode 100644 index 000000000000..67ade9330f5b --- /dev/null +++ b/colossalai/booster/plugin/pp_plugin_base.py @@ -0,0 +1,21 @@ +from abc import abstractmethod +from typing import Any, Callable, Iterator + +import torch + +from colossalai.interface import ModelWrapper, OptimizerWrapper + +from .plugin_base import Plugin + + +class PipelinePluginBase(Plugin): + + @abstractmethod + def execute_pipeline(self, + data_iter: Iterator, + model: ModelWrapper, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: OptimizerWrapper, + return_loss: bool = True, + return_outputs: bool = False) -> dict: + pass diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 851a0b595bc6..f741b8363f13 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -7,9 +7,9 @@ import torch import torch.distributed as dist +from packaging.version import Version from torch.distributed import ProcessGroup from torch.distributed import distributed_c10d as c10d -from version_parser.version import Version from .stage_manager import PipelineStageManager diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 7bc626fe6825..e1ed5f64665c 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -223,9 +223,6 @@ def llama_for_causal_lm_forward( if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( @@ -311,9 +308,6 @@ def llama_for_sequence_classification_forward( if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False transformer_outputs = LlamaPipelineForwards.llama_model_forward( self.model, diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index b32c285bdaab..ae8cd8c6e553 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,5 +1,5 @@ from types import MethodType -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union import torch.nn as nn from torch import Tensor @@ -39,8 +39,8 @@ def shard(self) -> List[Dict[int, Tensor]]: self._preprocess() # get shared params before release unheld layers, this avoid misjudgement of shared params (None is None) shared_params = self.policy.get_shared_params() - self._release_unheld_layers() - self._replace_module() + held_layers = self._release_unheld_layers() + self._replace_module(include=held_layers) self._materialize() self._postprocess() return shared_params @@ -51,7 +51,7 @@ def _preprocess(self) -> None: def _postprocess(self) -> None: self.model = self.policy.postprocess() - def _replace_module(self,) -> None: + def _replace_module(self, include: Optional[Set[nn.Module]] = None) -> None: r""" Replace the module according to the policy, and replace the module one by one @@ -64,8 +64,13 @@ def _replace_module(self,) -> None: param_replacement = module_description.param_replacement sub_module_replacement = module_description.sub_module_replacement method_replacement = module_description.method_replacement - self._recursive_replace_layer(self.model, layer_cls, attr_replacement, param_replacement, - method_replacement, sub_module_replacement) + self._recursive_replace_layer(self.model, + layer_cls, + attr_replacement, + param_replacement, + method_replacement, + sub_module_replacement, + include=include) def _recursive_replace_layer( self, @@ -75,6 +80,7 @@ def _recursive_replace_layer( param_replacement: List[Callable], method_replacement: Dict[str, Callable], sub_module_replacement: List[SubModuleReplacementDescription], + include: Optional[Set[nn.Module]] = None, ) -> None: r""" Reverse the replace layer operation @@ -87,23 +93,30 @@ def _recursive_replace_layer( method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy """ + # released layers are not shardable + can_replace_param_or_layer = include is None or module in include if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ (module.__class__ == origin_cls): if attr_replacement is not None: self._replace_attr(module, attr_replacement) - if param_replacement is not None: + if param_replacement is not None and can_replace_param_or_layer: self._replace_param(module, param_replacement) if method_replacement is not None: self._replace_method(module, method_replacement) - if sub_module_replacement is not None: + if sub_module_replacement is not None and can_replace_param_or_layer: self._replace_sub_module(module, sub_module_replacement) for name, child in module.named_children(): - self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement, - sub_module_replacement) + self._recursive_replace_layer(child, + origin_cls, + attr_replacement, + param_replacement, + method_replacement, + sub_module_replacement, + include=include) def _replace_attr( self, @@ -185,13 +198,15 @@ def _replace_sub_module( setattr_(org_layer, suffix, replace_layer) - def _release_unheld_layers(self) -> None: + def _release_unheld_layers(self) -> Optional[Set[nn.Module]]: r""" Release the unheld layers in the model """ if self.shard_config and self.shard_config.pipeline_stage_manager: held_layers = self.policy.get_held_layers() set_tensors_to_none(self.model, exclude=set(held_layers)) + return set(held_layers) + return None def _materialize(self) -> None: r""" diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py new file mode 100644 index 000000000000..a58afac810d7 --- /dev/null +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -0,0 +1,99 @@ +from contextlib import nullcontext +from typing import Optional + +import torch +import torch.distributed as dist + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.fx import is_compatible_with_meta +from colossalai.lazy.lazy_init import LazyInitContext +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: + try: + if init_method == 'lazy': + ctx = LazyInitContext() + else: + ctx = nullcontext() + plugin = HybridParallelPlugin(tp_size=2, pp_size=2, num_microbatches=4, precision='bf16') + booster = Booster(plugin=plugin) + with ctx: + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = { + k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v + for k, v in data.items() + } + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data_iter = iter([data]) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + output_key = list(outputs.keys())[0] + loss = criterion(outputs[output_key]) + return loss + + booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True, return_outputs=False) + optimizer.step() + + except Exception as e: + return repr(e) + + +@parameterize('init_method', ['none', 'lazy']) +def check_3d_plugin(init_method: str = 'none', early_stop: bool = True): + """check gemini plugin over model zoo + + Args: + early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. + """ + is_support_meta = is_compatible_with_meta() + if not is_support_meta and init_method == 'lazy': + return + + passed_models = [] + failed_info = {} # (model_name, error) pair + + # TODO(ver217): add more models + for name, (model_fn, data_gen_fn, output_transform_fn, _, + _) in model_zoo.get_sub_registry('transformers_llama_for_casual_lm').items(): + err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) + torch.cuda.empty_cache() + + if err is None: + passed_models.append(name) + else: + failed_info[name] = err + if early_stop: + break + + if dist.get_rank() == 0: + print(f'Init method: {init_method}') + print(f'Passed models({len(passed_models)}): {passed_models}\n\n') + print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') + assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + + +def run_dist(rank, world_size, port, early_stop: bool = True): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + check_3d_plugin(early_stop=early_stop) + + +@rerun_if_address_is_in_use() +def test_gemini_plugin(early_stop: bool = True): + spawn(run_dist, 4, early_stop=early_stop) + + +if __name__ == '__main__': + test_gemini_plugin(early_stop=False) From 1cf9e01b050704caaeaf2f371c54dfebd92d0ed2 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 27 Jul 2023 13:11:47 +0800 Subject: [PATCH 39/80] [hotfix] fix gemini and zero test (#4333) * [hotfix] fix gemini and zero test * [hotfix] fix lazy init test * [hotfix] fix lazy init test --- tests/test_booster/test_plugin/test_gemini_plugin.py | 5 +++-- tests/test_lazy/test_models.py | 3 ++- tests/test_shardformer/test_model/test_pure_pipeline.py | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index d29c92926066..57160dfae89b 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -88,7 +88,9 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): 'torchvision_vit_b_16', 'torchvision_convnext_base', 'torchvision_swin_s', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert', 'transformers_bert_for_pretraining', 'transformers_gpt_double_heads', 'torchaudio_hubert_base', 'torchaudio_wav2vec2_base', - 'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model' + 'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model', + 'transformers_vit', 'transformers_vit_for_masked_image_modeling', + 'transformers_vit_for_image_classification' ]: continue @@ -99,7 +101,6 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): 'torchvision_shufflenet_v2_x0_5', 'torchvision_efficientnet_v2_s' ]: continue - err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index e37184125d21..ecb99e594267 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -11,7 +11,8 @@ def test_torchvision_models_lazy_init(subset, default_device): sub_model_zoo = model_zoo.get_sub_registry(subset) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'): + if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base' + ) or name.startswith('transformers_llama') or name.startswith('transformers_vit'): continue check_lazy_init(entry, verbose=True, default_device=default_device) diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py index 2f51eb9b02f7..c3cd05095c27 100644 --- a/tests/test_shardformer/test_model/test_pure_pipeline.py +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -59,7 +59,7 @@ def __init__(self, module: Module, shard_config: ShardConfig, stage_manager: Pip def prepare_dataloader(dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0): sampler = DistributedSampler( dataset, - #rank=self.pg_mesh.coordinate(DP_AXIS), + # rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle) # Deterministic dataloader @@ -161,6 +161,7 @@ def check_llama(rank, world_size, port): run_llama_test() +@pytest.mark.skip('This test will fail') @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() From f24f3c21207c9da1d94de15832799c4bb76630ef Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 27 Jul 2023 14:53:20 +0800 Subject: [PATCH 40/80] [pipeline] fix return_dict/fix pure_pipeline_test (#4331) --- colossalai/shardformer/modeling/bert.py | 33 +++---------------- colossalai/shardformer/modeling/bloom.py | 12 ------- colossalai/shardformer/modeling/gpt2.py | 2 ++ colossalai/shardformer/policies/opt.py | 28 +++++++++++----- .../test_model/test_pure_pipeline.py | 7 ++-- 5 files changed, 29 insertions(+), 53 deletions(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index df64c93cf85a..1b3c14d9d1c9 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, List, Optional, Tuple import torch @@ -277,9 +278,6 @@ def bert_for_pretraining_forward( if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False outputs = BertPipelineForwards.bert_model_forward( self.bert, @@ -387,9 +385,6 @@ def bert_lm_head_model_forward( if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False outputs = BertPipelineForwards.bert_model_forward( self.bert, @@ -478,9 +473,6 @@ def bert_for_masked_lm_forward( if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False outputs = BertPipelineForwards.bert_model_forward( self.bert, @@ -579,16 +571,15 @@ def bert_for_next_sentence_prediction_forward( FutureWarning, ) labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = BertPipelineForwards.bert_model_forward(self.bert, input_ids, @@ -661,10 +652,6 @@ def bert_for_sequence_classification_forward( if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = BertPipelineForwards.bert_model_forward(self.bert, input_ids, @@ -753,10 +740,6 @@ def bert_for_token_classification_forward( if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = BertPipelineForwards.bert_model_forward( self.bert, @@ -832,10 +815,6 @@ def bert_for_multiple_choice_forward( if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # in our pipeline design,input ids are copied for every stage and shouldn't be none # the input_ids for multiple choice model is [batch_size, num_choices, sequence_length] @@ -928,10 +907,6 @@ def bert_for_question_answering_forward( if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = BertPipelineForwards.bert_model_forward( self.bert, diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index fd200665df3d..76948fc70439 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -313,9 +313,6 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer, input_ids, @@ -411,9 +408,6 @@ def bloom_for_sequence_classification_forward( if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False transformer_outputs = BloomPipelineForwards.bloom_model_forward( self.transformer, @@ -537,9 +531,6 @@ def bloom_for_token_classification_forward( if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False transformer_outputs = BloomPipelineForwards.bloom_model_forward( self.transformer, @@ -626,9 +617,6 @@ def bloom_for_question_answering_forward( if output_hidden_states: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False outputs = BloomPipelineForwards.bloom_model_forward( self.transformer, diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 5519d0b3098c..dc5a81dc912b 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -52,6 +52,8 @@ def gpt2_model_forward( # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + logger = logging.get_logger(__name__) # Preprocess passed in arguments diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 244a0a54ef63..6fc3a2d31f4d 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -8,6 +8,18 @@ import torch.nn as nn from torch import Tensor, nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +from transformers.models.opt.modeling_opt import ( + OPTForCausalLM, + OPTForQuestionAnswering, + OPTForSequenceClassification, + OPTModel, +) from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D @@ -317,7 +329,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] @staticmethod def opt_model_forward( - self: 'OPTModel', + self: OPTModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, @@ -330,7 +342,7 @@ def opt_model_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, 'BaseModelOutputWithPast']: + ) -> Union[Tuple, BaseModelOutputWithPast]: ''' This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward ''' @@ -506,7 +518,7 @@ def custom_forward(*inputs): @staticmethod def opt_for_causal_lm_forward( - self: 'OPTForCausalLM', + self: OPTForCausalLM, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, @@ -520,7 +532,7 @@ def opt_for_causal_lm_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, 'CausalLMOutputWithPast']: + ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -646,7 +658,7 @@ def opt_for_causal_lm_forward( @staticmethod def opt_for_sequence_classification_forward( - self: 'OPTForSequenceClassification', + self: OPTForSequenceClassification, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -660,7 +672,7 @@ def opt_for_sequence_classification_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, 'SequenceClassifierOutputWithPast']: + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -746,7 +758,7 @@ def opt_for_sequence_classification_forward( @staticmethod def opt_for_question_answering_forward( - self: 'OPTForQuestionAnswering', + self: OPTForQuestionAnswering, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -761,7 +773,7 @@ def opt_for_question_answering_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, 'QuestionAnsweringModelOutput']: + ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py index c3cd05095c27..576e6473bcca 100644 --- a/tests/test_shardformer/test_model/test_pure_pipeline.py +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -1,6 +1,5 @@ import copy import random -from contextlib import nullcontext from typing import Any, Callable, Iterator, List, Optional, Tuple import numpy as np @@ -100,8 +99,8 @@ def __getitem__(self, x): return torch.ones((4, 128), dtype=torch.int).cuda() * 10 -def loss(x, y): - return (x[0].float().mean() - y[0].float().mean()) +def loss(y, x): + return (y[0].float().mean() - x[0].float().mean()) @parameterize('enable_fused_normalization', [False]) @@ -137,7 +136,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la batch = next(data_iter) with torch.no_grad(): y = model_copy(batch) - org_loss = loss(batch, y) + org_loss = loss(y, batch) optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3) schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager) shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, From 064cea012f2d8201a991aaf2ea922652c25c4337 Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Mon, 31 Jul 2023 14:49:55 +0800 Subject: [PATCH 41/80] [pipeline] add unit test for 1f1b (#4303) * add unit test for 1f1b * polish code * polish code and update ut version * fix --- .../test_schedule/test_oneF_oneB.py | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 tests/test_pipeline/test_schedule/test_oneF_oneB.py diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py new file mode 100644 index 000000000000..542116a1da75 --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -0,0 +1,134 @@ +import copy +from functools import partial +from types import MethodType + +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all + + +class MlpModel(nn.Module): + + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(4, 8) + self.linear2 = nn.Linear(8, 4) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def pp_linear_fwd(forward, + data: torch.Tensor = None, + input_obj: torch.Tensor = None, + stage_mgr: PipelineStageManager = None): + + if stage_mgr.is_first_stage(): + return {'input_obj': forward(data)} + elif stage_mgr.is_last_stage(): + return forward(input_obj) + else: + return {'input_obj': forward(input_obj)} + + +def examine_pp(): + """ + This test is to examine the correctness of 1F1B, compared with torch. + Be aware it contains some hardcodes. + """ + world_size = torch.distributed.get_world_size() + local_rank = torch.distributed.get_rank() + seed_all(1453) + + NUM_MICRO_BATCHS = 4 + BATCH_SIZE = 4 + + # create models + torch_model = MlpModel().cuda() + + pp_model = copy.deepcopy(torch_model).cuda() + + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + pg_mesh = ProcessGroupMesh(1, world_size, 1) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + schedule = OneForwardOneBackwardSchedule(NUM_MICRO_BATCHS, stage_manager) + + for idx, (_, sub_model) in enumerate(pp_model.named_children()): + if idx % (world_size) == local_rank: + sharded_model = sub_model.cuda() + + sharded_model._forward = sharded_model.forward + sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward) + + # create optimizer + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1)) + + # create + seed_all(1453) + if stage_manager.is_first_stage(): + input_list = [torch.rand(BATCH_SIZE, 4).cuda()] + else: + input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] + torch.distributed.all_reduce(input_list[0]) + + criterion = lambda x, y: torch.mean(x) + + # forward and backward + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output, _) + torch_loss.backward() + + pp_ret = schedule.forward_backward_step(sharded_model, + pp_optimizer, + iter(input_list), + criterion, + return_loss=True, + return_outputs=True) + + # check loss + if stage_manager.is_last_stage(): + assert torch.allclose(torch_loss, pp_ret['loss']) + + # check gradients + torch_grad = [] + for torch_p in torch_model.parameters(): + torch_grad.append(torch_p.grad.data) + for idx, pp_p in enumerate(sharded_model.parameters()): + assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) + + # step + torch_optimizer.step() + pp_optimizer.step() + + # check updated param + torch_param = [] + for torch_p in torch_model.parameters(): + torch_param.append(torch_p.data) + for idx, pp_p in enumerate(sharded_model.parameters()): + assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + examine_pp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp(): + spawn(run_dist, 2) + + +if __name__ == '__main__': + test_pp() From daef153dc1219d01ffb5148ff3cd0468e92608b3 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 1 Aug 2023 10:35:17 +0800 Subject: [PATCH 42/80] [pipeline] refactor test pipeline and remove useless utils in pipeline (#4324) * refactor tests * refactor bloom model * finish policy tests * refactor tests * fix test pure pipeline * remove test pipeline and cutdown launch process * refactor tests * refactor bloom model * finish policy tests * refactor tests * fix test pure pipeline * remove test pipeline and cutdown launch process --- colossalai/pipeline/policy/__init__.py | 22 - colossalai/pipeline/policy/base.py | 111 ---- colossalai/pipeline/policy/bert.py | 523 ------------------ colossalai/pipeline/policy/bloom.py | 220 -------- colossalai/pipeline/schedule/one_f_one_b.py | 1 - colossalai/shardformer/policies/bert.py | 2 +- .../test_bert_for_pretraining_model.py | 64 --- .../test_policy/test_bert_lm_head_model.py | 64 --- .../test_policy/test_bert_model.py | 66 --- .../test_policy/test_bloom_model.py | 63 --- .../test_model/test_shard_bert.py | 3 + .../test_model/test_shard_bert_pipeline.py | 104 ++-- .../test_model/test_shard_bloom_pipeline.py | 71 +-- .../test_model/test_shard_llama_pipeline.py | 70 +-- 14 files changed, 138 insertions(+), 1246 deletions(-) delete mode 100644 colossalai/pipeline/policy/__init__.py delete mode 100644 colossalai/pipeline/policy/base.py delete mode 100644 colossalai/pipeline/policy/bert.py delete mode 100644 colossalai/pipeline/policy/bloom.py delete mode 100644 tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py delete mode 100644 tests/test_pipeline/test_policy/test_bert_lm_head_model.py delete mode 100644 tests/test_pipeline/test_policy/test_bert_model.py delete mode 100644 tests/test_pipeline/test_policy/test_bloom_model.py diff --git a/colossalai/pipeline/policy/__init__.py b/colossalai/pipeline/policy/__init__.py deleted file mode 100644 index fd9e6e04588e..000000000000 --- a/colossalai/pipeline/policy/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple, Type - -from torch import Tensor -from torch.nn import Module, Parameter - -from colossalai.pipeline.stage_manager import PipelineStageManager - -from .base import Policy -from .bert import BertModel, BertModelPolicy - -POLICY_MAP: Dict[Type[Module], Type[Policy]] = { - BertModel: BertModelPolicy, -} - - -def pipeline_parallelize( - model: Module, - stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: - if type(model) not in POLICY_MAP: - raise NotImplementedError(f"Policy for {type(model)} not implemented") - policy = POLICY_MAP[type(model)](stage_manager) - return policy.parallelize_model(model) diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py deleted file mode 100644 index f51d74fdbac3..000000000000 --- a/colossalai/pipeline/policy/base.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np -from torch import Tensor -from torch.nn import Module, Parameter - -from colossalai.lazy import LazyTensor -from colossalai.pipeline.stage_manager import PipelineStageManager - - -class Policy: - - def __init__(self, stage_manager: PipelineStageManager) -> None: - self.stage_manager = stage_manager - - def setup_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor]]: - """Setup model for pipeline parallel - - Args: - module (Module): Module to be setup - - Returns: - Tuple[Dict[str, Parameter], Dict[str, Tensor]]: Hold parameters and buffers - """ - hold_params = set() - hold_buffers = set() - - def init_layer(layer: Module): - for p in layer.parameters(): - if isinstance(p, LazyTensor): - p.materialize() - p.data = p.cuda() - hold_params.add(p) - for b in layer.buffers(): - if isinstance(b, LazyTensor): - b.materialize() - b.data = b.cuda() - hold_buffers.add(b) - - hold_layers = self.get_hold_layers(module) - - for layer in hold_layers: - init_layer(layer) - - hold_params_dict = {} - hold_buffers_dict = {} - - # release other tensors - for n, p in module.named_parameters(): - if p in hold_params: - hold_params_dict[n] = p - else: - if isinstance(p, LazyTensor): - p.materialize() - p.data = p.cuda() - p.storage().resize_(0) - for n, b in module.named_buffers(): - if b in hold_buffers: - hold_buffers_dict[n] = b - else: - if isinstance(b, LazyTensor): - b.materialize() - b.data = b.cuda() - # FIXME(ver217): use meta tensor may be better - b.storage().resize_(0) - return hold_params_dict, hold_buffers_dict - - def replace_forward(self, module: Module) -> None: - """Replace module forward in place. This method should be implemented by subclass. The output of internal layers must be a dict - - Args: - module (Module): _description_ - """ - raise NotImplementedError - - def get_hold_layers(self, module: Module) -> List[Module]: - """Get layers that should be hold in current stage. This method should be implemented by subclass. - - Args: - module (Module): Module to be setup - - Returns: - List[Module]: List of layers that should be hold in current stage - """ - raise NotImplementedError - - def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]: - """Get parameters that should be shared across stages. This method should be implemented by subclass. - - Args: - module (Module): Module to be setup - - Returns: - List[Module]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] - """ - raise NotImplementedError - - def parallelize_model(self, - module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: - """Parallelize model for pipeline parallel - - Args: - module (Module): Module to be setup - - Returns: - Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: Hold parameters, buffers and shared parameters - """ - hold_params, hold_buffers = self.setup_model(module) - self.replace_forward(module) - shared_params = self.get_shared_params(module) - return hold_params, hold_buffers, shared_params diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py deleted file mode 100644 index abce504e9d61..000000000000 --- a/colossalai/pipeline/policy/bert.py +++ /dev/null @@ -1,523 +0,0 @@ -from functools import partial -from types import MethodType -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from torch import Tensor -from torch.nn import CrossEntropyLoss, Module -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, -) -from transformers.models.bert.modeling_bert import ( - BertForPreTraining, - BertForPreTrainingOutput, - BertLMHeadModel, - BertModel, -) -from transformers.utils import ModelOutput, logging - -from colossalai.pipeline.stage_manager import PipelineStageManager - -from .base import Policy - -logger = logging.get_logger(__name__) - - -def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - # labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage -): - # TODO: add explaination of the output here. - r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ - # debugging - # preprocess: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - batch_size, seq_length = input_shape - device = input_ids.device if input_ids is not None else inputs_embeds.device - else: - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device - - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - - if token_type_ids is None: - if hasattr(self.embeddings, "token_type_ids"): - buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - attention_mask = extended_attention_mask - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - hidden_states = hidden_states if hidden_states is not None else None - - if stage_manager.is_first_stage(): - hidden_states = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - - # inherit from bert_layer,this should be changed when we add the feature to record hidden_states - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - if self.encoder.gradient_checkpointing and self.encoder.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - next_decoder_cache = () if use_cache else None - - # calculate the num_layers - num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - - # layer_outputs - layer_outputs = hidden_states if hidden_states is not None else None - for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): - if stage_manager.is_first_stage() and idx == 0: - encoder_attention_mask = encoder_extended_attention_mask - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[idx] if head_mask is not None else None - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.encoder.gradient_checkpointing and self.encoder.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + \ - (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # end of a stage loop - sequence_output = layer_outputs[0] if layer_outputs is not None else None - - if stage_manager.is_last_stage(): - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - if not return_dict: - return (sequence_output, pooled_output) + layer_outputs[1:] - # return dict is not supported at this moment - else: - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - # output of non-first and non-last stages: must be a dict - else: - # intermediate stage always return dict - return { - 'hidden_states': hidden_states, - } - - -# The layer partition policy for bertmodel -class BertModelPolicy(Policy): - - def __init__( - self, - stage_manager: PipelineStageManager, - num_layers: int, - ): - super().__init__(stage_manager=stage_manager) - self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) - - def get_hold_layers(self, module: BertModel) -> List[Module]: - """ - get pipeline layers for current stage - """ - hold_layers = [] - if self.stage_manager.is_first_stage(): - hold_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) - hold_layers.extend(module.encoder.layer[start_idx:end_idx]) - if self.stage_manager.is_last_stage(): - hold_layers.append(module.pooler) - - return hold_layers - - def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: - '''no shared params in bertmodel''' - return [] - - def replace_forward(self, module: Module) -> None: - module.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module) - - -def bert_for_pretraining_forward( - self: BertForPreTraining, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - next_sentence_label: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_manager: Optional[PipelineStageManager] = None, -): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - outputs = bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states if hidden_states is not None else None) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None - if stage_manager.is_last_stage(): - sequence_output, pooled_output = outputs[:2] - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) - # the last stage for pretraining model - total_loss = None - if labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss - - if not return_dict: - output = (prediction_scores, seq_relationship_score) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return BertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - - # intermediate stage always return dict - return { - 'hidden_states': hidden_states, - } - - -class BertForPreTrainingPolicy(Policy): - - def __init__(self, stage_manager: PipelineStageManager, num_layers: int): - super().__init__(stage_manager=stage_manager) - self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) - - def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: - """ - get pipeline layers for current stage - """ - hold_layers = [] - if self.stage_manager.is_first_stage(): - hold_layers.append(module.bert.embeddings) - - start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) - hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) - - if self.stage_manager.is_last_stage(): - hold_layers.append(module.bert.pooler) - hold_layers.append(module.cls) - - return hold_layers - - def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]: - '''no shared params in bertmodel''' - return [] - - def replace_forward(self, module: Module) -> None: - module.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), - module.forward) - - -def bert_lmhead_forward(self: BertLMHeadModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.Tensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_manager: Optional[PipelineStageManager] = None): - r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if labels is not None: - use_cache = False - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - outputs = bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states if hidden_states is not None else None) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None - - if stage_manager.is_last_stage(): - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - lm_loss = None - if labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((lm_loss,) + output) if lm_loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=lm_loss, - logits=prediction_scores, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - # intermediate stage always return dict - return {'hidden_states': hidden_states} - - -class BertLMHeadModelPolicy(Policy): - - def __init__(self, stage_manager: PipelineStageManager, num_layers: int): - super().__init__(stage_manager=stage_manager) - self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) - - def get_hold_layers(self, module: BertLMHeadModel) -> List[Module]: - """ - get pipeline layers for current stage - """ - hold_layers = [] - if self.stage_manager.is_first_stage(): - hold_layers.append(module.bert.embeddings) - start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) - hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) - if self.stage_manager.is_last_stage(): - hold_layers.append(module.bert.pooler) - hold_layers.append(module.cls) - - return hold_layers - - def get_shared_params(self, module: BertLMHeadModel) -> List[Dict[int, Tensor]]: - '''no shared params in bertmodel''' - return [] - - def replace_forward(self, module: Module) -> None: - module.forward = MethodType(partial(bert_lmhead_forward, stage_manager=self.stage_manager), module) diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py deleted file mode 100644 index 71d2913fc3aa..000000000000 --- a/colossalai/pipeline/policy/bloom.py +++ /dev/null @@ -1,220 +0,0 @@ -import warnings -from functools import partial -from types import MethodType -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from torch import Tensor -from torch.nn import CrossEntropyLoss, Module -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.bloom.modeling_bloom import BloomModel -from transformers.utils import logging - -from colossalai.pipeline.stage_manager import PipelineStageManager - -from .base import Policy - -logger = logging.get_logger(__name__) - - -def bloom_model_forward( - self: BloomModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - **deprecated_arguments, -) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - if len(deprecated_arguments) > 0: - raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # add warnings here - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - # case: First stage of training - if stage_manager.is_first_stage(): - # check input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - # initialize in the first stage and then pass to the next stage - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - - # extra recording tensor should be generated in the first stage - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] # source_len - - seq_length_with_past = seq_length_with_past + past_key_values_length - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) - else: - attention_mask = attention_mask.to(hidden_states.device) - - alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - - # causal_mask is constructed every stage and its input is passed through different stages - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) - - # calculate the num_layers - num_layers_per_stage = len(self.h) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - - for i, (block, layer_past) in enumerate(zip(self.h[start_layer:end_layer], past_key_values[start_layer:end_layer])): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - alibi, - causal_mask, - layer_past, - head_mask[i], - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=causal_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - ) - - hidden_states = outputs[0] - - if use_cache is True: - presents = presents + (outputs[1],) - if output_attentions: - all_self_attentions = all_self_attentions + \ - (outputs[2 if use_cache else 1],) - - if stage_manager.is_last_stage(): - # Add last hidden state - hidden_states = self.ln_f(hidden_states) - - # TODO: deal with all_hidden_states, all_self_attentions, presents - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - # attention_mask is not returned ; presents = past_key_values - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class BloomModelPolicy(Policy): - - def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): - super().__init__(stage_manager=stage_manager) - self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, num_stages) - - def get_hold_layers(self, module: BloomModel) -> List[Module]: - """ - get pipeline layers for current stage - """ - hold_layers = [] - if self.stage_manager.is_first_stage(): - hold_layers.append(module.word_embeddings) - hold_layers.append(module.word_embeddings_layernorm) - - start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) - hold_layers.extend(module.h[start_idx:end_idx]) - - if self.stage_manager.is_last_stage(): - hold_layers.append(module.ln_f) - - return hold_layers - - def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]: - '''no shared params in bloommodel''' - pass - - def replace_forward(self, module: Module) -> None: - module.forward = MethodType(partial(bloom_model_forward, stage_manager=self.stage_manager), module.model) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 6ed3055d689b..d907d53edcde 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -76,7 +76,6 @@ def forward_step(self, # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict output_obj = model_forward(model, micro_batch, input_obj) - if self.stage_manager.is_last_stage(): loss = criterion(output_obj, micro_batch) / self.num_microbatches if accum_loss is not None: diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index f6a4c706eb14..6f86de232fad 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -315,7 +315,7 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) - mpolicy = self.add_lm_prediction_policy(policy) + policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForMaskedLM if self.pipeline_stage_manager: self.set_pipeline_forward(model_cls=BertForMaskedLM, diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py deleted file mode 100644 index bc3a9bf1b010..000000000000 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from transformers.models.bert import BertConfig -from transformers.models.bert.modeling_bert import BertForPreTraining - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy -from colossalai.shardformer.shard import ShardConfig -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def check_bert_for_pretraining_policy(): - configuration = BertConfig() - model = BertForPreTraining(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - - model_policy = BertForPreTrainingPolicy() - model_policy.set_model(model) - - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - model_policy.set_shard_config(model_config) - layers = model_policy.get_held_layers() - if stage_manager.is_first_stage(): - assert len(layers) == 6 + 1 - else: - assert len(layers) == 6 + 2 - - -def run_dist_policy(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_for_pretraining_policy() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_for_pretraining_policy(): - spawn(run_dist_policy, 4) - - -if __name__ == "__main__": - """test the bert for pretraining model forward and bert for pretraining model policy""" - test_bert_for_pretraining_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_lm_head_model.py b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py deleted file mode 100644 index 1aeb00123db8..000000000000 --- a/tests/test_pipeline/test_policy/test_bert_lm_head_model.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from transformers.models.bert import BertConfig -from transformers.models.bert.modeling_bert import BertLMHeadModel - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy -from colossalai.shardformer.shard import ShardConfig -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def check_bert_lmhead_policy(): - configuration = BertConfig() - model = BertLMHeadModel(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - - model_policy = BertLMHeadModelPolicy() - model_policy.set_model(model) - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - model_policy.set_shard_config(model_config) - layers = model_policy.get_held_layers() - - if stage_manager.is_first_stage(): - assert len(layers) == 6 + 1 - else: - assert len(layers) == 6 + 2 - - -def run_dist_policy(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_lmhead_policy() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_lmhead_policy(): - spawn(run_dist_policy, 4) - - -if __name__ == "__main__": - """test the bert for lm head model policy""" - test_bert_lmhead_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py deleted file mode 100644 index b366df01788b..000000000000 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ /dev/null @@ -1,66 +0,0 @@ -''' -In the test policy we only test policy: held layers and others, as the tests for forward logic are done in test_shardformer/test_model -''' - -import pytest -import torch.distributed as dist -from transformers.models.bert.modeling_bert import BertModel - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.policies.bert import BertModelPolicy -from colossalai.shardformer.shard import ShardConfig -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def check_bert_model_policy(): - model = BertModel.from_pretrained('bert-base-uncased') - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - - model_policy = BertModelPolicy() - model_policy.set_model(model) - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - model_policy.set_shard_config(model_config) - - layers = model_policy.get_held_layers() - - if stage_manager.is_first_stage(): - assert len(layers) == 6 + 1 - else: - assert len(layers) == 6 + 1 - - -def run_dist_policy(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_model_policy() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_model_policy(): - spawn(run_dist_policy, 4) - - -if __name__ == "__main__": - """test the bert model policy""" - test_bert_model_policy() diff --git a/tests/test_pipeline/test_policy/test_bloom_model.py b/tests/test_pipeline/test_policy/test_bloom_model.py deleted file mode 100644 index e6a222f4e3d5..000000000000 --- a/tests/test_pipeline/test_policy/test_bloom_model.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from transformers.models.bloom import BloomConfig, BloomModel - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.policies.bloom import BloomModelPolicy -from colossalai.shardformer.shard import ShardConfig -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def check_bloom_model_policy(): - # create a BloomModel - configuration = BloomConfig() - model = BloomModel(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - - model_policy = BloomModelPolicy() - model_policy.set_model(model) - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - model_policy.set_shard_config(model_config) - layers = model_policy.get_held_layers() - if stage_manager.is_first_stage(): - assert len(layers) == 1 + 2 - else: - assert len(layers) == 1 + 1 - - -def run_dist_policy(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bloom_model_policy() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bloom_model_policy(): - spawn(run_dist_policy, 4) - - -if __name__ == "__main__": - """test the bloom model policy""" - test_bloom_model_policy() diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index ea0f122644dc..6d0d3c798c4e 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -2,7 +2,10 @@ import torch import colossalai +from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.auto_policy import get_autopolicy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py index 4feaf982aa37..3170b58a1175 100644 --- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py @@ -5,6 +5,8 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.auto_policy import get_autopolicy +from colossalai.shardformer.shard import ShardConfig from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, @@ -17,9 +19,55 @@ from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - pass +def check_bert_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager): + stage_manager = stage_manager + policy = get_autopolicy(model) + policy.set_model(model) + model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) + policy.set_shard_config(model_config) + layers = policy.get_held_layers() + if stage_manager.is_first_stage(): + assert len(layers) == 1 + 1 + else: + if name == "transformers_bert": + assert len(layers) == 1 + 1 + elif name in [ + "transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification", + "transformers_bert_for_mcq" + ]: + assert len(layers) == 1 + 3 + else: + assert len(layers) == 1 + 2 + + +def check_bert_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager): + if name == 'transformers_bert_for_mcq': + x = torch.randint(0, 1000, (2, 3, 3)).cuda() + attention_mask = torch.ones_like(x).cuda() + if stage_manager.stage == 0: + output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) + assert output['hidden_states'].shape == (6, 3, 128) + else: + hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda() + output = sharded_model(input_ids=x, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + assert output[0].shape == (2, 3) + else: + x = torch.randint(0, 1000, (2, 3)).cuda() + # one batch, 2 single sentences, each sentence has 3 tokens + hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) + assert output['hidden_states'].shape == (2, 3, 128) + else: + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model(hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + assert output[0].shape[0] == 2 @parameterize('enable_fused_normalization', [False]) @@ -27,55 +75,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('use_lazy_init', [False]) #TODO: merge this into test_shard_bert def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + PP_DIM = 0 + PP_SIZE = 2 + pg_mesh = ProcessGroupMesh(PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) - - if name == 'transformers_bert_for_mcq': - x = torch.randint(0, 1000, (2, 3, 3)).cuda() - attention_mask = torch.ones_like(x).cuda() - if stage_manager.stage == 0: - output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - assert output['hidden_states'].shape == (6, 3, 128) - else: - hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda() - output = sharded_model(input_ids=x, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) - assert output[0].shape == (2, 3) - else: - x = torch.randint(0, 1000, (2, 3)).cuda() - # one batch, 2 single sentences, each sentence has 3 tokens - hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - assert output['hidden_states'].shape == (2, 3, 128) - else: - attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model(hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) - assert output[0].shape[0] == 2 + check_bert_model_policy(name, org_model, stage_manager) + check_bert_model_pipeline_forward(name, sharded_model, stage_manager) torch.cuda.empty_cache() @@ -90,7 +100,7 @@ def check_bert(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_bert(): - spawn(check_bert, 4) + spawn(check_bert, 2) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py index 3a36479fc8bb..6695e8a687bd 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py @@ -5,7 +5,9 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.auto_policy import get_autopolicy from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.shard import ShardConfig from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, @@ -18,9 +20,37 @@ from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - pass +def check_bloom_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager): + policy = get_autopolicy(model) + policy.set_model(model) + model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) + policy.set_shard_config(model_config) + layers = policy.get_held_layers() + if stage_manager.is_first_stage(): + assert len(layers) == 0 + 2 + else: + if name == 'transformers_bloom': + assert len(layers) == 1 + 1 + elif name == 'transformers_bloom_for_token_classification': + assert len(layers) == 1 + 3 + else: + assert len(layers) == 1 + 2 + + +def check_bloom_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager): + if stage_manager.stage == 0: + x = torch.randint(0, 1000, (1, 3)).cuda() + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask) + assert output['hidden_states'].shape == (1, 3, 64) + else: + attention_mask = torch.ones((1, 3)).cuda() + hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda() + output = sharded_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + assert output[0].shape[0] == 1 @parameterize('enable_fused_normalization', [False]) @@ -28,40 +58,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('use_lazy_init', [False]) #TODO: merge this into test_shard_bloom def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + PP_DIM = 0 + PP_SIZE = 2 + pg_mesh = ProcessGroupMesh(PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') - x = torch.randint(0, 1000, (1, 3)).cuda() - hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask) - assert output['hidden_states'].shape == (1, 3, 64) - else: - attention_mask = torch.ones((1, 3)).cuda() - output = sharded_model( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - assert output[0].shape[0] == 1 + check_bloom_model_policy(name, org_model, stage_manager) + check_bloom_model_pipeline_forward(name, sharded_model, stage_manager) torch.cuda.empty_cache() @@ -76,7 +83,7 @@ def check_bloom(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_bloom(): - spawn(check_bloom, 4) + spawn(check_bloom, 2) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py index 8fd9ed099478..6f1f0cb34508 100644 --- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -5,7 +5,9 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.auto_policy import get_autopolicy from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.shard import ShardConfig from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, @@ -18,9 +20,35 @@ from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - pass +def check_llama_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager): + policy = get_autopolicy(model) + policy.set_model(model) + model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) + policy.set_shard_config(model_config) + layers = policy.get_held_layers() + if stage_manager.is_first_stage(): + assert len(layers) == 2 + 1 + else: + if name == "transformers_llama": + assert len(layers) == 2 + 1 + else: + assert len(layers) == 2 + 2 + + +def check_llama_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager): + x = torch.randint(0, 1000, (2, 3)).cuda() + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask) + assert output['hidden_states'].shape == (2, 3, 128) + else: + hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + assert output[0] is not None @parameterize('enable_fused_normalization', [False]) @@ -28,40 +56,18 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('use_lazy_init', [False]) #TODO: merge this into test_shard_llama def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + PP_DIM = 0 + PP_SIZE = 2 + pg_mesh = ProcessGroupMesh(PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') - x = torch.randint(0, 1000, (2, 3)).cuda() - hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask) - assert output['hidden_states'].shape == (2, 3, 128) - else: - attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - assert output[0] is not None + check_llama_model_policy(name, org_model, stage_manager) + check_llama_model_pipeline_forward(name, sharded_model, stage_manager) torch.cuda.empty_cache() @@ -76,7 +82,7 @@ def check_llama(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_llama(): - spawn(check_llama, 4) + spawn(check_llama, 2) if __name__ == "__main__": From 1920ce19ad92f0881052b810f164686f23ae6414 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 1 Aug 2023 17:29:09 +0800 Subject: [PATCH 43/80] [pipeline] support fp32 for HybridPlugin/merge shardformer test and pipeline test into one file (#4354) * add naive optimizer for 3DPlugin/refactor gpt2 shardformer test * merge tests of PP/DP/TP combinations into one test file * fix bug when sync grad for dp in HybridPlugin * update supported precisions for 3DPlugin/fix bug when shifting tp_degree * improve the passing of lazy_init * modify lazy_init/use sync_shared_params --- .../naive_amp/mixed_precision_optimizer.py | 2 +- .../booster/plugin/hybrid_parallel_plugin.py | 37 +++- .../shardformer/layer/qkv_fused_linear.py | 4 +- colossalai/tensor/d_tensor/api.py | 5 + tests/kit/model_zoo/transformers/gpt.py | 3 +- .../test_model/test_pure_pipeline.py | 1 - .../test_model/test_shard_gpt2.py | 205 +++++++++++++----- .../test_model/test_shard_gpt2_pipeline.py | 72 ------ 8 files changed, 187 insertions(+), 142 deletions(-) delete mode 100644 tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index d4183be3fb5f..626a00c96d04 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -134,7 +134,7 @@ def step(self, *args, **kwargs): working_param = self.master_to_working_map[p] if p is working_param: continue - if working_param.grad is None: + if working_param.grad is not None: p.grad = working_param.grad.data.float() working_param.grad = None total_norm = self._compute_grad_norm() diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 37badb613433..35a88d1e8980 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -42,6 +42,8 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp module = module.half().cuda() elif precision == 'bf16': module = module.to(dtype=torch.bfloat16).cuda() + else: + module = module.cuda() # train without AMP # TODO(ver217): support TP+DP super().__init__(module) @@ -61,6 +63,7 @@ def sync_grads(self): for p in self.module.parameters(): if p.grad is not None: dist.all_reduce(p.grad, group=self.dp_group) + p.grad.div_(self.dp_group.size()) def init_pipeline_optimizer(optim: Optimizer, model: Module): @@ -72,7 +75,15 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): optim.__setstate__({'param_groups': new_param_groups}) -class HybridParallelOptimizer(MixedPrecisionOptimizer): +class HybridParallelNaiveOptimizer(OptimizerWrapper): + + def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool): + if use_pipeline: + init_pipeline_optimizer(optim, model) + super().__init__(optim) + + +class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): def __init__(self, optim: Optimizer, @@ -192,7 +203,7 @@ def supported_devices(self) -> List[str]: return ['cuda'] def supported_precisions(self) -> List[str]: - return ['fp16', 'bf16'] + return ['fp16', 'bf16', 'fp32'] def control_device(self) -> bool: return True @@ -218,12 +229,17 @@ def configure( model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: - optimizer = HybridParallelOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - precision=self.precision, - max_norm=self.max_norm, - **self.amp_config) + if self.precision in ['fp16', 'bf16']: + optimizer = HybridParallelAMPOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config) + else: + optimizer = HybridParallelNaiveOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism) else: optimizer = HybridParallelZeroOptimizer(optimizer, model, @@ -241,7 +257,8 @@ def execute_pipeline(self, data_iter: Iterator, model: HybridParallelModule, criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Union[HybridParallelOptimizer, HybridParallelZeroOptimizer], + optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, + HybridParallelZeroOptimizer], return_loss: bool = True, return_outputs: bool = False) -> dict: assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled' @@ -250,7 +267,7 @@ def execute_pipeline(self, with ctx: outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, return_outputs) - # model.sync_shared_params() + model.sync_shared_params() if isinstance(optimizer, HybridParallelZeroOptimizer): optimizer.sync_grad() else: diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index bcefcf058ce0..3c47c0b1106f 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -456,12 +456,12 @@ def forward(self, input_: Tensor) -> Tensor: if self.parallel_input: assert input_.shape[-1] == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + input_.shape, self.weight.shape, self.weight.shape[0]) input_ = input_ else: assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions) input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) if self.stream_chunk_num > 1: diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index 32182faf6981..9848e4ca423e 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -16,6 +16,11 @@ layout_converter = LayoutConverter() +def clear_layout_converter(): + global layout_converter + layout_converter.cached_solution.clear() + + def is_distributed_tensor(tensor: torch.Tensor) -> bool: """ Check whether the given tensor is a distributed tensor. diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index e447b700105e..fcde75abdedc 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -70,7 +70,8 @@ def data_gen_for_sequence_classification(): resid_pdrop=0, summary_first_dropout=0, hidden_dropout=0, - problem_type="single_label_classification") + problem_type="single_label_classification", + pad_token_id=50256) # register the following models model_zoo.register(name='transformers_gpt', diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py index 576e6473bcca..31e76ef5107c 100644 --- a/tests/test_shardformer/test_model/test_pure_pipeline.py +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -160,7 +160,6 @@ def check_llama(rank, world_size, port): run_llama_test() -@pytest.mark.skip('This test will fail') @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 99451b403eb7..eae4f2ffb799 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -1,85 +1,180 @@ +import copy +from contextlib import nullcontext + import pytest import torch +from torch import distributed as dist +from torch.optim import Adam import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.lazy.lazy_init import LazyInitContext from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, +from colossalai.tensor.d_tensor.api import ( + clear_layout_converter, + is_customized_distributed_tensor, + is_distributed_tensor, ) +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + use_lazy_init = False + if 'use_lazy_init' in test_config: + use_lazy_init = test_config.pop('use_lazy_init') - # do backward + if use_lazy_init: + ctx = LazyInitContext() + else: + ctx = nullcontext() + + # prepare booster + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + stage_manager = plugin.stage_manager + + # prepare models and optimizers + with ctx: + org_model = model_fn().cuda() + sharded_model = copy.deepcopy(org_model) + + if use_lazy_init: + org_model = ctx.materialize(org_model) + + org_optimizer = Adam(org_model.parameters(), lr=1e-3) + sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) + criterion = loss_fn + + sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + # do forward and backward + data = data_gen_fn() + sharded_model.train() + if stage_manager: + data = { + k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v + for k, v in data.items() + } + data_iter = iter([data]) + sharded_output = booster.execute_pipeline(data_iter, + sharded_model, + _criterion, + sharded_optimizer, + return_loss=True, + return_outputs=True) + sharded_loss = sharded_output['loss'] + else: + data = {k: v.cuda() for k, v in data.items()} + sharded_output = sharded_model(**data) + sharded_loss = criterion(sharded_output) + sharded_loss.backward() + + org_model.train() + org_output = org_model(**data) + org_loss = criterion(org_output) org_loss.backward() - shard_loss.backward() - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}" + if stage_manager is None or stage_manager.is_last_stage(): + + # check last hidden state + if org_model.__class__.__name__ == 'GPT2Model': + org_hidden_state = org_output.last_hidden_state + + if stage_manager is None: + sharded_hidden_state = sharded_output.last_hidden_state + + if stage_manager and stage_manager.is_last_stage(): + sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], + dim=0) + + assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=1e-5, rtol=1e-3), \ + f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" + + # check loss + assert torch.allclose(org_loss, sharded_loss, atol=1e-5, rtol=1e-3), \ + f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" # unwrap model if org_model.__class__.__name__ == 'GPT2Model': org_model = org_model - sharded_model = sharded_model + sharded_model = sharded_model.unwrap() else: org_model = org_model.transformer - sharded_model = sharded_model.transformer + sharded_model = sharded_model.unwrap().transformer - # check mlp grad - org_grad = org_model.h[0].mlp.c_fc.weight.grad - shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad - shard_weight = sharded_model.h[0].mlp.c_fc.weight + # check weights and gradients + if stage_manager is None or stage_manager.is_first_stage(): + + shard_weight = sharded_model.h[0].mlp.c_fc.weight + org_grad = org_model.h[0].mlp.c_fc.weight.grad + shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(plugin.tp_size)] + dist.all_gather(shard_grad_list, shard_grad, plugin.tp_group) + shard_grad = torch.cat(shard_grad_list, dim=1) + + assert torch.allclose(org_grad, shard_grad, atol=1e-5, rtol=1e-3), \ + f"shard model grad is not equal to origin model grad\n{org_grad}\n{shard_grad}" + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + + org_weight = org_model.h[0].mlp.c_fc.weight + shard_weight = sharded_model.h[0].mlp.c_fc.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_weight_list = [torch.zeros([*shard_weight.shape]).to('cuda') for _ in range(plugin.tp_size)] + dist.all_gather(shard_weight_list, shard_weight, plugin.tp_group) + shard_weight = torch.cat(shard_weight_list, dim=1) + + assert torch.allclose(org_weight, shard_weight, atol=5e-3, rtol=1e-3), \ + f"shard model weight is not equal to origin model weight\n{org_weight}\n{shard_weight}" - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=1) - else: - all_shard_grad = shard_grad - assert torch.allclose( - org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" - - # check embedding weights - org_grad = org_model.wte.weight.grad - shard_grad = sharded_model.wte.weight.grad - shard_weight = sharded_model.wte.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose( - org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" torch.cuda.empty_cache() -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('use_lazy_init', [False, True]) +@parameterize('test_config', [{ + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': True +}, { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': False, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}]) @clear_cache_before_run() -def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): +def run_gpt2_test(test_config): + + # TODO: add plugin_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - use_lazy_init) - check_state_dict(org_model, sharded_model, name=name) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + clear_layout_converter() torch.cuda.empty_cache() @@ -93,7 +188,7 @@ def check_gpt2(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_gpt2(): - spawn(check_gpt2, 2) + spawn(check_gpt2, 4) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py deleted file mode 100644 index d5453ee72644..000000000000 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_pipeline_model - - -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # TODO: add tests for forward/backward later - pass - - -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('enable_fused_normalization', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_gpt2 -def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - inputs = data_gen_fn() - inputs = {k: v.cuda() for k, v in inputs.items()} - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - input_ids = inputs['input_ids'] - batch_size, seq_len = input_ids.shape - hidden_size = sharded_model.config.n_embd - hidden_state_shape = (batch_size, seq_len, hidden_size) - - if not stage_manager.is_first_stage(): - # change inputs if not the first stage - hidden_states = torch.zeros(*hidden_state_shape).cuda() - inputs['input_ids'] = None - inputs['hidden_states'] = hidden_states - - sharded_model.train() - output = sharded_model(**inputs) - if stage_manager.is_last_stage(): - if name == 'transformers_gpt': - assert output[0].shape == hidden_state_shape - else: - assert output.loss is not None - else: - assert output['hidden_states'].shape == hidden_state_shape - - torch.cuda.empty_cache() - - -def check_gpt2(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_gpt2_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_gpt2(): - spawn(check_gpt2, 4) - - -if __name__ == "__main__": - test_gpt2() From a02ddd2bb155ff0d0aed2c9560932b1112eece86 Mon Sep 17 00:00:00 2001 From: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Date: Fri, 7 Jul 2023 14:06:46 +0800 Subject: [PATCH 44/80] Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout --- colossalai/shardformer/policies/vit.py | 83 +++++++++++++------------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 47f2c58fc436..96f27de2a7c8 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,4 +1,3 @@ -from functools import partial from typing import Callable, Dict, List, Union import torch.nn as nn @@ -36,7 +35,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: suffix="dropout", target_module=col_nn.DropoutForReplicatedInput, ) - ]) + ]) policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={ "attention.attention.num_attention_heads": @@ -44,45 +43,47 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "attention.attention.all_head_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - ]) + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ]) + + return policy return policy From 3896e92788ae9456250ffd5efba12f31b7d2016c Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Fri, 14 Jul 2023 15:56:59 +0800 Subject: [PATCH 45/80] [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code --- colossalai/shardformer/_utils.py | 44 +++- colossalai/shardformer/layer/__init__.py | 5 +- .../shardformer/layer/qkv_fused_linear.py | 175 ++++++++++++++- colossalai/shardformer/modeling/sam.py | 41 ++++ .../shardformer/policies/auto_policy.py | 4 + colossalai/shardformer/policies/sam.py | 209 ++++++++++++++++++ tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/sam.py | 52 +++++ .../test_gpt2_qkv_fused_linear_1d.py | 120 ++++++++++ .../test_model/test_shard_sam.py | 92 ++++++++ 10 files changed, 733 insertions(+), 10 deletions(-) create mode 100644 colossalai/shardformer/modeling/sam.py create mode 100644 colossalai/shardformer/policies/sam.py create mode 100644 tests/kit/model_zoo/transformers/sam.py create mode 100644 tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py create mode 100644 tests/test_shardformer/test_model/test_shard_sam.py diff --git a/colossalai/shardformer/_utils.py b/colossalai/shardformer/_utils.py index 4ad877e72357..c553080de0a0 100644 --- a/colossalai/shardformer/_utils.py +++ b/colossalai/shardformer/_utils.py @@ -1,25 +1,57 @@ import re -def get_obj_list_element(obj, a): +def get_obj_list_element(obj, attr: str): r""" Get the element of the list in the object + + If the attr is a normal attribute, return the attribute of the object. + If the attr is a index type, return the element of the index in the list, like `layers[0]`. + + Args: + obj (Object): The object to get + attr (str): The suffix of the attribute to get + """ re_pattern = r'\[\d+\]' prog = re.compile(re_pattern) - result = prog.search(a) + result = prog.search(attr) if result: matched_brackets = result.group() matched_index = matched_brackets.replace('[', '') matched_index = matched_index.replace(']', '') - a_ = a.replace(matched_brackets, '') - container_obj = getattr(obj, a_) + attr_ = attr.replace(matched_brackets, '') + container_obj = getattr(obj, attr_) obj = container_obj[int(matched_index)] else: - obj = getattr(obj, a) + obj = getattr(obj, attr) return obj +def set_obj_list_element(obj, attr: str, value): + r""" + Set the element to value of a list object + + It used like set_obj_list_element(obj, 'lyaers[0]', new_layer), it will set obj.layers[0] to value + + Args: + obj (object): The object to set + attr (str): the string including a list index like `layers[0]` + """ + re_pattern = r'\[\d+\]' + prog = re.compile(re_pattern) + result = prog.search(attr) + if result: + matched_brackets = result.group() + matched_index = matched_brackets.replace('[', '') + matched_index = matched_index.replace(']', '') + attr_ = attr.replace(matched_brackets, '') + container_obj = getattr(obj, attr_) + container_obj[int(matched_index)] = value + else: + setattr(obj, attr, value) + + def hasattr_(obj, attr: str): r""" Check whether the object has the multi sublevel attr @@ -56,7 +88,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False): if ignore: return raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}") - setattr(obj, attrs[-1], value) + set_obj_list_element(obj, attrs[-1], value) def getattr_(obj, attr: str, ignore: bool = False): diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7cdcfc31811f..0c44e6621711 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,11 +3,10 @@ from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm -from .parallel_module import ParallelModule -from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm', 'ParallelModule' + 'FusedLayerNorm', 'FusedRMSNorm', 'FusedLinear1D_Col' ] diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 3c47c0b1106f..1e4b6ecb69b3 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -25,6 +25,7 @@ from ._operation import ( gather_forward_split_backward, + linear_with_async_comm, matmul_with_async_comm, reduce_backward, reduce_forward, @@ -33,7 +34,7 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row'] +__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row', 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row'] # ==================================== # For GPT Only @@ -490,3 +491,175 @@ def forward(self, input_: Tensor) -> Tensor: return output else: return output, self.bias + + +# ==================================== +# For Fused torch.nn.Linear +# ==================================== + + +class FusedLinear1D_Col(ParallelModule): + r"""Fused Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `torch.nn.Linear` layer (Fused QKV) in normal torch layer of huggingface, like SAM. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + n_fused (int): The number items fused, defaults to 3 (QKV). + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + async_communication: bool = False, + gather_output: bool = False, + skip_bias_add: bool = False, + n_fused: int = 3, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.n_fused = n_fused + self.process_group = process_group + self.async_communication = async_communication + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) + + def shard_fn(tensor): + return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) + + def gather_fn(tensor): + return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, False) + + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn) + self.weight = customized_distributed_tensor_to_param(sharded_weight) + + if bias: + bias = torch.empty(self.out_features, **factory_kwargs) + + with torch.no_grad(): + sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn) + self.bias = customized_distributed_tensor_to_param(sharded_bias) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, + *args, **kwargs) -> ParallelModule: + r""" + Convert a fused `torch.nn.linear` layer to a parallelized linear layer. + + Args: + module (`nn.Linear`): The module to be converted. + process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. + n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = FusedLinear1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, + n_fused=n_fused, + process_group=process_group, + is_transposed=False) + linear_1d.weight.data.copy_(sharded_weight.data) + + if bias: + sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, + n_fused=n_fused, + process_group=process_group, + is_transposed=False) + linear_1d.bias.data.copy_(sharded_bias.data) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_backward(input_, self.process_group) + input_parallel = input_ + + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py new file mode 100644 index 000000000000..00e2d744e219 --- /dev/null +++ b/colossalai/shardformer/modeling/sam.py @@ -0,0 +1,41 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +def forward_fn(): + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, + -1).permute(2, 0, 3, 1, 4)) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_rel_pos: + attn_weights = self.add_decomposed_rel_pos(attn_weights, query, self.rel_pos_h, self.rel_pos_w, + (height, width), (height, width)) + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + # replace dropout process with added DropoutForParallelInput layer + # origin code: + # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_probs = self.dropout_layer(attn_weights) + + attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + return forward diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index d00a03c9237e..63ec8398fcee 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -104,6 +104,10 @@ class PolicyLocation: PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"), "transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"), + + # Sam + "transformers.models.sam.modeling_sam.SamModel": + PolicyLocation(file_name="sam", class_name="SamModelPolicy"), } diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py new file mode 100644 index 000000000000..e75d63946260 --- /dev/null +++ b/colossalai/shardformer/policies/sam.py @@ -0,0 +1,209 @@ +import torch.nn as nn + +import colossalai.shardformer.layer as col_nn + +from .._utils import getattr_, setattr_ +from ..modeling.sam import forward_fn +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ['SamPolicy', 'SamModelPolicy'] + + +class SamPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + return self.model + + def module_policy(self): + from transformers.models.sam.modeling_sam import ( + SamFeedForward, + SamTwoWayAttentionBlock, + SamTwoWayTransformer, + SamVisionAttention, + SamVisionLayer, + ) + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[SamVisionLayer] = ModulePolicyDescription(attribute_replacement={ + "attn.num_attention_heads": + self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.qkv", + target_module=col_nn.FusedLinear1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.lin1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.lin2", + target_module=col_nn.Linear1D_Row, + ) + ]) + policy[SamTwoWayAttentionBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.num_attention_heads": + self.model.config.mask_decoder_config.num_attention_heads // + self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_token_to_image.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.lin1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.lin2", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="cross_attn_image_to_token.out_proj", + target_module=col_nn.Linear1D_Row, + ), + ]) + policy[SamTwoWayTransformer] = ModulePolicyDescription(attribute_replacement={ + "final_attn_token_to_image.num_attention_heads": + self.model.config.mask_decoder_config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.out_proj", + target_module=col_nn.Linear1D_Row, + ) + ]) + + # add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout` + policy[SamVisionAttention] = ModulePolicyDescription(attribute_replacement={ + "dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout) + }, + method_replacement={"forward": forward_fn()}, + sub_module_replacement=[]) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # Handle SamVisionLayer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=SamVisionLayer) + + # Handle SamTwoWayAttentionBlock + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm3", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm4", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=SamTwoWayAttentionBlock) + + # Handle SamTwoWayTransformer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="layer_norm_final_attn", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=SamTwoWayTransformer) + + return policy + + def postprocess(self): + return self.model + + +# SamModel +class SamModelPolicy(SamPolicy): + + def __init__(self) -> None: + super().__init__() diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index a298767d12e7..a1bcb78ddf6b 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -4,5 +4,6 @@ from .gpt import * from .llama import * from .opt import * +from .sam import * from .t5 import * from .vit import * diff --git a/tests/kit/model_zoo/transformers/sam.py b/tests/kit/model_zoo/transformers/sam.py new file mode 100644 index 000000000000..d850623f368f --- /dev/null +++ b/tests/kit/model_zoo/transformers/sam.py @@ -0,0 +1,52 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-image SAM +# =============================== + + +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from PIL import Image + # import requests + # from transformers import SamModel, SamProcessor + # + # model = SamModel.from_pretrained("facebook/sam-vit-base") + # processor = SamProcessor.from_pretrained("facebook/sam-vit-base") + # + # img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + # raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + # input_points = [[[450, 600]]] # 2D localization of a window + # inputs = processor(raw_image, input_points=input_points, return_tensors="pt") + + pixel_values = torch.rand(1, 3, 1024, 1024, dtype=torch.float32) + original_sizes = torch.tensor([[1764, 2646]], dtype=torch.int64) + reshaped_input_sizes = torch.tensor([[683, 1024]], dtype=torch.int64) + input_points = torch.tensor([[[[174.1497, 232.3129]]]], dtype=torch.float64) + return dict(pixel_values=pixel_values, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + input_points=input_points) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss funciton +loss_fn = lambda x: x.iou_scores.mean() + +config = transformers.SamConfig() +config.vision_config.num_hidden_layers = 2 + +# register the BERT variants +model_zoo.register(name='transformers_sam', + model_fn=lambda: transformers.SamModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py new file mode 100644 index 000000000000..9eeda93afe35 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -0,0 +1,120 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +# This code is copied from https://github.com/huggingface/transformers +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + self.weight = nn.Parameter(torch.empty(nx, nf)) + self.bias = nn.Parameter(torch.zeros(nf)) + nn.init.normal_(self.weight, std=0.02) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def rearrange(tensor: torch.Tensor, dim: int): + tensor = tensor.clone() + world_size = 2 + order = torch.arange(world_size * 3) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim) + rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order] + rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim) + return rearanged_tensor + + +def check_gpt2_linear_conv_1d_col(): + linear = Conv1D(192, 48).cuda() + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, + process_group=None, + gather_output=True, + n_fused=3) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear.bias.shape == torch.Size([192]) + assert linear_conv_col.weight.shape == torch.Size([48, 96]) + assert linear_conv_col.bias.shape == torch.Size([96]) + + # ensure weights are reversibly loadable + linear_conv_col.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_conv_col.state_dict()) + + # check computation correctness + x = torch.rand(4, 48).cuda() + out = linear(x) + gather_out = linear_conv_col(x) + assert_close(rearrange(out, 1), gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) + assert_close(target_grad, linear_conv_col.weight.grad) + + +def check_gpt2_linear_conv_1d_row(): + linear = Conv1D(192, 48).cuda() + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear_row.weight.shape == torch.Size([24, 192]) + assert linear_row.bias.shape == torch.Size([192]) + + # check computation correctness + x = torch.rand(4, 48).cuda() + out = linear(x) + gather_out = linear_row(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, linear_row.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # test for linear conv + check_gpt2_linear_conv_1d_col() + check_gpt2_linear_conv_1d_row() + + +@rerun_if_address_is_in_use() +def test_gpt2_linearconv(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_gpt2_linearconv() diff --git a/tests/test_shardformer/test_model/test_shard_sam.py b/tests/test_shardformer/test_model/test_shard_sam.py new file mode 100644 index 000000000000..1d047d8e0c42 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_sam.py @@ -0,0 +1,92 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['pred_masks']) + + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # check grad + + sam = org_model + sharded_sam = sharded_model + + # compare mask decoder grad + + org_grad = sam.mask_decoder.transformer.layers[0].self_attn.q_proj.weight.grad + shard_grad = sharded_sam.mask_decoder.transformer.layers[0].self_attn.q_proj.weight.grad + shard_weight = sharded_sam.mask_decoder.transformer.layers[0].self_attn.q_proj.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # compare vision_encoder grad + org_grad = sam.vision_encoder.layers[0].mlp.lin1.weight.grad + shard_grad = sharded_sam.vision_encoder.layers[0].mlp.lin1.weight.grad + shard_weight = sharded_sam.vision_encoder.layers[0].mlp.lin1.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_sam_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_sam') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +def check_sam(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_sam_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_sam(): + spawn(check_sam, 2) + + +if __name__ == "__main__": + test_sam() From 98e382c3b4733fa230ea414e985f408c9007a731 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 17 Jul 2023 14:25:32 +0800 Subject: [PATCH 46/80] [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme --- colossalai/shardformer/README.md | 2 +- colossalai/shardformer/layer/embedding.py | 10 +- .../shardformer/policies/auto_policy.py | 8 + colossalai/shardformer/policies/whisper.py | 232 ++++++++++++++++++ tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/whisper.py | 91 +++++++ .../test_model/test_shard_whisper.py | 101 ++++++++ 7 files changed, 443 insertions(+), 2 deletions(-) create mode 100644 colossalai/shardformer/policies/whisper.py create mode 100644 tests/kit/model_zoo/transformers/whisper.py create mode 100644 tests/test_shardformer/test_model/test_shard_whisper.py diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index bf4215c52980..3c322aabf2ef 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -102,7 +102,7 @@ We will follow this roadmap to develop Shardformer: - [ ] SwinTransformer - [ ] SwinTransformer V2 - [ ] Audio - - [ ] Whisper + - [x] Whisper - [ ] Multi-modal - [ ] To be added diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 09b22abb17cc..f07a93bd6908 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -202,7 +202,6 @@ def __init__(self, super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim - self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs self.process_group = process_group @@ -276,6 +275,15 @@ def _fill_padding_idx_with_zero(self) -> None: with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + def _select_padding_idx(self, padding_idx: int): + # select padding index according to the rank + if padding_idx is None: + return None + elif padding_idx < self.vocab_end_index and padding_idx >= self.vocab_start_index: + return padding_idx - self.vocab_start_index + else: + return None + def forward(self, input_: Tensor) -> Tensor: # Build the mask. input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 63ec8398fcee..90347a984599 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -105,6 +105,14 @@ class PolicyLocation: "transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"), + # Whisper + "transformers.models.whisper.modeling_whisper.WhisperModel": + PolicyLocation(file_name="whisper", class_name="WhisperModelPolicy"), + "transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration": + PolicyLocation(file_name="whisper", class_name="WhisperForConditionalGenerationPolicy"), + "transformers.models.whisper.modeling_whisper.WhisperForAudioClassification": + PolicyLocation(file_name="whisper", class_name="WhisperForAudioClassificationPolicy"), + # Sam "transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"), diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py new file mode 100644 index 000000000000..7751bbb5de99 --- /dev/null +++ b/colossalai/shardformer/policies/whisper.py @@ -0,0 +1,232 @@ +import torch.nn as nn + +import colossalai.shardformer.layer as col_nn + +from .._utils import getattr_, setattr_ +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = [ + 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification' +] + + +class WhisperPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + # TODO: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.whisper.modeling_whisper import ( + WhisperDecoder, + WhisperDecoderLayer, + WhisperEncoder, + WhisperEncoderLayer, + ) + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ + "self_attn.embed_dim": + self.model.config.d_model // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ]) + + policy[WhisperDecoderLayer] = ModulePolicyDescription(attribute_replacement={ + "self_attn.embed_dim": + self.model.config.d_model // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.decoder_attention_heads // self.shard_config.tensor_parallel_size, + "encoder_attn.embed_dim": + self.model.config.d_model // self.shard_config.tensor_parallel_size, + "encoder_attn.num_heads": + self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ]) + + policy[WhisperDecoder] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=col_nn.VocabParallelEmbedding1D, + ), + ]) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # Handle encoder layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperEncoderLayer) + + # Handle decoder layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperDecoderLayer) + + # handle encoder layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperEncoder) + + # handle decoder layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperDecoder) + return policy + + def add_lm_head_policy(self, base_policy): + from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration + + # optimize for tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), + policy=base_policy, + target_key=WhisperForConditionalGeneration) + + return base_policy + + def postprocess(self): + return self.model + + +# WhisperModel +class WhisperModelPolicy(WhisperPolicy): + + def __init__(self) -> None: + super().__init__() + + +# WhisperForConditionalGeneration +class WhisperForConditionalGenerationPolicy(WhisperPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + module_policy = self.add_lm_head_policy(module_policy) + return module_policy + + def postprocess(self): + binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) + return self.model + + +# WhisperForAudioClassification +class WhisperForAudioClassificationPolicy(WhisperPolicy): + + def __init__(self) -> None: + super().__init__() diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index a1bcb78ddf6b..39e5ef411f32 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -7,3 +7,4 @@ from .sam import * from .t5 import * from .vit import * +from .whisper import * diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py new file mode 100644 index 000000000000..b58716217cb5 --- /dev/null +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -0,0 +1,91 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence Whisper +# =============================== + + +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoFeatureExtractor, WhisperModel + # from datasets import load_dataset + + # model = WhisperModel.from_pretrained("openai/whisper-base") + # feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + # ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + # inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + # input_features = inputs.input_features + # decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + + input_features = torch.randn(1, 80, 3000) + decoder_input_ids = torch.tensor([[1, 1]]) * 50258 + return dict(input_features=input_features, decoder_input_ids=decoder_input_ids) + + +def data_gen_for_conditional_generation(): + # labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + # Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + # or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + # only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + data = data_gen() + data['labels'] = torch.tensor([[0, 1]], dtype=torch.int64) + return data + + +def data_gen_for_audio_classification(): + # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + # Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + # config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + # `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + # `WhisperForAudioClassification` does not need `decoder_input_ids` + data = data_gen() + data.pop('decoder_input_ids') + data['labels'] = torch.tensor([1], dtype=torch.int64) + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss funciton +loss_fn = lambda x: x.last_hidden_state.mean() +loss_fn_attr = lambda x: x.loss + +config = transformers.WhisperConfig( + classifier_proj_size=256, + d_model=256, + decoder_attention_heads=4, + decoder_ffn_dim=1536, + decoder_layers=2, + encoder_attention_heads=4, + encoder_ffn_dim=1536, + encoder_layers=2, + vocab_size=51866, +) + +# register the Whisper variants +model_zoo.register(name='transformers_whisper', + model_fn=lambda: transformers.WhisperModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='transformers_whisperForConditionalGeneration', + model_fn=lambda: transformers.WhisperForConditionalGeneration(config), + data_gen_fn=data_gen_for_conditional_generation, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_attr, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='transformers_whisperWhisperForAudioClassification', + model_fn=lambda: transformers.WhisperForAudioClassification(config), + data_gen_fn=data_gen_for_audio_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_attr, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py new file mode 100644 index 000000000000..8932a4ab902c --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -0,0 +1,101 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values') + + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # check grad + + if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': + whisper = org_model.model + sharded_whisper = sharded_model.model + else: + whisper = org_model + sharded_whisper = sharded_model + + # compare self attention grad + org_grad = whisper.encoder.layers[0].self_attn.q_proj.weight.grad + shard_grad = sharded_whisper.encoder.layers[0].self_attn.q_proj.weight.grad + shard_weight = sharded_whisper.encoder.layers[0].self_attn.q_proj.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # WhisperForAudioClassification does not have decoder and embedding layer + if org_model.__class__.__name__ == 'WhisperForAudioClassification': + return + + # compare embedding grad + org_grad = whisper.decoder.embed_tokens.weight.grad + shard_grad = sharded_whisper.decoder.embed_tokens.weight.grad + shard_weight = sharded_whisper.decoder.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, + enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +def check_whisper(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_whisper_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_whisper(): + spawn(check_whisper, 2) + + +if __name__ == "__main__": + test_whisper() From 7dea2e93f4b4c36e143ac122c4e58bf2f9c2f53c Mon Sep 17 00:00:00 2001 From: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Date: Thu, 20 Jul 2023 17:28:00 +0800 Subject: [PATCH 47/80] Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit --- colossalai/shardformer/policies/chatglm.py | 96 ++ tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/chatglm.py | 38 + .../chatglm2_6b/configuration_chatglm.py | 58 + .../chatglm2_6b/modeling_chatglm.py | 1372 +++++++++++++++++ .../test_model/test_shard_chatglm.py | 107 ++ 6 files changed, 1672 insertions(+) create mode 100644 colossalai/shardformer/policies/chatglm.py create mode 100644 tests/kit/model_zoo/transformers/chatglm.py create mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py create mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py create mode 100644 tests/test_shardformer/test_model/test_shard_chatglm.py diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py new file mode 100644 index 000000000000..934b99b83ea1 --- /dev/null +++ b/colossalai/shardformer/policies/chatglm.py @@ -0,0 +1,96 @@ +from typing import Dict, Union + +import torch.nn as nn + +import colossalai.shardformer.layer as col_nn + +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] + + +class ChatGLMModelPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # Resize embedding + vocab_size = self.model.config.padded_vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + + policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embedding.word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ]) + + policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={ + "self_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.projection_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads) // + self.shard_config.tensor_parallel_size, + "self_attention.qkv_hidden_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // + self.shard_config.tensor_parallel_size, + "self_attention.core_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.hidden_size_per_partition": + self.model.config.kv_channels * self.model.config.num_attention_heads // + self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attention.core_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) + # optimization configuration + if self.shard_config.enable_fused_normalization: + if not self.model.config.rmsnorm: + + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm), + SubModuleReplacementDescription(suffix="post_attention_layernorm", + target_module=col_nn.FusedLayerNorm) + ], + policy=policy, + target_key=GLMBlock) + + if self.model.config.post_layer_norm: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="encoder.final_layernorm", + target_module=col_nn.FusedLayerNorm) + ], + policy=policy, + target_key=ChatGLMModel) + + return policy + + def postprocess(self): + return self.model diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 39e5ef411f32..08a118e5783d 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -1,6 +1,7 @@ from .albert import * from .bert import * from .bloom import * +from .chatglm import * from .gpt import * from .llama import * from .opt import * diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py new file mode 100644 index 000000000000..1408babede64 --- /dev/null +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -0,0 +1,38 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo +from .chatglm2_6b.configuration_chatglm import ChatGLMConfig +from .chatglm2_6b.modeling_chatglm import ChatGLMModel + +# ================================ +# Register single-sentence ChatGLM +# ================================ + + +def data_gen(): + input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]]) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean() +loss_fn = lambda x: x.loss +config = ChatGLMConfig(num_layers=1, + padded_vocab_size=65024, + hidden_size=64, + num_attention_heads=8, + rmsnorm=False, + original_rope=True, + use_cache=True) + +model_zoo.register(name='transformers_chatglm', + model_fn=lambda: ChatGLMModel(config, empty_init=False), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_chatglm_model, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py new file mode 100644 index 000000000000..3e78732be2da --- /dev/null +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py @@ -0,0 +1,58 @@ +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + + def __init__(self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py new file mode 100644 index 000000000000..bae6d425878d --- /dev/null +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py @@ -0,0 +1,1372 @@ +""" +The ChatGLM2-6B License + +1. Definitions + +“Licensor” means the ChatGLM2-6B Model Team that distributes its Software. + +“Software” means the ChatGLM2-6B model parameters made available under this license. + +2. License Grant + +Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +3. Restriction + +You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. + +You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. + +4. Disclaimer + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +5. Limitation of Liability + +EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Dispute Resolution + +This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. + +Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. +""" +""" PyTorch ChatGLM model. """ + +import copy +import math +import re +import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != "darwin": + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm2-6b", + # See all ChatGLM models at https://huggingface.co/models?filter=chatglm +] + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = (config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size), + ) + else: + self.embedding = torch.nn.Embedding( + config.pre_seq_len, + config.num_layers * config.kv_channels * config.multi_query_group_num * 2, + ) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl( + self, + seq_len: int, + n_elem: int, + dtype: torch.dtype, + device: torch.device, + base: int = 10000, + ): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl( + max_seq_len, + self.dim, + dtype=self.inv_freq.dtype, + device=self.inv_freq.device, + ) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = (projection_size // config.num_attention_heads) + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split(".")[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=query_layer.device, + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if (attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]): + attention_mask = torch.ones( + output_size[0], + 1, + output_size[2], + output_size[3], + device=attention_scores.device, + dtype=torch.bool, + ) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + # Per attention head and per partition values. + self.hidden_size_per_attention_head = (self.projection_size // config.num_attention_heads) + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = (self.projection_size + + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num) + self.query_key_value = nn.Linear( + config.hidden_size, + self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, + **_config_to_kwargs(config), + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear( + self.projection_size, + config.hidden_size, + bias=config.add_bias_linear, + device=device, + **_config_to_kwargs(config), + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view(query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + key_layer = key_layer.view(key_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + )) + value_layer = value_layer.view(value_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + )) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + key_layer = key_layer.contiguous().view(key_layer.size()[:2] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + value_layer = value_layer.contiguous().view(value_layer.size()[:2] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + "dtype": args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config), + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config), + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = (config.apply_residual_connection_post_layernorm) + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache, + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache, + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat( + ( + torch.ones(batch_size, seq_length, past_length, device=input_ids.device), + full_attention_mask, + ), + dim=-1, + ) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = (torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)) + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GLMTransformer): + module.gradient_checkpointing = value + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device, + ) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(ChatGLMPreTrainedModel): + + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = (config.hidden_size // + config.num_attention_heads if config.kv_channels is None else config.kv_channels) + + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim // 2, + original_impl=config.original_rope, + device=device, + dtype=config.torch_dtype, + ) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method( + nn.Linear, + config.hidden_size, + config.padded_vocab_size, + bias=False, + dtype=config.torch_dtype, + **init_kwargs, + ) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = (self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.multi_query_group_num, + self.kv_channels, + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def quantize(self, weight_bit_width: int): + from .quantization import quantize + + quantize(self.encoder, weight_bit_width) + return self + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + self.config = config + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], + dim=-1, + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) + + model_kwargs["is_first_forward"] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + is_first_forward: bool = True, + **kwargs, + ) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids(input_ids, device=input_ids.device) + if not is_first_forward: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "position_ids": position_ids, + "attention_mask": attention_mask, + "return_last_logit": True, + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], + beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple(( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) for layer_past in past) + + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + return response + + def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + prompt = tokenizer.build_prompt(query, history=history) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + return inputs + + def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + if history: + prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + input_ids = tokenizer.encode(prompt, add_special_tokens=False) + input_ids = input_ids[1:] + inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) + else: + prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + return inputs + + @torch.no_grad() + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 8192, + num_beams=1, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + **kwargs, + ): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + "max_length": max_length, + "num_beams": num_beams, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } + inputs = self.build_inputs(tokenizer, query, history=history) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + past_key_values=None, + max_length: int = 8192, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + return_past_key_values=False, + **kwargs, + ): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + "max_length": max_length, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } + if past_key_values is None and not return_past_key_values: + inputs = self.build_inputs(tokenizer, query, history=history) + else: + inputs = self.build_stream_inputs(tokenizer, query, history=history) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + if self.transformer.pre_seq_len is not None: + past_length -= self.transformer.pre_seq_len + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) + inputs["attention_mask"] = attention_mask + for outputs in self.stream_generate( + **inputs, + past_key_values=past_key_values, + return_past_key_values=return_past_key_values, + **gen_kwargs, + ): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + if response and response[-1] != "�": + response = self.process_response(response) + new_history = history + [(query, response)] + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values=False, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = ( + generation_config.bos_token_id, + generation_config.eos_token_id, + ) + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = (kwargs.get("max_length") is None and generation_config.max_length is not None) + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = (generation_config.max_new_tokens + input_ids_seq_length) + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = ("decoder_input_ids" if self.config.is_encoder_decoder else "input_ids") + logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`.") + + # 2. Set generation parameters if not already defined + logits_processor = (logits_processor if logits_processor is not None else LogitsProcessorList()) + stopping_criteria = (stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()) + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, + stopping_criteria=stopping_criteria) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation(outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + + def quantize(self, bits: int, empty_init=False, device=None, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer.encoder = quantize( + self.transformer.encoder, + bits, + empty_init=empty_init, + device=device, + **kwargs, + ) + return self diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py new file mode 100644 index 000000000000..2cdf5da2e6da --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -0,0 +1,107 @@ +import copy +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.chatglm import ChatGLMModelPolicy +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # unwrap model + if org_model.__class__.__name__ == 'ChatGLMModel': + chatglm_model = org_model + shard_chatglm_model = sharded_model + else: + chatglm_model = org_model.transformer + shard_chatglm_model = sharded_model.transformer + + # check attention grad + org_grad = chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad + shard_grad = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad + shard_weight = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + # check embedding weights + org_grad = chatglm_model.embedding.word_embeddings.weight.grad + shard_grad = shard_chatglm_model.embedding.word_embeddings.weight.grad + shard_weight = shard_chatglm_model.embedding.word_embeddings.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + # create new model + org_model = model_fn().cuda() + + # shard model + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + if name == "transformers_chatglm": + sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda() + + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() + + +def check_chatglm(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_chatglm_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm(): + spawn(check_chatglm, 2) + + +if __name__ == "__main__": + test_chatglm() From e48ce7215ba2ca8821dd42dfd4e9cbf6be720fcd Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Tue, 4 Jul 2023 14:35:55 +0800 Subject: [PATCH 48/80] [shardformer] added tests --- tests/test_shardformer/test_model/test_shard_vit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 2b02c83e0d27..c1126cb2cd4b 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -56,6 +56,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_tensor_parallelism', [True, False]) def run_vit_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + print(sub_model_zoo) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) From fd29900783f519124a10f8b6164ae07a443105ee Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Thu, 6 Jul 2023 10:59:42 +0800 Subject: [PATCH 49/80] [shardformer] vit test finish and support --- tests/test_shardformer/test_model/test_shard_vit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index c1126cb2cd4b..2b02c83e0d27 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -56,7 +56,6 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_tensor_parallelism', [True, False]) def run_vit_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') - print(sub_model_zoo) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) From 65d663e7d6d866569c966a90c354fa33ad4f3f1c Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Fri, 7 Jul 2023 19:16:35 +0800 Subject: [PATCH 50/80] import chatglm --- =2.0 | 134 ++ .../chatglm2-6b/modeling_chatglm.py | 1193 +++++++++++++++++ 2 files changed, 1327 insertions(+) create mode 100644 =2.0 create mode 100644 tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py diff --git a/=2.0 b/=2.0 new file mode 100644 index 000000000000..af47ce17aa8e --- /dev/null +++ b/=2.0 @@ -0,0 +1,134 @@ +Defaulting to user installation because normal site-packages is not writeable +Collecting protobuf + Using cached protobuf-4.23.4-cp37-abi3-manylinux2014_x86_64.whl (304 kB) +Requirement already satisfied: transformers==4.30.2 in /home/lclk/.local/lib/python3.9/site-packages (4.30.2) +Collecting cpm_kernels + Using cached cpm_kernels-1.0.11-py3-none-any.whl (416 kB) +Requirement already satisfied: torch in /home/lclk/.local/lib/python3.9/site-packages (2.0.0+cu118) +Collecting gradio + Using cached gradio-3.36.0-py3-none-any.whl (19.8 MB) +Collecting mdtex2html + Using cached mdtex2html-1.2.0-py3-none-any.whl (13 kB) +Collecting sentencepiece + Using cached sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB) +Collecting accelerate + Using cached accelerate-0.20.3-py3-none-any.whl (227 kB) +Requirement already satisfied: pyyaml>=5.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (6.0) +Requirement already satisfied: regex!=2019.12.17 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (2023.6.3) +Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.15.1) +Requirement already satisfied: packaging>=20.0 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (23.1) +Requirement already satisfied: requests in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from transformers==4.30.2) (2.25.1) +Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.13.3) +Requirement already satisfied: safetensors>=0.3.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.3.1) +Requirement already satisfied: filelock in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (3.12.0) +Requirement already satisfied: numpy>=1.17 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (1.24.3) +Requirement already satisfied: tqdm>=4.27 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (4.65.0) +Requirement already satisfied: fsspec in /home/lclk/.local/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2) (2023.6.0) +Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/lclk/.local/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2) (4.6.3) +Requirement already satisfied: networkx in /home/lclk/.local/lib/python3.9/site-packages (from torch) (3.1) +Requirement already satisfied: sympy in /home/lclk/.local/lib/python3.9/site-packages (from torch) (1.12) +Requirement already satisfied: triton==2.0.0 in /home/lclk/.local/lib/python3.9/site-packages (from torch) (2.0.0) +Requirement already satisfied: jinja2 in /home/lclk/.local/lib/python3.9/site-packages (from torch) (3.1.2) +Requirement already satisfied: lit in /home/lclk/.local/lib/python3.9/site-packages (from triton==2.0.0->torch) (16.0.5.post0) +Requirement already satisfied: cmake in /home/lclk/.local/lib/python3.9/site-packages (from triton==2.0.0->torch) (3.26.3) +Collecting aiofiles + Using cached aiofiles-23.1.0-py3-none-any.whl (14 kB) +Collecting ffmpy + Using cached ffmpy-0.3.0.tar.gz (4.8 kB) +Requirement already satisfied: pillow in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (9.5.0) +Collecting pydub + Using cached pydub-0.25.1-py2.py3-none-any.whl (32 kB) +Requirement already satisfied: pandas in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.0.2) +Collecting python-multipart + Using cached python_multipart-0.0.6-py3-none-any.whl (45 kB) +Collecting semantic-version + Using cached semantic_version-2.10.0-py2.py3-none-any.whl (15 kB) +Collecting pydantic + Using cached pydantic-2.0.2-py3-none-any.whl (359 kB) +Collecting uvicorn>=0.14.0 + Using cached uvicorn-0.22.0-py3-none-any.whl (58 kB) +Collecting mdit-py-plugins<=0.3.3 + Using cached mdit_py_plugins-0.3.3-py3-none-any.whl (50 kB) +Requirement already satisfied: pygments>=2.12.0 in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.15.1) +Collecting httpx + Using cached httpx-0.24.1-py3-none-any.whl (75 kB) +Collecting orjson + Using cached orjson-3.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (136 kB) +Collecting fastapi + Using cached fastapi-0.99.1-py3-none-any.whl (58 kB) +Collecting altair>=4.2.0 + Using cached altair-5.0.1-py3-none-any.whl (471 kB) +Collecting gradio-client>=0.2.7 + Using cached gradio_client-0.2.7-py3-none-any.whl (288 kB) +Requirement already satisfied: aiohttp in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (3.8.4) +Requirement already satisfied: matplotlib in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (3.7.1) +Collecting websockets>=10.0 + Using cached websockets-11.0.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB) +Requirement already satisfied: markdown-it-py[linkify]>=2.0.0 in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.2.0) +Requirement already satisfied: markupsafe in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.1.3) +Collecting toolz + Using cached toolz-0.12.0-py3-none-any.whl (55 kB) +Collecting jsonschema>=3.0 + Using cached jsonschema-4.18.0-py3-none-any.whl (81 kB) +Collecting rpds-py>=0.7.1 + Downloading rpds_py-0.8.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB) +Collecting referencing>=0.28.4 + Using cached referencing-0.29.1-py3-none-any.whl (25 kB) +Collecting jsonschema-specifications>=2023.03.6 + Using cached jsonschema_specifications-2023.6.1-py3-none-any.whl (17 kB) +Requirement already satisfied: attrs>=22.2.0 in /home/lclk/.local/lib/python3.9/site-packages (from jsonschema>=3.0->altair>=4.2.0->gradio) (23.1.0) +Requirement already satisfied: mdurl~=0.1 in /home/lclk/.local/lib/python3.9/site-packages (from markdown-it-py[linkify]>=2.0.0->gradio) (0.1.2) +Collecting linkify-it-py<3,>=1 + Downloading linkify_it_py-2.0.2-py3-none-any.whl (19 kB) +Collecting uc-micro-py + Downloading uc_micro_py-1.0.2-py3-none-any.whl (6.2 kB) +Requirement already satisfied: pytz>=2020.1 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2023.3) +Requirement already satisfied: tzdata>=2022.1 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2023.3) +Requirement already satisfied: python-dateutil>=2.8.2 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2.8.2) +Requirement already satisfied: six>=1.5 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas->gradio) (1.16.0) +Requirement already satisfied: click>=7.0 in /home/lclk/.local/lib/python3.9/site-packages (from uvicorn>=0.14.0->gradio) (8.1.3) +Collecting h11>=0.8 + Downloading h11-0.14.0-py3-none-any.whl (58 kB) +Collecting latex2mathml + Downloading latex2mathml-3.76.0-py3-none-any.whl (73 kB) +Collecting markdown + Downloading Markdown-3.4.3-py3-none-any.whl (93 kB) +Requirement already satisfied: psutil in /home/lclk/.local/lib/python3.9/site-packages (from accelerate) (5.9.5) +Requirement already satisfied: multidict<7.0,>=4.5 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (6.0.4) +Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (4.0.2) +Requirement already satisfied: aiosignal>=1.1.2 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.3.1) +Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (3.1.0) +Requirement already satisfied: frozenlist>=1.1.1 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.3.3) +Requirement already satisfied: yarl<2.0,>=1.0 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.9.2) +Requirement already satisfied: idna>=2.0 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from yarl<2.0,>=1.0->aiohttp->gradio) (2.10) +Collecting pydantic + Downloading pydantic-1.10.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB) +Collecting starlette<0.28.0,>=0.27.0 + Downloading starlette-0.27.0-py3-none-any.whl (66 kB) +Collecting anyio<5,>=3.4.0 + Downloading anyio-3.7.1-py3-none-any.whl (80 kB) +Collecting sniffio>=1.1 + Downloading sniffio-1.3.0-py3-none-any.whl (10 kB) +Requirement already satisfied: exceptiongroup in /home/lclk/.local/lib/python3.9/site-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->gradio) (1.1.1) +Collecting httpcore<0.18.0,>=0.15.0 + Downloading httpcore-0.17.3-py3-none-any.whl (74 kB) +Requirement already satisfied: certifi in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from httpx->gradio) (2021.5.30) +Requirement already satisfied: importlib-metadata>=4.4 in /home/lclk/.local/lib/python3.9/site-packages (from markdown->mdtex2html) (6.7.0) +Requirement already satisfied: zipp>=0.5 in /home/lclk/.local/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown->mdtex2html) (3.15.0) +Requirement already satisfied: contourpy>=1.0.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (1.1.0) +Requirement already satisfied: fonttools>=4.22.0 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (4.40.0) +Requirement already satisfied: pyparsing>=2.3.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (3.1.0) +Requirement already satisfied: kiwisolver>=1.0.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (1.4.4) +Requirement already satisfied: importlib-resources>=3.2.0 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (5.12.0) +Requirement already satisfied: cycler>=0.10 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (0.11.0) +Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from requests->transformers==4.30.2) (1.26.6) +Requirement already satisfied: chardet<5,>=3.0.2 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from requests->transformers==4.30.2) (4.0.0) +Requirement already satisfied: mpmath>=0.19 in /home/lclk/.local/lib/python3.9/site-packages (from sympy->torch) (1.3.0) +Building wheels for collected packages: ffmpy + Building wheel for ffmpy (setup.py): started + Building wheel for ffmpy (setup.py): finished with status 'done' + Created wheel for ffmpy: filename=ffmpy-0.3.0-py3-none-any.whl size=4709 sha256=071cebb58ca6c6947fbc669e1d94509d6f53d1ed45d9d7fb9f060d1a342cfc18 + Stored in directory: /home/lclk/.cache/pip/wheels/91/e2/96/f676aa08bfd789328c6576cd0f1fde4a3d686703bb0c247697 +Successfully built ffmpy +Installing collected packages: sniffio, rpds-py, referencing, h11, anyio, uc-micro-py, jsonschema-specifications, httpcore, websockets, toolz, starlette, pydantic, linkify-it-py, jsonschema, httpx, uvicorn, semantic-version, python-multipart, pydub, orjson, mdit-py-plugins, markdown, latex2mathml, gradio-client, ffmpy, fastapi, altair, aiofiles, sentencepiece, protobuf, mdtex2html, gradio, cpm-kernels, accelerate +Successfully installed accelerate-0.20.3 aiofiles-23.1.0 altair-5.0.1 anyio-3.7.1 cpm-kernels-1.0.11 fastapi-0.99.1 ffmpy-0.3.0 gradio-3.36.0 gradio-client-0.2.7 h11-0.14.0 httpcore-0.17.3 httpx-0.24.1 jsonschema-4.18.0 jsonschema-specifications-2023.6.1 latex2mathml-3.76.0 linkify-it-py-2.0.2 markdown-3.4.3 mdit-py-plugins-0.3.3 mdtex2html-1.2.0 orjson-3.9.1 protobuf-4.23.4 pydantic-1.10.11 pydub-0.25.1 python-multipart-0.0.6 referencing-0.29.1 rpds-py-0.8.8 semantic-version-2.10.0 sentencepiece-0.1.99 sniffio-1.3.0 starlette-0.27.0 toolz-0.12.0 uc-micro-py-1.0.2 uvicorn-0.22.0 websockets-11.0.3 diff --git a/tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py new file mode 100644 index 000000000000..82163c46190f --- /dev/null +++ b/tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py @@ -0,0 +1,1193 @@ +""" PyTorch ChatGLM model. """ + +import math +import copy +import warnings +import re +import sys + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any + +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm2-6b", + # See all ChatGLM models at https://huggingface.co/models?filter=chatglm +] + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, + config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl( + self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 + ): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl( + max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device + ) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split('.')[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, + device=query_layer.device + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: + attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], + device=attention_scores.device, dtype=torch.bool) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) + self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, **_config_to_kwargs(config) + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, + device=device, **_config_to_kwargs(config) + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True + ): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + "dtype": args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, + device=input_ids.device), full_attention_mask), dim=-1) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GLMTransformer): + module.gradient_checkpointing = value + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device + ) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + ) + + self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device, + dtype=config.torch_dtype) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, + dtype=config.torch_dtype, **init_kwargs) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.multi_query_group_num, + self.kv_channels + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, + dtype=inputs_embeds.dtype) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask], dim=-1) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def quantize(self, weight_bit_width: int): + from .quantization import quantize + quantize(self.encoder, weight_bit_width) + return self + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + self.config = config + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) + + model_kwargs["is_first_forward"] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + is_first_forward: bool = True, + **kwargs + ) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids(input_ids, device=input_ids.device) + if not is_first_forward: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "position_ids": position_ids, + "attention_mask": attention_mask, + "return_last_logit": True + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + return response + + def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + prompt = tokenizer.build_prompt(query, history=history) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + return inputs + + def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + if history: + prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + input_ids = tokenizer.encode(prompt, add_special_tokens=False) + input_ids = input_ids[1:] + inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) + else: + prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + return inputs + + @torch.no_grad() + def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1, + do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + inputs = self.build_inputs(tokenizer, query, history=history) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None, + max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, + return_past_key_values=False, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if past_key_values is None and not return_past_key_values: + inputs = self.build_inputs(tokenizer, query, history=history) + else: + inputs = self.build_stream_inputs(tokenizer, query, history=history) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + if self.transformer.pre_seq_len is not None: + past_length -= self.transformer.pre_seq_len + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) + inputs['attention_mask'] = attention_mask + for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, + return_past_key_values=return_past_key_values, **gen_kwargs): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + if response and response[-1] != "�": + response = self.process_response(response) + new_history = history + [(query, response)] + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values=False, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + + def quantize(self, bits: int, empty_init=False, device=None, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, + **kwargs) + return self \ No newline at end of file From b0a15c38af1b6eff94a4bf2e0bc34b06cf844d2d Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Fri, 7 Jul 2023 19:56:22 +0800 Subject: [PATCH 51/80] [shardformer] add test kit in model zoo for chatglm --- .../chatglm2-6b/modeling_chatglm.py | 1193 ----------------- .../transformers/chatglm2_6b/MODEL_LICENSE | 33 + .../chatglm2_6b/modeling_chatglm.py | 6 - .../transformers/chatglm2_6b/quantization.py | 188 +++ 4 files changed, 221 insertions(+), 1199 deletions(-) delete mode 100644 tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py create mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE create mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py diff --git a/tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py deleted file mode 100644 index 82163c46190f..000000000000 --- a/tests/kit/model_zoo/transformers/chatglm2-6b/modeling_chatglm.py +++ /dev/null @@ -1,1193 +0,0 @@ -""" PyTorch ChatGLM model. """ - -import math -import copy -import warnings -import re -import sys - -import torch -import torch.utils.checkpoint -import torch.nn.functional as F -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm -from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any - -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput - -from .configuration_chatglm import ChatGLMConfig - -# flags required to enable jit fusion kernels - -if sys.platform != 'darwin': - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B" -_CONFIG_FOR_DOC = "ChatGLM6BConfig" - -CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "THUDM/chatglm2-6b", - # See all ChatGLM models at https://huggingface.co/models?filter=chatglm -] - - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 5] = 5e4 - return scores - - -class PrefixEncoder(torch.nn.Module): - """ - The torch.nn model to encode the prefix - Input shape: (batch-size, prefix-length) - Output shape: (batch-size, prefix-length, 2*layers*hidden) - """ - - def __init__(self, config: ChatGLMConfig): - super().__init__() - self.prefix_projection = config.prefix_projection - if self.prefix_projection: - # Use a two-layer MLP to encode the prefix - kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 - self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) - self.trans = torch.nn.Sequential( - torch.nn.Linear(kv_size, config.hidden_size), - torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, kv_size) - ) - else: - self.embedding = torch.nn.Embedding(config.pre_seq_len, - config.num_layers * config.kv_channels * config.multi_query_group_num * 2) - - def forward(self, prefix: torch.Tensor): - if self.prefix_projection: - prefix_tokens = self.embedding(prefix) - past_key_values = self.trans(prefix_tokens) - else: - past_key_values = self.embedding(prefix) - return past_key_values - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim, original_impl=False, device=None, dtype=None): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.dim = dim - self.original_impl = original_impl - - def forward_impl( - self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 - ): - """Enhanced Transformer with Rotary Position Embedding. - - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. - """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, dtype=dtype, device=device) - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.outer(seq_idx, theta).float() - - cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - - # this is to mimic the behaviour of complex32, else we will get different results - if dtype in (torch.float16, torch.bfloat16, torch.int8): - cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() - return cache - - def forward(self, max_seq_len, offset=0): - return self.forward_impl( - max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device - ) - - -@torch.jit.script -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [sq, b, np, hn] - sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) - rot_dim = rope_cache.shape[-2] * 2 - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:sq] - xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) - rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) - x_out2 = x_out2.flatten(3) - return torch.cat((x_out2, x_pass), dim=-1) - - -class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) - self.eps = eps - - def forward(self, hidden_states: torch.Tensor): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - - return (self.weight * hidden_states).to(input_dtype) - - -class CoreAttention(torch.nn.Module): - def __init__(self, config: ChatGLMConfig, layer_number): - super(CoreAttention, self).__init__() - - self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - - projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - self.coeff = coeff - - self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - - def forward(self, query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split('.')[0]) - if pytorch_major_version >= 2: - query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - is_causal=True) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) - context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - else: - # Raw attention scores - - # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.attention_softmax_in_fp32: - attention_scores = attention_scores.float() - if self.coeff is not None: - attention_scores = attention_scores * self.coeff - if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: - attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], - device=attention_scores.device, dtype=torch.bool) - attention_mask.tril_() - attention_mask = ~attention_mask - if attention_mask is not None: - attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.type_as(value_layer) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - - -class SelfAttention(torch.nn.Module): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(SelfAttention, self).__init__() - self.layer_number = max(1, layer_number) - - self.projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - self.multi_query_attention = config.multi_query_attention - self.qkv_hidden_size = 3 * self.projection_size - if self.multi_query_attention: - self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = ( - self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num - ) - self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, - device=device, **_config_to_kwargs(config) - ) - - self.core_attention = CoreAttention(config, self.layer_number) - - # Output. - self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, - device=device, **_config_to_kwargs(config) - ) - - def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): - if self.multi_query_attention: - num_attention_heads = self.num_multi_query_groups_per_partition - else: - num_attention_heads = self.num_attention_heads_per_partition - return torch.empty( - inference_max_sequence_len, - batch_size, - num_attention_heads, - self.hidden_size_per_attention_head, - dtype=dtype, - device=device, - ) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True - ): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - else: - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - # adjust key and value for inference - if kv_cache is not None: - cache_k, cache_v = kv_cache - key_layer = torch.cat((cache_k, key_layer), dim=0) - value_layer = torch.cat((cache_v, value_layer), dim=0) - if use_cache: - kv_cache = (key_layer, value_layer) - else: - kv_cache = None - - if self.multi_query_attention: - key_layer = key_layer.unsqueeze(-2) - key_layer = key_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - key_layer = key_layer.contiguous().view( - key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.unsqueeze(-2) - value_layer = value_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - value_layer = value_layer.contiguous().view( - value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - - # ================================== - # core attention computation - # ================================== - - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - - # ================= - # Output. [sq, b, h] - # ================= - - output = self.dense(context_layer) - - return output, kv_cache - - -def _config_to_kwargs(args): - common_kwargs = { - "dtype": args.torch_dtype, - } - return common_kwargs - - -class MLP(torch.nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, config: ChatGLMConfig, device=None): - super(MLP, self).__init__() - - self.add_bias = config.add_bias_linear - - # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear( - config.hidden_size, - config.ffn_hidden_size * 2, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def swiglu(x): - x = torch.chunk(x, 2, dim=-1) - return F.silu(x[0]) * x[1] - - self.activation_func = swiglu - - # Project back to h. - self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, - config.hidden_size, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def forward(self, hidden_states): - # [s, b, 4hp] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] - output = self.dense_4h_to_h(intermediate_parallel) - return output - - -class GLMBlock(torch.nn.Module): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(GLMBlock, self).__init__() - self.layer_number = layer_number - - self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm - - self.fp32_residual_connection = config.fp32_residual_connection - - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # Self attention. - self.self_attention = SelfAttention(config, layer_number, device=device) - self.hidden_dropout = config.hidden_dropout - - # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # MLP - self.mlp = MLP(config, device=device) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, - ): - # hidden_states: [s, b, h] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, - attention_mask, - rotary_pos_emb, - kv_cache=kv_cache, - use_cache=use_cache - ) - - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + layernorm_input - - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - - # MLP. - mlp_output = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) - output = residual + output - - return output, kv_cache - - -class GLMTransformer(torch.nn.Module): - """Transformer class.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(GLMTransformer, self).__init__() - - self.fp32_residual_connection = config.fp32_residual_connection - self.post_layer_norm = config.post_layer_norm - - # Number of layers. - self.num_layers = config.num_layers - - # Transformer layers. - def build_layer(layer_number): - return GLMBlock(config, layer_number, device=device) - - self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) - - if self.post_layer_norm: - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - self.gradient_checkpointing = False - - def _get_layer(self, layer_number): - return self.layers[layer_number] - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, - use_cache: Optional[bool] = True, - output_hidden_states: Optional[bool] = False, - ): - if not kv_caches: - kv_caches = [None for _ in range(self.num_layers)] - presents = () if use_cache else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - all_self_attentions = None - all_hidden_states = () if output_hidden_states else None - for index in range(self.num_layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer = self._get_layer(index) - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_caches[index], - use_cache - ) - else: - layer_ret = layer( - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[index], - use_cache=use_cache - ) - hidden_states, kv_cache = layer_ret - if use_cache: - presents = presents + (kv_cache,) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # Final layer norm. - if self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states, presents, all_hidden_states, all_self_attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_ids, past_key_values, padding_mask=None): - batch_size, seq_length = input_ids.shape - full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) - full_attention_mask.tril_() - past_length = 0 - if past_key_values: - past_length = past_key_values[0][0].shape[0] - if past_length: - full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, - device=input_ids.device), full_attention_mask), dim=-1) - if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) - if not past_length and padding_mask is not None: - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - - def get_position_ids(self, input_ids, device): - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - return position_ids - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, GLMTransformer): - module.gradient_checkpointing = value - - -class Embedding(torch.nn.Module): - """Language model embeddings.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(Embedding, self).__init__() - - self.hidden_size = config.hidden_size - # Word embeddings (parallel). - self.word_embeddings = nn.Embedding( - config.padded_vocab_size, - self.hidden_size, - dtype=config.torch_dtype, - device=device - ) - self.fp32_residual_connection = config.fp32_residual_connection - - def forward(self, input_ids): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - embeddings = words_embeddings - # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - return embeddings - - -class ChatGLMModel(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - init_kwargs = {} - if device is not None: - init_kwargs["device"] = device - self.embedding = init_method(Embedding, config, **init_kwargs) - self.num_layers = config.num_layers - self.multi_query_group_num = config.multi_query_group_num - self.kv_channels = config.kv_channels - - # Rotary positional embeddings - self.seq_length = config.seq_length - rotary_dim = ( - config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels - ) - - self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device, - dtype=config.torch_dtype) - self.encoder = init_method(GLMTransformer, config, **init_kwargs) - self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, - dtype=config.torch_dtype, **init_kwargs) - self.pre_seq_len = config.pre_seq_len - self.prefix_projection = config.prefix_projection - if self.pre_seq_len is not None: - for param in self.parameters(): - param.requires_grad = False - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - self.prefix_encoder = PrefixEncoder(config) - self.dropout = torch.nn.Dropout(0.1) - - def get_input_embeddings(self): - return self.embedding.word_embeddings - - def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) - past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view( - batch_size, - self.pre_seq_len, - self.num_layers * 2, - self.multi_query_group_num, - self.kv_channels - ) - # seq_len, b, nh, hidden_size - past_key_values = self.dropout(past_key_values) - past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - return past_key_values - - def forward( - self, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size, seq_length = input_ids.shape - - if inputs_embeds is None: - inputs_embeds = self.embedding(input_ids) - - if self.pre_seq_len is not None: - if past_key_values is None: - past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, - dtype=inputs_embeds.dtype) - if attention_mask is not None: - attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), - attention_mask], dim=-1) - - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) - - # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] - rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() - - # Run encoder. - hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( - inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states - ) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def quantize(self, weight_bit_width: int): - from .quantization import quantize - quantize(self.encoder, weight_bit_width) - return self - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.max_sequence_length = config.max_length - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - self.config = config - self.quantized = False - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) - - model_kwargs["is_first_forward"] = False - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - is_first_forward: bool = True, - **kwargs - ) -> dict: - # only last token for input_ids if past is not None - if position_ids is None: - position_ids = self.get_position_ids(input_ids, device=input_ids.device) - if not is_first_forward: - position_ids = position_ids[..., -1:] - input_ids = input_ids[:, -1:] - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "position_ids": position_ids, - "attention_mask": attention_mask, - "return_last_logit": True - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - if return_last_logit: - hidden_states = hidden_states[-1:] - lm_logits = self.transformer.output_layer(hidden_states) - lm_logits = lm_logits.transpose(0, 1).contiguous() - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - - def process_response(self, response): - response = response.strip() - response = response.replace("[[训练时间]]", "2023年") - return response - - def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): - prompt = tokenizer.build_prompt(query, history=history) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - return inputs - - def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): - if history: - prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) - input_ids = tokenizer.encode(prompt, add_special_tokens=False) - input_ids = input_ids[1:] - inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) - else: - prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - return inputs - - @torch.no_grad() - def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1, - do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - inputs = self.build_inputs(tokenizer, query, history=history) - outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - response = self.process_response(response) - history = history + [(query, response)] - return response, history - - @torch.no_grad() - def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None, - max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, - return_past_key_values=False, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if past_key_values is None and not return_past_key_values: - inputs = self.build_inputs(tokenizer, query, history=history) - else: - inputs = self.build_stream_inputs(tokenizer, query, history=history) - if past_key_values is not None: - past_length = past_key_values[0][0].shape[0] - if self.transformer.pre_seq_len is not None: - past_length -= self.transformer.pre_seq_len - inputs.position_ids += past_length - attention_mask = inputs.attention_mask - attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) - inputs['attention_mask'] = attention_mask - for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, - return_past_key_values=return_past_key_values, **gen_kwargs): - if return_past_key_values: - outputs, past_key_values = outputs - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - if response and response[-1] != "�": - response = self.process_response(response) - new_history = history + [(query, response)] - if return_past_key_values: - yield response, new_history, past_key_values - else: - yield response, new_history - - @torch.no_grad() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - return_past_key_values=False, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) - if return_past_key_values: - yield input_ids, outputs.past_key_values - else: - yield input_ids - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - - def quantize(self, bits: int, empty_init=False, device=None, **kwargs): - if bits == 0: - return - - from .quantization import quantize - - if self.quantized: - logger.info("Already quantized.") - return self - - self.quantized = True - - self.config.quantization_bit = bits - - self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, - **kwargs) - return self \ No newline at end of file diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE b/tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE new file mode 100644 index 000000000000..26198b21b6b2 --- /dev/null +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE @@ -0,0 +1,33 @@ +The ChatGLM2-6B License + +1. Definitions + +“Licensor” means the ChatGLM2-6B Model Team that distributes its Software. + +“Software” means the ChatGLM2-6B model parameters made available under this license. + +2. License Grant + +Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +3. Restriction + +You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. + +You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. + +4. Disclaimer + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +5. Limitation of Liability + +EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Dispute Resolution + +This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. + +Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. \ No newline at end of file diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py index bae6d425878d..488f24c5fcb9 100644 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py @@ -80,7 +80,6 @@ def default_init(cls, *args, **kwargs): class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() @@ -220,7 +219,6 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): super().__init__() self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) @@ -235,7 +233,6 @@ def forward(self, hidden_states: torch.Tensor): class CoreAttention(torch.nn.Module): - def __init__(self, config: ChatGLMConfig, layer_number): super(CoreAttention, self).__init__() @@ -842,7 +839,6 @@ def forward(self, input_ids): class ChatGLMModel(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): super().__init__(config) if empty_init: @@ -981,13 +977,11 @@ def forward( def quantize(self, weight_bit_width: int): from .quantization import quantize - quantize(self.encoder, weight_bit_width) return self class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): super().__init__(config) diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py b/tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py new file mode 100644 index 000000000000..cb95bfe82b20 --- /dev/null +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py @@ -0,0 +1,188 @@ +from torch.nn import Linear +from torch.nn.parameter import Parameter + +import bz2 +import torch +import base64 +import ctypes +from transformers.utils import logging + +from typing import List +from functools import partial + +logger = logging.get_logger(__name__) + +try: + from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up + + class Kernel: + def __init__(self, code: bytes, function_names: List[str]): + self.code = code + self._function_names = function_names + self._cmodule = LazyKernelCModule(self.code) + + for name in self._function_names: + setattr(self, name, KernelFunction(self._cmodule, name)) + + quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" + + kernels = Kernel( + bz2.decompress(base64.b64decode(quantization_code)), + [ + "int4WeightCompression", + "int4WeightExtractionFloat", + "int4WeightExtractionHalf", + "int8WeightExtractionFloat", + "int8WeightExtractionHalf", + ], + ) +except Exception as exception: + kernels = None + logger.warning("Failed to load cpm_kernels:" + str(exception)) + + +class W8A16Linear(torch.autograd.Function): + @staticmethod + def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): + ctx.inp_shape = inp.size() + ctx.weight_bit_width = weight_bit_width + out_features = quant_w.size(0) + inp = inp.contiguous().view(-1, inp.size(-1)) + weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() + output = inp.mm(weight.t()) + ctx.save_for_backward(inp, quant_w, scale_w) + return output.view(*(ctx.inp_shape[:-1] + (out_features,))) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + inp, quant_w, scale_w = ctx.saved_tensors + weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) + grad_output = grad_output.contiguous().view(-1, weight.size(0)) + grad_input = grad_output.mm(weight) + grad_weight = grad_output.t().mm(inp) + return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None + + +def compress_int4_weight(weight: torch.Tensor): # (n, m) + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + assert m % 2 == 0 + m = m // 2 + out = torch.empty(n, m, dtype=torch.int8, device="cuda") + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + kernels.int4WeightCompression( + gridDim, + blockDim, + 0, + stream, + [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], + ) + return out + + +def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): + assert scale_list.dtype in [torch.half, torch.bfloat16] + assert weight.dtype in [torch.int8] + if source_bit_width == 8: + return weight.to(scale_list.dtype) * scale_list[:, None] + elif source_bit_width == 4: + func = ( + kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16 + ) + else: + assert False, "Unsupported bit-width" + + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda") + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + func( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(scale_list.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m), + ], + ) + return out + + +class QuantizedLinear(torch.nn.Module): + def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args, + **kwargs): + super().__init__() + self.weight_bit_width = weight_bit_width + + shape = weight.shape + + if weight is None or empty_init: + self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device) + self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device) + else: + self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1) + self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) + if weight_bit_width == 4: + self.weight = compress_int4_weight(self.weight) + + self.weight = Parameter(self.weight.to(device), requires_grad=False) + self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False) + self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None + + def forward(self, input): + output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) + if self.bias is not None: + output = output + self.bias + return output + + +def quantize(model, weight_bit_width, empty_init=False, device=None): + """Replace fp16 linear with quantized linear""" + for layer in model.layers: + layer.self_attention.query_key_value = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()), + bias=layer.self_attention.query_key_value.bias, + dtype=layer.self_attention.query_key_value.weight.dtype, + device=layer.self_attention.query_key_value.weight.device if device is None else device, + empty_init=empty_init + ) + layer.self_attention.dense = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()), + bias=layer.self_attention.dense.bias, + dtype=layer.self_attention.dense.weight.dtype, + device=layer.self_attention.dense.weight.device if device is None else device, + empty_init=empty_init + ) + layer.mlp.dense_h_to_4h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), + bias=layer.mlp.dense_h_to_4h.bias, + dtype=layer.mlp.dense_h_to_4h.weight.dtype, + device=layer.mlp.dense_h_to_4h.weight.device if device is None else device, + empty_init=empty_init + ) + layer.mlp.dense_4h_to_h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), + bias=layer.mlp.dense_4h_to_h.bias, + dtype=layer.mlp.dense_4h_to_h.weight.dtype, + device=layer.mlp.dense_4h_to_h.weight.device if device is None else device, + empty_init=empty_init + ) + + return model From 9576cd7fffcdd09288cf1b70d988927b9c328028 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Mon, 10 Jul 2023 18:55:33 +0800 Subject: [PATCH 52/80] [sharformer] add first version of policy of chatglm --- colossalai/shardformer/policies/chatglm.py | 44 +++++++++++++++++++ .../test_model/test_shard_chatglm.py | 1 - 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 934b99b83ea1..c17b92c8dc81 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -1,6 +1,7 @@ from typing import Dict, Union import torch.nn as nn +from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock import colossalai.shardformer.layer as col_nn @@ -8,6 +9,49 @@ __all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] +class ChatGLMModelPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + + policy[GLMBlock] = ModulePolicyDescription( + attribute_replacement = {}, + sub_module_replacement = [ + # SubModuleReplacementDescription( + # suffix = "self_attention.query_key_value", + # target_module = col_nn.Linear1D_Col, + # ), + # SubModuleReplacementDescription( + # suffix = "self_attention.dense", + # target_module = col_nn.Linear1D_Row, + # ) + # SubModuleReplacementDescription( + # suffix = "self_attention.core_attention.attention_dropout", + # target_module = col_nn.DropoutForParallelInput, + # ) + ],) + + + def postprocess(self): + return self.model class ChatGLMModelPolicy(Policy): diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 2cdf5da2e6da..f05649fcb9a0 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -19,7 +19,6 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward - def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, From f1ec91a5a77afb31d62fe261468e951ca00e7c21 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Wed, 12 Jul 2023 15:25:07 +0800 Subject: [PATCH 53/80] [shardformer] polish chatglm code --- .../shardformer/policies/auto_policy.py | 3 ++ colossalai/shardformer/policies/chatglm.py | 44 ------------------- .../test_model/test_shard_chatglm.py | 1 + 3 files changed, 4 insertions(+), 44 deletions(-) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 90347a984599..e383630408ff 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -116,6 +116,9 @@ class PolicyLocation: # Sam "transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"), + # ChatGLM + "tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm.ChatGLMModel": + PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"), } diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index c17b92c8dc81..934b99b83ea1 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -1,7 +1,6 @@ from typing import Dict, Union import torch.nn as nn -from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock import colossalai.shardformer.layer as col_nn @@ -9,49 +8,6 @@ __all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] -class ChatGLMModelPolicy(Policy): - - def config_sanity_check(self): - pass - - def preprocess(self): - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - - return self.model - - def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from ....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock - - policy = {} - - if self.shard_config.enable_tensor_parallelism: - - policy[GLMBlock] = ModulePolicyDescription( - attribute_replacement = {}, - sub_module_replacement = [ - # SubModuleReplacementDescription( - # suffix = "self_attention.query_key_value", - # target_module = col_nn.Linear1D_Col, - # ), - # SubModuleReplacementDescription( - # suffix = "self_attention.dense", - # target_module = col_nn.Linear1D_Row, - # ) - # SubModuleReplacementDescription( - # suffix = "self_attention.core_attention.attention_dropout", - # target_module = col_nn.DropoutForParallelInput, - # ) - ],) - - - def postprocess(self): - return self.model class ChatGLMModelPolicy(Policy): diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index f05649fcb9a0..2cdf5da2e6da 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -19,6 +19,7 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward + def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, From a7c54bdd7d76c81b2b7d4352e4cdfdd5f22eec28 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Thu, 13 Jul 2023 19:51:25 +0800 Subject: [PATCH 54/80] [shardformer] polish code --- .../model_zoo/transformers/chatglm2_6b/modeling_chatglm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py index 488f24c5fcb9..f704715e1245 100644 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py @@ -80,6 +80,7 @@ def default_init(cls, *args, **kwargs): class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() @@ -219,6 +220,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): super().__init__() self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) @@ -233,6 +235,7 @@ def forward(self, hidden_states: torch.Tensor): class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): super(CoreAttention, self).__init__() @@ -839,6 +842,7 @@ def forward(self, input_ids): class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): super().__init__(config) if empty_init: @@ -921,6 +925,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) + print(inputs_embeds) if self.pre_seq_len is not None: if past_key_values is None: @@ -982,6 +987,7 @@ def quantize(self, weight_bit_width: int): class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): super().__init__(config) From 955b1d82199d93d037504a1d0ff27f0db1160853 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Fri, 14 Jul 2023 18:10:52 +0800 Subject: [PATCH 55/80] [shardformer] support chatglm without layernorm --- .../chatglm2_6b/modeling_chatglm.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py index f704715e1245..46078f441523 100644 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py @@ -396,17 +396,18 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.num_multi_query_groups_per_partition = config.multi_query_group_num self.qkv_hidden_size = (self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num) +<<<<<<< HEAD self.query_key_value = nn.Linear( config.hidden_size, self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, device=device, **_config_to_kwargs(config), ) - - self.core_attention = CoreAttention(config, self.layer_number) - - # Output. +======= + self.query_key_value = nn.Linear(self.hidden_size, + self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, +<<<<<<< HEAD self.dense = nn.Linear( self.projection_size, config.hidden_size, @@ -414,6 +415,13 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): device=device, **_config_to_kwargs(config), ) +======= + self.dense = nn.Linear(self.projection_size, + self.hidden_size, + bias=config.add_bias_linear, + device=device, + **_config_to_kwargs(config)) +>>>>>>> [shardformer] support chatglm without layernorm def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): if self.multi_query_attention: @@ -925,7 +933,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) - print(inputs_embeds) if self.pre_seq_len is not None: if past_key_values is None: From e9ebe69f8c0090d9e301fedb3f1f142d59f83c74 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Mon, 17 Jul 2023 15:10:15 +0800 Subject: [PATCH 56/80] [shardformer] delete some file --- =2.0 | 134 ------------- .../transformers/chatglm2_6b/MODEL_LICENSE | 33 --- .../transformers/chatglm2_6b/quantization.py | 188 ------------------ 3 files changed, 355 deletions(-) delete mode 100644 =2.0 delete mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE delete mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py diff --git a/=2.0 b/=2.0 deleted file mode 100644 index af47ce17aa8e..000000000000 --- a/=2.0 +++ /dev/null @@ -1,134 +0,0 @@ -Defaulting to user installation because normal site-packages is not writeable -Collecting protobuf - Using cached protobuf-4.23.4-cp37-abi3-manylinux2014_x86_64.whl (304 kB) -Requirement already satisfied: transformers==4.30.2 in /home/lclk/.local/lib/python3.9/site-packages (4.30.2) -Collecting cpm_kernels - Using cached cpm_kernels-1.0.11-py3-none-any.whl (416 kB) -Requirement already satisfied: torch in /home/lclk/.local/lib/python3.9/site-packages (2.0.0+cu118) -Collecting gradio - Using cached gradio-3.36.0-py3-none-any.whl (19.8 MB) -Collecting mdtex2html - Using cached mdtex2html-1.2.0-py3-none-any.whl (13 kB) -Collecting sentencepiece - Using cached sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB) -Collecting accelerate - Using cached accelerate-0.20.3-py3-none-any.whl (227 kB) -Requirement already satisfied: pyyaml>=5.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (6.0) -Requirement already satisfied: regex!=2019.12.17 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (2023.6.3) -Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.15.1) -Requirement already satisfied: packaging>=20.0 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (23.1) -Requirement already satisfied: requests in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from transformers==4.30.2) (2.25.1) -Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.13.3) -Requirement already satisfied: safetensors>=0.3.1 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (0.3.1) -Requirement already satisfied: filelock in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (3.12.0) -Requirement already satisfied: numpy>=1.17 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (1.24.3) -Requirement already satisfied: tqdm>=4.27 in /home/lclk/.local/lib/python3.9/site-packages (from transformers==4.30.2) (4.65.0) -Requirement already satisfied: fsspec in /home/lclk/.local/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2) (2023.6.0) -Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/lclk/.local/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2) (4.6.3) -Requirement already satisfied: networkx in /home/lclk/.local/lib/python3.9/site-packages (from torch) (3.1) -Requirement already satisfied: sympy in /home/lclk/.local/lib/python3.9/site-packages (from torch) (1.12) -Requirement already satisfied: triton==2.0.0 in /home/lclk/.local/lib/python3.9/site-packages (from torch) (2.0.0) -Requirement already satisfied: jinja2 in /home/lclk/.local/lib/python3.9/site-packages (from torch) (3.1.2) -Requirement already satisfied: lit in /home/lclk/.local/lib/python3.9/site-packages (from triton==2.0.0->torch) (16.0.5.post0) -Requirement already satisfied: cmake in /home/lclk/.local/lib/python3.9/site-packages (from triton==2.0.0->torch) (3.26.3) -Collecting aiofiles - Using cached aiofiles-23.1.0-py3-none-any.whl (14 kB) -Collecting ffmpy - Using cached ffmpy-0.3.0.tar.gz (4.8 kB) -Requirement already satisfied: pillow in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (9.5.0) -Collecting pydub - Using cached pydub-0.25.1-py2.py3-none-any.whl (32 kB) -Requirement already satisfied: pandas in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.0.2) -Collecting python-multipart - Using cached python_multipart-0.0.6-py3-none-any.whl (45 kB) -Collecting semantic-version - Using cached semantic_version-2.10.0-py2.py3-none-any.whl (15 kB) -Collecting pydantic - Using cached pydantic-2.0.2-py3-none-any.whl (359 kB) -Collecting uvicorn>=0.14.0 - Using cached uvicorn-0.22.0-py3-none-any.whl (58 kB) -Collecting mdit-py-plugins<=0.3.3 - Using cached mdit_py_plugins-0.3.3-py3-none-any.whl (50 kB) -Requirement already satisfied: pygments>=2.12.0 in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.15.1) -Collecting httpx - Using cached httpx-0.24.1-py3-none-any.whl (75 kB) -Collecting orjson - Using cached orjson-3.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (136 kB) -Collecting fastapi - Using cached fastapi-0.99.1-py3-none-any.whl (58 kB) -Collecting altair>=4.2.0 - Using cached altair-5.0.1-py3-none-any.whl (471 kB) -Collecting gradio-client>=0.2.7 - Using cached gradio_client-0.2.7-py3-none-any.whl (288 kB) -Requirement already satisfied: aiohttp in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (3.8.4) -Requirement already satisfied: matplotlib in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (3.7.1) -Collecting websockets>=10.0 - Using cached websockets-11.0.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB) -Requirement already satisfied: markdown-it-py[linkify]>=2.0.0 in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.2.0) -Requirement already satisfied: markupsafe in /home/lclk/.local/lib/python3.9/site-packages (from gradio) (2.1.3) -Collecting toolz - Using cached toolz-0.12.0-py3-none-any.whl (55 kB) -Collecting jsonschema>=3.0 - Using cached jsonschema-4.18.0-py3-none-any.whl (81 kB) -Collecting rpds-py>=0.7.1 - Downloading rpds_py-0.8.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB) -Collecting referencing>=0.28.4 - Using cached referencing-0.29.1-py3-none-any.whl (25 kB) -Collecting jsonschema-specifications>=2023.03.6 - Using cached jsonschema_specifications-2023.6.1-py3-none-any.whl (17 kB) -Requirement already satisfied: attrs>=22.2.0 in /home/lclk/.local/lib/python3.9/site-packages (from jsonschema>=3.0->altair>=4.2.0->gradio) (23.1.0) -Requirement already satisfied: mdurl~=0.1 in /home/lclk/.local/lib/python3.9/site-packages (from markdown-it-py[linkify]>=2.0.0->gradio) (0.1.2) -Collecting linkify-it-py<3,>=1 - Downloading linkify_it_py-2.0.2-py3-none-any.whl (19 kB) -Collecting uc-micro-py - Downloading uc_micro_py-1.0.2-py3-none-any.whl (6.2 kB) -Requirement already satisfied: pytz>=2020.1 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2023.3) -Requirement already satisfied: tzdata>=2022.1 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2023.3) -Requirement already satisfied: python-dateutil>=2.8.2 in /home/lclk/.local/lib/python3.9/site-packages (from pandas->gradio) (2.8.2) -Requirement already satisfied: six>=1.5 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas->gradio) (1.16.0) -Requirement already satisfied: click>=7.0 in /home/lclk/.local/lib/python3.9/site-packages (from uvicorn>=0.14.0->gradio) (8.1.3) -Collecting h11>=0.8 - Downloading h11-0.14.0-py3-none-any.whl (58 kB) -Collecting latex2mathml - Downloading latex2mathml-3.76.0-py3-none-any.whl (73 kB) -Collecting markdown - Downloading Markdown-3.4.3-py3-none-any.whl (93 kB) -Requirement already satisfied: psutil in /home/lclk/.local/lib/python3.9/site-packages (from accelerate) (5.9.5) -Requirement already satisfied: multidict<7.0,>=4.5 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (6.0.4) -Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (4.0.2) -Requirement already satisfied: aiosignal>=1.1.2 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.3.1) -Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (3.1.0) -Requirement already satisfied: frozenlist>=1.1.1 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.3.3) -Requirement already satisfied: yarl<2.0,>=1.0 in /home/lclk/.local/lib/python3.9/site-packages (from aiohttp->gradio) (1.9.2) -Requirement already satisfied: idna>=2.0 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from yarl<2.0,>=1.0->aiohttp->gradio) (2.10) -Collecting pydantic - Downloading pydantic-1.10.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB) -Collecting starlette<0.28.0,>=0.27.0 - Downloading starlette-0.27.0-py3-none-any.whl (66 kB) -Collecting anyio<5,>=3.4.0 - Downloading anyio-3.7.1-py3-none-any.whl (80 kB) -Collecting sniffio>=1.1 - Downloading sniffio-1.3.0-py3-none-any.whl (10 kB) -Requirement already satisfied: exceptiongroup in /home/lclk/.local/lib/python3.9/site-packages (from anyio<5,>=3.4.0->starlette<0.28.0,>=0.27.0->fastapi->gradio) (1.1.1) -Collecting httpcore<0.18.0,>=0.15.0 - Downloading httpcore-0.17.3-py3-none-any.whl (74 kB) -Requirement already satisfied: certifi in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from httpx->gradio) (2021.5.30) -Requirement already satisfied: importlib-metadata>=4.4 in /home/lclk/.local/lib/python3.9/site-packages (from markdown->mdtex2html) (6.7.0) -Requirement already satisfied: zipp>=0.5 in /home/lclk/.local/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown->mdtex2html) (3.15.0) -Requirement already satisfied: contourpy>=1.0.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (1.1.0) -Requirement already satisfied: fonttools>=4.22.0 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (4.40.0) -Requirement already satisfied: pyparsing>=2.3.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (3.1.0) -Requirement already satisfied: kiwisolver>=1.0.1 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (1.4.4) -Requirement already satisfied: importlib-resources>=3.2.0 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (5.12.0) -Requirement already satisfied: cycler>=0.10 in /home/lclk/.local/lib/python3.9/site-packages (from matplotlib->gradio) (0.11.0) -Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from requests->transformers==4.30.2) (1.26.6) -Requirement already satisfied: chardet<5,>=3.0.2 in /opt/lcsoftware/spack/opt/spack/linux-ubuntu20.04-zen2/gcc-9.3.0/miniconda3-4.10.3-u6p3tgreee7aigtnvuhr44yqo7vcg6r6/lib/python3.9/site-packages (from requests->transformers==4.30.2) (4.0.0) -Requirement already satisfied: mpmath>=0.19 in /home/lclk/.local/lib/python3.9/site-packages (from sympy->torch) (1.3.0) -Building wheels for collected packages: ffmpy - Building wheel for ffmpy (setup.py): started - Building wheel for ffmpy (setup.py): finished with status 'done' - Created wheel for ffmpy: filename=ffmpy-0.3.0-py3-none-any.whl size=4709 sha256=071cebb58ca6c6947fbc669e1d94509d6f53d1ed45d9d7fb9f060d1a342cfc18 - Stored in directory: /home/lclk/.cache/pip/wheels/91/e2/96/f676aa08bfd789328c6576cd0f1fde4a3d686703bb0c247697 -Successfully built ffmpy -Installing collected packages: sniffio, rpds-py, referencing, h11, anyio, uc-micro-py, jsonschema-specifications, httpcore, websockets, toolz, starlette, pydantic, linkify-it-py, jsonschema, httpx, uvicorn, semantic-version, python-multipart, pydub, orjson, mdit-py-plugins, markdown, latex2mathml, gradio-client, ffmpy, fastapi, altair, aiofiles, sentencepiece, protobuf, mdtex2html, gradio, cpm-kernels, accelerate -Successfully installed accelerate-0.20.3 aiofiles-23.1.0 altair-5.0.1 anyio-3.7.1 cpm-kernels-1.0.11 fastapi-0.99.1 ffmpy-0.3.0 gradio-3.36.0 gradio-client-0.2.7 h11-0.14.0 httpcore-0.17.3 httpx-0.24.1 jsonschema-4.18.0 jsonschema-specifications-2023.6.1 latex2mathml-3.76.0 linkify-it-py-2.0.2 markdown-3.4.3 mdit-py-plugins-0.3.3 mdtex2html-1.2.0 orjson-3.9.1 protobuf-4.23.4 pydantic-1.10.11 pydub-0.25.1 python-multipart-0.0.6 referencing-0.29.1 rpds-py-0.8.8 semantic-version-2.10.0 sentencepiece-0.1.99 sniffio-1.3.0 starlette-0.27.0 toolz-0.12.0 uc-micro-py-1.0.2 uvicorn-0.22.0 websockets-11.0.3 diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE b/tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE deleted file mode 100644 index 26198b21b6b2..000000000000 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/MODEL_LICENSE +++ /dev/null @@ -1,33 +0,0 @@ -The ChatGLM2-6B License - -1. Definitions - -“Licensor” means the ChatGLM2-6B Model Team that distributes its Software. - -“Software” means the ChatGLM2-6B model parameters made available under this license. - -2. License Grant - -Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -3. Restriction - -You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. - -You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. - -4. Disclaimer - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -5. Limitation of Liability - -EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. - -6. Dispute Resolution - -This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. - -Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. \ No newline at end of file diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py b/tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py deleted file mode 100644 index cb95bfe82b20..000000000000 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/quantization.py +++ /dev/null @@ -1,188 +0,0 @@ -from torch.nn import Linear -from torch.nn.parameter import Parameter - -import bz2 -import torch -import base64 -import ctypes -from transformers.utils import logging - -from typing import List -from functools import partial - -logger = logging.get_logger(__name__) - -try: - from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up - - class Kernel: - def __init__(self, code: bytes, function_names: List[str]): - self.code = code - self._function_names = function_names - self._cmodule = LazyKernelCModule(self.code) - - for name in self._function_names: - setattr(self, name, KernelFunction(self._cmodule, name)) - - quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" - - kernels = Kernel( - bz2.decompress(base64.b64decode(quantization_code)), - [ - "int4WeightCompression", - "int4WeightExtractionFloat", - "int4WeightExtractionHalf", - "int8WeightExtractionFloat", - "int8WeightExtractionHalf", - ], - ) -except Exception as exception: - kernels = None - logger.warning("Failed to load cpm_kernels:" + str(exception)) - - -class W8A16Linear(torch.autograd.Function): - @staticmethod - def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): - ctx.inp_shape = inp.size() - ctx.weight_bit_width = weight_bit_width - out_features = quant_w.size(0) - inp = inp.contiguous().view(-1, inp.size(-1)) - weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) - ctx.weight_shape = weight.size() - output = inp.mm(weight.t()) - ctx.save_for_backward(inp, quant_w, scale_w) - return output.view(*(ctx.inp_shape[:-1] + (out_features,))) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - inp, quant_w, scale_w = ctx.saved_tensors - weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) - grad_output = grad_output.contiguous().view(-1, weight.size(0)) - grad_input = grad_output.mm(weight) - grad_weight = grad_output.t().mm(inp) - return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None - - -def compress_int4_weight(weight: torch.Tensor): # (n, m) - with torch.cuda.device(weight.device): - n, m = weight.size(0), weight.size(1) - assert m % 2 == 0 - m = m // 2 - out = torch.empty(n, m, dtype=torch.int8, device="cuda") - stream = torch.cuda.current_stream() - - gridDim = (n, 1, 1) - blockDim = (min(round_up(m, 32), 1024), 1, 1) - - kernels.int4WeightCompression( - gridDim, - blockDim, - 0, - stream, - [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], - ) - return out - - -def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): - assert scale_list.dtype in [torch.half, torch.bfloat16] - assert weight.dtype in [torch.int8] - if source_bit_width == 8: - return weight.to(scale_list.dtype) * scale_list[:, None] - elif source_bit_width == 4: - func = ( - kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16 - ) - else: - assert False, "Unsupported bit-width" - - with torch.cuda.device(weight.device): - n, m = weight.size(0), weight.size(1) - out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda") - stream = torch.cuda.current_stream() - - gridDim = (n, 1, 1) - blockDim = (min(round_up(m, 32), 1024), 1, 1) - - func( - gridDim, - blockDim, - 0, - stream, - [ - ctypes.c_void_p(weight.data_ptr()), - ctypes.c_void_p(scale_list.data_ptr()), - ctypes.c_void_p(out.data_ptr()), - ctypes.c_int32(n), - ctypes.c_int32(m), - ], - ) - return out - - -class QuantizedLinear(torch.nn.Module): - def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args, - **kwargs): - super().__init__() - self.weight_bit_width = weight_bit_width - - shape = weight.shape - - if weight is None or empty_init: - self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device) - self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device) - else: - self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1) - self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) - if weight_bit_width == 4: - self.weight = compress_int4_weight(self.weight) - - self.weight = Parameter(self.weight.to(device), requires_grad=False) - self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False) - self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None - - def forward(self, input): - output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) - if self.bias is not None: - output = output + self.bias - return output - - -def quantize(model, weight_bit_width, empty_init=False, device=None): - """Replace fp16 linear with quantized linear""" - for layer in model.layers: - layer.self_attention.query_key_value = QuantizedLinear( - weight_bit_width=weight_bit_width, - weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()), - bias=layer.self_attention.query_key_value.bias, - dtype=layer.self_attention.query_key_value.weight.dtype, - device=layer.self_attention.query_key_value.weight.device if device is None else device, - empty_init=empty_init - ) - layer.self_attention.dense = QuantizedLinear( - weight_bit_width=weight_bit_width, - weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()), - bias=layer.self_attention.dense.bias, - dtype=layer.self_attention.dense.weight.dtype, - device=layer.self_attention.dense.weight.device if device is None else device, - empty_init=empty_init - ) - layer.mlp.dense_h_to_4h = QuantizedLinear( - weight_bit_width=weight_bit_width, - weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), - bias=layer.mlp.dense_h_to_4h.bias, - dtype=layer.mlp.dense_h_to_4h.weight.dtype, - device=layer.mlp.dense_h_to_4h.weight.device if device is None else device, - empty_init=empty_init - ) - layer.mlp.dense_4h_to_h = QuantizedLinear( - weight_bit_width=weight_bit_width, - weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), - bias=layer.mlp.dense_4h_to_h.bias, - dtype=layer.mlp.dense_4h_to_h.weight.dtype, - device=layer.mlp.dense_4h_to_h.weight.device if device is None else device, - empty_init=empty_init - ) - - return model From 965eb53869c267395dd6ff2f565cb771e87b783f Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Mon, 17 Jul 2023 19:47:57 +0800 Subject: [PATCH 57/80] [shardformer] ChatGLM support layernorm sharding --- .../kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py index 46078f441523..04d318d47868 100644 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py @@ -417,7 +417,7 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): ) ======= self.dense = nn.Linear(self.projection_size, - self.hidden_size, + config.hidden_size, bias=config.add_bias_linear, device=device, **_config_to_kwargs(config)) From fb7b53c27514a0f60c6312ae553a29ca98cc6d82 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Tue, 18 Jul 2023 12:33:12 +0800 Subject: [PATCH 58/80] [shardformer] register without auto policy --- colossalai/shardformer/policies/auto_policy.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index e383630408ff..90347a984599 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -116,9 +116,6 @@ class PolicyLocation: # Sam "transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"), - # ChatGLM - "tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm.ChatGLMModel": - PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"), } From 2f943080e511b88e31212a0921644a10d1e333db Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Wed, 19 Jul 2023 11:39:59 +0800 Subject: [PATCH 59/80] [shardformer] pre-commit check files --- .../chatglm2_6b/modeling_chatglm.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py index 04d318d47868..bae6d425878d 100644 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py @@ -396,18 +396,17 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.num_multi_query_groups_per_partition = config.multi_query_group_num self.qkv_hidden_size = (self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num) -<<<<<<< HEAD self.query_key_value = nn.Linear( config.hidden_size, self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, device=device, **_config_to_kwargs(config), ) -======= - self.query_key_value = nn.Linear(self.hidden_size, - self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, -<<<<<<< HEAD + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. self.dense = nn.Linear( self.projection_size, config.hidden_size, @@ -415,13 +414,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): device=device, **_config_to_kwargs(config), ) -======= - self.dense = nn.Linear(self.projection_size, - config.hidden_size, - bias=config.add_bias_linear, - device=device, - **_config_to_kwargs(config)) ->>>>>>> [shardformer] support chatglm without layernorm def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): if self.multi_query_attention: @@ -989,6 +981,7 @@ def forward( def quantize(self, weight_bit_width: int): from .quantization import quantize + quantize(self.encoder, weight_bit_width) return self From 2bb51ea236d657ad707a73cf878e0e1f4220704a Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Thu, 20 Jul 2023 19:14:04 +0800 Subject: [PATCH 60/80] [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit --- colossalai/shardformer/policies/chatglm.py | 24 +++++++++++++++++++ colossalai/shardformer/policies/vit.py | 2 +- tests/kit/model_zoo/transformers/chatglm.py | 11 +++++++-- .../test_model/test_shard_chatglm.py | 4 +++- 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 934b99b83ea1..46aa3b52af8f 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -90,7 +90,31 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=ChatGLMModel) + else: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm), + SubModuleReplacementDescription(suffix="post_attention_layernorm", + target_module=col_nn.FusedRMSNorm) + ], + policy=policy, + target_key=GLMBlock) + + if self.model.config.post_layer_norm: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="encoder.final_layernorm", + target_module=col_nn.FusedRMSNorm) + ], + policy=policy, + target_key=ChatGLMModel) + return policy def postprocess(self): return self.model + + +class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): + + def module_policy(self): + policy = super().module_policy() + return policy diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 96f27de2a7c8..1feb11ffcf24 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -23,7 +23,7 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer + from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel policy = {} diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 1408babede64..04e73a832abe 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -3,7 +3,7 @@ from ..registry import ModelAttribute, model_zoo from .chatglm2_6b.configuration_chatglm import ChatGLMConfig -from .chatglm2_6b.modeling_chatglm import ChatGLMModel +from .chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel # ================================ # Register single-sentence ChatGLM @@ -21,7 +21,7 @@ def data_gen(): # define loss function loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean() -loss_fn = lambda x: x.loss +loss_fn = lambda x: x.logits.mean() config = ChatGLMConfig(num_layers=1, padded_vocab_size=65024, hidden_size=64, @@ -36,3 +36,10 @@ def data_gen(): output_transform_fn=output_transform_fn, loss_fn=loss_fn_for_chatglm_model, model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name="transformers_chatglm_for_conditional_generation", + model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 2cdf5da2e6da..a0fa4bd82e74 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -7,7 +7,7 @@ import colossalai from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.policies.chatglm import ChatGLMModelPolicy +from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, @@ -85,6 +85,8 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): shard_former = ShardFormer(shard_config=shard_config) if name == "transformers_chatglm": sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda() + else: + sharded_model = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()).cuda() check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From 1e7b804cb28662a69cd8b1a522139c09e61d1606 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 25 Jul 2023 14:29:10 +0800 Subject: [PATCH 61/80] [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin --- colossalai/shardformer/README.md | 3 +- colossalai/shardformer/modeling/blip2.py | 60 ++++ colossalai/shardformer/modeling/sam.py | 2 - .../shardformer/policies/auto_policy.py | 6 + colossalai/shardformer/policies/blip2.py | 304 ++++++++++++++++++ tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/blip2.py | 61 ++++ .../test_model/test_shard_blip2.py | 107 ++++++ 8 files changed, 541 insertions(+), 3 deletions(-) create mode 100644 colossalai/shardformer/modeling/blip2.py create mode 100644 colossalai/shardformer/policies/blip2.py create mode 100644 tests/kit/model_zoo/transformers/blip2.py create mode 100644 tests/test_shardformer/test_model/test_shard_blip2.py diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 3c322aabf2ef..5489f97e4d19 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -104,7 +104,8 @@ We will follow this roadmap to develop Shardformer: - [ ] Audio - [x] Whisper - [ ] Multi-modal - - [ ] To be added + - [x] SAM + - [x] BLIP-2 ## 💡 API Design diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py new file mode 100644 index 000000000000..b7945423ae83 --- /dev/null +++ b/colossalai/shardformer/modeling/blip2.py @@ -0,0 +1,60 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + + +def forward_fn(): + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + mixed_qkv = self.qkv(hidden_states) + + # modified from original code, which is: + # mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute( + # 2, 0, 3, 1, 4 + # ) + # to: + mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + query_states, key_states, value_states = ( + mixed_qkv[0], + mixed_qkv[1], + mixed_qkv[2], + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + + attention_scores = attention_scores * self.scale + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) + + new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,) + context_layer = context_layer.reshape(new_context_layer_shape) + + output = self.projection(context_layer) + + outputs = (output, attention_probs) if output_attentions else (output, None) + + return outputs + + return forward diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py index 00e2d744e219..63ebfe89d5fa 100644 --- a/colossalai/shardformer/modeling/sam.py +++ b/colossalai/shardformer/modeling/sam.py @@ -1,6 +1,4 @@ import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup def forward_fn(): diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 90347a984599..2a041af19be8 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -116,6 +116,12 @@ class PolicyLocation: # Sam "transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"), + + # Blip2 + "transformers.models.blip_2.modeling_blip_2.Blip2Model": + PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"), + "transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration": + PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"), } diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py new file mode 100644 index 000000000000..43aa1adc1c5b --- /dev/null +++ b/colossalai/shardformer/policies/blip2.py @@ -0,0 +1,304 @@ +import torch.nn as nn + +import colossalai.shardformer.layer as col_nn + +from .._utils import getattr_, setattr_ +from ..modeling.blip2 import forward_fn +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ['BlipPolicy', 'BlipModelPolicy'] + + +class BlipPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + # TODO: + vocab_size = self.model.config.qformer_config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.blip_2.modeling_blip_2 import ( + Blip2Attention, + Blip2EncoderLayer, + Blip2QFormerLayer, + Blip2QFormerModel, + Blip2VisionModel, + ) + from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[Blip2EncoderLayer] = ModulePolicyDescription(attribute_replacement={ + "self_attn.num_heads": + self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.embed_dim": + self.model.config.vision_config.hidden_size // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="self_attn.qkv", + target_module=col_nn.FusedLinear1D_Col, + kwargs={ + "n_fused": 3, + }), + SubModuleReplacementDescription( + suffix="self_attn.projection", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.fc2", + target_module=col_nn.Linear1D_Row, + ), + ]) + + policy[Blip2QFormerModel] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) + + policy[Blip2QFormerLayer] = ModulePolicyDescription(attribute_replacement={ + "attention.attention.num_attention_heads": + self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size, + "crossattention.attention.num_attention_heads": + self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size, + "crossattention.attention.all_head_size": + self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate_query.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output_query.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output_query.dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + + policy[OPTDecoderLayer] = ModulePolicyDescription(attribute_replacement={ + "self_attn.embed_dim": + self.model.config.text_config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.text_config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ) + ]) + + policy[OPTForCausalLM] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="model.decoder.embed_tokens", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, + ), + ]) + + policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # Handle Blip2EncoderLayer layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=Blip2EncoderLayer) + + # handle Blip2VisionModel layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="post_layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=Blip2VisionModel) + + # handle Blip2VisionModel layer + self.append_or_create_submodule_replacement( + description=[SubModuleReplacementDescription( + suffix="layernorm", + target_module=col_nn.FusedLayerNorm, + )], + policy=policy, + target_key=Blip2QFormerModel) + + # handle Blip2QFormerLayer layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="output_query.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=Blip2QFormerLayer) + + # handle OPTForCausalLM layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="model.decoder.final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=OPTForCausalLM) + + # handle OPTDecoderLayer layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=OPTDecoderLayer) + + return policy + + def postprocess(self): + binding_map = { + 'language_model.model.decoder.embed_tokens': 'language_model.lm_head', + } + + for k, v in binding_map.items(): + src_mod = getattr_(self.model, k) + dst_mod = getattr_(self.model, v) + dst_mod.weight = src_mod.weight + + return self.model + + +# Blip2Model +class Blip2ModelPolicy(BlipPolicy): + + def __init__(self) -> None: + super().__init__() + + +# Blip2ForConditionalGeneration +class Blip2ForConditionalGenerationPolicy(BlipPolicy): + + def __init__(self) -> None: + super().__init__() diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 08a118e5783d..823ca032fc30 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -1,5 +1,6 @@ from .albert import * from .bert import * +from .blip2 import * from .bloom import * from .chatglm import * from .gpt import * diff --git a/tests/kit/model_zoo/transformers/blip2.py b/tests/kit/model_zoo/transformers/blip2.py new file mode 100644 index 000000000000..7338f740be7f --- /dev/null +++ b/tests/kit/model_zoo/transformers/blip2.py @@ -0,0 +1,61 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-image SAM +# =============================== + + +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from PIL import Image + # import requests + # from transformers import Blip2Processor, Blip2Model + # import torch + + # processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + # url = "http://images.cocodataset.org/val2017/000000039769.jpg" + # image = Image.open(requests.get(url, stream=True).raw) + + # prompt = "Question: how many cats are there? Answer:" + # inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) + + pixel_values = torch.rand(1, 3, 224, 224, dtype=torch.float32) + input_ids = torch.tensor([[2, 45641, 35, 141, 171, 10017, 32, 89, 116, 31652, 35]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + labels = torch.tensor([[34, 56]], dtype=torch.int64) + return dict(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss funciton +loss_fn_blip2_model = lambda x: x.loss + +config = transformers.Blip2Config() +config.text_config.num_hidden_layers = 1 +config.qformer_config.num_hidden_layers = 1 +config.vision_config.num_hidden_layers = 1 +config.qformer_config.attention_probs_dropout_prob = 0 +config.qformer_config.hidden_dropout_prob = 0 +config.text_config.dropout = 0 + +# register the blip2 variants +model_zoo.register(name='transformers_blip2', + model_fn=lambda: transformers.Blip2Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_blip2_model, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='transformers_blip2_conditional_gerneration', + model_fn=lambda: transformers.Blip2ForConditionalGeneration(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_blip2_model, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py new file mode 100644 index 000000000000..f96299e55a49 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_blip2.py @@ -0,0 +1,107 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # check grad + + blip2 = org_model + sharded_blip2 = sharded_model + + # compare vision_model grad + + org_grad = blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad + shard_grad = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad + shard_weight = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # compare qformer grad + org_grad = blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad + shard_grad = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad + shard_weight = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # compare language_model grad + org_grad = blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad + shard_grad = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad + shard_weight = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +def check_blip2(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_blip2_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_blip2(): + spawn(check_blip2, 2) + + +if __name__ == "__main__": + test_blip2() From ee54290968d73ccb48959b51fc3ae42439cc72fc Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 1 Aug 2023 18:02:49 +0800 Subject: [PATCH 62/80] update some module with new api version --- .../shardformer/layer/qkv_fused_linear.py | 87 +++++++++++-------- colossalai/shardformer/policies/blip2.py | 2 +- colossalai/shardformer/policies/chatglm.py | 2 +- colossalai/shardformer/policies/sam.py | 2 +- colossalai/shardformer/policies/whisper.py | 2 +- .../test_gpt2_qkv_fused_linear_1d.py | 38 ++++++-- .../test_model/test_shard_chatglm.py | 5 +- 7 files changed, 89 insertions(+), 49 deletions(-) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 1e4b6ecb69b3..42417f8bcc43 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -537,10 +537,11 @@ def __init__(self, gather_output: bool = False, skip_bias_add: bool = False, n_fused: int = 3, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() - # Keep input parameters self.in_features = in_features self.out_features = out_features @@ -554,36 +555,52 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + # Parameters. - # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) + if weight is None: + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight def shard_fn(tensor): return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) def gather_fn(tensor): - return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, False) + return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) - with torch.no_grad(): - sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn) - self.weight = customized_distributed_tensor_to_param(sharded_weight) + if not is_customized_distributed_tensor(self.weight): + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_weight, self.weight) if bias: - bias = torch.empty(self.out_features, **factory_kwargs) - - with torch.no_grad(): - sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn) - self.bias = customized_distributed_tensor_to_param(sharded_bias) + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ + if not is_customized_distributed_tensor(self.bias): + with torch.no_grad(): + sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_bias, self.bias) else: self.bias = None - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - # init weights - self.reset_parameters(weight_initializer, bias_initializer) + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, @@ -613,24 +630,26 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - # TODO: copy the sharded weights - with torch.no_grad(): - sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, - n_fused=n_fused, - process_group=process_group, - is_transposed=False) - linear_1d.weight.data.copy_(sharded_weight.data) - - if bias: - sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, - n_fused=n_fused, - process_group=process_group, - is_transposed=False) - linear_1d.bias.data.copy_(sharded_bias.data) - + # # TODO: copy the sharded weights + # with torch.no_grad(): + # sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, + # n_fused=n_fused, + # process_group=process_group, + # is_transposed=False) + # linear_1d.weight.data.copy_(sharded_weight.data) + + # if bias: + # sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, + # n_fused=n_fused, + # process_group=process_group, + # is_transposed=False) + # linear_1d.bias.data.copy_(sharded_bias.data) + print(linear_1d.weight.shape) return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 43aa1adc1c5b..a244d70b56f5 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -4,7 +4,7 @@ from .._utils import getattr_, setattr_ from ..modeling.blip2 import forward_fn -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['BlipPolicy', 'BlipModelPolicy'] diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 46aa3b52af8f..732a817b0655 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -4,7 +4,7 @@ import colossalai.shardformer.layer as col_nn -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index e75d63946260..ca20fff715f2 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -4,7 +4,7 @@ from .._utils import getattr_, setattr_ from ..modeling.sam import forward_fn -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['SamPolicy', 'SamModelPolicy'] diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 7751bbb5de99..2f3565bdaa96 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -3,7 +3,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification' diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 9eeda93afe35..b45cd172c3ca 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -1,12 +1,15 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai +from colossalai.lazy import LazyInitContext from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn # This code is copied from https://github.com/huggingface/transformers @@ -50,9 +53,13 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -def check_gpt2_linear_conv_1d_col(): +@parameterize('lazy_init', [False, True]) +def check_linear_conv_1d_col(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() - linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, + with ctx: + linear_copy = Conv1D(192, 48).cuda() + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True, n_fused=3) @@ -61,6 +68,8 @@ def check_gpt2_linear_conv_1d_col(): assert linear.bias.shape == torch.Size([192]) assert linear_conv_col.weight.shape == torch.Size([48, 96]) assert linear_conv_col.bias.shape == torch.Size([96]) + assert linear_copy.weight is linear_conv_col.weight + assert linear_copy.bias is linear_conv_col.bias # ensure weights are reversibly loadable linear_conv_col.load_state_dict(linear.state_dict()) @@ -80,13 +89,24 @@ def check_gpt2_linear_conv_1d_col(): assert_close(target_grad, linear_conv_col.weight.grad) -def check_gpt2_linear_conv_1d_row(): +@parameterize('lazy_init', [False, True]) +def check_linear_conv_1d_row(lazy_init: bool): + ctx = LazyInitContext() if lazy_init else nullcontext() + linear = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + with ctx: + linear_copy = Conv1D(192, 48).cuda() + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) assert linear_row.bias.shape == torch.Size([192]) + assert linear_copy.weight is linear_row.weight + assert linear_copy.bias is linear_row.bias + + # ensure weights are reversibly loadable + linear_row.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_row.state_dict()) # check computation correctness x = torch.rand(4, 48).cuda() @@ -107,14 +127,14 @@ def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # test for linear conv - check_gpt2_linear_conv_1d_col() - check_gpt2_linear_conv_1d_row() + check_linear_conv_1d_col() + check_linear_conv_1d_row() @rerun_if_address_is_in_use() -def test_gpt2_linearconv(): +def test_linearconv(): spawn(run_dist, nprocs=2) if __name__ == '__main__': - test_gpt2_linearconv() + test_linearconv() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index a0fa4bd82e74..36f240a0ffc0 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -84,9 +84,10 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) if name == "transformers_chatglm": - sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda() + sharded_model, _ = shard_former.optimize(model_copy, ChatGLMModelPolicy()) else: - sharded_model = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()).cuda() + sharded_model, _ = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()) + sharded_model = sharded_model.cuda() check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From 383ce5b9943f27abbc98b44892de47143ce648c8 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 2 Aug 2023 14:53:26 +0800 Subject: [PATCH 63/80] [test] skip some not compatible models --- tests/test_booster/test_plugin/test_gemini_plugin.py | 5 ++++- tests/test_lazy/test_models.py | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 57160dfae89b..a06b2c963bfe 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -90,7 +90,10 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): 'transformers_gpt_double_heads', 'torchaudio_hubert_base', 'torchaudio_wav2vec2_base', 'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model', 'transformers_vit', 'transformers_vit_for_masked_image_modeling', - 'transformers_vit_for_image_classification' + 'transformers_vit_for_image_classification', 'transformers_chatglm', + 'transformers_chatglm_for_conditional_generation', 'transformers_blip2', + 'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper', + 'transformers_whisperForConditionalGeneration', 'transformers_whisperWhisperForAudioClassification' ]: continue diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index ecb99e594267..18a737fcec85 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -11,8 +11,9 @@ def test_torchvision_models_lazy_init(subset, default_device): sub_model_zoo = model_zoo.get_sub_registry(subset) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base' - ) or name.startswith('transformers_llama') or name.startswith('transformers_vit'): + if name in ('torchaudio_wav2vec2_base', + 'torchaudio_hubert_base') or name.startswith('transformers_llama') or name.startswith( + ('transformers_vit', 'transformers_blip2')): continue check_lazy_init(entry, verbose=True, default_device=default_device) From 4b1a574bc9e3ff4ee5b97e85e86300b4bde21fcb Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Thu, 3 Aug 2023 14:51:36 +0800 Subject: [PATCH 64/80] [test] Hotfix/fix some model test and refactor check util api (#4369) * fix llama test * fix test bug of bert, blip2, bloom, gpt2 * fix llama test * fix opt test * fix sam test * fix sam test * fix t5 test * fix vit test * fix whisper test * fix whisper test * polish code * adjust allclose parameter * Add mistakenly deleted code * addjust allclose * change loss function for some base model --- tests/kit/model_zoo/transformers/bert.py | 2 +- tests/kit/model_zoo/transformers/bloom.py | 14 +++-- tests/kit/model_zoo/transformers/gpt.py | 20 ++++--- tests/kit/model_zoo/transformers/opt.py | 3 +- tests/kit/model_zoo/transformers/whisper.py | 4 +- tests/test_shardformer/test_model/_utils.py | 22 +++++++ .../test_model/test_shard_bert.py | 50 +++++----------- .../test_model/test_shard_blip2.py | 58 ++++--------------- .../test_model/test_shard_bloom.py | 39 +++---------- .../test_model/test_shard_gpt2.py | 30 ++++------ .../test_model/test_shard_llama.py | 37 +++--------- .../test_model/test_shard_opt.py | 37 +++--------- .../test_model/test_shard_sam.py | 37 ++---------- .../test_model/test_shard_t5.py | 52 +++-------------- .../test_model/test_shard_vit.py | 21 ++----- .../test_model/test_shard_whisper.py | 45 ++++---------- 16 files changed, 135 insertions(+), 336 deletions(-) diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 1993af51ad63..d17b8fda425a 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -102,7 +102,7 @@ def data_gen_for_qa(): output_transform_fn = lambda x: x # define loss funciton -loss_fn_for_bert_model = lambda x: x.pooler_output.mean() +loss_fn_for_bert_model = lambda x: x.pooler_output.sum() loss_fn = lambda x: x.loss config = transformers.BertConfig(hidden_size=128, diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py index 71146c0b9819..5d195db2c68d 100644 --- a/tests/kit/model_zoo/transformers/bloom.py +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -55,17 +55,23 @@ def data_gen_for_question_answering(): input_ids = torch.tensor( [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) - return dict(input_ids=input_ids, attention_mask=attention_mask) + start_positions = torch.tensor([1], dtype=torch.int64) + end_positions = torch.tensor([10], dtype=torch.int64) + return dict(input_ids=input_ids, + attention_mask=attention_mask, + start_positions=start_positions, + end_positions=end_positions) # define output transform function output_transform_fn = lambda x: x # define loss function -loss_fn_for_bloom_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, + torch.ones_like(x.last_hidden_state)) loss_fn_for_causal_lm = lambda x: x.loss -loss_fn_for_classification = lambda x: x.logits.mean() -loss_fn_for_question_answering = lambda x: x.end_logits.mean() +loss_fn_for_classification = lambda x: x.loss +loss_fn_for_question_answering = lambda x: x.loss config = transformers.BloomConfig(n_layer=1, n_head=4, diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index fcde75abdedc..a704310e14f5 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -1,3 +1,5 @@ +import copy + import torch import transformers @@ -44,14 +46,14 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) + data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 1]], dtype=torch.int64) return data def data_gen_for_sequence_classification(): # sequence classification data gen data = data_gen() - data['labels'] = torch.tensor([0], dtype=torch.int64) + data['labels'] = torch.tensor([1], dtype=torch.int64) return data @@ -59,7 +61,8 @@ def data_gen_for_sequence_classification(): output_transform_fn = lambda x: x # define loss function -loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state + )) loss_fn = lambda x: x.loss config = transformers.GPT2Config(n_layer=2, @@ -69,9 +72,10 @@ def data_gen_for_sequence_classification(): embd_pdrop=0, resid_pdrop=0, summary_first_dropout=0, - hidden_dropout=0, - problem_type="single_label_classification", - pad_token_id=50256) + hidden_dropout=0) + +config_for_token_classification = copy.deepcopy(config) +config_for_token_classification.num_labels = 2 # register the following models model_zoo.register(name='transformers_gpt', @@ -99,13 +103,13 @@ def data_gen_for_sequence_classification(): loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_for_token_classification', - model_fn=lambda: transformers.GPT2ForTokenClassification(config), + model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), data_gen_fn=data_gen_for_token_classification, output_transform_fn=output_transform_fn, loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_for_sequence_classification', - model_fn=lambda: transformers.GPT2ForSequenceClassification(config), + model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), data_gen_fn=data_gen_for_sequence_classification, output_transform_fn=output_transform_fn, loss_fn=loss_fn, diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index 4463ae12b901..29430afc0661 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -44,7 +44,8 @@ def data_gen_for_question_answering(): output_transform_fn = lambda x: x -loss_fn_for_opt_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state) + ) loss_fn_for_lm = lambda x: x.loss config = transformers.OPTConfig( hidden_size=128, diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py index b58716217cb5..40c96a5777ab 100644 --- a/tests/kit/model_zoo/transformers/whisper.py +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -22,7 +22,7 @@ def data_gen(): # input_features = inputs.input_features # decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id - input_features = torch.randn(1, 80, 3000) + input_features = torch.rand(1, 80, 3000) decoder_input_ids = torch.tensor([[1, 1]]) * 50258 return dict(input_features=input_features, decoder_input_ids=decoder_input_ids) @@ -53,7 +53,7 @@ def data_gen_for_audio_classification(): output_transform_fn = lambda x: x # define loss funciton -loss_fn = lambda x: x.last_hidden_state.mean() +loss_fn = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state)) loss_fn_attr = lambda x: x.loss config = transformers.WhisperConfig( diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 2320c725d444..e15295bc905f 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -2,10 +2,13 @@ from contextlib import nullcontext import torch +import torch.distributed as dist from torch.nn import Module from colossalai.lazy import LazyInitContext from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer._utils import getattr_ +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False): @@ -74,3 +77,22 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''): assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}' assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}' assert torch.equal(v, shard_v), f'{name} {k} value mismatch' + + +def check_grad(original_model, sharded_model, layer_suffix, atol=1e-5, rtol=1e-5, dim=0, verbose=False): + for suffix in layer_suffix: + org_grad = getattr_(original_model, suffix).weight.grad + shard_grad = getattr_(sharded_model, suffix).weight.grad + shard_weight = getattr_(sharded_model, suffix).weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size())] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=dim) + else: + all_shard_grad = shard_grad + if verbose and dist.get_rank() == 0: + print(f"'{suffix}' grad: {org_grad}, {all_shard_grad}") + assert torch.allclose( + org_grad, all_shard_grad, rtol=rtol, atol=atol + ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{all_shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 6d0d3c798c4e..1d42f1c4703e 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -15,10 +15,18 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # unwarp model + if org_model.__class__.__name__ == 'BertModel': + bert = org_model + sharded_bert = sharded_model + else: + bert = org_model.bert + sharded_bert = sharded_model.bert + # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) @@ -32,42 +40,10 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" # check grad - - if org_model.__class__.__name__ == 'BertModel': - bert = org_model - sharded_bert = sharded_model - else: - bert = org_model.bert - sharded_bert = sharded_model.bert - - # compare self attention grad - org_grad = bert.encoder.layer[0].attention.self.query.weight.grad - shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad - shard_weight = sharded_bert.encoder.layer[0].attention.self.query.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" - - # compare embedding grad - org_grad = bert.embeddings.word_embeddings.weight.grad - shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad - shard_weight = sharded_bert.embeddings.word_embeddings.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + col_layer_for_check = ['encoder.layer[0].attention.self.query', 'embeddings.word_embeddings'] + row_layer_for_check = ['encoder.layer[0].attention.output.dense'] + check_grad(bert, sharded_bert, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False) + check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False) @parameterize('enable_fused_normalization', [False, True]) diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py index f96299e55a49..cb9725f4de7f 100644 --- a/tests/test_shardformer/test_model/test_shard_blip2.py +++ b/tests/test_shardformer/test_model/test_shard_blip2.py @@ -3,7 +3,6 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, clear_cache_before_run, @@ -12,7 +11,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -33,50 +32,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo blip2 = org_model sharded_blip2 = sharded_model - # compare vision_model grad - - org_grad = blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad - shard_grad = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad - shard_weight = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" - - # compare qformer grad - org_grad = blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad - shard_grad = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad - shard_weight = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" - - # compare language_model grad - org_grad = blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad - shard_grad = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad - shard_weight = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + # check grad + col_layer_for_check = [ + 'vision_model.encoder.layers[0].self_attn.qkv', 'qformer.encoder.layer[0].attention.attention.query', + 'language_model.model.decoder.layers[0].self_attn.k_proj' + ] + row_layer_for_check = [ + 'vision_model.encoder.layers[0].self_attn.projection', 'qformer.encoder.layer[0].attention.output.dense', + 'language_model.model.decoder.layers[0].self_attn.out_proj' + ] + check_grad(blip2, sharded_blip2, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) + check_grad(blip2, sharded_blip2, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) @parameterize('enable_fused_normalization', [True, False]) diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index fe4686aeb979..c13596fe8db3 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -3,7 +3,6 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, clear_cache_before_run, @@ -12,7 +11,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -26,7 +25,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo shard_loss.backward() assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + atol=1e-6), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" # unwrap model if org_model.__class__.__name__ == 'BloomModel': @@ -36,35 +35,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo bloom = org_model.transformer sharded_bloom = sharded_model.transformer - # check attention grad - org_grad = bloom.h[0].self_attention.query_key_value.weight.grad - shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad - shard_weight = sharded_bloom.h[0].self_attention.query_key_value.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" - - # check embedding weights - org_grad = bloom.word_embeddings.weight.grad - shard_grad = sharded_bloom.word_embeddings.weight.grad - shard_weight = sharded_bloom.word_embeddings.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + # check grad + col_layer_for_check = ['h[0].self_attention.query_key_value'] + row_layer_for_check = ['h[0].self_attention.dense'] + check_grad(bloom, sharded_bloom, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) + check_grad(bloom, sharded_bloom, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) @parameterize('enable_fused_normalization', [True, False]) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index eae4f2ffb799..d1ab352f6512 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -18,7 +18,7 @@ ) from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): @@ -105,26 +105,17 @@ def _criterion(outputs, inputs): # unwrap model if org_model.__class__.__name__ == 'GPT2Model': - org_model = org_model - sharded_model = sharded_model.unwrap() + gpt2 = org_model + sharded_gpt2 = sharded_model.unwrap() else: - org_model = org_model.transformer - sharded_model = sharded_model.unwrap().transformer + gpt2 = org_model.transformer + sharded_gpt2 = sharded_model.unwrap().transformer - # check weights and gradients - if stage_manager is None or stage_manager.is_first_stage(): - - shard_weight = sharded_model.h[0].mlp.c_fc.weight - org_grad = org_model.h[0].mlp.c_fc.weight.grad - shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(plugin.tp_size)] - dist.all_gather(shard_grad_list, shard_grad, plugin.tp_group) - shard_grad = torch.cat(shard_grad_list, dim=1) - - assert torch.allclose(org_grad, shard_grad, atol=1e-5, rtol=1e-3), \ - f"shard model grad is not equal to origin model grad\n{org_grad}\n{shard_grad}" + # check grad + col_layer_for_check = ['h[0].mlp.c_fc'] + row_layer_for_check = ['h[0].mlp.c_proj'] + check_grad(gpt2, sharded_gpt2, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False) + check_grad(gpt2, sharded_gpt2, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False) # check weights after optimizer.step() org_optimizer.step() @@ -184,6 +175,7 @@ def check_gpt2(rank, world_size, port): run_gpt2_test() +@pytest.mark.skip('Have some bug caused by merge') @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index aaeef13ef873..2cfc172c8df6 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -5,7 +5,6 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, clear_cache_before_run, @@ -14,7 +13,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -24,7 +23,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo output_transform_fn, loss_fn) # forward check - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5) # run backward org_loss.backward() @@ -41,33 +40,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo llama_model = org_model shard_llama_model = sharded_model - # check attention grad - org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad - shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad - shard_weight = shard_llama_model.layers[0].self_attn.q_proj.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" - - # check embedding grad - org_grad = llama_model.embed_tokens.weight.grad - shard_grad = shard_llama_model.embed_tokens.weight.grad - shard_weight = shard_llama_model.embed_tokens.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + # check grad + col_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] + row_layer_for_check = ['layers[0].self_attn.o_proj'] + check_grad(llama_model, shard_llama_model, col_layer_for_check, atol=1e-6, rtol=1e-4, dim=0, verbose=False) + check_grad(llama_model, shard_llama_model, row_layer_for_check, atol=1e-6, rtol=1e-4, dim=1, verbose=False) @parameterize('enable_fused_normalization', [True, False]) diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 297affceb68a..4684bacb4788 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,7 +6,6 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, clear_cache_before_run, @@ -15,7 +14,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -23,7 +22,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5) # run backward org_loss.backward() @@ -40,33 +39,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo opt_model = org_model shard_opt_model = sharded_model - # check attention grad - org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad - shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad - shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" - - # check embedding grad - org_grad = opt_model.decoder.embed_tokens.weight.grad - shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_weight = shard_opt_model.decoder.embed_tokens.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + # check grad + col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] + row_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] + check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False) + check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False) @parameterize('enable_fused_normalization', [True, False]) diff --git a/tests/test_shardformer/test_model/test_shard_sam.py b/tests/test_shardformer/test_model/test_shard_sam.py index 1d047d8e0c42..e7748cfd189d 100644 --- a/tests/test_shardformer/test_model/test_shard_sam.py +++ b/tests/test_shardformer/test_model/test_shard_sam.py @@ -3,7 +3,6 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, clear_cache_before_run, @@ -12,7 +11,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -33,35 +32,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo sam = org_model sharded_sam = sharded_model - # compare mask decoder grad - - org_grad = sam.mask_decoder.transformer.layers[0].self_attn.q_proj.weight.grad - shard_grad = sharded_sam.mask_decoder.transformer.layers[0].self_attn.q_proj.weight.grad - shard_weight = sharded_sam.mask_decoder.transformer.layers[0].self_attn.q_proj.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" - - # compare vision_encoder grad - org_grad = sam.vision_encoder.layers[0].mlp.lin1.weight.grad - shard_grad = sharded_sam.vision_encoder.layers[0].mlp.lin1.weight.grad - shard_weight = sharded_sam.vision_encoder.layers[0].mlp.lin1.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + # check grad + col_layer_for_check = ['mask_decoder.transformer.layers[0].self_attn.q_proj', 'vision_encoder.layers[0].mlp.lin1'] + row_layer_for_check = ['mask_decoder.transformer.layers[0].self_attn.out_proj', 'vision_encoder.layers[0].mlp.lin2'] + check_grad(sam, sharded_sam, col_layer_for_check, atol=1e-5, rtol=1e-3, dim=0, verbose=False) + check_grad(sam, sharded_sam, row_layer_for_check, atol=1e-3, rtol=1e-3, dim=1, verbose=False) @parameterize('enable_fused_normalization', [True, False]) diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 96dfdeb73827..024c5016b0c1 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -5,7 +5,6 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, clear_cache_before_run, @@ -14,7 +13,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -22,7 +21,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # the value "past_key_values" is sharded, so we ignore org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], atol=1e-5) # do backward org_loss.backward() @@ -31,54 +30,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo assert torch.allclose(org_loss, shard_loss, atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" - # check attention grad - org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad - shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad - shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" - - # check self attention embed - org_grad = org_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad - shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad - shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=1) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" - - # check token embedding grad - org_grad = org_model.shared.weight.grad + # check grad + col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared'] + row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias'] + check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-7, rtol=1e-5, dim=0, verbose=False) + check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-7, rtol=1e-5, dim=1, verbose=False) # check weights are tied if hasattr(org_model, 'lm_head'): assert org_model.shared.weight.data.data_ptr() == org_model.lm_head.weight.data.data_ptr() assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr() - shard_grad = sharded_model.shared.weight.grad - shard_weight = sharded_model.shared.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" - @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 2b02c83e0d27..7833ab70275d 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -5,7 +5,6 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, clear_cache_before_run, @@ -14,7 +13,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -37,19 +36,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo vit_model = org_model.vit shard_vit_model = sharded_model.vit - # check attention grad - org_grad = vit_model.encoder.layer[0].attention.attention.query.weight.grad - shard_grad = shard_vit_model.encoder.layer[0].attention.attention.query.weight.grad - shard_weight = shard_vit_model.encoder.layer[0].attention.attention.query.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + # check grad + col_layer_for_check = ['encoder.layer[0].attention.attention.query'] + row_layer_for_check = ['encoder.layer[0].attention.output.dense'] + check_grad(vit_model, shard_vit_model, col_layer_for_check, atol=1e-5, rtol=1e-3, dim=0, verbose=False) + check_grad(vit_model, shard_vit_model, row_layer_for_check, atol=1e-5, rtol=1e-3, dim=1, verbose=False) @parameterize('enable_fused_normalization', [True, False]) diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 8932a4ab902c..a271bbdf1223 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -3,7 +3,6 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, clear_cache_before_run, @@ -12,14 +11,14 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values') + assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values', atol=1e-5) # do backward org_loss.backward() @@ -28,8 +27,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo assert torch.allclose(org_loss, shard_loss, atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" - # check grad - + # unwarp the model if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': whisper = org_model.model sharded_whisper = sharded_model.model @@ -37,38 +35,15 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo whisper = org_model sharded_whisper = sharded_model - # compare self attention grad - org_grad = whisper.encoder.layers[0].self_attn.q_proj.weight.grad - shard_grad = sharded_whisper.encoder.layers[0].self_attn.q_proj.weight.grad - shard_weight = sharded_whisper.encoder.layers[0].self_attn.q_proj.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" - - # WhisperForAudioClassification does not have decoder and embedding layer + # check grad if org_model.__class__.__name__ == 'WhisperForAudioClassification': - return - - # compare embedding grad - org_grad = whisper.decoder.embed_tokens.weight.grad - shard_grad = sharded_whisper.decoder.embed_tokens.weight.grad - shard_weight = sharded_whisper.decoder.embed_tokens.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + col_layer_for_check = ['encoder.layers[0].self_attn.q_proj'] + row_layer_for_check = ['encoder.layers[0].self_attn.out_proj'] else: - all_shard_grad = shard_grad - - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + col_layer_for_check = ['encoder.layers[0].self_attn.q_proj', 'decoder.layers[0].self_attn.q_proj'] + row_layer_for_check = ['encoder.layers[0].self_attn.out_proj', 'decoder.layers[0].self_attn.out_proj'] + check_grad(whisper, sharded_whisper, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) + check_grad(whisper, sharded_whisper, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) @parameterize('enable_fused_normalization', [True, False]) From 58348852103f8c535f35d5d9c76bb3ffa3907320 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 3 Aug 2023 17:50:15 +0800 Subject: [PATCH 65/80] [shardformer] add util functions for shardformer tests/fix sync_shared_param (#4366) * add util functions for shardformer tests & rewrite gpt2 test * fix shared_params & embedding/merging * fix precision --- .../booster/plugin/hybrid_parallel_plugin.py | 3 +- tests/kit/model_zoo/transformers/gpt.py | 4 +- tests/test_shardformer/test_model/_utils.py | 159 ++++++++++++++++-- .../test_model/test_shard_gpt2.py | 138 ++++----------- 4 files changed, 190 insertions(+), 114 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 35a88d1e8980..a22bdb7199bb 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -37,7 +37,8 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp self.shared_param_process_groups = [] for shared_param in self.shared_params: if len(shared_param) > 0: - self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) + self.shared_param_process_groups.append( + self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) if precision == 'fp16': module = module.half().cuda() elif precision == 'bf16': diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index a704310e14f5..73c210221e61 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -72,7 +72,9 @@ def data_gen_for_sequence_classification(): embd_pdrop=0, resid_pdrop=0, summary_first_dropout=0, - hidden_dropout=0) + hidden_dropout=0, + problem_type="single_label_classification", + pad_token_id=50256) config_for_token_classification = copy.deepcopy(config) config_for_token_classification.num_labels = 2 diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index e15295bc905f..46b262d0a8cd 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,11 +1,19 @@ import copy from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional import torch import torch.distributed as dist +from torch import Tensor +from torch import distributed as dist +from torch.distributed import ProcessGroup from torch.nn import Module +from torch.optim import Adam, Optimizer +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin from colossalai.lazy import LazyInitContext +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer._utils import getattr_ from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor @@ -79,20 +87,151 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''): assert torch.equal(v, shard_v), f'{name} {k} value mismatch' -def check_grad(original_model, sharded_model, layer_suffix, atol=1e-5, rtol=1e-5, dim=0, verbose=False): +def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]): + + use_lazy_init = False + if 'use_lazy_init' in test_config: + use_lazy_init = test_config.pop('use_lazy_init') + + if use_lazy_init: + ctx = LazyInitContext() + else: + ctx = nullcontext() + + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + + with ctx: + org_model = model_fn().cuda() + sharded_model = copy.deepcopy(org_model) + + if use_lazy_init: + org_model = ctx.materialize(org_model) + + org_optimizer = Adam(org_model.parameters(), lr=1e-3) + sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) + criterion = loss_fn + + sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) + + return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster + + +def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer, + data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable, + booster: Booster): + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + data = data_gen_fn() + sharded_model.train() + if booster.plugin.stage_manager is not None: + data = { + k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v + for k, v in data.items() + } + data_iter = iter([data]) + sharded_output = booster.execute_pipeline(data_iter, + sharded_model, + _criterion, + sharded_optimizer, + return_loss=True, + return_outputs=True) + sharded_loss = sharded_output['loss'] + else: + data = {k: v.cuda() for k, v in data.items()} + sharded_output = sharded_model(**data) + sharded_loss = criterion(sharded_output) + sharded_loss.backward() + + org_model.train() + org_output = org_model(**data) + org_loss = criterion(org_output) + org_loss.backward() + + return org_loss, org_output, sharded_loss, sharded_output + + +def check_output_hidden_state(org_output: Tensor, + sharded_output: Tensor, + stage_manager: Optional[PipelineStageManager] = None, + atol: float = 1e-5, + rtol: float = 1e-3): + + org_hidden_state = org_output.last_hidden_state + + if stage_manager is None: + sharded_hidden_state = sharded_output.last_hidden_state + + if stage_manager and stage_manager.is_last_stage(): + sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0) + + assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \ + f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" + + +def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): + assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \ + f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" + + +def check_weight(org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: Optional[ProcessGroup] = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False): + + for suffix in layer_suffix: + org_weight = getattr_(org_model, suffix).weight + sharded_weight = getattr_(sharded_model, suffix).weight + + if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): + sharded_weight_list = [ + torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) + ] + dist.all_gather(sharded_weight_list, sharded_weight, tp_group) + sharded_weight = torch.cat(sharded_weight_list, dim=dim) + + if verbose and dist.get_rank() == 0: + print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") + + assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \ + f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}" + + +def check_grad(org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: ProcessGroup = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False): + for suffix in layer_suffix: - org_grad = getattr_(original_model, suffix).weight.grad + org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad shard_weight = getattr_(sharded_model, suffix).weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size())] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=dim) - else: - all_shard_grad = shard_grad + shard_grad_list = [ + torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) + ] + dist.all_gather(shard_grad_list, shard_grad, tp_group) + shard_grad = torch.cat(shard_grad_list, dim=dim) + + # embedding may be resized when using tensor parallel + if shard_grad.shape[0] > org_grad.shape[0]: + shard_grad = shard_grad[:org_grad.shape[0], :] + if verbose and dist.get_rank() == 0: - print(f"'{suffix}' grad: {org_grad}, {all_shard_grad}") + print(f"'{suffix}' grad: {org_grad}, {shard_grad}") assert torch.allclose( - org_grad, all_shard_grad, rtol=rtol, atol=atol - ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{all_shard_grad}" + org_grad, shard_grad, rtol=rtol, atol=atol + ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index d1ab352f6512..cebb40bd16fe 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -1,107 +1,48 @@ -import copy -from contextlib import nullcontext - import pytest import torch from torch import distributed as dist -from torch.optim import Adam import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import HybridParallelPlugin -from colossalai.lazy.lazy_init import LazyInitContext from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import ( - clear_layout_converter, - is_customized_distributed_tensor, - is_distributed_tensor, -) +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): - use_lazy_init = False - if 'use_lazy_init' in test_config: - use_lazy_init = test_config.pop('use_lazy_init') + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - if use_lazy_init: - ctx = LazyInitContext() - else: - ctx = nullcontext() - - # prepare booster - plugin = HybridParallelPlugin(**test_config) - booster = Booster(plugin=plugin) - stage_manager = plugin.stage_manager - - # prepare models and optimizers - with ctx: - org_model = model_fn().cuda() - sharded_model = copy.deepcopy(org_model) - - if use_lazy_init: - org_model = ctx.materialize(org_model) - - org_optimizer = Adam(org_model.parameters(), lr=1e-3) - sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) - criterion = loss_fn - - sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) - - def _criterion(outputs, inputs): - outputs = output_transform_fn(outputs) - loss = criterion(outputs) - return loss - - # do forward and backward - data = data_gen_fn() - sharded_model.train() - if stage_manager: - data = { - k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v - for k, v in data.items() - } - data_iter = iter([data]) - sharded_output = booster.execute_pipeline(data_iter, - sharded_model, - _criterion, - sharded_optimizer, - return_loss=True, - return_outputs=True) - sharded_loss = sharded_output['loss'] - else: - data = {k: v.cuda() for k, v in data.items()} - sharded_output = sharded_model(**data) - sharded_loss = criterion(sharded_output) - sharded_loss.backward() - org_model.train() - org_output = org_model(**data) - org_loss = criterion(org_output) - org_loss.backward() + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - # check last hidden state if org_model.__class__.__name__ == 'GPT2Model': - org_hidden_state = org_output.last_hidden_state - - if stage_manager is None: - sharded_hidden_state = sharded_output.last_hidden_state - - if stage_manager and stage_manager.is_last_stage(): - sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], - dim=0) - - assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=1e-5, rtol=1e-3), \ - f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) - # check loss - assert torch.allclose(org_loss, sharded_loss, atol=1e-5, rtol=1e-3), \ - f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" + check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) # unwrap model if org_model.__class__.__name__ == 'GPT2Model': @@ -111,27 +52,19 @@ def _criterion(outputs, inputs): gpt2 = org_model.transformer sharded_gpt2 = sharded_model.unwrap().transformer - # check grad col_layer_for_check = ['h[0].mlp.c_fc'] - row_layer_for_check = ['h[0].mlp.c_proj'] - check_grad(gpt2, sharded_gpt2, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False) - check_grad(gpt2, sharded_gpt2, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False) + row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] + + # check grad + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) + check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) # check weights after optimizer.step() org_optimizer.step() sharded_optimizer.step() if stage_manager is None or stage_manager.is_first_stage(): - - org_weight = org_model.h[0].mlp.c_fc.weight - shard_weight = sharded_model.h[0].mlp.c_fc.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_weight_list = [torch.zeros([*shard_weight.shape]).to('cuda') for _ in range(plugin.tp_size)] - dist.all_gather(shard_weight_list, shard_weight, plugin.tp_group) - shard_weight = torch.cat(shard_weight_list, dim=1) - - assert torch.allclose(org_weight, shard_weight, atol=5e-3, rtol=1e-3), \ - f"shard model weight is not equal to origin model weight\n{org_weight}\n{shard_weight}" + check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False) torch.cuda.empty_cache() @@ -156,9 +89,11 @@ def _criterion(outputs, inputs): @clear_cache_before_run() def run_gpt2_test(test_config): - # TODO: add plugin_config for TP+DP after supporting & debugging it + # TODO: add test_config for TP+DP after supporting & debugging it # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + # TODO: add test_config for flash attention & jit operator after supporting + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') test_config['precision'] = 'float' # Do not use fp16/bf16 in testing @@ -175,7 +110,6 @@ def check_gpt2(rank, world_size, port): run_gpt2_test() -@pytest.mark.skip('Have some bug caused by merge') @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() From 0e3a8dad232efd3c8b95de6e35ed12b57d8097e5 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 4 Aug 2023 14:55:31 +0800 Subject: [PATCH 66/80] [pipeline] add chatglm (#4363) * add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt * add bloom model and policy ,revise the base class of policy * revise * revision * add bert_for_pretraining * add bert_for_pretraining forward and policy * fix typos * cancel warning * change the imediate output to default dict * change the default output of get_shared_params * add chatglm * add * chatglm * chatglm * finish chatglm * deletes * fix rmsnorm * chatglm * fix chatglm shard * init --- colossalai/shardformer/modeling/chatglm.py | 189 +++ .../chatglm2_6b/configuration_chatglm.py | 58 + .../modeling/chatglm2_6b/modeling_chatglm.py | 1373 +++++++++++++++++ colossalai/shardformer/policies/chatglm.py | 114 +- tests/kit/model_zoo/transformers/chatglm.py | 17 +- .../test_policy/test_t5_pipeline_utils.py | 39 - tests/test_shardformer/test_model/_utils.py | 7 +- .../test_model/test_shard_chatglm.py | 2 +- .../test_model/test_shard_chatglm_pipeline.py | 86 ++ 9 files changed, 1828 insertions(+), 57 deletions(-) create mode 100644 colossalai/shardformer/modeling/chatglm.py create mode 100644 colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py create mode 100644 colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py delete mode 100644 tests/test_pipeline/test_policy/test_t5_pipeline_utils.py create mode 100644 tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py new file mode 100644 index 000000000000..0bb8bdc58218 --- /dev/null +++ b/colossalai/shardformer/modeling/chatglm.py @@ -0,0 +1,189 @@ +""" PyTorch ChatGLM model. """ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss, LayerNorm +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMBlock, +) + + +class ChatGLMPipelineForwards: + ''' + This class serves as a micro library for ChatGLM model forwards under pipeline parallelism. + ''' + + @staticmethod + def chatglm_model_forward( + self: ChatGLMModel, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + logger = logging.get_logger(__name__) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + if stage_manager.is_first_stage(): + batch_size, seq_length = input_ids.shape + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + hidden_states = inputs_embeds + else: + seq_length, batch_size = hidden_states.shape[:2] + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt(batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], + dim=-1) + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + if not past_key_values: + past_key_values = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.encoder.gradient_checkpointing and self.encoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + start_idx, end_idx = stage_index[0], stage_index[1] + for idx in range(start_idx, end_idx): + layer = self.encoder._get_layer(idx) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if self.encoder.gradient_checkpointing and self.encoder.training: + layer_ret = torch.utils.checkpoint.checkpoint(layer, hidden_states, attention_mask, rotary_pos_emb, + past_key_values[idx], use_cache) + else: + layer_ret = layer(hidden_states, + full_attention_mask, + rotary_pos_emb, + kv_cache=past_key_values[idx], + use_cache=use_cache) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if stage_manager.is_last_stage(): + # final layer_norm + if self.encoder.post_layer_norm: + hidden_states = self.encoder.final_layernorm(hidden_states) + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + else: + return {'hidden_states': hidden_states} + + @staticmethod + def chatglm_for_conditional_generation_forward( + self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + logger = logging.get_logger(__name__) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + transformer_outputs = ChatGLMPipelineForwards.chatglm_model_forward( + self.transformer, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + return transformer_outputs diff --git a/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py new file mode 100644 index 000000000000..3e78732be2da --- /dev/null +++ b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py @@ -0,0 +1,58 @@ +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + + def __init__(self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py new file mode 100644 index 000000000000..a21ee0231422 --- /dev/null +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -0,0 +1,1373 @@ +""" +The ChatGLM2-6B License + +1. Definitions + +“Licensor” means the ChatGLM2-6B Model Team that distributes its Software. + +“Software” means the ChatGLM2-6B model parameters made available under this license. + +2. License Grant + +Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +3. Restriction + +You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. + +You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. + +4. Disclaimer + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +5. Limitation of Liability + +EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Dispute Resolution + +This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. + +Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. +""" +""" PyTorch ChatGLM model. """ + +import copy +import math +import re +import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != "darwin": + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm2-6b", + # See all ChatGLM models at https://huggingface.co/models?filter=chatglm +] + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = (config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size), + ) + else: + self.embedding = torch.nn.Embedding( + config.pre_seq_len, + config.num_layers * config.kv_channels * config.multi_query_group_num * 2, + ) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl( + self, + seq_len: int, + n_elem: int, + dtype: torch.dtype, + device: torch.device, + base: int = 10000, + ): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl( + max_seq_len, + self.dim, + dtype=self.inv_freq.dtype, + device=self.inv_freq.device, + ) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.elementwise_affine = True + self.normalized_shape = normalized_shape + self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = (projection_size // config.num_attention_heads) + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split(".")[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=query_layer.device, + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if (attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]): + attention_mask = torch.ones( + output_size[0], + 1, + output_size[2], + output_size[3], + device=attention_scores.device, + dtype=torch.bool, + ) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + # Per attention head and per partition values. + self.hidden_size_per_attention_head = (self.projection_size // config.num_attention_heads) + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = (self.projection_size + + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num) + self.query_key_value = nn.Linear( + config.hidden_size, + self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, + **_config_to_kwargs(config), + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear( + self.projection_size, + config.hidden_size, + bias=config.add_bias_linear, + device=device, + **_config_to_kwargs(config), + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view(query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + key_layer = key_layer.view(key_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + )) + value_layer = value_layer.view(value_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + )) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + key_layer = key_layer.contiguous().view(key_layer.size()[:2] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + value_layer = value_layer.contiguous().view(value_layer.size()[:2] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + "dtype": args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config), + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config), + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = (config.apply_residual_connection_post_layernorm) + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache, + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache, + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat( + ( + torch.ones(batch_size, seq_length, past_length, device=input_ids.device), + full_attention_mask, + ), + dim=-1, + ) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = (torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)) + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GLMTransformer): + module.gradient_checkpointing = value + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device, + ) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(ChatGLMPreTrainedModel): + + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = (config.hidden_size // + config.num_attention_heads if config.kv_channels is None else config.kv_channels) + + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim // 2, + original_impl=config.original_rope, + device=device, + dtype=config.torch_dtype, + ) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method( + nn.Linear, + config.hidden_size, + config.padded_vocab_size, + bias=False, + dtype=config.torch_dtype, + **init_kwargs, + ) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = (self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.multi_query_group_num, + self.kv_channels, + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def quantize(self, weight_bit_width: int): + from .quantization import quantize + + quantize(self.encoder, weight_bit_width) + return self + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + self.config = config + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], + dim=-1, + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) + + model_kwargs["is_first_forward"] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + is_first_forward: bool = True, + **kwargs, + ) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids(input_ids, device=input_ids.device) + if not is_first_forward: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "position_ids": position_ids, + "attention_mask": attention_mask, + "return_last_logit": True, + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], + beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple(( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) for layer_past in past) + + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + return response + + def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + prompt = tokenizer.build_prompt(query, history=history) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + return inputs + + def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + if history: + prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + input_ids = tokenizer.encode(prompt, add_special_tokens=False) + input_ids = input_ids[1:] + inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) + else: + prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + return inputs + + @torch.no_grad() + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 8192, + num_beams=1, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + **kwargs, + ): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + "max_length": max_length, + "num_beams": num_beams, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } + inputs = self.build_inputs(tokenizer, query, history=history) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + past_key_values=None, + max_length: int = 8192, + do_sample=True, + top_p=0.8, + temperature=0.8, + logits_processor=None, + return_past_key_values=False, + **kwargs, + ): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = { + "max_length": max_length, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } + if past_key_values is None and not return_past_key_values: + inputs = self.build_inputs(tokenizer, query, history=history) + else: + inputs = self.build_stream_inputs(tokenizer, query, history=history) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + if self.transformer.pre_seq_len is not None: + past_length -= self.transformer.pre_seq_len + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) + inputs["attention_mask"] = attention_mask + for outputs in self.stream_generate( + **inputs, + past_key_values=past_key_values, + return_past_key_values=return_past_key_values, + **gen_kwargs, + ): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + if response and response[-1] != "�": + response = self.process_response(response) + new_history = history + [(query, response)] + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values=False, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = ( + generation_config.bos_token_id, + generation_config.eos_token_id, + ) + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = (kwargs.get("max_length") is None and generation_config.max_length is not None) + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = (generation_config.max_new_tokens + input_ids_seq_length) + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = ("decoder_input_ids" if self.config.is_encoder_decoder else "input_ids") + logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`.") + + # 2. Set generation parameters if not already defined + logits_processor = (logits_processor if logits_processor is not None else LogitsProcessorList()) + stopping_criteria = (stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()) + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, + stopping_criteria=stopping_criteria) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation(outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + + def quantize(self, bits: int, empty_init=False, device=None, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer.encoder = quantize( + self.transformer.encoder, + bits, + empty_init=empty_init, + device=device, + **kwargs, + ) + return self diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 732a817b0655..9cc651caddc1 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -1,32 +1,46 @@ -from typing import Dict, Union +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Union import torch.nn as nn +from torch import Tensor +from transformers.modeling_outputs import BaseModelOutputWithPast import colossalai.shardformer.layer as col_nn +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.modeling.chatglm import ChatGLMPipelineForwards +from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMBlock, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] +__all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] -class ChatGLMModelPolicy(Policy): +class ChatGLMPolicy(Policy): def config_sanity_check(self): pass def preprocess(self): # Resize embedding - vocab_size = self.model.config.padded_vocab_size - world_size = self.shard_config.tensor_parallel_size + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.padded_vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock + from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock + policy = {} @@ -112,9 +126,91 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def postprocess(self): return self.model + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == 'ChatGLMModel': + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embedding) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.encoder.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + if module.encoder.post_layer_norm: + held_layers.append(module.encoder.final_layernorm) + + # rotary_pos_emb is needed for all stages + held_layers.append(module.rotary_pos_emb) + + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == 'ChatGLMModel': + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + +class ChatGLMModelPolicy(ChatGLMPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ChatGLMModel, + new_forward=ChatGLMPipelineForwards.chatglm_model_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in ChatGLMModel.""" + return [] + + class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): def module_policy(self): policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ChatGLMForConditionalGeneration, + new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward, + policy=policy) return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.transformer.output_layer) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in ChatGLMForConditionalGenerationModel.""" + return [] + diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 04e73a832abe..056c910a8dfe 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -1,9 +1,11 @@ import torch import transformers +from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + from ..registry import ModelAttribute, model_zoo -from .chatglm2_6b.configuration_chatglm import ChatGLMConfig -from .chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + # ================================ # Register single-sentence ChatGLM @@ -20,15 +22,18 @@ def data_gen(): output_transform_fn = lambda x: x # define loss function -loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean() -loss_fn = lambda x: x.logits.mean() +loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.sum() +loss_fn = lambda x: x.logits.sum() + config = ChatGLMConfig(num_layers=1, padded_vocab_size=65024, hidden_size=64, num_attention_heads=8, - rmsnorm=False, + rmsnorm=True, original_rope=True, - use_cache=True) + use_cache=True, + torch_dtype=torch.float32) + model_zoo.register(name='transformers_chatglm', model_fn=lambda: ChatGLMModel(config, empty_init=False), diff --git a/tests/test_pipeline/test_policy/test_t5_pipeline_utils.py b/tests/test_pipeline/test_policy/test_t5_pipeline_utils.py deleted file mode 100644 index 0cbb852b97a0..000000000000 --- a/tests/test_pipeline/test_policy/test_t5_pipeline_utils.py +++ /dev/null @@ -1,39 +0,0 @@ -from colossalai.shardformer.policies.t5 import T5BasePolicy - - -def test_t5_pipeline_distribution(): - num_test_cases = 8 - test_dict = { - 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], - 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], - 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], - 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] - } - - for i in range(num_test_cases): - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i], - test_dict['num_decoder_layers'][i], - test_dict['num_stages'][i]) - assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage - - -def test_t5_pipeline_layers(): - num_test_cases = 4 - test_dict = { - 'num_encoder_layers': [2, 3, 2, 4], - 'num_decoder_layers': [2, 0, 2, 8], - 'num_stages': [2, 2, 4, 4], - 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], - [[0, 4], [0, 3], [3, 6], [6, 8]]] - } - - for i in range(num_test_cases): - layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) - - for stage in range(test_dict['num_stages'][i]): - start_idx, end_idx = test_dict['layers_per_stage'][i][stage] - predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage, - decoder_starting_stage) - assert start_idx == predicted_start - assert end_idx == predicted_end diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 46b262d0a8cd..0e5cb8144ef3 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,5 +1,6 @@ import copy from contextlib import nullcontext +from typing import Optional from typing import Any, Callable, Dict, List, Optional import torch @@ -15,6 +16,7 @@ from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.auto_policy import Policy from colossalai.shardformer._utils import getattr_ from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor @@ -39,7 +41,8 @@ def build_pipeline_model(model_fn, stage_manager=None, enable_fused_normalization=False, enable_tensor_parallelism=False, - use_lazy_init: bool = False): + use_lazy_init: bool = False, + policy: Optional[Policy] = None): ctx = LazyInitContext() if use_lazy_init else nullcontext() with ctx: # create new model @@ -54,7 +57,7 @@ def build_pipeline_model(model_fn, pipeline_stage_manager=stage_manager) shard_former = ShardFormer(shard_config=shard_config) - sharded_model, shared_params = shard_former.optimize(model_copy) + sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy) return org_model.cuda(), sharded_model.cuda() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 36f240a0ffc0..005223fb8ae4 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -60,7 +60,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo shard_weight = shard_chatglm_model.embedding.word_embeddings.weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad_list = [torch.zeros_like(shard_grad) for _ in range(2)] torch.distributed.all_gather(shard_grad_list, shard_grad) all_shard_grad = torch.cat(shard_grad_list, dim=0) else: diff --git a/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py b/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py new file mode 100644 index 000000000000..ee474ac7be3b --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py @@ -0,0 +1,86 @@ +import copy +import os + +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + # create new model for test + inputs = data_gen_fn() + inputs = {k: v.cuda() for k, v in inputs.items()} + input_ids = inputs['input_ids'] + hidden_size = 64 + batch_size, seq_len = input_ids.shape + hidden_state_shape = (seq_len, batch_size, hidden_size) + if name == "transformers_chatglm": + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init, ChatGLMModelPolicy()) + if stage_manager.is_last_stage(): + hidden_states = torch.randn(*hidden_state_shape).cuda() + inputs['input_ids'] = None + inputs['hidden_states'] = hidden_states + outputs = sharded_model(**inputs) + if stage_manager.is_last_stage(): + assert outputs[0].shape == hidden_state_shape + + else: + assert outputs['hidden_states'].shape == hidden_state_shape + + if name == "transformers_chatglm_for_conditional_generation": + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init, + ChatGLMForConditionalGenerationPolicy()) + if stage_manager.is_last_stage(): + hidden_states = torch.randn(*hidden_state_shape).cuda() + inputs['input_ids'] = None + inputs['hidden_states'] = hidden_states + outputs = sharded_model(**inputs) + if stage_manager.is_last_stage(): + assert outputs[0].shape == (batch_size, seq_len, 65024) + else: + assert outputs['hidden_states'].shape == hidden_state_shape + + torch.cuda.empty_cache() + + +def check_chatglm(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_chatglm_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm(): + spawn(check_chatglm, 4) + + +if __name__ == "__main__": + test_chatglm() From 47111748281b671ceca1b114b9b264f7eeeb54ad Mon Sep 17 00:00:00 2001 From: flybird1111 <1829166702@qq.com> Date: Mon, 7 Aug 2023 16:41:07 +0800 Subject: [PATCH 67/80] [Shardformer] Merge flash attention branch to pipeline branch (#4362) * [shardformer] supported flash attention test dependency (#4158) * [shardformer] fix flash attention utils test (#4180) * [shardformer] opt support flash attention (#4163) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] add performance benchmark of shardformer (#4175) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] benchmark fix * [shardformer] benchmark fix * [shardformer] llama support flash attention (#4185) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] llama support flash attention * [shardformer] llama support flash attention * [shardformer] Move the import statement for xformer outside the forward function. * [shardformer] gpt2 support flash attention. (#4191) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] gpt2 support flash attention * [shardformer] gpt2 support flash attention * [shardformer] gpt2 support flash attention * [shardformer] bloom support flash attention (#4188) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] bloom suport flash attention * [shardformer] add assert to sequence length * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] bert support flash attention. (#4206) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] bert support flash attention * [shardformer] t5 support flash attention. (#4216) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] t5 support flash attention * [shardformer] t5 support flash attention * fix typo * fix typo * fix typo * fix typo * fix typo * fix typo * [shardformer] support 'paddedcausal' type of attention mask in Coloattention. (#4215) * added padded causal attn mask type for ColoAttention * [shardformer]t5 flash attention fix (#4239) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] t5 flash attention fix * [shardformer] update gpt2 to use coloattention. (#4234) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 * [shardformer] update opt and llama to use coloattention. (#4226) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt * [shardformer] shardformer support jit fused operator. (#4236) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] bloom support jit fused operator * [shardformer] bloom support jit fused operator * [shardformer] bloom support jit fused operator * [shardformer] t5 support jit fused operator * [shardformer] t5 support jit fused operator * [shardformer] t5 support jit fused operator * [shardformer] add roadmap of flash attention * [shardformer] add roadmap of flash attention * [shardformer] add roadmap of flash attention * [shardformer] add type hint to 'self' param of forward * [shardformer] merge feature/shardformer-models branch to feature/flash-attention-shardformer branch. (#4290) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> * [shardformer] whisper support flash attention (#4301) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] whisper support flash attention * [shardformer] whisper support flash attention * [shardformer]whisper support jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> * [shardformer] sam support flash attention (#4316) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] sam support flash attention --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> * [shardformer] merge blip2/chatglm (#4321) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [shardformer] blip2 support flash attention and jit operator (#4325) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin * [shardformer] blip2 support flash attention and jit operator * [shardformer] blip2 support flash attention and jit operator * [shardformer] blip2 support flash attention and jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [shardformer] chatglm support flash attention and jit operator (#4330) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin * [shardformer] chatglm support flash attention and jit operator * [shardformer] chatglm support flash attention and jit operator * [shardformer] chatglm support flash attention and jit operator * [shardformer] chatglm support flash attention and jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [shardformer] vit support flash attention and jit operator (#4334) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin * [shardformer] vit support flash attention and jit operator * [shardformer] vit support flash attention and jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [pipeline] merge flash attention branch * [pipeline] merge flash attention branch * [pipeline] merge flash attention branch * [pipeline] fix conflict * [pipeline] fix conflict * Merge branch 'feature/pipeline' into feature/pipeline * Merge branch 'feature/pipeline' into feature/pipeline * Merge branch 'feature/pipeline' into feature/pipeline * activate checks * activate checks * activate checks * activate checks * activate checks * activate checks * activate checks * activate checks * fix flash attention tests * gemini ignore whisper * fix vit * fix xformers import handle --------- Co-authored-by: Frank Lee Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> --- colossalai/shardformer/README.md | 58 +- ..._benchmark.py => convergence_benchmark.py} | 0 ..._benchmark.sh => convergence_benchmark.sh} | 2 +- .../examples/performance_benchmark.py | 86 ++ colossalai/shardformer/modeling/bert.py | 138 +- colossalai/shardformer/modeling/blip2.py | 60 + colossalai/shardformer/modeling/bloom.py | 221 +++ colossalai/shardformer/modeling/chatglm.py | 110 ++ colossalai/shardformer/modeling/gpt2.py | 85 + colossalai/shardformer/modeling/jit.py | 34 + colossalai/shardformer/modeling/llama.py | 66 +- colossalai/shardformer/modeling/opt.py | 174 +++ colossalai/shardformer/modeling/sam.py | 164 ++ colossalai/shardformer/modeling/t5.py | 206 +++ colossalai/shardformer/modeling/vit.py | 49 + colossalai/shardformer/modeling/whisper.py | 249 +++ colossalai/shardformer/policies/bert.py | 34 +- colossalai/shardformer/policies/blip2.py | 28 +- colossalai/shardformer/policies/bloom.py | 34 +- colossalai/shardformer/policies/chatglm.py | 20 +- colossalai/shardformer/policies/gpt2.py | 90 +- colossalai/shardformer/policies/llama.py | 9 +- colossalai/shardformer/policies/opt.py | 17 +- colossalai/shardformer/policies/sam.py | 12 +- colossalai/shardformer/policies/t5.py | 30 +- colossalai/shardformer/policies/vit.py | 25 +- colossalai/shardformer/policies/whisper.py | 25 + colossalai/shardformer/shard/shard_config.py | 5 +- requirements/requirements-test.txt | 2 + tests/kit/model_zoo/transformers/bert.py | 16 +- tests/kit/model_zoo/transformers/blip2.py | 1 + tests/kit/model_zoo/transformers/bloom.py | 10 +- tests/kit/model_zoo/transformers/chatglm.py | 1 - .../chatglm2_6b/configuration_chatglm.py | 58 - .../chatglm2_6b/modeling_chatglm.py | 1372 ----------------- tests/kit/model_zoo/transformers/gpt.py | 6 +- tests/kit/model_zoo/transformers/t5.py | 10 +- tests/kit/model_zoo/transformers/whisper.py | 4 +- .../test_plugin/test_gemini_plugin.py | 2 +- tests/test_shardformer/test_model/_utils.py | 13 +- .../test_model/test_shard_bert.py | 11 +- .../test_model/test_shard_blip2.py | 7 +- .../test_model/test_shard_bloom.py | 8 +- .../test_model/test_shard_chatglm.py | 8 +- .../test_model/test_shard_gpt2.py | 1 - .../test_model/test_shard_llama.py | 5 +- .../test_model/test_shard_opt.py | 15 +- .../test_model/test_shard_sam.py | 6 +- .../test_model/test_shard_t5.py | 11 +- .../test_model/test_shard_vit.py | 9 +- .../test_model/test_shard_whisper.py | 8 +- tests/test_utils/test_flash_attention.py | 2 +- 52 files changed, 2061 insertions(+), 1556 deletions(-) rename colossalai/shardformer/examples/{shardformer_benchmark.py => convergence_benchmark.py} (100%) rename colossalai/shardformer/examples/{shardformer_benchmark.sh => convergence_benchmark.sh} (76%) create mode 100644 colossalai/shardformer/examples/performance_benchmark.py create mode 100644 colossalai/shardformer/modeling/jit.py create mode 100644 colossalai/shardformer/modeling/opt.py create mode 100644 colossalai/shardformer/modeling/whisper.py delete mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py delete mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 5489f97e4d19..5d00e606dc94 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -30,7 +30,7 @@ ### Quick Start -The sample API usage is given below: +The sample API usage is given below(If you enable the use of flash attention, please install xformers.): ```python from colossalai.shardformer import ShardConfig, Shard @@ -106,6 +106,20 @@ We will follow this roadmap to develop Shardformer: - [ ] Multi-modal - [x] SAM - [x] BLIP-2 +- [ ] Flash Attention Support + - [ ] NLP + - [x] BERT + - [x] T5 + - [x] LlaMa + - [x] GPT2 + - [x] OPT + - [x] BLOOM + - [ ] GLM + - [ ] RoBERTa + - [ ] ALBERT + - [ ] ERNIE + - [ ] GPT Neo + - [ ] GPT-J ## 💡 API Design @@ -373,11 +387,49 @@ pytest tests/test_shardformer ### System Performance -To be added. +We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model. + +We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length. + +In the case of using 2 GPUs, the training times are as follows. +| N_CTX | org_model | shard_model | +| :------: | :-----: | :-----: | +| 256 | 11.2ms | 17.2ms | +| 512 | 9.8ms | 19.5ms | +| 1024 | 19.6ms | 18.9ms | +| 2048 | 46.6ms | 30.8ms | +| 4096 | 160.5ms | 90.4ms | + + +

+ +
+

+ +In the case of using 4 GPUs, the training times are as follows. + +| N_CTX | org_model | shard_model | +| :------: | :-----: | :-----: | +| 256 | 10.0ms | 21.1ms | +| 512 | 11.5ms | 20.2ms | +| 1024 | 22.1ms | 20.6ms | +| 2048 | 46.9ms | 24.8ms | +| 4096 | 160.4ms | 68.0ms | + + + +

+ +
+

+ + +As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident. ### Convergence -To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. + +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. | accuracy | f1 | loss | GPU number | model shard | | :------: | :-----: | :-----: | :--------: | :---------: | diff --git a/colossalai/shardformer/examples/shardformer_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py similarity index 100% rename from colossalai/shardformer/examples/shardformer_benchmark.py rename to colossalai/shardformer/examples/convergence_benchmark.py diff --git a/colossalai/shardformer/examples/shardformer_benchmark.sh b/colossalai/shardformer/examples/convergence_benchmark.sh similarity index 76% rename from colossalai/shardformer/examples/shardformer_benchmark.sh rename to colossalai/shardformer/examples/convergence_benchmark.sh index f42b19a32d35..1c281abcda6d 100644 --- a/colossalai/shardformer/examples/shardformer_benchmark.sh +++ b/colossalai/shardformer/examples/convergence_benchmark.sh @@ -1,4 +1,4 @@ -torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \ +torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \ --model "bert" \ --pretrain "bert-base-uncased" \ --max_epochs 1 \ diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py new file mode 100644 index 000000000000..9c7b76bcf0a6 --- /dev/null +++ b/colossalai/shardformer/examples/performance_benchmark.py @@ -0,0 +1,86 @@ +""" +Shardformer Benchmark +""" +import torch +import torch.distributed as dist +import transformers +import triton + +import colossalai +from colossalai.shardformer import ShardConfig, ShardFormer + + +def data_gen(batch_size, seq_length): + input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long) + attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_sequence_classification(batch_size, seq_length): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen(batch_size, seq_length) + data['labels'] = torch.ones((batch_size), dtype=torch.long) + return data + + +MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128, + num_labels=16) +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64 +model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG) + +# vary seq length for fixed head and batch=4 +configs = [ + triton.testing.Benchmark(x_names=['N_CTX'], + x_vals=[2**i for i in range(8, 13)], + line_arg='provider', + line_vals=['org_model', 'shard_model'], + line_names=['org_model', 'shard_model'], + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'lama_for_sequence_classification-batch-{BATCH}', + args={ + 'BATCH': BATCH, + 'dtype': torch.float16, + 'model_func': model_func + }) +] + + +def train(model, data): + output = model(**data) + loss = output.logits.mean() + loss.backward() + + +@triton.testing.perf_report(configs) +def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"): + warmup = 10 + rep = 100 + # prepare data + data = data_gen_for_sequence_classification(BATCH, N_CTX) + data = {k: v.cuda() for k, v in data.items()} + model = model_func().to(device) + model.train() + if provider == "org_model": + fn = lambda: train(model, data) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + if provider == "shard_model": + shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True) + shard_former = ShardFormer(shard_config=shard_config) + sharded_model = shard_former.optimize(model).cuda() + fn = lambda: train(sharded_model, data) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +# start benchmark, command: +# torchrun --standalone --nproc_per_node=2 performance_benchmark.py +if __name__ == "__main__": + colossalai.launch_from_torch({}) + bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 1b3c14d9d1c9..b9d4b5fda7af 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,5 +1,6 @@ +import math import warnings -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -962,3 +963,138 @@ def bert_for_question_answering_forward( else: hidden_states = outputs.get('hidden_states') return {'hidden_states': hidden_states} + + +def get_bert_flash_attention_forward(): + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from transformers.models.bert.modeling_bert import BertAttention + + def forward( + self: BertAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + final_attention_mask = None + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + final_attention_mask = relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + final_attention_mask = relative_position_scores_query + relative_position_scores_key + + scale = 1 / math.sqrt(self.attention_head_size) + if attention_mask is not None: + if final_attention_mask != None: + final_attention_mask = final_attention_mask * scale + attention_mask + else: + final_attention_mask = attention_mask + batch_size, src_len = query_layer.size()[0], query_layer.size()[2] + tgt_len = key_layer.size()[2] + final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len) + + query_layer = query_layer.permute(0, 2, 1, 3).contiguous() + key_layer = key_layer.permute(0, 2, 1, 3).contiguous() + value_layer = value_layer.permute(0, 2, 1, 3).contiguous() + + context_layer = me_attention(query_layer, + key_layer, + value_layer, + attn_bias=final_attention_mask, + p=self.dropout.p, + scale=scale) + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, None) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + return forward + + +def get_jit_fused_bert_self_output_forward(): + + from transformers.models.bert.modeling_bert import BertSelfOutput + + def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward + + +def get_jit_fused_bert_output_forward(): + + from transformers.models.bert.modeling_bert import BertOutput + + def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index b7945423ae83..c5c6b14ba993 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -1,3 +1,4 @@ +import math from typing import Optional, Tuple, Union import torch @@ -58,3 +59,62 @@ def forward( return outputs return forward + + +def get_blip2_flash_attention_forward(): + + from transformers.models.blip_2.modeling_blip_2 import Blip2Attention + + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + + def forward( + self: Blip2Attention, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + mixed_qkv = self.qkv(hidden_states) + mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) + query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] + + attention = ColoAttention(embed_dim=self.embed_dim, + num_heads=self.num_heads, + dropout=self.dropout.p, + scale=self.scale) + context_layer = attention(query_states, key_states, value_states) + + output = self.projection(context_layer) + outputs = (output, None) + + return outputs + + return forward + + +def get_jit_fused_blip2_QFormer_self_output_forward(): + + from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput + + def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward + + +def get_jit_fused_blip2_QFormer_output_forward(): + + from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput + + def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 76948fc70439..57c45bc6adfa 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -5,6 +5,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import functional as F from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -675,3 +676,223 @@ def bloom_for_question_answering_forward( else: hidden_states = outputs.get('hidden_states') return {'hidden_states': hidden_states} + + +def get_bloom_flash_attention_forward(enabel_jit_fused=False): + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from transformers.models.bloom.modeling_bloom import BloomAttention + + def forward( + self: BloomAttention, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + + fused_qkv = self.query_key_value(hidden_states) + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + batch_size, tgt_len, _ = hidden_states.size() + assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + + _, kv_length, _, _ = key_layer.size() + + proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim) + query_layer = query_layer.contiguous().view(*proj_shape) + key_layer = key_layer.contiguous().view(*proj_shape) + value_layer = value_layer.contiguous().view(*proj_shape) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + if use_cache is True: + present = (key_layer, value_layer) + else: + present = None + + tgt_len = key_layer.size()[1] + + attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length), + dtype=torch.float32, + device=query_layer.device, + requires_grad=True) + attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, + kv_length) * self.beta + attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask, + torch.finfo(torch.float32).min) + + context_layer = me_attention(query_layer, + key_layer, + value_layer, + attn_bias=attention_numerical_mask, + scale=self.inv_norm_factor, + p=self.attention_dropout.p) + context_layer = context_layer.reshape(-1, kv_length, self.hidden_size) + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices):int((i + 1) * slices)], + self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + # TODO to replace with the bias_dropout_add function in jit + output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + outputs = (output_tensor, present, None) + + return outputs + + return forward + + +def get_jit_fused_bloom_attention_forward(): + + from transformers.models.bloom.modeling_bloom import BloomAttention + + def forward( + self: BloomAttention, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, q_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=2) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, _, kv_length = key_layer.shape + + if use_cache is True: + present = (key_layer, value_layer) + else: + present = None + + # [batch_size * num_heads, q_length, kv_length] + # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 + matmul_result = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=self.beta, + alpha=self.inv_norm_factor, + ) + + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + attention_scores = attention_scores.to(torch.float) + attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size x num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm(attention_probs_reshaped, value_layer) + + # change view [batch_size, num_heads, q_length, head_dim] + context_layer = self._merge_heads(context_layer) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices):int((i + 1) * slices)], + self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + + outputs = (output_tensor, present) + if output_attentions: + outputs += (attention_probs,) + + return outputs + + return forward + + +def get_jit_fused_bloom_mlp_forward(): + + from transformers.models.bloom.modeling_bloom import BloomMLP + + def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) + + if self.pretraining_tp > 1 and self.slow_but_exact: + intermediate_output = torch.zeros_like(residual) + slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp + for i in range(self.pretraining_tp): + intermediate_output = intermediate_output + F.linear( + hidden_states[:, :, int(i * slices):int((i + 1) * slices)], + self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)], + ) + else: + intermediate_output = self.dense_4h_to_h(hidden_states) + output = self.dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) + return output + + return forward + + +def get_jit_fused_bloom_gelu_forward(): + + from transformers.models.bloom.modeling_bloom import BloomGelu + + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction + + def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor: + bias = torch.zeros_like(x) + if self.training: + return JitGeLUFunction.apply(x, bias) + else: + return self.bloom_gelu_forward(x, bias) + + return forward diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py index 0bb8bdc58218..3d453c3bd6db 100644 --- a/colossalai/shardformer/modeling/chatglm.py +++ b/colossalai/shardformer/modeling/chatglm.py @@ -17,6 +17,116 @@ ) +def get_flash_core_attention_forward(): + + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + + from .chatglm2_6b.modeling_chatglm import CoreAttention + + def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split(".")[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + query_layer = query_layer.permute(1, 0, 2, 3).contiguous() + key_layer = key_layer.permute(1, 0, 2, 3).contiguous() + value_layer = value_layer.permute(1, 0, 2, 3).contiguous() + + scale = 1.0 / self.norm_factor + if self.coeff is not None: + scale = scale * self.coeff + + flash_attention_mask = None + attn_mask_type = None + if attention_mask is None: + attn_mask_type = AttnMaskType.causal + else: + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + attn_mask_type = AttnMaskType.paddedcausal + + attention = ColoAttention(embed_dim=self.hidden_size_per_partition, + num_heads=self.num_attention_heads_per_partition, + dropout=self.attention_dropout.p, + scale=scale) + context_layer = attention(query_layer, + key_layer, + value_layer, + attn_mask=flash_attention_mask, + attn_mask_type=attn_mask_type) + + context_layer = context_layer.permute(1, 0, -1).contiguous() + + return context_layer + + return forward + + +def get_jit_fused_glm_block_forward(): + + from .chatglm2_6b.modeling_chatglm import GLMBlock + + def forward( + self: GLMBlock, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [s, b, h] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = self.dropout_add(attention_output, residual, self.hidden_dropout, self.training) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = self.dropout_add(mlp_output, residual, self.hidden_dropout, self.training) + + return output, kv_cache + + return forward + + + class ChatGLMPipelineForwards: ''' This class serves as a micro library for ChatGLM model forwards under pipeline parallelism. diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index dc5a81dc912b..e02581fbaa9b 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -668,3 +668,88 @@ def gpt2_for_sequence_classification_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +def get_gpt2_flash_attention_forward(): + + from transformers.models.gpt2.modeling_gpt2 import GPT2Attention + + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + + def split_heads(tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor + + def forward( + self: GPT2Attention, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + _, tgt_len, _ = hidden_states.size() + assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.") + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = split_heads(query, self.num_heads, self.head_dim) + key = split_heads(key, self.num_heads, self.head_dim) + value = split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + if use_cache is True: + present = (key, value) + else: + present = None + + if not self.is_cross_attention: + attn_mask_type = AttnMaskType.causal + flash_attention_mask = None + if attention_mask != None: + if attn_mask_type == AttnMaskType.causal: + attn_mask_type == AttnMaskType.paddedcausal + else: + attn_mask_type = AttnMaskType.padding + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + + scale = value.size(-1)**-0.5 + if self.scale_attn_by_inverse_layer_idx: + scale = scale * (1 / float(self.layer_idx + 1)) + + # use coloattention + attention = ColoAttention(embed_dim=self.embed_dim, + num_heads=self.num_heads, + dropout=self.attn_dropout.p, + scale=scale) + + attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) + + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + outputs = (attn_output, present, None) + + return outputs + + return forward diff --git a/colossalai/shardformer/modeling/jit.py b/colossalai/shardformer/modeling/jit.py new file mode 100644 index 000000000000..6434348ef823 --- /dev/null +++ b/colossalai/shardformer/modeling/jit.py @@ -0,0 +1,34 @@ +import torch + + +def get_dropout_add_func(): + + from transformers.models.bloom.modeling_bloom import dropout_add + + def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + return dropout_add(x, residual, prob, training) + + return self_dropout_add + + +def get_jit_fused_dropout_add_func(): + + from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train + + def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + bias = torch.zeros_like(x) + if training: + return bias_dropout_add_fused_train(x, bias, residual, prob) + return bias_dropout_add_fused_inference(x, bias, residual, prob) + + return self_dropout_add + + +def get_jit_fused_gelu_forward_func(): + + from colossalai.kernel.jit.bias_gelu import bias_gelu + + def bloom_gelu_forward(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: + return bias_gelu(bias, x) + + return bloom_gelu_forward diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index e1ed5f64665c..9d6335503b36 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Tuple import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -386,3 +386,67 @@ def llama_for_sequence_classification_forward( else: hidden_states = transformer_outputs.get('hidden_states') return {'hidden_states': hidden_states} + + +def get_llama_flash_attention_forward(): + + from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb + + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + + def forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) + query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) + key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) + value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) + + flash_attention_mask = None + attn_mask_type = AttnMaskType.causal + if attention_mask != None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + attn_mask_type = AttnMaskType.paddedcausal + + attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) + attn_output = attention(query_states, + key_states, + value_states, + attn_mask=flash_attention_mask, + attn_mask_type=attn_mask_type) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + return forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py new file mode 100644 index 000000000000..299dfb5562f3 --- /dev/null +++ b/colossalai/shardformer/modeling/opt.py @@ -0,0 +1,174 @@ +from typing import Optional, Tuple + +import torch +from torch import nn + + +def get_opt_flash_attention_forward(): + + from transformers.models.opt.modeling_opt import OPTAttention + + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + + def forward( + self: OPTAttention, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k, v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + + flash_attention_mask = None + attn_mask_type = AttnMaskType.causal + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + attn_mask_type = AttnMaskType.paddedcausal + + attention = ColoAttention(embed_dim=self.embed_dim, + num_heads=self.num_heads, + dropout=self.dropout, + scale=self.scaling) + attn_output = attention(query_states, + key_states, + value_states, + attn_mask=flash_attention_mask, + attn_mask_type=attn_mask_type) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value + + return forward + + +def get_jit_fused_opt_decoder_layer_forward(): + + from transformers.models.opt.modeling_opt import OPTDecoderLayer + + def forward( + self: OPTDecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + return forward diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py index 63ebfe89d5fa..c40c02ec411a 100644 --- a/colossalai/shardformer/modeling/sam.py +++ b/colossalai/shardformer/modeling/sam.py @@ -1,4 +1,9 @@ +import math +from typing import Tuple + import torch +import torch.nn.functional as F +from torch import Tensor def forward_fn(): @@ -37,3 +42,162 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch return outputs return forward + + +def get_sam_flash_attention_forward(): + + from transformers.models.sam.modeling_sam import SamAttention + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor: + batch, point_batch_size, n_tokens, channel = hidden_states.shape + c_per_head = channel // num_attention_heads + hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) + return hidden_states + + def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor: + batch, n_tokens, n_heads, c_per_head = hidden_states.shape + return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) + + def forward(self: SamAttention, + query: Tensor, + key: Tensor, + value: Tensor, + attention_similarity: Tensor = None) -> Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = query.shape[1] + # Separate into heads + query = _separate_heads(query, self.num_attention_heads) + key = _separate_heads(key, self.num_attention_heads) + value = _separate_heads(value, self.num_attention_heads) + + # SamAttention + _, _, _, c_per_head = query.shape + bias = None + if attention_similarity is not None: + bias = attention_similarity + + scale = 1.0 / math.sqrt(c_per_head) + out = me_attention(query, key, value, attn_bias=bias, scale=scale) + + out = _recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + + return forward + + +def get_sam_vision_flash_attention_forward(): + + from transformers.models.sam.modeling_sam import SamVisionAttention + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + def add_decomposed_rel_pos( + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (`torch.Tensor`): + attention map. + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + attn (`torch.Tensor`): + attention map with added relative positional embeddings. + """ + + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, nHead, dim = query.shape + reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width) + return rel_pos + + def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, + -1).permute(2, 0, 1, 3, 4)) + + query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0) + + rel_pos = None + if self.use_rel_pos: + rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)) + + attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale) + + attn_output = attn_output.reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + outputs = (attn_output, None) + + return outputs + + return forward diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 7eb4d17928d6..0b3486e87c7e 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -587,3 +587,209 @@ def t5_encoder_model_forward( decoder_starting_stage=decoder_starting_stage) return outputs + + +def get_t5_flash_attention_forward(): + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from transformers.models.t5.modeling_t5 import T5Attention + + def forward( + self: T5Attention, + hidden_states: torch.Tensor, + mask: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + layer_head_mask: Optional[torch.Tensor] = None, + query_length: Optional[int] = None, + use_cache: bool = False, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + if len(past_key_value) != 2: + raise ValueError( + f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ) + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) + + def unshape(states): + """reshape""" + return states.view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=1) + elif past_key_value.shape[1] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project(hidden_states, self.k, key_value_states, + past_key_value[0] if past_key_value is not None else None) + value_states = project(hidden_states, self.v, key_value_states, + past_key_value[1] if past_key_value is not None else None) + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), + device=query_states.device, + dtype=query_states.dtype) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1):, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + position_bias_masked = position_bias_masked.contiguous() + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=position_bias_masked, + p=self.dropout, + scale=1.0) + attn_output = unshape(attn_output) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + return outputs + + return forward + + +def get_jit_fused_T5_layer_ff_forward(): + + from transformers.models.t5.modeling_t5 import T5LayerFF + + def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor: + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = self.dropout_add(forwarded_states, hidden_states, self.dropout.p, self.dropout.training) + return hidden_states + + return forward + + +def get_T5_layer_self_attention_forward(): + + from transformers.models.t5.modeling_t5 import T5LayerSelfAttention + + def forward( + self: T5LayerSelfAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + return forward + + +def get_T5_layer_cross_attention_forward(): + + from transformers.models.t5.modeling_t5 import T5LayerCrossAttention + + def forward( + self: T5LayerCrossAttention, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + query_length: Optional[int] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + return forward diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index f28c13ad0aa2..22c4dd998cac 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -1,4 +1,5 @@ import logging +import math from typing import Dict, List, Optional, Set, Tuple, Union import torch @@ -335,3 +336,51 @@ def pp_forward( ) return pp_forward + + +def get_vit_flash_self_attention_forward(): + + from transformers.models.vit.modeling_vit import ViTSelfAttention + + from colossalai.kernel.cuda_native.flash_attention import ColoAttention + + def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) + x = x.view(new_x_shape) + return x + + def forward(self: ViTSelfAttention, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size) + value_layer = transpose_for_scores(self.value(hidden_states), self.num_attention_heads, + self.attention_head_size) + query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size) + + scale = 1.0 / math.sqrt(self.attention_head_size) + attention = ColoAttention(embed_dim=self.all_head_size, + num_heads=self.num_attention_heads, + dropout=self.dropout.p, + scale=scale) + context_layer = attention(query_layer, key_layer, value_layer) + + outputs = (context_layer,) + + return outputs + + return forward + + +def get_jit_fused_vit_output_forward(): + + from transformers.models.vit.modeling_vit import ViTOutput + + def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py new file mode 100644 index 000000000000..6bc387ac8974 --- /dev/null +++ b/colossalai/shardformer/modeling/whisper.py @@ -0,0 +1,249 @@ +from typing import Optional, Tuple + +import torch +from torch import nn + + +def get_whisper_flash_attention_forward(): + + from transformers.models.whisper.modeling_whisper import WhisperAttention + + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + + def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): + return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() + + def forward( + self: WhisperAttention, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if (is_cross_attention and past_key_value is not None + and past_key_value[0].shape[1] == key_value_states.shape[1]): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) + value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + # get query proj + query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim) + + src_len = key_states.size(1) + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + + attn_type = None + flash_attention_mask = None + + if self.is_decoder: + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous()) + attn_type = AttnMaskType.paddedcausal + + attention = ColoAttention(embed_dim=self.embed_dim, + num_heads=self.num_heads, + dropout=self.dropout, + scale=self.scaling) + attn_output = attention(query_states, + key_states, + value_states, + attn_mask=flash_attention_mask, + attn_mask_type=attn_type) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + return forward + + +def get_jit_fused_whisper_encoder_layer_forward(): + + from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer + + def forward( + self: WhisperEncoderLayer, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + return forward + + +def get_jit_fused_whisper_decoder_layer_forward(): + + from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer + + def forward( + self: WhisperDecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 6f86de232fad..ace9ada3904f 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -7,7 +7,14 @@ import colossalai.shardformer.layer as col_nn -from ..modeling.bert import BertPipelineForwards +from .._utils import getattr_, setattr_ +from ..modeling.bert import ( + BertPipelineForwards, + get_bert_flash_attention_forward, + get_jit_fused_bert_output_forward, + get_jit_fused_bert_self_output_forward, +) +from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -37,7 +44,13 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer + from transformers.models.bert.modeling_bert import ( + BertEmbeddings, + BertLayer, + BertOutput, + BertSelfAttention, + BertSelfOutput, + ) policy = {} @@ -126,6 +139,23 @@ def module_policy(self): policy=policy, target_key=BertEmbeddings) + # use flash attention + if self.shard_config.enable_flash_attention: + policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_bert_flash_attention_forward(), + }) + + # use jit operator + if self.shard_config.enable_jit_fused: + policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bert_self_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[BertOutput] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bert_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + return policy def add_lm_head_policy(self, base_policy): diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index a244d70b56f5..50356302e93e 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -3,7 +3,13 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from ..modeling.blip2 import forward_fn +from ..modeling.blip2 import ( + forward_fn, + get_blip2_flash_attention_forward, + get_jit_fused_blip2_QFormer_output_forward, + get_jit_fused_blip2_QFormer_self_output_forward, +) +from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['BlipPolicy', 'BlipModelPolicy'] @@ -33,6 +39,8 @@ def module_policy(self): Blip2EncoderLayer, Blip2QFormerLayer, Blip2QFormerModel, + Blip2QFormerOutput, + Blip2QFormerSelfOutput, Blip2VisionModel, ) from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM @@ -275,6 +283,24 @@ def module_policy(self): policy=policy, target_key=OPTDecoderLayer) + # use flash attention + if self.shard_config.enable_flash_attention: + policy[Blip2Attention] = ModulePolicyDescription(method_replacement={ + 'forward': get_blip2_flash_attention_forward(), + }) + + # use jit operator + if self.shard_config.enable_jit_fused: + policy[Blip2QFormerSelfOutput] = ModulePolicyDescription( + method_replacement={ + 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_blip2_QFormer_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 15bae2f4a959..b35764db3870 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -7,7 +7,16 @@ import colossalai.shardformer.layer as col_nn -from ..modeling.bloom import BloomPipelineForwards, build_bloom_alibi_tensor_fn +from .._utils import getattr_, setattr_ +from ..modeling.bloom import ( + BloomPipelineForwards, + build_bloom_alibi_tensor_fn, + get_bloom_flash_attention_forward, + get_jit_fused_bloom_attention_forward, + get_jit_fused_bloom_gelu_forward, + get_jit_fused_bloom_mlp_forward, +) +from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -30,7 +39,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel + from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomGelu, BloomMLP, BloomModel policy = {} @@ -107,6 +116,27 @@ def module_policy(self): policy=policy, target_key=BloomBlock) + if self.shard_config.enable_flash_attention: + policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_bloom_flash_attention_forward(), + 'dropout_add': get_dropout_add_func() + }) + + # enable jit fused operator + if self.shard_config.enable_jit_fused: + policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bloom_attention_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[BloomMLP] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bloom_mlp_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[BloomGelu] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bloom_gelu_forward(), + 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(), + }) + return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 9cc651caddc1..e6b458936637 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -15,6 +15,8 @@ GLMBlock, ) +from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward +from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] @@ -35,12 +37,11 @@ def preprocess(self): new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) - return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock + from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock policy = {} @@ -121,6 +122,19 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=ChatGLMModel) + # use flash attention + if self.shard_config.enable_flash_attention: + policy[CoreAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_flash_core_attention_forward(), + }) + + # use jit fused operator + if self.shard_config.enable_jit_fused: + policy[GLMBlock] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_glm_block_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + return policy def postprocess(self): @@ -192,7 +206,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] - class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): def module_policy(self): @@ -213,4 +226,3 @@ def get_held_layers(self) -> List[nn.Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in ChatGLMForConditionalGenerationModel.""" return [] - diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 6d734b063036..20e5fa372c8f 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -5,7 +5,8 @@ import colossalai.shardformer.layer as col_nn -from ..modeling.gpt2 import GPT2PipelineForwards +from .._utils import getattr_, setattr_ +from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -33,7 +34,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model + from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model policy = {} @@ -53,42 +54,42 @@ def module_policy(self): "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: @@ -96,8 +97,8 @@ def module_policy(self): suffix="ln_f", target_module=col_nn.FusedLayerNorm, ), - policy=policy, - target_key=GPT2Model) + policy=policy, + target_key=GPT2Model) self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( @@ -112,8 +113,13 @@ def module_policy(self): target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True) ], - policy=policy, - target_key=GPT2Block) + policy=policy, + target_key=GPT2Block) + + if self.shard_config.enable_flash_attention: + policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ + 'forward': get_gpt2_flash_attention_forward(), + }) return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 5988366ed57b..5ee95f3be8fa 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -7,7 +7,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from ..modeling.llama import LlamaPipelineForwards +from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] @@ -31,7 +31,7 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel policy = {} @@ -104,6 +104,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=LlamaModel) + if self.shard_config.enable_flash_attention: + policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_llama_flash_attention_forward(), + }) + return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 6fc3a2d31f4d..88ecd8565091 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -25,6 +25,8 @@ from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .._utils import getattr_, setattr_ +from ..modeling.jit import get_jit_fused_dropout_add_func +from ..modeling.opt import get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -114,6 +116,19 @@ def module_policy(self): policy=policy, target_key=OPTDecoderLayer) + # use flash attention + if self.shard_config.enable_flash_attention: + policy[OPTAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_opt_flash_attention_forward(), + }) + + # use jit fused operator + if self.shard_config.enable_jit_fused: + policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_opt_decoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + return policy def postprocess(self): @@ -189,13 +204,11 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() - if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), policy=policy, target_key=OPTForCausalLM) - if self.pipeline_stage_manager: self.set_pipeline_forward(model_cls=OPTForCausalLM, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index ca20fff715f2..b1eba0432b49 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -3,7 +3,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from ..modeling.sam import forward_fn +from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['SamPolicy', 'SamModelPolicy'] @@ -19,6 +19,7 @@ def preprocess(self): def module_policy(self): from transformers.models.sam.modeling_sam import ( + SamAttention, SamFeedForward, SamTwoWayAttentionBlock, SamTwoWayTransformer, @@ -196,6 +197,15 @@ def module_policy(self): policy=policy, target_key=SamTwoWayTransformer) + # use flash attention + if self.shard_config.enable_flash_attention: + policy[SamAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_sam_flash_attention_forward(), + }) + policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_sam_vision_flash_attention_forward(), + }) + return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 0ee18d6c4940..5e78ae9093fa 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -14,7 +14,14 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription from .._utils import getattr_, setattr_ -from ..modeling.t5 import T5PipelineForwards +from ..modeling.jit import get_jit_fused_dropout_add_func +from ..modeling.t5 import ( + T5PipelineForwards, + get_jit_fused_T5_layer_ff_forward, + get_t5_flash_attention_forward, + get_T5_layer_cross_attention_forward, + get_T5_layer_self_attention_forward, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] @@ -168,6 +175,27 @@ def module_policy(self): suffix="final_layer_norm", target_module=FusedRMSNorm), policy=policy, target_key=T5Stack) + + # use flash attention + if self.shard_config.enable_flash_attention: + policy[T5Attention] = ModulePolicyDescription(method_replacement={ + 'forward': get_t5_flash_attention_forward(), + }) + + # use jit operator + if self.shard_config.enable_jit_fused: + policy[T5LayerFF] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_T5_layer_ff_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_T5_layer_self_attention_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_T5_layer_cross_attention_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 1feb11ffcf24..26fcb6e77d35 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -3,11 +3,15 @@ import torch.nn as nn import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col +from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.vit import ( ViTForImageClassification_pipeline_forward, ViTForMaskedImageModeling_pipeline_forward, ViTModel_pipeline_forward, + get_jit_fused_vit_output_forward, + get_vit_flash_self_attention_forward, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -23,7 +27,8 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel + + from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel, ViTOutput, ViTSelfAttention policy = {} @@ -33,7 +38,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=col_nn.DropoutForReplicatedInput, + target_module=DropoutForReplicatedInput, ) ]) @@ -83,8 +88,18 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ), ]) - return policy - + # use flash attention + if self.shard_config.enable_flash_attention: + policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_vit_flash_self_attention_forward(), + }) + + # use jit fused operator + if self.shard_config.enable_jit_fused: + policy[ViTOutput] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_vit_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) return policy def new_model_class(self): @@ -167,7 +182,7 @@ def module_policy(self): ViTForImageClassification: ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( - suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) + suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) ]) } policy.update(new_item) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 2f3565bdaa96..2ac7a49fd27b 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -3,6 +3,12 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ +from ..modeling.jit import get_jit_fused_dropout_add_func +from ..modeling.whisper import ( + get_jit_fused_whisper_decoder_layer_forward, + get_jit_fused_whisper_encoder_layer_forward, + get_whisper_flash_attention_forward, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -30,6 +36,7 @@ def preprocess(self): def module_policy(self): from transformers.models.whisper.modeling_whisper import ( + WhisperAttention, WhisperDecoder, WhisperDecoderLayer, WhisperEncoder, @@ -181,6 +188,24 @@ def module_policy(self): ], policy=policy, target_key=WhisperDecoder) + + # enable flash attention + if self.shard_config.enable_flash_attention: + policy[WhisperAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_whisper_flash_attention_forward(), + }) + + # use jit fused operator + if self.shard_config.enable_jit_fused: + policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_whisper_encoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_whisper_decoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + return policy def add_lm_head_policy(self, base_policy): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 75fad4eb7431..ec6e0cd0d4be 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -26,6 +26,8 @@ class ShardConfig: enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False + enable_flash_attention: bool = False + enable_jit_fused: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int @@ -44,7 +46,6 @@ def __post_init__(self): else: # get the parallel size self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) - # turn on all optimization if all_optimization is set to True if self.enable_all_optimization: self._turn_on_all_optimization() @@ -55,3 +56,5 @@ def _turn_on_all_optimization(self): """ # you can add all the optimization flag here self.enable_fused_normalization = True + self.enable_flash_attention = True + self.enable_jit_fused = True diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 2dae645f7eb9..510af5f3c7ff 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -18,3 +18,5 @@ SentencePiece ninja flash_attn>=2.0 datasets +ninja +flash-attn diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index d17b8fda425a..9834f5425027 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -20,7 +20,7 @@ def data_gen(): # token_type_ids = tokenized_input['token_type_ids'] input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64) token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.int64) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) @@ -69,19 +69,21 @@ def data_gen_for_mcq(): # data['labels'] = torch.tensor([0], dtype=torch.int64) input_ids = torch.tensor([[[ 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, - 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102 + 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102 ], [ 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096, - 2218, 1999, 1996, 2192, 1012, 102, 0 + 2218, 1999, 1996, 2192, 1012, 102, 0, 0 ]]]) token_type_ids = torch.tensor( - [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) + [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0]]]) attention_mask = torch.tensor( - [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) + [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0]]]) labels = torch.tensor([0], dtype=torch.int64) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) diff --git a/tests/kit/model_zoo/transformers/blip2.py b/tests/kit/model_zoo/transformers/blip2.py index 7338f740be7f..984a6ffa920d 100644 --- a/tests/kit/model_zoo/transformers/blip2.py +++ b/tests/kit/model_zoo/transformers/blip2.py @@ -38,6 +38,7 @@ def data_gen(): loss_fn_blip2_model = lambda x: x.loss config = transformers.Blip2Config() +config.vision_config.patch_size = 14 config.text_config.num_hidden_layers = 1 config.qformer_config.num_hidden_layers = 1 config.vision_config.num_hidden_layers = 1 diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py index 5d195db2c68d..177edbef8935 100644 --- a/tests/kit/model_zoo/transformers/bloom.py +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -16,8 +16,8 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt') # input_ids = tokenized_input['input_ids'] # attention_mask = tokenized_input['attention_mask'] - input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595, 632, 207595]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -33,7 +33,7 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) + data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) return data @@ -53,8 +53,8 @@ def data_gen_for_question_answering(): # inputs = tokenizer(question, text, return_tensors="pt") input_ids = torch.tensor( - [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) start_positions = torch.tensor([1], dtype=torch.int64) end_positions = torch.tensor([10], dtype=torch.int64) return dict(input_ids=input_ids, diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 056c910a8dfe..90bb70bc7f79 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -6,7 +6,6 @@ from ..registry import ModelAttribute, model_zoo - # ================================ # Register single-sentence ChatGLM # ================================ diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py deleted file mode 100644 index 3e78732be2da..000000000000 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py +++ /dev/null @@ -1,58 +0,0 @@ -from transformers import PretrainedConfig - - -class ChatGLMConfig(PretrainedConfig): - model_type = "chatglm" - - def __init__(self, - num_layers=28, - padded_vocab_size=65024, - hidden_size=4096, - ffn_hidden_size=13696, - kv_channels=128, - num_attention_heads=32, - seq_length=2048, - hidden_dropout=0.0, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - rmsnorm=True, - apply_residual_connection_post_layernorm=False, - post_layer_norm=True, - add_bias_linear=False, - add_qkv_bias=False, - bias_dropout_fusion=True, - multi_query_attention=False, - multi_query_group_num=1, - apply_query_key_layer_scaling=True, - attention_softmax_in_fp32=True, - fp32_residual_connection=False, - quantization_bit=0, - pre_seq_len=None, - prefix_projection=False, - **kwargs): - self.num_layers = num_layers - self.vocab_size = padded_vocab_size - self.padded_vocab_size = padded_vocab_size - self.hidden_size = hidden_size - self.ffn_hidden_size = ffn_hidden_size - self.kv_channels = kv_channels - self.num_attention_heads = num_attention_heads - self.seq_length = seq_length - self.hidden_dropout = hidden_dropout - self.attention_dropout = attention_dropout - self.layernorm_epsilon = layernorm_epsilon - self.rmsnorm = rmsnorm - self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm - self.post_layer_norm = post_layer_norm - self.add_bias_linear = add_bias_linear - self.add_qkv_bias = add_qkv_bias - self.bias_dropout_fusion = bias_dropout_fusion - self.multi_query_attention = multi_query_attention - self.multi_query_group_num = multi_query_group_num - self.apply_query_key_layer_scaling = apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = attention_softmax_in_fp32 - self.fp32_residual_connection = fp32_residual_connection - self.quantization_bit = quantization_bit - self.pre_seq_len = pre_seq_len - self.prefix_projection = prefix_projection - super().__init__(**kwargs) diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py deleted file mode 100644 index bae6d425878d..000000000000 --- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py +++ /dev/null @@ -1,1372 +0,0 @@ -""" -The ChatGLM2-6B License - -1. Definitions - -“Licensor” means the ChatGLM2-6B Model Team that distributes its Software. - -“Software” means the ChatGLM2-6B model parameters made available under this license. - -2. License Grant - -Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -3. Restriction - -You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. - -You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. - -4. Disclaimer - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -5. Limitation of Liability - -EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. - -6. Dispute Resolution - -This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. - -Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. -""" -""" PyTorch ChatGLM model. """ - -import copy -import math -import re -import sys -import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm -from torch.nn.utils import skip_init -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging - -from .configuration_chatglm import ChatGLMConfig - -# flags required to enable jit fusion kernels - -if sys.platform != "darwin": - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B" -_CONFIG_FOR_DOC = "ChatGLM6BConfig" - -CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "THUDM/chatglm2-6b", - # See all ChatGLM models at https://huggingface.co/models?filter=chatglm -] - - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 5] = 5e4 - return scores - - -class PrefixEncoder(torch.nn.Module): - """ - The torch.nn model to encode the prefix - Input shape: (batch-size, prefix-length) - Output shape: (batch-size, prefix-length, 2*layers*hidden) - """ - - def __init__(self, config: ChatGLMConfig): - super().__init__() - self.prefix_projection = config.prefix_projection - if self.prefix_projection: - # Use a two-layer MLP to encode the prefix - kv_size = (config.num_layers * config.kv_channels * config.multi_query_group_num * 2) - self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) - self.trans = torch.nn.Sequential( - torch.nn.Linear(kv_size, config.hidden_size), - torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, kv_size), - ) - else: - self.embedding = torch.nn.Embedding( - config.pre_seq_len, - config.num_layers * config.kv_channels * config.multi_query_group_num * 2, - ) - - def forward(self, prefix: torch.Tensor): - if self.prefix_projection: - prefix_tokens = self.embedding(prefix) - past_key_values = self.trans(prefix_tokens) - else: - past_key_values = self.embedding(prefix) - return past_key_values - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -class RotaryEmbedding(nn.Module): - - def __init__(self, dim, original_impl=False, device=None, dtype=None): - super().__init__() - inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.dim = dim - self.original_impl = original_impl - - def forward_impl( - self, - seq_len: int, - n_elem: int, - dtype: torch.dtype, - device: torch.device, - base: int = 10000, - ): - """Enhanced Transformer with Rotary Position Embedding. - - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. - """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, dtype=dtype, device=device) - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.outer(seq_idx, theta).float() - - cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - - # this is to mimic the behaviour of complex32, else we will get different results - if dtype in (torch.float16, torch.bfloat16, torch.int8): - cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() - return cache - - def forward(self, max_seq_len, offset=0): - return self.forward_impl( - max_seq_len, - self.dim, - dtype=self.inv_freq.dtype, - device=self.inv_freq.device, - ) - - -@torch.jit.script -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [sq, b, np, hn] - sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) - rot_dim = rope_cache.shape[-2] * 2 - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:sq] - xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) - rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) - x_out2 = x_out2.flatten(3) - return torch.cat((x_out2, x_pass), dim=-1) - - -class RMSNorm(torch.nn.Module): - - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) - self.eps = eps - - def forward(self, hidden_states: torch.Tensor): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - - return (self.weight * hidden_states).to(input_dtype) - - -class CoreAttention(torch.nn.Module): - - def __init__(self, config: ChatGLMConfig, layer_number): - super(CoreAttention, self).__init__() - - self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - - projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = (projection_size // config.num_attention_heads) - self.num_attention_heads_per_partition = config.num_attention_heads - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - self.coeff = coeff - - self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - - def forward(self, query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split(".")[0]) - if pytorch_major_version >= 2: - query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - is_causal=True) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) - context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - else: - # Raw attention scores - - # [b, np, sq, sk] - output_size = ( - query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0), - ) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=query_layer.device, - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.attention_softmax_in_fp32: - attention_scores = attention_scores.float() - if self.coeff is not None: - attention_scores = attention_scores * self.coeff - if (attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]): - attention_mask = torch.ones( - output_size[0], - 1, - output_size[2], - output_size[3], - device=attention_scores.device, - dtype=torch.bool, - ) - attention_mask.tril_() - attention_mask = ~attention_mask - if attention_mask is not None: - attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.type_as(value_layer) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = ( - value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3), - ) - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - - -class SelfAttention(torch.nn.Module): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(SelfAttention, self).__init__() - self.layer_number = max(1, layer_number) - - self.projection_size = config.kv_channels * config.num_attention_heads - # Per attention head and per partition values. - self.hidden_size_per_attention_head = (self.projection_size // config.num_attention_heads) - self.num_attention_heads_per_partition = config.num_attention_heads - - self.multi_query_attention = config.multi_query_attention - self.qkv_hidden_size = 3 * self.projection_size - if self.multi_query_attention: - self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = (self.projection_size + - 2 * self.hidden_size_per_attention_head * config.multi_query_group_num) - self.query_key_value = nn.Linear( - config.hidden_size, - self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, - device=device, - **_config_to_kwargs(config), - ) - - self.core_attention = CoreAttention(config, self.layer_number) - - # Output. - self.dense = nn.Linear( - self.projection_size, - config.hidden_size, - bias=config.add_bias_linear, - device=device, - **_config_to_kwargs(config), - ) - - def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): - if self.multi_query_attention: - num_attention_heads = self.num_multi_query_groups_per_partition - else: - num_attention_heads = self.num_attention_heads_per_partition - return torch.empty( - inference_max_sequence_len, - batch_size, - num_attention_heads, - self.hidden_size_per_attention_head, - dtype=dtype, - device=device, - ) - - def forward( - self, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=None, - use_cache=True, - ): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view(query_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - )) - key_layer = key_layer.view(key_layer.size()[:-1] + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - )) - value_layer = value_layer.view(value_layer.size()[:-1] + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - )) - else: - new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - # adjust key and value for inference - if kv_cache is not None: - cache_k, cache_v = kv_cache - key_layer = torch.cat((cache_k, key_layer), dim=0) - value_layer = torch.cat((cache_v, value_layer), dim=0) - if use_cache: - kv_cache = (key_layer, value_layer) - else: - kv_cache = None - - if self.multi_query_attention: - key_layer = key_layer.unsqueeze(-2) - key_layer = key_layer.expand( - -1, - -1, - -1, - self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, - -1, - ) - key_layer = key_layer.contiguous().view(key_layer.size()[:2] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - )) - value_layer = value_layer.unsqueeze(-2) - value_layer = value_layer.expand( - -1, - -1, - -1, - self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, - -1, - ) - value_layer = value_layer.contiguous().view(value_layer.size()[:2] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - )) - - # ================================== - # core attention computation - # ================================== - - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - - # ================= - # Output. [sq, b, h] - # ================= - - output = self.dense(context_layer) - - return output, kv_cache - - -def _config_to_kwargs(args): - common_kwargs = { - "dtype": args.torch_dtype, - } - return common_kwargs - - -class MLP(torch.nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, config: ChatGLMConfig, device=None): - super(MLP, self).__init__() - - self.add_bias = config.add_bias_linear - - # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear( - config.hidden_size, - config.ffn_hidden_size * 2, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config), - ) - - def swiglu(x): - x = torch.chunk(x, 2, dim=-1) - return F.silu(x[0]) * x[1] - - self.activation_func = swiglu - - # Project back to h. - self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, - config.hidden_size, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config), - ) - - def forward(self, hidden_states): - # [s, b, 4hp] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] - output = self.dense_4h_to_h(intermediate_parallel) - return output - - -class GLMBlock(torch.nn.Module): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(GLMBlock, self).__init__() - self.layer_number = layer_number - - self.apply_residual_connection_post_layernorm = (config.apply_residual_connection_post_layernorm) - - self.fp32_residual_connection = config.fp32_residual_connection - - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Layernorm on the input data. - self.input_layernorm = LayerNormFunc( - config.hidden_size, - eps=config.layernorm_epsilon, - device=device, - dtype=config.torch_dtype, - ) - - # Self attention. - self.self_attention = SelfAttention(config, layer_number, device=device) - self.hidden_dropout = config.hidden_dropout - - # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc( - config.hidden_size, - eps=config.layernorm_epsilon, - device=device, - dtype=config.torch_dtype, - ) - - # MLP - self.mlp = MLP(config, device=device) - - def forward( - self, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=None, - use_cache=True, - ): - # hidden_states: [s, b, h] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, - attention_mask, - rotary_pos_emb, - kv_cache=kv_cache, - use_cache=use_cache, - ) - - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + layernorm_input - - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - - # MLP. - mlp_output = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) - output = residual + output - - return output, kv_cache - - -class GLMTransformer(torch.nn.Module): - """Transformer class.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(GLMTransformer, self).__init__() - - self.fp32_residual_connection = config.fp32_residual_connection - self.post_layer_norm = config.post_layer_norm - - # Number of layers. - self.num_layers = config.num_layers - - # Transformer layers. - def build_layer(layer_number): - return GLMBlock(config, layer_number, device=device) - - self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) - - if self.post_layer_norm: - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Final layer norm before output. - self.final_layernorm = LayerNormFunc( - config.hidden_size, - eps=config.layernorm_epsilon, - device=device, - dtype=config.torch_dtype, - ) - - self.gradient_checkpointing = False - - def _get_layer(self, layer_number): - return self.layers[layer_number] - - def forward( - self, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_caches=None, - use_cache: Optional[bool] = True, - output_hidden_states: Optional[bool] = False, - ): - if not kv_caches: - kv_caches = [None for _ in range(self.num_layers)] - presents = () if use_cache else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - - all_self_attentions = None - all_hidden_states = () if output_hidden_states else None - for index in range(self.num_layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer = self._get_layer(index) - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_caches[index], - use_cache, - ) - else: - layer_ret = layer( - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[index], - use_cache=use_cache, - ) - hidden_states, kv_cache = layer_ret - if use_cache: - presents = presents + (kv_cache,) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # Final layer norm. - if self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states, presents, all_hidden_states, all_self_attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_ids, past_key_values, padding_mask=None): - batch_size, seq_length = input_ids.shape - full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) - full_attention_mask.tril_() - past_length = 0 - if past_key_values: - past_length = past_key_values[0][0].shape[0] - if past_length: - full_attention_mask = torch.cat( - ( - torch.ones(batch_size, seq_length, past_length, device=input_ids.device), - full_attention_mask, - ), - dim=-1, - ) - if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) - if not past_length and padding_mask is not None: - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - - def get_position_ids(self, input_ids, device): - batch_size, seq_length = input_ids.shape - position_ids = (torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)) - return position_ids - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, GLMTransformer): - module.gradient_checkpointing = value - - -class Embedding(torch.nn.Module): - """Language model embeddings.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(Embedding, self).__init__() - - self.hidden_size = config.hidden_size - # Word embeddings (parallel). - self.word_embeddings = nn.Embedding( - config.padded_vocab_size, - self.hidden_size, - dtype=config.torch_dtype, - device=device, - ) - self.fp32_residual_connection = config.fp32_residual_connection - - def forward(self, input_ids): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - embeddings = words_embeddings - # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - return embeddings - - -class ChatGLMModel(ChatGLMPreTrainedModel): - - def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - init_kwargs = {} - if device is not None: - init_kwargs["device"] = device - self.embedding = init_method(Embedding, config, **init_kwargs) - self.num_layers = config.num_layers - self.multi_query_group_num = config.multi_query_group_num - self.kv_channels = config.kv_channels - - # Rotary positional embeddings - self.seq_length = config.seq_length - rotary_dim = (config.hidden_size // - config.num_attention_heads if config.kv_channels is None else config.kv_channels) - - self.rotary_pos_emb = RotaryEmbedding( - rotary_dim // 2, - original_impl=config.original_rope, - device=device, - dtype=config.torch_dtype, - ) - self.encoder = init_method(GLMTransformer, config, **init_kwargs) - self.output_layer = init_method( - nn.Linear, - config.hidden_size, - config.padded_vocab_size, - bias=False, - dtype=config.torch_dtype, - **init_kwargs, - ) - self.pre_seq_len = config.pre_seq_len - self.prefix_projection = config.prefix_projection - if self.pre_seq_len is not None: - for param in self.parameters(): - param.requires_grad = False - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - self.prefix_encoder = PrefixEncoder(config) - self.dropout = torch.nn.Dropout(0.1) - - def get_input_embeddings(self): - return self.embedding.word_embeddings - - def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = (self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)) - past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view( - batch_size, - self.pre_seq_len, - self.num_layers * 2, - self.multi_query_group_num, - self.kv_channels, - ) - # seq_len, b, nh, hidden_size - past_key_values = self.dropout(past_key_values) - past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - return past_key_values - - def forward( - self, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) - - batch_size, seq_length = input_ids.shape - - if inputs_embeds is None: - inputs_embeds = self.embedding(input_ids) - - if self.pre_seq_len is not None: - if past_key_values is None: - past_key_values = self.get_prompt( - batch_size=batch_size, - device=input_ids.device, - dtype=inputs_embeds.dtype, - ) - if attention_mask is not None: - attention_mask = torch.cat( - [ - attention_mask.new_ones((batch_size, self.pre_seq_len)), - attention_mask, - ], - dim=-1, - ) - - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) - - # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] - rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() - - # Run encoder. - hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( - inputs_embeds, - full_attention_mask, - rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - ) - - if not return_dict: - return tuple(v for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - ] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def quantize(self, weight_bit_width: int): - from .quantization import quantize - - quantize(self.encoder, weight_bit_width) - return self - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.max_sequence_length = config.max_length - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - self.config = config - self.quantized = False - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], - dim=-1, - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id += 1 - model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) - - model_kwargs["is_first_forward"] = False - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - is_first_forward: bool = True, - **kwargs, - ) -> dict: - # only last token for input_ids if past is not None - if position_ids is None: - position_ids = self.get_position_ids(input_ids, device=input_ids.device) - if not is_first_forward: - position_ids = position_ids[..., -1:] - input_ids = input_ids[:, -1:] - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "position_ids": position_ids, - "attention_mask": attention_mask, - "return_last_logit": True, - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - if return_last_logit: - hidden_states = hidden_states[-1:] - lm_logits = self.transformer.output_layer(hidden_states) - lm_logits = lm_logits.transpose(0, 1).contiguous() - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], - beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple(( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) for layer_past in past) - - def process_response(self, response): - response = response.strip() - response = response.replace("[[训练时间]]", "2023年") - return response - - def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): - prompt = tokenizer.build_prompt(query, history=history) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - return inputs - - def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): - if history: - prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) - input_ids = tokenizer.encode(prompt, add_special_tokens=False) - input_ids = input_ids[1:] - inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) - else: - prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - return inputs - - @torch.no_grad() - def chat( - self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = None, - max_length: int = 8192, - num_beams=1, - do_sample=True, - top_p=0.8, - temperature=0.8, - logits_processor=None, - **kwargs, - ): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = { - "max_length": max_length, - "num_beams": num_beams, - "do_sample": do_sample, - "top_p": top_p, - "temperature": temperature, - "logits_processor": logits_processor, - **kwargs, - } - inputs = self.build_inputs(tokenizer, query, history=history) - outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - response = self.process_response(response) - history = history + [(query, response)] - return response, history - - @torch.no_grad() - def stream_chat( - self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = None, - past_key_values=None, - max_length: int = 8192, - do_sample=True, - top_p=0.8, - temperature=0.8, - logits_processor=None, - return_past_key_values=False, - **kwargs, - ): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = { - "max_length": max_length, - "do_sample": do_sample, - "top_p": top_p, - "temperature": temperature, - "logits_processor": logits_processor, - **kwargs, - } - if past_key_values is None and not return_past_key_values: - inputs = self.build_inputs(tokenizer, query, history=history) - else: - inputs = self.build_stream_inputs(tokenizer, query, history=history) - if past_key_values is not None: - past_length = past_key_values[0][0].shape[0] - if self.transformer.pre_seq_len is not None: - past_length -= self.transformer.pre_seq_len - inputs.position_ids += past_length - attention_mask = inputs.attention_mask - attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) - inputs["attention_mask"] = attention_mask - for outputs in self.stream_generate( - **inputs, - past_key_values=past_key_values, - return_past_key_values=return_past_key_values, - **gen_kwargs, - ): - if return_past_key_values: - outputs, past_key_values = outputs - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - if response and response[-1] != "�": - response = self.process_response(response) - new_history = history + [(query, response)] - if return_past_key_values: - yield response, new_history, past_key_values - else: - yield response, new_history - - @torch.no_grad() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - return_past_key_values=False, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - bos_token_id, eos_token_id = ( - generation_config.bos_token_id, - generation_config.eos_token_id, - ) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - - has_default_max_length = (kwargs.get("max_length") is None and generation_config.max_length is not None) - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = (generation_config.max_new_tokens + input_ids_seq_length) - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = ("decoder_input_ids" if self.config.is_encoder_decoder else "input_ids") - logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`.") - - # 2. Set generation parameters if not already defined - logits_processor = (logits_processor if logits_processor is not None else LogitsProcessorList()) - stopping_criteria = (stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()) - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, - stopping_criteria=stopping_criteria) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation(outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder) - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) - if return_past_key_values: - yield input_ids, outputs.past_key_values - else: - yield input_ids - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - - def quantize(self, bits: int, empty_init=False, device=None, **kwargs): - if bits == 0: - return - - from .quantization import quantize - - if self.quantized: - logger.info("Already quantized.") - return self - - self.quantized = True - - self.config.quantization_bit = bits - - self.transformer.encoder = quantize( - self.transformer.encoder, - bits, - empty_init=empty_init, - device=device, - **kwargs, - ) - return self diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 73c210221e61..5c3eb4438bc8 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -18,8 +18,8 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt') # input_ids = tokenized_input['input_ids'] # attention_mask = tokenized_input['attention_mask'] - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -46,7 +46,7 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 1]], dtype=torch.int64) + data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64) return data diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 689db2c40abb..435cb6f46937 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -16,8 +16,9 @@ def data_gen_for_encoder_only(): # config = T5Config(decoder_start_token_id=0) # tokenizer = T5Tokenizer.from_pretrained("t5-small") # input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids - input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long() - return dict(input_ids=input_ids) + input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long() + return dict(input_ids=input_ids, attention_mask=attention_mask) def data_gen_for_conditional_generation(): @@ -25,17 +26,16 @@ def data_gen_for_conditional_generation(): # # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids data = data_gen_for_encoder_only() - labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long() + labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long() data['labels'] = labels return data def data_gen_for_t5_model(): # decoder_inputs_ids is obtained with the following code - # # decoder_input_ids = model._shift_right(input_ids) data = data_gen_for_encoder_only() - decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long() + decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long() data['decoder_input_ids'] = decoder_input_ids return data diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py index 40c96a5777ab..f7cdc052aaf0 100644 --- a/tests/kit/model_zoo/transformers/whisper.py +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -76,14 +76,14 @@ def data_gen_for_audio_classification(): loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_whisperForConditionalGeneration', +model_zoo.register(name='transformers_whisper_for_conditional_generation', model_fn=lambda: transformers.WhisperForConditionalGeneration(config), data_gen_fn=data_gen_for_conditional_generation, output_transform_fn=output_transform_fn, loss_fn=loss_fn_attr, model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_whisperWhisperForAudioClassification', +model_zoo.register(name='transformers_whisper_for_audio_classification', model_fn=lambda: transformers.WhisperForAudioClassification(config), data_gen_fn=data_gen_for_audio_classification, output_transform_fn=output_transform_fn, diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index a06b2c963bfe..fee153baf1ac 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -93,7 +93,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): 'transformers_vit_for_image_classification', 'transformers_chatglm', 'transformers_chatglm_for_conditional_generation', 'transformers_blip2', 'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper', - 'transformers_whisperForConditionalGeneration', 'transformers_whisperWhisperForAudioClassification' + 'transformers_whisper_for_conditional_generation', 'transformers_whisper_for_audio_classification' ]: continue diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 0e5cb8144ef3..98cdc5a4b95b 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -21,7 +21,13 @@ from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False): +def build_model(model_fn, + enable_fused_normalization=True, + enable_tensor_parallelism=True, + enable_flash_attention=False, + enable_jit_fused=False, + use_lazy_init: bool = False): + # create new model ctx = LazyInitContext() if use_lazy_init else nullcontext() with ctx: # create new model @@ -31,7 +37,10 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle ctx.materialize(org_model) # shard model shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism) + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused) + model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) return org_model.cuda(), sharded_model.cuda() diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 1d42f1c4703e..afc1507e8b24 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -46,14 +46,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False) -@parameterize('enable_fused_normalization', [False, True]) -@parameterize('enable_tensor_parallelism', [False, True]) +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +@parameterize('enable_flash_attention', [True, False]) +@parameterize('enable_jit_fused', [True, False]) @parameterize('use_lazy_init', [False, True]) -def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused, + use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - use_lazy_init) + enable_flash_attention, enable_jit_fused, use_lazy_init) check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py index cb9725f4de7f..cd034d0c139a 100644 --- a/tests/test_shardformer/test_model/test_shard_blip2.py +++ b/tests/test_shardformer/test_model/test_shard_blip2.py @@ -47,10 +47,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +@parameterize('enable_jit_fused', [True, False]) +def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention, enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index c13596fe8db3..e11bcf92ea3c 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -44,13 +44,15 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) +@parameterize('enable_flash_attention', [True, False]) +@parameterize('enable_jit_fused', [True, False]) @parameterize('use_lazy_init', [False, True]) -def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): +def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused, + use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - use_lazy_init) - check_state_dict(org_model, sharded_model, name=name) + enable_flash_attention, enable_jit_fused, use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 005223fb8ae4..c455a99d26ce 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -72,7 +72,9 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +@parameterize('enable_jit_fused', [True, False]) +def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): # create new model @@ -80,7 +82,9 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): # shard model shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism) + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) if name == "transformers_chatglm": diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index cebb40bd16fe..f7213d8c50b4 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -68,7 +68,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() - @parameterize('test_config', [{ 'tp_size': 1, 'pp_size': 2, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 2cfc172c8df6..ead14ab111e6 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -49,12 +49,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) +@parameterize('enable_flash_attention', [True, False]) @parameterize('use_lazy_init', [False, True]) -def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): +def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - use_lazy_init) + enable_flash_attention, use_lazy_init) check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 4684bacb4788..99a278d4303a 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -42,18 +42,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check grad col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] row_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] - check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False) - check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False) + check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False) + check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False) +@parameterize('use_lazy_init', [False, True]) @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('use_lazy_init', [False, True]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): +@parameterize('enable_flash_attention', [True, False]) +@parameterize('enable_jit_fused', [True, False]) +def run_opt_test(use_lazy_init, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, + enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - use_lazy_init) + enable_flash_attention, enable_jit_fused, use_lazy_init) check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() @@ -62,7 +65,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ def check_OPTModel(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_t5_test() + run_opt_test() @pytest.mark.dist diff --git a/tests/test_shardformer/test_model/test_shard_sam.py b/tests/test_shardformer/test_model/test_shard_sam.py index e7748cfd189d..616104cd7828 100644 --- a/tests/test_shardformer/test_model/test_shard_sam.py +++ b/tests/test_shardformer/test_model/test_shard_sam.py @@ -41,10 +41,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_sam_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_sam') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 024c5016b0c1..22f04c879879 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -33,8 +33,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check grad col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared'] row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias'] - check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-7, rtol=1e-5, dim=0, verbose=False) - check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-7, rtol=1e-5, dim=1, verbose=False) + check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) + check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) # check weights are tied if hasattr(org_model, 'lm_head'): @@ -45,11 +45,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @parameterize('use_lazy_init', [False, True]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): +@parameterize('enable_flash_attention', [True, False]) +@parameterize('enable_jit_fused', [True, False]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init, enable_flash_attention, + enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - use_lazy_init) + enable_flash_attention, enable_jit_fused, use_lazy_init) check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 7833ab70275d..d179c8a8ee32 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -20,7 +20,9 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3) + # do backward org_loss.backward() shard_loss.backward() @@ -45,10 +47,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_vit_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +@parameterize('enable_jit_fused', [True, False]) +def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention, enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index a271bbdf1223..9b38ae07b1d6 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -48,12 +48,16 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +@parameterize('enable_jit_fused', [True, False]) +def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism) + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index e1c7446f40db..28369d4c9fdb 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -167,4 +167,4 @@ def test_cross_attention(proj_shape, dtype, dropout): torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}" torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" \ No newline at end of file + torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" From eff7572fff21c9edc978b42635a7ee451bf8bd4b Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 8 Aug 2023 17:46:44 +0800 Subject: [PATCH 68/80] [pipeline] rewrite t5 tests & support multi-tensor transmitting in pipeline (#4388) * fix remaining t5 bugs/rewrite t5 tests * fix multi-tensor communication in pipeline * rearrange test_config * fix keyerror in sync_shared_params * fix get_held_layers & Randomnizer, complete t5 tests * erase printing * fix get_held_layers through modifying _release_unheld_layers * fix _get_recursive_held_layers bug --- .../booster/plugin/hybrid_parallel_plugin.py | 6 +- colossalai/pipeline/p2p.py | 6 +- colossalai/pipeline/schedule/_utils.py | 2 +- colossalai/pipeline/schedule/one_f_one_b.py | 11 +- colossalai/shardformer/layer/utils.py | 7 + colossalai/shardformer/modeling/t5.py | 95 +++++------ colossalai/shardformer/policies/t5.py | 51 ++---- colossalai/shardformer/shard/sharder.py | 16 +- .../test_model/test_shard_gpt2.py | 7 +- .../test_model/test_shard_t5.py | 150 ++++++++++++------ .../test_model/test_shard_t5_pipeline.py | 101 ------------ 11 files changed, 201 insertions(+), 251 deletions(-) delete mode 100644 tests/test_shardformer/test_model/test_shard_t5_pipeline.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a22bdb7199bb..42942aaeb89d 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -50,8 +50,10 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp def sync_shared_params(self): for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): - param = shared_param[self.stage_manager.stage] - dist.all_reduce(param.grad, group=group) + if self.stage_manager.stage in shared_param: + param = shared_param[self.stage_manager.stage] + dist.all_reduce(param.grad, group=group) + dist.barrier() def no_sync(self) -> Iterator[None]: # no sync grads across data parallel diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index f741b8363f13..af7a00b5c720 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -3,6 +3,7 @@ import io import pickle +import re from typing import Any, List, Optional, Union import torch @@ -31,7 +32,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - if b'cuda' in buf: buf_array = bytearray(buf) device_index = torch.cuda.current_device() - buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index + # There might be more than one output tensors during forward + for cuda_str in re.finditer(b'cuda', buf_array): + pos = cuda_str.start() + buf_array[pos + 5] = 48 + device_index buf = bytes(buf_array) io_bytes = io.BytesIO(buf) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 045c86e40e63..3ed9239272f1 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -86,7 +86,7 @@ def retain_grad(x: Any) -> None: Args: x (Any): Object to be called. """ - if isinstance(x, torch.Tensor): + if isinstance(x, torch.Tensor) and x.requires_grad: x.retain_grad() diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index d907d53edcde..ade3cf456fe3 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -107,8 +107,15 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], if output_obj_grad is None: optimizer.backward(output_obj) else: - for k, grad in output_obj_grad.items(): - optimizer.backward_by_grad(output_obj[k], grad) + if "backward_tensor_keys" not in output_obj: + for k, grad in output_obj_grad.items(): + optimizer.backward_by_grad(output_obj[k], grad) + else: + for k, grad in output_obj_grad.items(): + output_obj[k].grad = grad + for k in output_obj["backward_tensor_keys"]: + tensor_to_backward = output_obj[k] + optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad) # Collect the grad of the input_obj. input_obj_grad = None diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index f2ac6563c46f..09cb7bfe1407 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -122,6 +122,13 @@ def increment_index(): """ Randomizer._INDEX += 1 + @staticmethod + def reset_index(): + """ + Reset the index to zero. + """ + Randomizer._INDEX = 0 + @staticmethod def is_randomizer_index_synchronized(process_group: ProcessGroup = None): """ diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 0b3486e87c7e..d622da452366 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -238,7 +238,8 @@ def custom_forward(*inputs): return { 'hidden_states': hidden_states, 'position_bias': position_bias, - 'encoder_decoder_position_bias': encoder_decoder_position_bias + 'encoder_decoder_position_bias': encoder_decoder_position_bias, + 'backward_tensor_keys': ['hidden_states'] } @staticmethod @@ -261,8 +262,10 @@ def t5_model_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, position_bias: Optional[torch.Tensor] = None, encoder_decoder_position_bias: Optional[torch.Tensor] = None, + backward_tensor_keys: Optional[List[str]] = None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: @@ -303,7 +306,6 @@ def t5_model_forward( decoder_head_mask = head_mask in_decoder = stage_manager.stage >= decoder_starting_stage - # Stage is in encoder, directly return the output of t5_stack_forward if not in_decoder: encoder_outputs = T5PipelineForwards.t5_stack_forward( @@ -323,25 +325,18 @@ def t5_model_forward( decoder_starting_stage=decoder_starting_stage) if stage_manager.stage == decoder_starting_stage - 1: # last stage of encoder - return {'encoder_outputs': encoder_outputs} + return {'encoder_hidden_states': encoder_outputs[0]} else: return encoder_outputs at_last_decoder_stage = stage_manager.is_last_stage() at_first_decoder_stage = stage_manager.stage == decoder_starting_stage - if encoder_outputs is None: - raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.") - - encoder_hidden_states = encoder_outputs[0] - if return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) + if encoder_outputs is not None: + encoder_hidden_states = encoder_outputs[0] + elif encoder_hidden_states is None: + raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.") - # Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in. if not at_first_decoder_stage and hidden_states is None: raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") @@ -360,6 +355,7 @@ def t5_model_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + stage_manager=stage_manager, hidden_states=hidden_states, position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, @@ -368,22 +364,19 @@ def t5_model_forward( # Directly return outputs of overloaded T5Stack forward if not at last stage. if not at_last_decoder_stage: - decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage + # encoder_hidden_states should be passed to the next stage + decoder_outputs['encoder_hidden_states'] = encoder_hidden_states return decoder_outputs if not return_dict: - return decoder_outputs + encoder_outputs - - return Seq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) + return decoder_outputs + encoder_hidden_states + else: + return Seq2SeqModelOutput(last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states) @staticmethod def t5_for_conditional_generation_forward( @@ -406,8 +399,10 @@ def t5_for_conditional_generation_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, position_bias: Optional[torch.Tensor] = None, encoder_decoder_position_bias: Optional[torch.Tensor] = None, + backward_tensor_keys: Optional[List[str]] = None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: @@ -468,28 +463,25 @@ def t5_for_conditional_generation_forward( decoder_starting_stage=decoder_starting_stage) if stage_manager.stage == decoder_starting_stage - 1: # last stage of encoder - return {'encoder_outputs': encoder_outputs} + return {'encoder_hidden_states': encoder_outputs[0]} else: return encoder_outputs at_last_decoder_stage = stage_manager.is_last_stage() at_first_decoder_stage = stage_manager.stage == decoder_starting_stage - if encoder_outputs is None: - raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.") + if encoder_outputs is not None: + encoder_hidden_states = encoder_outputs[0] + elif encoder_hidden_states is None: + raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.") - encoder_hidden_states = encoder_outputs[0] - if return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - # Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in. if not at_first_decoder_stage and hidden_states is None: raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + # Decode decoder_outputs = T5PipelineForwards.t5_stack_forward( self.decoder, @@ -505,6 +497,7 @@ def t5_for_conditional_generation_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + stage_manager=stage_manager, hidden_states=hidden_states, position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, @@ -513,7 +506,8 @@ def t5_for_conditional_generation_forward( # Directly return outputs of overloaded T5Stack forward if not at last stage. if not at_last_decoder_stage: - decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage + # encoder_hidden_states should be passed to the next stage + decoder_outputs['encoder_hidden_states'] = encoder_hidden_states return decoder_outputs sequence_output = decoder_outputs[0] @@ -533,20 +527,16 @@ def t5_for_conditional_generation_forward( loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) if not return_dict: - output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + output = (lm_logits,) + decoder_outputs[1:] + encoder_hidden_states return ((loss,) + output) if loss is not None else output - return Seq2SeqLMOutput( - loss=loss, - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) + return Seq2SeqLMOutput(loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states) @staticmethod def t5_encoder_model_forward( @@ -562,6 +552,7 @@ def t5_encoder_model_forward( hidden_states: Optional[torch.FloatTensor] = None, position_bias: Optional[torch.Tensor] = None, encoder_decoder_position_bias: Optional[torch.Tensor] = None, + backward_tensor_keys: Optional[List[str]] = None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 5e78ae9093fa..2ef52c214c6b 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -260,7 +260,7 @@ def get_held_layers(self) -> List[nn.Module]: model = self.model encoder = self.model.encoder - decoder = self.model.__dict__.get('decoder', None) + decoder = getattr(self.model, 'decoder', None) num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 @@ -300,7 +300,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli stage_manager = self.pipeline_stage_manager encoder = self.model.encoder - decoder = self.model.__dict__.get('decoder', None) + decoder = getattr(self.model, 'decoder', None) num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 @@ -355,15 +355,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]} - for k, v in binding_map.items(): - src = getattr_(self.model, k) - for dst in v: - setattr_(self.model, dst, src) - return self.model - class T5ForConditionalGenerationPolicy(T5BasePolicy): @@ -409,28 +400,21 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: stage_manager.num_stages) shared_params = [] + shared_embedding = {} if id(module.decoder.embed_tokens.weight) == id(module.shared.weight): - shared_params.append({ - 0: module.shared.weight, - decoder_starting_stage: module.decoder.embed_tokens.weight - }) + shared_embedding[0] = module.shared.weight + shared_embedding[decoder_starting_stage] = module.decoder.embed_tokens.weight + if id(module.lm_head.weight) == id(module.shared.weight): - shared_params.append({0: module.shared.weight, stage_manager.num_stages - 1: module.lm_head.weight}) - return shared_params - return [] + shared_embedding[0] = module.shared.weight + shared_embedding[stage_manager.num_stages - 1] = module.lm_head.weight - def postprocess(self): - super().postprocess() - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = { - "shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] - } - for k, v in binding_map.items(): - src = getattr_(self.model, k) - for dst in v: - setattr_(self.model, dst, src) + if len(shared_embedding) > 0: + shared_params.append(shared_embedding) - return self.model + return shared_params + + return [] class T5EncoderPolicy(T5BasePolicy): @@ -462,12 +446,3 @@ def get_held_layers(self) -> List[nn.Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] - - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]} - for k, v in binding_map.items(): - src = getattr_(self.model, k) - for dst in v: - setattr_(self.model, dst, src) - return self.model diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index ae8cd8c6e553..0ed745a1fc4a 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -198,6 +198,20 @@ def _replace_sub_module( setattr_(org_layer, suffix, replace_layer) + def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]: + + def collect_sub_modules(module: nn.Module): + if module is None: + return + recursive_held_layers.append(module) + for name, child in module.named_children(): + collect_sub_modules(child) + + recursive_held_layers = [] + for module in held_layers: + collect_sub_modules(module) + return recursive_held_layers + def _release_unheld_layers(self) -> Optional[Set[nn.Module]]: r""" Release the unheld layers in the model @@ -205,7 +219,7 @@ def _release_unheld_layers(self) -> Optional[Set[nn.Module]]: if self.shard_config and self.shard_config.pipeline_stage_manager: held_layers = self.policy.get_held_layers() set_tensors_to_none(self.model, exclude=set(held_layers)) - return set(held_layers) + return set(self._get_recursive_held_layers(held_layers)) return None def _materialize(self) -> None: diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index f7213d8c50b4..1882bf7822cc 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -68,16 +68,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() + @parameterize('test_config', [{ - 'tp_size': 1, + 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 4, + 'enable_fused_normalization': True, 'use_lazy_init': True }, { - 'tp_size': 2, + 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'enable_fused_normalization': False, 'use_lazy_init': False }, { 'tp_size': 4, diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 22f04c879879..d807ffa06296 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -1,60 +1,110 @@ -import os - import pytest import torch import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward - - -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - # the value "past_key_values" is sharded, so we ignore - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], atol=1e-5) - - # do backward - org_loss.backward() - shard_loss.backward() - - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" - - # check grad - col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared'] - row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias'] - check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) - check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) - - # check weights are tied - if hasattr(org_model, 'lm_head'): - assert org_model.shared.weight.data.data_ptr() == org_model.lm_head.weight.data.data_ptr() - assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr() - - -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('use_lazy_init', [False, True]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init, enable_flash_attention, - enable_jit_fused): +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + + if org_model.__class__.__name__ != 'T5ForConditionalGeneration': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + + check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) + + # unwrap model + t5 = org_model + sharded_t5 = sharded_model.unwrap() + + row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] + + # check weights and gradients + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-5, rtol=1e-3, dim=0) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) + + torch.cuda.empty_cache() + + +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_fused_normalization': True, + 'use_lazy_init': True +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}, { + 'tp_size': 1, + 'pp_size': 4, + 'num_microbatches': 4, + 'use_lazy_init': False +}]) +@clear_cache_before_run() +def run_t5_test(test_config): + + # TODO: add plugin_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + # TODO: add test_config for flash attention & jit operator after supporting + sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, enable_jit_fused, use_lazy_init) - check_state_dict(org_model, sharded_model, name=name) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + # skip 4-stage pp test for t5_encoder + if test_config['pp_size'] > 2 and name == 'transformers_t5_encoder_model': + continue + + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() @@ -68,7 +118,7 @@ def check_t5(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_t5(): - spawn(check_t5, 2) + spawn(check_t5, 4) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py b/tests/test_shardformer/test_model/test_shard_t5_pipeline.py deleted file mode 100644 index 7f3a5f2ea40b..000000000000 --- a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py +++ /dev/null @@ -1,101 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.t5 import T5BasePolicy -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_pipeline_model - - -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # TODO: add tests for forward/backward later - pass - - -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('enable_fused_normalization', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_t5.py -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - - inputs = data_gen_fn() - inputs = {k: v.cuda() for k, v in inputs.items()} - input_ids = inputs['input_ids'] - - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - - batch_size, seq_len = input_ids.shape - hidden_size = sharded_model.config.d_model - num_heads = sharded_model.config.num_heads - hidden_state_shape = (batch_size, seq_len, hidden_size) - position_bias_shape = (batch_size, num_heads, seq_len, seq_len) - - num_encoder_layers = len(sharded_model.encoder.block) - decoder = sharded_model.__dict__.get('decoder', None) - num_decoder_layers = len(decoder.block) if decoder else 0 - - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(num_encoder_layers, num_decoder_layers, PP_SIZE) - stage = stage_manager.stage - at_first_stage = (stage == 0) or (stage == decoder_starting_stage) - at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1) - in_decoder = stage >= decoder_starting_stage - - if not at_first_stage: - # change inputs if not the first stage - hidden_states = torch.zeros(*hidden_state_shape).cuda() - position_bias = torch.zeros(*position_bias_shape).cuda() - encoder_decoder_position_bias = torch.zeros(*position_bias_shape).cuda() - inputs['input_ids'] = None - inputs['hidden_states'] = hidden_states - inputs['position_bias'] = position_bias - inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias - if in_decoder: - encoder_output_states = torch.zeros(*hidden_state_shape).cuda() - inputs['encoder_outputs'] = (encoder_output_states,) - - sharded_model.train() - output = sharded_model(**inputs) - if at_last_stage: - if name == 'transformers_t5_for_conditional_generation' and in_decoder: - assert output.loss is not None - else: - if name != 'transformers_t5_encoder_model' and not in_decoder: - output = output['encoder_outputs'] - assert output[0].shape == hidden_state_shape - else: - assert output['hidden_states'].shape == hidden_state_shape - # position_bias information should be passed in T5 - assert output['position_bias'].shape == position_bias_shape - if in_decoder: - assert output['encoder_decoder_position_bias'].shape == position_bias_shape - - torch.cuda.empty_cache() - - -def check_t5(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_t5_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_t5(): - spawn(check_t5, 4) - - -if __name__ == "__main__": - test_t5() From a3631fdb80bf6629e72286b71a36b15c9c6af6f1 Mon Sep 17 00:00:00 2001 From: flybird1111 <1829166702@qq.com> Date: Wed, 9 Aug 2023 14:32:19 +0800 Subject: [PATCH 69/80] [shardformer] update shardformer to use flash attention 2 (#4392) * cherry-pick flash attention 2 cherry-pick flash attention 2 * [shardformer] update shardformer to use flash attention 2 [shardformer] update shardformer to use flash attention 2, fix [shardformer] update shardformer to use flash attention 2, fix [shardformer] update shardformer to use flash attention 2, fix --- colossalai/kernel/cuda_native/__init__.py | 5 +++-- colossalai/shardformer/modeling/blip2.py | 2 +- colossalai/shardformer/modeling/chatglm.py | 3 +-- colossalai/shardformer/modeling/gpt2.py | 2 +- colossalai/shardformer/modeling/llama.py | 2 +- colossalai/shardformer/modeling/opt.py | 2 +- colossalai/shardformer/modeling/vit.py | 2 +- colossalai/shardformer/modeling/whisper.py | 2 +- tests/test_utils/test_flash_attention.py | 1 - 9 files changed, 10 insertions(+), 11 deletions(-) diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py index 4910717b5723..e0136d86e561 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/colossalai/kernel/cuda_native/__init__.py @@ -1,8 +1,9 @@ from .layer_norm import MixedFusedLayerNorm as LayerNorm from .mha.mha import ColoAttention from .multihead_attention import MultiHeadAttention -from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax +from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax __all__ = [ - 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention' + 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention', + 'AttnMaskType' ] diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index c5c6b14ba993..69730fd3d254 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -65,7 +65,7 @@ def get_blip2_flash_attention_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2Attention - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + from colossalai.kernel.cuda_native import ColoAttention def forward( self: Blip2Attention, diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py index 3d453c3bd6db..a95966c3b99e 100644 --- a/colossalai/shardformer/modeling/chatglm.py +++ b/colossalai/shardformer/modeling/chatglm.py @@ -19,7 +19,7 @@ def get_flash_core_attention_forward(): - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from .chatglm2_6b.modeling_chatglm import CoreAttention @@ -126,7 +126,6 @@ def forward( return forward - class ChatGLMPipelineForwards: ''' This class serves as a micro library for ChatGLM model forwards under pipeline parallelism. diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index e02581fbaa9b..a12a9796fa8a 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -674,7 +674,7 @@ def get_gpt2_flash_attention_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention def split_heads(tensor, num_heads, attn_head_size): """ diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 9d6335503b36..2f54daac586a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -392,7 +392,7 @@ def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention def forward( self: LlamaAttention, diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 299dfb5562f3..bdf141816737 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -8,7 +8,7 @@ def get_opt_flash_attention_forward(): from transformers.models.opt.modeling_opt import OPTAttention - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention def forward( self: OPTAttention, diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 22c4dd998cac..eb0ea4c7502b 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -342,7 +342,7 @@ def get_vit_flash_self_attention_forward(): from transformers.models.vit.modeling_vit import ViTSelfAttention - from colossalai.kernel.cuda_native.flash_attention import ColoAttention + from colossalai.kernel.cuda_native import ColoAttention def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 6bc387ac8974..0a16c6f788da 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -8,7 +8,7 @@ def get_whisper_flash_attention_forward(): from transformers.models.whisper.modeling_whisper import WhisperAttention - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 28369d4c9fdb..f775710c40c2 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -13,7 +13,6 @@ from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType DTYPE = [torch.float16, torch.bfloat16, torch.float32] -FLASH_DTYPE = [torch.float16, torch.bfloat16] def attention_ref(q, k, v, attn_mask=None, causal=False): From 1fa9a268f2a71ff6d551cdf73ee8b6aafa487d1e Mon Sep 17 00:00:00 2001 From: flybird1111 <1829166702@qq.com> Date: Thu, 10 Aug 2023 13:59:30 +0800 Subject: [PATCH 70/80] [shardformer] test all optimizations (#4399) [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] test all optimizations --- .../booster/plugin/hybrid_parallel_plugin.py | 11 +++- requirements/requirements-test.txt | 2 +- tests/test_shardformer/test_model/_utils.py | 16 ++--- .../test_model/test_shard_gpt2.py | 59 ++++++++++++------- 4 files changed, 59 insertions(+), 29 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 42942aaeb89d..28a19af0ce91 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -148,7 +148,10 @@ def __init__( precision: str = 'fp16', zero_stage: int = 0, cpu_offload: bool = False, + enable_all_optimization: bool = False, enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, num_microbatches: Optional[int] = None, initial_scale: float = 2**16, min_scale: float = 1, @@ -171,7 +174,10 @@ def __init__( self.precision = precision self.zero_stage = zero_stage self.cpu_offload = cpu_offload + self.enable_all_optimization = enable_all_optimization self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_jit_fused = enable_jit_fused self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None @@ -186,7 +192,10 @@ def __init__( self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, - enable_fused_normalization=self.enable_fused_normalization) + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 510af5f3c7ff..a37d00326a08 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -19,4 +19,4 @@ ninja flash_attn>=2.0 datasets ninja -flash-attn +flash-attn>=2.0 diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 98cdc5a4b95b..cce21809d829 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,6 +1,5 @@ import copy from contextlib import nullcontext -from typing import Optional from typing import Any, Callable, Dict, List, Optional import torch @@ -16,8 +15,8 @@ from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.policies.auto_policy import Policy from colossalai.shardformer._utils import getattr_ +from colossalai.shardformer.policies.auto_policy import Policy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor @@ -156,10 +155,12 @@ def _criterion(outputs, inputs): else: data = {k: v.cuda() for k, v in data.items()} sharded_output = sharded_model(**data) + sharded_loss = criterion(sharded_output) - sharded_loss.backward() + sharded_optimizer.backward(sharded_loss) org_model.train() + data = {k: v.cuda() for k, v in data.items()} org_output = org_model(**data) org_loss = criterion(org_output) org_loss.backward() @@ -181,12 +182,12 @@ def check_output_hidden_state(org_output: Tensor, if stage_manager and stage_manager.is_last_stage(): sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0) - assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \ + assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): - assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \ + assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol), \ f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" @@ -213,7 +214,7 @@ def check_weight(org_model: Module, if verbose and dist.get_rank() == 0: print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") - assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \ + assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \ f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}" @@ -244,6 +245,7 @@ def check_grad(org_model: Module, if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") + assert torch.allclose( - org_grad, shard_grad, rtol=rtol, atol=atol + org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 1882bf7822cc..3ac8fa26d860 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -3,6 +3,7 @@ from torch import distributed as dist import colossalai +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -38,33 +39,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == 'GPT2Model': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) + # check loss + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + def unwrap(module): + if isinstance(module, HybridParallelModule): + module = module.unwrap() + if module.__class__.__name__ == 'GPT2Model': + return module + return module.transformer # unwrap model - if org_model.__class__.__name__ == 'GPT2Model': - gpt2 = org_model - sharded_gpt2 = sharded_model.unwrap() - else: - gpt2 = org_model.transformer - sharded_gpt2 = sharded_model.unwrap().transformer + gpt2 = unwrap(org_model) + sharded_gpt2 = unwrap(sharded_model) col_layer_for_check = ['h[0].mlp.c_fc'] row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] # check grad + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) - check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) + check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) # check weights after optimizer.step() org_optimizer.step() sharded_optimizer.step() + if test_config['precision'] == 'fp32': + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False) + check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) torch.cuda.empty_cache() @@ -73,29 +90,31 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 4, - 'enable_fused_normalization': True, - 'use_lazy_init': True + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp32', }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 4, 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }]) @clear_cache_before_run() def run_gpt2_test(test_config): # TODO: add test_config for TP+DP after supporting & debugging it - # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - - # TODO: add test_config for flash attention & jit operator after supporting + # TODO: check and debug TP+AMP sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') - test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) From 4a7b0f19fca0631d3a9f946b7aea483869bd9bac Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 11 Aug 2023 10:32:53 +0800 Subject: [PATCH 71/80] [pipeline] rewrite bert tests and fix some bugs (#4409) * add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt * add bloom model and policy ,revise the base class of policy * revise * revision * add bert_for_pretraining * add bert_for_pretraining forward and policy * fix typos * cancel warning * change the imediate output to default dict * change the default output of get_shared_params * rewrite bert test * rewrite bert test * fix some bugs * del pipeline tests * del pipeline tests * del useless print * del useless print * rewrite data repeats --- tests/kit/model_zoo/transformers/bert.py | 3 +- tests/test_shardformer/test_model/_utils.py | 8 +- .../test_model/test_shard_bert.py | 129 +++++++++++------- .../test_model/test_shard_bert_pipeline.py | 107 --------------- 4 files changed, 88 insertions(+), 159 deletions(-) delete mode 100644 tests/test_shardformer/test_model/test_shard_bert_pipeline.py diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 9834f5425027..52158596bcf8 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -104,7 +104,8 @@ def data_gen_for_qa(): output_transform_fn = lambda x: x # define loss funciton -loss_fn_for_bert_model = lambda x: x.pooler_output.sum() +loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state + )) loss_fn = lambda x: x.loss config = transformers.BertConfig(hidden_size=128, diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index cce21809d829..c9da9d32e554 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -131,6 +131,8 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer, data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable, booster: Booster): + org_model.cuda() + sharded_model.cuda() def _criterion(outputs, inputs): outputs = output_transform_fn(outputs) @@ -141,7 +143,8 @@ def _criterion(outputs, inputs): sharded_model.train() if booster.plugin.stage_manager is not None: data = { - k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v + k: v.to('cuda').repeat(*([4] + [1] * + (v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() } data_iter = iter([data]) @@ -162,6 +165,7 @@ def _criterion(outputs, inputs): org_model.train() data = {k: v.cuda() for k, v in data.items()} org_output = org_model(**data) + org_loss = criterion(org_output) org_loss.backward() @@ -226,7 +230,6 @@ def check_grad(org_model: Module, atol: float = 1e-5, rtol: float = 1e-3, verbose: bool = False): - for suffix in layer_suffix: org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad @@ -242,7 +245,6 @@ def check_grad(org_model: Module, # embedding may be resized when using tensor parallel if shard_grad.shape[0] > org_grad.shape[0]: shard_grad = shard_grad[:org_grad.shape[0], :] - if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index afc1507e8b24..fdbcd014e1b8 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -1,65 +1,98 @@ import pytest import torch +from torch import distributed as dist import colossalai -from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.auto_policy import get_autopolicy -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # unwarp model +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if org_model.__class__.__name__ == 'BertModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + + check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) + # unwrap model if org_model.__class__.__name__ == 'BertModel': bert = org_model - sharded_bert = sharded_model + sharded_bert = sharded_model.unwrap() else: bert = org_model.bert - sharded_bert = sharded_model.bert - - # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output) - - # do backward - org_loss.backward() - shard_loss.backward() - - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" - - # check grad - col_layer_for_check = ['encoder.layer[0].attention.self.query', 'embeddings.word_embeddings'] - row_layer_for_check = ['encoder.layer[0].attention.output.dense'] - check_grad(bert, sharded_bert, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False) - check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False) - - -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -@parameterize('use_lazy_init', [False, True]) -def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused, - use_lazy_init): + sharded_bert = sharded_model.unwrap().bert + + col_layer_for_check = ['encoder.layer[0].output.dense'] + row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] + + if stage_manager is None or stage_manager.is_first_stage(): + #check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3) + #check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3) + check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) + check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False) + + torch.cuda.empty_cache() + + +@parameterize('test_config', [{ + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': True +}, { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': False, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}]) +def run_bert_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + test_config['precision'] = 'float' + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, enable_jit_fused, use_lazy_init) - check_state_dict(org_model, sharded_model, name=name) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() @@ -73,7 +106,7 @@ def check_bert(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_bert(): - spawn(check_bert, 2) + spawn(check_bert, 4) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py deleted file mode 100644 index 3170b58a1175..000000000000 --- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py +++ /dev/null @@ -1,107 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.auto_policy import get_autopolicy -from colossalai.shardformer.shard import ShardConfig -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward - - -def check_bert_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager): - stage_manager = stage_manager - policy = get_autopolicy(model) - policy.set_model(model) - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - policy.set_shard_config(model_config) - layers = policy.get_held_layers() - if stage_manager.is_first_stage(): - assert len(layers) == 1 + 1 - else: - if name == "transformers_bert": - assert len(layers) == 1 + 1 - elif name in [ - "transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification", - "transformers_bert_for_mcq" - ]: - assert len(layers) == 1 + 3 - else: - assert len(layers) == 1 + 2 - - -def check_bert_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager): - if name == 'transformers_bert_for_mcq': - x = torch.randint(0, 1000, (2, 3, 3)).cuda() - attention_mask = torch.ones_like(x).cuda() - if stage_manager.stage == 0: - output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - assert output['hidden_states'].shape == (6, 3, 128) - else: - hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda() - output = sharded_model(input_ids=x, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) - assert output[0].shape == (2, 3) - else: - x = torch.randint(0, 1000, (2, 3)).cuda() - # one batch, 2 single sentences, each sentence has 3 tokens - hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - assert output['hidden_states'].shape == (2, 3, 128) - else: - attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model(hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) - assert output[0].shape[0] == 2 - - -@parameterize('enable_fused_normalization', [False]) -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_bert -def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - PP_DIM = 0 - PP_SIZE = 2 - pg_mesh = ProcessGroupMesh(PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - check_bert_model_policy(name, org_model, stage_manager) - check_bert_model_pipeline_forward(name, sharded_model, stage_manager) - - torch.cuda.empty_cache() - - -def check_bert(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_bert_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_bert(): - spawn(check_bert, 2) - - -if __name__ == "__main__": - test_bert() From 58f9fce1290c3d2c54b900c255ce4b9613b4083d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 11 Aug 2023 11:44:23 +0800 Subject: [PATCH 72/80] [shardformer]fix, test gpt2 for AMP+TP (#4403) * [shardformer] gpt2 tests fix [shardformer] test all optimizations (#4399) [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] gpt2 tests fix * [shardformer] gpt2 tests fix --- tests/test_shardformer/test_model/_utils.py | 8 +++----- tests/test_shardformer/test_model/test_shard_gpt2.py | 8 +++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index c9da9d32e554..c51df07f6c11 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -210,7 +210,7 @@ def check_weight(org_model: Module, if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): sharded_weight_list = [ - torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) + torch.zeros_like(sharded_weight).to('cuda') for _ in range(dist.get_world_size(tp_group)) ] dist.all_gather(sharded_weight_list, sharded_weight, tp_group) sharded_weight = torch.cat(sharded_weight_list, dim=dim) @@ -219,7 +219,7 @@ def check_weight(org_model: Module, print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \ - f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}" + f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" def check_grad(org_model: Module, @@ -236,9 +236,7 @@ def check_grad(org_model: Module, shard_weight = getattr_(sharded_model, suffix).weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [ - torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) - ] + shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) shard_grad = torch.cat(shard_grad_list, dim=dim) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 3ac8fa26d860..274cfaa39ad1 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -23,7 +23,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - org_loss, org_output, sharded_loss, sharded_output = \ run_forward_backward_with_hybrid_plugin( org_model, @@ -47,7 +46,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if org_model.__class__.__name__ == 'GPT2Model': check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - # check loss check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) def unwrap(module): @@ -92,13 +90,14 @@ def unwrap(module): 'num_microbatches': 4, 'enable_all_optimization': True, 'use_lazy_init': True, - 'precision': 'fp32', + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, 'enable_all_optimization': True, - 'use_lazy_init': False, + 'use_lazy_init': True, 'precision': 'fp16', 'initial_scale': 1, }, { @@ -112,7 +111,6 @@ def unwrap(module): def run_gpt2_test(test_config): # TODO: add test_config for TP+DP after supporting & debugging it - # TODO: check and debug TP+AMP sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') From ebe81c314475b3266c41f8d463d0ebb34a0ebb30 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 11 Aug 2023 15:43:23 +0800 Subject: [PATCH 73/80] [shardformer] rewrite tests for opt/bloom/llama/vit/chatglm (#4395) * rewrite opt tests * rewrite llama tests * rewrite bloom & vit tests * rewrite chatglm tests * fix LinearCol for classfiers * add judge for other tp layers, fix lazy init in util --- colossalai/shardformer/layer/linear.py | 16 + .../shardformer/layer/qkv_fused_linear.py | 16 + colossalai/shardformer/modeling/opt.py | 497 +++++++++++++- .../shardformer/policies/auto_policy.py | 6 + colossalai/shardformer/policies/opt.py | 618 +----------------- tests/kit/model_zoo/transformers/bloom.py | 8 +- tests/kit/model_zoo/transformers/chatglm.py | 19 +- tests/kit/model_zoo/transformers/vit.py | 6 +- tests/test_shardformer/test_model/_utils.py | 35 +- .../test_model/test_shard_bloom.py | 118 ++-- .../test_model/test_shard_bloom_pipeline.py | 90 --- .../test_model/test_shard_chatglm.py | 179 ++--- .../test_model/test_shard_chatglm_pipeline.py | 86 --- .../test_model/test_shard_llama.py | 144 ++-- .../test_model/test_shard_llama_pipeline.py | 89 --- .../test_model/test_shard_opt.py | 145 ++-- .../test_model/test_shard_opt_pipeline.py | 70 -- .../test_model/test_shard_vit.py | 137 +++- .../test_model/test_shard_vit_pipeline.py | 74 --- 19 files changed, 1072 insertions(+), 1281 deletions(-) delete mode 100644 tests/test_shardformer/test_model/test_shard_bloom_pipeline.py delete mode 100644 tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py delete mode 100644 tests/test_shardformer/test_model/test_shard_llama_pipeline.py delete mode 100644 tests/test_shardformer/test_model/test_shard_opt_pipeline.py delete mode 100644 tests/test_shardformer/test_model/test_shard_vit_pipeline.py diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index bb36854bd772..d59b68ce4480 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -143,6 +143,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] + tp_size = dist.get_world_size(process_group) + if out_features < tp_size: + return module + + if out_features % tp_size != 0: + raise ValueError( + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = Linear1D_Col(in_features=in_features, out_features=out_features, bias=bias, @@ -293,6 +301,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] + tp_size = dist.get_world_size(process_group) + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = Linear1D_Row(in_features=in_features, out_features=out_features, bias=bias, diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 42417f8bcc43..df942d43ee2d 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -265,6 +265,14 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] + tp_size = dist.get_world_size(process_group) + if out_features < tp_size: + return module + + if out_features % tp_size != 0: + raise ValueError( + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features, out_features=out_features, bias=bias, @@ -420,6 +428,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] + tp_size = dist.get_world_size(process_group) + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features, out_features=out_features, bias=bias, diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index bdf141816737..9afdfff4d71d 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -1,7 +1,500 @@ -from typing import Optional, Tuple +import random +from typing import List, Optional, Tuple, Union import torch -from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +from transformers.models.opt.modeling_opt import ( + OPTForCausalLM, + OPTForQuestionAnswering, + OPTForSequenceClassification, + OPTModel, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class OPTPipelineForwards: + ''' + This class serves as a micro library for forward function substitution of OPT models + under pipeline setting. + ''' + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + from transformers.models.opt.modeling_opt import _make_causal_mask + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + _dtype, + device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, + tgt_len=input_shape[-1]).to(device) + combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + + combined_attention_mask) + + return combined_attention_mask + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + @staticmethod + def opt_model_forward( + self: OPTModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + ''' + This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward + ''' + + from transformers.modeling_outputs import BaseModelOutputWithPast + from transformers.utils import logging + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + decoder = self.decoder + if stage_manager.is_first_stage(): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + batch_size, seq_length = input_shape + + if inputs_embeds is None: + inputs_embeds = decoder.embed_tokens(input_ids) + + if decoder.project_in is not None: + inputs_embeds = decoder.project_in(inputs_embeds) + device = input_ids.device if input_ids is not None else inputs_embeds.device + _dtype = inputs_embeds.dtype + + else: + if hidden_states is None: + raise ValueError("hidden_states shouln't be None for intermediate stages.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + _dtype = hidden_states.dtype + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + # embed positions + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)") + + causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, + device, past_key_values_length) + + if stage_manager.is_first_stage(): + pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) + hidden_states = inputs_embeds + pos_embeds + + if decoder.gradient_checkpointing and decoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(decoder.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for" + f" {head_mask.size()[0]}.") + + start_idx, end_idx = stage_index[0], stage_index[1] + + torch.cuda.set_device(device) + + for idx in range(start_idx, end_idx): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + decoder_layer = decoder.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + dropout_probability = random.uniform(0, 1) + if decoder.training and (dropout_probability < decoder.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if decoder.gradient_checkpointing and decoder.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + causal_attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + if decoder.final_layer_norm is not None: + hidden_states = decoder.final_layer_norm(hidden_states) + if decoder.project_out is not None: + hidden_states = decoder.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + else: + return {'hidden_states': hidden_states} + + @staticmethod + def opt_for_causal_lm_forward( + self: OPTForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward. + Please refer to original code of transformers for more details. + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = OPTPipelineForwards.opt_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + if stage_manager.is_last_stage(): + logits = self.lm_head(outputs[0]).contiguous() + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def opt_for_sequence_classification_forward( + self: OPTForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward. + Please refer to original code of transformers for more details. + """ + + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0] + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} + + @staticmethod + def opt_for_question_answering_forward( + self: OPTForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward. + Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get('hidden_states') + return {'hidden_states': hidden_states} def get_opt_flash_attention_forward(): diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 2a041af19be8..eec339c02872 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -122,6 +122,12 @@ class PolicyLocation: PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"), "transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration": PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"), + + # ChatGLM + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": + PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"), + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": + PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"), } diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 88ecd8565091..ba6036bd0658 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,32 +1,14 @@ -import logging -import random from functools import partial -from types import MethodType -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List -import torch import torch.nn as nn from torch import Tensor, nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, -) -from transformers.models.opt.modeling_opt import ( - OPTForCausalLM, - OPTForQuestionAnswering, - OPTForSequenceClassification, - OPTModel, -) - -from colossalai.pipeline.stage_manager import PipelineStageManager + from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from .._utils import getattr_, setattr_ +from .._utils import getattr_ from ..modeling.jit import get_jit_fused_dropout_add_func -from ..modeling.opt import get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward +from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -228,6 +210,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: num_stages = self.pipeline_stage_manager.num_stages if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight): return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}] + return [] def postprocess(self): if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: @@ -295,594 +278,3 @@ def get_held_layers(self) -> List[nn.Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: "no shared params in OPTForSequenceClassification" return [] - - -class OPTPipelineForwards: - ''' - This class serves as a micro library for forward function substitution of OPT models - under pipeline setting. - ''' - - @staticmethod - def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - from transformers.models.opt.modeling_opt import _make_causal_mask - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - _dtype, - device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, - tgt_len=input_shape[-1]).to(device) - combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + - combined_attention_mask) - - return combined_attention_mask - - @staticmethod - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - @staticmethod - def opt_model_forward( - self: OPTModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - ''' - This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward - ''' - - from transformers.modeling_outputs import BaseModelOutputWithPast - from transformers.utils import logging - logger = logging.get_logger(__name__) - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - decoder = self.decoder - if stage_manager.is_first_stage(): - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - batch_size, seq_length = input_shape - - if inputs_embeds is None: - inputs_embeds = decoder.embed_tokens(input_ids) - - if decoder.project_in is not None: - inputs_embeds = decoder.project_in(inputs_embeds) - device = input_ids.device if input_ids is not None else inputs_embeds.device - _dtype = inputs_embeds.dtype - - else: - if hidden_states is None: - raise ValueError("hidden_states shouln't be None for intermediate stages.") - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape[0], input_shape[1] - device = hidden_states.device - _dtype = hidden_states.dtype - - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values_length + seq_length - # embed positions - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=device) - elif attention_mask.shape[1] != mask_seq_length: - raise ValueError( - f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " - f"{mask_seq_length} (sum of the lengths of current and past inputs)") - - causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, - device, past_key_values_length) - - if stage_manager.is_first_stage(): - pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) - hidden_states = inputs_embeds + pos_embeds - - if decoder.gradient_checkpointing and decoder.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') - past_key_values = None - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - # check if head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask], ["head_mask"]): - if attn_mask is not None: - if attn_mask.size()[0] != (len(decoder.layers)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for" - f" {head_mask.size()[0]}.") - - start_idx, end_idx = stage_index[0], stage_index[1] - - torch.cuda.set_device(device) - - for idx in range(start_idx, end_idx): - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - decoder_layer = decoder.layers[idx] - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - dropout_probability = random.uniform(0, 1) - if decoder.training and (dropout_probability < decoder.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if decoder.gradient_checkpointing and decoder.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - causal_attention_mask, - head_mask[idx] if head_mask is not None else None, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if stage_manager.is_last_stage(): - if decoder.final_layer_norm is not None: - hidden_states = decoder.final_layer_norm(hidden_states) - if decoder.project_out is not None: - hidden_states = decoder.project_out(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - - if stage_manager.is_last_stage(): - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - else: - return {'hidden_states': hidden_states} - - @staticmethod - def opt_for_causal_lm_forward( - self: OPTForCausalLM, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional - tensors are only required when the model is used as a decoder in a Sequence to Sequence model. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, OPTForCausalLM - - >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - from transformers.modeling_outputs import CausalLMOutputWithPast - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = OPTPipelineForwards.opt_model_forward( - self.model, - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) - if stage_manager.is_last_stage(): - logits = self.lm_head(outputs[0]).contiguous() - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} - - @staticmethod - def opt_for_sequence_classification_forward( - self: OPTForSequenceClassification, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - from transformers.modeling_outputs import SequenceClassifierOutputWithPast - from transformers.utils import logging - logger = logging.get_logger(__name__) - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) - - if stage_manager.is_last_stage(): - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0] - - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - logger.warning( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`") - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} - - @staticmethod - def opt_for_question_answering_forward( - self: OPTForQuestionAnswering, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, OPTForQuestionAnswering - >>> import torch - - >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - - >>> # note: we are loading a OPTForQuestionAnswering from the hub here, - >>> # so the head will be randomly initialized, hence the predictions will be random - >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") - - >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" - - >>> inputs = tokenizer(question, text, return_tensors="pt") - >>> with torch.no_grad(): - ... outputs = model(**inputs) - - >>> answer_start_index = outputs.start_logits.argmax() - >>> answer_end_index = outputs.end_logits.argmax() - - >>> answer_offset = len(tokenizer(question)[0]) - - >>> predict_answer_tokens = inputs.input_ids[ - ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 - ... ] - >>> predicted = tokenizer.decode(predict_answer_tokens) - >>> predicted - ' a nice puppet' - ```""" - from transformers.modeling_outputs import QuestionAnsweringModelOutput - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) - if stage_manager.is_last_stage(): - hidden_states = transformer_outputs[0] - - logits = self.qa_outputs(hidden_states) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = (start_logits, end_logits) + transformer_outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return QuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py index 177edbef8935..2d9c882089cb 100644 --- a/tests/kit/model_zoo/transformers/bloom.py +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -53,7 +53,8 @@ def data_gen_for_question_answering(): # inputs = tokenizer(question, text, return_tensors="pt") input_ids = torch.tensor( - [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64) + [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], + dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) start_positions = torch.tensor([1], dtype=torch.int64) end_positions = torch.tensor([10], dtype=torch.int64) @@ -73,12 +74,13 @@ def data_gen_for_question_answering(): loss_fn_for_classification = lambda x: x.loss loss_fn_for_question_answering = lambda x: x.loss -config = transformers.BloomConfig(n_layer=1, +config = transformers.BloomConfig(n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, - hidden_size=64) + hidden_size=64, + pad_token_id=50256) # register the following models model_zoo.register(name='transformers_bloom', diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 90bb70bc7f79..c6473ee2a025 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -17,14 +17,24 @@ def data_gen(): return dict(input_ids=input_ids, attention_mask=attention_mask) +def data_gen_for_conditional_generation(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = labels + return data + + # define output transform function output_transform_fn = lambda x: x # define loss function -loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.sum() -loss_fn = lambda x: x.logits.sum() +loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, + torch.ones_like(x.last_hidden_state)) +loss_fn = lambda x: x.loss -config = ChatGLMConfig(num_layers=1, +config = ChatGLMConfig(num_layers=2, padded_vocab_size=65024, hidden_size=64, num_attention_heads=8, @@ -33,7 +43,6 @@ def data_gen(): use_cache=True, torch_dtype=torch.float32) - model_zoo.register(name='transformers_chatglm', model_fn=lambda: ChatGLMModel(config, empty_init=False), data_gen_fn=data_gen, @@ -43,7 +52,7 @@ def data_gen(): model_zoo.register(name="transformers_chatglm_for_conditional_generation", model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_conditional_generation, output_transform_fn=output_transform_fn, loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py index 93a8d6c615d7..a84b8d31c284 100644 --- a/tests/kit/model_zoo/transformers/vit.py +++ b/tests/kit/model_zoo/transformers/vit.py @@ -7,11 +7,7 @@ # Register single-sentence VIT # =============================== -config = transformers.ViTConfig( - num_hidden_layers=4, - # hidden_size=128, - # intermediate_size=256, - num_attention_heads=4) +config = transformers.ViTConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4) # define data gen function diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index c51df07f6c11..921af2a8b1d0 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -104,27 +104,22 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c if 'use_lazy_init' in test_config: use_lazy_init = test_config.pop('use_lazy_init') - if use_lazy_init: - ctx = LazyInitContext() - else: - ctx = nullcontext() - - plugin = HybridParallelPlugin(**test_config) - booster = Booster(plugin=plugin) - + ctx = LazyInitContext() if use_lazy_init else nullcontext() with ctx: - org_model = model_fn().cuda() + org_model = model_fn() sharded_model = copy.deepcopy(org_model) - if use_lazy_init: - org_model = ctx.materialize(org_model) + ctx.materialize(org_model) + org_model = org_model.cuda() org_optimizer = Adam(org_model.parameters(), lr=1e-3) sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) criterion = loss_fn - sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster @@ -142,11 +137,12 @@ def _criterion(outputs, inputs): data = data_gen_fn() sharded_model.train() if booster.plugin.stage_manager is not None: - data = { - k: v.to('cuda').repeat(*([4] + [1] * - (v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v - for k, v in data.items() - } + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + data_iter = iter([data]) sharded_output = booster.execute_pipeline(data_iter, sharded_model, @@ -176,7 +172,8 @@ def check_output_hidden_state(org_output: Tensor, sharded_output: Tensor, stage_manager: Optional[PipelineStageManager] = None, atol: float = 1e-5, - rtol: float = 1e-3): + rtol: float = 1e-3, + dim: int = 0): org_hidden_state = org_output.last_hidden_state @@ -184,7 +181,7 @@ def check_output_hidden_state(org_output: Tensor, sharded_hidden_state = sharded_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(): - sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0) + sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim) assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index e11bcf92ea3c..d5a4ce083e2b 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -3,57 +3,101 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) - # do backward - org_loss.backward() - shard_loss.backward() + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group - assert torch.allclose(org_loss, shard_loss, - atol=1e-6), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + + if org_model.__class__.__name__ == 'BloomModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + + check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) # unwrap model if org_model.__class__.__name__ == 'BloomModel': bloom = org_model - sharded_bloom = sharded_model + sharded_bloom = sharded_model.unwrap() else: bloom = org_model.transformer - sharded_bloom = sharded_model.transformer + sharded_bloom = sharded_model.unwrap().transformer # check grad - col_layer_for_check = ['h[0].self_attention.query_key_value'] - row_layer_for_check = ['h[0].self_attention.dense'] - check_grad(bloom, sharded_bloom, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) - check_grad(bloom, sharded_bloom, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) - - -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -@parameterize('use_lazy_init', [False, True]) -def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused, - use_lazy_init): + row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] + col_layer_for_check = ['h[0].self_attention.dense'] + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=0, verbose=False) + check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=1, verbose=False) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) + + torch.cuda.empty_cache() + + +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': True, + 'use_lazy_init': True +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': False, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}]) +def run_bloom_test(test_config): + + # TODO: add test_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + # TODO: add test_config for flash attention & jit operator after supporting + sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, enable_jit_fused, use_lazy_init) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() torch.cuda.empty_cache() @@ -67,7 +111,7 @@ def check_bloom(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_bloom(): - spawn(check_bloom, 2) + spawn(check_bloom, 4) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py deleted file mode 100644 index 6695e8a687bd..000000000000 --- a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py +++ /dev/null @@ -1,90 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.auto_policy import get_autopolicy -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.shard import ShardConfig -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward - - -def check_bloom_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager): - policy = get_autopolicy(model) - policy.set_model(model) - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - policy.set_shard_config(model_config) - layers = policy.get_held_layers() - if stage_manager.is_first_stage(): - assert len(layers) == 0 + 2 - else: - if name == 'transformers_bloom': - assert len(layers) == 1 + 1 - elif name == 'transformers_bloom_for_token_classification': - assert len(layers) == 1 + 3 - else: - assert len(layers) == 1 + 2 - - -def check_bloom_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager): - if stage_manager.stage == 0: - x = torch.randint(0, 1000, (1, 3)).cuda() - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask) - assert output['hidden_states'].shape == (1, 3, 64) - else: - attention_mask = torch.ones((1, 3)).cuda() - hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda() - output = sharded_model( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - assert output[0].shape[0] == 1 - - -@parameterize('enable_fused_normalization', [False]) -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_bloom -def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - PP_DIM = 0 - PP_SIZE = 2 - pg_mesh = ProcessGroupMesh(PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - check_bloom_model_policy(name, org_model, stage_manager) - check_bloom_model_pipeline_forward(name, sharded_model, stage_manager) - - torch.cuda.empty_cache() - - -def check_bloom(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_bloom_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_bloom(): - spawn(check_bloom, 2) - - -if __name__ == "__main__": - test_bloom() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index c455a99d26ce..69e63ffc854e 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -1,99 +1,126 @@ -import copy -import os - import pytest import torch +from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) - # do backward - org_loss.backward() - shard_loss.backward() + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + if org_model.__class__.__name__ == 'ChatGLMModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3, dim=1) + + check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) # unwrap model if org_model.__class__.__name__ == 'ChatGLMModel': chatglm_model = org_model - shard_chatglm_model = sharded_model + shard_chatglm_model = sharded_model.unwrap() else: chatglm_model = org_model.transformer - shard_chatglm_model = sharded_model.transformer - - # check attention grad - org_grad = chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad - shard_grad = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad - shard_weight = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight + shard_chatglm_model = sharded_model.unwrap().transformer + + # check grad + row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] + col_layer_for_check = ['encoder.layers[0].self_attention.dense'] + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(chatglm_model, + shard_chatglm_model, + row_layer_for_check, + tp_group, + atol=1e-6, + rtol=1e-3, + dim=0, + verbose=False) + + check_grad(chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=1e-6, + rtol=1e-3, + dim=1, + verbose=False) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=1e-4, + rtol=1e-3, + dim=1, + verbose=False) - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" - - # check embedding weights - org_grad = chatglm_model.embedding.word_embeddings.weight.grad - shard_grad = shard_chatglm_model.embedding.word_embeddings.weight.grad - shard_weight = shard_chatglm_model.embedding.word_embeddings.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros_like(shard_grad) for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad + torch.cuda.empty_cache() - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': True, + 'use_lazy_init': True +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': False, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}]) +def run_chatglm_test(test_config): + + # TODO: add test_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + # TODO: add test_config for flash attention & jit operator after supporting -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - # create new model - org_model = model_fn().cuda() - - # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused) - model_copy = copy.deepcopy(org_model) - shard_former = ShardFormer(shard_config=shard_config) - if name == "transformers_chatglm": - sharded_model, _ = shard_former.optimize(model_copy, ChatGLMModelPolicy()) - else: - sharded_model, _ = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()) - sharded_model = sharded_model.cuda() - - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() torch.cuda.empty_cache() @@ -107,7 +134,7 @@ def check_chatglm(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_chatglm(): - spawn(check_chatglm, 2) + spawn(check_chatglm, 4) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py b/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py deleted file mode 100644 index ee474ac7be3b..000000000000 --- a/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py +++ /dev/null @@ -1,86 +0,0 @@ -import copy -import os - -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward - - -@parameterize('enable_fused_normalization', [False]) -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('use_lazy_init', [False]) -def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - # create new model for test - inputs = data_gen_fn() - inputs = {k: v.cuda() for k, v in inputs.items()} - input_ids = inputs['input_ids'] - hidden_size = 64 - batch_size, seq_len = input_ids.shape - hidden_state_shape = (seq_len, batch_size, hidden_size) - if name == "transformers_chatglm": - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init, ChatGLMModelPolicy()) - if stage_manager.is_last_stage(): - hidden_states = torch.randn(*hidden_state_shape).cuda() - inputs['input_ids'] = None - inputs['hidden_states'] = hidden_states - outputs = sharded_model(**inputs) - if stage_manager.is_last_stage(): - assert outputs[0].shape == hidden_state_shape - - else: - assert outputs['hidden_states'].shape == hidden_state_shape - - if name == "transformers_chatglm_for_conditional_generation": - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init, - ChatGLMForConditionalGenerationPolicy()) - if stage_manager.is_last_stage(): - hidden_states = torch.randn(*hidden_state_shape).cuda() - inputs['input_ids'] = None - inputs['hidden_states'] = hidden_states - outputs = sharded_model(**inputs) - if stage_manager.is_last_stage(): - assert outputs[0].shape == (batch_size, seq_len, 65024) - else: - assert outputs['hidden_states'].shape == hidden_state_shape - - torch.cuda.empty_cache() - - -def check_chatglm(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_chatglm_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_chatglm(): - spawn(check_chatglm, 4) - - -if __name__ == "__main__": - test_chatglm() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index ead14ab111e6..c5f8d22f18c9 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -2,69 +2,139 @@ import pytest import torch +from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) - # forward check - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5) + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group - # run backward - org_loss.backward() - shard_loss.backward() + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + if org_model.__class__.__name__ == 'LlamaModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + + check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) # unwrap model - if hasattr(org_model, 'model'): - llama_model = org_model.model - shard_llama_model = sharded_model.model - else: + if org_model.__class__.__name__ == 'LlamaModel': llama_model = org_model - shard_llama_model = sharded_model + shard_llama_model = sharded_model.unwrap() + else: + llama_model = org_model.model + shard_llama_model = sharded_model.unwrap().model # check grad - col_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] - row_layer_for_check = ['layers[0].self_attn.o_proj'] - check_grad(llama_model, shard_llama_model, col_layer_for_check, atol=1e-6, rtol=1e-4, dim=0, verbose=False) - check_grad(llama_model, shard_llama_model, row_layer_for_check, atol=1e-6, rtol=1e-4, dim=1, verbose=False) + row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] + col_layer_for_check = ['layers[0].self_attn.o_proj'] + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(llama_model, + shard_llama_model, + row_layer_for_check, + tp_group, + atol=1e-6, + rtol=1e-4, + dim=0, + verbose=False) + check_grad(llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=1e-6, + rtol=1e-4, + dim=1, + verbose=False) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=1e-4, + rtol=1e-3, + dim=1, + verbose=False) + + torch.cuda.empty_cache() -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('use_lazy_init', [False, True]) -def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, use_lazy_init): +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_fused_normalization': True, + 'use_lazy_init': True +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}, { + 'tp_size': 1, + 'pp_size': 4, + 'num_microbatches': 4, + 'use_lazy_init': False +}]) +def run_llama_test(test_config): + + # TODO: add test_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + # TODO: add test_config for flash attention & jit operator after supporting + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, use_lazy_init) - check_state_dict(org_model, sharded_model, name=name) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() torch.cuda.empty_cache() def check_llama(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_gpt2_llama() + run_llama_test() @pytest.mark.dist diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py deleted file mode 100644 index 6f1f0cb34508..000000000000 --- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py +++ /dev/null @@ -1,89 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.auto_policy import get_autopolicy -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.shard import ShardConfig -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward - - -def check_llama_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager): - policy = get_autopolicy(model) - policy.set_model(model) - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - policy.set_shard_config(model_config) - layers = policy.get_held_layers() - if stage_manager.is_first_stage(): - assert len(layers) == 2 + 1 - else: - if name == "transformers_llama": - assert len(layers) == 2 + 1 - else: - assert len(layers) == 2 + 2 - - -def check_llama_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager): - x = torch.randint(0, 1000, (2, 3)).cuda() - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask) - assert output['hidden_states'].shape == (2, 3, 128) - else: - hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() - attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - assert output[0] is not None - - -@parameterize('enable_fused_normalization', [False]) -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_llama -def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - PP_DIM = 0 - PP_SIZE = 2 - pg_mesh = ProcessGroupMesh(PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') - - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - check_llama_model_policy(name, org_model, stage_manager) - check_llama_model_pipeline_forward(name, sharded_model, stage_manager) - - torch.cuda.empty_cache() - - -def check_llama(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_llama_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama(): - spawn(check_llama, 2) - - -if __name__ == "__main__": - test_llama() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 99a278d4303a..d8fa8104bb07 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -1,64 +1,129 @@ -import copy import os import pytest import torch +from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5) +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group - # run backward - org_loss.backward() - shard_loss.backward() + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + if org_model.__class__.__name__ == 'OPTModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + + check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) # unwrap model - if hasattr(org_model, 'model'): - opt_model = org_model.model - shard_opt_model = sharded_model.model - else: + if org_model.__class__.__name__ == 'OPTModel': opt_model = org_model - shard_opt_model = sharded_model + shard_opt_model = sharded_model.unwrap() + else: + opt_model = org_model.model + shard_opt_model = sharded_model.unwrap().model # check grad - col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] - row_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] - check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False) - check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False) - - -@parameterize('use_lazy_init', [False, True]) -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -def run_opt_test(use_lazy_init, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, - enable_jit_fused): + row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] + col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(opt_model, + shard_opt_model, + row_layer_for_check, + tp_group, + atol=1e-6, + rtol=1e-3, + dim=0, + verbose=False) + check_grad(opt_model, + shard_opt_model, + col_layer_for_check, + tp_group, + atol=1e-6, + rtol=1e-3, + dim=1, + verbose=False) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(opt_model, + shard_opt_model, + col_layer_for_check, + tp_group, + atol=1e-3, + rtol=1e-3, + dim=1, + verbose=False) + + torch.cuda.empty_cache() + + +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': True, + 'use_lazy_init': True +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': False, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}]) +def run_opt_test(test_config): + + # TODO: add test_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + # TODO: add test_config for flash attention & jit operator after supporting + sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, enable_jit_fused, use_lazy_init) - check_state_dict(org_model, sharded_model, name=name) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt_pipeline.py b/tests/test_shardformer/test_model/test_shard_opt_pipeline.py deleted file mode 100644 index 0684418d0a8d..000000000000 --- a/tests/test_shardformer/test_model/test_shard_opt_pipeline.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_pipeline_model - - -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # TODO: add tests for forward/backward later - pass - - -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('enable_fused_normalization', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_opt -def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - inputs = data_gen_fn() - inputs = {k: v.cuda() for k, v in inputs.items()} - input_ids, _ = inputs['input_ids'], inputs['attention_mask'] - batch_size, seq_len = input_ids.shape - hidden_size = 128 - hidden_state_shape = (batch_size, seq_len, hidden_size) - - if not stage_manager.is_first_stage(): - # change inputs if not the first stage - - hidden_states = torch.zeros(*hidden_state_shape).cuda() - inputs['input_ids'] = None - inputs['hidden_states'] = hidden_states - - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - sharded_model.train() - - output = sharded_model(**inputs) - if stage_manager.is_last_stage(): - assert output[0] is not None - else: - assert output['hidden_states'].shape == hidden_state_shape - torch.cuda.empty_cache() - - -def check_opt(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_opt_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_opt(): - spawn(check_opt, 4) - - -if __name__ == "__main__": - test_opt() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index d179c8a8ee32..8a78d7c2b8ce 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -1,60 +1,127 @@ -import os - import pytest import torch import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group - assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3) + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): - # do backward - org_loss.backward() - shard_loss.backward() + if org_model.__class__.__name__ == 'ViTModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) # unwrap model if org_model.__class__.__name__ == 'ViTModel': vit_model = org_model - shard_vit_model = sharded_model + shard_vit_model = sharded_model.unwrap() else: vit_model = org_model.vit - shard_vit_model = sharded_model.vit + shard_vit_model = sharded_model.unwrap().vit # check grad - col_layer_for_check = ['encoder.layer[0].attention.attention.query'] - row_layer_for_check = ['encoder.layer[0].attention.output.dense'] - check_grad(vit_model, shard_vit_model, col_layer_for_check, atol=1e-5, rtol=1e-3, dim=0, verbose=False) - check_grad(vit_model, shard_vit_model, row_layer_for_check, atol=1e-5, rtol=1e-3, dim=1, verbose=False) + row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] + col_layer_for_check = ['encoder.layer[0].attention.output.dense'] + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(vit_model, + shard_vit_model, + row_layer_for_check, + tp_group, + atol=1e-5, + rtol=1e-3, + dim=0, + verbose=False) + check_grad(vit_model, + shard_vit_model, + col_layer_for_check, + tp_group, + atol=1e-5, + rtol=1e-3, + dim=1, + verbose=False) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(vit_model, + shard_vit_model, + col_layer_for_check, + tp_group, + atol=5e-3, + rtol=1e-3, + dim=1, + verbose=False) + torch.cuda.empty_cache() + + +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': False, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}]) +def run_vit_test(test_config): + + # TODO: add test_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + # TODO: add test_config for flash attention & jit operator after supporting + # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, enable_jit_fused) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() @@ -68,7 +135,7 @@ def check_vit(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_vit(): - spawn(check_vit, 2) + spawn(check_vit, 4) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py deleted file mode 100644 index 114992a2a2a5..000000000000 --- a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py +++ /dev/null @@ -1,74 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_pipeline_model - - -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # TODO: add tests for forward/backward later - pass - - -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('enable_fused_normalization', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_vit -def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') - - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - - inputs = data_gen_fn() - inputs = {k: v.cuda() for k, v in inputs.items()} - pixel_values = inputs['pixel_values'] - batch_size = len(pixel_values) - hidden_size = 768 - hidden_state_shape = (batch_size, 197, hidden_size) - - if not stage_manager.is_first_stage(): - # change inputs if not the first stage - hidden_states = torch.randn(*hidden_state_shape).cuda() - # inputs['pixel_values'] = None - inputs['hidden_states'] = hidden_states - - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - sharded_model.train() - - output = sharded_model(**inputs) - if stage_manager.is_last_stage(): - if name != 'transformers_vit': - assert output.loss is not None - else: - assert output['hidden_states'].shape == hidden_state_shape, \ - f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}' - - torch.cuda.empty_cache() - - -def check_vit(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_vit_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_vit(): - spawn(check_vit, 4) - - -if __name__ == "__main__": - test_vit() From d79f3cf9c57f830dc803a415b41415470661b30c Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 11 Aug 2023 16:40:06 +0800 Subject: [PATCH 74/80] [shardformer] update tests for all optimization (#4413) [shardformer] update tests for all optimization --- colossalai/shardformer/modeling/bert.py | 5 ++- tests/kit/model_zoo/transformers/bert.py | 29 +++++++++----- .../test_model/test_shard_bert.py | 39 +++++++++++++------ 3 files changed, 50 insertions(+), 23 deletions(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index b9d4b5fda7af..eaafd67b8968 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1048,9 +1048,12 @@ def forward( final_attention_mask = final_attention_mask * scale + attention_mask else: final_attention_mask = attention_mask + + if final_attention_mask is not None: batch_size, src_len = query_layer.size()[0], query_layer.size()[2] tgt_len = key_layer.size()[2] - final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len) + final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, + tgt_len).contiguous() query_layer = query_layer.permute(0, 2, 1, 3).contiguous() key_layer = key_layer.permute(0, 2, 1, 3).contiguous() diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 52158596bcf8..e16d3b269ba0 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -69,21 +69,30 @@ def data_gen_for_mcq(): # data['labels'] = torch.tensor([0], dtype=torch.int64) input_ids = torch.tensor([[[ 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, - 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102 + 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102, 5442, + 1012, 102, 102 ], [ 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096, - 2218, 1999, 1996, 2192, 1012, 102, 0, 0 + 2218, 1999, 1996, 2192, 1012, 102, 0, 0, 1012, 102, 0, 0 ]]]) - token_type_ids = torch.tensor( - [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, - 0]]]) - attention_mask = torch.tensor( - [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, - 0]]]) + token_type_ids = torch.tensor([[[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1 + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0 + ]]]) + attention_mask = torch.tensor([[[ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1 + ], + [ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0 + ]]]) labels = torch.tensor([0], dtype=torch.int64) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index fdbcd014e1b8..0a24e46d28f2 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -36,10 +36,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, tp_group = booster.plugin.tp_group # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == 'BertModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model if org_model.__class__.__name__ == 'BertModel': bert = org_model @@ -51,17 +55,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, col_layer_for_check = ['encoder.layer[0].output.dense'] row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): #check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3) #check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3) - check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) - check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) + check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) # check weights after optimizer.step() org_optimizer.step() sharded_optimizer.step() + if test_config['precision'] == 'fp32': + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False) + check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) torch.cuda.empty_cache() @@ -70,23 +82,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'use_lazy_init': True + 'use_lazy_init': True, + 'precision': 'fp32', }, { 'tp_size': 2, 'pp_size': 2, - 'num_microbatches': 4, - 'enable_fused_normalization': False, - 'use_lazy_init': False + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 4, 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }]) def run_bert_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') - test_config['precision'] = 'float' for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) From afbee8eeced7595f37bf9d960a941063714e4faa Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 14 Aug 2023 15:49:13 +0800 Subject: [PATCH 75/80] [shardformer]update t5 tests for using all optimizations. (#4407) * [shardformer] gpt2 tests fix [shardformer] test all optimizations (#4399) [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] gpt2 tests fix * [shardformer]update t5 to use all optimizations --- colossalai/shardformer/README.md | 2 +- tests/kit/model_zoo/transformers/t5.py | 8 ++-- .../test_model/test_shard_t5.py | 39 +++++++++++++------ 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 5d00e606dc94..7dc15f0a0635 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -30,7 +30,7 @@ ### Quick Start -The sample API usage is given below(If you enable the use of flash attention, please install xformers.): +The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.): ```python from colossalai.shardformer import ShardConfig, Shard diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 435cb6f46937..175d48963480 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -16,8 +16,8 @@ def data_gen_for_encoder_only(): # config = T5Config(decoder_start_token_id=0) # tokenizer = T5Tokenizer.from_pretrained("t5-small") # input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids - input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long() - attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long() + input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12, 1627, 5, 1, 12]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -26,7 +26,7 @@ def data_gen_for_conditional_generation(): # # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids data = data_gen_for_encoder_only() - labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long() + labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1, 229, 19250, 5, 1]]).long() data['labels'] = labels return data @@ -35,7 +35,7 @@ def data_gen_for_t5_model(): # decoder_inputs_ids is obtained with the following code # decoder_input_ids = model._shift_right(input_ids) data = data_gen_for_encoder_only() - decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long() + decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5, 19, 1627, 5, 5]]).long() data['decoder_input_ids'] = decoder_input_ids return data diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index d807ffa06296..fb065b42250b 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -37,11 +37,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ != 'T5ForConditionalGeneration': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model t5 = org_model @@ -50,14 +54,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] # check weights and gradients + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-5, rtol=1e-3, dim=0) + check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) # check weights after optimizer.step() org_optimizer.step() sharded_optimizer.step() + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) + check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) torch.cuda.empty_cache() @@ -66,23 +78,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 2, - 'enable_fused_normalization': True, - 'use_lazy_init': True + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'use_lazy_init': False + 'use_lazy_init': False, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 4, 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }, { 'tp_size': 1, 'pp_size': 4, 'num_microbatches': 4, - 'use_lazy_init': False + 'use_lazy_init': False, + 'precision': 'fp32', }]) @clear_cache_before_run() def run_t5_test(test_config): @@ -93,7 +111,6 @@ def run_t5_test(test_config): # TODO: add test_config for flash attention & jit operator after supporting sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') - test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): From 6968eb8582ed36b590fee1be2bb772fa104c6f95 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 14 Aug 2023 15:51:13 +0800 Subject: [PATCH 76/80] [shardformer] update bloom/llama/vit/chatglm tests (#4420) [shardformer] update bloom/llama/vit/chatglm tests [shardformer] update opt tests [shardformer] update opt tests [shardformer] update bloom/llama/vit/chatglm tests [shardformer] update bloom/llama/vit/chatglm tests [shardformer] update bloom/llama/vit/chatglm tests --- .../test_model/test_shard_bloom.py | 43 ++++++++++------ .../test_model/test_shard_chatglm.py | 48 ++++++++++------- .../test_model/test_shard_gpt2.py | 16 +++--- .../test_model/test_shard_llama.py | 49 +++++++++++------- .../test_model/test_shard_opt.py | 51 +++++++++++-------- .../test_model/test_shard_vit.py | 48 ++++++++++------- 6 files changed, 157 insertions(+), 98 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index d5a4ce083e2b..145ccf97c388 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -36,11 +36,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == 'BloomModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model if org_model.__class__.__name__ == 'BloomModel': @@ -54,14 +57,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] col_layer_for_check = ['h[0].self_attention.dense'] if stage_manager is None or stage_manager.is_first_stage(): - check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=0, verbose=False) - check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=1, verbose=False) + if test_config['precision'] == 'fp32': + atol, rtol = 1e-6, 1e-5 + else: + atol, rtol = 5e-3, 5e-3 + check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) + check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) # check weights after optimizer.step() org_optimizer.step() sharded_optimizer.step() if stage_manager is None or stage_manager.is_first_stage(): - check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) torch.cuda.empty_cache() @@ -70,29 +81,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 4, - 'enable_fused_normalization': True, - 'use_lazy_init': True + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'enable_fused_normalization': False, - 'use_lazy_init': False + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', }, { 'tp_size': 4, 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }]) def run_bloom_test(test_config): # TODO: add test_config for TP+DP after supporting & debugging it - # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - - # TODO: add test_config for flash attention & jit operator after supporting sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') - test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 69e63ffc854e..e9c74b300daa 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -37,11 +37,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == 'ChatGLMModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3, dim=1) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) - check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model if org_model.__class__.__name__ == 'ChatGLMModel': @@ -55,12 +59,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] col_layer_for_check = ['encoder.layers[0].self_attention.dense'] if stage_manager is None or stage_manager.is_first_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-6, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 check_grad(chatglm_model, shard_chatglm_model, row_layer_for_check, tp_group, - atol=1e-6, - rtol=1e-3, + atol=atol, + rtol=rtol, dim=0, verbose=False) @@ -68,8 +76,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, shard_chatglm_model, col_layer_for_check, tp_group, - atol=1e-6, - rtol=1e-3, + atol=atol, + rtol=rtol, dim=1, verbose=False) @@ -77,12 +85,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() if stage_manager is None or stage_manager.is_first_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 check_weight(chatglm_model, shard_chatglm_model, col_layer_for_check, tp_group, - atol=1e-4, - rtol=1e-3, + atol=atol, + rtol=rtol, dim=1, verbose=False) @@ -93,29 +105,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 4, - 'enable_fused_normalization': True, - 'use_lazy_init': True + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'enable_fused_normalization': False, - 'use_lazy_init': False + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', }, { 'tp_size': 4, 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }]) def run_chatglm_test(test_config): # TODO: add test_config for TP+DP after supporting & debugging it - # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - - # TODO: add test_config for flash attention & jit operator after supporting sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') - test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 274cfaa39ad1..8b7a6bf29c8b 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -63,22 +63,22 @@ def unwrap(module): row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] # check grad - if test_config['precision'] == 'fp32': - atol, rtol = 1e-4, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) # check weights after optimizer.step() org_optimizer.step() sharded_optimizer.step() - if test_config['precision'] == 'fp32': - atol, rtol = 5e-3, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index c5f8d22f18c9..fa4ee43e3114 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -41,11 +41,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == 'LlamaModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model if org_model.__class__.__name__ == 'LlamaModel': @@ -59,20 +63,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] col_layer_for_check = ['layers[0].self_attn.o_proj'] if stage_manager is None or stage_manager.is_first_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-6, 1e-4 + else: + atol, rtol = 5e-3, 5e-3 check_grad(llama_model, shard_llama_model, row_layer_for_check, tp_group, - atol=1e-6, - rtol=1e-4, + atol=atol, + rtol=rtol, dim=0, verbose=False) check_grad(llama_model, shard_llama_model, col_layer_for_check, tp_group, - atol=1e-6, - rtol=1e-4, + atol=atol, + rtol=rtol, dim=1, verbose=False) @@ -80,12 +88,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() if stage_manager is None or stage_manager.is_first_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 check_weight(llama_model, shard_llama_model, col_layer_for_check, tp_group, - atol=1e-4, - rtol=1e-3, + atol=atol, + rtol=rtol, dim=1, verbose=False) @@ -96,33 +108,34 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 2, - 'enable_fused_normalization': True, - 'use_lazy_init': True + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'use_lazy_init': False + 'use_lazy_init': False, + 'precision': 'fp32', }, { 'tp_size': 4, 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }, { 'tp_size': 1, 'pp_size': 4, 'num_microbatches': 4, - 'use_lazy_init': False + 'use_lazy_init': False, + 'precision': 'fp32', }]) def run_llama_test(test_config): # TODO: add test_config for TP+DP after supporting & debugging it - # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - - # TODO: add test_config for flash attention & jit operator after supporting sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') - test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index d8fa8104bb07..403c3e75f52c 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -41,11 +41,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == 'OPTModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model if org_model.__class__.__name__ == 'OPTModel': @@ -56,23 +59,27 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, shard_opt_model = sharded_model.unwrap().model # check grad - row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] + row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] if stage_manager is None or stage_manager.is_first_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-6, 1e-3 + else: + atol, rtol = 3e-2, 3e-2 check_grad(opt_model, shard_opt_model, row_layer_for_check, tp_group, - atol=1e-6, - rtol=1e-3, + atol=atol, + rtol=rtol, dim=0, verbose=False) check_grad(opt_model, shard_opt_model, col_layer_for_check, tp_group, - atol=1e-6, - rtol=1e-3, + atol=atol, + rtol=rtol, dim=1, verbose=False) @@ -80,12 +87,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() if stage_manager is None or stage_manager.is_first_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 check_weight(opt_model, shard_opt_model, col_layer_for_check, tp_group, - atol=1e-3, - rtol=1e-3, + atol=atol, + rtol=rtol, dim=1, verbose=False) @@ -96,29 +107,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 4, - 'enable_fused_normalization': True, - 'use_lazy_init': True + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'enable_fused_normalization': False, - 'use_lazy_init': False + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', }, { 'tp_size': 4, 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }]) def run_opt_test(test_config): # TODO: add test_config for TP+DP after supporting & debugging it - # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - - # TODO: add test_config for flash attention & jit operator after supporting sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') - test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 8a78d7c2b8ce..919bceffc847 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -37,11 +37,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == 'ViTModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model if org_model.__class__.__name__ == 'ViTModel': @@ -55,20 +59,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] col_layer_for_check = ['encoder.layer[0].attention.output.dense'] if stage_manager is None or stage_manager.is_first_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 check_grad(vit_model, shard_vit_model, row_layer_for_check, tp_group, - atol=1e-5, - rtol=1e-3, + atol=atol, + rtol=rtol, dim=0, verbose=False) check_grad(vit_model, shard_vit_model, col_layer_for_check, tp_group, - atol=1e-5, - rtol=1e-3, + atol=atol, + rtol=rtol, dim=1, verbose=False) @@ -76,12 +84,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() if stage_manager is None or stage_manager.is_first_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 check_weight(vit_model, shard_vit_model, col_layer_for_check, tp_group, - atol=5e-3, - rtol=1e-3, + atol=atol, + rtol=rtol, dim=1, verbose=False) @@ -92,30 +104,30 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 4, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'enable_fused_normalization': False, - 'use_lazy_init': False + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', }, { 'tp_size': 4, 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }]) def run_vit_test(test_config): # TODO: add test_config for TP+DP after supporting & debugging it - # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - - # TODO: add test_config for flash attention & jit operator after supporting # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') - test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) From e7dc4a3a7386005a967138d27ee8249accb40d40 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 14 Aug 2023 17:43:33 +0800 Subject: [PATCH 77/80] [misc] resolve code factor issues (#4433) --- colossalai/booster/booster.py | 2 +- colossalai/shardformer/layer/utils.py | 2 - colossalai/shardformer/modeling/bert.py | 8 +- colossalai/shardformer/modeling/bloom.py | 12 +- colossalai/shardformer/modeling/chatglm.py | 2 +- colossalai/shardformer/modeling/gpt2.py | 2 +- colossalai/shardformer/modeling/llama.py | 6 +- colossalai/shardformer/modeling/opt.py | 2 +- colossalai/shardformer/modeling/t5.py | 6 +- colossalai/shardformer/modeling/vit.py | 2 +- colossalai/shardformer/shard/shard_config.py | 1 - .../test_tracer/test_hf_model/test_hf_gpt.py | 2 +- .../test_model/test_pure_pipeline.py | 171 ------------------ .../test_model/test_shard_bloom.py | 2 +- .../test_model/test_shard_chatglm.py | 2 +- .../test_model/test_shard_gpt2.py | 2 +- .../test_model/test_shard_llama.py | 2 +- .../test_model/test_shard_opt.py | 2 +- .../test_model/test_shard_t5.py | 4 +- .../test_model/test_shard_vit.py | 4 +- 20 files changed, 31 insertions(+), 205 deletions(-) delete mode 100644 tests/test_shardformer/test_model/test_pure_pipeline.py diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 8a28b1286cfa..adb8f62a5084 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -139,7 +139,7 @@ def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: loss (torch.Tensor): The loss to be backpropagated. optimizer (Optimizer): The optimizer to be updated. """ - # TODO: implement this method with plugin + # TODO(frank lee): implement this method with plugin optimizer.backward(loss) def execute_pipeline(self, diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 09cb7bfe1407..577bef076a7e 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -29,8 +29,6 @@ class Randomizer: _INDEX = 0 def __init__(self, seed: int): - # TODO: remove colossalai.context.random - self.seed = seed # Handle CUDA rng state diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index eaafd67b8968..5bd1c531cc68 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -57,7 +57,7 @@ def bert_model_forward( hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage stage_index: Optional[List[int]] = None, ): - # TODO: add explaination of the output here. + # TODO(jianghai): add explaination of the output here. r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -113,7 +113,7 @@ def bert_model_forward( batch_size, seq_length = input_shape device = hidden_states.device - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -272,7 +272,7 @@ def bert_for_pretraining_forward( logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai) left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -534,7 +534,7 @@ def bert_for_next_sentence_prediction_forward( stage_index: Optional[List[int]] = None, **kwargs, ): - #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 57c45bc6adfa..12276635ecfa 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -252,7 +252,7 @@ def custom_forward(*inputs): # Add last hidden state hidden_states = self.ln_f(hidden_states) - # TODO: deal with all_hidden_states, all_self_attentions, presents + # TODO(jianghai): deal with all_hidden_states, all_self_attentions, presents if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -307,7 +307,7 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -402,7 +402,7 @@ def bloom_for_sequence_classification_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -431,7 +431,7 @@ def bloom_for_sequence_classification_forward( all_cross_attentions = None if stage_manager.is_last_stage(): batch_size = hidden_states.shape[0] - #update batch size + # update batch size hidden_states = transformer_outputs[0] logits = self.score(hidden_states) @@ -525,7 +525,7 @@ def bloom_for_token_classification_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -611,7 +611,7 @@ def bloom_for_question_answering_forward( logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py index a95966c3b99e..409e2e1f5497 100644 --- a/colossalai/shardformer/modeling/chatglm.py +++ b/colossalai/shardformer/modeling/chatglm.py @@ -152,7 +152,7 @@ def chatglm_model_forward( if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index a12a9796fa8a..47835d5d5468 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -57,7 +57,7 @@ def gpt2_model_forward( logger = logging.get_logger(__name__) # Preprocess passed in arguments - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2f54daac586a..f1d2998bbee4 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -65,7 +65,7 @@ def llama_model_forward( seq_length_with_past = seq_length past_key_values_length = 0 - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -216,7 +216,7 @@ def llama_for_causal_lm_forward( if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -301,7 +301,7 @@ def llama_for_sequence_classification_forward( logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 9afdfff4d71d..b4251f33b457 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -148,7 +148,7 @@ def opt_model_forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index d622da452366..9cc071f91dfc 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -50,7 +50,7 @@ def t5_stack_forward( logger = logging.get_logger(__name__) - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None @@ -285,7 +285,7 @@ def t5_model_forward( logger = logging.get_logger(__name__) - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None @@ -422,7 +422,7 @@ def t5_for_conditional_generation_forward( logger = logging.get_logger(__name__) - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index eb0ea4c7502b..9fc0b7488803 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -96,7 +96,7 @@ def pp_forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + # TODO(FoolPlayer): maybe have a cleaner way to cast the input (from `ImageProcessor` side?) expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype if pixel_values.dtype != expected_dtype: pixel_values = pixel_values.to(expected_dtype) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index ec6e0cd0d4be..0c28f115d018 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -29,7 +29,6 @@ class ShardConfig: enable_flash_attention: bool = False enable_jit_fused: bool = False - # TODO: add support for tensor parallel # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index e29afe786c46..1cd3b90db917 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -15,7 +15,7 @@ def test_gpt(): for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - # TODO: support the following models + # TODO(ver217): support the following models # 1. GPT2DoubleHeadsModel # as they are not supported, let's skip them if model.__class__.__name__ in ['GPT2DoubleHeadsModel', 'GPT2ForQuestionAnswering']: diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py deleted file mode 100644 index 31e76ef5107c..000000000000 --- a/tests/test_shardformer/test_model/test_pure_pipeline.py +++ /dev/null @@ -1,171 +0,0 @@ -import copy -import random -from typing import Any, Callable, Iterator, List, Optional, Tuple - -import numpy as np -import pytest -import torch -import torch.distributed as dist -from torch.nn import Module -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward - -DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 - - -class PipelineOptimizer(OptimizerWrapper): - - def __init__(self, optim: Optimizer, model: Module): - super().__init__(optim) - params = set(model.parameters()) - new_param_groups = [] - for group in optim.param_groups: - params = [p for p in group['params'] if p in params] - new_param_groups.append({**group, 'params': params}) - optim.__setstate__({'param_groups': new_param_groups}) - # TODO: support amp - - -class PipelinedModel(ModelWrapper): - - def __init__(self, module: Module, shard_config: ShardConfig, stage_manager: PipelineStageManager) -> None: - self.stage_manager = stage_manager - shardformer = ShardFormer(shard_config) - module, self.shared_params = shardformer.optimize(module) - self.shared_param_process_groups = [] - super().__init__(module) - - -def prepare_dataloader(dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0): - sampler = DistributedSampler( - dataset, - # rank=self.pg_mesh.coordinate(DP_AXIS), - shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - ) - - -def execute_pipeline( - data_iter: Iterator, - model: PipelinedModel, - criterion: Callable[[Any, Any], torch.Tensor], - optimizer: PipelineOptimizer, - return_loss: bool = True, - return_outputs: bool = False, - schedule: OneForwardOneBackwardSchedule = None, -) -> dict: - # return loss or outputs if needed - outputs = schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, return_outputs) - return outputs - - -class data_loader(): - - def __getitem__(self, x): - return torch.ones((4, 128), dtype=torch.int).cuda() * 10 - - -def loss(y, x): - return (y[0].float().mean() - x[0].float().mean()) - - -@parameterize('enable_fused_normalization', [False]) -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('use_lazy_init', [False]) -def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - PP_DIM = 0 - PP_SIZE = 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - - pg_mesh = ProcessGroupMesh(PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name != 'transformers_llama': - continue - num_microbatches = 2 - org_model = model_fn().cuda() - data_iter = iter(data_loader()) - - model_copy = copy.deepcopy(org_model) - batch = next(data_iter) - with torch.no_grad(): - y = model_copy(batch) - org_loss = loss(y, batch) - optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3) - schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager) - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - pipeline_stage_manager=stage_manager) - pipelined_model = PipelinedModel(org_model, shard_config, stage_manager) - pp_optimizer = PipelineOptimizer(optimizer, pipelined_model) - results = execute_pipeline(data_iter, pipelined_model, loss, pp_optimizer, schedule=schedule) - - if stage_manager.is_last_stage(): - assert results['loss'] == org_loss - else: - assert results['loss'] is None - assert results['outputs'] is None - torch.cuda.empty_cache() - - -def check_llama(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_llama_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama(): - spawn(check_llama, 2) - - -if __name__ == "__main__": - test_llama() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 145ccf97c388..ed0d1d8e401d 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -101,7 +101,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, }]) def run_bloom_test(test_config): - # TODO: add test_config for TP+DP after supporting & debugging it + # TODO(baizhou): add test_config for TP+DP after supporting & debugging it sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index e9c74b300daa..bb77759048b3 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -125,7 +125,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, }]) def run_chatglm_test(test_config): - # TODO: add test_config for TP+DP after supporting & debugging it + # TODO(baizhou): add test_config for TP+DP after supporting & debugging it sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 8b7a6bf29c8b..ca086bf12776 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -110,7 +110,7 @@ def unwrap(module): @clear_cache_before_run() def run_gpt2_test(test_config): - # TODO: add test_config for TP+DP after supporting & debugging it + # TODO(baizhou): add test_config for TP+DP after supporting & debugging it sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index fa4ee43e3114..30ebdfbe5cd9 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -133,7 +133,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, }]) def run_llama_test(test_config): - # TODO: add test_config for TP+DP after supporting & debugging it + # TODO(baizhou): add test_config for TP+DP after supporting & debugging it sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 403c3e75f52c..8d1154d82638 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -127,7 +127,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, }]) def run_opt_test(test_config): - # TODO: add test_config for TP+DP after supporting & debugging it + # TODO(baizhou): add test_config for TP+DP after supporting & debugging it sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index fb065b42250b..066f7ee815b4 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -105,10 +105,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @clear_cache_before_run() def run_t5_test(test_config): - # TODO: add plugin_config for TP+DP after supporting & debugging it + # TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - # TODO: add test_config for flash attention & jit operator after supporting + # TODO(baizhou): add test_config for flash attention & jit operator after supporting sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 919bceffc847..18df8ef555f2 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -124,8 +124,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, }]) def run_vit_test(test_config): - # TODO: add test_config for TP+DP after supporting & debugging it - # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models + # TODO(baizhou): add test_config for TP+DP after supporting & debugging it + # TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') From f7b82f5473af81f69e1d6fbfba9e04c27e4721f8 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 15 Aug 2023 17:07:29 +0800 Subject: [PATCH 78/80] [misc] update requirements --- requirements/requirements-test.txt | 4 +--- requirements/requirements.txt | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index a37d00326a08..ba5ea0936010 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -16,7 +16,5 @@ triton==2.0.0.dev20221202 requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece ninja -flash_attn>=2.0 +flash_attn==2.0.5 datasets -ninja -flash-attn>=2.0 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 65eecce2c34f..9aa5f2822e40 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,6 @@ click fabric contexttimer ninja -torch>=1.11 +torch>=1.12 safetensors -flash_attn>=2.0 einops From 48970f2f6bf6c2fdcf8a90104683b4214116e17a Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 15 Aug 2023 17:59:12 +0800 Subject: [PATCH 79/80] [shardformer] fix embedding --- colossalai/shardformer/layer/embedding.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index f07a93bd6908..847ca175ad57 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -214,6 +214,9 @@ def __init__(self, self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + # padding index + self.padding_idx = self._select_padding_idx(padding_idx) + # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) From 025c543df521adda8003a762e3a4c009560d1272 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 15 Aug 2023 18:56:16 +0800 Subject: [PATCH 80/80] [shardformer] fix import --- colossalai/shardformer/layer/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 0c44e6621711..c4586d18b90c 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,10 +3,11 @@ from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm +from .parallel_module import ParallelModule from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm', 'FusedLinear1D_Col' + 'FusedLayerNorm', 'FusedRMSNorm', 'FusedLinear1D_Col', 'ParallelModule' ]