Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solves #1512 #1526

Merged
merged 13 commits into from
Nov 20, 2024
9 changes: 8 additions & 1 deletion python/mlx/nn/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,14 @@
LayerNorm,
RMSNorm,
)
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
from mlx.nn.layers.pooling import (
AvgPool1d,
AvgPool2d,
AvgPool3d,
MaxPool1d,
MaxPool2d,
MaxPool3d,
)
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize
from mlx.nn.layers.recurrent import GRU, LSTM, RNN
Expand Down
125 changes: 125 additions & 0 deletions python/mlx/nn/layers/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,30 @@ def __init__(
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)


class _Pool3d(_Pool):
def __init__(
self,
pooling_function,
padding_value,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
):
class_name = type(self).__name__
msg = "[{}] '{}' must be an integer or a tuple containing 3 integers"
kernel_size = _value_or_list(
kernel_size, 3, msg.format(class_name, "kernel_size")
)
if stride is not None:
stride = _value_or_list(stride, 3, msg.format(class_name, "stride"))
else:
stride = kernel_size
padding = _value_or_list(padding, 3, msg.format(class_name, "padding"))
padding = [(p, p) for p in padding]

super().__init__(pooling_function, kernel_size, stride, padding, padding_value)


class MaxPool1d(_Pool1d):
r"""Applies 1-dimensional max pooling.

Expand Down Expand Up @@ -332,3 +356,104 @@ def __init__(
padding: Optional[Union[int, Tuple[int, int]]] = 0,
):
super().__init__(mx.mean, 0, kernel_size, stride, padding)


class MaxPool3d(_Pool3d):
"""
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
H_{out}, W_{out}, C)`, given by:

.. math::
\begin{aligned}
\text{out}(N_i, d, h, w, C_j) = & \max_{l=0, \ldots, k_D-1} \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
& \text{input}(N_i, \text{stride[0]} \times d + l,
\text{stride[1]} \times h + m,
\text{stride[2]} \times w + n, C_j),
\end{aligned}

where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.

The parameters ``kernel_size``, ``stride``, ``padding``, can either be:

- a single ``int`` -- in which case the same value is used for the depth,
height and width axis;
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth axis, the second ``int`` for the height axis, and the third
``int`` for the width axis.

Args:
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
stride (int or tuple(int, int, int), optional): The stride of the pooling
window. Default: ``kernel_size``.
padding (int or tuple(int, int, int), optional): How much negative infinity
padding to apply to the input. The padding is applied on both sides
of the depth, height and width axis. Default: ``0``.

Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
>>> pool = nn.MaxPool3d(kernel_size=2, stride=2)
>>> pool(x)
"""

def __init__(
self,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
):
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)


class AvgPool3d(_Pool3d):
"""
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
H_{out}, W_{out}, C)`, given by:

.. math::
\begin{aligned}
\text{out}(N_i, d, h, w, C_j) = & \frac{1}{k_D k_H k_W} \sum_{l=0, \ldots, k_D-1} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
& \text{input}(N_i, \text{stride[0]} \times d + l,
\text{stride[1]} \times h + m,
\text{stride[2]} \times w + n, C_j),
\end{aligned}

where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.

The parameters ``kernel_size``, ``stride``, ``padding``, can either be:

- a single ``int`` -- in which case the same value is used for the depth,
height and width axis;
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
for the depth axis, the second ``int`` for the height axis, and the third
``int`` for the width axis.

Args:
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
stride (int or tuple(int, int, int), optional): The stride of the pooling
window. Default: ``kernel_size``.
padding (int or tuple(int, int, int), optional): How much zero
padding to apply to the input. The padding is applied on both sides
of the depth, height and width axis. Default: ``0``.

Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
>>> pool = nn.AvgPool3d(kernel_size=2, stride=2)
>>> pool(x)
"""
def __init__(
self,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Optional[Union[int, Tuple[int, int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
):
super().__init__(mx.mean, 0, kernel_size, stride, padding)
117 changes: 117 additions & 0 deletions python/tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,123 @@ def test_pooling(self):
str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))),
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
)
# Test 3d pooling
x = mx.array(
[
[
[
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
[[9, 10, 11], [12, 13, 14], [15, 16, 17]],
[[18, 19, 20], [21, 22, 23], [24, 25, 26]],
],
[
[[27, 28, 29], [30, 31, 32], [33, 34, 35]],
[[36, 37, 38], [39, 40, 41], [42, 43, 44]],
[[45, 46, 47], [48, 49, 50], [51, 52, 53]],
],
]
]
)
expected_max_pool_output_no_padding_stride_1 = [
[[[[39, 40, 41], [42, 43, 44]], [[48, 49, 50], [51, 52, 53]]]]
]

expected_max_pool_output_no_padding_stride_2 = [[[[[39, 40, 41]]]]]
expected_max_pool_output_padding_1 = [
[
[[[0, 1, 2], [6, 7, 8]], [[18, 19, 20], [24, 25, 26]]],
[[[27, 28, 29], [33, 34, 35]], [[45, 46, 47], [51, 52, 53]]],
]
]
expected_irregular_max_pool_output = [
[
[[[9, 10, 11], [12, 13, 14], [15, 16, 17]]],
[[[36, 37, 38], [39, 40, 41], [42, 43, 44]]],
]
]

self.assertTrue(
np.array_equal(
nn.MaxPool3d(kernel_size=2, stride=1, padding=0)(x),
expected_max_pool_output_no_padding_stride_1,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool3d(kernel_size=2, stride=2, padding=0)(x),
expected_max_pool_output_no_padding_stride_2,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool3d(kernel_size=2, stride=2, padding=1)(x),
expected_max_pool_output_padding_1,
)
)
self.assertTrue(
np.array_equal(
nn.MaxPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),
expected_irregular_max_pool_output,
)
)
self.assertEqual(
str(nn.MaxPool3d(kernel_size=3, stride=3, padding=2)),
"MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
)

expected_avg_pool_output_no_padding_stride_1 = [[[[[19.5, 20.5, 21.5],
[22.5, 23.5, 24.5]],
[[28.5, 29.5, 30.5],
[31.5, 32.5, 33.5]]]]
]

expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]]
expected_avg_pool_output_padding_1 = [
[[[[0, 0.125, 0.25],
[1.125, 1.375, 1.625]],
[[3.375, 3.625, 3.875],
[9, 9.5, 10]]],
[[[3.375, 3.5, 3.625],
[7.875, 8.125, 8.375]],
[[10.125, 10.375, 10.625],
[22.5, 23, 23.5]]]]
]
expected_irregular_avg_pool_output = [[[[[4.5, 5.5, 6.5],
[7.5, 8.5, 9.5],
[10.5, 11.5, 12.5]]],
[[[31.5, 32.5, 33.5],
[34.5, 35.5, 36.5],
[37.5, 38.5, 39.5]]]]
]

self.assertTrue(
np.array_equal(
nn.AvgPool3d(kernel_size=2, stride=1, padding=0)(x),
expected_avg_pool_output_no_padding_stride_1,
)
)
self.assertTrue(
np.array_equal(
nn.AvgPool3d(kernel_size=2, stride=2, padding=0)(x),
expected_avg_pool_output_no_padding_stride_2,
)
)
self.assertTrue(
np.array_equal(
nn.AvgPool3d(kernel_size=2, stride=2, padding=1)(x),
expected_avg_pool_output_padding_1,
)
)
self.assertTrue(
np.array_equal(
nn.AvgPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x),
expected_irregular_avg_pool_output,
)
)
self.assertEqual(
str(nn.AvgPool3d(kernel_size=3, stride=3, padding=2)),
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
)

def test_set_dtype(self):
def assert_dtype(layer, dtype):
Expand Down