Skip to content

Commit

Permalink
feat: cache plans for fft
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 29, 2024
1 parent 6f1cc2c commit 75d80fe
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module NeuralOperators
using ArgCheck: @argcheck
using ChainRulesCore: @non_differentiable
using ConcreteStructs: @concrete
using FFTW: FFTW, irfft, rfft
using FFTW: FFTW, plan_rfft, plan_irfft
using Random: Random, AbstractRNG
using Static: StaticBool, False, True, known, static

Expand Down
23 changes: 17 additions & 6 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ function LuxCore.initialparameters(rng::AbstractRNG, layer::OperatorConv)
rng, eltype(layer.tform), out_chs, in_chs, layer.prod_modes))
end

function LuxCore.initialstates(::AbstractRNG, layer::OperatorConv)
fake_x = zeros(Float32, ntuple(Returns(1), ndims(layer.tform))..., 1)
plan_tform = plan_transform(layer.tform, fake_x, nothing)
x = transform(layer.tform, fake_x, plan_tform)
plan_inv_tform = plan_inverse(layer.tform, x, nothing, size(x))
return (; plan_tform, plan_inv_tform)
end

function LuxCore.parameterlength(layer::OperatorConv)
return layer.prod_modes * layer.in_chs * layer.out_chs
end
Expand All @@ -59,27 +67,30 @@ function OperatorConv(
end

function (conv::OperatorConv{True})(x::AbstractArray, ps, st)
return operator_conv(x, conv.tform, ps.weight), st
return operator_conv(x, conv.tform, ps.weight, st)
end

function (conv::OperatorConv{False})(x::AbstractArray, ps, st)
N = ndims(conv.tform)
xᵀ = permutedims(x, (ntuple(i -> i + 1, N)..., 1, N + 2))
yᵀ = operator_conv(xᵀ, conv.tform, ps.weight)
yᵀ, stₙ = operator_conv(xᵀ, conv.tform, ps.weight, st)
y = permutedims(yᵀ, (N + 1, 1:N..., N + 2))
return y, st
return y, stₙ
end

function operator_conv(x, tform::AbstractTransform, weights)
x_t = transform(tform, x)
function operator_conv(x, tform::AbstractTransform, weights, st)
plan_tform = plan_transform(tform, x, st.plan_tform)
x_t = transform(tform, x, plan_tform)

x_tr = truncate_modes(tform, x_t)
x_p = apply_pattern(x_tr, weights)

pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)]
x_padded = NNlib.pad_constant(x_p, expand_pad_dims(pad_dims), false;
dims=ntuple(identity, ndims(x_p) - 2))::typeof(x_p)

return inverse(tform, x_padded, size(x))
plan_inv_tform = plan_inverse(tform, x_padded, st.plan_inv_tform, size(x))
return inverse(tform, x_padded, plan_inv_tform, size(x)), (; plan_tform, plan_inv_tform)
end

"""
Expand Down
49 changes: 43 additions & 6 deletions src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,22 @@
## Interface
- `Base.ndims(<:AbstractTransform)`: N dims of modes
- `transform(<:AbstractTransform, x::AbstractArray)`: Apply the transform to x
- `truncate_modes(<:AbstractTransform, x_transformed::AbstractArray)`: Truncate modes
that contribute to the noise
- `inverse(<:AbstractTransform, x_transformed::AbstractArray)`: Apply the inverse
### Transform Interface
- `plan_transform(<:AbstractTransform, x::AbstractArray, prev_plan)`: Construct a plan to
apply the transform to x. Might reuse the previous plan if possible
- `transform(<:AbstractTransform, x::AbstractArray, plan)`: Apply the transform to x using
the plan
### Inverse Transform Interface
- `plan_inverse(<:AbstractTransform, x_transformed::AbstractArray, prev_plan, M)`:
Construct a plan to apply the inverse transform to `x_transformed`. Might reuse the
previous plan if possible
- `inverse(<:AbstractTransform, x_transformed::AbstractArray, plan, M)`: Apply the inverse
transform to `x_transformed`
"""
abstract type AbstractTransform{T} end
Expand All @@ -22,15 +34,40 @@ end

Base.ndims(T::FourierTransform) = length(T.modes)

transform(ft::FourierTransform, x::AbstractArray) = rfft(x, 1:ndims(ft))
function plan_transform(ft::FourierTransform, x::AbstractArray, ::Nothing)
return plan_rfft(x, 1:ndims(ft))
end

function plan_transform(ft::FourierTransform, x::AbstractArray, prev_plan)
size(prev_plan) == size(x) && eltype(prev_plan) == eltype(x) && return prev_plan
return plan_transform(ft, x, nothing)
end

@non_differentiable plan_transform(::Any...)

transform(::FourierTransform, x::AbstractArray, plan) = plan * x

function low_pass(ft::FourierTransform, x_fft::AbstractArray)
return view(x_fft, map(d -> 1:d, ft.modes)..., :, :)
end

truncate_modes(ft::FourierTransform, x_fft::AbstractArray) = low_pass(ft, x_fft)

function inverse(
ft::FourierTransform, x_fft::AbstractArray{T, N}, M::NTuple{N, Int64}) where {T, N}
return real(irfft(x_fft, first(M), 1:ndims(ft)))
function plan_inverse(ft::FourierTransform, x_transformed::AbstractArray{T, N},
::Nothing, M::NTuple{N, Int64}) where {T, N}
return plan_irfft(x_transformed, first(M), 1:ndims(ft))
end

function plan_inverse(ft::FourierTransform, x_transformed::AbstractArray{T, N},
prev_plan, M::NTuple{N, Int64}) where {T, N}
size(prev_plan) == size(x_transformed) && eltype(prev_plan) == eltype(x_transformed) &&
return prev_plan
return plan_inverse(ft, x_transformed, nothing, M)
end

@non_differentiable plan_inverse(::Any...)

function inverse(::FourierTransform, x_transformed::AbstractArray{T, N}, plan,
::NTuple{N, Int64}) where {T, N}
return real(plan * x_transformed)
end

0 comments on commit 75d80fe

Please sign in to comment.