From 1a535685e1b130a7511dce6c2bdf07499a09ac67 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 22 Feb 2024 14:11:56 -0500 Subject: [PATCH] lint --- composer/trainer/mosaic_fsdp_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index b08a1137b6..1da6811b09 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -658,8 +658,8 @@ def _init_core_state_t2p3p0( ) -> _FSDPState: if sharding_strategy == ShardingStrategy.NO_SHARD: warnings.warn( - "The `NO_SHARD` sharding strategy is deprecated. If having issues, " - "please use DistributedDataParallel instead.", + 'The `NO_SHARD` sharding strategy is deprecated. If having issues, ' + 'please use DistributedDataParallel instead.', # Level 1 is here, level 2 is from `FullyShardedDataParallel`, and # level 3 is from the true caller stacklevel=3, @@ -668,10 +668,10 @@ def _init_core_state_t2p3p0( state.mixed_precision = mixed_precision or MixedPrecision() if mixed_precision is not None: torch._C._log_api_usage_once( - f"torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}" + f'torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}' ) state._use_full_prec_in_eval = ( - os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1" + os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, '') == '1' ) state.cpu_offload = cpu_offload or CPUOffload() state.limit_all_gathers = limit_all_gathers @@ -687,7 +687,7 @@ def _init_core_state_t2p3p0( ) # Mapping from fully sharded module to the handles it is responsible to # unshard and reshard (see [Note: Fully Sharded Module]) - _fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = dict() + _fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = {} state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle # Invariant: `state.params` contains exactly the `FlatParameter`s of the # handles in `state._handle`