Skip to content

Commit

Permalink
optimized loops
Browse files Browse the repository at this point in the history
  • Loading branch information
ArrogantGao committed Feb 12, 2024
1 parent 3900f8a commit a8a7026
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 36 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ChebParticleMesh = "1983ef0c-217d-4026-99b0-9163e7750d85"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExTinyMD = "fec76197-d59f-46dd-a0ed-76a83c21f7aa"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
SumOfExpVPMR = "2c69873a-e0bb-44e1-90b4-d15ac3b7e936"

Expand Down
32 changes: 20 additions & 12 deletions src/FFCT/linear_eqs.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,36 @@
@inbounds function boundaries_single!(q::T, pos::NTuple{3, T}, b_l::Array{Complex{T}, 2}, b_u::Array{Complex{T}, 2}, k_x::Vector{T}, k_y::Vector{T}, L_z::T, uspara::USeriesPara{T}, M_mid::Int) where{T}
@inbounds function boundaries_single!(q::T, pos::NTuple{3, T}, b_l::Array{Complex{T}, 2}, b_u::Array{Complex{T}, 2}, k_x::Vector{T}, k_y::Vector{T}, phase_x::Vector{Complex{T}}, phase_y::Vector{Complex{T}}, L_z::T, uspara::USeriesPara{T}, M_mid::Int) where{T}

x, y, z = pos
for i in 1:size(b_l, 1), j in 1:size(b_l, 2)
k_xi = k_x[i]
k_yj = k_y[j]
phase = exp( - T(1)im * (k_xi * x + k_yj * y))
for l in M_mid + 1:length(uspara.sw)
sl, wl = uspara.sw[l]
temp = q * π * wl * sl^2 * exp(-sl^2 * (k_xi^2 + k_yj^2) / 4) * phase
b_l[i, j] += temp * exp(- z^2 / sl^2)
b_u[i, j] += temp * exp(- (L_z - z)^2 / sl^2)

revise_phase_neg!(phase_x, phase_y, k_x, k_y, x, y)

for l in M_mid + 1:length(uspara.sw)
sl, wl = uspara.sw[l]
temp = q * π * wl * sl^2
temp_l = temp * exp(- z^2 / sl^2)
temp_u = temp * exp(- (L_z - z)^2 / sl^2)
for j in size(b_l, 2)
k_yj = k_y[j]
for i in 1:size(b_l, 1)
k_xi = k_x[i]
phase = phase_x[i] * phase_y[j]
exp_temp = exp(-sl^2 * (k_xi^2 + k_yj^2) / 4)
b_l[i, j] += temp_l * exp_temp * phase
b_u[i, j] += temp_u * exp_temp * phase
end
end
end

return b_l, b_u
end

function boundaries!(qs::Vector{T}, poses::Vector{NTuple{3, T}}, b_l::Array{Complex{T}, 2}, b_u::Array{Complex{T}, 2}, k_x::Vector{T}, k_y::Vector{T}, L_z::T, uspara::USeriesPara{T}, M_mid::Int) where{T}
function boundaries!(qs::Vector{T}, poses::Vector{NTuple{3, T}}, b_l::Array{Complex{T}, 2}, b_u::Array{Complex{T}, 2}, k_x::Vector{T}, k_y::Vector{T}, phase_x::Vector{Complex{T}}, phase_y::Vector{Complex{T}}, L_z::T, uspara::USeriesPara{T}, M_mid::Int) where{T}

set_zeros!(b_l)
set_zeros!(b_u)

for i in 1:length(qs)
boundaries_single!(qs[i], poses[i], b_l, b_u, k_x, k_y, L_z, uspara, M_mid)
boundaries_single!(qs[i], poses[i], b_l, b_u, k_x, k_y, phase_x, phase_y, L_z, uspara, M_mid)
end

return b_l, b_u
Expand Down
15 changes: 10 additions & 5 deletions src/FFCT/nugrid_gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
ϕ = zero(Complex{T})

revise_phase_pos!(phase_x, phase_y, k_x, k_y, x, y)

for i in 1:size(H_s, 1), j in 1:size(H_s, 2)
phase = phase_x[i] * phase_y[j]

# k = 1
for j in 1:size(H_s, 2), i in 1:size(H_s, 1)
phase = phase_x[i] * phase_y[j]
ϕ += H_s[i, j, 1] * phase / T(2)
for k in 2:size(H_s, 3)
ϕ += H_s[i, j, k] * phase * chebpoly(k - 1, z - L_z / T(2), L_z / T(2))
end

for k in 2:size(H_s, 3)
cheb_temp = chebpoly(k - 1, z - L_z / T(2), L_z / T(2))
for j in 1:size(H_s, 2), i in 1:size(H_s, 1)
phase = phase_x[i] * phase_y[j]
ϕ += H_s[i, j, k] * phase * cheb_temp
end
end

Expand Down
40 changes: 23 additions & 17 deletions src/FFCT/nugrid_interpolate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,25 @@

x, y, z = pos
revise_phase_neg!(phase_x, phase_y, k_x, k_y, x, y)
for i in 1:size(H_r, 1), j in 1:size(H_r, 2)
k_xi = k_x[i]
k_yj = k_y[j]
phase = phase_x[i] * phase_y[j]

for k in 1:size(H_r, 3)
r_zk = r_z[k]
val = zero(T)

for l in M_mid + 1:length(uspara.sw)
sl, wl = uspara.sw[l]
val += wl * (T(2) - T(4) * (z - r_zk)^2 / sl^2 + (k_xi^2 + k_yj^2) * sl^2) * exp(- (z - r_zk)^2 / sl^2) * us_mat[i, j, l - M_mid]
end

H_r[i, j, k] += q * π * phase * val
for k in 1:size(H_r, 3)
r_zk = r_z[k]
for l in M_mid + 1:length(uspara.sw)
sl, wl = uspara.sw[l]
exp_temp = exp(- (z - r_zk)^2 / sl^2)
z_temp = T(2) - T(4) * (z - r_zk)^2 / sl^2

for j in 1:size(H_r, 2)
k_yj = k_y[j]
phase_yj = phase_y[j]
for i in 1:size(H_r, 1)
k_xi = k_x[i]
phase = phase_x[i] * phase_yj
H_r[i, j, k] += q * π * phase * (z_temp + (k_xi^2 + k_yj^2) * sl^2) * exp_temp * us_mat[i, j, l - M_mid]
end
end
end
end
end

return H_r
end
Expand All @@ -37,8 +39,12 @@ end

set_zeros!(H_c)
N_z = size(H_r, 3)
for i in 1:size(H_r, 1), j in 1:size(H_r, 2), k in 1:N_z, l in 1:N_z
H_c[i, j, k] += 2 / N_z * H_r[i, j, l] * chebpoly(k - 1, r_z[l] - L_z / T(2), L_z / T(2))

for k in 1:N_z, l in 1:N_z
cheb_temp = chebpoly(k - 1, r_z[l] - L_z / T(2), L_z / T(2))
for j in 1:size(H_r, 2), i in 1:size(H_r, 1)
H_c[i, j, k] += 2 / N_z * H_r[i, j, l] * cheb_temp
end
end

return H_c
Expand Down
2 changes: 1 addition & 1 deletion src/FastSpecSoG.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module FastSpecSoG

using ExTinyMD, LinearAlgebra, SpecialFunctions, ChebParticleMesh, SumOfExpVPMR
using ExTinyMD, LinearAlgebra, SpecialFunctions, ChebParticleMesh, SumOfExpVPMR, LoopVectorization

export USeriesPara, U_series, BSA
export FSSoG_naive
Expand Down
2 changes: 1 addition & 1 deletion src/energy/energy_long.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function energy_long(

@assert M_mid length(uspara.sw)

b_l, b_u = boundaries!(qs, poses, b_l, b_u, k_x, k_y, L[3], uspara, M_mid)
b_l, b_u = boundaries!(qs, poses, b_l, b_u, k_x, k_y, phase_x, phase_y, L[3], uspara, M_mid)
H_r = interpolate_nu!(qs, poses, k_x, k_y, phase_x, phase_y, r_z, us_mat, H_r, uspara, M_mid)
H_c = real2Cheb!(H_r, H_c, r_z, L[3])
H_s = solve_eqs!(rhs, sol, H_c, H_s, b_l, b_u, ivsm, L[3])
Expand Down

0 comments on commit a8a7026

Please sign in to comment.