Skip to content

Commit

Permalink
Use direct FFT method for 3D FFT
Browse files Browse the repository at this point in the history
When `Variable` dimensions and convolution dimensions are both equal to
3, set the convolution dimensions to 3. This enables absorption of the
circulant convolution term into the direct least square method.
  • Loading branch information
antonysigma committed Oct 5, 2024
1 parent 08797c5 commit 0c798ac
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
6 changes: 3 additions & 3 deletions proximal/algorithms/invert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Utilities for getting the inverse of lin ops.
from __future__ import print_function
import numpy as np
from proximal.prox_fns import least_squares, sum_squares
from proximal.lin_ops import vstack
Expand Down Expand Up @@ -93,8 +92,9 @@ def get_least_squares_inverse(op_list, b, try_freq_diagonalize=True, verbose=Fal
implem = get_implem(op_list) # If any freqdiag is halide, solve with halide

if verbose:
dimstr = (' with dimensionality %d' % dims) if dims is not None else ''
print('Optimized for diagonal frequency inverse' + dimstr)
print('Optimized for diagonal frequency inverse' +
(f' with dimensionality {dims}' if dims is not None else '')
)

x_update = least_squares(stacked, b,
freq_diag=diag, freq_dims=dims, implem=implem)
Expand Down
2 changes: 1 addition & 1 deletion proximal/lin_ops/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class conv(LinOp):

def __init__(self, kernel, arg, dims=None, implem=None):
self.kernel = kernel
if dims is not None and dims < len(arg.shape):
if dims is not None and dims <= len(arg.shape):
self.dims = dims
else:
self.dims = None
Expand Down
3 changes: 1 addition & 2 deletions proximal/prox_fns/sum_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def solve(self, b: memoized_expr, rho=None, v=None, lin_solver="lsqr", hash=None
if self.implementation == Impl['halide'] and \
(len(self.freq_shape) == 2 or
(len(self.freq_shape) == 2 and self.freq_dims == 2)):

ftmp_halide_out = np.empty(self.freq_shape, dtype=np.float32, order='F')

if rho is None:
Expand All @@ -236,7 +236,6 @@ def solve(self, b: memoized_expr, rho=None, v=None, lin_solver="lsqr", hash=None
return ftmp_halide_out.ravel()

else:

# General frequency inversion
Ktb = fftd(np.reshape(self.Ktb, self.freq_shape), self.freq_dims)

Expand Down

0 comments on commit 0c798ac

Please sign in to comment.