From 071fb868ae42f042d885ed1b75fa83c2267de380 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 15 Jul 2024 13:23:51 -0700 Subject: [PATCH] Add distributed layers to nn top-level --- python/mlx/nn/layers/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index f528c9908..a2881486d 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -55,6 +55,12 @@ from mlx.nn.layers.base import Module from mlx.nn.layers.containers import Sequential from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d +from mlx.nn.layers.distributed import ( + AllToShardedLinear, + QuantizedAllToShardedLinear, + QuantizedShardedToAllLinear, + ShardedToAllLinear, +) from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Bilinear, Identity, Linear