Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman100 committed Nov 21, 2023
1 parent cbb25ea commit d18b7e3
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2986,13 +2986,8 @@ def get_state_dict(self, model, unwrap=True):

state_dict = clone_tensors_for_torch_save(self.unwrap_model(model).state_dict())
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp import (
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
Expand Down

0 comments on commit d18b7e3

Please sign in to comment.