From a8b3da7946cb49ce32eba262e9ab321e524cbd2e Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 15 Jul 2024 13:23:51 -0700 Subject: [PATCH] Add distributed layers to nn top-level --- python/mlx/nn/layers/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 3cf5e33a8..1dd530765 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -60,6 +60,12 @@ ConvTranspose2d, ConvTranspose3d, ) +from mlx.nn.layers.distributed import ( + AllToShardedLinear, + QuantizedAllToShardedLinear, + QuantizedShardedToAllLinear, + ShardedToAllLinear, +) from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Bilinear, Identity, Linear