-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add classification and additional non-compositional covariates #12
Open
viettr
wants to merge
71
commits into
jacobbien:master
Choose a base branch
from
viettr:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 68 commits
Commits
Show all changes
71 commits
Select commit
Hold shift + click to select a range
fd0f980
Add classification and additional covariates
viettr 86a5a43
initiate testing
viettr abd0730
change CI settings
viettr a4dd5c6
try CI for reticulate
viettr 382c24e
Try CI settings
viettr 2378df6
Initiate CI with Travis, Python3 for classo
viettr 642b42a
Bugfix: remove map_dfc and transpose instead
viettr 8b64a81
Fix: cv: cross validation fold y wasn't a vector
viettr d2ff4dd
Add backwards compatibility
viettr 9636fe3
Bugfix for additional variables
viettr acf392b
fix bugs for small sample size and convergence problems.
viettr 94557ca
add probibalisitic output for classification, refactoring code, add m…
viettr 558e4bf
update documentation
viettr 49ed94d
CI: move from travis ci to github actions
viettr 57f0e4c
CI: bug fix github actions
viettr b37e9cd
CI: add first test
viettr d3667f4
update construction of A
viettr dec0ecf
Small bugfix, integrate update from master branch, expand testing
viettr 712909b
Merge remote-tracking branch 'upstream/master'Merge remote-tracking b…
viettr 45907d4
Update github workflow. Install necessary python packages for c-lasso
viettr b374e3b
fix bug: add normalization arguement to cv_trac
viettr da8a1b2
Fix bug: fix normalization bug
viettr 1ec5010
add additional covariates names
viettr 0ab7782
add weights for additional covariates
viettr ca7c53d
fix bug
viettr 0c90538
refactor code
viettr d4cf8d5
Rebase master into fork
viettr 716c6c4
Revert "Rebase master into fork"
viettr 92230ea
Revert "Revert "Rebase master into fork""
viettr f99a414
initiate testing
viettr 641ba98
change CI settings
viettr 4769d9a
try CI for reticulate
viettr 93d1c3c
Try CI settings
viettr 996e3f5
Initiate CI with Travis, Python3 for classo
viettr e17396d
Bugfix: remove map_dfc and transpose instead
viettr 7e46f8a
Fix: cv: cross validation fold y wasn't a vector
viettr d91c256
Add backwards compatibility
viettr 895b155
Bugfix for additional variables
viettr ea6effb
fix bugs for small sample size and convergence problems.
viettr b104bcd
add probibalisitic output for classification, refactoring code, add m…
viettr fbd97b1
update documentation
viettr 75dab85
CI: move from travis ci to github actions
viettr 15c9fb1
CI: bug fix github actions
viettr 5c28ba1
update construction of A
viettr a25caca
Update github workflow. Install necessary python packages for c-lasso
viettr e1f6d60
fix bug: add normalization arguement to cv_trac
viettr 923ee12
Fix bug: fix normalization bug
viettr cad6e07
add additional covariates names
viettr f583149
add weights for additional covariates
viettr 7ccd9e4
fix bug
viettr 9c8ff84
refactor code
viettr 1e78da3
add stratified folds
viettr 26e3cf5
refactor code
viettr f54a15e
Fix github action: dependencies for vignette
viettr 3b88473
Fix github actions: add pandoc
viettr 98af64e
adjust log_contrast to new output from c-lasso
viettr 5918b20
add test for classification
viettr 266bb81
fix bug: probability output
viettr 9b21f85
Fix bug: weights for additional covariates and weights for compositio…
viettr 798aafc
fix bug: weights for compositional and non compositional data
viettr 623fc1e
Add further documentation
viettr 69c27b8
merge upstream branch
viettr f2c9308
fix bug: weights additional covariates
viettr ce13529
fix bug: weights additional covariates
viettr eecf116
Refactor code + add classification to log-contrast + add additional c…
viettr 0b22db4
add comments to helper functions
viettr 99de4ac
Update documentation
viettr aa19e83
remove github workflow
viettr 3bb9ce9
publish website
viettr 32bc822
adjust hsm dependency
viettr 190d6c3
remove ggb
viettr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ | |
^_pkgdown\.yml$ | ||
^docs$ | ||
^pkgdown$ | ||
^\.github$ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
.Rproj.user | ||
inst/doc | ||
.Rhistory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,27 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
export(aggregate_to_level) | ||
export(check_additional_covariates) | ||
export(check_method) | ||
export(cv_sparse_log_contrast) | ||
export(cv_trac) | ||
export(get_probability_cv) | ||
export(get_probability_platt) | ||
export(phylo_to_A) | ||
export(plot_cv_trac) | ||
export(plot_trac_path) | ||
export(predict_second_stage) | ||
export(predict_sparse_log_contrast) | ||
export(predict_trac) | ||
export(probability_transform) | ||
export(refit_sparse_log_contrast) | ||
export(refit_sparse_log_contrast_classif) | ||
export(refit_trac) | ||
export(rescale_betas) | ||
export(second_stage) | ||
export(sparse_log_contrast) | ||
export(tax_table_to_phylo) | ||
export(trac) | ||
importFrom(magrittr,"%>%") | ||
importFrom(rlang,.data) | ||
importFrom(stats,predict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,138 @@ | ||
#' Perform cross validation for tuning parameter selection | ||
#' | ||
#' This function is to be called after calling \code{\link{trac}}. It performs | ||
#' \code{nfold}-fold cross validation. | ||
#' \code{nfold}-fold cross validation. For classification the metric is | ||
#' misclassification error. | ||
#' | ||
#' @param fit output of \code{\link{trac}} function. | ||
#' @param Z,y,A same arguments as passed to \code{\link{trac}} | ||
#' @param Z,y,A,additional_covariates same arguments as passed to | ||
#' \code{\link{trac}} | ||
#' @param folds a partition of \code{1:nrow(Z)}. | ||
#' @param nfolds number of folds for cross-validation | ||
#' @param summary_function how to combine the errors calculated on each | ||
#' observation within a fold (e.g. mean or median) | ||
#' observation within a fold (e.g. mean or median) (only for regression task) | ||
#' @param stratified if `TRUE` use stratified folds based on target variable | ||
#' only for classification. Default set to FALSE. | ||
#' @export | ||
cv_trac <- function(fit, Z, y, A, folds = NULL, nfolds = 5, summary_function = stats::median) { | ||
cv_trac <- function(fit, Z, y, A, additional_covariates = NULL, folds = NULL, | ||
nfolds = 5, summary_function = stats::median, | ||
stratified = FALSE) { | ||
n <- nrow(Z) | ||
p <- ncol(Z) | ||
if (!is.null(additional_covariates) & !is.data.frame(additional_covariates)) { | ||
additional_covariates <- data.frame(additional_covariates) | ||
} | ||
stopifnot(length(y) == n) | ||
if(is.null(folds)) folds <- ggb:::make_folds(n, nfolds) | ||
else | ||
if (is.null(folds)) { | ||
if (stratified) { | ||
folds <- make_folds_stratified(n, nfolds, y) | ||
} else { | ||
folds <- ggb:::make_folds(n, nfolds) | ||
} | ||
} else { | ||
nfolds <- length(folds) | ||
} | ||
|
||
cv <- list() | ||
fit_folds <- list() # save this to reuse by log-ratio's cv function | ||
for (iw in seq_along(fit)) { | ||
if (length(fit) > 1) cat("CV for weight sequence #", iw, fill = TRUE) | ||
errs <- matrix(NA, ncol(fit[[iw]]$beta), nfolds) | ||
for (i in seq(nfolds)) { | ||
cat("fold", i, fill = TRUE) | ||
# add for backward compatibility | ||
if (is.null(fit[[iw]]$method)) fit[[iw]]$method <- "regr" | ||
if (is.null(fit[[iw]]$w_additional_covariates)) { | ||
fit[[iw]]$w_additional_covariates <- NULL | ||
} | ||
if (is.null(fit[[iw]]$rho)) fit[[iw]]$rho <- 0 | ||
|
||
|
||
# train on all but i-th fold (and use settings from fit): | ||
fit_folds[[i]] <- trac(Z[-folds[[i]], ], | ||
y[-folds[[i]]], | ||
A, fraclist = fit[[iw]]$fraclist, w = fit[[iw]]$w) | ||
fit_folds[[i]] <- trac(Z = Z[-folds[[i]], ], | ||
y = y[-folds[[i]]], | ||
A = A, | ||
additional_covariates = | ||
additional_covariates[-folds[[i]], ], | ||
fraclist = fit[[iw]]$fraclist, | ||
w = fit[[iw]]$w, | ||
w_additional_covariates = | ||
fit[[iw]]$w_additional_covariates, | ||
method = fit[[iw]]$method, | ||
rho = fit[[iw]]$rho, | ||
normalized = fit[[iw]]$normalized) | ||
|
||
if (fit[[iw]]$refit) { | ||
fit_folds[[i]] <- refit_trac(fit_folds[[i]], Z[-folds[[i]], ], | ||
y[-folds[[i]]], A) | ||
y[-folds[[i]]], A) | ||
} | ||
if (fit[[iw]]$method == "regr" | is.null(fit[[iw]]$method)) { | ||
errs[, i] <- apply( | ||
(predict_trac( | ||
fit_folds[[i]], | ||
Z[folds[[i]], ], | ||
additional_covariates[folds[[i]], ])[[1]] - y[folds[[i]]])^2, | ||
2, summary_function | ||
) | ||
} | ||
|
||
if (fit[[iw]]$method == "classif" | | ||
fit[[iw]]$method == "classif_huber") { | ||
# loss: max(0, 1 - y_hat * y)^2 | ||
er <- sign(predict_trac(fit_folds[[i]], | ||
Z[folds[[i]],], | ||
additional_covariates[folds[[i]],])[[1]]) != | ||
c(y[folds[[i]]]) | ||
errs[, i] <- colMeans(er) | ||
} | ||
errs[, i] <- apply((predict_trac(fit_folds[[i]], | ||
Z[folds[[i]], ])[[1]] - y[folds[[i]]])^2, 2, summary_function) | ||
} | ||
m <- rowMeans(errs) | ||
se <- apply(errs, 1, stats::sd) / sqrt(nfolds) | ||
ibest <- which.min(m) | ||
i1se <- min(which(m < m[ibest] + se[ibest])) | ||
cv[[iw]] <- list(errs = errs, m = m, se = se, | ||
lambda_best = fit[[iw]]$fraclist[ibest], ibest = ibest, | ||
lambda_1se = fit[[iw]]$fraclist[i1se], i1se = i1se, | ||
fraclist = fit[[iw]]$fraclist, w = fit[[iw]]$w, | ||
nonzeros = colSums(abs(fit[[iw]]$gamma) > 1e-5), | ||
fit_folds = fit_folds) | ||
i1se <- min(which(m <= m[ibest] + se[ibest])) | ||
cv[[iw]] <- list( | ||
errs = errs, m = m, se = se, | ||
lambda_best = fit[[iw]]$fraclist[ibest], ibest = ibest, | ||
lambda_1se = fit[[iw]]$fraclist[i1se], i1se = i1se, | ||
fraclist = fit[[iw]]$fraclist, w = fit[[iw]]$w, | ||
nonzeros = colSums(abs(fit[[iw]]$gamma) > 1e-5), | ||
fit_folds = fit_folds | ||
) | ||
} | ||
list( | ||
cv = cv, | ||
iw_best = which.min(lapply(cv, function(cvv) cvv$m[cvv$ibest])), | ||
iw_1se = which.min(lapply(cv, function(cvv) cvv$m[cvv$i1se])), | ||
folds = folds | ||
) | ||
} | ||
|
||
#' This function creates stratified folds for cross validation for unbalanced | ||
#' data. The code is adopted from ggb:::make_folds | ||
#' | ||
#' @param n number of observations | ||
#' @param nfolds number of folds | ||
#' @param y variable with group assignment. | ||
|
||
make_folds_stratified <- function(n, nfolds, y) { | ||
# Check if number of folds is greater than the max n of observations | ||
# per group. If the number is greater at least one fold will not contain | ||
# any observations of group of interest. | ||
max_n_y <- max(table(y)) | ||
nfolds <- min(nfolds, max_n_y) | ||
# Initiate the list in advance | ||
folds <- vector(mode = "list", length = nfolds) | ||
for (j in unique(y)) { | ||
ixs <- which(y == j) | ||
nn <- round(length(ixs) / nfolds) | ||
sizes <- rep(nn, nfolds) | ||
sizes[nfolds] <- sizes[nfolds] + length(ixs) - nn * nfolds | ||
b <- c(0, cumsum(sizes)) | ||
ii <- sample(length(ixs)) | ||
ii <- ixs[ii] | ||
for (i in seq(nfolds)) { | ||
folds[[i]] <- c(folds[[i]], ii[seq(b[i] + 1, b[i + 1])]) | ||
} | ||
} | ||
list(cv = cv, | ||
iw_best = which.min(lapply(cv, function(cvv) cvv$m[cvv$ibest])), | ||
iw_1se = which.min(lapply(cv, function(cvv) cvv$m[cvv$i1se])), | ||
folds = folds) | ||
folds | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if NULL, then set it to NULL?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thats a good point! I tried to make it backwards compatible but this line is unnecessary.