diff --git a/internlm/model/utils.py b/internlm/model/utils.py index b9c7c03a..19531e4a 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,12 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Any, Optional, Union +from typing import Optional import fused_dense_lib as fused_dense_cuda import torch import torch.nn.functional as F -from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw +from flash_attn.utils.distributed import all_reduce_raw from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup @@ -397,7 +397,6 @@ def backward(ctx, grad_output, *args): grad_input = grad_input.contiguous() process_group = ctx.process_group all_gather_handler = ctx.all_gather_handler - module = ctx.module block_index = ctx.block_index module_name = ctx.module_name diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d5fec315..cb8aa659 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -11,10 +11,7 @@ from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import ( - release_reduce_scatter_memory_pool, - split_forward_gather_backward, -) +from internlm.model.utils import release_reduce_scatter_memory_pool from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index 228045ed..f486ccec 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -45,7 +45,7 @@ def __init__(self, group_id, dp_parallel_mode): def num_elements_in_bucket(self, reduce_rank: int = None): return self._num_elements_in_bucket[reduce_rank] - + def num_params_in_bucket(self, reduce_rank: int = None): return len(self._params[reduce_rank])