From 2614aa81d9832479e80b759ab51620fedb0be799 Mon Sep 17 00:00:00 2001 From: albert-de-montserrat Date: Fri, 13 Oct 2023 16:04:07 +0200 Subject: [PATCH 1/8] AMDGPU support --- src/IO/H5.jl | 1 + src/JustRelax.jl | 2 +- src/MetaJustRelax.jl | 12 ++++- src/Utils.jl | 54 ++++++---------------- src/phases/CallArrays.jl | 4 +- src/stokes/Stokes2D.jl | 2 +- src/stokes/Stokes3D.jl | 2 +- src/thermal_diffusion/DiffusionExplicit.jl | 6 +-- 8 files changed, 34 insertions(+), 49 deletions(-) diff --git a/src/IO/H5.jl b/src/IO/H5.jl index 6fb88c21..5caae2d4 100644 --- a/src/IO/H5.jl +++ b/src/IO/H5.jl @@ -10,6 +10,7 @@ end _tocpu(x) = x _tocpu(x::T) where {T<:CuArray} = Array(x) +_tocpu(x::T) where {T<:ROCArray} = Array(x) """ checkpointing(dst, stokes, T, η, time) diff --git a/src/JustRelax.jl b/src/JustRelax.jl index 8eaf6aa9..2c1e2f27 100644 --- a/src/JustRelax.jl +++ b/src/JustRelax.jl @@ -5,7 +5,7 @@ using Reexport @reexport using ImplicitGlobalGrid using LinearAlgebra using Printf -using CUDA +using CUDA, AMDGPU using MPI using GeoParams using HDF5 diff --git a/src/MetaJustRelax.jl b/src/MetaJustRelax.jl index 593460e0..03960a8c 100644 --- a/src/MetaJustRelax.jl +++ b/src/MetaJustRelax.jl @@ -7,19 +7,26 @@ struct PS_Setup{B,C} end function environment!(model::PS_Setup{T,N}) where {T,N} - gpu = model.device == :gpu ? true : false # call appropriate FD module Base.eval(@__MODULE__, Meta.parse("using ParallelStencil.FiniteDifferences$(N)D")) Base.eval(Main, Meta.parse("using ParallelStencil.FiniteDifferences$(N)D")) # start ParallelStencil - if model.device == :gpu + if model.device == :CUDA eval(:(@init_parallel_stencil(CUDA, $T, $N))) Base.eval(Main, Meta.parse("using CUDA")) if !isconst(Main, :PTArray) eval(:(const PTArray = CUDA.CuArray{$T,$N,CUDA.Mem.DeviceBuffer})) end + + elseif model.device == :AMDGPU + eval(:(@init_parallel_stencil(AMDGPU, $T, $N))) + Base.eval(Main, Meta.parse("using AMDGPU")) + if !isconst(Main, :PTArray) + eval(:(const PTArray = AMDGPU.ROCArray{$T,$N,AMDGPU.Runtime.Mem.HIPBuffer})) + end + else @eval begin @init_parallel_stencil(Threads, $T, $N) @@ -27,6 +34,7 @@ function environment!(model::PS_Setup{T,N}) where {T,N} const PTArray = Array{$T,$N} end end + end # CREATE ARRAY STRUCTS diff --git a/src/Utils.jl b/src/Utils.jl index ed0577e2..0f37ca21 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -291,44 +291,13 @@ Compute the maximum value of `A` in the `window = (width_x, width_y, width_z)` a """ function compute_maxloc!(B, A; window=(1, 1, 1)) ni = size(A) - width_x, width_y, width_z = window - @parallel_indices (i, j) function _maxloc!( - B::T, A::T - ) where {T<:AbstractArray{<:Number,2}} - B[i, j] = _maxloc_window_clamped(A, i, j, width_x, width_y) + @parallel_indices (I...) function _maxloc!(B, A, window) + B[I...] = _maxloc_window_clamped(A, I..., window...) return nothing end - @parallel_indices (i, j, k) function _maxloc!( - B::T, A::T - ) where {T<:AbstractArray{<:Number,3}} - B[i, j, k] = _maxloc_window_clamped(A, i, j, k, width_x, width_y, width_z) - return nothing - end - - @parallel (@idx ni) _maxloc!(B, A) -end - -function _compute_maxloc!(B, A, window) - I = indices(window) - - if all(I .<= size(A)) - B[I...] = JustRelax._maxloc_window_clamped(A, I..., window...) - end - - return nothing -end - -function compute_maxloc!( - B::CuArray{T1,N,T2}, A::CuArray{T1,N,T2}; window=ntuple(i -> 1, Val(N)) -) where {T1,T2,N} - nx, ny = size(A) - ntx, nty = 32, 32 - blkx, blky = ceil(Int, nx / ntx), ceil(Int, ny / nty) - CUDA.@sync begin - @cuda threads = (ntx, nty) blocks = (blkx, blky) _compute_maxloc!(B, A, window) - end + @parallel (@idx ni) _maxloc!(B, A, window) return nothing end @@ -404,7 +373,7 @@ Compute the time step `dt` for the velocity field `S.V` and the diffusive maximu @inline function compute_dt(V::NTuple, di, dt_diff) n = inv(length(V) + 0.1) - dt_adv = mapreduce(x -> x[1] * inv(maximum(abs.(x[2]))), min, zip(di, V)) * n + dt_adv = mapreduce(x -> x[1] * inv(maximum_mpi(abs.(x[2]))), min, zip(di, V)) * n return min(dt_diff, dt_adv) end """ @@ -454,21 +423,28 @@ end # MPI reductions function mean_mpi(A) - mean_l = mean(A) + mean_l = _mean(A) return MPI.Allreduce(mean_l, MPI.SUM, MPI.COMM_WORLD) / MPI.Comm_size(MPI.COMM_WORLD) end function norm_mpi(A) - sum2_l = sum(A .^ 2) + sum2_l = _sum(A .^ 2) return sqrt(MPI.Allreduce(sum2_l, MPI.SUM, MPI.COMM_WORLD)) end function minimum_mpi(A) - min_l = minimum(A) + min_l = _minimum(A) return MPI.Allreduce(min_l, MPI.MIN, MPI.COMM_WORLD) end function maximum_mpi(A) - max_l = maximum(A) + max_l = _maximum(A) return MPI.Allreduce(max_l, MPI.MAX, MPI.COMM_WORLD) end + +for (f1,f2) in zip((:_mean, :_norm, :_minimum, :_maximum, :_sum), (:mean, :norm, :minimum, :maximum, :sum)) + @eval begin + $f1(A::ROCArray) = $f2(Array(A)) + $f1(A) = $f2(A) + end +end \ No newline at end of file diff --git a/src/phases/CallArrays.jl b/src/phases/CallArrays.jl index c8b209fb..c5422922 100644 --- a/src/phases/CallArrays.jl +++ b/src/phases/CallArrays.jl @@ -120,8 +120,8 @@ end ## Fallbacks import Base: getindex, setindex! -@inline element(A::Union{Array,CuArray}, I::Vararg{Int,N}) where {N} = getindex(A, I...) -@inline function setelement!(A::Union{Array,CuArray}, x::Number, I::Vararg{Int,N}) where {N} +@inline element(A::Union{Array,CuArray,ROCArray}, I::Vararg{Int,N}) where {N} = getindex(A, I...) +@inline function setelement!(A::Union{Array,CuArray,ROCArray}, x::Number, I::Vararg{Int,N}) where {N} return setindex!(A, x, I...) end diff --git a/src/stokes/Stokes2D.jl b/src/stokes/Stokes2D.jl index 1c7b81bf..4cb21045 100644 --- a/src/stokes/Stokes2D.jl +++ b/src/stokes/Stokes2D.jl @@ -39,7 +39,7 @@ module Stokes2D using ImplicitGlobalGrid using ..JustRelax -using CUDA +using CUDA, AMDGPU using ParallelStencil using ParallelStencil.FiniteDifferences2D using GeoParams, LinearAlgebra, Printf, TimerOutputs diff --git a/src/stokes/Stokes3D.jl b/src/stokes/Stokes3D.jl index 7a21d6cf..dc731121 100644 --- a/src/stokes/Stokes3D.jl +++ b/src/stokes/Stokes3D.jl @@ -41,7 +41,7 @@ using ImplicitGlobalGrid using ParallelStencil using ParallelStencil.FiniteDifferences3D using JustRelax -using CUDA +using CUDA, AMDGPU using LinearAlgebra using Printf using GeoParams diff --git a/src/thermal_diffusion/DiffusionExplicit.jl b/src/thermal_diffusion/DiffusionExplicit.jl index e41fb81e..aacc5f56 100644 --- a/src/thermal_diffusion/DiffusionExplicit.jl +++ b/src/thermal_diffusion/DiffusionExplicit.jl @@ -21,7 +21,7 @@ using ParallelStencil.FiniteDifferences1D using JustRelax using LinearAlgebra using Printf -using CUDA +using CUDA, AMDGPU import JustRelax: ThermalParameters, solve!, assign!, thermal_boundary_conditions!, update_T! @@ -155,7 +155,7 @@ module ThermalDiffusion2D using ParallelStencil using ParallelStencil.FiniteDifferences2D using JustRelax -using CUDA +using CUDA, AMDGPU using GeoParams import JustRelax: ThermalParameters, solve!, assign!, thermal_boundary_conditions! @@ -511,7 +511,7 @@ using ParallelStencil.FiniteDifferences3D using JustRelax using MPI using Printf -using CUDA +using CUDA, AMDGPU using GeoParams import JustRelax: From 6c0bcad18467c5954392a918c6a0f65bb2d0a7c4 Mon Sep 17 00:00:00 2001 From: albert-de-montserrat Date: Fri, 13 Oct 2023 16:09:14 +0200 Subject: [PATCH 2/8] format --- src/MetaJustRelax.jl | 5 ++--- src/Utils.jl | 10 ++++++---- src/phases/CallArrays.jl | 7 +++++-- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/MetaJustRelax.jl b/src/MetaJustRelax.jl index 03960a8c..ca75af8f 100644 --- a/src/MetaJustRelax.jl +++ b/src/MetaJustRelax.jl @@ -19,14 +19,14 @@ function environment!(model::PS_Setup{T,N}) where {T,N} if !isconst(Main, :PTArray) eval(:(const PTArray = CUDA.CuArray{$T,$N,CUDA.Mem.DeviceBuffer})) end - + elseif model.device == :AMDGPU eval(:(@init_parallel_stencil(AMDGPU, $T, $N))) Base.eval(Main, Meta.parse("using AMDGPU")) if !isconst(Main, :PTArray) eval(:(const PTArray = AMDGPU.ROCArray{$T,$N,AMDGPU.Runtime.Mem.HIPBuffer})) end - + else @eval begin @init_parallel_stencil(Threads, $T, $N) @@ -34,7 +34,6 @@ function environment!(model::PS_Setup{T,N}) where {T,N} const PTArray = Array{$T,$N} end end - end # CREATE ARRAY STRUCTS diff --git a/src/Utils.jl b/src/Utils.jl index 0f37ca21..935f7376 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -292,7 +292,7 @@ Compute the maximum value of `A` in the `window = (width_x, width_y, width_z)` a function compute_maxloc!(B, A; window=(1, 1, 1)) ni = size(A) - @parallel_indices (I...) function _maxloc!(B, A, window) + @parallel_indices (I...) function _maxloc!(B, A, window) B[I...] = _maxloc_window_clamped(A, I..., window...) return nothing end @@ -442,9 +442,11 @@ function maximum_mpi(A) return MPI.Allreduce(max_l, MPI.MAX, MPI.COMM_WORLD) end -for (f1,f2) in zip((:_mean, :_norm, :_minimum, :_maximum, :_sum), (:mean, :norm, :minimum, :maximum, :sum)) +for (f1, f2) in zip( + (:_mean, :_norm, :_minimum, :_maximum, :_sum), (:mean, :norm, :minimum, :maximum, :sum) +) @eval begin $f1(A::ROCArray) = $f2(Array(A)) $f1(A) = $f2(A) - end -end \ No newline at end of file + end +end diff --git a/src/phases/CallArrays.jl b/src/phases/CallArrays.jl index c5422922..8108b0f7 100644 --- a/src/phases/CallArrays.jl +++ b/src/phases/CallArrays.jl @@ -120,8 +120,11 @@ end ## Fallbacks import Base: getindex, setindex! -@inline element(A::Union{Array,CuArray,ROCArray}, I::Vararg{Int,N}) where {N} = getindex(A, I...) -@inline function setelement!(A::Union{Array,CuArray,ROCArray}, x::Number, I::Vararg{Int,N}) where {N} +@inline element(A::Union{Array,CuArray,ROCArray}, I::Vararg{Int,N}) where {N} = + getindex(A, I...) +@inline function setelement!( + A::Union{Array,CuArray,ROCArray}, x::Number, I::Vararg{Int,N} +) where {N} return setindex!(A, x, I...) end From 409758330dcb6ab886203122998b1b0bbb6d7c67 Mon Sep 17 00:00:00 2001 From: albert-de-montserrat Date: Fri, 13 Oct 2023 16:40:17 +0200 Subject: [PATCH 3/8] update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1ec36b83..0704b0cf 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Patrick Sanan , Albert De Montserrat Date: Fri, 13 Oct 2023 16:51:49 +0200 Subject: [PATCH 4/8] AMDGPU support --- src/IO/DataIO.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/IO/DataIO.jl b/src/IO/DataIO.jl index f462b24d..76302f52 100644 --- a/src/IO/DataIO.jl +++ b/src/IO/DataIO.jl @@ -3,7 +3,7 @@ module DataIO using WriteVTK using HDF5 using MPI -using CUDA +using CUDA, AMDGPU import ..JustRelax: Geometry From 55c5a9f94a4a2667febf787ff91d49ff8de553be Mon Sep 17 00:00:00 2001 From: albert-de-montserrat Date: Fri, 13 Oct 2023 17:15:32 +0200 Subject: [PATCH 5/8] remove TimerOutputs --- src/stokes/Stokes2D.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/stokes/Stokes2D.jl b/src/stokes/Stokes2D.jl index 4cb21045..edb3e9da 100644 --- a/src/stokes/Stokes2D.jl +++ b/src/stokes/Stokes2D.jl @@ -42,7 +42,7 @@ using ..JustRelax using CUDA, AMDGPU using ParallelStencil using ParallelStencil.FiniteDifferences2D -using GeoParams, LinearAlgebra, Printf, TimerOutputs +using GeoParams, LinearAlgebra, Printf import JustRelax: elastic_iter_params!, PTArray, Velocity, SymmetricTensor import JustRelax: @@ -509,7 +509,6 @@ function JustRelax.solve!( # solver loop wtime0 = 0.0 λ = @zeros(ni...) - to = TimerOutput() η0 = deepcopy(η) do_visc = true GC.enable(false) @@ -530,9 +529,9 @@ function JustRelax.solve!( θ_dτ, ) - # if rem(iter, 5) == 0 - # @timeit to "ρg" @parallel (@idx ni) compute_ρg!(ρg[2], phase_ratios.center, rheology, args) - # end + if rem(iter, 5) == 0 + @parallel (@idx ni) compute_ρg!(ρg[2], phase_ratios.center, rheology, args) + end @parallel (@idx ni .+ 1) compute_strain_rate!( @strain(stokes)..., stokes.∇V, @velocity(stokes)..., _di... @@ -635,7 +634,6 @@ function JustRelax.solve!( norm_Ry=norm_Ry, norm_∇V=norm_∇V, ) - return to end function JustRelax.solve!( From 9697de318ff4c4924df3cf8cc75764beaa531c5c Mon Sep 17 00:00:00 2001 From: albert-de-montserrat Date: Fri, 13 Oct 2023 18:08:46 +0200 Subject: [PATCH 6/8] remove @timeit --- src/stokes/Stokes2D.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stokes/Stokes2D.jl b/src/stokes/Stokes2D.jl index edb3e9da..b1829230 100644 --- a/src/stokes/Stokes2D.jl +++ b/src/stokes/Stokes2D.jl @@ -542,7 +542,7 @@ function JustRelax.solve!( end if do_visc ν = 1e-2 - @timeit to "viscosity" compute_viscosity!( + compute_viscosity!( η, ν, phase_ratios.center, From 26000f74f4b7eadfacb8d0386cfc8108a416561e Mon Sep 17 00:00:00 2001 From: albert-de-montserrat Date: Sat, 14 Oct 2023 12:19:38 +0200 Subject: [PATCH 7/8] update maxloc calls --- src/stokes/Stokes2D.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/stokes/Stokes2D.jl b/src/stokes/Stokes2D.jl index b1829230..2a57b5c7 100644 --- a/src/stokes/Stokes2D.jl +++ b/src/stokes/Stokes2D.jl @@ -100,7 +100,7 @@ function JustRelax.solve!( # ~preconditioner ητ = deepcopy(η) # @hide_communication b_width begin # communication/computation overlap - compute_maxloc!(ητ, η) + compute_maxloc!(ητ, η; window=(1, 1)) update_halo!(ητ) # end @@ -344,7 +344,7 @@ function JustRelax.solve!( # ~preconditioner ητ = deepcopy(η) # @hide_communication b_width begin # communication/computation overlap - compute_maxloc!(ητ, η) + compute_maxloc!(ητ, η; window=(1, 1)) update_halo!(ητ) # end @@ -379,7 +379,7 @@ function JustRelax.solve!( @parallel (@idx ni) compute_viscosity!( η, ν, @strain(stokes)..., args, rheology, viscosity_cutoff ) - compute_maxloc!(ητ, η) + compute_maxloc!(ητ, η; window=(1, 1)) update_halo!(ητ) @parallel (@idx ni) compute_τ_nonlinear!( @@ -488,7 +488,7 @@ function JustRelax.solve!( # ~preconditioner ητ = deepcopy(η) # @hide_communication b_width begin # communication/computation overlap - compute_maxloc!(ητ, η) + compute_maxloc!(ητ, η; window=(1, 1)) update_halo!(ητ) # end @@ -552,7 +552,7 @@ function JustRelax.solve!( viscosity_cutoff, ) end - compute_maxloc!(ητ, η) + compute_maxloc!(ητ, η; window=(1, 1)) update_halo!(ητ) @parallel (@idx ni) compute_τ_nonlinear!( @@ -665,7 +665,7 @@ function JustRelax.solve!( # ~preconditioner ητ = deepcopy(η) # @hide_communication b_width begin # communication/computation overlap - compute_maxloc!(ητ, η) + compute_maxloc!(ητ, η; window=(1, 1)) update_halo!(ητ) # end From 2a6fe859b5d9a893b8c27f9b3eccde5a74d146a3 Mon Sep 17 00:00:00 2001 From: albert-de-montserrat Date: Sat, 14 Oct 2023 13:57:09 +0200 Subject: [PATCH 8/8] update maxloc call --- src/stokes/Stokes2D.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stokes/Stokes2D.jl b/src/stokes/Stokes2D.jl index 2a57b5c7..7b8c224a 100644 --- a/src/stokes/Stokes2D.jl +++ b/src/stokes/Stokes2D.jl @@ -222,7 +222,7 @@ function JustRelax.solve!( # ~preconditioner ητ = deepcopy(η) # @hide_communication b_width begin # communication/computation overlap - compute_maxloc!(ητ, η; window=(1, 1, 1)) + compute_maxloc!(ητ, η; window=(1, 1)) update_halo!(ητ) # end