Skip to content

Commit

Permalink
Add missing name qualifications
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Oct 5, 2023
1 parent 3a3fc7f commit fd64c3f
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
6 changes: 3 additions & 3 deletions ext/EvoTreesCUDAExt/EvoTreesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ using EvoTrees
using CUDA

# This should be different on CPUs and GPUs
EvoTrees.device_ones(::Type{<:GPU}, ::Type{T}, n::Int) where {T} = CUDA.ones(T, n)
EvoTrees.device_array_type(::Type{<:GPU}) = CuArray
function EvoTrees.post_fit_gc(::Type{<:GPU})
EvoTrees.device_ones(::Type{<:EvoTrees.GPU}, ::Type{T}, n::Int) where {T} = CUDA.ones(T, n)

Check warning on line 7 in ext/EvoTreesCUDAExt/EvoTreesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/EvoTreesCUDAExt.jl#L7

Added line #L7 was not covered by tests
EvoTrees.device_array_type(::Type{<:EvoTrees.GPU}) = CuArray
function EvoTrees.post_fit_gc(::Type{<:EvoTrees.GPU})
GC.gc(true)
CUDA.reclaim()

Check warning on line 11 in ext/EvoTreesCUDAExt/EvoTreesCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/EvoTreesCUDAExt.jl#L9-L11

Added lines #L9 - L11 were not covered by tests
end
Expand Down
6 changes: 3 additions & 3 deletions ext/EvoTreesCUDAExt/fit.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function EvoTrees.grow_evotree!(evotree::EvoTree{L,K}, cache, params::EvoTrees.EvoTypes{L}, ::Type{GPU}) where {L,K}
function EvoTrees.grow_evotree!(evotree::EvoTree{L,K}, cache, params::EvoTrees.EvoTypes{L}, ::Type{<:EvoTrees.GPU}) where {L,K}

Check warning on line 1 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L1

Added line #L1 was not covered by tests

# compute gradients
EvoTrees.update_grads!(cache.∇, cache.pred, cache.y, params)

Check warning on line 4 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L4

Added line #L4 was not covered by tests
Expand Down Expand Up @@ -90,7 +90,7 @@ function grow_tree!(
update_hist_gpu!(nodes[n].h, h∇, ∇, x_bin, nodes[n].is, jsg, js)
end
end
@threads for n sort(n_current)
Threads.@threads for n sort(n_current)
EvoTrees.update_gains!(nodes[n], js, params, feattypes, monotone_constraints)

Check warning on line 94 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L93-L94

Added lines #L93 - L94 were not covered by tests
end
end
Expand Down Expand Up @@ -217,7 +217,7 @@ function grow_otree!(
update_hist_gpu!(nodes[n].h, h∇, ∇, x_bin, nodes[n].is, jsg, js)
end
end
@threads for n n_current
Threads.@threads for n n_current
EvoTrees.update_gains!(nodes[n], js, params, feattypes, monotone_constraints)

Check warning on line 221 in ext/EvoTreesCUDAExt/fit.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/fit.jl#L220-L221

Added lines #L220 - L221 were not covered by tests
end

Expand Down
2 changes: 1 addition & 1 deletion ext/EvoTreesCUDAExt/init.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function EvoTrees.init_core(params::EvoTrees.EvoTypes{L}, ::Type{EvoTrees.GPU}, data, fnames, y_train, w, offset) where {L}
function EvoTrees.init_core(params::EvoTrees.EvoTypes{L}, ::Type{<:EvoTrees.GPU}, data, fnames, y_train, w, offset) where {L}

Check warning on line 1 in ext/EvoTreesCUDAExt/init.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/init.jl#L1

Added line #L1 was not covered by tests

# binarize data into quantiles
edges, featbins, feattypes = EvoTrees.get_edges(data; fnames, nbins=params.nbins, rng=params.rng)
Expand Down
16 changes: 8 additions & 8 deletions ext/EvoTreesCUDAExt/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ end
GradientRegression
"""
function predict_kernel!(
::Type{L},
::Type{<:EvoTrees.GradientRegression},
pred::CuDeviceMatrix{T},
split,
feats,
cond_bins,
leaf_pred,
x_bin,
feattypes,
) where {L<:EvoTrees.GradientRegression,T}
) where {T}
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
nid = 1
@inbounds if i <= size(pred, 2)
Expand All @@ -61,15 +61,15 @@ end
Logistic
"""
function predict_kernel!(
::Type{L},
::Type{<:EvoTrees.LogLoss},
pred::CuDeviceMatrix{T},
split,
feats,
cond_bins,
leaf_pred,
x_bin,
feattypes,
) where {L<:EvoTrees.LogLoss,T}
) where {T}
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
nid = 1
@inbounds if i <= size(pred, 2)
Expand All @@ -89,15 +89,15 @@ end
MLE2P
"""
function predict_kernel!(
::Type{L},
::Type{<:EvoTrees.MLE2P},
pred::CuDeviceMatrix{T},
split,
feats,
cond_bins,
leaf_pred,
x_bin,
feattypes,
) where {L<:EvoTrees.MLE2P,T}
) where {T}
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
nid = 1
@inbounds if i <= size(pred, 2)
Expand Down Expand Up @@ -165,13 +165,13 @@ end
function predict(
m::EvoTree{L,K},
data,
::Type{GPU};
::Type{<:EvoTrees.GPU};
ntree_limit=length(m.trees)) where {L,K}

pred = CUDA.zeros(K, size(data, 1))
ntrees = length(m.trees)
ntree_limit > ntrees && error("ntree_limit is larger than number of trees $ntrees.")
x_bin = CuArray(binarize(data; fnames=m.info[:fnames], edges=m.info[:edges]))
x_bin = CuArray(EvoTrees.binarize(data; fnames=m.info[:fnames], edges=m.info[:edges]))

Check warning on line 174 in ext/EvoTreesCUDAExt/predict.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/predict.jl#L174

Added line #L174 was not covered by tests
feattypes = CuArray(m.info[:feattypes])
for i = 1:ntree_limit
EvoTrees.predict!(pred, m.trees[i], x_bin, feattypes)

Check warning on line 177 in ext/EvoTreesCUDAExt/predict.jl

View check run for this annotation

Codecov / codecov/patch

ext/EvoTreesCUDAExt/predict.jl#L177

Added line #L177 was not covered by tests
Expand Down

0 comments on commit fd64c3f

Please sign in to comment.