Skip to content

Commit

Permalink
convolve operator (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 authored Oct 3, 2024
1 parent 5274201 commit 9030b79
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 32 deletions.
54 changes: 28 additions & 26 deletions src/beignet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ._compose_quaternion import compose_quaternion
from ._compose_rotation_matrix import compose_rotation_matrix
from ._compose_rotation_vector import compose_rotation_vector
from ._convolve import convolve
from ._differentiate_chebyshev_polynomial import differentiate_chebyshev_polynomial
from ._differentiate_laguerre_polynomial import differentiate_laguerre_polynomial
from ._differentiate_legendre_polynomial import differentiate_legendre_polynomial
Expand Down Expand Up @@ -370,14 +371,9 @@
"apply_rotation_matrix",
"apply_rotation_vector",
"apply_transform",
"evaluate_chebyshev_polynomial_cartesian_2d",
"evaluate_chebyshev_polynomial_cartesian_3d",
"chebyshev_interpolation",
"linear_chebyshev_polynomial",
"multiply_chebyshev_polynomial_by_x",
"chebyshev_zeros",
"chebyshev_extrema",
"chebyshev_gauss_quadrature",
"chebyshev_interpolation",
"chebyshev_polynomial_companion",
"chebyshev_polynomial_domain",
"chebyshev_polynomial_from_roots",
Expand All @@ -391,10 +387,12 @@
"chebyshev_polynomial_weight",
"chebyshev_polynomial_x",
"chebyshev_polynomial_zero",
"chebyshev_zeros",
"compose_euler_angle",
"compose_quaternion",
"compose_rotation_matrix",
"compose_rotation_vector",
"convolve",
"differentiate_chebyshev_polynomial",
"differentiate_laguerre_polynomial",
"differentiate_legendre_polynomial",
Expand All @@ -416,22 +414,34 @@
"evaluate_chebyshev_polynomial",
"evaluate_chebyshev_polynomial_2d",
"evaluate_chebyshev_polynomial_3d",
"evaluate_chebyshev_polynomial_cartesian_2d",
"evaluate_chebyshev_polynomial_cartesian_3d",
"evaluate_laguerre_polynomial",
"evaluate_laguerre_polynomial_2d",
"evaluate_laguerre_polynomial_3d",
"evaluate_laguerre_polynomial_cartesian_2d",
"evaluate_laguerre_polynomial_cartesian_3d",
"evaluate_legendre_polynomial",
"evaluate_legendre_polynomial_2d",
"evaluate_legendre_polynomial_3d",
"evaluate_legendre_polynomial_cartesian_2d",
"evaluate_legendre_polynomial_cartesian_3d",
"evaluate_physicists_hermite_polynomial",
"evaluate_physicists_hermite_polynomial_2d",
"evaluate_physicists_hermite_polynomial_3d",
"evaluate_physicists_hermite_polynomial_cartesian_2d",
"evaluate_physicists_hermite_polynomial_cartesian_3d",
"evaluate_polynomial",
"evaluate_polynomial_2d",
"evaluate_polynomial_3d",
"evaluate_polynomial_cartesian_2d",
"evaluate_polynomial_cartesian_3d",
"evaluate_polynomial_from_roots",
"evaluate_probabilists_hermite_polynomial",
"evaluate_probabilists_hermite_polynomial_2d",
"evaluate_probabilists_hermite_polynomial_3d",
"evaluate_probabilists_hermite_polynomial_cartersian_2d",
"evaluate_probabilists_hermite_polynomial_cartersian_3d",
"farthest_first_traversal",
"fit_chebyshev_polynomial",
"fit_laguerre_polynomial",
Expand All @@ -443,14 +453,6 @@
"gauss_legendre_quadrature",
"gauss_physicists_hermite_polynomial_quadrature",
"gauss_probabilists_hermite_polynomial_quadrature",
"evaluate_probabilists_hermite_polynomial_cartersian_2d",
"evaluate_probabilists_hermite_polynomial_cartersian_3d",
"linear_probabilists_hermite_polynomial",
"multiply_probabilists_hermite_polynomial_by_x",
"evaluate_physicists_hermite_polynomial_cartesian_2d",
"evaluate_physicists_hermite_polynomial_cartesian_3d",
"linear_physicists_hermite_polynomial",
"multiply_physicists_hermite_polynomial_by_x",
"integrate_chebyshev_polynomial",
"integrate_laguerre_polynomial",
"integrate_legendre_polynomial",
Expand All @@ -462,10 +464,6 @@
"invert_rotation_matrix",
"invert_rotation_vector",
"invert_transform",
"evaluate_laguerre_polynomial_cartesian_2d",
"evaluate_laguerre_polynomial_cartesian_3d",
"linear_laguerre_polynomial",
"multiply_laguerre_polynomial_by_x",
"laguerre_polynomial_companion",
"laguerre_polynomial_domain",
"laguerre_polynomial_from_roots",
Expand All @@ -492,17 +490,25 @@
"legendre_polynomial_weight",
"legendre_polynomial_x",
"legendre_polynomial_zero",
"evaluate_legendre_polynomial_cartesian_2d",
"evaluate_legendre_polynomial_cartesian_3d",
"linear_legendre_polynomial",
"multiply_legendre_polynomial_by_x",
"lennard_jones_potential",
"linear_chebyshev_polynomial",
"linear_laguerre_polynomial",
"linear_legendre_polynomial",
"linear_physicists_hermite_polynomial",
"linear_polynomial",
"linear_probabilists_hermite_polynomial",
"multiply_chebyshev_polynomial",
"multiply_chebyshev_polynomial_by_x",
"multiply_laguerre_polynomial",
"multiply_laguerre_polynomial_by_x",
"multiply_legendre_polynomial",
"multiply_legendre_polynomial_by_x",
"multiply_physicists_hermite_polynomial",
"multiply_physicists_hermite_polynomial_by_x",
"multiply_polynomial",
"multiply_polynomial_by_x",
"multiply_probabilists_hermite_polynomial",
"multiply_probabilists_hermite_polynomial_by_x",
"physicists_hermite_polynomial_companion",
"physicists_hermite_polynomial_domain",
"physicists_hermite_polynomial_from_roots",
Expand All @@ -516,10 +522,6 @@
"physicists_hermite_polynomial_weight",
"physicists_hermite_polynomial_x",
"physicists_hermite_polynomial_zero",
"evaluate_polynomial_cartesian_2d",
"evaluate_polynomial_cartesian_3d",
"linear_polynomial",
"multiply_polynomial_by_x",
"polynomial_companion",
"polynomial_domain",
"polynomial_from_roots",
Expand Down
4 changes: 2 additions & 2 deletions src/beignet/_chebyshev_polynomial_power.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math

import torch
import torchaudio.functional
from torch import Tensor
from torchaudio.functional import convolve

from ._add_chebyshev_polynomial import add_chebyshev_polynomial

Expand Down Expand Up @@ -45,7 +45,7 @@ def chebyshev_polynomial_power(
output = output1

for _ in range(2, _exponent + 1):
output = torchaudio.functional.convolve(output, zs, mode="same")
output = convolve(output, zs, mode="same")

n = (math.prod(output.shape) + 1) // 2
c = output[n - 1 :]
Expand Down
92 changes: 92 additions & 0 deletions src/beignet/_convolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import math
from typing import Literal

import torch
import torch.nn.functional
from torch import Tensor


def convolve(
input: Tensor,
other: Tensor,
mode: Literal["full", "same", "valid"] = "full",
) -> Tensor:
if input.ndim != other.ndim:
raise ValueError

for i in range(input.ndim - 1):
a = input.size(i)
b = other.size(i)

if a == b or a == 1 or b == 1:
continue

raise ValueError

if mode not in {"full", "same", "valid"}:
raise ValueError

x_size, y_size = input.shape[-1], other.shape[-1]

if input.shape[-1] < other.shape[-1]:
input, other = other, input

if input.shape[:-1] != other.shape[:-1]:
input_shape = []

for i, j in zip(input.shape[:-1], other.shape[:-1], strict=False):
input_shape = [*input_shape, max(i, j)]

input = torch.broadcast_to(
input,
[*input_shape, input.shape[-1]],
)

other = torch.broadcast_to(
other,
[*input_shape, other.shape[-1]],
)

input = torch.reshape(
input,
[math.prod(input.shape[:-1]), input.shape[-1]],
)

other = torch.unsqueeze(
torch.flip(
torch.reshape(
other,
[math.prod(input.shape[:-1]), other.shape[-1]],
),
dims=[-1],
),
dim=1,
)

output = torch.reshape(
torch.nn.functional.conv1d(
input,
other,
groups=input.shape[0],
padding=other.shape[-1] - 1,
),
[*input.shape[:-1], -1],
)

match mode:
case "same":
size = x_size

m = (output.shape[-1] - size) // 2
n = m + size

output = output[..., m:n]
case "valid":
size = max(x_size, y_size) - min(x_size, y_size) + 1

m = (output.shape[-1] - size) // 2
n = m + size

output = output[..., m:n]

return output
4 changes: 2 additions & 2 deletions src/beignet/_multiply_chebyshev_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Literal

import torch
import torchaudio.functional
from torch import Tensor
from torchaudio.functional import convolve


def multiply_chebyshev_polynomial(
Expand Down Expand Up @@ -47,7 +47,7 @@ def multiply_chebyshev_polynomial(
output2 = torch.flip(output2, dims=[0]) + output2
b = output2

output = torchaudio.functional.convolve(a, b, mode=mode)
output = convolve(a, b, mode=mode)

n = (math.prod(output.shape) + 1) // 2
c = output[n - 1 :]
Expand Down
4 changes: 2 additions & 2 deletions src/beignet/_multiply_polynomial.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Literal

import torch
import torchaudio.functional
from torch import Tensor
from torchaudio.functional import convolve


def multiply_polynomial(
Expand Down Expand Up @@ -34,7 +34,7 @@ def multiply_polynomial(
input = input.to(dtype)
other = other.to(dtype)

output = torchaudio.functional.convolve(input, other)
output = convolve(input, other)

if mode == "same":
output = output[: max(input.shape[0], other.shape[0])]
Expand Down
39 changes: 39 additions & 0 deletions tests/beignet/test__convolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import beignet
import hypothesis.extra.numpy
import hypothesis.strategies
import numpy
import torch.testing


@hypothesis.strategies.composite
def _strategy(function):
size = function(
hypothesis.strategies.integers(
min_value=128,
max_value=512,
),
)

input = torch.rand([size])
other = torch.rand([size])

return (
{
"input": input,
"other": other,
},
torch.reshape(
torch.from_numpy(numpy.convolve(input, other)),
[1, -1],
),
)


@hypothesis.given(_strategy())
def test_convolve(data):
parameters, expected = data

torch.testing.assert_allclose(
beignet.convolve(**parameters),
expected,
)

0 comments on commit 9030b79

Please sign in to comment.