Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create CUDA extension #259

Merged
merged 8 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ StatsBase = "0.32, 0.33, 0.34"
Tables = "1.9"
julia = "1.6"

[extensions]
EvoTreesCUDAExt = "CUDA"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Expand All @@ -37,4 +41,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
docs = ["Documenter"]
test = ["DataFrames", "Test", "MLJBase", "MLJTestInterface"]
test = ["CUDA", "DataFrames", "Test", "MLJBase", "MLJTestInterface"]

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Julia's formatter, it appears that expected order of project keys to be:

  • deps
  • compat
  • weakdeps
  • extensions
  • compat
  • extras
  • targets

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on the Julia version you use: On Julia 1.6, the order Pkg uses is

  • deps
  • compat
  • extensions
  • extras
  • targets
  • weakdeps
    whereas on Julia 1.7, 1.8, and 1.9 it is
  • deps
  • weakdeps
  • extensions
  • compat
  • extras
  • targets

See eg the discussion in JuliaTesting/Aqua.jl#105 and in particular JuliaTesting/Aqua.jl#105 (comment)

I guess you were referring to the order in Julia > 1.6 and accidentally included the compat section twice?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I meant to refer to Julia > 1.6, sorry about the confusion for the duplicated compat.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the order to the convention in Julia > 1.6.

6 changes: 6 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ m = fit_evotree(config, dtrain; target_name="y", fnames=["x1", "x3"]);

### GPU Acceleration

EvoTrees supports training and inference on Nvidia GPU's with [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl).
Note that on Julia ≥ 1.9 CUDA support is only enabled when CUDA.jl is installed and loaded, by another package or explicitly with e.g.
```julia
using CUDA
```

If running on a CUDA enabled machine, training and inference on GPU can be triggered through the `device` kwarg:

```julia
Expand Down
4 changes: 0 additions & 4 deletions docs/src/internals.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ EvoTrees.update_gains!
EvoTrees.predict!
EvoTrees.subsample
EvoTrees.split_set_chunk!
EvoTrees.split_chunk_kernel!
```

## Histogram
Expand All @@ -28,7 +27,4 @@ EvoTrees.split_chunk_kernel!
EvoTrees.get_edges
EvoTrees.binarize
EvoTrees.update_hist!
EvoTrees.hist_kernel!
EvoTrees.hist_kernel_vec!
EvoTrees.predict_kernel!
```
22 changes: 22 additions & 0 deletions ext/EvoTreesCUDAExt/EvoTreesCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module EvoTreesCUDAExt

using EvoTrees
using CUDA

# This should be different on CPUs and GPUs
EvoTrees.device_ones(::Type{<:EvoTrees.GPU}, ::Type{T}, n::Int) where {T} = CUDA.ones(T, n)

Check warning on line 7 in ext/EvoTreesCUDAExt/EvoTreesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/EvoTreesCUDAExt.jl#L7

Added line #L7 was not covered by tests
EvoTrees.device_array_type(::Type{<:EvoTrees.GPU}) = CuArray
function EvoTrees.post_fit_gc(::Type{<:EvoTrees.GPU})
GC.gc(true)
CUDA.reclaim()

Check warning on line 11 in ext/EvoTreesCUDAExt/EvoTreesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/EvoTreesCUDAExt.jl#L9-L11

Added lines #L9 - L11 were not covered by tests
end

include("loss.jl")
include("eval.jl")
include("predict.jl")
include("init.jl")
include("subsample.jl")
include("fit-utils.jl")
include("fit.jl")

end # module
22 changes: 11 additions & 11 deletions src/gpu/eval.jl → ext/EvoTreesCUDAExt/eval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
end
return nothing
end
function mse(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}
function EvoTrees.mse(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}

Check warning on line 11 in ext/EvoTreesCUDAExt/eval.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/eval.jl#L11

Added line #L11 was not covered by tests
threads = min(MAX_THREADS, length(y))
blocks = cld(length(y), threads)
@cuda blocks = blocks threads = threads eval_mse_kernel!(eval, p, y, w)
Expand All @@ -19,8 +19,8 @@
########################
# RMSE
########################
rmse(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat} =
sqrt(rmse(p, y, w; MAX_THREADS, kwargs...))
EvoTrees.rmse(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat} =

Check warning on line 22 in ext/EvoTreesCUDAExt/eval.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/eval.jl#L22

Added line #L22 was not covered by tests
sqrt(EvoTrees.rmse(p, y, w; MAX_THREADS, kwargs...))

########################
# MAE
Expand All @@ -32,7 +32,7 @@
end
return nothing
end
function mae(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}
function EvoTrees.mae(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}

Check warning on line 35 in ext/EvoTreesCUDAExt/eval.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/eval.jl#L35

Added line #L35 was not covered by tests
threads = min(MAX_THREADS, length(y))
blocks = cld(length(y), threads)
@cuda blocks = blocks threads = threads eval_mae_kernel!(eval, p, y, w)
Expand All @@ -51,7 +51,7 @@
end
return nothing
end
function logloss(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}
function EvoTrees.logloss(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}

Check warning on line 54 in ext/EvoTreesCUDAExt/eval.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/eval.jl#L54

Added line #L54 was not covered by tests
threads = min(MAX_THREADS, length(y))
blocks = cld(length(y), threads)
@cuda blocks = blocks threads = threads eval_logloss_kernel!(eval, p, y, w)
Expand All @@ -70,7 +70,7 @@
end
return nothing
end
function gaussian_mle(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}
function EvoTrees.gaussian_mle(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}

Check warning on line 73 in ext/EvoTreesCUDAExt/eval.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/eval.jl#L73

Added line #L73 was not covered by tests
threads = min(MAX_THREADS, length(y))
blocks = cld(length(y), threads)
@cuda blocks = blocks threads = threads eval_gaussian_kernel!(eval, p, y, w)
Expand All @@ -91,7 +91,7 @@
return nothing
end

function poisson(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}
function EvoTrees.poisson(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}

Check warning on line 94 in ext/EvoTreesCUDAExt/eval.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/eval.jl#L94

Added line #L94 was not covered by tests
threads = min(MAX_THREADS, length(y))
blocks = cld(length(y), threads)
@cuda blocks = blocks threads = threads eval_poisson_kernel!(eval, p, y, w)
Expand All @@ -111,7 +111,7 @@
return nothing
end

function gamma(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}
function EvoTrees.gamma(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}

Check warning on line 114 in ext/EvoTreesCUDAExt/eval.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/eval.jl#L114

Added line #L114 was not covered by tests
threads = min(MAX_THREADS, length(y))
blocks = cld(length(y), threads)
@cuda blocks = blocks threads = threads eval_gamma_kernel!(eval, p, y, w)
Expand All @@ -133,7 +133,7 @@
return nothing
end

function tweedie(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}
function EvoTrees.tweedie(p::CuMatrix{T}, y::CuVector{T}, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}

Check warning on line 136 in ext/EvoTreesCUDAExt/eval.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/eval.jl#L136

Added line #L136 was not covered by tests
threads = min(MAX_THREADS, length(y))
blocks = cld(length(y), threads)
@cuda blocks = blocks threads = threads eval_tweedie_kernel!(eval, p, y, w)
Expand All @@ -158,10 +158,10 @@
return nothing
end

function mlogloss(p::CuMatrix{T}, y::CuVector, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}
function EvoTrees.mlogloss(p::CuMatrix{T}, y::CuVector, w::CuVector{T}, eval::CuVector{T}; MAX_THREADS=1024, kwargs...) where {T<:AbstractFloat}

Check warning on line 161 in ext/EvoTreesCUDAExt/eval.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/eval.jl#L161

Added line #L161 was not covered by tests
threads = min(MAX_THREADS, length(y))
blocks = cld(length(y), threads)
@cuda blocks = blocks threads = threads eval_mlogloss_kernel!(eval, p, y, w)
CUDA.synchronize()
return sum(eval) / sum(w)
end
end
16 changes: 4 additions & 12 deletions src/gpu/fit-utils.jl → ext/EvoTreesCUDAExt/fit-utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
"""
hist_kernel!
"""
function hist_kernel!(h∇::CuDeviceArray{T,3}, ∇::CuDeviceMatrix{S}, x_bin, is, js) where {T,S}
tix, tiy, k = threadIdx().z, threadIdx().y, threadIdx().x
bdx, bdy = blockDim().z, blockDim().y
Expand Down Expand Up @@ -48,9 +45,6 @@
return nothing
end

"""
hist_kernel_vec!
"""
function hist_kernel_vec!(h∇, ∇, x_bin, is)
tix, k = threadIdx().x, threadIdx().y
bdx = blockDim().x
Expand Down Expand Up @@ -103,10 +97,8 @@
return nothing
end

"""
Multi-threads split_set!
Take a view into left and right placeholders. Right ids are assigned at the end of the length of the current node set.
"""
# Multi-threads split_set!
# Take a view into left and right placeholders. Right ids are assigned at the end of the length of the current node set.
function split_chunk_kernel!(
left::CuDeviceVector{S},
right::CuDeviceVector{S},
Expand Down Expand Up @@ -149,7 +141,7 @@
return nothing
end

function split_views_kernel!(
function EvoTrees.split_views_kernel!(

Check warning on line 144 in ext/EvoTreesCUDAExt/fit-utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit-utils.jl#L144

Added line #L144 was not covered by tests
out::CuDeviceVector{S},
left::CuDeviceVector{S},
right::CuDeviceVector{S},
Expand Down Expand Up @@ -208,7 +200,7 @@
sum_lefts = sum(lefts)
cumsum_lefts = cumsum(lefts)
cumsum_rights = cumsum(rights)
@cuda blocks = nblocks threads = 1 split_views_kernel!(
@cuda blocks = nblocks threads = 1 EvoTrees.split_views_kernel!(

Check warning on line 203 in ext/EvoTreesCUDAExt/fit-utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit-utils.jl#L203

Added line #L203 was not covered by tests
out,
left,
right,
Expand Down
48 changes: 24 additions & 24 deletions src/gpu/fit.jl → ext/EvoTreesCUDAExt/fit.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
function grow_evotree!(evotree::EvoTree{L,K}, cache, params::EvoTypes{L}, ::Type{GPU}) where {L,K}
function EvoTrees.grow_evotree!(evotree::EvoTree{L,K}, cache, params::EvoTrees.EvoTypes{L}, ::Type{<:EvoTrees.GPU}) where {L,K}

Check warning on line 1 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L1

Added line #L1 was not covered by tests

# compute gradients
update_grads!(cache.∇, cache.pred, cache.y, params)
EvoTrees.update_grads!(cache.∇, cache.pred, cache.y, params)

Check warning on line 4 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L4

Added line #L4 was not covered by tests
# subsample rows
cache.nodes[1].is =
subsample(cache.is_in, cache.is_out, cache.mask, params.rowsample, params.rng)
EvoTrees.subsample(cache.is_in, cache.is_out, cache.mask, params.rowsample, params.rng)
# subsample cols
sample!(params.rng, cache.js_, cache.js, replace=false, ordered=true)
EvoTrees.sample!(params.rng, cache.js_, cache.js, replace=false, ordered=true)

Check warning on line 9 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L9

Added line #L9 was not covered by tests

# assign a root and grow tree
tree = Tree{L,K}(params.max_depth)
tree = EvoTrees.Tree{L,K}(params.max_depth)

Check warning on line 12 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L12

Added line #L12 was not covered by tests
grow! = params.tree_type == "oblivious" ? grow_otree! : grow_tree!
grow!(
tree,
Expand All @@ -34,9 +34,9 @@

jeremiedb marked this conversation as resolved.
Show resolved Hide resolved
# grow a single binary tree - grow through all depth
function grow_tree!(
tree::Tree{L,K},
tree::EvoTrees.Tree{L,K},
nodes::Vector{N},
params::EvoTypes{L},
params::EvoTrees.EvoTypes{L},
∇::CuMatrix,
edges,
js,
Expand Down Expand Up @@ -66,7 +66,7 @@

# initialize summary stats
nodes[1].∑ .= Vector(vec(sum(∇[:, nodes[1].is], dims=2)))
nodes[1].gain = get_gain(params, nodes[1].∑) # should use a GPU version?
nodes[1].gain = EvoTrees.get_gain(params, nodes[1].∑) # should use a GPU version?

Check warning on line 69 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L69

Added line #L69 was not covered by tests

# grow while there are remaining active nodes
while length(n_current) > 0 && depth <= params.max_depth
Expand All @@ -90,14 +90,14 @@
update_hist_gpu!(nodes[n].h, h∇, ∇, x_bin, nodes[n].is, jsg, js)
end
end
@threads for n ∈ sort(n_current)
update_gains!(nodes[n], js, params, feattypes, monotone_constraints)
Threads.@threads for n ∈ sort(n_current)
EvoTrees.update_gains!(nodes[n], js, params, feattypes, monotone_constraints)

Check warning on line 94 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L93-L94

Added lines #L93 - L94 were not covered by tests
end
end

for n ∈ sort(n_current)
if depth == params.max_depth || nodes[n].∑[end] <= params.min_weight
pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is)
EvoTrees.pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is)

Check warning on line 100 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L100

Added line #L100 was not covered by tests
else
best = findmax(findmax.(nodes[n].gains))
best_gain = best[1][1]
Expand Down Expand Up @@ -126,8 +126,8 @@
nodes[n<<1].is, nodes[n<<1+1].is = _left, _right
nodes[n<<1].∑ .= nodes[n].hL[best_feat][:, best_bin]
nodes[n<<1+1].∑ .= nodes[n].hR[best_feat][:, best_bin]
nodes[n<<1].gain = get_gain(params, nodes[n<<1].∑)
nodes[n<<1+1].gain = get_gain(params, nodes[n<<1+1].∑)
nodes[n<<1].gain = EvoTrees.get_gain(params, nodes[n<<1].∑)
nodes[n<<1+1].gain = EvoTrees.get_gain(params, nodes[n<<1+1].∑)

Check warning on line 130 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L129-L130

Added lines #L129 - L130 were not covered by tests

if length(_right) >= length(_left)
push!(n_next, n << 1)
Expand All @@ -137,7 +137,7 @@
push!(n_next, n << 1)
end
else
pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is)
EvoTrees.pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is)

Check warning on line 140 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L140

Added line #L140 was not covered by tests
end
end
end
Expand All @@ -151,9 +151,9 @@

# grow a single oblivious tree - grow through all depth
function grow_otree!(
tree::Tree{L,K},
tree::EvoTrees.Tree{L,K},
nodes::Vector{N},
params::EvoTypes{L},
params::EvoTrees.EvoTypes{L},
∇::CuMatrix,
edges,
js,
Expand Down Expand Up @@ -183,7 +183,7 @@

# initialize summary stats
nodes[1].∑ .= Vector(vec(sum(∇[:, nodes[1].is], dims=2)))
nodes[1].gain = get_gain(params, nodes[1].∑) # should use a GPU version?
nodes[1].gain = EvoTrees.get_gain(params, nodes[1].∑) # should use a GPU version?

Check warning on line 186 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L186

Added line #L186 was not covered by tests

# grow while there are remaining active nodes
while length(n_current) > 0 && depth <= params.max_depth
Expand All @@ -197,7 +197,7 @@
if depth == params.max_depth || min_weight_flag
for n in n_current
# @info "length(nodes[n].is)" length(nodes[n].is) depth n
pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is)
EvoTrees.pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is)

Check warning on line 200 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L200

Added line #L200 was not covered by tests
end
else
# update histograms
Expand All @@ -217,8 +217,8 @@
update_hist_gpu!(nodes[n].h, h∇, ∇, x_bin, nodes[n].is, jsg, js)
end
end
@threads for n ∈ n_current
update_gains!(nodes[n], js, params, feattypes, monotone_constraints)
Threads.@threads for n ∈ n_current
EvoTrees.update_gains!(nodes[n], js, params, feattypes, monotone_constraints)

Check warning on line 221 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L220-L221

Added lines #L220 - L221 were not covered by tests
end

# initialize gains for node 1 in which all gains of a given depth will be accumulated
Expand Down Expand Up @@ -273,8 +273,8 @@
nodes[n<<1].is, nodes[n<<1+1].is = _left, _right
nodes[n<<1].∑ .= nodes[n].hL[best_feat][:, best_bin]
nodes[n<<1+1].∑ .= nodes[n].hR[best_feat][:, best_bin]
nodes[n<<1].gain = get_gain(params, nodes[n<<1].∑)
nodes[n<<1+1].gain = get_gain(params, nodes[n<<1+1].∑)
nodes[n<<1].gain = EvoTrees.get_gain(params, nodes[n<<1].∑)
nodes[n<<1+1].gain = EvoTrees.get_gain(params, nodes[n<<1+1].∑)

Check warning on line 277 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L276-L277

Added lines #L276 - L277 were not covered by tests

if length(_right) >= length(_left)
push!(n_next, n << 1)
Expand All @@ -286,7 +286,7 @@
end
else
for n in n_current
pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is)
EvoTrees.pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is)

Check warning on line 289 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L289

Added line #L289 was not covered by tests
end
end
end
Expand All @@ -295,4 +295,4 @@
end # end of loop over current nodes for a given depth

return nothing
end
end
Loading
Loading