diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 3c1ea89..869da88 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -3,7 +3,7 @@ module NeuralOperators using ArgCheck: @argcheck using ChainRulesCore: @non_differentiable using ConcreteStructs: @concrete -using FFTW: FFTW, irfft, rfft +using FFTW: FFTW, plan_rfft, plan_irfft using Random: Random, AbstractRNG using Static: StaticBool, False, True, known, static diff --git a/src/layers.jl b/src/layers.jl index dff7b13..b83307c 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -48,6 +48,14 @@ function LuxCore.initialparameters(rng::AbstractRNG, layer::OperatorConv) rng, eltype(layer.tform), out_chs, in_chs, layer.prod_modes)) end +function LuxCore.initialstates(::AbstractRNG, layer::OperatorConv) + fake_x = zeros(Float32, ntuple(Returns(1), ndims(layer.tform))..., 1) + plan_tform = plan_transform(layer.tform, fake_x, nothing) + x = transform(layer.tform, fake_x, plan_tform) + plan_inv_tform = plan_inverse(layer.tform, x, nothing, size(x)) + return (; plan_tform, plan_inv_tform) +end + function LuxCore.parameterlength(layer::OperatorConv) return layer.prod_modes * layer.in_chs * layer.out_chs end @@ -59,19 +67,21 @@ function OperatorConv( end function (conv::OperatorConv{True})(x::AbstractArray, ps, st) - return operator_conv(x, conv.tform, ps.weight), st + return operator_conv(x, conv.tform, ps.weight, st) end function (conv::OperatorConv{False})(x::AbstractArray, ps, st) N = ndims(conv.tform) xᵀ = permutedims(x, (ntuple(i -> i + 1, N)..., 1, N + 2)) - yᵀ = operator_conv(xᵀ, conv.tform, ps.weight) + yᵀ, stₙ = operator_conv(xᵀ, conv.tform, ps.weight, st) y = permutedims(yᵀ, (N + 1, 1:N..., N + 2)) - return y, st + return y, stₙ end -function operator_conv(x, tform::AbstractTransform, weights) - x_t = transform(tform, x) +function operator_conv(x, tform::AbstractTransform, weights, st) + plan_tform = plan_transform(tform, x, st.plan_tform) + x_t = transform(tform, x, plan_tform) + x_tr = truncate_modes(tform, x_t) x_p = apply_pattern(x_tr, weights) @@ -79,7 +89,8 @@ function operator_conv(x, tform::AbstractTransform, weights) x_padded = NNlib.pad_constant(x_p, expand_pad_dims(pad_dims), false; dims=ntuple(identity, ndims(x_p) - 2))::typeof(x_p) - return inverse(tform, x_padded, size(x)) + plan_inv_tform = plan_inverse(tform, x_padded, st.plan_inv_tform, size(x)) + return inverse(tform, x_padded, plan_inv_tform, size(x)), (; plan_tform, plan_inv_tform) end """ diff --git a/src/transform.jl b/src/transform.jl index 7fe53f2..10dd2ee 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -4,10 +4,22 @@ ## Interface - `Base.ndims(<:AbstractTransform)`: N dims of modes - - `transform(<:AbstractTransform, x::AbstractArray)`: Apply the transform to x - `truncate_modes(<:AbstractTransform, x_transformed::AbstractArray)`: Truncate modes that contribute to the noise - - `inverse(<:AbstractTransform, x_transformed::AbstractArray)`: Apply the inverse + +### Transform Interface + + - `plan_transform(<:AbstractTransform, x::AbstractArray, prev_plan)`: Construct a plan to + apply the transform to x. Might reuse the previous plan if possible + - `transform(<:AbstractTransform, x::AbstractArray, plan)`: Apply the transform to x using + the plan + +### Inverse Transform Interface + + - `plan_inverse(<:AbstractTransform, x_transformed::AbstractArray, prev_plan, M)`: + Construct a plan to apply the inverse transform to `x_transformed`. Might reuse the + previous plan if possible + - `inverse(<:AbstractTransform, x_transformed::AbstractArray, plan, M)`: Apply the inverse transform to `x_transformed` """ abstract type AbstractTransform{T} end @@ -22,7 +34,18 @@ end Base.ndims(T::FourierTransform) = length(T.modes) -transform(ft::FourierTransform, x::AbstractArray) = rfft(x, 1:ndims(ft)) +function plan_transform(ft::FourierTransform, x::AbstractArray, ::Nothing) + return plan_rfft(x, 1:ndims(ft)) +end + +function plan_transform(ft::FourierTransform, x::AbstractArray, prev_plan) + size(prev_plan) == size(x) && eltype(prev_plan) == eltype(x) && return prev_plan + return plan_transform(ft, x, nothing) +end + +@non_differentiable plan_transform(::Any...) + +transform(::FourierTransform, x::AbstractArray, plan) = plan * x function low_pass(ft::FourierTransform, x_fft::AbstractArray) return view(x_fft, map(d -> 1:d, ft.modes)..., :, :) @@ -30,7 +53,21 @@ end truncate_modes(ft::FourierTransform, x_fft::AbstractArray) = low_pass(ft, x_fft) -function inverse( - ft::FourierTransform, x_fft::AbstractArray{T, N}, M::NTuple{N, Int64}) where {T, N} - return real(irfft(x_fft, first(M), 1:ndims(ft))) +function plan_inverse(ft::FourierTransform, x_transformed::AbstractArray{T, N}, + ::Nothing, M::NTuple{N, Int64}) where {T, N} + return plan_irfft(x_transformed, first(M), 1:ndims(ft)) +end + +function plan_inverse(ft::FourierTransform, x_transformed::AbstractArray{T, N}, + prev_plan, M::NTuple{N, Int64}) where {T, N} + size(prev_plan) == size(x_transformed) && eltype(prev_plan) == eltype(x_transformed) && + return prev_plan + return plan_inverse(ft, x_transformed, nothing, M) +end + +@non_differentiable plan_inverse(::Any...) + +function inverse(::FourierTransform, x_transformed::AbstractArray{T, N}, plan, + ::NTuple{N, Int64}) where {T, N} + return real(plan * x_transformed) end