Skip to content

Commit

Permalink
optimize the FFCT speed
Browse files Browse the repository at this point in the history
  • Loading branch information
ArrogantGao committed Feb 11, 2024
1 parent 97b58fd commit 3900f8a
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 17 deletions.
11 changes: 5 additions & 6 deletions src/FFCT/nugrid_gather.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
@inbounds function gather_nu_single(q::T, pos::NTuple{3, T}, L::NTuple{3, T}, k_x::Vector{T}, k_y::Vector{T}, H_s::Array{Complex{T}, 3}) where{T}
@inbounds function gather_nu_single(q::T, pos::NTuple{3, T}, L::NTuple{3, T}, k_x::Vector{T}, k_y::Vector{T}, phase_x::Vector{Complex{T}}, phase_y::Vector{Complex{T}}, H_s::Array{Complex{T}, 3}) where{T}
x, y, z = pos
L_x, L_y, L_z = L
ϕ = zero(Complex{T})

revise_phase_pos!(phase_x, phase_y, k_x, k_y, x, y)

for i in 1:size(H_s, 1), j in 1:size(H_s, 2)
k_xi = k_x[i]
k_yj = k_y[j]
phase = exp(T(1)im * (k_xi * x + k_yj * y))
phase = phase_x[i] * phase_y[j]

ϕ += H_s[i, j, 1] * phase / T(2)
for k in 2:size(H_s, 3)
Expand All @@ -18,10 +17,10 @@
return q * ϕ / (2 * L_x * L_y)
end

function gather_nu(qs::Vector{T}, poses::Vector{NTuple{3, T}}, L::NTuple{3, T}, k_x::Vector{T}, k_y::Vector{T}, H_s::Array{Complex{T}, 3}) where{T}
function gather_nu(qs::Vector{T}, poses::Vector{NTuple{3, T}}, L::NTuple{3, T}, k_x::Vector{T}, k_y::Vector{T}, phase_x::Vector{Complex{T}}, phase_y::Vector{Complex{T}}, H_s::Array{Complex{T}, 3}) where{T}
E = zero(Complex{T})
for i in 1:length(qs)
E += gather_nu_single(qs[i], poses[i], L, k_x, k_y, H_s)
E += gather_nu_single(qs[i], poses[i], L, k_x, k_y, phase_x, phase_y, H_s)
end
return real(E)
end
10 changes: 5 additions & 5 deletions src/FFCT/nugrid_interpolate.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
@inbounds function interpolate_nu_single!(q::T, pos::NTuple{3, T}, k_x::Vector{T}, k_y::Vector{T}, r_z::Vector{T}, us_mat::Array{T, 3}, H_r::Array{Complex{T}, 3}, uspara::USeriesPara{T}, M_mid::Int) where{T}
@inbounds function interpolate_nu_single!(q::T, pos::NTuple{3, T}, k_x::Vector{T}, k_y::Vector{T}, phase_x::Vector{Complex{T}}, phase_y::Vector{Complex{T}}, r_z::Vector{T}, us_mat::Array{T, 3}, H_r::Array{Complex{T}, 3}, uspara::USeriesPara{T}, M_mid::Int) where{T}

x, y, z = pos
revise_phase_neg!(phase_x, phase_y, k_x, k_y, x, y)
for i in 1:size(H_r, 1), j in 1:size(H_r, 2)
k_xi = k_x[i]
k_yj = k_y[j]
phase = exp( - T(1)im * (k_xi * x + k_yj * y))
phase = phase_x[i] * phase_y[j]

for k in 1:size(H_r, 3)
r_zk = r_z[k]
Expand All @@ -22,11 +23,11 @@
return H_r
end

function interpolate_nu!(qs::Vector{T}, poses::Vector{NTuple{3, T}}, k_x::Vector{T}, k_y::Vector{T}, r_z::Vector{T}, us_mat::Array{T, 3}, H_r::Array{Complex{T}, 3}, uspara::USeriesPara{T}, M_mid::Int) where{T}
function interpolate_nu!(qs::Vector{T}, poses::Vector{NTuple{3, T}}, k_x::Vector{T}, k_y::Vector{T}, phase_x::Vector{Complex{T}}, phase_y::Vector{Complex{T}}, r_z::Vector{T}, us_mat::Array{T, 3}, H_r::Array{Complex{T}, 3}, uspara::USeriesPara{T}, M_mid::Int) where{T}

set_zeros!(H_r)
for i in 1:length(qs)
interpolate_nu_single!(qs[i], poses[i], k_x, k_y, r_z, us_mat, H_r, uspara, M_mid)
interpolate_nu_single!(qs[i], poses[i], k_x, k_y, phase_x, phase_y, r_z, us_mat, H_r, uspara, M_mid)
end

return H_r
Expand All @@ -38,7 +39,6 @@ end
N_z = size(H_r, 3)
for i in 1:size(H_r, 1), j in 1:size(H_r, 2), k in 1:N_z, l in 1:N_z
H_c[i, j, k] += 2 / N_z * H_r[i, j, l] * chebpoly(k - 1, r_z[l] - L_z / T(2), L_z / T(2))
# cos((k - 1)* acos(r_z[l] / (L_z / 2) - 1.0))
end

return H_c
Expand Down
24 changes: 23 additions & 1 deletion src/FFCT/precompute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,36 @@ function FFCT_precompute(L::NTuple{3, T}, N_grid::NTuple{3, Int}, uspara::USerie
# the boundary condition at z = 0 and z = L_z
b_l = zeros(Complex{T}, 2N_x + 1, 2N_y + 1)
b_u = zeros(Complex{T}, 2N_x + 1, 2N_y + 1)
phase_x = zeros(Complex{T}, 2N_x + 1)
phase_y = zeros(Complex{T}, 2N_y + 1)

rhs = zeros(Complex{T}, N_z + 2)
sol = zeros(Complex{T}, N_z + 2)

sort_z = zeros(Int, n_atoms)
z = zeros(T, n_atoms)

return k_x, k_y, r_z, us_mat, H_r, H_c, H_s, ivsm, b_l, b_u, rhs, sol, sort_z, z
return k_x, k_y, r_z, us_mat, H_r, H_c, H_s, ivsm, b_l, b_u, phase_x, phase_y, rhs, sol, sort_z, z
end

@inbounds function revise_phase_neg!(phase_x::Vector{Complex{T}}, phase_y::Vector{Complex{T}}, k_x::Vector{T}, k_y::Vector{T}, x::T, y::T) where{T}

for i in 1:length(k_x)
phase_x[i] = exp(-T(1)im * k_x[i] * x)
phase_y[i] = exp(-T(1)im * k_y[i] * y)
end

return nothing
end

@inbounds function revise_phase_pos!(phase_x::Vector{Complex{T}}, phase_y::Vector{Complex{T}}, k_x::Vector{T}, k_y::Vector{T}, x::T, y::T) where{T}

for i in 1:length(k_x)
phase_x[i] = exp(T(1)im * k_x[i] * x)
phase_y[i] = exp(T(1)im * k_y[i] * y)
end

return nothing
end

# precompute the inverse matrix
Expand Down
5 changes: 3 additions & 2 deletions src/energy/energy_long.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
function energy_long(
qs::Vector{T}, poses::Vector{NTuple{3, T}}, L::NTuple{3, T}, M_mid::Int,
k_x::Vector{T}, k_y::Vector{T}, r_z::Vector{T},
phase_x::Vector{Complex{T}}, phase_y::Vector{Complex{T}},
z::Vector{T}, sort_z::Vector{Int},
us_mat::Array{T, 3}, b_l::Array{Complex{T}, 2}, b_u::Array{Complex{T}, 2},
rhs::Vector{Complex{T}}, sol::Vector{Complex{T}}, ivsm::Array{T, 4},
Expand All @@ -23,10 +24,10 @@ function energy_long(
@assert M_mid length(uspara.sw)

b_l, b_u = boundaries!(qs, poses, b_l, b_u, k_x, k_y, L[3], uspara, M_mid)
H_r = interpolate_nu!(qs, poses, k_x, k_y, r_z, us_mat, H_r, uspara, M_mid)
H_r = interpolate_nu!(qs, poses, k_x, k_y, phase_x, phase_y, r_z, us_mat, H_r, uspara, M_mid)
H_c = real2Cheb!(H_r, H_c, r_z, L[3])
H_s = solve_eqs!(rhs, sol, H_c, H_s, b_l, b_u, ivsm, L[3])
E_k = gather_nu(qs, poses, L, k_x, k_y, H_s)
E_k = gather_nu(qs, poses, L, k_x, k_y, phase_x, phase_y, H_s)

revise_z!(z, sort_z, poses)
E_0 = zeroth_order(qs, z, soepara, uspara, sort_z, L, M_mid)
Expand Down
4 changes: 2 additions & 2 deletions test/energy_long.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
soepara = SoePara16()
M_mid = 8

k_x, k_y, r_z, us_mat, H_r, H_c, H_s, ivsm, b_l, b_u, rhs, sol, sort_z, z = FFCT_precompute(L, N_grid, USeriesPara(2), M_mid, n_atoms)
k_x, k_y, r_z, us_mat, H_r, H_c, H_s, ivsm, b_l, b_u, phase_x, phase_y, rhs, sol, sort_z, z = FFCT_precompute(L, N_grid, USeriesPara(2), M_mid, n_atoms)

@info "running the FFCT for the long range part of the energy"
E_FFCT = energy_long(qs, poses, L, M_mid, k_x, k_y, r_z, z, sort_z, us_mat, b_l, b_u, rhs, sol, ivsm, H_r, H_c, H_s, uspara, soepara)
E_FFCT = energy_long(qs, poses, L, M_mid, k_x, k_y, r_z, phase_x, phase_y, z, sort_z, us_mat, b_l, b_u, rhs, sol, ivsm, H_r, H_c, H_s, uspara, soepara)

@info "running the direct summation for the long range part of the energy"
# using the direct summation
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ExTinyMD, EwaldSummations

@testset "FastSpecSoG.jl" begin
include("U_series.jl")
include("energy_naive.jl")
# include("energy_naive.jl")
include("energy_mid.jl")
include("energy_long.jl")
end

0 comments on commit 3900f8a

Please sign in to comment.