Skip to content

Commit

Permalink
Add quantized distributed layers
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Nov 1, 2024
1 parent 7a0bb4f commit d63dee5
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 3 deletions.
295 changes: 293 additions & 2 deletions python/mlx/nn/layers/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ def __init__(
)

def _extra_repr(self) -> str:
out_dims, in_dims = self.weight.shape
N = self.group.size()
return f"input_dims={self.weight.shape[1]}, output_dims={N * self.weight.shape[0]}, bias={'bias' in self}"
out_dims *= N
return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}"

def __call__(self, x: mx.array) -> mx.array:
# Aggregate the gradients coming from each shard
Expand All @@ -86,13 +88,35 @@ def __call__(self, x: mx.array) -> mx.array:
x = x @ self["weight"].T
return x

@classmethod
def from_linear(
cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
output_dims, input_dims = linear_layer.weight.shape
step = output_dims // N

sl = cls(input_dims, output_dims, False, group)
# The multiplication with 1.0 forces a copy, perhaps change to
# something better when available.
sl.weight = linear_layer.weight[r * step : (r + 1) * step] * 1
if "bias" in linear_layer:
sl.bias = linear_layer.bias[r * step : (r + 1) * step] * 1

return sl


class ShardedToAllLinear(Module):
"""Each member of the group applies part of the affine transformation and
then aggregates the results.
All nodes will have the same exact result after this layer.
:class:`ShardedToAllLinear` provides a classmethod :meth:`from_linear` to
convert linear layers to sharded :obj:`ShardedToAllLinear` layers.
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
Expand Down Expand Up @@ -136,7 +160,9 @@ def __init__(

def _extra_repr(self) -> str:
N = self.group.size()
return f"input_dims={N * self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
out_dims, in_dims = self.weight.shape
in_dims *= N
return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}"

def __call__(self, x: mx.array) -> mx.array:
if self.group.size() > 1:
Expand All @@ -154,3 +180,268 @@ def __call__(self, x: mx.array) -> mx.array:
else:
x = x @ self["weight"].T
return x

@classmethod
def from_linear(
cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
output_dims, input_dims = linear_layer.weight.shape
step = input_dims // N

sl = cls(input_dims, output_dims, False, group)
# The multiplication with 1.0 forces a copy, perhaps change to
# something better when available.
sl.weight = linear_layer.weight[:, r * step : (r + 1) * step] * 1
if "bias" in linear_layer:
sl.bias = linear_layer.bias

return sl


class QuantizedAllToShardedLinear(Module):
"""Each member of the group applies part of the affine transformation with
a quantized matrix such that the result is sharded across the group.
It is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
will not be included in any gradient computation.
Args:
input_dims (int): The dimensionality of the input features.
output_dims (int): The dimensionality of the output features.
bias (bool, optional): If set to ``False`` then the layer will not use
a bias. Default: ``True``.
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
group (mx.distributed.Group, optional): The sharding will happen across
this group. If not set then the global group is used. Default is
``None``.
"""

def __init__(
self,
input_dims: int,
output_dims: int,
bias: bool = True,
group_size: int = 64,
bits: int = 4,
group: Optional[mx.distributed.Group] = None,
):
super().__init__()

# Quantization config
self.group_size = group_size
self.bits = bits

# Initialize the quantized weight
scale = math.sqrt(1.0 / input_dims)
self.group = group or mx.distributed.init()
N = self.group.size()

if (output_dims % N) != 0:
raise ValueError(
f"Cannot shard the output of size {output_dims} across {N} devices."
)

weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims // N, input_dims),
)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)

# And bias if needed
if bias:
self.bias = mx.zeros((output_dims // N,))

# Freeze this model's parameters
self.freeze()

def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze so that we unfreeze any layers we might contain but
our parameters will remain frozen."""
super().unfreeze(*args, **kwargs)
self.freeze(recurse=False)

def _extra_repr(self) -> str:
out_dims, in_dims = self.weight.shape
in_dims *= 32 // self.bits
out_dims *= self.group.size()
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}"
)

def __call__(self, x: mx.array) -> mx.array:
# Aggregate the gradients coming from each shard
if self.group.size() > 1:
x = sum_gradients(self.group)(x)

x = mx.quantized_matmul(
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
if "bias" in self:
x = x + self["bias"]
return x

@classmethod
def from_quantized_linear(
cls,
quantized_linear_layer: Module,
group: Optional[mx.distributed.Group] = None,
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
output_dims, input_dims = quantized_linear_layer.weight.shape
input_dims *= 32 // quantized_linear_layer.bits
step = output_dims // N

sl = cls(
input_dims,
output_dims,
False,
group_size=quantized_linear_layer.group_size,
bits=quantized_linear_layer.bits,
group=group,
)
sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1
sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1
sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1
if "bias" in quantized_linear_layer:
sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1

return sl


class QuantizedShardedToAllLinear(Module):
"""Each member of the group applies part of the affine transformation using
the quantized matrix and then aggregates the results.
All nodes will have the same exact result after this layer.
It is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
will not be included in any gradient computation.
Args:
input_dims (int): The dimensionality of the input features.
output_dims (int): The dimensionality of the output features.
bias (bool, optional): If set to ``False`` then the layer will not use
a bias. Default: ``True``.
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``.
group (mx.distributed.Group, optional): The sharding will happen across
this group. If not set then the global group is used. Default is
``None``.
"""

def __init__(
self,
input_dims: int,
output_dims: int,
bias: bool = True,
group_size: int = 64,
bits: int = 4,
group: Optional[mx.distributed.Group] = None,
):
super().__init__()

# Quantization config
self.group_size = group_size
self.bits = bits

# Initialize the quantized weight
scale = math.sqrt(1.0 / input_dims)
self.group = group or mx.distributed.init()
N = self.group.size()

if (input_dims % N) != 0:
raise ValueError(
f"The input of size {input_dims} cannot be sharded across {N} devices."
)

weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims // N),
)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)

# And bias if needed
if bias:
self.bias = mx.zeros((output_dims,))

# Freeze this model's parameters
self.freeze()

def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze so that we unfreeze any layers we might contain but
our parameters will remain frozen."""
super().unfreeze(*args, **kwargs)
self.freeze(recurse=False)

def _extra_repr(self) -> str:
out_dims, in_dims = self.weight.shape
in_dims *= (32 // self.bits) * self.group.size()
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}"
)

def __call__(self, x: mx.array) -> mx.array:
x = mx.quantized_matmul(
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
if self.group.size() > 1:
x = mx.distributed.sum_all(x, group=group)
if "bias" in self:
x = x + self["bias"]
return x

@classmethod
def from_quantized_linear(
cls,
quantized_linear_layer: Module,
group: Optional[mx.distributed.Group] = None,
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
output_dims, input_dims = quantized_linear_layer.weight.shape
input_dims *= (32 // quantized_linear_layer.bits) * N

sl = cls(
input_dims,
output_dims,
False,
group_size=quantized_linear_layer.group_size,
bits=quantized_linear_layer.bits,
group=group,
)
sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1
sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1
sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1
if "bias" in quantized_linear_layer:
sl.bias = quantized_linear_layer.bias

return sl
2 changes: 1 addition & 1 deletion python/mlx/nn/layers/quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _extra_repr(self):
out_dims, in_dims = self.weight.shape
in_dims *= 32 // self.bits
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}"
)

Expand Down

0 comments on commit d63dee5

Please sign in to comment.