Skip to content

Commit

Permalink
feat(optimizer/hybrid_zero_optim.py): fix lint error
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Oct 20, 2023
1 parent 3c69254 commit eac382a
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 8 deletions.
5 changes: 2 additions & 3 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import Any, Optional, Union
from typing import Optional

import fused_dense_lib as fused_dense_cuda
import torch
import torch.nn.functional as F
from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw
from flash_attn.utils.distributed import all_reduce_raw
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -397,7 +397,6 @@ def backward(ctx, grad_output, *args):
grad_input = grad_input.contiguous()
process_group = ctx.process_group
all_gather_handler = ctx.all_gather_handler
module = ctx.module
block_index = ctx.block_index
module_name = ctx.module_name

Expand Down
5 changes: 1 addition & 4 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@

from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import (
release_reduce_scatter_memory_pool,
split_forward_gather_backward,
)
from internlm.model.utils import release_reduce_scatter_memory_pool
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
BucketStore,
Expand Down
2 changes: 1 addition & 1 deletion internlm/solver/optimizer/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, group_id, dp_parallel_mode):

def num_elements_in_bucket(self, reduce_rank: int = None):
return self._num_elements_in_bucket[reduce_rank]

def num_params_in_bucket(self, reduce_rank: int = None):
return len(self._params[reduce_rank])

Expand Down

0 comments on commit eac382a

Please sign in to comment.