From 16dc008eedbf65ce5845b6cc647a7b7ffcb5a09c Mon Sep 17 00:00:00 2001 From: qiancao Date: Fri, 29 Nov 2024 10:25:42 -0500 Subject: [PATCH 1/6] getting started on a test for torch.compile --- tests/test_torch_compile.py | 51 +++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 tests/test_torch_compile.py diff --git a/tests/test_torch_compile.py b/tests/test_torch_compile.py new file mode 100644 index 0000000..f665a64 --- /dev/null +++ b/tests/test_torch_compile.py @@ -0,0 +1,51 @@ +""" + +Tests for torch.compile + +Note: conda install conda-forge::gxx + +TODO: + - test conda install conda-forge::cxx-compiler + - check torchtriton for GPU + +References: + - https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html + - https://github.com/pytorch/pytorch/issues/122094 + +""" + +# global modules +import json +import unittest + +import numpy as np +import torch + +# Local modules +from torchsurv.loss.cox import neg_partial_log_likelihood as cox + +# torch compile settings +import torch._inductor.config +torch._inductor.config.cpp.cxx = ("g++",) + +# set seed for reproducibility +torch.manual_seed(42) + +# TODO: wrap this in TestCoxSurvivalLossCompile(unittest.TestCase) +if __name__ == "__main__": + + # random data and parameters + N = 32 + log_hz = torch.randn(N) + event = torch.randint(low=0, high=2, size=(N,)).bool() + time = torch.randint(low=1, high=100, size=(N,)) + + # compile cox + ccox = torch.compile(cox) + + loss_cox = cox(log_hz, event, time) + loss_ccox = ccox(log_hz, event, time) + + + + \ No newline at end of file From 2e98ce79b629307113257b05028c2cd6a35bbc9e Mon Sep 17 00:00:00 2001 From: corolth1 Date: Mon, 9 Dec 2024 13:00:09 -0500 Subject: [PATCH 2/6] init --- src/torchsurv/loss/cox.py | 166 +++++++++++-------------- src/torchsurv/loss/weibull.py | 41 ++---- src/torchsurv/metrics/auc.py | 36 +++--- src/torchsurv/metrics/brier_score.py | 38 +++--- src/torchsurv/metrics/cindex.py | 6 +- src/torchsurv/stats/ipcw.py | 21 ++-- src/torchsurv/stats/kaplan_meier.py | 11 +- src/torchsurv/tools/validate_data.py | 146 ++++++++++++++++++++++ src/torchsurv/tools/validate_inputs.py | 125 ------------------- 9 files changed, 286 insertions(+), 304 deletions(-) create mode 100644 src/torchsurv/tools/validate_data.py delete mode 100644 src/torchsurv/tools/validate_inputs.py diff --git a/src/torchsurv/loss/cox.py b/src/torchsurv/loss/cox.py index ce49f46..774d8eb 100644 --- a/src/torchsurv/loss/cox.py +++ b/src/torchsurv/loss/cox.py @@ -6,7 +6,77 @@ import torch +from torchsurv.tools.validate_data import validate_cox + +@torch.jit.script +def _partial_likelihood_cox( + log_hz_sorted: torch.Tensor, + event_sorted: torch.Tensor, +) -> torch.Tensor: + """Calculate the partial log likelihood for the Cox proportional hazards model + in the absence of ties in event time. + """ + log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0) + return (log_hz_sorted - log_denominator)[event_sorted] + + +@torch.jit.script +def _partial_likelihood_efron( + log_hz_sorted: torch.Tensor, + event_sorted: torch.Tensor, + time_sorted: torch.Tensor, + time_unique: torch.Tensor, +) -> torch.Tensor: + """Calculate the partial log likelihood for the Cox proportional hazards model + using Efron's method to handle ties in event time. + """ + J = len(time_unique) + + H = [ + torch.where((time_sorted == time_unique[j]) & (event_sorted == 1))[0] + for j in range(J) + ] + R = [torch.where(time_sorted >= time_unique[j])[0] for j in range(J)] + + m = torch.tensor([len(h) for h in H]) + include = torch.tensor([len(h) > 0 for h in H]) + + log_nominator = torch.stack([torch.sum(log_hz_sorted[h]) for h in H]) + + denominator_naive = torch.stack([torch.sum(torch.exp(log_hz_sorted[r])) for r in R]) + denominator_ties = torch.stack([torch.sum(torch.exp(log_hz_sorted[h])) for h in H]) + + log_denominator_efron = torch.zeros(J).to(log_hz_sorted.device) + for j in range(J): + mj = int(m[j].item()) # Convert tensor to int + for l in range(1, mj + 1): + log_denominator_efron[j] += torch.log( + denominator_naive[j] - (l - 1) / float(m[j]) * denominator_ties[j] + ) + return (log_nominator - log_denominator_efron)[include] + + +@torch.jit.script +def _partial_likelihood_breslow( + log_hz_sorted: torch.Tensor, + event_sorted: torch.Tensor, + time_sorted: torch.Tensor, +): + """Calculate the partial log likelihood for the Cox proportional hazards model + using Breslow's method to handle ties in event time. + """ + N = len(time_sorted) + + R = [torch.where(time_sorted >= time_sorted[i])[0] for i in range(N)] + log_denominator = torch.tensor( + [torch.logsumexp(log_hz_sorted[R[i]], dim=0) for i in range(N)] + ) + + return (log_hz_sorted - log_denominator)[event_sorted] + + +@torch.jit.script def neg_partial_log_likelihood( log_hz: torch.Tensor, event: torch.Tensor, @@ -118,9 +188,9 @@ def neg_partial_log_likelihood( """ if checks: - _check_inputs(log_hz, event, time) + validate_cox(log_hz, event, time) - if any([event.sum() == 0, len(log_hz.size()) == 0]): + if torch.any([event.sum().item() == 0, len(log_hz.size()) == 0]): warnings.warn("No events OR single sample. Returning zero loss for the batch") return torch.tensor(0.0, requires_grad=True) @@ -164,98 +234,6 @@ def neg_partial_log_likelihood( return loss -def _partial_likelihood_cox( - log_hz_sorted: torch.Tensor, - event_sorted: torch.Tensor, -) -> torch.Tensor: - """Calculate the partial log likelihood for the Cox proportional hazards model - in the absence of ties in event time. - """ - log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0) - return (log_hz_sorted - log_denominator)[event_sorted] - - -def _partial_likelihood_efron( - log_hz_sorted: torch.Tensor, - event_sorted: torch.Tensor, - time_sorted: torch.Tensor, - time_unique: torch.Tensor, -) -> torch.Tensor: - """Calculate the partial log likelihood for the Cox proportional hazards model - using Efron's method to handle ties in event time. - """ - J = len(time_unique) - - H = [ - torch.where((time_sorted == time_unique[j]) & (event_sorted == 1))[0] - for j in range(J) - ] - R = [torch.where(time_sorted >= time_unique[j])[0] for j in range(J)] - - m = torch.tensor([len(h) for h in H]) - include = torch.tensor([len(h) > 0 for h in H]) - - log_nominator = torch.stack([torch.sum(log_hz_sorted[h]) for h in H]) - - denominator_naive = torch.stack([torch.sum(torch.exp(log_hz_sorted[r])) for r in R]) - denominator_ties = torch.stack([torch.sum(torch.exp(log_hz_sorted[h])) for h in H]) - - log_denominator_efron = torch.zeros(J).to(log_hz_sorted.device) - for j in range(J): - for l in range(1, m[j] + 1): - log_denominator_efron[j] += torch.log( - denominator_naive[j] - (l - 1) / m[j] * denominator_ties[j] - ) - return (log_nominator - log_denominator_efron)[include] - - -def _partial_likelihood_breslow( - log_hz_sorted: torch.Tensor, - event_sorted: torch.Tensor, - time_sorted: torch.Tensor, -): - """Calculate the partial log likelihood for the Cox proportional hazards model - using Breslow's method to handle ties in event time. - """ - N = len(time_sorted) - - R = [torch.where(time_sorted >= time_sorted[i])[0] for i in range(N)] - log_denominator = torch.tensor( - [torch.logsumexp(log_hz_sorted[R[i]], dim=0) for i in range(N)] - ) - - return (log_hz_sorted - log_denominator)[event_sorted] - - -def _check_inputs(log_hz: torch.Tensor, event: torch.Tensor, time: torch.Tensor): - if not isinstance(log_hz, torch.Tensor): - raise TypeError("Input 'log_hz' must be a tensor.") - - if not isinstance(event, torch.Tensor): - raise TypeError("Input 'event' must be a tensor.") - - if not isinstance(time, torch.Tensor): - raise TypeError("Input 'time' must be a tensor.") - - if len(log_hz) != len(event): - raise ValueError( - "Length mismatch: 'log_hz' and 'event' must have the same length." - ) - - if len(time) != len(event): - raise ValueError( - "Length mismatch: 'time' must have the same length as 'event'." - ) - - if any(val < 0 for val in time): - raise ValueError("Invalid values: All elements in 'time' must be non-negative.") - - if any(val not in [True, False, 0, 1] for val in event): - raise ValueError( - "Invalid values: 'event' must contain only boolean values (True/False or 1/0)" - ) - - if __name__ == "__main__": import doctest diff --git a/src/torchsurv/loss/weibull.py b/src/torchsurv/loss/weibull.py index 926488a..bc25141 100644 --- a/src/torchsurv/loss/weibull.py +++ b/src/torchsurv/loss/weibull.py @@ -4,7 +4,10 @@ TORCH_CLAMP_VALUE = 1e10 +from torchsurv.tools.validate_data import validate_weibull + +@torch.jit.script def neg_log_likelihood( log_params: torch.Tensor, event: torch.Tensor, @@ -95,7 +98,7 @@ def neg_log_likelihood( """ if checks: - _check_inputs(log_params, event, time) + validate_weibull(log_params, event, time) # Negative log likelihood nll = torch.neg( @@ -120,6 +123,7 @@ def neg_log_likelihood( return loss +@torch.jit.script def survival_function( log_params: torch.Tensor, time: torch.Tensor, all_times: bool = True ) -> torch.Tensor: @@ -186,6 +190,7 @@ def survival_function( ).cdf(time) +@torch.jit.script def log_hazard( log_params: torch.Tensor, time: torch.Tensor, all_times: bool = True ) -> torch.Tensor: @@ -254,12 +259,13 @@ def log_hazard( log_shape - log_scale + torch.expm1(log_shape) - * (torch.log(torch.clip(time, 1e-100, torch.inf)) - log_scale), + * (torch.log(torch.clamp(time, min=1e-100, max=torch.inf)) - log_scale), min=-TORCH_CLAMP_VALUE, max=TORCH_CLAMP_VALUE, ) +@torch.jit.script def cumulative_hazard( log_params: torch.Tensor, time: torch.Tensor, all_times: bool = True ) -> torch.Tensor: @@ -308,13 +314,14 @@ def cumulative_hazard( return torch.clamp( torch.exp( torch.exp(log_shape) - * (torch.log(torch.clip(time, 1e-100, torch.inf)) - log_scale) + * (torch.log(torch.clamp(time, min=1e-100, max=torch.inf)) - log_scale) ), min=0, max=TORCH_CLAMP_VALUE, ) +@torch.jit.script def _check_log_shape(log_params: torch.Tensor) -> torch.Tensor: """Private function, check if the log shape is missing and impute it with 0 if needed.""" @@ -335,34 +342,6 @@ def _check_log_shape(log_params: torch.Tensor) -> torch.Tensor: return log_params -def _check_inputs(log_params: torch.Tensor, event: torch.Tensor, time: torch.Tensor): - """Private function, perform input format checks.""" - if not isinstance(log_params, torch.Tensor): - raise TypeError("Input 'log_params' must be a tensor.") - - if not isinstance(event, torch.Tensor): - raise TypeError("Input 'event' must be a tensor.") - - if not isinstance(time, torch.Tensor): - raise TypeError("``Input 'time' must be a tensor.") - - if log_params.shape[0] != len(event): - raise ValueError( - "Length mismatch: The length of 'log_params' must match the length of 'event'." - ) - - if len(time) != len(event): - raise ValueError( - "Length mismatch: The length of 'time' must match the length of 'event'.`" - ) - - if any(val < 0 for val in time): - raise ValueError("All elements in 'time' must be non-negative.") - - if any(val not in [True, False, 0, 1] for val in event): - raise ValueError("All elements in 'event' must be boolean (True/False or 0/1).") - - if __name__ == "__main__": import doctest diff --git a/src/torchsurv/metrics/auc.py b/src/torchsurv/metrics/auc.py index b5a6a98..35444da 100644 --- a/src/torchsurv/metrics/auc.py +++ b/src/torchsurv/metrics/auc.py @@ -7,7 +7,7 @@ from torchmetrics import regression from ..stats import kaplan_meier -from ..tools import validate_inputs +from ..tools import validate_data class Auc: @@ -212,9 +212,9 @@ def __call__( # further input format checks if self.checks: - validate_inputs.validate_survival_data(event, time) - validate_inputs.validate_evaluation_time(new_time, time) - validate_inputs.validate_estimate(estimate, time, new_time) + validate_data.validate_survival_data(event, time) + validate_data.validate_evaluation_time(new_time, time) + validate_data.validate_estimate(estimate, time, new_time) # sample size and length of time n_samples, n_times = estimate.shape[0], new_time.shape[0] @@ -579,7 +579,7 @@ def compare( return pvalue # pylint: disable=invalid-name - def _integrate_incident(self, S: torch.tensor, tmax: torch.tensor) -> torch.Tensor: + def _integrate_incident(self, S: torch.Tensor, tmax: torch.Tensor) -> torch.Tensor: """Integrates the incident/dynamic AUC, int_t AUC(t) x w(t) dt where w(t) = 2*f(t)*S(t) and f(t) is the lifeline distribution, S(t) is the survival distribution estimated with the Kaplan @@ -617,7 +617,7 @@ def _integrate_incident(self, S: torch.tensor, tmax: torch.tensor) -> torch.Tens # pylint: disable=invalid-name def _integrate_cumulative( - self, S: torch.tensor, tmax: torch.tensor + self, S: torch.Tensor, tmax: torch.Tensor ) -> torch.Tensor: """Integrates the cumulative/dynamic AUC, int_t AUC(t) ยท f(t) dt where f(t) is the lifeline distribution estimated from the discrete @@ -1168,7 +1168,7 @@ def _bootstrap_auc( @staticmethod def _find_torch_unique_indices( - inverse_indices: torch.tensor, counts: torch.tensor + inverse_indices: torch.Tensor, counts: torch.Tensor ) -> torch.tensor: """return unique_sorted_indices such that sorted_unique_tensor[inverse_indices] = original_tensor @@ -1214,12 +1214,12 @@ def _validate_auc_inputs( @staticmethod def _update_auc_new_time( - estimate: torch.tensor, - event: torch.tensor, - time: torch.tensor, - new_time: torch.tensor, - weight: torch.tensor, - weight_new_time: torch.tensor, + estimate: torch.Tensor, + event: torch.Tensor, + time: torch.Tensor, + new_time: torch.Tensor, + weight: torch.Tensor, + weight_new_time: torch.Tensor, ) -> torch.tensor: # update new time if ( @@ -1255,7 +1255,7 @@ def _update_auc_new_time( @staticmethod def _update_auc_estimate( - estimate: torch.tensor, new_time: torch.tensor + estimate: torch.Tensor, new_time: torch.Tensor ) -> torch.tensor: # squeeze estimate if shape = (n_samples, 1) if estimate.ndim == 2 and estimate.shape[1] == 1: @@ -1271,10 +1271,10 @@ def _update_auc_estimate( @staticmethod def _update_auc_weight( - time: torch.tensor, - new_time: torch.tensor, - weight: torch.tensor, - weight_new_time: torch.tensor, + time: torch.Tensor, + new_time: torch.Tensor, + weight: torch.Tensor, + weight_new_time: torch.Tensor, ) -> torch.tensor: # if weight was not specified, weight of 1 if weight is None: diff --git a/src/torchsurv/metrics/brier_score.py b/src/torchsurv/metrics/brier_score.py index e50ebe1..d6c07d6 100644 --- a/src/torchsurv/metrics/brier_score.py +++ b/src/torchsurv/metrics/brier_score.py @@ -5,7 +5,7 @@ import torch from scipy import stats -from ..tools import validate_inputs +from ..tools import validate_data class BrierScore: @@ -185,11 +185,11 @@ def __call__( # further input format checks if self.checks: - validate_inputs.validate_survival_data(event, time) - validate_inputs.validate_evaluation_time( + validate_data.validate_survival_data(event, time) + validate_data.validate_evaluation_time( new_time, time, within_follow_up=False ) - validate_inputs.validate_estimate(estimate, time, new_time) + validate_data.validate_estimate(estimate, time, new_time) # Calculating the residuals for each subject and time point residuals = torch.zeros_like(estimate) @@ -802,7 +802,7 @@ def _bootstrap_brier_score( @staticmethod def _find_torch_unique_indices( - inverse_indices: torch.tensor, counts: torch.tensor + inverse_indices: torch.Tensor, counts: torch.Tensor ) -> torch.tensor: """return unique_sorted_indices such that sorted_unique_tensor[inverse_indices] = original_tensor @@ -824,11 +824,11 @@ def _find_torch_unique_indices( @staticmethod def _validate_brier_score_inputs( - estimate: torch.tensor, - time: torch.tensor, - new_time: torch.tensor, - weight: torch.tensor, - weight_new_time: torch.tensor, + estimate: torch.Tensor, + time: torch.Tensor, + new_time: torch.Tensor, + weight: torch.Tensor, + weight_new_time: torch.Tensor, ) -> torch.tensor: # check new_time and weight are provided, weight_new_time should be provided if all([new_time is not None, weight is not None, weight_new_time is None]): @@ -855,11 +855,11 @@ def _validate_brier_score_inputs( @staticmethod def _update_brier_score_new_time( - estimate: torch.tensor, - time: torch.tensor, - new_time: torch.tensor, - weight: torch.tensor, - weight_new_time: torch.tensor, + estimate: torch.Tensor, + time: torch.Tensor, + new_time: torch.Tensor, + weight: torch.Tensor, + weight_new_time: torch.Tensor, ) -> torch.tensor: # check format of new_time if ( @@ -891,10 +891,10 @@ def _update_brier_score_new_time( @staticmethod def _update_brier_score_weight( - time: torch.tensor, - new_time: torch.tensor, - weight: torch.tensor, - weight_new_time: torch.tensor, + time: torch.Tensor, + new_time: torch.Tensor, + weight: torch.Tensor, + weight_new_time: torch.Tensor, ) -> torch.tensor: # if weight was not specified, weight of 1 if weight is None: diff --git a/src/torchsurv/metrics/cindex.py b/src/torchsurv/metrics/cindex.py index 9276a07..a79cd9b 100644 --- a/src/torchsurv/metrics/cindex.py +++ b/src/torchsurv/metrics/cindex.py @@ -7,7 +7,7 @@ from scipy import stats from torchmetrics import regression -from ..tools import validate_inputs +from ..tools import validate_data class ConcordanceIndex: @@ -185,8 +185,8 @@ def __call__( # Inputs checks if self.checks: - validate_inputs.validate_survival_data(event, time) - validate_inputs.validate_estimate(estimate, time) + validate_data.validate_survival_data(event, time) + validate_data.validate_estimate(estimate, time) # find comparable pairs comparable = self._get_comparable_and_tied_time(event, time) diff --git a/src/torchsurv/stats/ipcw.py b/src/torchsurv/stats/ipcw.py index 871cdd7..5277ee4 100644 --- a/src/torchsurv/stats/ipcw.py +++ b/src/torchsurv/stats/ipcw.py @@ -4,15 +4,16 @@ import torch -from ..tools import validate_inputs -from . import kaplan_meier +from torchsurv.tools.validate_data import validate_inputs +from torchsurv.stats import kaplan_meier # pylint: disable=anomalous-backslash-in-string +@torch.jit.script def get_ipcw( - event: torch.tensor, - time: torch.tensor, - new_time: Optional[torch.tensor] = None, + event: torch.Tensor, + time: torch.Tensor, + new_time: Optional[torch.Tensor] = None, checks: bool = True, ) -> torch.Tensor: """Calculate the inverse probability censoring weights (IPCW). @@ -56,7 +57,7 @@ def get_ipcw( """ if checks: - validate_inputs.validate_survival_data(event, time) + validate_inputs(event, time) # time on which to evaluate IPCW if new_time is None: # if none, return ipcw of same size as time @@ -77,6 +78,7 @@ def get_ipcw( return ipcw +@torch.jit.script def _inverse_censoring_dist(ct: torch.Tensor) -> torch.Tensor: """Compute inverse of the censoring distribution. @@ -95,11 +97,12 @@ def _inverse_censoring_dist(ct: torch.Tensor) -> torch.Tensor: tensor([2.9701, 7.7634, 4.2651, 4.3415]) """ - if torch.any(ct.eq(0.0)): + if torch.any(ct == 0.0): + zero_indices = torch.nonzero(ct.eq(0.0), as_tuple=True)[0] warnings.warn( - "Censoring distribution zero at one or more time points. Returning ones as weight" + f"Censoring distribution zero at time points: {zero_indices.tolist()}. Returning ones as weight" ) - return torch.ones_like(ct, dtype=ct.dtype) + weight = 1.0 / ct weight = torch.ones(1, dtype=ct.dtype) / ct return weight diff --git a/src/torchsurv/stats/kaplan_meier.py b/src/torchsurv/stats/kaplan_meier.py index 750c2b2..c25d465 100644 --- a/src/torchsurv/stats/kaplan_meier.py +++ b/src/torchsurv/stats/kaplan_meier.py @@ -4,16 +4,17 @@ import torch -from ..tools import validate_inputs +from torchsurv.tools import validate_data +# @torch.jit.script class KaplanMeierEstimator: """Kaplan-Meier estimate of survival or censoring distribution for right-censored data :cite:p:`Kaplan1958`.""" def __call__( self, - event: torch.tensor, - time: torch.tensor, + event: torch.Tensor, + time: torch.Tensor, censoring_dist: bool = False, check: bool = True, ): @@ -62,7 +63,7 @@ def __call__( # Check input validity if required if check: - validate_inputs.validate_survival_data(event, time) + validate_data.validate_inputs(event, time) # Compute the counts of events, censorings, and the number at risk at each unique time uniq_times, n_events, n_at_risk, n_censored = self._compute_counts() @@ -200,7 +201,7 @@ def print_survival_table(self): def _compute_counts( self, - ) -> Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the counts of events, censorings and risk set at ``time``. Returns: Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor] diff --git a/src/torchsurv/tools/validate_data.py b/src/torchsurv/tools/validate_data.py new file mode 100644 index 0000000..27faa6a --- /dev/null +++ b/src/torchsurv/tools/validate_data.py @@ -0,0 +1,146 @@ +import torch +import torch + + +@torch.jit.script +def validate_weibull(log_params: torch.Tensor, event: torch.Tensor, time: torch.Tensor): + """Private function, perform input format checks.""" + if not isinstance(log_params, torch.Tensor): + raise TypeError("Input 'log_params' must be a tensor.") + + if not isinstance(event, torch.Tensor): + raise TypeError("Input 'event' must be a tensor.") + + if not isinstance(time, torch.Tensor): + raise TypeError("``Input 'time' must be a tensor.") + + if log_params.shape[0] != len(event): + raise ValueError( + "Length mismatch: The length of 'log_params' must match the length of 'event'." + ) + + if len(time) != len(event): + raise ValueError( + "Length mismatch: The length of 'time' must match the length of 'event'.`" + ) + + if torch.any(time < 0): + raise ValueError("All elements in 'time' must be non-negative.") + + if torch.any((event != 0) & (event != 1)): + raise ValueError( + "Invalid values: 'event' must contain only boolean values (True/False or 1/0)" + ) + + +@torch.jit.script +def validate_cox(log_hz: torch.Tensor, event: torch.Tensor, time: torch.Tensor): + if not isinstance(log_hz, torch.Tensor): + raise TypeError("Input 'log_hz' must be a tensor.") + + if not isinstance(event, torch.Tensor): + raise TypeError("Input 'event' must be a tensor.") + + if not isinstance(time, torch.Tensor): + raise TypeError("Input 'time' must be a tensor.") + + if len(log_hz) != len(event): + raise ValueError( + "Length mismatch: 'log_hz' and 'event' must have the same length." + ) + + if len(time) != len(event): + raise ValueError( + "Length mismatch: 'time' must have the same length as 'event'." + ) + + if torch.any(time < 0): + raise ValueError("Invalid values: All elements in 'time' must be non-negative.") + + if torch.any((event != 0) & (event != 1)): + raise ValueError( + "Invalid values: 'event' must contain only boolean values (True/False or 1/0)" + ) + + +@torch.jit.script +def validate_inputs(event: torch.Tensor, time: torch.Tensor) -> None: + """ + Validate the inputs for survival analysis functions. + + Args: + event (torch.Tensor): Event indicator tensor. + time (torch.Tensor): Time-to-event or censoring tensor. + + Raises: + TypeError: If inputs are not tensors. + ValueError: If any ``time`` are negative. + """ + if not isinstance(event, torch.Tensor) or not isinstance(time, torch.Tensor): + raise TypeError("Inputs 'event' and 'time' should be tensors") + + if torch.any(time < 0): + raise ValueError("All elements in 'time' must be non-negative") + + if torch.any((event != 0) & (event != 1)): + raise ValueError( + "Input 'event' must contain only boolean values (True/False or 1/0)" + ) + + +@torch.jit.script +def check_within_follow_up( + new_time: torch.Tensor, time: torch.Tensor, within_follow_up: bool +) -> None: + # Check if the within_follow_up flag is set to True + if within_follow_up: + # Check if any value in new_time is outside the range of time + if new_time.max() >= time.max() or new_time.min() < time.min(): + # Get the minimum and maximum values of time + min_time = time.min().item() + max_time = time.max().item() + # Raise a ValueError if new_time is not within the follow-up time range + raise ValueError( + "Value error: All new_time must be within follow-up time of test data: [{}; {}[".format( + min_time, max_time + ) + ) + + +@torch.jit.script +def validate_new_time( + new_time: torch.Tensor, time: torch.Tensor, within_follow_up: bool = True +) -> None: + """ + Validate the new_time tensor for survival analysis functions. + + Args: + new_time (torch.Tensor, float): Time points for metric computation of size n_times. + time (torch.Tensor, float): Event or censoring time of size n_samples. + within_follow_up (bool, optional): Whether values of ``new_time`` must be within values in ``time``. Defaults to True. + + Raises: + ValueError: If ``new_time`` contains duplicate values. + ValueError: If ``new_time`` is not sorted. + TypeError: If ``new_time`` is not a tensor. + ValueError: If ``new_time`` is not of floating-point type. + ValueError: If ``new_time`` is not within the range of follow-up in ``time``. Assessed only if ``within_follow_up`` is True. + """ + if not isinstance(new_time, torch.Tensor): + raise TypeError("Type error: Input 'new_time' should be a tensor.") + + if not torch.is_floating_point(new_time): + raise ValueError( + "Value error: Input 'new_time' should be of floating-point type." + ) + + new_time_sorted, _ = torch.sort(new_time) + if not torch.equal(new_time_sorted, new_time): + raise ValueError( + "Value error: Input 'new_time' should be sorted from the smallest time to the largest." + ) + + if len(new_time_sorted) != len(torch.unique(new_time_sorted)): + raise ValueError("Value error: Input 'new_time' should contain unique values.") + + check_within_follow_up(new_time, time, within_follow_up) diff --git a/src/torchsurv/tools/validate_inputs.py b/src/torchsurv/tools/validate_inputs.py deleted file mode 100644 index 1d19023..0000000 --- a/src/torchsurv/tools/validate_inputs.py +++ /dev/null @@ -1,125 +0,0 @@ -import torch - - -def validate_survival_data(event, time): - """Perform format and validity checks for survival data. - - Args: - event (torch.Tensor, boolean): - Event indicator of size n_samples (= True if event occured). - time (torch.Tensor, float): - Event or censoring time of size n_samples. - - Raises: - TypeError: If ``event`` or ``time`` are not tensors. - ValueError: If ``event`` is not boolean. - ValueError: If ``event`` and ``time`` are not of the same length. - ValueError: If all ``event`` are False. - ValueError: If any ``time`` are negative. - """ - if not torch.is_tensor(event) or not torch.is_tensor(time): - raise TypeError("Inputs 'event' and 'time' should be tensors") - - if not event.dtype == torch.bool: - raise ValueError("Input 'event' should be of boolean type.") - - if not torch.is_floating_point(time): - raise ValueError("Input 'time' should be of float type.") - - if len(event) != len(time): - raise ValueError( - "Dimension mismatch: Incompatible length between inputs 'time' and 'event'." - ) - - if torch.sum(event) <= 0: - raise ValueError("All samples are censored.") - - if torch.any(time < 0.0): - raise ValueError("Input 'time' should be non-negative.") - - -def validate_evaluation_time(new_time, time, within_follow_up=True): - """Perform format and validity checks for evaluation time. - - Args: - new_time (torch.Tensor, optional): - Time points for metric computation of size n_times. - time (torch.Tensor, float): - Event or censoring time of size n_samples. - within_follow_up (bool, optional): - Whether values of ``new_time`` must be within values in ``time``. - Defaults to True. - - Raises: - ValueError: If ``new_time`` contains duplicate values. - ValueError: If ``new_time`` is not sorted. - TypeError: If ``new_time`` is not a tensor. - ValueError: If ``new_time`` is not of floating-point type. - ValueError: - If ``new_time`` is not within the range of follow-up in ``time``. - Assessed only if ``within_follow_up`` is True. - """ - new_time_sorted, indices = torch.unique(new_time, return_inverse=True) - - if len(new_time_sorted) != len(new_time): - raise ValueError("'Value error: Input 'new_time' should contain unique values.") - - if not torch.all(indices == torch.arange(len(new_time))): - raise ValueError( - "Value error: Input 'new_time' should be sorted from the smallest time to the largest." - ) - - if not torch.is_tensor(new_time): - raise TypeError("Type error: Input 'new_time' should be a tensor.") - - if not torch.is_floating_point(new_time): - raise ValueError( - "Value error: Input 'new_time' should be of floating-point type." - ) - - if within_follow_up: - if new_time.max() >= time.max() or new_time.min() < time.min(): - raise ValueError( - "Value error: All new_time must be within follow-up time of test data: [{}; {}[".format( - time.min(), time.max() - ) - ) - - -def validate_estimate(estimate, time, new_time=None) -> None: - """Perform format and validity checks for estimate. - - Args: - estimate (torch.Tensor): - Estimates of shape = (n_samples,) or (n_samples, n_times). - time (torch.Tensor, float): - Time of event or censoring of size n_samples. - new_time (torch.Tensor, optional): - Time points for metric computation of size n_times. - - Raises: - TypeError: If ``estimate`` is not a tensor. - ValueError: If ``estimate`` has more than 2 dimensions. - ValueError: If number of rows of ``estimate`` is not n_samples. - ValueError: - If number of columns of ``estimate`` is not n_times. - Assessed only if ``new_time`` is not None. - """ - if not torch.is_tensor(estimate): - raise TypeError("Type error: Input 'estimate' should be a tensor.") - - if estimate.ndim > 2: - raise ValueError("Value error: Input 'estimate' should have 1 or 2 dimensions.") - - if estimate.shape[0] != time.shape[0]: - raise ValueError( - "Dimension mismatch: Inconsistent number of samples between inputs 'time' and 'estimate'." - ) - - if not new_time is None: - if estimate.ndim == 2 and estimate.shape[1] != new_time.shape[0]: - raise ValueError( - "Dimension mismatch: Expected inputs 'estimate' with {} columns, but got {}".format( - new_time.shape[0], estimate.shape[1] - ) - ) From 9df099a4046d7abf7e34166ad93d27954ddfba2c Mon Sep 17 00:00:00 2001 From: corolth1 Date: Tue, 10 Dec 2024 11:38:49 -0500 Subject: [PATCH 3/6] wip --- docs/notebooks/introduction.ipynb | 2 +- src/torchsurv/loss/cox.py | 19 +- src/torchsurv/loss/momentum.py | 32 +-- src/torchsurv/loss/weibull.py | 296 +++++++++++++-------------- src/torchsurv/stats/ipcw.py | 13 +- src/torchsurv/stats/kaplan_meier.py | 2 +- src/torchsurv/tools/validate_data.py | 22 +- tests/test_kaplan_meier.py | 39 ++++ tests/test_torch_compile.py | 15 +- 9 files changed, 246 insertions(+), 194 deletions(-) diff --git a/docs/notebooks/introduction.ipynb b/docs/notebooks/introduction.ipynb index bdfb74d..eec04fe 100644 --- a/docs/notebooks/introduction.ipynb +++ b/docs/notebooks/introduction.ipynb @@ -28,7 +28,7 @@ "# %pip install lifelines\n", "# %pip install matplotlib\n", "# %pip install sklearn\n", - "# %pip install pandas\n" + "# %pip install pandas" ] }, { diff --git a/src/torchsurv/loss/cox.py b/src/torchsurv/loss/cox.py index 774d8eb..8200add 100644 --- a/src/torchsurv/loss/cox.py +++ b/src/torchsurv/loss/cox.py @@ -39,7 +39,10 @@ def _partial_likelihood_efron( ] R = [torch.where(time_sorted >= time_unique[j])[0] for j in range(J)] + # Calculate the length of each element in H and store it in a tensor m = torch.tensor([len(h) for h in H]) + + # Create a boolean tensor indicating whether each element in H has a length greater than 0 include = torch.tensor([len(h) > 0 for h in H]) log_nominator = torch.stack([torch.sum(log_hz_sorted[h]) for h in H]) @@ -47,9 +50,9 @@ def _partial_likelihood_efron( denominator_naive = torch.stack([torch.sum(torch.exp(log_hz_sorted[r])) for r in R]) denominator_ties = torch.stack([torch.sum(torch.exp(log_hz_sorted[h])) for h in H]) - log_denominator_efron = torch.zeros(J).to(log_hz_sorted.device) + log_denominator_efron = torch.zeros(J, device=log_hz_sorted.device) for j in range(J): - mj = int(m[j].item()) # Convert tensor to int + mj = int(m[j].item()) for l in range(1, mj + 1): log_denominator_efron[j] += torch.log( denominator_naive[j] - (l - 1) / float(m[j]) * denominator_ties[j] @@ -63,8 +66,16 @@ def _partial_likelihood_breslow( event_sorted: torch.Tensor, time_sorted: torch.Tensor, ): - """Calculate the partial log likelihood for the Cox proportional hazards model - using Breslow's method to handle ties in event time. + """ + Compute the partial likelihood using Breslow's method for Cox proportional hazards model. + + Args: + log_hz_sorted (torch.Tensor): Log hazard rates sorted by time. + event_sorted (torch.Tensor): Binary tensor indicating if the event occurred (1) or was censored (0), sorted by time. + time_sorted (torch.Tensor): Event or censoring times sorted in ascending order. + + Returns: + torch.Tensor: The partial likelihood for the observed events. """ N = len(time_sorted) diff --git a/src/torchsurv/loss/momentum.py b/src/torchsurv/loss/momentum.py index 8cab479..98f640a 100644 --- a/src/torchsurv/loss/momentum.py +++ b/src/torchsurv/loss/momentum.py @@ -151,15 +151,15 @@ def forward( """ - estimate_q = self.online(inputs) - for estimate in zip(estimate_q, event, time): - self.memory_q.append(self.survtuple(*list(estimate))) + online_estimate = self.online(inputs) + for estimate in zip(online_estimate, event, time): + self.memory_q.append(self.survtuple(*estimate)) loss = self._bank_loss() with torch.no_grad(): self._update_momentum_encoder() - estimate_k = self.target(inputs) - for estimate in zip(estimate_k, event, time): - self.memory_k.append(self.survtuple(*list(estimate))) + target_estimate = self.target(inputs) + for estimate in zip(target_estimate, event, time): + self.memory_k.append(self.survtuple(*estimate)) return loss @torch.no_grad() # deactivates autograd @@ -187,27 +187,33 @@ def infer(self, inputs: torch.Tensor) -> torch.Tensor: return self.target(inputs) def _bank_loss(self) -> torch.Tensor: - """computer the negative loss likelyhood from memory bank""" + """compute the negative log-likelihood from memory bank""" # Combine current batch and momentum bank = self.memory_k + self.memory_q assert all( x in bank[0]._fields for x in ["estimate", "event", "time"] ), "All fields must be present" - return self.loss( - torch.stack([mem.estimate.cpu() for mem in bank]).squeeze(), - torch.stack([mem.event.cpu() for mem in bank]).squeeze(), - torch.stack([mem.time.cpu() for mem in bank]).squeeze(), - ) + log_estimates = torch.stack([mem.estimate.cpu() for mem in bank]).squeeze() + events = torch.stack([mem.event.cpu() for mem in bank]).squeeze() + times = torch.stack([mem.time.cpu() for mem in bank]).squeeze() + return self.loss(log_estimates, events, times) @torch.no_grad() def _update_momentum_encoder(self): - """Exponantial moving average""" + """Exponential moving average""" for param_b, param_m in zip(self.online.parameters(), self.target.parameters()): param_m.data = param_m.data * self.rate + param_b.data * (1.0 - self.rate) @torch.no_grad() def _init_encoder_k(self): + """ + Initialize the target network (encoder_k) with the parameters of the online network (encoder_q). + The requires_grad attribute of the target network parameters is set to False to prevent gradient updates during training, + ensuring that the target network remains a stable reference point. + This method uses the `copy_` method to copy the parameters from the online network to the target network + and sets the requires_grad attribute of the target network parameters to False to prevent gradient updates. + """ for param_q, param_k in zip(self.online.parameters(), self.target.parameters()): param_k.data.copy_(param_q.data) param_k.requires_grad = False diff --git a/src/torchsurv/loss/weibull.py b/src/torchsurv/loss/weibull.py index bc25141..fc68360 100644 --- a/src/torchsurv/loss/weibull.py +++ b/src/torchsurv/loss/weibull.py @@ -2,12 +2,146 @@ import torch -TORCH_CLAMP_VALUE = 1e10 +from torchsurv.tools.validate_data import validate_log_shape, validate_weibull -from torchsurv.tools.validate_data import validate_weibull + +@torch.jit.script +def cumulative_hazard( + log_params: torch.Tensor, + time: torch.Tensor, + all_times: bool = True, + clamp_value: float = 1e10, +) -> torch.Tensor: + """Cumulative hazard for the Weibull Accelerated Time Failure (AFT) survival model. + + Args: + log_params (torch.Tensor, float): + Parameters of the Weibull distribution of shape = (n_samples, 1) or (n_samples, 2). + The first column corresponds to the log scale parameter. The second column + corresponds to the log shape parameter. If the log shape parameter is missing, it is + imputed with 0. + time (torch.Tensor, float): + Time-to-event or censoring of length n_samples. + all_times (bool) + If True, subject-specific cumulative hazard is evaluated at all ``time`` (used for evaluation metrics). + If False, subject-specific cumulative hazard is evaluated at respective ``time``. + Defaults is True. + + Returns: + (torch.Tensor, float): Subject-specific cumulative hazard evaluated at ``time``. + + Examples: + >>> _ = torch.manual_seed(42) + >>> time = torch.randint(low=1, high=100, size=(4,)) + >>> log_params = torch.randn((4, 2)) + >>> cumulative_hazard(log_params, time, all_times=False) # Cumulative hazard at respective time + tensor([ 8.6257, 112.2115, 3.5105, 112.6339]) + >>> cumulative_hazard(log_params, time, all_times=True) # Default. Cumulative hazard at all time + tensor([[ 8.6257, 233.0865, 239.2167, 126.2805], + [ 12.7698, 112.2115, 114.1484, 74.9134], + [ 0.8706, 3.4725, 3.5105, 2.6850], + [ 6.9530, 212.7592, 218.5687, 112.6339]]) + """ + log_scale, log_shape = validate_log_shape(log_params).unbind(1) + + if all_times: + # Use all times for each sample + time = time.unsqueeze(0).expand(len(time), len(time)) # expand across rows + log_scale = log_scale.unsqueeze(1).expand( + len(time), len(time) + ) # expand across columns + log_shape = log_shape.unsqueeze(1).expand( + len(time), len(time) + ) # expand across columns + + return torch.clamp( + torch.exp( + torch.exp(log_shape) + * (torch.log(torch.clamp(time, min=1e-100, max=torch.inf)) - log_scale) + ), + min=0, + max=clamp_value, + ) @torch.jit.script +def log_hazard( + log_params: torch.Tensor, + time: torch.Tensor, + all_times: bool = True, + clamp_value: float = 1e10, +) -> torch.Tensor: + """Log hazard of the Weibull Accelerated Time Failure (AFT) survival model. + + Args: + log_params (torch.Tensor, float): + Parameters of the Weibull distribution of shape = (n_samples, 1) or (n_samples, 2). + The first column corresponds to the log scale parameter. The second column + corresponds to the log shape parameter. If the log shape parameter is missing, it is + imputed with 0. + time (torch.Tensor, float): + Time at which to evaluate the log hazard. + Should be of length n_samples to evaluate the log hazard at observed time-to-event or censoring, + or of length one to evaluate the log hazard at a new time. + all_times (bool): + If True, subject-specific log hazard is evaluated at all ``time`` (used for evaluation metrics). + If False, subject-specific log hazard is evaluated at respective ``time``. + Defaults is True. + Ignored if ``time`` is of length one. + + Returns: + (torch.Tensor, float): Subject-specific log hazard evaluated at ``time``. + + Examples: + >>> _ = torch.manual_seed(42) + >>> time = torch.randint(low=1, high=100, size=(4,)) + >>> log_params = torch.randn((4, 2)) + >>> log_hazard(log_params, time, all_times = False) # Log hazard at respective time + tensor([ 0.4392, -0.0303, -3.9672, 0.9140]) + >>> log_hazard(log_params, time, all_times = True) # Default. Log hazard at all time + tensor([[ 0.4392, 1.1174, 1.1227, 0.9913], + [ 0.4148, -0.0303, -0.0338, 0.0525], + [-2.7225, -3.9575, -3.9672, -3.7279], + [ 0.2606, 1.0632, 1.0695, 0.9140]]) + >>> log_hazard(log_params, time=torch.tensor(10.0)) # Log hazard at one new time (e.g., 10 years) + tensor([ 0.5316, 0.3542, -2.8907, 0.3699]) + >>> for t in torch.tensor([100.0, 150.0]): log_hazard(log_params, time=t) # Subject-specific log hazard at multiple new times + tensor([ 1.1280, -0.0372, -3.9767, 1.0757]) + tensor([ 1.2330, -0.1062, -4.1680, 1.1999]) + >>> log_params *= 1e2 # Increase scale + >>> log_hazard(log_params, time, all_times = False) # Check for Torch.Inf values + tensor([-1.0000e+10, -2.3197e+01, -6.8385e+01, -1.0000e+10]) + """ + + log_scale, log_shape = validate_log_shape(log_params).unbind(1) + + if time.dim() == 0: + # Use fixed time for each sample + time = time.repeat(len(log_params)) + elif time.size(0) == log_params.size(0) and all_times: + # Use all times for each sample + time = time.unsqueeze(0).expand(len(time), len(time)) # expand across rows + log_scale = log_scale.unsqueeze(1).expand( + len(time), len(time) + ) # expand across columns + log_shape = log_shape.unsqueeze(1).expand( + len(time), len(time) + ) # expand across columns + if time.size(0) != log_params.size(0): + raise ValueError( + f"Dimension mismatch: 'time' ({len(time)}) does not match the length of 'log_params' ({len(log_params)})." + ) + + return torch.clamp( + log_shape + - log_scale + + torch.expm1(log_shape) + * (torch.log(torch.clamp(time, min=1e-100, max=torch.inf)) - log_scale), + min=-clamp_value, + max=clamp_value, + ) + + def neg_log_likelihood( log_params: torch.Tensor, event: torch.Tensor, @@ -102,8 +236,8 @@ def neg_log_likelihood( # Negative log likelihood nll = torch.neg( - event * log_hazard(log_params, time, all_times=False) - - cumulative_hazard(log_params, time, all_times=False) # Huge values here + event * log_hazard(log_params, time, False) + - cumulative_hazard(log_params, time, False) # Huge values here ) if any(torch.isinf(nll)): @@ -167,7 +301,7 @@ def survival_function( """ - log_scale, log_shape = _check_log_shape(log_params).unbind(1) + log_scale, log_shape = validate_log_shape(log_params).unbind(1) if time.dim() == 0: # Use one time for each sample @@ -190,158 +324,6 @@ def survival_function( ).cdf(time) -@torch.jit.script -def log_hazard( - log_params: torch.Tensor, time: torch.Tensor, all_times: bool = True -) -> torch.Tensor: - """Log hazard of the Weibull Accelerated Time Failure (AFT) survival model. - - Args: - log_params (torch.Tensor, float): - Parameters of the Weibull distribution of shape = (n_samples, 1) or (n_samples, 2). - The first column corresponds to the log scale parameter. The second column - corresponds to the log shape parameter. If the log shape parameter is missing, it is - imputed with 0. - time (torch.Tensor, float): - Time at which to evaluate the log hazard. - Should be of length n_samples to evaluate the log hazard at observed time-to-event or censoring, - or of length one to evaluate the log hazard at a new time. - all_times (bool): - If True, subject-specific log hazard is evaluated at all ``time`` (used for evaluation metrics). - If False, subject-specific log hazard is evaluated at respective ``time``. - Defaults is True. - Ignored if ``time`` is of length one. - - Returns: - (torch.Tensor, float): Subject-specific log hazard evaluated at ``time``. - - Examples: - >>> _ = torch.manual_seed(42) - >>> time = torch.randint(low=1, high=100, size=(4,)) - >>> log_params = torch.randn((4, 2)) - >>> log_hazard(log_params, time, all_times = False) # Log hazard at respective time - tensor([ 0.4392, -0.0303, -3.9672, 0.9140]) - >>> log_hazard(log_params, time, all_times = True) # Default. Log hazard at all time - tensor([[ 0.4392, 1.1174, 1.1227, 0.9913], - [ 0.4148, -0.0303, -0.0338, 0.0525], - [-2.7225, -3.9575, -3.9672, -3.7279], - [ 0.2606, 1.0632, 1.0695, 0.9140]]) - >>> log_hazard(log_params, time=torch.tensor(10.0)) # Log hazard at one new time (e.g., 10 years) - tensor([ 0.5316, 0.3542, -2.8907, 0.3699]) - >>> for t in torch.tensor([100.0, 150.0]): log_hazard(log_params, time=t) # Subject-specific log hazard at multiple new times - tensor([ 1.1280, -0.0372, -3.9767, 1.0757]) - tensor([ 1.2330, -0.1062, -4.1680, 1.1999]) - >>> log_params *= 1e2 # Increase scale - >>> log_hazard(log_params, time, all_times = False) # Check for Torch.Inf values - tensor([-1.0000e+10, -2.3197e+01, -6.8385e+01, -1.0000e+10]) - """ - - log_scale, log_shape = _check_log_shape(log_params).unbind(1) - - if time.dim() == 0: - # Use fixed time for each sample - time = time.repeat(len(log_params)) - elif all([time.size(0) == log_params.size(0), all_times]): - # Use all times for each sample - time = time.unsqueeze(0).expand(len(time), len(time)) # expand across rows - log_scale = log_scale.unsqueeze(1).expand( - len(time), len(time) - ) # expand across columns - log_shape = log_shape.unsqueeze(1).expand( - len(time), len(time) - ) # expand across columns - if time.size(0) != log_params.size(0): - raise ValueError( - f"Dimension mismatch: 'time' ({len(time)}) does not match the length of 'log_params' ({len(log_params)})." - ) - - return torch.clamp( - log_shape - - log_scale - + torch.expm1(log_shape) - * (torch.log(torch.clamp(time, min=1e-100, max=torch.inf)) - log_scale), - min=-TORCH_CLAMP_VALUE, - max=TORCH_CLAMP_VALUE, - ) - - -@torch.jit.script -def cumulative_hazard( - log_params: torch.Tensor, time: torch.Tensor, all_times: bool = True -) -> torch.Tensor: - """Cumulative hazard for the Weibull Accelerated Time Failure (AFT) survival model. - - Args: - log_params (torch.Tensor, float): - Parameters of the Weibull distribution of shape = (n_samples, 1) or (n_samples, 2). - The first column corresponds to the log scale parameter. The second column - corresponds to the log shape parameter. If the log shape parameter is missing, it is - imputed with 0. - time (torch.Tensor, float): - Time-to-event or censoring of length n_samples. - all_times (bool) - If True, subject-specific cumulative hazard is evaluated at all ``time`` (used for evaluation metrics). - If False, subject-specific cumulative hazard is evaluated at respective ``time``. - Defaults is True. - - Returns: - (torch.Tensor, float): Subject-specific cumulative hazard evaluated at ``time``. - - Examples: - >>> _ = torch.manual_seed(42) - >>> time = torch.randint(low=1, high=100, size=(4,)) - >>> log_params = torch.randn((4, 2)) - >>> cumulative_hazard(log_params, time, all_times=False) # Cumulative hazard at respective time - tensor([ 8.6257, 112.2115, 3.5105, 112.6339]) - >>> cumulative_hazard(log_params, time, all_times=True) # Default. Cumulative hazard at all time - tensor([[ 8.6257, 233.0865, 239.2167, 126.2805], - [ 12.7698, 112.2115, 114.1484, 74.9134], - [ 0.8706, 3.4725, 3.5105, 2.6850], - [ 6.9530, 212.7592, 218.5687, 112.6339]]) - """ - log_scale, log_shape = _check_log_shape(log_params).unbind(1) - - if all_times: - # Use all times for each sample - time = time.unsqueeze(0).expand(len(time), len(time)) # expand across rows - log_scale = log_scale.unsqueeze(1).expand( - len(time), len(time) - ) # expand across columns - log_shape = log_shape.unsqueeze(1).expand( - len(time), len(time) - ) # expand across columns - - return torch.clamp( - torch.exp( - torch.exp(log_shape) - * (torch.log(torch.clamp(time, min=1e-100, max=torch.inf)) - log_scale) - ), - min=0, - max=TORCH_CLAMP_VALUE, - ) - - -@torch.jit.script -def _check_log_shape(log_params: torch.Tensor) -> torch.Tensor: - """Private function, check if the log shape is missing and impute it with 0 - if needed.""" - if any( - [ - log_params.dim() == 0, - log_params.dim() == 1, # if shape = [n_samples] - log_params.dim() > 1 - and log_params.size(1) == 1, # if shape = [n_samples, 1] - ] - ): - if log_params.dim() == 1: - log_params = log_params.unsqueeze(1) - - # Missing log shape parameter. Creating zeros placeholder instead. - log_params = torch.hstack((log_params, torch.zeros_like(log_params))) - - return log_params - - if __name__ == "__main__": import doctest diff --git a/src/torchsurv/stats/ipcw.py b/src/torchsurv/stats/ipcw.py index 5277ee4..354b3a6 100644 --- a/src/torchsurv/stats/ipcw.py +++ b/src/torchsurv/stats/ipcw.py @@ -4,12 +4,11 @@ import torch -from torchsurv.tools.validate_data import validate_inputs from torchsurv.stats import kaplan_meier +from torchsurv.tools.validate_data import validate_inputs # pylint: disable=anomalous-backslash-in-string -@torch.jit.script def get_ipcw( event: torch.Tensor, time: torch.Tensor, @@ -98,12 +97,13 @@ def _inverse_censoring_dist(ct: torch.Tensor) -> torch.Tensor: """ if torch.any(ct == 0.0): - zero_indices = torch.nonzero(ct.eq(0.0), as_tuple=True)[0] + zero_indices = torch.nonzero(ct.eq(0.0)).squeeze() + zero_indices_list = zero_indices.tolist() # Explicitly convert to list warnings.warn( - f"Censoring distribution zero at time points: {zero_indices.tolist()}. Returning ones as weight" + f"Censoring distribution zero at time points: {zero_indices_list}. Returning ones as weight" ) weight = 1.0 / ct - weight = torch.ones(1, dtype=ct.dtype) / ct + weight = torch.ones_like(ct) / ct return weight @@ -114,6 +114,3 @@ def _inverse_censoring_dist(ct: torch.Tensor) -> torch.Tensor: results = doctest.testmod() if results.failed == 0: print("All tests passed.") - else: - print("Some doctests failed.") - sys.exit(1) diff --git a/src/torchsurv/stats/kaplan_meier.py b/src/torchsurv/stats/kaplan_meier.py index c25d465..ae18746 100644 --- a/src/torchsurv/stats/kaplan_meier.py +++ b/src/torchsurv/stats/kaplan_meier.py @@ -7,7 +7,7 @@ from torchsurv.tools import validate_data -# @torch.jit.script +@torch.jit.script class KaplanMeierEstimator: """Kaplan-Meier estimate of survival or censoring distribution for right-censored data :cite:p:`Kaplan1958`.""" diff --git a/src/torchsurv/tools/validate_data.py b/src/torchsurv/tools/validate_data.py index 27faa6a..7e6333b 100644 --- a/src/torchsurv/tools/validate_data.py +++ b/src/torchsurv/tools/validate_data.py @@ -1,5 +1,25 @@ import torch -import torch + + +@torch.jit.script +def validate_log_shape(log_params: torch.Tensor) -> torch.Tensor: + """Private function, check if the log shape is missing and impute it with 0 + if needed.""" + if any( + [ + log_params.dim() == 0, + log_params.dim() == 1, # if shape = [n_samples] + log_params.dim() > 1 + and log_params.size(1) == 1, # if shape = [n_samples, 1] + ] + ): + if log_params.dim() == 1: + log_params = log_params.unsqueeze(1) + + # Missing log shape parameter. Creating zeros placeholder instead. + log_params = torch.hstack((log_params, torch.zeros_like(log_params))) + + return log_params @torch.jit.script diff --git a/tests/test_kaplan_meier.py b/tests/test_kaplan_meier.py index 62e0a5d..76eaa6d 100644 --- a/tests/test_kaplan_meier.py +++ b/tests/test_kaplan_meier.py @@ -216,6 +216,45 @@ def test_kaplan_meier_prediction_error_raised(self): self.assertRaises(ValueError, km.predict, test_time) + def test_kaplan_meier_plot_km(self): + """test Kaplan Meier plot function""" + import matplotlib.pyplot as plt + + event = torch.tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 1], dtype=torch.bool) + time = torch.tensor([1, 2, 2, 3, 3, 4, 4, 5, 6, 7], dtype=torch.float32) + + km = KaplanMeierEstimator() + km(event, time, censoring_dist=False) + + fig, ax = plt.subplots() + km.plot_km(ax=ax) + plt.close(fig) # Close the plot to avoid displaying it during tests + + def test_kaplan_meier_print_survival_table(self): + """test Kaplan Meier print survival table function""" + event = torch.tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 1], dtype=torch.bool) + time = torch.tensor([1, 2, 2, 3, 3, 4, 4, 5, 6, 7], dtype=torch.float32) + + km = KaplanMeierEstimator() + km(event, time, censoring_dist=False) + km.print_survival_table() + + # Check if the survival table is printed correctly + expected_output = ( + "Time\tSurvival\n" + "----------------\n" + "1.00\t1.0000\n" + "2.00\t0.8889\n" + "3.00\t0.6667\n" + "4.00\t0.5000\n" + "5.00\t0.5000\n" + "6.00\t0.3333\n" + "7.00\t0.0000\n" + ) + with self.assertLogs(level="INFO") as log: + km.print_survival_table() + self.assertIn(expected_output, log.output) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_torch_compile.py b/tests/test_torch_compile.py index f665a64..9d7514d 100644 --- a/tests/test_torch_compile.py +++ b/tests/test_torch_compile.py @@ -21,11 +21,12 @@ import numpy as np import torch +# torch compile settings +import torch._inductor.config + # Local modules from torchsurv.loss.cox import neg_partial_log_likelihood as cox -# torch compile settings -import torch._inductor.config torch._inductor.config.cpp.cxx = ("g++",) # set seed for reproducibility @@ -33,19 +34,15 @@ # TODO: wrap this in TestCoxSurvivalLossCompile(unittest.TestCase) if __name__ == "__main__": - + # random data and parameters N = 32 log_hz = torch.randn(N) event = torch.randint(low=0, high=2, size=(N,)).bool() time = torch.randint(low=1, high=100, size=(N,)) - + # compile cox ccox = torch.compile(cox) - + loss_cox = cox(log_hz, event, time) loss_ccox = ccox(log_hz, event, time) - - - - \ No newline at end of file From bfad67be2212363aa2bd31bb8de6937b0b90e085 Mon Sep 17 00:00:00 2001 From: corolth1 Date: Tue, 10 Dec 2024 11:51:42 -0500 Subject: [PATCH 4/6] cleanup, cox working --- src/torchsurv/loss/cox.py | 9 +- src/torchsurv/loss/weibull.py | 4 +- src/torchsurv/tools/validate_data.py | 138 +++++++++++++++------------ tests/test_torch_compile.py | 48 ---------- 4 files changed, 83 insertions(+), 116 deletions(-) delete mode 100644 tests/test_torch_compile.py diff --git a/src/torchsurv/loss/cox.py b/src/torchsurv/loss/cox.py index 8200add..222b305 100644 --- a/src/torchsurv/loss/cox.py +++ b/src/torchsurv/loss/cox.py @@ -6,7 +6,7 @@ import torch -from torchsurv.tools.validate_data import validate_cox +from torchsurv.tools.validate_data import validate_inputs @torch.jit.script @@ -78,9 +78,8 @@ def _partial_likelihood_breslow( torch.Tensor: The partial likelihood for the observed events. """ N = len(time_sorted) - R = [torch.where(time_sorted >= time_sorted[i])[0] for i in range(N)] - log_denominator = torch.tensor( + log_denominator = torch._stack( [torch.logsumexp(log_hz_sorted[R[i]], dim=0) for i in range(N)] ) @@ -199,9 +198,9 @@ def neg_partial_log_likelihood( """ if checks: - validate_cox(log_hz, event, time) + validate_inputs(log_hz, event, time, model_type="cox") - if torch.any([event.sum().item() == 0, len(log_hz.size()) == 0]): + if any([event.sum().item() == 0, len(log_hz.size()) == 0]): warnings.warn("No events OR single sample. Returning zero loss for the batch") return torch.tensor(0.0, requires_grad=True) diff --git a/src/torchsurv/loss/weibull.py b/src/torchsurv/loss/weibull.py index fc68360..32ca9ba 100644 --- a/src/torchsurv/loss/weibull.py +++ b/src/torchsurv/loss/weibull.py @@ -2,7 +2,7 @@ import torch -from torchsurv.tools.validate_data import validate_log_shape, validate_weibull +from torchsurv.tools.validate_data import validate_log_shape, validate_inputs @torch.jit.script @@ -232,7 +232,7 @@ def neg_log_likelihood( """ if checks: - validate_weibull(log_params, event, time) + validate_inputs(log_params, event, time, model_type="weibull") # Negative log likelihood nll = torch.neg( diff --git a/src/torchsurv/tools/validate_data.py b/src/torchsurv/tools/validate_data.py index 7e6333b..f0813c7 100644 --- a/src/torchsurv/tools/validate_data.py +++ b/src/torchsurv/tools/validate_data.py @@ -22,67 +22,6 @@ def validate_log_shape(log_params: torch.Tensor) -> torch.Tensor: return log_params -@torch.jit.script -def validate_weibull(log_params: torch.Tensor, event: torch.Tensor, time: torch.Tensor): - """Private function, perform input format checks.""" - if not isinstance(log_params, torch.Tensor): - raise TypeError("Input 'log_params' must be a tensor.") - - if not isinstance(event, torch.Tensor): - raise TypeError("Input 'event' must be a tensor.") - - if not isinstance(time, torch.Tensor): - raise TypeError("``Input 'time' must be a tensor.") - - if log_params.shape[0] != len(event): - raise ValueError( - "Length mismatch: The length of 'log_params' must match the length of 'event'." - ) - - if len(time) != len(event): - raise ValueError( - "Length mismatch: The length of 'time' must match the length of 'event'.`" - ) - - if torch.any(time < 0): - raise ValueError("All elements in 'time' must be non-negative.") - - if torch.any((event != 0) & (event != 1)): - raise ValueError( - "Invalid values: 'event' must contain only boolean values (True/False or 1/0)" - ) - - -@torch.jit.script -def validate_cox(log_hz: torch.Tensor, event: torch.Tensor, time: torch.Tensor): - if not isinstance(log_hz, torch.Tensor): - raise TypeError("Input 'log_hz' must be a tensor.") - - if not isinstance(event, torch.Tensor): - raise TypeError("Input 'event' must be a tensor.") - - if not isinstance(time, torch.Tensor): - raise TypeError("Input 'time' must be a tensor.") - - if len(log_hz) != len(event): - raise ValueError( - "Length mismatch: 'log_hz' and 'event' must have the same length." - ) - - if len(time) != len(event): - raise ValueError( - "Length mismatch: 'time' must have the same length as 'event'." - ) - - if torch.any(time < 0): - raise ValueError("Invalid values: All elements in 'time' must be non-negative.") - - if torch.any((event != 0) & (event != 1)): - raise ValueError( - "Invalid values: 'event' must contain only boolean values (True/False or 1/0)" - ) - - @torch.jit.script def validate_inputs(event: torch.Tensor, time: torch.Tensor) -> None: """ @@ -164,3 +103,80 @@ def validate_new_time( raise ValueError("Value error: Input 'new_time' should contain unique values.") check_within_follow_up(new_time, time, within_follow_up) + + +import torch + + +@torch.jit.script +def validate_inputs( + log_params: torch.Tensor, event: torch.Tensor, time: torch.Tensor, model_type: str +) -> None: + """ + Validate the inputs for survival analysis functions. + + Args: + log_params (torch.Tensor): Parameters of the model (log hazards or log parameters). + event (torch.Tensor): Event indicator tensor. + time (torch.Tensor): Time-to-event or censoring tensor. + model_type (str): Type of the model ('weibull' or 'cox'). + + Raises: + TypeError: If inputs are not tensors. + ValueError: If any `time` are negative. + ValueError: If `event` contains invalid values. + ValueError: If lengths of inputs do not match. + """ + if not isinstance(log_params, torch.Tensor): + raise TypeError("Input 'log_params' must be a tensor.") + + if not isinstance(event, torch.Tensor): + raise TypeError("Input 'event' must be a tensor.") + + if not isinstance(time, torch.Tensor): + raise TypeError("Input 'time' must be a tensor.") + + if log_params.shape[0] != len(event): + raise ValueError( + "Length mismatch: The length of 'log_params' must match the length of 'event'." + ) + + if len(time) != len(event): + raise ValueError( + "Length mismatch: The length of 'time' must match the length of 'event'." + ) + + if torch.any(time < 0): + raise ValueError("All elements in 'time' must be non-negative.") + + if torch.any((event != 0) & (event != 1)): + raise ValueError( + "Invalid values: 'event' must contain only boolean values (True/False or 1/0)" + ) + + if model_type == "weibull": + if log_params.shape[1] != 2: + raise ValueError( + "For Weibull model, 'log_params' must have shape (n_samples, 2)." + ) + elif model_type == "cox": + if log_params.shape[1] != 1: + raise ValueError( + "For Cox model, 'log_params' must have shape (n_samples, 1)." + ) + else: + raise ValueError("Invalid model type. Must be 'weibull' or 'cox'.") + + +# Example usage +if __name__ == "__main__": + log_params_weibull = torch.randn((5, 2)) + log_params_cox = torch.randn((5, 1)) + event = torch.tensor([1, 0, 1, 1, 0], dtype=torch.float32) + time = torch.tensor([10, 20, 30, 40, 50], dtype=torch.float32) + + # Validate Weibull model inputs + validate_inputs(log_params_weibull, event, time, model_type="weibull") + + # Validate Cox model inputs + validate_inputs(log_params_cox, event, time, model_type="cox") diff --git a/tests/test_torch_compile.py b/tests/test_torch_compile.py deleted file mode 100644 index 9d7514d..0000000 --- a/tests/test_torch_compile.py +++ /dev/null @@ -1,48 +0,0 @@ -""" - -Tests for torch.compile - -Note: conda install conda-forge::gxx - -TODO: - - test conda install conda-forge::cxx-compiler - - check torchtriton for GPU - -References: - - https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html - - https://github.com/pytorch/pytorch/issues/122094 - -""" - -# global modules -import json -import unittest - -import numpy as np -import torch - -# torch compile settings -import torch._inductor.config - -# Local modules -from torchsurv.loss.cox import neg_partial_log_likelihood as cox - -torch._inductor.config.cpp.cxx = ("g++",) - -# set seed for reproducibility -torch.manual_seed(42) - -# TODO: wrap this in TestCoxSurvivalLossCompile(unittest.TestCase) -if __name__ == "__main__": - - # random data and parameters - N = 32 - log_hz = torch.randn(N) - event = torch.randint(low=0, high=2, size=(N,)).bool() - time = torch.randint(low=1, high=100, size=(N,)) - - # compile cox - ccox = torch.compile(cox) - - loss_cox = cox(log_hz, event, time) - loss_ccox = ccox(log_hz, event, time) From ad53a177f94875f17dbe8f3509dd8ef001092010 Mon Sep 17 00:00:00 2001 From: corolth1 Date: Wed, 11 Dec 2024 09:51:36 -0500 Subject: [PATCH 5/6] more explicit --- src/torchsurv/loss/cox.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchsurv/loss/cox.py b/src/torchsurv/loss/cox.py index 222b305..1e1840f 100644 --- a/src/torchsurv/loss/cox.py +++ b/src/torchsurv/loss/cox.py @@ -17,7 +17,8 @@ def _partial_likelihood_cox( """Calculate the partial log likelihood for the Cox proportional hazards model in the absence of ties in event time. """ - log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0) + log_hz_flipped = log_hz_sorted.flip(0) + log_denominator = torch.logcumsumexp(log_hz_flipped, dim=0).flip(0) return (log_hz_sorted - log_denominator)[event_sorted] From 093ce0b53b58af42d41221d44124c5acf80d23a2 Mon Sep 17 00:00:00 2001 From: corolth1 Date: Wed, 11 Dec 2024 14:51:36 -0500 Subject: [PATCH 6/6] working doctest, cleanup and ready to check pytest on CI --- src/torchsurv/loss/cox.py | 6 +-- src/torchsurv/loss/weibull.py | 7 +-- src/torchsurv/metrics/auc.py | 14 ++++-- src/torchsurv/metrics/brier_score.py | 14 +++--- src/torchsurv/metrics/cindex.py | 6 +-- src/torchsurv/stats/ipcw.py | 6 +-- src/torchsurv/stats/kaplan_meier.py | 5 +- src/torchsurv/tools/validate_data.py | 72 ++++++++++++++++------------ 8 files changed, 70 insertions(+), 60 deletions(-) diff --git a/src/torchsurv/loss/cox.py b/src/torchsurv/loss/cox.py index 1e1840f..3eba4e6 100644 --- a/src/torchsurv/loss/cox.py +++ b/src/torchsurv/loss/cox.py @@ -6,7 +6,7 @@ import torch -from torchsurv.tools.validate_data import validate_inputs +from torchsurv.tools.validate_data import validate_loss @torch.jit.script @@ -80,7 +80,7 @@ def _partial_likelihood_breslow( """ N = len(time_sorted) R = [torch.where(time_sorted >= time_sorted[i])[0] for i in range(N)] - log_denominator = torch._stack( + log_denominator = torch.stack( [torch.logsumexp(log_hz_sorted[R[i]], dim=0) for i in range(N)] ) @@ -199,7 +199,7 @@ def neg_partial_log_likelihood( """ if checks: - validate_inputs(log_hz, event, time, model_type="cox") + validate_loss(log_hz, event, time, model_type="cox") if any([event.sum().item() == 0, len(log_hz.size()) == 0]): warnings.warn("No events OR single sample. Returning zero loss for the batch") diff --git a/src/torchsurv/loss/weibull.py b/src/torchsurv/loss/weibull.py index 32ca9ba..469aa39 100644 --- a/src/torchsurv/loss/weibull.py +++ b/src/torchsurv/loss/weibull.py @@ -2,10 +2,9 @@ import torch -from torchsurv.tools.validate_data import validate_log_shape, validate_inputs +from torchsurv.tools.validate_data import validate_log_shape, validate_loss -@torch.jit.script def cumulative_hazard( log_params: torch.Tensor, time: torch.Tensor, @@ -64,7 +63,6 @@ def cumulative_hazard( ) -@torch.jit.script def log_hazard( log_params: torch.Tensor, time: torch.Tensor, @@ -232,7 +230,7 @@ def neg_log_likelihood( """ if checks: - validate_inputs(log_params, event, time, model_type="weibull") + validate_loss(log_params, event, time, model_type="weibull") # Negative log likelihood nll = torch.neg( @@ -257,7 +255,6 @@ def neg_log_likelihood( return loss -@torch.jit.script def survival_function( log_params: torch.Tensor, time: torch.Tensor, all_times: bool = True ) -> torch.Tensor: diff --git a/src/torchsurv/metrics/auc.py b/src/torchsurv/metrics/auc.py index 35444da..303e68e 100644 --- a/src/torchsurv/metrics/auc.py +++ b/src/torchsurv/metrics/auc.py @@ -6,8 +6,12 @@ from scipy import stats from torchmetrics import regression -from ..stats import kaplan_meier -from ..tools import validate_data +from torchsurv.stats import kaplan_meier +from torchsurv.tools.validate_data import ( + validate_log_shape, + validate_new_time, + validate_survival_data, +) class Auc: @@ -212,9 +216,9 @@ def __call__( # further input format checks if self.checks: - validate_data.validate_survival_data(event, time) - validate_data.validate_evaluation_time(new_time, time) - validate_data.validate_estimate(estimate, time, new_time) + validate_survival_data(event, time) + validate_new_time(new_time, time) + validate_log_shape(estimate) # sample size and length of time n_samples, n_times = estimate.shape[0], new_time.shape[0] diff --git a/src/torchsurv/metrics/brier_score.py b/src/torchsurv/metrics/brier_score.py index d6c07d6..77c1f53 100644 --- a/src/torchsurv/metrics/brier_score.py +++ b/src/torchsurv/metrics/brier_score.py @@ -5,7 +5,11 @@ import torch from scipy import stats -from ..tools import validate_data +from torchsurv.tools.validate_data import ( + validate_log_shape, + validate_new_time, + validate_survival_data, +) class BrierScore: @@ -185,11 +189,9 @@ def __call__( # further input format checks if self.checks: - validate_data.validate_survival_data(event, time) - validate_data.validate_evaluation_time( - new_time, time, within_follow_up=False - ) - validate_data.validate_estimate(estimate, time, new_time) + validate_survival_data(event, time) + validate_new_time(new_time, time, within_follow_up=False) + validate_log_shape(estimate) # Calculating the residuals for each subject and time point residuals = torch.zeros_like(estimate) diff --git a/src/torchsurv/metrics/cindex.py b/src/torchsurv/metrics/cindex.py index a79cd9b..4882c6d 100644 --- a/src/torchsurv/metrics/cindex.py +++ b/src/torchsurv/metrics/cindex.py @@ -7,7 +7,7 @@ from scipy import stats from torchmetrics import regression -from ..tools import validate_data +from torchsurv.tools.validate_data import validate_log_shape, validate_survival_data class ConcordanceIndex: @@ -185,8 +185,8 @@ def __call__( # Inputs checks if self.checks: - validate_data.validate_survival_data(event, time) - validate_data.validate_estimate(estimate, time) + validate_survival_data(event, time) + validate_log_shape(estimate) # find comparable pairs comparable = self._get_comparable_and_tied_time(event, time) diff --git a/src/torchsurv/stats/ipcw.py b/src/torchsurv/stats/ipcw.py index 354b3a6..20446c9 100644 --- a/src/torchsurv/stats/ipcw.py +++ b/src/torchsurv/stats/ipcw.py @@ -1,11 +1,10 @@ -import sys import warnings from typing import Optional import torch from torchsurv.stats import kaplan_meier -from torchsurv.tools.validate_data import validate_inputs +from torchsurv.tools.validate_data import validate_survival_data # pylint: disable=anomalous-backslash-in-string @@ -56,7 +55,7 @@ def get_ipcw( """ if checks: - validate_inputs(event, time) + validate_survival_data(event, time) # time on which to evaluate IPCW if new_time is None: # if none, return ipcw of same size as time @@ -77,7 +76,6 @@ def get_ipcw( return ipcw -@torch.jit.script def _inverse_censoring_dist(ct: torch.Tensor) -> torch.Tensor: """Compute inverse of the censoring distribution. diff --git a/src/torchsurv/stats/kaplan_meier.py b/src/torchsurv/stats/kaplan_meier.py index ae18746..da79e68 100644 --- a/src/torchsurv/stats/kaplan_meier.py +++ b/src/torchsurv/stats/kaplan_meier.py @@ -4,10 +4,9 @@ import torch -from torchsurv.tools import validate_data +from torchsurv.tools.validate_data import validate_survival_data -@torch.jit.script class KaplanMeierEstimator: """Kaplan-Meier estimate of survival or censoring distribution for right-censored data :cite:p:`Kaplan1958`.""" @@ -63,7 +62,7 @@ def __call__( # Check input validity if required if check: - validate_data.validate_inputs(event, time) + validate_survival_data(event, time) # Compute the counts of events, censorings, and the number at risk at each unique time uniq_times, n_events, n_at_risk, n_censored = self._compute_counts() diff --git a/src/torchsurv/tools/validate_data.py b/src/torchsurv/tools/validate_data.py index f0813c7..571dd8f 100644 --- a/src/torchsurv/tools/validate_data.py +++ b/src/torchsurv/tools/validate_data.py @@ -22,31 +22,6 @@ def validate_log_shape(log_params: torch.Tensor) -> torch.Tensor: return log_params -@torch.jit.script -def validate_inputs(event: torch.Tensor, time: torch.Tensor) -> None: - """ - Validate the inputs for survival analysis functions. - - Args: - event (torch.Tensor): Event indicator tensor. - time (torch.Tensor): Time-to-event or censoring tensor. - - Raises: - TypeError: If inputs are not tensors. - ValueError: If any ``time`` are negative. - """ - if not isinstance(event, torch.Tensor) or not isinstance(time, torch.Tensor): - raise TypeError("Inputs 'event' and 'time' should be tensors") - - if torch.any(time < 0): - raise ValueError("All elements in 'time' must be non-negative") - - if torch.any((event != 0) & (event != 1)): - raise ValueError( - "Input 'event' must contain only boolean values (True/False or 1/0)" - ) - - @torch.jit.script def check_within_follow_up( new_time: torch.Tensor, time: torch.Tensor, within_follow_up: bool @@ -105,11 +80,45 @@ def validate_new_time( check_within_follow_up(new_time, time, within_follow_up) -import torch +def validate_survival_data(event, time): + """Perform format and validity checks for survival data. + + Args: + event (torch.Tensor, boolean): + Event indicator of size n_samples (= True if event occured). + time (torch.Tensor, float): + Event or censoring time of size n_samples. + + Raises: + TypeError: If ``event`` or ``time`` are not tensors. + ValueError: If ``event`` is not boolean. + ValueError: If ``event`` and ``time`` are not of the same length. + ValueError: If all ``event`` are False. + ValueError: If any ``time`` are negative. + """ + if not torch.is_tensor(event) or not torch.is_tensor(time): + raise TypeError("Inputs 'event' and 'time' should be tensors") + + if not event.dtype == torch.bool: + raise ValueError("Input 'event' should be of boolean type.") + + if not torch.is_floating_point(time): + raise ValueError("Input 'time' should be of float type.") + + if len(event) != len(time): + raise ValueError( + "Dimension mismatch: Incompatible length between inputs 'time' and 'event'." + ) + + if torch.sum(event) <= 0: + raise ValueError("All samples are censored.") + + if torch.any(time < 0.0): + raise ValueError("Input 'time' should be non-negative.") @torch.jit.script -def validate_inputs( +def validate_loss( log_params: torch.Tensor, event: torch.Tensor, time: torch.Tensor, model_type: str ) -> None: """ @@ -155,9 +164,10 @@ def validate_inputs( ) if model_type == "weibull": - if log_params.shape[1] != 2: + # if log_params.shape[1] is not 1 or 2: + if log_params.shape[1] not in [1, 2]: raise ValueError( - "For Weibull model, 'log_params' must have shape (n_samples, 2)." + f"For Weibull model, 'log_params' must have shape (n_samples, 2) or (n_samples, 1)." ) elif model_type == "cox": if log_params.shape[1] != 1: @@ -176,7 +186,7 @@ def validate_inputs( time = torch.tensor([10, 20, 30, 40, 50], dtype=torch.float32) # Validate Weibull model inputs - validate_inputs(log_params_weibull, event, time, model_type="weibull") + validate_loss(log_params_weibull, event, time, model_type="weibull") # Validate Cox model inputs - validate_inputs(log_params_cox, event, time, model_type="cox") + validate_loss(log_params_cox, event, time, model_type="cox")