Skip to content

Commit

Permalink
Add distributed layers to nn top-level
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Jul 15, 2024
1 parent 84e1444 commit 071fb86
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/mlx/nn/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@
from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d, Conv3d
from mlx.nn.layers.distributed import (
AllToShardedLinear,
QuantizedAllToShardedLinear,
QuantizedShardedToAllLinear,
ShardedToAllLinear,
)
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear
Expand Down

0 comments on commit 071fb86

Please sign in to comment.