Skip to content

Commit

Permalink
Fix lda with constant within class variables
Browse files Browse the repository at this point in the history
  • Loading branch information
fradav committed Jun 4, 2020
1 parent 888701c commit 9b7e0fa
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 11 deletions.
22 changes: 14 additions & 8 deletions src/lda-eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ typedef Matrix<size_t, Dynamic, 1> VectorXs;
* @param Ld computes matrix whoses columns are Ld vectors
*/
template<class Derived>
void lda(const MatrixBase<Derived> &x,
std::vector<int> lda(const MatrixBase<Derived> &x,
const VectorXs &y,
MatrixXd& Ld)
{
auto n = x.rows();
auto p = x.cols();
auto K = y.maxCoeff() + 1;
auto p = x.cols();

// M = Centroids
MatrixXd M = MatrixXd::Zero(K, p);
Expand Down Expand Up @@ -59,18 +59,23 @@ void lda(const MatrixBase<Derived> &x,
MatrixXd D(n, p);

// W = Within-class covariance matrix
MatrixXd W = MatrixXd::Zero(p, p);
MatrixXd Wraw = MatrixXd::Zero(p, p);
size_t slicepos = 0;
for (auto c = 0; c < K; c++)
{
// Dc is a single class subrange of x
MatrixXd Dc = x.block(slicepos, 0, d[c], p);
// Centering
Dc.rowwise() -= M.row(c);
W += Dc.transpose() * Dc;
Wraw += Dc.transpose() * Dc;
slicepos += d[c];
}
W /= static_cast<double>(n - K);
Wraw /= static_cast<double>(n - K);
auto validvars = std::vector<int>();
for(auto i = 0; i< x.cols(); i++) {
if (std::abs(Wraw(i,i)) >= 1.0e-8) validvars.push_back(i);
}
auto W = Wraw(validvars,validvars);
// [4,4]((0.265008,0.0927211,0.167514,0.0384014),(0.0927211,0.115388,0.0552435,0.0327102),(0.167514,0.0552435,0.185188,0.0426653),(0.0384014,0.0327102,0.0426653,0.0418816))

// Calculate pseudo-inverse square root for W with SVD
Expand All @@ -84,9 +89,9 @@ void lda(const MatrixBase<Derived> &x,
// }
//auto W12 = svd.matrixU() * svd.eigenvalues().array().inverse().sqrt().matrix().asDiagonal() * svd.matrixU().transpose();
// auto W12 = svd.operatorInverseSqrt();

MatrixXd Mstar = M * W12;
VectorXd mstar = VectorXd::Zero(p);
MatrixXd Mvalid = M(all,validvars);
MatrixXd Mstar = Mvalid * W12;
VectorXd mstar = VectorXd::Zero(validvars.size());
for (auto c = 0; c < K; c++)
mstar += Mstar.row(c);
mstar /= static_cast<double>(K);
Expand All @@ -98,4 +103,5 @@ void lda(const MatrixBase<Derived> &x,
JacobiSVD<MatrixXd> svd2(Bstar,ComputeThinU);
MatrixXd Vl = W12 * svd2.matrixU();
Ld = Vl.leftCols(K-1);
return validvars;
}
25 changes: 22 additions & 3 deletions src/matutils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,40 @@ void addLinearComb(Eigen::Ref<MatrixXd, 0, Eigen::Stride<Eigen::Dynamic, Eigen::
X.block(0,ncols,X.rows(),M.cols()) = ref * M;
}

template<class Derived, class OtherDerived>
void addLinearComb(Eigen::Ref<MatrixXd, 0, Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>> ref, MatrixBase<Derived> const & X_,const MatrixBase<OtherDerived>& M, std::vector<int>& validvars) {
auto ncols = X_.cols();
auto& X = constCastAddColsMatrix(X_,M.cols());
X.block(0,ncols,X.rows(),M.cols()) = ref(all,validvars).eval() * M;
}

template<class Derived, class OtherDerived>
void addLinearComb(MatrixBase<Derived> const & X_, const MatrixBase<OtherDerived>& M) {
auto ncols = X_.cols();
auto& X = constCastAddColsMatrix(X_,M.cols());
X.block(0,ncols,X.rows(),M.cols()) = X.block(0,0,X.rows(),ncols) * M;
}

template<class Derived, class OtherDerived>
void addLinearComb(MatrixBase<Derived> const & X_, const MatrixBase<OtherDerived>& M, std::vector<int>& validvars) {
auto ncols = X_.cols();
auto& X = constCastAddColsMatrix(X_,M.cols());
X.block(0,ncols,X.rows(),M.cols()) = X.block(0,0,X.rows(),ncols)(all,validvars).eval() * M;
}


template<class Derived, class MatrixType>
void addLda(Reftable<MatrixType>& rf, MatrixXd& data, MatrixBase<Derived> const &statobs) {
VectorXs scen(rf.nrec);
for(auto i = 0; i < rf.nrec; i++) scen(i) = static_cast<size_t>(rf.scenarios[i]) - 1;
MatrixXd Ld;
lda(rf.stats, scen, Ld);
addLinearComb(rf.stats,data,Ld);
addLinearComb(statobs,Ld);
std::vector<int> validvars = lda(rf.stats, scen, Ld);
addLinearComb(rf.stats,data,Ld,validvars);
addLinearComb(statobs,Ld,validvars);
for(auto i = 0; i < rf.stats.cols(); i++) {
if (std::find(std::begin(validvars),std::end(validvars),i) == std::end(validvars))
std::cout << "LDA Warning : " << rf.stats_names[i] << " is constant within class, removed." << std::endl;
}
for(auto i = 0; i < rf.nrecscen.size() - 1; i++) {
rf.stats_names.push_back("LDA" + std::to_string(i+1));
}
Expand Down
53 changes: 53 additions & 0 deletions test/lda-eigen-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,56 @@ TEST_CASE("LDA with Eigen")
CHECK(std::min(LdPlus.col(c).lpNorm<Infinity>(),LdMinus.col(c).lpNorm<Infinity>()) == Approx(0.0).margin(1e-13));
}
}

TEST_CASE("LDA with Eigen and constant variable")
{
const size_t K = 3;
const size_t p = 4;
const size_t n = (S.size() + C.size() + V.size()) / p;
Matrix<size_t,-1,1> y(n);
MatrixXd x(n,p+1);
// S
for (auto i = 0; i < (S.size()) / p; i++)
{
y[i] = 0;
for (auto j = 0; j < p; j++)
x(i, j) = S[i * p + j];
}
for (auto i = 0; i < (C.size()) / p; i++)
{
auto ii = (S.size() / p) + i;
y[ii] = 1;
for (auto j = 0; j < p; j++)
x(ii, j) = C[i * p + j];
}
for (auto i = 0; i < (V.size()) / p; i++)
{
auto ii = (S.size() + C.size()) / p + i;
y[ii] = 2;
for (auto j = 0; j < p; j++)
x(ii, j) = V[i * p + j];
}
for(auto i = 0; i < n; i++) {
x(i,p) = x(i,2);
x(i,2) = y[i];
}
MatrixXd Ld;
lda(x,y,Ld);

MatrixXd LdMass(p,K-1);
LdMass <<
// 0.8293776, 0.02410215,
// 1.5344731, 2.16452123,
// -2.2012117, -0.93192121,
// -2.8104603, 2.83918785;
-0.82937764226600674, 0.024102148876954166,
-1.53447306770001091, 2.164521234658435489,
2.81046030884310172, 2.839187852982734128,
2.20121165556177356, -0.931921210029371894;

auto LdPlus = Ld + LdMass;
auto LdMinus = Ld - LdMass;
for(auto c = 0; c < K - 1; c++) {
CHECK(std::min(LdPlus.col(c).lpNorm<Infinity>(),LdMinus.col(c).lpNorm<Infinity>()) == Approx(0.0).margin(1e-13));
}
}

0 comments on commit 9b7e0fa

Please sign in to comment.