From dc88598c87c726021f22348ae26bd62d05295ab7 Mon Sep 17 00:00:00 2001
From: albert-de-montserrat <albert.de.montserrat@gmail.com>
Date: Fri, 29 Nov 2024 17:34:47 +0100
Subject: [PATCH] path way to 3D variational stokes

---
 .../stress_rotation_particles.jl              |  10 +-
 src/variational_stokes/MiniKernels.jl         |  39 ++-
 src/variational_stokes/Stokes2D.jl            |   8 +-
 src/variational_stokes/Stokes3D.jl            | 221 ++++++++++++++
 src/variational_stokes/StressKernels.jl       | 275 ++++++++++++++++++
 src/variational_stokes/VelocityKernels.jl     | 168 +++++++++--
 src/variational_stokes/mask.jl                | 168 +++++++++--
 7 files changed, 835 insertions(+), 54 deletions(-)
 create mode 100644 src/variational_stokes/Stokes3D.jl

diff --git a/src/stress_rotation/stress_rotation_particles.jl b/src/stress_rotation/stress_rotation_particles.jl
index 34441de9..ab874237 100644
--- a/src/stress_rotation/stress_rotation_particles.jl
+++ b/src/stress_rotation/stress_rotation_particles.jl
@@ -3,8 +3,8 @@ using StaticArrays
 # Vorticity tensor
 
 @parallel_indices (I...) function compute_vorticity!(ωxy, Vx, Vy, _dx, _dy)
-    @inline dx(A) = _d_xa(A, I..., _dx)
-    @inline dy(A) = _d_ya(A, I..., _dy)
+    @inline dx(A) = _d_xa(A, _dx, I...)
+    @inline dy(A) = _d_ya(A, _dy, I...)
 
     ωxy[I...] = 0.5 * (dx(Vy) - dy(Vx))
 
@@ -14,9 +14,9 @@ end
 @parallel_indices (I...) function compute_vorticity!(
     ωyz, ωxz, ωxy, Vx, Vy, Vz, _dx, _dy, _dz
 )
-    dx(A) = _d_xa(A, I..., _dx)
-    dy(A) = _d_ya(A, I..., _dy)
-    dz(A) = _d_za(A, I..., _dz)
+    dx(A) = _d_xa(A, _dx, I...)
+    dy(A) = _d_ya(A, _dy, I...)
+    dz(A) = _d_za(A, _dz, I...)
 
     if all(I .≤ size(ωyz))
         ωyz[I...] = 0.5 * (dy(Vz) - dz(Vy))
diff --git a/src/variational_stokes/MiniKernels.jl b/src/variational_stokes/MiniKernels.jl
index 3e2055b8..6f9d217b 100644
--- a/src/variational_stokes/MiniKernels.jl
+++ b/src/variational_stokes/MiniKernels.jl
@@ -10,29 +10,66 @@ end
 # finite differences
 Base.@propagate_inbounds @inline _d_xa(A::T, ϕ::T, _dx, I::Vararg{Integer,N}) where {N,T} =
     (-center(A, ϕ, I...) + right(A, ϕ, I...)) * _dx
+
 Base.@propagate_inbounds @inline _d_ya(A::T, ϕ::T, _dy, I::Vararg{Integer,N}) where {N,T} =
     (-center(A, ϕ, I...) + front(A, ϕ, I...)) * _dy
+
 Base.@propagate_inbounds @inline _d_za(A::T, ϕ::T, _dz, I::Vararg{Integer,N}) where {N,T} =
     (-center(A, ϕ, I...) + front(A, ϕ, I...)) * _dz
+
 Base.@propagate_inbounds @inline _d_xi(A::T, ϕ::T, _dx, I::Vararg{Integer,N}) where {N,T} =
     (-front(A, ϕ, I...) + next(A, ϕ, I...)) * _dx
+
 Base.@propagate_inbounds @inline _d_yi(A::T, ϕ::T, _dy, I::Vararg{Integer,N}) where {N,T} =
     (-right(A, ϕ, I...) + next(A, ϕ, I...)) * _dy
 
-# averages
+Base.@propagate_inbounds @inline _d_zi(A::T, ϕ::T, _dz, I::Vararg{Integer,N}) where {N,T} =
+    (-top(A, ϕ, I...) + next(A, ϕ, I...)) * _dz
+
+# averages 2D
 Base.@propagate_inbounds @inline _av(A::T, ϕ::T, i, j) where {T<:T2} =
     0.25 * mysum(A, ϕ, (i + 1):(i + 2), (j + 1):(j + 2))
+
 Base.@propagate_inbounds @inline _av_a(A::T, ϕ::T, i, j) where {T<:T2} =
     0.25 * mysum(A, ϕ, (i):(i + 1), (j):(j + 1))
+
 Base.@propagate_inbounds @inline _av_xa(A::T, ϕ::T, I::Vararg{Integer,2}) where {T<:T2} =
     (center(A, ϕ, I...) + right(A, ϕ, I...)) * 0.5
+
 Base.@propagate_inbounds @inline _av_ya(A::T, ϕ::T, I::Vararg{Integer,2}) where {T<:T2} =
     (center(A, ϕ, I...) + front(A, ϕ, I...)) * 0.5
+
 Base.@propagate_inbounds @inline _av_xi(A::T, ϕ::T, I::Vararg{Integer,2}) where {T<:T2} =
     (front(A, ϕ, I...), next(A, ϕ, I...)) * 0.5
+
 Base.@propagate_inbounds @inline _av_yi(A::T, ϕ::T, I::Vararg{Integer,2}) where {T<:T2} =
     (right(A, ϕ, I...), next(A, ϕ, I...)) * 0.5
 
+# averages 3D
+Base.@propagate_inbounds @inline _av(A::T, ϕ::T, i, j, k) where {T<:T3} =
+    0.125 * mysum(A, ϕ, (i + 1):(i + 2), (j + 1):(j + 2), (k + 1):(k + 2))
+
+Base.@propagate_inbounds @inline _av_a(A::T, ϕ::T, i, j, k) where {T<:T3} =
+    0.125 * mysum(A, ϕ, (i):(i + 1), (j):(j + 1), (k):(k + 1))
+
+Base.@propagate_inbounds @inline _av_xa(A::T, ϕ::T, I::Vararg{Integer,3}) where {T<:T3} =
+    (center(A, ϕ, I...) + right(A, ϕ, I...)) * 0.5
+
+Base.@propagate_inbounds @inline _av_ya(A::T, ϕ::T, I::Vararg{Integer,3}) where {T<:T3} =
+    (center(A, ϕ, I...) + front(A, ϕ, I...)) * 0.5
+
+Base.@propagate_inbounds @inline _av_za(A::T, ϕ::T, I::Vararg{Integer,3}) where {T<:T3} =
+    (center(A, ϕ, I...) + top(A, ϕ, I...)) * 0.5
+
+Base.@propagate_inbounds @inline _av_xi(A::T, ϕ::T, I::Vararg{Integer,3}) where {T<:T3} =
+    (front(A, ϕ, I...), next(A, ϕ, I...)) * 0.5
+
+Base.@propagate_inbounds @inline _av_yi(A::T, ϕ::T, I::Vararg{Integer,3}) where {T<:T3} =
+    (right(A, ϕ, I...), next(A, ϕ, I...)) * 0.5
+
+Base.@propagate_inbounds @inline _av_zi(A::T, ϕ::T, I::Vararg{Integer,3}) where {T<:T3} =
+    (top(A, ϕ, I...) + next(A, ϕ, I...)) * 0.5
+
 ## Because mysum(::generator) does not work inside CUDA kernels...
 @inline mysum(A, ϕ, ranges::Vararg{T,N}) where {T,N} = mysum(identity, A, ϕ, ranges...)
 
diff --git a/src/variational_stokes/Stokes2D.jl b/src/variational_stokes/Stokes2D.jl
index 9be98e04..52feeb09 100644
--- a/src/variational_stokes/Stokes2D.jl
+++ b/src/variational_stokes/Stokes2D.jl
@@ -81,9 +81,6 @@ function _solve_VS!(
     compute_viscosity!(stokes, phase_ratios, args, rheology, air_phase, viscosity_cutoff)
     displacement2velocity!(stokes, dt, flow_bcs)
 
-    @parallel (@idx ni .+ 1) multi_copy!(@tensor(stokes.τ_o), @tensor(stokes.τ))
-    @parallel (@idx ni) multi_copy!(@tensor_center(stokes.τ_o), @tensor_center(stokes.τ))
-
     while iter ≤ iterMax
         iterMin < iter && err < ϵ && break
 
@@ -100,7 +97,7 @@ function _solve_VS!(
                 stokes.∇V,
                 ητ,
                 rheology,
-                phase_ratios.center,
+                phase_ratios,
                 dt,
                 r,
                 θ_dτ,
@@ -213,6 +210,9 @@ function _solve_VS!(
     # accumulate plastic strain tensor
     @parallel (@idx ni) accumulate_tensor!(stokes.EII_pl, @tensor_center(stokes.ε_pl), dt)
 
+    @parallel (@idx ni .+ 1) multi_copy!(@tensor(stokes.τ_o), @tensor(stokes.τ))
+    @parallel (@idx ni) multi_copy!(@tensor_center(stokes.τ_o), @tensor_center(stokes.τ))
+
     return (
         iter=iter,
         err_evo1=err_evo1,
diff --git a/src/variational_stokes/Stokes3D.jl b/src/variational_stokes/Stokes3D.jl
new file mode 100644
index 00000000..20d28592
--- /dev/null
+++ b/src/variational_stokes/Stokes3D.jl
@@ -0,0 +1,221 @@
+## 3D VISCO-ELASTIC STOKES SOLVER
+
+# backend trait
+function solve_VariationalStokes!(stokes::JustRelax.StokesArrays, args...; kwargs)
+    solve_VariationalStokes!(backend(stokes), stokes, args...; kwargs)
+    return nothing
+end
+
+# entry point for extensions
+function solve_VariationalStokes!(::CPUBackendTrait, stokes, args...; kwargs)
+    return _solve_VS!(stokes, args...; kwargs...)
+end
+
+
+# GeoParams and multiple phases
+function _solve!(
+    stokes::JustRelax.StokesArrays,
+    pt_stokes,
+    di::NTuple{3,T},
+    flow_bcs::AbstractFlowBoundaryConditions,
+    ρg,
+    phase_ratios::JustPIC.PhaseRatios,
+    ϕ::JustRelax.RockRatio,
+    rheology::NTuple{N,AbstractMaterialParamsStruct},
+    args,
+    dt,
+    igg::IGG;
+    iterMax=10e3,
+    nout=500,
+    b_width=(4, 4, 4),
+    verbose=true,
+    viscosity_relaxation=1e-2,
+    viscosity_cutoff=(-Inf, Inf),
+    kwargs...,
+) where {T,N}
+
+    ## UNPACK
+
+    # solver related
+    ϵ = pt_stokes.ϵ
+    # geometry
+    _di = @. 1 / di
+    ni = size(stokes.P)
+    (; η, η_vep) = stokes.viscosity
+
+    # errors
+    err = Inf
+    iter = 0
+    cont = 0
+    err_evo1 = Float64[]
+    err_evo2 = Int64[]
+    norm_Rx = Float64[]
+    norm_Ry = Float64[]
+    norm_Rz = Float64[]
+    norm_∇V = Float64[]
+
+    @copy stokes.P0 stokes.P
+    θ = deepcopy(stokes.P)
+    λ = @zeros(ni...)
+    λv_yz = @zeros(size(stokes.τ.yz)...)
+    λv_xz = @zeros(size(stokes.τ.xz)...)
+    λv_xy = @zeros(size(stokes.τ.xy)...)
+
+    # solver loop
+    wtime0 = 0.0
+    ητ = deepcopy(η)
+
+    # compute buoyancy forces and viscosity
+    compute_ρg!(ρg, phase_ratios, rheology, args)
+    compute_viscosity!(stokes, phase_ratios, args, rheology, viscosity_cutoff)
+
+    # convert displacement to velocity
+    displacement2velocity!(stokes, dt, flow_bcs)
+
+    while iter < 2 || (err > ϵ && iter ≤ iterMax)
+        wtime0 += @elapsed begin
+            # ~preconditioner
+            compute_maxloc!(ητ, η)
+            update_halo!(ητ)
+
+            @parallel (@idx ni) compute_∇V!(stokes.∇V, @velocity(stokes)..., _di...)
+            compute_P!(
+                θ,
+                stokes.P0,
+                stokes.R.RP,
+                stokes.∇V,
+                ητ,
+                rheology,
+                phase_ratios.center,
+                dt,
+                pt_stokes.r,
+                pt_stokes.θ_dτ,
+                args,
+            )
+
+            @parallel (@idx ni) compute_strain_rate!(
+                stokes.∇V, @strain(stokes)..., @velocity(stokes)..., ϕ, _di...
+            )
+
+            # Update buoyancy
+            update_ρg!(ρg, phase_ratios, rheology, args)
+
+            # Update viscosity
+            update_viscosity!(
+                stokes,
+                phase_ratios,
+                args,
+                rheology,
+                viscosity_cutoff;
+                air_phase=air_phase,
+                relaxation=viscosity_relaxation,
+            )
+            # update_stress!(stokes, θ, λ, phase_ratios, rheology, dt, pt_stokes.θ_dτ)
+
+            @parallel (@idx ni .+ 1) update_stresses_center_vertex!(
+                @strain(stokes),
+                @tensor_center(stokes.ε_pl),
+                stokes.EII_pl,
+                @tensor_center(stokes.τ),
+                (stokes.τ.yz, stokes.τ.xz, stokes.τ.xy),
+                @tensor_center(stokes.τ_o),
+                (stokes.τ_o.yz, stokes.τ_o.xz, stokes.τ_o.xy),
+                θ,
+                stokes.P,
+                stokes.viscosity.η,
+                λ,
+                (λv_yz, λv_xz, λv_xy),
+                stokes.τ.II,
+                stokes.viscosity.η_vep,
+                0.2,
+                dt,
+                pt_stokes.θ_dτ,
+                rheology,
+                phase_ratios.center,
+                phase_ratios.vertex,
+                phase_ratios.xy,
+                phase_ratios.yz,
+                phase_ratios.xz,
+                ϕ,
+            )
+            update_halo!(stokes.τ.yz)
+            update_halo!(stokes.τ.xz)
+            update_halo!(stokes.τ.xy)
+
+            @hide_communication b_width begin # communication/computation overlap
+                @parallel compute_V!(
+                    @velocity(stokes)...,
+                    @residuals(stokes.R)...,
+                    stokes.P,
+                    ρg...,
+                    @stress(stokes)...,
+                    ητ,
+                    pt_stokes.ηdτ,
+                    ϕ,
+                    _di...,
+                )
+                # apply boundary conditions
+                velocity2displacement!(stokes, dt)
+                free_surface_bcs!(stokes, flow_bcs, η, rheology, phase_ratios, dt, di)
+                flow_bcs!(stokes, flow_bcs)
+                update_halo!(@velocity(stokes)...)
+            end
+        end
+
+        iter += 1
+        if iter % nout == 0 && iter > 1
+            cont += 1
+            for (norm_Ri, Ri) in zip((norm_Rx, norm_Ry, norm_Rz), @residuals(stokes.R))
+                push!(
+                    norm_Ri,
+                    norm_mpi(Ri[2:(end - 1), 2:(end - 1), 2:(end - 1)]) / length(Ri),
+                )
+            end
+            push!(norm_∇V, norm_mpi(stokes.R.RP) / length(stokes.R.RP))
+            err = max(norm_Rx[cont], norm_Ry[cont], norm_Rz[cont], norm_∇V[cont])
+            push!(err_evo1, err)
+            push!(err_evo2, iter)
+            if igg.me == 0 && (verbose || iter == iterMax)
+                @printf(
+                    "iter = %d, err = %1.3e [norm_Rx=%1.3e, norm_Ry=%1.3e, norm_Rz=%1.3e, norm_∇V=%1.3e] \n",
+                    iter,
+                    err,
+                    norm_Rx[cont],
+                    norm_Ry[cont],
+                    norm_Rz[cont],
+                    norm_∇V[cont]
+                )
+            end
+            isnan(err) && error("NaN(s)")
+        end
+
+        if igg.me == 0 && err ≤ ϵ
+            println("Pseudo-transient iterations converged in $iter iterations")
+        end
+    end
+
+    av_time = wtime0 / (iter - 1) # average time per iteration
+
+    # compute vorticity
+    @parallel (@idx ni .+ 1) compute_vorticity!(
+        stokes.ω.yz, stokes.ω.xz, stokes.ω.xy, @velocity(stokes)..., inv.(di)...
+    )
+
+    # accumulate plastic strain tensor
+    @parallel (@idx ni) accumulate_tensor!(stokes.EII_pl, @tensor_center(stokes.ε_pl), dt)
+
+    @parallel (@idx ni .+ 1) multi_copy!(@tensor(stokes.τ_o), @tensor(stokes.τ))
+    @parallel (@idx ni) multi_copy!(@tensor_center(stokes.τ_o), @tensor_center(stokes.τ))
+
+    return (
+        iter=iter,
+        err_evo1=err_evo1,
+        err_evo2=err_evo2,
+        norm_Rx=norm_Rx,
+        norm_Ry=norm_Ry,
+        norm_Rz=norm_Rz,
+        norm_∇V=norm_∇V,
+        time=wtime0,
+        av_time=av_time,
+    )
+end
diff --git a/src/variational_stokes/StressKernels.jl b/src/variational_stokes/StressKernels.jl
index 9aa70acd..661a0cf0 100644
--- a/src/variational_stokes/StressKernels.jl
+++ b/src/variational_stokes/StressKernels.jl
@@ -141,3 +141,278 @@
 
     return nothing
 end
+
+# 3D kernel
+@parallel_indices (I...) function update_stresses_center_vertex!(
+    ε::NTuple{6},         # normal components @ centers; shear components @ vertices
+    ε_pl::NTuple{6},      # whole Voigt tensor @ centers
+    EII,                  # accumulated plastic strain rate @ centers
+    τ::NTuple{6},         # whole Voigt tensor @ centers
+    τshear_v::NTuple{3},  # shear tensor components @ vertices
+    τ_o::NTuple{6},
+    τshear_ov::NTuple{3}, # shear tensor components @ vertices
+    Pr,
+    Pr_c,
+    η,
+    λ,
+    λv::NTuple{3},
+    τII,
+    η_vep,
+    relλ,
+    dt,
+    θ_dτ,
+    rheology,
+    phase_center,
+    phase_vertex,
+    phase_xy,
+    phase_yz,
+    phase_xz,
+    ϕ::JustRelax.RockRatio,
+)
+    τyzv, τxzv, τxyv = τshear_v
+    τyzv_old, τxzv_old, τxyv_old = τshear_ov
+
+    ni = size(Pr)
+    Ic = clamped_indices(ni, I...)
+
+    ## yz 
+    if all(I .≤ size(ε[4])) && isvalid_yz(ϕ, I...)
+        # interpolate to ith vertex
+        ηv_ij = av_clamped_yz(η, Ic...)
+        Pv_ij = av_clamped_yz(Pr, Ic...)
+        EIIv_ij = av_clamped_yz(EII, Ic...)
+        εxxv_ij = av_clamped_yz(ε[1], Ic...)
+        εyyv_ij = av_clamped_yz(ε[2], Ic...)
+        εzzv_ij = av_clamped_yz(ε[3], Ic...)
+        εyzv_ij = ε[4][I...]
+        εxzv_ij = av_clamped_yz_y(ε[5], Ic...)
+        εxyv_ij = av_clamped_yz_z(ε[6], Ic...)
+
+        τxxv_ij = av_clamped_yz(τ[1], Ic...)
+        τyyv_ij = av_clamped_yz(τ[2], Ic...)
+        τzzv_ij = av_clamped_yz(τ[3], Ic...)
+        τyzv_ij = τyzv[I...]
+        τxzv_ij = av_clamped_yz_y(τxzv, Ic...)
+        τxyv_ij = av_clamped_yz_z(τxyv, Ic...)
+
+        τxxv_old_ij = av_clamped_yz(τ_o[1], Ic...)
+        τyyv_old_ij = av_clamped_yz(τ_o[2], Ic...)
+        τzzv_old_ij = av_clamped_yz(τ_o[3], Ic...)
+        τyzv_old_ij = τyzv_old[I...]
+        τxzv_old_ij = av_clamped_yz_y(τxzv_old, Ic...)
+        τxyv_old_ij = av_clamped_yz_z(τxyv_old, Ic...)
+
+        # vertex parameters
+        phase = @inbounds phase_yz[I...]
+        is_pl, Cv, sinϕv, cosϕv, sinψv, η_regv = plastic_params_phase(
+            rheology, EIIv_ij, phase
+        )
+        _Gvdt = inv(fn_ratio(get_shear_modulus, rheology, phase) * dt)
+        Kv = fn_ratio(get_bulk_modulus, rheology, phase)
+        volumev = isinf(Kv) ? 0.0 : Kv * dt * sinϕv * sinψv # plastic volumetric change K * dt * sinϕ * sinψ
+        dτ_rv = inv(θ_dτ + ηv_ij * _Gvdt + 1.0)
+
+        # stress increments @ vertex
+        dτxxv = compute_stress_increment(τxxv_ij, τxxv_old_ij, ηv_ij, εxxv_ij, _Gvdt, dτ_rv)
+        dτyyv = compute_stress_increment(τyyv_ij, τyyv_old_ij, ηv_ij, εyyv_ij, _Gvdt, dτ_rv)
+        dτzzv = compute_stress_increment(τzzv_ij, τzzv_old_ij, ηv_ij, εzzv_ij, _Gvdt, dτ_rv)
+        dτyzv = compute_stress_increment(τyzv_ij, τyzv_old_ij, ηv_ij, εyzv_ij, _Gvdt, dτ_rv)
+        dτxzv = compute_stress_increment(τxzv_ij, τxzv_old_ij, ηv_ij, εxzv_ij, _Gvdt, dτ_rv)
+        dτxyv = compute_stress_increment(τxyv_ij, τxyv_old_ij, ηv_ij, εxyv_ij, _Gvdt, dτ_rv)
+
+        dτijv = dτxxv, dτyyv, dτzzv, dτyzv, dτxzv, dτxyv
+        τijv = τxxv_ij, τyyv_ij, τzzv_ij, τyzv_ij, τxzv_ij, τxyv_ij
+        τIIv_ij = second_invariant(τijv .+ dτijv)
+
+        # yield function @ vertex
+        Fv = τIIv_ij - Cv - Pv_ij * sinϕv
+        if is_pl && !iszero(τIIv_ij) && Fv > 0
+            # stress correction @ vertex
+            λv[1][I...] =
+                (1.0 - relλ) * λv[1][I...] +
+                relλ * (max(Fv, 0.0) / (ηv_ij * dτ_rv + η_regv + volumev))
+
+            dQdτyz = 0.5 * (τyzv_ij + dτyzv) / τIIv_ij
+            τyzv[I...] += dτyzv - 2.0 * ηv_ij * 0.5 * λv[1][I...] * dQdτyz * dτ_rv
+        else
+            # stress correction @ vertex
+            τyzv[I...] += dτyzv
+        end
+    end
+
+    ## xz
+    if all(I .≤ size(ε[5])) && isvalid_xz(ϕ, I...)
+        # interpolate to ith vertex
+        ηv_ij = av_clamped_xz(η, Ic...)
+        EIIv_ij = av_clamped_xz(EII, Ic...)
+        Pv_ij = av_clamped_xz(Pr, Ic...)
+        εxxv_ij = av_clamped_xz(ε[1], Ic...)
+        εyyv_ij = av_clamped_xz(ε[2], Ic...)
+        εzzv_ij = av_clamped_xz(ε[3], Ic...)
+        εyzv_ij = av_clamped_xz_x(ε[4], Ic...)
+        εxzv_ij = ε[5][I...]
+        εxyv_ij = av_clamped_xz_z(ε[6], Ic...)
+        τxxv_ij = av_clamped_xz(τ[1], Ic...)
+        τyyv_ij = av_clamped_xz(τ[2], Ic...)
+        τzzv_ij = av_clamped_xz(τ[3], Ic...)
+        τyzv_ij = av_clamped_xz_x(τyzv, Ic...)
+        τxzv_ij = τxzv[I...]
+        τxyv_ij = av_clamped_xz_z(τxyv, Ic...)
+        τxxv_old_ij = av_clamped_xz(τ_o[1], Ic...)
+        τyyv_old_ij = av_clamped_xz(τ_o[2], Ic...)
+        τzzv_old_ij = av_clamped_xz(τ_o[3], Ic...)
+        τyzv_old_ij = av_clamped_xz_x(τyzv_old, Ic...)
+        τxzv_old_ij = τxzv_old[I...]
+        τxyv_old_ij = av_clamped_xz_z(τxyv_old, Ic...)
+
+        # vertex parameters
+        phase = @inbounds phase_xz[I...]
+        is_pl, Cv, sinϕv, cosϕv, sinψv, η_regv = plastic_params_phase(
+            rheology, EIIv_ij, phase
+        )
+        _Gvdt = inv(fn_ratio(get_shear_modulus, rheology, phase) * dt)
+        Kv = fn_ratio(get_bulk_modulus, rheology, phase)
+        volumev = isinf(Kv) ? 0.0 : Kv * dt * sinϕv * sinψv # plastic volumetric change K * dt * sinϕ * sinψ
+        dτ_rv = inv(θ_dτ + ηv_ij * _Gvdt + 1.0)
+
+        # stress increments @ vertex
+        dτxxv = compute_stress_increment(τxxv_ij, τxxv_old_ij, ηv_ij, εxxv_ij, _Gvdt, dτ_rv)
+        dτyyv = compute_stress_increment(τyyv_ij, τyyv_old_ij, ηv_ij, εyyv_ij, _Gvdt, dτ_rv)
+        dτzzv = compute_stress_increment(τzzv_ij, τzzv_old_ij, ηv_ij, εzzv_ij, _Gvdt, dτ_rv)
+        dτyzv = compute_stress_increment(τyzv_ij, τyzv_old_ij, ηv_ij, εyzv_ij, _Gvdt, dτ_rv)
+        dτxzv = compute_stress_increment(τxzv_ij, τxzv_old_ij, ηv_ij, εxzv_ij, _Gvdt, dτ_rv)
+        dτxyv = compute_stress_increment(τxyv_ij, τxyv_old_ij, ηv_ij, εxyv_ij, _Gvdt, dτ_rv)
+
+        dτijv = dτxxv, dτyyv, dτzzv, dτyzv, dτxzv, dτxyv
+        τijv = τxxv_ij, τyyv_ij, τzzv_ij, τyzv_ij, τxzv_ij, τxyv_ij
+        τIIv_ij = second_invariant(τijv .+ dτijv)
+
+        # yield function @ vertex
+        Fv = τIIv_ij - Cv - Pv_ij * sinϕv
+        if is_pl && !iszero(τIIv_ij) && Fv > 0
+            # stress correction @ vertex
+            λv[2][I...] =
+                (1.0 - relλ) * λv[2][I...] +
+                relλ * (max(Fv, 0.0) / (ηv_ij * dτ_rv + η_regv + volumev))
+
+            dQdτxz = 0.5 * (τxzv_ij + dτxzv) / τIIv_ij
+            τxzv[I...] += dτxzv - 2.0 * ηv_ij * 0.5 * λv[2][I...] * dQdτxz * dτ_rv
+        else
+            # stress correction @ vertex
+            τxzv[I...] += dτxzv
+        end
+    end
+
+    ## xy
+    if all(I .≤ size(ε[6]))  && isvalid_xy(ϕ, I...)
+        # interpolate to ith vertex
+        ηv_ij = av_clamped_xy(η, Ic...)
+        EIIv_ij = av_clamped_xy(EII, Ic...)
+        Pv_ij = av_clamped_xy(Pr, Ic...)
+        εxxv_ij = av_clamped_xy(ε[1], Ic...)
+        εyyv_ij = av_clamped_xy(ε[2], Ic...)
+        εzzv_ij = av_clamped_xy(ε[3], Ic...)
+        εyzv_ij = av_clamped_xy_x(ε[4], Ic...)
+        εxzv_ij = av_clamped_xy_y(ε[5], Ic...)
+        εxyv_ij = ε[6][I...]
+
+        τxxv_ij = av_clamped_xy(τ[1], Ic...)
+        τyyv_ij = av_clamped_xy(τ[2], Ic...)
+        τzzv_ij = av_clamped_xy(τ[3], Ic...)
+        τyzv_ij = av_clamped_xy_x(τyzv, Ic...)
+        τxzv_ij = av_clamped_xy_y(τxzv, Ic...)
+        τxyv_ij = τxyv[I...]
+
+        τxxv_old_ij = av_clamped_xy(τ_o[1], Ic...)
+        τyyv_old_ij = av_clamped_xy(τ_o[2], Ic...)
+        τzzv_old_ij = av_clamped_xy(τ_o[3], Ic...)
+        τyzv_old_ij = av_clamped_xy_x(τyzv_old, Ic...)
+        τxzv_old_ij = av_clamped_xy_y(τxzv_old, Ic...)
+        τxyv_old_ij = τxyv_old[I...]
+
+        # vertex parameters
+        phase = @inbounds phase_xy[I...]
+        is_pl, Cv, sinϕv, cosϕv, sinψv, η_regv = plastic_params_phase(
+            rheology, EIIv_ij, phase
+        )
+        _Gvdt = inv(fn_ratio(get_shear_modulus, rheology, phase) * dt)
+        Kv = fn_ratio(get_bulk_modulus, rheology, phase)
+        volumev = isinf(Kv) ? 0.0 : Kv * dt * sinϕv * sinψv # plastic volumetric change K * dt * sinϕ * sinψ
+        dτ_rv = inv(θ_dτ + ηv_ij * _Gvdt + 1.0)
+
+        # stress increments @ vertex
+        dτxxv = compute_stress_increment(τxxv_ij, τxxv_old_ij, ηv_ij, εxxv_ij, _Gvdt, dτ_rv)
+        dτyyv = compute_stress_increment(τyyv_ij, τyyv_old_ij, ηv_ij, εyyv_ij, _Gvdt, dτ_rv)
+        dτzzv = compute_stress_increment(τzzv_ij, τzzv_old_ij, ηv_ij, εzzv_ij, _Gvdt, dτ_rv)
+        dτyzv = compute_stress_increment(τyzv_ij, τyzv_old_ij, ηv_ij, εyzv_ij, _Gvdt, dτ_rv)
+        dτxzv = compute_stress_increment(τxzv_ij, τxzv_old_ij, ηv_ij, εxzv_ij, _Gvdt, dτ_rv)
+        dτxyv = compute_stress_increment(τxyv_ij, τxyv_old_ij, ηv_ij, εxyv_ij, _Gvdt, dτ_rv)
+        dτijv = dτxxv, dτyyv, dτzzv, dτyzv, dτxzv, dτxyv
+        τijv = τxxv_ij, τyyv_ij, τzzv_ij, τyzv_ij, τxzv_ij, τxyv_ij
+        τIIv_ij = second_invariant(τijv .+ dτijv)
+
+        # yield function @ vertex
+        Fv = τIIv_ij - Cv - Pv_ij * sinϕv
+        if is_pl && !iszero(τIIv_ij) && Fv > 0
+            # stress correction @ vertex
+            λv[3][I...] =
+                (1.0 - relλ) * λv[3][I...] +
+                relλ * (max(Fv, 0.0) / (ηv_ij * dτ_rv + η_regv + volumev))
+
+            dQdτxy = 0.5 * (τxyv_ij + dτxyv) / τIIv_ij
+            τxyv[I...] += dτxyv - 2.0 * ηv_ij * 0.5 * λv[3][I...] * dQdτxy * dτ_rv
+        else
+            # stress correction @ vertex
+            τxyv[I...] += dτxyv
+        end
+    end
+
+    ## center
+    if all(I .≤ ni) && isvalid_c(ϕ, I...)
+        # Material properties
+        phase = @inbounds phase_center[I...]
+        _Gdt = inv(fn_ratio(get_shear_modulus, rheology, phase) * dt)
+        is_pl, C, sinϕ, cosϕ, sinψ, η_reg = plastic_params_phase(rheology, EII[I...], phase)
+        K = fn_ratio(get_bulk_modulus, rheology, phase)
+        volume = isinf(K) ? 0.0 : K * dt * sinϕ * sinψ # plastic volumetric change K * dt * sinϕ * sinψ
+        ηij = η[I...]
+        dτ_r = inv(θ_dτ + ηij * _Gdt + 1.0)
+
+        # cache strain rates for center calculations
+        τij, τij_o, εij = cache_tensors(τ, τ_o, ε, I...)
+
+        # visco-elastic strain rates @ center
+        εij_ve = @. εij + 0.5 * τij_o * _Gdt
+        εII_ve = second_invariant(εij_ve)
+        # stress increments @ center
+        dτij = @. (-(τij - τij_o) * ηij * _Gdt - τij + 2.0 * ηij * εij) * dτ_r
+        τII_ij = second_invariant(dτij .+ τij)
+        # yield function @ center
+        F = τII_ij - C - Pr[I...] * sinϕ
+
+        if is_pl && !iszero(τII_ij) && F > 0
+            # stress correction @ center
+            λ[I...] =
+                (1.0 - relλ) * λ[I...] +
+                relλ * (max(F, 0.0) / (η[I...] * dτ_r + η_reg + volume))
+            dQdτij = @. 0.5 * (τij + dτij) / τII_ij
+            εij_pl = λ[I...] .* dQdτij
+            dτij = @. dτij - 2.0 * ηij * εij_pl * dτ_r
+            τij = dτij .+ τij
+            setindex!.(τ, τij, I...)
+            setindex!.(ε_pl, εij_pl, I...)
+            τII[I...] = second_invariant(τij)
+            Pr_c[I...] = Pr[I...] + K * dt * λ[I...] * sinψ
+            η_vep[I...] = 0.5 * τII_ij / εII_ve
+        else
+            # stress correction @ center
+            setindex!.(τ, dτij .+ τij, I...)
+            η_vep[I...] = ηij
+            τII[I...] = τII_ij
+        end
+
+        Pr_c[I...] = Pr[I...] + (isinf(K) ? 0.0 : K * dt * λ[I...] * sinψ)
+    end
+
+    return nothing
+end
diff --git a/src/variational_stokes/VelocityKernels.jl b/src/variational_stokes/VelocityKernels.jl
index 4a4bc767..4acfd6b2 100644
--- a/src/variational_stokes/VelocityKernels.jl
+++ b/src/variational_stokes/VelocityKernels.jl
@@ -16,10 +16,6 @@ end
     @inline d_yi(A) = _d_yi(A, _dy, i, j)
     @inline d_xa(A) = _d_xa(A, _dx, i, j)
     @inline d_ya(A) = _d_ya(A, _dy, i, j)
-    @inline d_xi(A, ϕ) = _d_xi(A, ϕ, _dx, i, j)
-    @inline d_xa(A, ϕ) = _d_xa(A, ϕ, _dx, i, j)
-    @inline d_yi(A, ϕ) = _d_yi(A, ϕ, _dy, i, j)
-    @inline d_ya(A, ϕ) = _d_ya(A, ϕ, _dy, i, j)
 
     if all((i, j) .≤ size(εxx))
         if isvalid_c(ϕ, i, j)
@@ -30,34 +26,75 @@ end
             εxx[i, j] = zero(T)
             εyy[i, j] = zero(T)
         end
-        # εxx[i, j] = (Vx[i + 1, j + 1] * ϕ.Vx[i + 1, j] - Vx[i, j + 1] * ϕ.Vx[i, j]) * _dx - ∇V_ij
-        # εyy[i, j] = (Vy[i + 1, j + 1] * ϕ.Vy[i, j + 1] - Vy[i + 1, j] * ϕ.Vy[i, j]) * _dy - ∇V_ij
     end
 
     εxy[i, j] = if isvalid_v(ϕ, i, j)
-        0.5 * (
-            d_ya(Vx) + 
-            d_xa(Vy)
-        )
+        0.5 * (d_ya(Vx) + d_xa(Vy))
     else
         zero(T)
     end
 
-    # εxy[i, j] =  0.5 * (
-    #         d_ya(Vx) + 
-    #         d_xa(Vy)
-    #     )
-    # vy_mask_left  = ϕ.Vy[max(i - 1, 1), j]
-    # vy_mask_right = ϕ.Vy[min(i + 1, size(ϕ.Vy, 1)), j]
-    
-    # vx_mask_bot = ϕ.Vx[i, max(j - 1, 1)]
-    # vx_mask_top = ϕ.Vx[i, min(j + 1, size(ϕ.Vx, 2))]
-    
-    # εxy[i, j] = 0.5 * (
-    #     (Vx[i, j+1] * vx_mask_top - Vx[i, j] * vx_mask_bot) * _dy + 
-    #     (Vy[i+1, j] * vy_mask_right - Vy[i, j] * vy_mask_left) * _dx
-    # )
+    return nothing
+end
+
+@parallel_indices (i, j, k) function compute_strain_rate!(
+    ∇V::AbstractArray{T,3},
+    εxx,
+    εyy,
+    εzz,
+    εyz,
+    εxz,
+    εxy,
+    Vx,
+    Vy,
+    Vz,
+    ϕ::JustRelax.RockRatio,
+    _dx,
+    _dy,
+    _dz,
+) where {T}
+    d_xi(A) = _d_xi(A, _dx, i, j, k)
+    d_yi(A) = _d_yi(A, _dy, i, j, k)
+    d_zi(A) = _d_zi(A, _dz, i, j, k)
 
+    @inbounds begin
+        # normal components are all located @ cell centers
+        if all((i, j, k) .≤ size(εxx))
+            if isvalid_c(ϕ, i, j, k)
+                ∇Vijk = ∇V[i, j, k] * inv(3)
+                # Compute ε_xx
+                εxx[i, j, k] = d_xi(Vx) - ∇Vijk
+                # Compute ε_yy
+                εyy[i, j, k] = d_yi(Vy) - ∇Vijk
+                # Compute ε_zz
+                εzz[i, j, k] = d_zi(Vz) - ∇Vijk
+            end
+        end
+        # Compute ε_yz
+        if all((i, j, k) .≤ size(εyz)) && isvalid_yz(ϕ, i, j, k)
+            εyz[i, j, k] =
+                0.5 * (
+                    _dz * (Vy[i + 1, j, k + 1] - Vy[i + 1, j, k]) +
+                    _dy * (Vz[i + 1, j + 1, k] - Vz[i + 1, j, k])
+                )
+        end
+        # Compute ε_xz
+        if all((i, j, k) .≤ size(εxz)) && isvalid_xz(ϕ, i, j, k)
+            εxz[i, j, k] =
+                0.5 * (
+                    _dz * (Vx[i, j + 1, k + 1] - Vx[i, j + 1, k]) +
+                    _dx * (Vz[i + 1, j + 1, k] - Vz[i, j + 1, k])
+                )
+        end
+        # Compute ε_xy
+        if all((i, j, k) .≤ size(εxy)) && isvalid_xy(ϕ, i, j, k)
+            εxy[i, j, k] =
+                0.5 * (
+                    _dy * (Vx[i, j + 1, k + 1] - Vx[i, j, k + 1]) +
+                    _dx * (Vy[i + 1, j, k + 1] - Vy[i, j, k + 1])
+                )
+        end
+    end
     return nothing
 end
 
@@ -84,6 +121,7 @@ end
     d_ya(A, ϕ) = _d_ya(A, ϕ, _dy, i, j)
     av_xa(A, ϕ) = _av_xa(A, ϕ, i, j)
     av_ya(A, ϕ) = _av_ya(A, ϕ, i, j)
+
     av_xa(A) = _av_xa(A, i, j)
     av_ya(A) = _av_ya(A, i, j)
     harm_xa(A) = _av_xa(A, i, j)
@@ -118,3 +156,85 @@ end
 
     return nothing
 end
+
+@parallel_indices (i, j, k) function compute_V!(
+    Vx::AbstractArray{T,3},
+    Vy,
+    Vz,
+    Rx,
+    Ry,
+    Rz,
+    P,
+    fx,
+    fy,
+    fz,
+    τxx,
+    τyy,
+    τzz,
+    τyz,
+    τxz,
+    τxy,
+    ητ,
+    ηdτ,
+    ϕ::JustRelax.RockRatio,
+    _dx,
+    _dy,
+    _dz,
+) where {T}
+    @inline harm_x(A) = _harm_x(A, i, j, k)
+    @inline harm_y(A) = _harm_y(A, i, j, k)
+    @inline harm_z(A) = _harm_z(A, i, j, k)
+    @inline d_xa(A, ϕ) = _d_xa(A, ϕ, _dx, i, j, k)
+    @inline d_ya(A, ϕ) = _d_ya(A, ϕ, _dy, i, j, k)
+    @inline d_za(A, ϕ) = _d_za(A, ϕ, _dz, i, j, k)
+    @inline d_xi(A, ϕ) = _d_xi(A, ϕ, _dx, i, j, k)
+    @inline d_yi(A, ϕ) = _d_yi(A, ϕ, _dy, i, j, k)
+    @inline d_zi(A, ϕ) = _d_zi(A, ϕ, _dz, i, j, k)
+    @inline av_x(A) = _av_x(A, i, j, k)
+    @inline av_y(A) = _av_y(A, i, j, k)
+    @inline av_z(A) = _av_z(A, i, j, k)
+    @inline av_x(A, ϕ) = _av_x(A, ϕ, i, j, k)
+    @inline av_y(A, ϕ) = _av_y(A, ϕ, i, j, k)
+    @inline av_z(A, ϕ) = _av_z(A, ϕ, i, j, k)
+
+    @inbounds begin
+        if all((i, j, k) .< size(Vx) .- 1)
+            if isvalid_vx(ϕ, i + 1, j, k)
+                Rx_ijk =
+                    Rx[i, j, k] =
+                        d_xa(τxx, ϕ.center) + d_yi(τxy, ϕ.xy) + d_zi(τxz, ϕ.xz) -
+                        d_xa(P, ϕ.center) - av_x(fx, ϕ.center)
+                Vx[i + 1, j + 1, k + 1] += Rx_ijk * ηdτ / av_x(ητ)
+            else
+                Rx[i, j, k] = zero(T)
+                Vx[i + 1, j + 1, k + 1] = zero(T)
+            end
+        end
+        if all((i, j, k) .< size(Vy) .- 1)
+            if isvalid_vy(ϕ, i, j + 1, k)
+                Ry_ijk =
+                    Ry[i, j, k] =
+                        d_ya(τyy, ϕ.center) + d_xi(τxy, ϕ.xy) + d_zi(τyz, ϕ.yz) -
+                        d_ya(P, ϕ.center) - av_y(fy, ϕ.center)
+                Vy[i + 1, j + 1, k + 1] += Ry_ijk * ηdτ / av_y(ητ)
+            else
+                Ry[i, j, k] = zero(T)
+                Vy[i + 1, j + 1, k + 1] = zero(T)
+            end
+        end
+        if all((i, j, k) .< size(Vz) .- 1)
+            if isvalid_vz(ϕ, i, j, k + 1)
+                Rz_ijk =
+                    Rz[i, j, k] =
+                        d_za(τzz, ϕ.center) + d_xi(τxz, ϕ.xz) + d_yi(τyz, ϕ.yz) -
+                        d_za(P, ϕ.center) - av_z(fz, ϕ.center)
+                Vz[i + 1, j + 1, k + 1] += Rz_ijk * ηdτ / av_z(ητ)
+            else
+                Rz[i, j, k] = zero(T)
+                Vz[i + 1, j + 1, k + 1] = zero(T)
+            end
+        end
+    end
+
+    return nothing
+end
diff --git a/src/variational_stokes/mask.jl b/src/variational_stokes/mask.jl
index 3ad186c5..01be3cc9 100644
--- a/src/variational_stokes/mask.jl
+++ b/src/variational_stokes/mask.jl
@@ -96,10 +96,18 @@ Check if  `ϕ.center[inds...]` is a not a nullspace.
 - `inds`: Cartesian indices to check.
 """
 Base.@propagate_inbounds @inline function isvalid_c(ϕ::JustRelax.RockRatio, i, j)
-    vx = (ϕ.Vx[i, j] > 0) * (ϕ.Vx[i + 1, j] > 0)
-    vy = (ϕ.Vy[i, j] > 0) * (ϕ.Vy[i, j + 1] > 0)
+    vx = isvalid(ϕ.Vx, i, j) * isvalid(ϕ.Vx[i + 1, j])
+    vy = isvalid(ϕ.Vy, i, j) * isvalid(ϕ.Vy[i, j + 1])
     v = vx * vy
-    return v * (ϕ.center[i, j] > 0)
+    return v * isvalid(ϕ.center, i, j)
+end
+
+Base.@propagate_inbounds @inline function isvalid_c(ϕ::JustRelax.RockRatio, i, j, k)
+    vx = isvalid(ϕ.Vx, i, j, k) * isvalid(ϕ.Vx, i + 1, j, k)
+    vy = isvalid(ϕ.Vy, i, j, k) * isvalid(ϕ.Vy, i, j + 1, k)
+    vz = isvalid(ϕ.Vz, i, j, k) * isvalid(ϕ.Vz, i, j, k + 1)
+    v = vx * vy * vz
+    return v * isvalid(ϕ.center, i, j, k)
 end
 
 """
@@ -115,16 +123,17 @@ Base.@propagate_inbounds @inline function isvalid_v(ϕ::JustRelax.RockRatio, i,
     nx, ny = size(ϕ.Vx)
     j_bot = max(j - 1, 1)
     j0 = min(j, ny)
-    vx = (ϕ.Vx[i, j0] > 0) * (ϕ.Vx[i, j_bot] > 0)
+    vx = isvalid(ϕ.Vx, i, j0) * isvalid(ϕ.Vx, i, j_bot)
 
     nx, ny = size(ϕ.Vy)
     i_left = max(i - 1, 1)
     i0 = min(i, nx)
-    vy = (ϕ.Vy[i0, j] > 0) * (ϕ.Vy[i_left, j] > 0)
+    vy = isvalid(ϕ.Vy, i0, j) * isvalid(ϕ.Vy, i_left, j)
     v = vx * vy
-    return v * (ϕ.vertex[i, j] > 0)
+    return v * isvalid(ϕ.vertex, i, j)
 end
 
+
 """
     isvalid_vx(ϕ::JustRelax.RockRatio, inds...)
 
@@ -134,14 +143,20 @@ Check if  `ϕ.Vx[inds...]` is a not a nullspace.
 - `ϕ::JustRelax.RockRatio`: The `RockRatio` object to check against.
 - `inds`: Cartesian indices to check.
 """
-Base.@propagate_inbounds @inline function isvalid_vx(ϕ::JustRelax.RockRatio, i, j)
-    # c = (ϕ.center[i, j] > 0) * (ϕ.center[i - 1, j] > 0)
-    # v = (ϕ.vertex[i, j] > 0) * (ϕ.vertex[i, j + 1] > 0)
-    # cv = c * v
-    # return cv * (ϕ.Vx[i, j] > 0)
-    return (ϕ.Vx[i, j] > 0)
+Base.@propagate_inbounds @inline function isvalid_vx(
+    ϕ::JustRelax.RockRatio, I::Vararg{Integer,N}
+) where {N}
+    return isvalid(ϕ.Vx, I...)
 end
 
+# Base.@propagate_inbounds @inline function isvalid_vx(ϕ::JustRelax.RockRatio, I::Vararg{Integer,N}) where {N}
+#     # c = (ϕ.center[i, j] > 0) * (ϕ.center[i - 1, j] > 0)
+#     # v = (ϕ.vertex[i, j] > 0) * (ϕ.vertex[i, j + 1] > 0)
+#     # cv = c * v
+#     # return cv * (ϕ.Vx[i, j] > 0)
+#     return (ϕ.Vx[I...] > 0)
+# end
+
 """
     isvalid_vy(ϕ::JustRelax.RockRatio, inds...)
 
@@ -151,14 +166,129 @@ Check if  `ϕ.Vy[inds...]` is a not a nullspace.
 - `ϕ::JustRelax.RockRatio`: The `RockRatio` object to check against.
 - `inds`: Cartesian indices to check.
 """
-Base.@propagate_inbounds @inline function isvalid_vy(ϕ::JustRelax.RockRatio, i, j)
-    # c = (ϕ.center[i, j] > 0) * (ϕ.center[i, j - 1] > 0)
-    # v = (ϕ.vertex[i, j] > 0) * (ϕ.vertex[i + 1, j] > 0)
-    # cv = c * v
-    # return cv * (ϕ.Vy[i, j] > 0)
-    return (ϕ.Vy[i, j] > 0)
+# Base.@propagate_inbounds @inline function isvalid_vy(ϕ::JustRelax.RockRatio, i, j)
+#     # c = (ϕ.center[i, j] > 0) * (ϕ.center[i, j - 1] > 0)
+#     # v = (ϕ.vertex[i, j] > 0) * (ϕ.vertex[i + 1, j] > 0)
+#     # cv = c * v
+#     # return cv * (ϕ.Vy[i, j] > 0)
+#     return (ϕ.Vy[i, j] > 0)
+# end
+Base.@propagate_inbounds @inline function isvalid_vy(
+    ϕ::JustRelax.RockRatio, I::Vararg{Integer,N}
+) where {N}
+    return isvalid(ϕ.Vy, I...)
+end
+
+"""
+    isvalid_vz(ϕ::JustRelax.RockRatio, inds...)
+
+Check if  `ϕ.Vz[inds...]` is a not a nullspace.
+
+# Arguments
+- `ϕ::JustRelax.RockRatio`: The `RockRatio` object to check against.
+- `inds`: Cartesian indices to check.
+"""
+Base.@propagate_inbounds @inline function isvalid_vz(
+    ϕ::JustRelax.RockRatio, I::Vararg{Integer,N}
+) where {N}
+    return isvalid(ϕ.Vz, I...)
+end
+
+Base.@propagate_inbounds @inline function isvalid_velocity(ϕ::JustRelax.RockRatio, i, j)
+    return isvalid(ϕ.Vx, i, j) * isvalid(ϕ.Vy, i, j)
+end
+
+Base.@propagate_inbounds @inline function isvalid_velocity(ϕ::JustRelax.RockRatio, i, j, k)
+    return isvalid(ϕ.Vx, i, j, k) * isvalid(ϕ.Vy, i, j, k) * isvalid(ϕ.Vz, i, j, k)
+end
+
+Base.@propagate_inbounds @inline function isvalid_v(ϕ::JustRelax.RockRatio, i, j, k)
+    # yz
+    nx, ny, nz = size(ϕ.yz)
+    i_left = max(i - 1, 1)
+    i_right = min(i, nx)
+    yz = isvalid(ϕ.yz, i_left, j, k) * isvalid(ϕ.yz, i_right, j, k)
+
+    # xz
+    nx, ny, nz = size(ϕ.xz)
+    j_front = max(j - 1, 1)
+    j_back = min(j, ny)
+    xz = isvalid(ϕ.xz, i, j_front, k) * isvalid(ϕ.xz, i, j_back, k)
+
+    # xy
+    nx, ny, nz = size(ϕ.xy)
+    k_top = max(k - 1, 1)
+    k_bot = min(k, nz)
+    xy = isvalid(ϕ.xy, i, j, k_top) * isvalid(ϕ.xy, i, j, k_back)
+
+    # V
+    v = yz * xz * xy
+
+    return v * isvalid(ϕ.vertex, i, j, k)
+end
+
+Base.@propagate_inbounds @inline function isvalid_xz(ϕ::JustRelax.RockRatio, i, j, k)
+
+    # check vertices
+    v = isvalid(ϕ.vertex, i, j, k) * isvalid(ϕ.vertex, i, j + 1, k)
+
+    # check vz
+    nx, ny, nz = size(ϕ.vz)
+    i_left = max(i - 1, 1)
+    i_right = min(i, nx)
+    vz = isvalid(ϕ.vz, i_left, j, k) * isvalid(ϕ.vz, i_right, j, k)
+
+    # check vx
+    nx, ny, nz = size(ϕ.vx)
+    k_top = max(k - 1, 1)
+    k_bot = min(k, nz)
+    vx = isvalid(ϕ.vx, i, j, k_top) * isvalid(ϕ.vx, i, j, k_back)
+
+    return v * vx * vz * isvalid(ϕ.vertex, i, j, k)
 end
 
+Base.@propagate_inbounds @inline function isvalid_xy(ϕ::JustRelax.RockRatio, i, j, k)
+
+    # check vertices
+    v = isvalid(ϕ.vertex, i, j, k) * isvalid(ϕ.vertex, i, j, k + 1)
+
+    # check vx
+    nx, ny, nz = size(ϕ.vx)
+    j_front = max(j - 1, 1)
+    j_back = min(j, ny)
+    vx = isvalid(ϕ.vx, i, j_front, k) * isvalid(ϕ.vx, i, j_back, k)
+
+    # check vy
+    nx, ny, nz = size(ϕ.vy)
+    i_left = max(i - 1, 1)
+    i_right = min(i, nx)
+    vy = isvalid(ϕ.vy, i_left, j, k) * isvalid(ϕ.vy, i_right, j, k)
+
+    return v * vy * vz * isvalid(ϕ.vertex, i, j, k)
+end
+
+Base.@propagate_inbounds @inline function isvalid_yz(ϕ::JustRelax.RockRatio, i, j, k)
+
+    # check vertices
+    v = isvalid(ϕ.vertex, i, j, k) * isvalid(ϕ.vertex, i + 1, j, k)
+
+    # check vz
+    nx, ny, nz = size(ϕ.vz)
+    j_front = max(j - 1, 1)
+    j_back = min(j, ny)
+    vz = isvalid(ϕ.vz, i, j_front, k) * isvalid(ϕ.vz, i, j_back, k)
+
+    # check vy
+    nx, ny, nz = size(ϕ.vy)
+    k_top = max(k - 1, 1)
+    k_bot = min(k, nz)
+    vy = isvalid(ϕ.vy, i, j, k_top) * isvalid(ϕ.vy, i, j, k_back)
+
+    return v * vy * vz * isvalid(ϕ.vertex, i, j, k)
+end
+
+Base.@propagate_inbounds @inline isvalid(ϕ, I::Vararg{Integer,N}) where {N} = ϕ[I...] > 0
+
 ######
 
 # """
@@ -218,5 +348,3 @@ end
 #     cv = c || v
 #     return cv || isvalid(ϕ.Vy, i, j)
 # end
-
-Base.@propagate_inbounds @inline isvalid(ϕ, I::Vararg{Integer,N}) where {N} = ϕ[I...] > 0