Skip to content

Commit

Permalink
input_checks fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
montyvesselinov committed Jul 17, 2024
1 parent 78ce6fd commit c7c035a
Showing 1 changed file with 26 additions and 33 deletions.
59 changes: 26 additions & 33 deletions src/NMFkExecute.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import Distributed
import JLD

function check_methods!(X::AbstractArray{T,N}, mixture::Symbol=:null, method::Symbol=:simple, algorithm::Symbol=:multdiv, clusterWmatrix::Bool=false) where {T <: Number, N}
function input_checks(X::AbstractArray{T,N}, load::Bool, save::Bool, casefilename::AbstractString, mixture::Symbol, method::Symbol, algorithm::Symbol, clusterWmatrix::Bool) where {T <: Number, N}
if load && casefilename == ""
@info("Loading of existing results is requested but \`casefilename\` is not specified; casefilename = \"nmfk\" will be used!")
casefilename = "nmfk"
end
if save && casefilename == ""
@info("Saving of obtained results is requested but \`casefilename\` is not specified; casefilename = \"nmfk\" will be used!")
casefilename = "nmfk"
end
if mixture != :null
clusterWmatrix = true
method = :ipopt
Expand All @@ -10,6 +18,11 @@ function check_methods!(X::AbstractArray{T,N}, mixture::Symbol=:null, method::Sy
@info("For multi-dimensional arrays (tensors), use NMFk.tensorfactorization or NTFk!")
throw(ArgumentError("NMFk analysis can be executed for matrices!"))
end
if N == 2 && size(X, 1) < size(X, 2) && clusterWmatrix == false
@warn("Processed matrix size has more columns than rows (matrix size=$(size(X)))!")
@info("In this case, it is recommended to use `clusterWmatrix == true`.")
@info("It is preferred to cluster the smaller of the matrices!")
end
if any(isnan.(X))
@info("Analyzed matrix has NaN's.")
if method != :simple && method != :ipopt && method != :nlopt
Expand Down Expand Up @@ -45,38 +58,31 @@ function check_methods!(X::AbstractArray{T,N}, mixture::Symbol=:null, method::Sy
println("NLopt ...")
elseif method == :nmf
if algorithm == :multdiv
println("Multiplicative update using divergence ...")
println("Multiplicative updates based on divergence ...")
elseif algorithm == :multmse
println("Multiplicative update using mean-squared-error ...")
println("Multiplicative updates based on mean-squared-error ...")
elseif algorithm == :alspgrad
println("Alternate Least Square using Projected Gradient Descent ...")
println("Alternate Least Square based on Projected Gradient Descent ...")
end
elseif method == :sparsity
println("Sparsity penalty ...")
elseif method == :simple
println("Simple NMF multiplicative ...")
println("Simple multiplicative updates ...")
end
return load, save, casefilename, mixture, method, algorithm, clusterWmatrix
end

"Execute NMFk analysis for a range of number of signals"
function execute(X::AbstractArray{T,N}, nkrange::AbstractRange{Int}, nNMF::Integer=10; cutoff::Number=0.5, casefilename::AbstractString="", clusterWmatrix::Bool=false, mixture::Symbol=:null, method::Symbol=:simple, algorithm::Symbol=:multdiv, load::Bool=true, save::Bool=true, kw...) where {T <: Number, N}
function execute(X::AbstractArray{T,N}, nkrange::AbstractRange{Int}, nNMF::Integer=10; cutoff::Number=0.5, clusterWmatrix::Bool=false, mixture::Symbol=:null, method::Symbol=:simple, algorithm::Symbol=:multdiv, load::Bool=true, save::Bool=true, casefilename::AbstractString="", kw...) where {T <: Number, N}
load, save, casefilename, mixture, method, algorithm, clusterWmatrix = input_checks(X, load, save, casefilename, mixture, method, algorithm, clusterWmatrix)
maxk = maximum(collect(nkrange))
W = Vector{Array{T, N}}(undef, maxk)
H = Vector{Matrix{T}}(undef, maxk)
fitquality = zeros(T, maxk)
robustness = zeros(T, maxk)
aic = zeros(T, maxk)
if load && casefilename == ""
@info("Loading of existing results is requested but \`casefilename\` is not specified; casefilename = \"nmfk\" will be used!")
casefilename = "nmfk"
end
if save && casefilename == ""
@info("Saving of obtained results is requested but \`casefilename\` is not specified; casefilename = \"nmfk\" will be used!")
casefilename = "nmfk"
end
check_methods!(X, mixture, method, algorithm, clusterWmatrix)
for nk in nkrange
W[nk], H[nk], fitquality[nk], robustness[nk], aic[nk] = NMFk.execute(X, nk, nNMF; load=load, save=save, casefilename=casefilename, mixture=mixture, method=method, algorithm=algorithm, clusterWmatrix=clusterWmatrix, check_metods=false, kw...)
W[nk], H[nk], fitquality[nk], robustness[nk], aic[nk] = NMFk.execute(X, nk, nNMF; load=load, save=save, casefilename=casefilename, mixture=mixture, method=method, algorithm=algorithm, clusterWmatrix=clusterWmatrix, check_inputs=false, kw...)
end
@info("Results:")
for nk in nkrange
Expand All @@ -92,30 +98,18 @@ function execute(X::AbstractArray{T,N}, nkrange::AbstractRange{Int}, nNMF::Integ
end

"Execute NMFk analysis for a given number of signals"
function execute(X::AbstractArray{T,N}, nk::Integer, nNMF::Integer=10; clusterWmatrix::Bool=false, mixture::Symbol=:null, method::Symbol=:simple, algorithm::Symbol=:multdiv, resultdir::AbstractString=".", casefilename::AbstractString="", loadonly::Bool=false, load::Bool=true, save::Bool=true, quiet::Bool=true, check_metods::Bool=true, kw...) where {T <: Number, N}
function execute(X::AbstractArray{T,N}, nk::Integer, nNMF::Integer=10; clusterWmatrix::Bool=false, mixture::Symbol=:null, method::Symbol=:simple, algorithm::Symbol=:multdiv, resultdir::AbstractString=".", casefilename::AbstractString="", loadonly::Bool=false, load::Bool=true, save::Bool=true, quiet::Bool=true, check_inputs::Bool=true, kw...) where {T <: Number, N}
if .*(size(X)...) == 0
error("Input array has a zero dimension! Array size=$(size(X))")
end
if N == 2 && size(X, 1) < size(X, 2) && clusterWmatrix == false
@warn("Processed matrix size has more columns than rows (matrix size=$(size(X)))!")
@info("In this case, it is recommended to use `clusterWmatrix == true`.")
@info("It is preferred to cluster the smaller of the matrices!")
end
if loadonly
load = true
save = false
runflag = false
else
runflag = true
end
if load && casefilename == ""
@info("Loading of existing results is requested but \`casefilename\` is not specified; casefilename = \"nmfk\" will be used!")
casefilename = "nmfk"
end
if save && casefilename == ""
@info("Saving of obtained results is requested but \`casefilename\` is not specified; casefilename = \"nmfk\" will be used!")
casefilename = "nmfk"
end
check_metods && check_methods!(X, mixture, method, algorithm, clusterWmatrix)
check_inputs && (load, save, casefilename, mixture, method, algorithm, clusterWmatrix = input_checks(X, load, save, casefilename, mixture, method, algorithm, clusterWmatrix))
if load
filename = joinpathcheck(resultdir, "$casefilename-$nk-$nNMF.jld")
if isfile(filename)
Expand All @@ -130,14 +124,13 @@ function execute(X::AbstractArray{T,N}, nk::Integer, nNMF::Integer=10; clusterWm
println("H matrix - Expected size: $((nk, size(X, 2))) Actual size: $(size(H))")
end
else
@info("File $(filename) is missing; runs will be executed!")
if loadonly
W = Matrix{T}(undef, 0, 0);
H = Matrix{T}(undef, 0, 0);
fitquality = Inf;
robustness = -1;
aic = -Inf;
else
@info("File $(filename) is missing; runs will be executed!")
end
end
end
Expand Down

0 comments on commit c7c035a

Please sign in to comment.