Skip to content

Commit

Permalink
MLJ test
Browse files Browse the repository at this point in the history
GPU launcher fixes
  • Loading branch information
jeremiedb committed Oct 26, 2023
1 parent 9e5d4d6 commit 57ac5dc
Show file tree
Hide file tree
Showing 15 changed files with 21 additions and 6 deletions.
3 changes: 2 additions & 1 deletion ext/EvoTreesCUDAExt/fit-utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ function update_hist_gpu!(h, h∇_cpu, h∇, ∇, x_bin, is, js, jsc)
max_blocks = config.blocks
k = size(h∇, 1)
ty = max(1, min(length(js), fld(max_threads, k)))
tx = max(1, min(length(is), fld(max_threads, k * ty)))
tx = min(64, max(1, min(length(is), fld(max_threads, k * ty))))
threads = (k, ty, tx)
max_blocks = min(65535, max_blocks * fld(max_threads, prod(threads)))
by = cld(length(js), ty)
bx = min(cld(max_blocks, by), cld(length(is), tx))
blocks = (1, by, bx)
Expand Down
6 changes: 4 additions & 2 deletions ext/EvoTreesCUDAExt/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,19 @@ function EvoTrees.predict!(
pred .= max.(T(-15), pred .- maximum(pred, dims=1))
end

# prediction from single tree - assign each observation to its final leaf
# prediction for EvoTree model
function predict(
m::EvoTree{L,K},
data,
::Type{<:EvoTrees.GPU};
ntree_limit=length(m.trees)) where {L,K}

pred = CUDA.zeros(K, size(data, 1))
Tables.istable(data) ? data = Tables.columntable(data) : nothing
ntrees = length(m.trees)
ntree_limit > ntrees && error("ntree_limit is larger than number of trees $ntrees.")
x_bin = CuArray(EvoTrees.binarize(data; fnames=m.info[:fnames], edges=m.info[:edges]))
nobs = size(x_bin, 1)
pred = CUDA.zeros(K, nobs)
feattypes = CuArray(m.info[:feattypes])
for i = 1:ntree_limit
EvoTrees.predict!(pred, m.trees[i], x_bin, feattypes)
Expand Down
Binary file modified figures/gaussian-sinus-binary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/logistic-sinus-binary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/quantiles-sinus-binary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/regression-sinus-binary-gpu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/regression-sinus-binary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/regression-sinus-oblivious-gpu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/regression-sinus-oblivious.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/regression-sinus2-binary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/regression-sinus2-oblivious.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions src/MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function MMI.update(
end

function predict(::EvoTreeRegressor, fitresult, A)
pred = vec(predict(fitresult, A))
pred = predict(fitresult, A)
return pred
end

Expand All @@ -52,7 +52,7 @@ function predict(::EvoTreeClassifier, fitresult, A)
end

function predict(::EvoTreeCount, fitresult, A)
λs = vec(predict(fitresult, A))
λs = predict(fitresult, A)
return [Distributions.Poisson(λ) for λ λs]
end

Expand Down
3 changes: 2 additions & 1 deletion src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ function predict(
::Type{<:Device}=CPU;
ntree_limit=length(m.trees)) where {L,K}

Tables.istable(data) ? data = Tables.columntable(data) : nothing
ntrees = length(m.trees)
ntree_limit > ntrees && error("ntree_limit is larger than number of trees $ntrees.")
x_bin = binarize(data; fnames=m.info[:fnames], edges=m.info[:edges])
nobs = Tables.istable(data) ? length(Tables.getcolumn(data, 1)) : size(data, 1)
nobs = size(x_bin, 1)
pred = zeros(Float32, K, nobs)
for i = 1:ntree_limit
predict!(pred, m.trees[i], x_bin, m.info[:feattypes])
Expand Down
10 changes: 10 additions & 0 deletions test/MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,13 @@ mach.model.nrounds += 10
fit!(mach, rows=train, verbosity=1)

report(mach)

@testset "MLJ - rowtables - EvoTreeRegressor" begin
X, y = make_regression(1000, 5)
X = Tables.rowtable(X)
booster = EvoTreeRegressor()
# smoke tests:
mach = machine(booster, X, y) |> fit!
fit!(mach)
predict(mach, X)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Statistics
using EvoTrees
using EvoTrees: predict
using CategoricalArrays
using Tables
using Random
using Test

Expand Down

0 comments on commit 57ac5dc

Please sign in to comment.