-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathEikoTypes.jl
38 lines (30 loc) · 1.06 KB
/
EikoTypes.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
struct τ0{F} <: Lux.AbstractExplicitLayer
scale::F
end
Lux.initialparameters(::AbstractRNG, ::τ0) = NamedTuple()
function Lux.initialstates(::AbstractRNG, l::τ0)
return (scale=l.scale,)
end
struct EikoNet{L1, L2} <: Lux.AbstractExplicitContainerLayer{(:τ1, :τ0)}
τ1::L1
τ0::L2
end
function (eikonet::EikoNet)(x::AbstractArray, ps, st::NamedTuple)
T0, st_τ0 = eikonet.τ0(x, ps.τ0, st.τ0)
T1, st_τ1 = eikonet.τ1(x, ps.τ1, st.τ1)
return T0 .* T1, (τ0 = st_τ0, τ1 = st_τ1)
end
function (l::τ0)(x::AbstractArray, ps, st::NamedTuple)
T = st.scale * sqrt.(sum((x[4:6,:] - x[1:3,:]).^2, dims=1))
return T, st
end
function EikonalPDE(eikonet::EikoNet, x::AbstractArray, ps, st::NamedTuple)
τ0, _ = Lux.apply(eikonet.τ0, x, ps.τ0, st.τ0)
τ1, _ = Lux.apply(eikonet.τ1, x, ps.τ1, st.τ1)
∇τ0 = (x[4:6,:] .- x[1:3,:]) ./ τ0
f(x) = sum(eikonet.τ1(x, ps.τ1, st.τ1)[1])
∇τ1 = gradient(f, x)[1][4:6,:]
∇τ = τ1 .* ∇τ0 + τ0 .* ∇τ1
s = sqrt.(sum(∇τ.^2, dims=1))
return s
end