Skip to content

Commit

Permalink
[Doptim] refine APIs and add more checks for use_distributed_optimize…
Browse files Browse the repository at this point in the history
…r settings (#15)
  • Loading branch information
Vremold authored Mar 27, 2024
1 parent 878da06 commit 9a937a3
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 60 deletions.
7 changes: 7 additions & 0 deletions python/vescale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from vescale.dtensor import zeros, ones, empty, full, randn
from vescale.dtensor._utils import equal, allclose
from vescale.dmp import auto_parallelize_module, set_plan_overriding_policy, get_plan_overriding_policy
from vescale.ddp.distributed_data_parallel import DistributedDataParallel
from vescale.optim.distributed_optimizer import DistributedOptimizer
from vescale.optim.base_optimizer import BasicOptimizer, BasicOptimizerHook
from vescale.initialize.deferred_init import deferred_init, is_deferred, materialize_dtensor, materialize_dparameter

# All public APIs from vescale package
Expand Down Expand Up @@ -69,6 +72,10 @@
"materialize_dtensor",
"materialize_dparameter",
"deprecated_function",
"DistributedDataParallel",
"DistributedOptimizer",
"BasicOptimizer",
"BasicOptimizerHook",
]


Expand Down
97 changes: 51 additions & 46 deletions python/vescale/optim/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,49 @@
logger = logging.getLogger(__name__)


class GradOptimizerHookBase(ABC):
"""
Abstract base class for hooks, that needed to be run before
and after `optim.step`.
"""

@abstractstaticmethod
def step_pre_hook(optim, *args, **kwargs):
return NotImplementedError("not impl")

@abstractstaticmethod
def step_post_hook(optim, *args, **kwargs):
raise NotImplementedError("not impl")


class BasicOptimizerHook(GradOptimizerHookBase):
"""
A GradOptimizerHookBase subclass, that is responsible to 'fill' flattened
main_grad in DDP to PyTorch '.grad' fields.
Example: see example codes for `BasicOptimizer`
"""

def step_pre_hook(optim, *args, **kwargs):
for param_group in optim.param_groups:
for p in param_group["params"]:
if p.grad is not None:
continue
if p.main_grad is None:
continue
if isinstance(p.data, DTensor):
dtensor_placements = p.data.placements
dtensor_device_mesh = p.data.device_mesh
p.grad = DTensor.from_local(
p.main_grad, device_mesh=dtensor_device_mesh, placements=dtensor_placements
)
else:
p.grad = p.main_grad

def step_post_hook(optim, *args, **kwargs):
return None


class OptimizerBase(ABC):
"""
Abstract base class for vescale optimizer wrapper.
Expand Down Expand Up @@ -70,21 +113,6 @@ def get_loss_scale(self):
return 1.0


class GradOptimizerHookBase(ABC):
"""
Abstract base class for hooks, that needed to be run before
and after `optim.step`.
"""

@abstractstaticmethod
def step_pre_hook(optim, *args, **kwargs):
return NotImplementedError("not impl")

@abstractstaticmethod
def step_post_hook(optim, *args, **kwargs):
raise NotImplementedError("not impl")


class BasicOptimizer(OptimizerBase):
"""
A simple wrapper around a concrete optimizer instance. It provides basic
Expand All @@ -108,14 +136,14 @@ class BasicOptimizer(OptimizerBase):
# One only need to wrap a Adam optimizer by BasicOptimizer, then everything,
# like flattened main_grad in DDP world will be hidden.
from vescale.optim.base_optimizer import BasicOptimizer, BaseOptimizerHook
from vescale.optim.base_optimizer import BasicOptimizer, BasicOptimizerHook
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.dmodule.api import parallelize_module
mlp = parallelize_module(MLP(), mesh, ..., ...)
ddp_model = DDP(mlp, ...)
optim = torch.optim.Adam(model.parameters())
optim_wrapper = BasicOptimizer(optim, mlp, grad_hook=BaseOptimizerHook())
optim_wrapper = BasicOptimizer(optim, mlp)
# do the forward and backward
ddp_model(torch.rand(xxx)).sum().backward()
Expand All @@ -128,13 +156,18 @@ def __init__(
self,
optimizer,
models: Union[torch.nn.Module, List[torch.nn.Module]],
grad_hook: Optional[GradOptimizerHookBase] = None,
grad_hook: Optional[GradOptimizerHookBase] = BasicOptimizerHook(),
) -> None:
super().__init__(optimizer=optimizer)
self.models = models
if not isinstance(self.models, List):
self.models = [self.models]

if any(getattr(x, "use_distributed_optimizer", False) for x in self.models):
raise RuntimeError(
"detected DDP with use_distributed_optimizer on, please consider use a distributed optimizer"
)

if grad_hook is not None:
self.register_optimizer_hook(grad_hook)

Expand Down Expand Up @@ -166,31 +199,3 @@ def get_loss_scale(self):
def register_optimizer_hook(self, grad_hook: GradOptimizerHookBase):
self.optimizer.register_step_pre_hook(grad_hook.step_pre_hook)
self.optimizer.register_step_post_hook(grad_hook.step_post_hook)


class BaseOptimizerHook(GradOptimizerHookBase):
"""
A GradOptimizerHookBase subclass, that is responsible to 'fill' flattened
main_grad in DDP to PyTorch '.grad' fields.
Example: see example codes for `BasicOptimizer`
"""

def step_pre_hook(optim, *args, **kwargs):
for param_group in optim.param_groups:
for p in param_group["params"]:
if p.grad is not None:
continue
if p.main_grad is None:
continue
if isinstance(p.data, DTensor):
dtensor_placements = p.data.placements
dtensor_device_mesh = p.data.device_mesh
p.grad = DTensor.from_local(
p.main_grad, device_mesh=dtensor_device_mesh, placements=dtensor_placements
)
else:
p.grad = p.main_grad

def step_post_hook(optim, *args, **kwargs):
return None
16 changes: 9 additions & 7 deletions python/vescale/optim/distributed_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def __init__(
self.data_parallel_group = m.data_parallel_group
elif self.data_parallel_group != m.data_parallel_group:
raise RuntimeError("Detect model chunks of warious data-parallel process groups")
if not all(x.use_distributed_optimizer for x in models):
print(
"You are using a distributed optimizer, it's suggested to set use_distributed_optimizer on for better performance"
)

param_dtype_cnt = {}
main_param_dtype_cnt = 0
Expand Down Expand Up @@ -545,10 +549,9 @@ def build_model_and_main_param_groups(self, model_gbuf_ranges, param_gbuf_map, o
shard_main_param = shard_model_param.clone().float()
# copy sharded info from DTensor
shard_model_param._spec = None if not isinstance(model_param, DTensor) else model_param._spec
# TODO: we need to find another way to judge whether a param is shared
# if hasattr(model_param, "shared"):
# shard_model_param.shared = model_param.shared
# shard_main_param.shared = model_param.shared
if hasattr(model_param, "shared"):
shard_model_param.shared = model_param.shared
shard_main_param.shared = model_param.shared

# Add to group.
model_float16_params_this_group.append(model_param)
Expand All @@ -563,9 +566,8 @@ def build_model_and_main_param_groups(self, model_gbuf_ranges, param_gbuf_map, o
shard_model_param = model_param_tensor.view(-1)[param_range.start : param_range.end]
model_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param)
# TODO: we need to find another way to judge whether a param is shared
# if hasattr(model_param, "shared"):
# shard_model_param.shared = model_param.shared
if hasattr(model_param, "shared"):
shard_model_param.shared = model_param.shared

# copy sharded info from DTensor
shard_model_param._spec = None if not isinstance(model_param, DTensor) else model_param._spec
Expand Down
3 changes: 1 addition & 2 deletions python/vescale/optim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,5 @@ def param_is_sharded_or_replicate_on_first_rank(param):
return False


# TODO:
def param_is_shared(param):
return False
return getattr(param, "shared", False)
4 changes: 2 additions & 2 deletions test/parallel/ddp_optim/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from vescale.dtensor.api import redistribute_dtensor
from vescale.dmodule.api import parallelize_module
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.optim.base_optimizer import BasicOptimizer, BaseOptimizerHook
from vescale.optim.base_optimizer import BasicOptimizer, BasicOptimizerHook

from common_dtensor import DTensorTestBase, with_comms
from test_models.mlp import (
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_ddp_e2e(self, overlap_grad_reduce: bool):
)

ve_optimizer = torch.optim.Adam(ve_model.parameters(), lr=0.01)
ve_optimizer = BasicOptimizer(ve_optimizer, models=ve_model, grad_hook=BaseOptimizerHook)
ve_optimizer = BasicOptimizer(ve_optimizer, models=ve_model, grad_hook=BasicOptimizerHook)

# epoch 1
ve_optimizer.zero_grad()
Expand Down
6 changes: 3 additions & 3 deletions test/parallel/ddp_optim/test_grad_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from vescale.dtensor.placement_types import Replicate, Shard
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.optim.distributed_optimizer import DistributedOptimizer
from vescale.optim.base_optimizer import BasicOptimizer, BaseOptimizerHook
from vescale.optim.base_optimizer import BasicOptimizer, BasicOptimizerHook

HIDDEN_DIM = 16
BSZ = 2
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_basic(self, grad_sync):
grad_sync=grad_sync,
)
optimizer = torch.optim.Adam(m.parameters(), lr=1e-3)
optimizer = BasicOptimizer(optimizer, models=m, grad_hook=BaseOptimizerHook)
optimizer = BasicOptimizer(optimizer, models=m, grad_hook=BasicOptimizerHook)

dx = distribute_tensor(torch.rand(BSZ, SEQ_LEN, HIDDEN_DIM), device_mesh, inout_sharding)
dout = m(dx)
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_ddp(self, overlap_grad_reduce: bool, use_distributed_optimizer: bool):
models=[ddp_m],
)
else:
optimizer = BasicOptimizer(optimizer, models=ddp_m, grad_hook=BaseOptimizerHook)
optimizer = BasicOptimizer(optimizer, models=ddp_m, grad_hook=BasicOptimizerHook)
optimizer.zero_grad()
if torch.distributed.get_rank() in (0, 1):
dx = distribute_tensor(batch_1, tp_submesh, inout_sharding)
Expand Down
Loading

0 comments on commit 9a937a3

Please sign in to comment.