Skip to content

Commit

Permalink
Add distributed layers to nn top-level
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Nov 5, 2024
1 parent 060e1c9 commit a8b3da7
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/mlx/nn/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@
ConvTranspose2d,
ConvTranspose3d,
)
from mlx.nn.layers.distributed import (
AllToShardedLinear,
QuantizedAllToShardedLinear,
QuantizedShardedToAllLinear,
ShardedToAllLinear,
)
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear
Expand Down

0 comments on commit a8b3da7

Please sign in to comment.