Skip to content

Commit

Permalink
updating, running roxygen2 and adding censored files
Browse files Browse the repository at this point in the history
  • Loading branch information
mtwesley committed Nov 25, 2024
1 parent a2a2c84 commit f64580a
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 1 deletion.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Encoding: UTF-8
Language: en-GB
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.0
RoxygenNote: 7.3.2
SystemRequirements: Python (>= 2.7.0) with header files and shared
library; TensorFlow (v1.14; https://www.tensorflow.org/); TensorFlow
Probability (v0.7.0; https://www.tensorflow.org/probability/)
54 changes: 54 additions & 0 deletions R/censored.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
exponential_censored_distribution <- R6::R6Class(
"exponential_censored_distribution",
inherit = distribution_node,
public = list(
initialize = function(rate, is_censored, censoring_type, lower, upper, dim) {
rate <- as.greta_array(rate)
is_censored <- check_param_greta_array(is_censored)
check_numeric_length_1(lower)
check_numeric_length_1(upper)
check_finite(lower)
check_finite(upper)
check_x_gte_y(lower, upper)

dim <- check_dims(rate, is_censored, target_dim = dim)

super$initialize("exponential_censored", dim)
self$add_parameter(rate, "rate")
self$add_parameter(is_censored, "is_censored")
self$censoring_type <- censoring_type
self$lower <- lower
self$upper <- upper
},
tf_distrib = function(parameters, dag) {
rate <- parameters$rate
is_censored <- parameters$is_censored
exp_dist <- tfp$distributions$Exponential(rate = rate)
censored_log_prob <- switch(
self$censoring_type,
"right" = function(y) exp_dist$log_survival_function(y),
"left" = function(y) exp_dist$log_cdf(y),
"interval" = function(y) {
log_cdf_upper <- exp_dist$log_cdf(self$upper)
log_cdf_lower <- exp_dist$log_cdf(self$lower)
tf$log(tf$exp(log_cdf_upper) - tf$exp(log_cdf_lower))
},
function(y) exp_dist$log_prob(y)
)
uncensored_log_prob <- function(y) exp_dist$log_prob(y)
list(
log_prob = function(y) {
tf$where(
tf$equal(is_censored, 1),
censored_log_prob(y),
uncensored_log_prob(y)
)
}
)
}
)
)

exponential_censored <- function(rate, is_censored, censoring_type, lower, upper, dim) {
distrib("exponential_censored", rate, is_censored, censoring_type = censoring_type, lower = lower, upper = upper, dim = dim)
}
13 changes: 13 additions & 0 deletions man/greta.template.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions tests/testthat/test-censored.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
test_that("exponential_censored_distribution initializes correctly", {
rate <- as.greta_array(1)
is_censored <- as.greta_array(0)
dist <- exponential_censored_distribution$new(rate, is_censored, "right", 0, 1, dim = c(1, 1))
expect_s3_class(dist, "exponential_censored_distribution")
})

test_that("exponential_censored function works", {
rate <- as.greta_array(1)
is_censored <- as.greta_array(0)
dist <- exponential_censored(rate, is_censored, "right", 0, 1, dim = c(1, 1))
expect_s3_class(dist, "distribution_node")
})

0 comments on commit f64580a

Please sign in to comment.