Skip to content

Commit

Permalink
Merge pull request #54 from appleparan/ap/knn
Browse files Browse the repository at this point in the history
kNN Imputation
  • Loading branch information
rofinn authored Mar 19, 2020
2 parents d716162 + 86d0597 commit 26e50c7
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 3 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ authors = ["Invenia Technical Computing"]
version = "0.4.0"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
18 changes: 18 additions & 0 deletions src/Impute.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
module Impute

using Distances
using IterTools
using Missings
using NearestNeighbors
using Random
using Statistics
using StatsBase
Expand Down Expand Up @@ -68,6 +71,7 @@ const global imputation_methods = (
nocb = NOCB,
srs = SRS,
svd = SVD,
knn = KNN,
)

include("deprecated.jl")
Expand Down Expand Up @@ -334,4 +338,18 @@ Utility method for `impute(data, :svd; limit=limit)`
"""
svd(data::AbstractMatrix; limit=1.0) = impute(data, :svd; limit=limit)

"""
knn!(data::AbstractMatrix; limit=1.0)
Utility method for `impute!(data, :knn; limit=limit)`
"""
knn!(data::AbstractMatrix; limit=1.0) = impute!(data, :knn; limit=limit)

"""
knn(data::AbstractMatrix; limit=1.0)
Utility method for `impute(data, :knn; limit=limit)`
"""
knn(data::AbstractMatrix; limit=1.0) = impute(data, :knn; limit=limit)

end # module
2 changes: 1 addition & 1 deletion src/imputors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,6 @@ function impute!(table, imp::Imputor)
return table
end

for file in ("drop.jl", "locf.jl", "nocb.jl", "interp.jl", "fill.jl", "chain.jl", "srs.jl", "svd.jl")
for file in ("drop.jl", "locf.jl", "nocb.jl", "interp.jl", "fill.jl", "chain.jl", "srs.jl", "svd.jl", "knn.jl")
include(joinpath("imputors", file))
end
75 changes: 75 additions & 0 deletions src/imputors/knn.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
KNN <: Imputor
Imputation using k-Nearest Neighbor algorithm.
# Keyword Arguments
* `k::Int`: number of nearest neighbors
* `dist::MinkowskiMetric`: distance metric suppports by `NearestNeighbors.jl` (Euclidean, Chebyshev, Minkowski and Cityblock)
* `threshold::AbsstractFloat`: thershold for missing neighbors
* `on_complete::Function`: a function to run when imputation is complete
# Reference
* Troyanskaya, Olga, et al. "Missing value estimation methods for DNA microarrays." Bioinformatics 17.6 (2001): 520-525.
"""
# TODO : Support Categorical Distance (NearestNeighbors.jl support needed)
struct KNN{M} <: Imputor where M <: NearestNeighbors.MinkowskiMetric
k::Int
threshold::AbstractFloat
dist::M
context::AbstractContext
end

function KNN(; k=1, threshold=0.5, dist=Euclidean(), context=Context())
k < 1 && throw(ArgumentError("The number of nearset neighbors should be greater than 0"))

!(0 < threshold < 1) && throw(ArgumentError("Missing neighbors threshold should be within 0 to 1"))

# to exclude missing value itself
KNN(k + 1, threshold, dist, context)
end

function impute!(data::AbstractMatrix{<:Union{T, Missing}}, imp::KNN) where T<:Real
imp.context() do ctx
# Get mask array first (order of )
mmask = ismissing.(transpose(data))

# fill missing value as mean value
impute!(data, Fill(; value=mean, context=ctx))

# then, transpose to D x N for KDTree
transposed = transpose(disallowmissing(data))

kdtree = KDTree(transposed, imp.dist)
idxs, dists = NearestNeighbors.knn(kdtree, transposed, imp.k, true)

idxes = CartesianIndices(transposed)
fallback_threshold = imp.k * imp.threshold

for I in CartesianIndices(transposed)
if mmask[I] == 1
w = 1.0 ./ dists[I[2]]
ws = sum(w[2:end])
missing_neighbors = ismissing.(transposed[:, idxs[I[2]]][:, 2:end])

# exclude missing value itself because distance would be zero
if isnan(ws) || isinf(ws) || iszero(ws)
# if distance is zero or not a number, keep mean imputation
transposed[I] = transposed[I]
elseif count(!iszero, mapslices(sum, missing_neighbors, dims=1)) >
fallback_threshold
# If too many neighbors are also missing, fallback to mean imputation
# get column and check how many neighbors are also missing
transposed[I] = transposed[I]
else
# Inverse distance weighting
wt = w .* transposed[I[1], idxs[I[2]]]
transposed[I] = sum(wt[2:end]) / ws
end
end
end

# for type stability
allowmissing(transposed')
end
end
127 changes: 125 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

using AxisArrays
using Combinatorics
using DataFrames
Expand Down Expand Up @@ -31,7 +30,6 @@ using Impute:
interp,
chain


function add_missings(X, ratio=0.1)
result = Matrix{Union{Float64, Missing}}(X)

Expand All @@ -42,6 +40,17 @@ function add_missings(X, ratio=0.1)
return result
end

function add_missings_single(X, ratio=0.1)
result = Matrix{Union{Float64, Missing}}(X)

randcols = 1:floor(Int, size(X, 2) * ratio)
for col in randcols
result[rand(1:size(X, 1)), col] = missing
end

return result
end

@testset "Impute" begin
# Defining our missing datasets
a = allowmissing(1.0:1.0:20.0)
Expand Down Expand Up @@ -532,6 +541,120 @@ end
end
end

@testset "KNN" begin
@testset "Iris" begin
# Reference
# P. Schimitt, et. al
# A comparison of six methods for missing data imputation
iris = dataset("datasets", "iris")
iris2 = filter(row -> row[:Species] == "versicolor" || row[:Species] == "virginica", iris)
data = Array(iris2[:, [:SepalLength, :SepalWidth, :PetalLength, :PetalWidth]])
num_tests = 100

@testset "Iris - 0.15" begin
X = add_missings(data, 0.15)

knn_nrmsd, mean_nrmsd = 0.0, 0.0

for i = 1:num_tests
knn_imputed = impute(copy(X), Impute.KNN(; k=2))
mean_imputed = impute(copy(X),
Fill(; value=mean, context=Context(; limit=1.0)))

knn_nrmsd = ((i - 1) * knn_nrmsd + nrmsd(data, knn_imputed)) / i
mean_nrmsd = ((i - 1) * mean_nrmsd + nrmsd(data, mean_imputed)) / i
end

@test knn_nrmsd < mean_nrmsd
# test type stability
@test typeof(X) == typeof(impute(copy(X), Impute.KNN(; k=2)))
@test typeof(X) == typeof(impute(copy(X), Fill(; value=mean,
context=Context(; limit=1.0))))
end

@testset "Iris - 0.25" begin
X = add_missings(data, 0.25)

knn_nrmsd, mean_nrmsd = 0.0, 0.0

for i = 1:num_tests
knn_imputed = impute(copy(X), Impute.KNN(; k=2))
mean_imputed = impute(copy(X),
Fill(; value=mean, context=Context(; limit=1.0)))

knn_nrmsd = ((i - 1) * knn_nrmsd + nrmsd(data, knn_imputed)) / i
mean_nrmsd = ((i - 1) * mean_nrmsd + nrmsd(data, mean_imputed)) / i
end

@test knn_nrmsd < mean_nrmsd
# test type stability
@test typeof(X) == typeof(impute(copy(X), Impute.KNN(; k=2)))
@test typeof(X) == typeof(impute(copy(X), Fill(; value=mean,
context=Context(; limit=1.0))))
end

@testset "Iris - 0.35" begin
X = add_missings(data, 0.35)

knn_nrmsd, mean_nrmsd = 0.0, 0.0

for i = 1:num_tests
knn_imputed = impute(copy(X), Impute.KNN(; k=2))
mean_imputed = impute(copy(X),
Fill(; value=mean, context=Context(; limit=1.0)))

knn_nrmsd = ((i - 1) * knn_nrmsd + nrmsd(data, knn_imputed)) / i
mean_nrmsd = ((i - 1) * mean_nrmsd + nrmsd(data, mean_imputed)) / i
end

@test knn_nrmsd < mean_nrmsd
# test type stability
@test typeof(X) == typeof(impute(copy(X), Impute.KNN(; k=2)))
@test typeof(X) == typeof(impute(copy(X), Fill(; value=mean,
context=Context(; limit=1.0))))
end
end

# Test a case where we expect kNN to perform well (e.g., many variables, )
@testset "Data match" begin
data = mapreduce(hcat, 1:1000) do i
seeds = [sin(i), cos(i), tan(i), atan(i)]
mapreduce(vcat, combinations(seeds)) do args
[
+(args...),
*(args...),
+(args...) * 100,
+(abs.(args)...),
(+(args...) * 10) ^ 2,
(+(abs.(args)...) * 10) ^ 2,
log(+(abs.(args)...) * 100),
+(args...) * 100 + rand(-10:0.1:10),
]
end
end

X = add_missings(data')
num_tests = 100

knn_nrmsd, mean_nrmsd = 0.0, 0.0

for i = 1:num_tests
knn_imputed = impute(copy(X), Impute.KNN(; k=2))
mean_imputed = impute(copy(X),
Fill(; value=mean, context=Context(; limit=1.0)))

knn_nrmsd = ((i - 1) * knn_nrmsd + nrmsd(data', knn_imputed)) / i
mean_nrmsd = ((i - 1) * mean_nrmsd + nrmsd(data', mean_imputed)) / i
end

@test knn_nrmsd < mean_nrmsd
# test type stability
@test typeof(X) == typeof(impute(copy(X), Impute.KNN(; k=2)))
@test typeof(X) == typeof(impute(copy(X), Fill(; value=mean,
context=Context(; limit=1.0))))
end
end

include("deprecated.jl")
include("testutils.jl")

Expand Down

0 comments on commit 26e50c7

Please sign in to comment.