diff --git a/bluefog/torch/mpi_ops.py b/bluefog/torch/mpi_ops.py index 4c6b1dbc..2b50d642 100644 --- a/bluefog/torch/mpi_ops.py +++ b/bluefog/torch/mpi_ops.py @@ -21,6 +21,7 @@ import torch from bluefog.torch import mpi_lib # C library +from bluefog.torch.utility import deprecated_function_arg from bluefog.common.basics import BlueFogBasics, logger from bluefog.common.topology_util import GetRecvWeights @@ -528,7 +529,8 @@ def _neighbor_allreduce_nonblocking(tensor, output, self_weight, src_weights, _handle_map[handle] = (tensor, output) return handle - +@deprecated_function_arg(arg_name="neighbor_weights", fix="Use src_weights instead") +@deprecated_function_arg(arg_name="send_neighbors", fix="Use dst_weights instead") def neighbor_allreduce(tensor: torch.Tensor, *, self_weight: Optional[float] = None, src_weights: Optional[Dict[int, float]] = None, @@ -548,9 +550,9 @@ def neighbor_allreduce(tensor: torch.Tensor, *, Arguments: tensor: A tensor to execute weighted average with neighbors. - self_weight: The weight for self node, used with neighbor_weights. + self_weight: The weight for self node, used with src_weights. src_weights: The weights for in-neighbor nodes, used with self weight. - If neighbor_weights is presented, the return tensor will return the weighted average + If src_weights is presented, the return tensor will return the weighted average defined by these weights and the self_weight. If not, the return tensor will return the weighted average defined by the topology weights is provided or uniformly average. The data structure of weights should be {rank : weight} and rank has to belong to the @@ -560,7 +562,7 @@ def neighbor_allreduce(tensor: torch.Tensor, *, part of (out-)neighbors will be sent to. If set to be a list, assume all the weights are one. In this mode, this node sends its value to partial neighbors listed in this variable in a dynamic graph, and `self_weight` and `src_weights` must be present. - enable_topo_check: When send_neighbors is present, enabling this option checks if the + enable_topo_check: When dst_weights is present, enabling this option checks if the sending and recieving neighbors match with each other. Disabling this check can boost the performance. name: A name of the reduction operation. @@ -584,6 +586,8 @@ def neighbor_allreduce(tensor: torch.Tensor, *, return synchronize(handle) +@deprecated_function_arg(arg_name="neighbor_weights", fix="Use src_weights instead") +@deprecated_function_arg(arg_name="send_neighbors", fix="Use dst_weights instead") def neighbor_allreduce_nonblocking(tensor: torch.Tensor, *, self_weight: Optional[float] = None, src_weights: Optional[Dict[int, float]] = None, @@ -603,9 +607,9 @@ def neighbor_allreduce_nonblocking(tensor: torch.Tensor, *, Arguments: tensor: A tensor to execute weighted average with neighbors. - self_weight: The weight for self node, used with neighbor_weights. + self_weight: The weight for self node, used with src_weights. src_weights: The weights for in-neighbor nodes, used with self weight. - If neighbor_weights is presented, the return tensor will return the weighted average + If src_weights is presented, the return tensor will return the weighted average defined by these weights and the self_weight. If not, the return tensor will return the weighted average defined by the topology weights is provided or uniformly average. The data structure of weights should be {rank : weight} and rank has to belong to the @@ -615,7 +619,7 @@ def neighbor_allreduce_nonblocking(tensor: torch.Tensor, *, part of (out-)neighbors will be sent to. If set to be a list, assume all the weights are one. In this mode, this node sends its value to partial neighbors listed in this variable in a dynamic graph, and `self_weight` and `src_weights` must be present. - enable_topo_check: When send_neighbors is present, enabling this option checks if the + enable_topo_check: When dst_weights is present, enabling this option checks if the sending and recieving neighbors match with each other. Disabling this check can boost the performance. name: A name of the neighbor_allreduce operation. diff --git a/bluefog/torch/utility.py b/bluefog/torch/utility.py index 73337207..7d35f56a 100644 --- a/bluefog/torch/utility.py +++ b/bluefog/torch/utility.py @@ -15,12 +15,14 @@ # ============================================================================== from typing import Any, List, Optional +from functools import wraps import collections import numpy as np import torch import bluefog.torch as bf + def broadcast_parameters(params, root_rank): """ Broadcasts the parameters from root rank to all other processes. @@ -212,3 +214,16 @@ def _from_tensor(): for key, p in params: if key in callbacks: callbacks[key]() + + +def deprecated_function_arg(arg_name: str, fix: str): + def deprecated_decorator(f): + @wraps(f) + def wrapper(*arg, **kwargs): + if arg_name in kwargs: + raise TypeError(f"{arg_name} is deprecated in {f.__name__}: {fix}") + return f(*arg, **kwargs) + + return wrapper + + return deprecated_decorator