diff --git a/python/mlx/nn/layers/distributed.py b/python/mlx/nn/layers/distributed.py index c29cd81d0..f1cef52e6 100644 --- a/python/mlx/nn/layers/distributed.py +++ b/python/mlx/nn/layers/distributed.py @@ -1,5 +1,6 @@ # Copyright © 2024 Apple Inc. +import math from functools import lru_cache from typing import Optional @@ -168,7 +169,7 @@ def __call__(self, x: mx.array) -> mx.array: if self.group.size() > 1: # Perform the local projection and aggregate the results x = x @ self["weight"].T - x = mx.distributed.all_sum(x, group=group) + x = mx.distributed.all_sum(x, group=self.group) # Add the bias if we have one if "bias" in self: @@ -316,9 +317,9 @@ def from_quantized_linear( bits=quantized_linear_layer.bits, group=group, ) - sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1 - sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1 - sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1 + sl.weight = quantized_linear_layer.weight[r * step : (r + 1) * step] * 1 + sl.scales = quantized_linear_layer.scales[r * step : (r + 1) * step] * 1 + sl.biases = quantized_linear_layer.biases[r * step : (r + 1) * step] * 1 if "bias" in quantized_linear_layer: sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1 @@ -413,7 +414,7 @@ def __call__(self, x: mx.array) -> mx.array: bits=self.bits, ) if self.group.size() > 1: - x = mx.distributed.sum_all(x, group=group) + x = mx.distributed.all_sum(x, group=self.group) if "bias" in self: x = x + self["bias"] return x @@ -428,6 +429,8 @@ def from_quantized_linear( N = group.size() r = group.rank() output_dims, input_dims = quantized_linear_layer.weight.shape + step = input_dims // N + step_grouped = quantized_linear_layer.scales.shape[1] // N input_dims *= (32 // quantized_linear_layer.bits) * N sl = cls( @@ -438,9 +441,15 @@ def from_quantized_linear( bits=quantized_linear_layer.bits, group=group, ) - sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1 - sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1 - sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1 + sl.weight = quantized_linear_layer.weight[:, r * step : (r + 1) * step] * 1 + sl.scales = ( + quantized_linear_layer.scales[:, r * step_grouped : (r + 1) * step_grouped] + * 1 + ) + sl.biases = ( + quantized_linear_layer.biases[:, r * step_grouped : (r + 1) * step_grouped] + * 1 + ) if "bias" in quantized_linear_layer: sl.bias = quantized_linear_layer.bias