Skip to content

Commit

Permalink
Merge pull request #35 from tylerjthomas9/dev
Browse files Browse the repository at this point in the history
fix some MLJ warnings, add MMI.reformat for eventual update support
  • Loading branch information
tylerjthomas9 authored May 12, 2023
2 parents c9461f7 + 3c9847f commit 18b8593
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
CUDA = "3"
CUDA = "3, 4"
CondaPkg = "0.2"
MLJBase = "0.20, 0.21"
MLJModelInterface = "1"
Expand Down
22 changes: 15 additions & 7 deletions src/CuML/CuML.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,29 @@ const CUML_MODELS = Union{
CUML_TIME_SERIES,
}

function MMI.reformat(::CUML_MODELS, X)
return (to_numpy(X),)
function MMI.reformat(::CUML_MODELS, X, y)
return (to_numpy(X), to_numpy(y))
end

function MMI.reformat(::CUML_MODELS, X, y)
return to_numpy(X), to_numpy(y)
function MMI.reformat(::CUML_MODELS, X)
return (to_numpy(X), )
end

function MMI.selectrows(::CUML_MODELS, I, X)
py_I = numpy.array(I .- 1)
return (X[py_I,],)
py_I = numpy.array(numpy.array(I .- 1))
return (X[py_I,], )
end

function MMI.selectrows(::CUML_MODELS, I::Colon, X)
return (X,)
return (X, )
end

function MMI.selectrows(::CUML_MODELS, I, X, y)
py_I = numpy.array(numpy.array(I .- 1))
return (X[py_I,], y[py_I])
end
function MMI.selectrows(::CUML_MODELS, I::Colon, X, y)
return (X, y)
end

MMI.clean!(model::CUML_MODELS) = ""
Expand Down
2 changes: 1 addition & 1 deletion src/CuML/classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ function MMI.input_scitype(::Type{<:CUML_CLASSIFICATION})
AbstractMatrix{MMI.Continuous},
}
end
MMI.target_scitype(::Type{<:CUML_CLASSIFICATION}) = AbstractVector{<:Finite}
MMI.target_scitype(::Type{<:CUML_CLASSIFICATION}) = Union{AbstractVector{<:Finite}, AbstractVector{MMI.Continuous}}

function MMI.docstring(::Type{<:LogisticRegression})
return "cuML's LogisticRegression: https://docs.rapids.ai/api/cuml/stable/api.html#logistic-regression"
Expand Down
2 changes: 1 addition & 1 deletion src/CuML/dimensionality_reduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ MMI.load_path(::Type{<:TSNE}) = "$PKG.CuML.TSNE"

function MMI.input_scitype(::Type{<:CUML_DIMENSIONALITY_REDUCTION})
return Union{
MMI.Table(MMI.Continuous, MMI.Count, MMI.OrderedFactor, MMI.Multiclass),
Table{<:Union{AbstractVector{<:Continuous}, AbstractVector{<:Count}, AbstractVector{<:OrderedFactor}, AbstractVector{<:Multiclass}}},
AbstractMatrix{MMI.Continuous},
}
end
Expand Down
2 changes: 1 addition & 1 deletion src/RAPIDS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ if Base.VERSION <= v"1.8.3"
@warn warning_msg
end

if !CUDA.functional()
if !CUDA.has_cuda_gpu()
@warn "No CUDA GPU Detected. Unable to load RAPIDS."
const cucim = nothing
const cudf = nothing
Expand Down
4 changes: 2 additions & 2 deletions test/cuml.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,14 @@ end
end

@testset "SVC" begin
model = SVC()
model = SVC(probability=true)
mach = machine(model, X, y)
fit!(mach)
preds = predict(mach, X)
end

@testset "LinearSVC" begin
model = LinearSVC()
model = LinearSVC(probability=true)
mach = machine(model, X, y)
fit!(mach)
preds = predict(mach, X)
Expand Down
34 changes: 5 additions & 29 deletions test/cuml_integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
],
MLJTestInterface.make_regression()...;
mod=@__MODULE__,
verbosity=0, # bump to debug
verbosity=1, # bump to debug
throw=false,
)
@test isempty(failures)
Expand All @@ -30,41 +30,17 @@
y[y_string .== "O"] .= 1.0
failures, summary = MLJTestInterface.test(
[
LogisticRegression,
#LogisticRegression,
MBSGDClassifier,
RandomForestClassifier,
SVC,
LinearSVC,
#SVC,
#LinearSVC,
KNeighborsClassifier,
],
X,
y;
mod=@__MODULE__,
verbosity=0, # bump to debug
throw=false,
)
@test isempty(failures)
end

@testset "Binary Classification" begin
X, y_string = MLJTestInterface.make_binary()
# RAPIDS can only handle numeric values
# TODO: add support for non-numeric labels
y = zeros(200)
y[y_string .== "O"] .= 1.0
failures, summary = MLJTestInterface.test(
[
LogisticRegression,
MBSGDClassifier,
RandomForestClassifier,
SVC,
LinearSVC,
KNeighborsClassifier,
],
X,
y;
mod=@__MODULE__,
verbosity=0, # bump to debug
verbosity=1, # bump to debug
throw=false,
)
@test isempty(failures)
Expand Down

0 comments on commit 18b8593

Please sign in to comment.