From f64580a252935caccbe6a2aa566c81e9f69054c1 Mon Sep 17 00:00:00 2001 From: mtwesley Date: Sun, 24 Nov 2024 21:05:19 -0500 Subject: [PATCH] updating, running roxygen2 and adding censored files --- DESCRIPTION | 2 +- R/censored.R | 54 ++++++++++++++++++++++++++++++++++ man/greta.template.Rd | 13 ++++++++ tests/testthat/test-censored.R | 13 ++++++++ 4 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 R/censored.R create mode 100644 tests/testthat/test-censored.R diff --git a/DESCRIPTION b/DESCRIPTION index 986376a..259bb6a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -34,7 +34,7 @@ Encoding: UTF-8 Language: en-GB LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.0 +RoxygenNote: 7.3.2 SystemRequirements: Python (>= 2.7.0) with header files and shared library; TensorFlow (v1.14; https://www.tensorflow.org/); TensorFlow Probability (v0.7.0; https://www.tensorflow.org/probability/) diff --git a/R/censored.R b/R/censored.R new file mode 100644 index 0000000..1e2a5fb --- /dev/null +++ b/R/censored.R @@ -0,0 +1,54 @@ +exponential_censored_distribution <- R6::R6Class( + "exponential_censored_distribution", + inherit = distribution_node, + public = list( + initialize = function(rate, is_censored, censoring_type, lower, upper, dim) { + rate <- as.greta_array(rate) + is_censored <- check_param_greta_array(is_censored) + check_numeric_length_1(lower) + check_numeric_length_1(upper) + check_finite(lower) + check_finite(upper) + check_x_gte_y(lower, upper) + + dim <- check_dims(rate, is_censored, target_dim = dim) + + super$initialize("exponential_censored", dim) + self$add_parameter(rate, "rate") + self$add_parameter(is_censored, "is_censored") + self$censoring_type <- censoring_type + self$lower <- lower + self$upper <- upper + }, + tf_distrib = function(parameters, dag) { + rate <- parameters$rate + is_censored <- parameters$is_censored + exp_dist <- tfp$distributions$Exponential(rate = rate) + censored_log_prob <- switch( + self$censoring_type, + "right" = function(y) exp_dist$log_survival_function(y), + "left" = function(y) exp_dist$log_cdf(y), + "interval" = function(y) { + log_cdf_upper <- exp_dist$log_cdf(self$upper) + log_cdf_lower <- exp_dist$log_cdf(self$lower) + tf$log(tf$exp(log_cdf_upper) - tf$exp(log_cdf_lower)) + }, + function(y) exp_dist$log_prob(y) + ) + uncensored_log_prob <- function(y) exp_dist$log_prob(y) + list( + log_prob = function(y) { + tf$where( + tf$equal(is_censored, 1), + censored_log_prob(y), + uncensored_log_prob(y) + ) + } + ) + } + ) +) + +exponential_censored <- function(rate, is_censored, censoring_type, lower, upper, dim) { + distrib("exponential_censored", rate, is_censored, censoring_type = censoring_type, lower = lower, upper = upper, dim = dim) +} \ No newline at end of file diff --git a/man/greta.template.Rd b/man/greta.template.Rd index 955a6bc..d4231ec 100644 --- a/man/greta.template.Rd +++ b/man/greta.template.Rd @@ -2,6 +2,7 @@ % Please edit documentation in R/package.R \docType{package} \name{greta.template} +\alias{greta.template-package} \alias{greta.template} \title{What the Package Does (One Line, Title Case)} \description{ @@ -12,3 +13,15 @@ describe your package here, you can re-use the text from DESCRIPTION # add a simple example here to introduce the package! } +\seealso{ +Useful links: +\itemize{ + \item \url{https://github.com/your_username/your_repository} + \item Report bugs at \url{https://github.com/your_username/your_repository/issues} +} + +} +\author{ +\strong{Maintainer}: First Last \email{first.last@example.com} (\href{https://orcid.org/YOUR-ORCID-ID}{ORCID}) + +} diff --git a/tests/testthat/test-censored.R b/tests/testthat/test-censored.R new file mode 100644 index 0000000..c27e3d8 --- /dev/null +++ b/tests/testthat/test-censored.R @@ -0,0 +1,13 @@ +test_that("exponential_censored_distribution initializes correctly", { + rate <- as.greta_array(1) + is_censored <- as.greta_array(0) + dist <- exponential_censored_distribution$new(rate, is_censored, "right", 0, 1, dim = c(1, 1)) + expect_s3_class(dist, "exponential_censored_distribution") +}) + +test_that("exponential_censored function works", { + rate <- as.greta_array(1) + is_censored <- as.greta_array(0) + dist <- exponential_censored(rate, is_censored, "right", 0, 1, dim = c(1, 1)) + expect_s3_class(dist, "distribution_node") +}) \ No newline at end of file