diff --git a/docs/notebooks/introduction.ipynb b/docs/notebooks/introduction.ipynb index bdfb74d..eec04fe 100644 --- a/docs/notebooks/introduction.ipynb +++ b/docs/notebooks/introduction.ipynb @@ -28,7 +28,7 @@ "# %pip install lifelines\n", "# %pip install matplotlib\n", "# %pip install sklearn\n", - "# %pip install pandas\n" + "# %pip install pandas" ] }, { diff --git a/src/torchsurv/loss/cox.py b/src/torchsurv/loss/cox.py index ce49f46..3eba4e6 100644 --- a/src/torchsurv/loss/cox.py +++ b/src/torchsurv/loss/cox.py @@ -6,7 +6,88 @@ import torch +from torchsurv.tools.validate_data import validate_loss + +@torch.jit.script +def _partial_likelihood_cox( + log_hz_sorted: torch.Tensor, + event_sorted: torch.Tensor, +) -> torch.Tensor: + """Calculate the partial log likelihood for the Cox proportional hazards model + in the absence of ties in event time. + """ + log_hz_flipped = log_hz_sorted.flip(0) + log_denominator = torch.logcumsumexp(log_hz_flipped, dim=0).flip(0) + return (log_hz_sorted - log_denominator)[event_sorted] + + +@torch.jit.script +def _partial_likelihood_efron( + log_hz_sorted: torch.Tensor, + event_sorted: torch.Tensor, + time_sorted: torch.Tensor, + time_unique: torch.Tensor, +) -> torch.Tensor: + """Calculate the partial log likelihood for the Cox proportional hazards model + using Efron's method to handle ties in event time. + """ + J = len(time_unique) + + H = [ + torch.where((time_sorted == time_unique[j]) & (event_sorted == 1))[0] + for j in range(J) + ] + R = [torch.where(time_sorted >= time_unique[j])[0] for j in range(J)] + + # Calculate the length of each element in H and store it in a tensor + m = torch.tensor([len(h) for h in H]) + + # Create a boolean tensor indicating whether each element in H has a length greater than 0 + include = torch.tensor([len(h) > 0 for h in H]) + + log_nominator = torch.stack([torch.sum(log_hz_sorted[h]) for h in H]) + + denominator_naive = torch.stack([torch.sum(torch.exp(log_hz_sorted[r])) for r in R]) + denominator_ties = torch.stack([torch.sum(torch.exp(log_hz_sorted[h])) for h in H]) + + log_denominator_efron = torch.zeros(J, device=log_hz_sorted.device) + for j in range(J): + mj = int(m[j].item()) + for l in range(1, mj + 1): + log_denominator_efron[j] += torch.log( + denominator_naive[j] - (l - 1) / float(m[j]) * denominator_ties[j] + ) + return (log_nominator - log_denominator_efron)[include] + + +@torch.jit.script +def _partial_likelihood_breslow( + log_hz_sorted: torch.Tensor, + event_sorted: torch.Tensor, + time_sorted: torch.Tensor, +): + """ + Compute the partial likelihood using Breslow's method for Cox proportional hazards model. + + Args: + log_hz_sorted (torch.Tensor): Log hazard rates sorted by time. + event_sorted (torch.Tensor): Binary tensor indicating if the event occurred (1) or was censored (0), sorted by time. + time_sorted (torch.Tensor): Event or censoring times sorted in ascending order. + + Returns: + torch.Tensor: The partial likelihood for the observed events. + """ + N = len(time_sorted) + R = [torch.where(time_sorted >= time_sorted[i])[0] for i in range(N)] + log_denominator = torch.stack( + [torch.logsumexp(log_hz_sorted[R[i]], dim=0) for i in range(N)] + ) + + return (log_hz_sorted - log_denominator)[event_sorted] + + +@torch.jit.script def neg_partial_log_likelihood( log_hz: torch.Tensor, event: torch.Tensor, @@ -118,9 +199,9 @@ def neg_partial_log_likelihood( """ if checks: - _check_inputs(log_hz, event, time) + validate_loss(log_hz, event, time, model_type="cox") - if any([event.sum() == 0, len(log_hz.size()) == 0]): + if any([event.sum().item() == 0, len(log_hz.size()) == 0]): warnings.warn("No events OR single sample. Returning zero loss for the batch") return torch.tensor(0.0, requires_grad=True) @@ -164,98 +245,6 @@ def neg_partial_log_likelihood( return loss -def _partial_likelihood_cox( - log_hz_sorted: torch.Tensor, - event_sorted: torch.Tensor, -) -> torch.Tensor: - """Calculate the partial log likelihood for the Cox proportional hazards model - in the absence of ties in event time. - """ - log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0) - return (log_hz_sorted - log_denominator)[event_sorted] - - -def _partial_likelihood_efron( - log_hz_sorted: torch.Tensor, - event_sorted: torch.Tensor, - time_sorted: torch.Tensor, - time_unique: torch.Tensor, -) -> torch.Tensor: - """Calculate the partial log likelihood for the Cox proportional hazards model - using Efron's method to handle ties in event time. - """ - J = len(time_unique) - - H = [ - torch.where((time_sorted == time_unique[j]) & (event_sorted == 1))[0] - for j in range(J) - ] - R = [torch.where(time_sorted >= time_unique[j])[0] for j in range(J)] - - m = torch.tensor([len(h) for h in H]) - include = torch.tensor([len(h) > 0 for h in H]) - - log_nominator = torch.stack([torch.sum(log_hz_sorted[h]) for h in H]) - - denominator_naive = torch.stack([torch.sum(torch.exp(log_hz_sorted[r])) for r in R]) - denominator_ties = torch.stack([torch.sum(torch.exp(log_hz_sorted[h])) for h in H]) - - log_denominator_efron = torch.zeros(J).to(log_hz_sorted.device) - for j in range(J): - for l in range(1, m[j] + 1): - log_denominator_efron[j] += torch.log( - denominator_naive[j] - (l - 1) / m[j] * denominator_ties[j] - ) - return (log_nominator - log_denominator_efron)[include] - - -def _partial_likelihood_breslow( - log_hz_sorted: torch.Tensor, - event_sorted: torch.Tensor, - time_sorted: torch.Tensor, -): - """Calculate the partial log likelihood for the Cox proportional hazards model - using Breslow's method to handle ties in event time. - """ - N = len(time_sorted) - - R = [torch.where(time_sorted >= time_sorted[i])[0] for i in range(N)] - log_denominator = torch.tensor( - [torch.logsumexp(log_hz_sorted[R[i]], dim=0) for i in range(N)] - ) - - return (log_hz_sorted - log_denominator)[event_sorted] - - -def _check_inputs(log_hz: torch.Tensor, event: torch.Tensor, time: torch.Tensor): - if not isinstance(log_hz, torch.Tensor): - raise TypeError("Input 'log_hz' must be a tensor.") - - if not isinstance(event, torch.Tensor): - raise TypeError("Input 'event' must be a tensor.") - - if not isinstance(time, torch.Tensor): - raise TypeError("Input 'time' must be a tensor.") - - if len(log_hz) != len(event): - raise ValueError( - "Length mismatch: 'log_hz' and 'event' must have the same length." - ) - - if len(time) != len(event): - raise ValueError( - "Length mismatch: 'time' must have the same length as 'event'." - ) - - if any(val < 0 for val in time): - raise ValueError("Invalid values: All elements in 'time' must be non-negative.") - - if any(val not in [True, False, 0, 1] for val in event): - raise ValueError( - "Invalid values: 'event' must contain only boolean values (True/False or 1/0)" - ) - - if __name__ == "__main__": import doctest diff --git a/src/torchsurv/loss/momentum.py b/src/torchsurv/loss/momentum.py index 8cab479..98f640a 100644 --- a/src/torchsurv/loss/momentum.py +++ b/src/torchsurv/loss/momentum.py @@ -151,15 +151,15 @@ def forward( """ - estimate_q = self.online(inputs) - for estimate in zip(estimate_q, event, time): - self.memory_q.append(self.survtuple(*list(estimate))) + online_estimate = self.online(inputs) + for estimate in zip(online_estimate, event, time): + self.memory_q.append(self.survtuple(*estimate)) loss = self._bank_loss() with torch.no_grad(): self._update_momentum_encoder() - estimate_k = self.target(inputs) - for estimate in zip(estimate_k, event, time): - self.memory_k.append(self.survtuple(*list(estimate))) + target_estimate = self.target(inputs) + for estimate in zip(target_estimate, event, time): + self.memory_k.append(self.survtuple(*estimate)) return loss @torch.no_grad() # deactivates autograd @@ -187,27 +187,33 @@ def infer(self, inputs: torch.Tensor) -> torch.Tensor: return self.target(inputs) def _bank_loss(self) -> torch.Tensor: - """computer the negative loss likelyhood from memory bank""" + """compute the negative log-likelihood from memory bank""" # Combine current batch and momentum bank = self.memory_k + self.memory_q assert all( x in bank[0]._fields for x in ["estimate", "event", "time"] ), "All fields must be present" - return self.loss( - torch.stack([mem.estimate.cpu() for mem in bank]).squeeze(), - torch.stack([mem.event.cpu() for mem in bank]).squeeze(), - torch.stack([mem.time.cpu() for mem in bank]).squeeze(), - ) + log_estimates = torch.stack([mem.estimate.cpu() for mem in bank]).squeeze() + events = torch.stack([mem.event.cpu() for mem in bank]).squeeze() + times = torch.stack([mem.time.cpu() for mem in bank]).squeeze() + return self.loss(log_estimates, events, times) @torch.no_grad() def _update_momentum_encoder(self): - """Exponantial moving average""" + """Exponential moving average""" for param_b, param_m in zip(self.online.parameters(), self.target.parameters()): param_m.data = param_m.data * self.rate + param_b.data * (1.0 - self.rate) @torch.no_grad() def _init_encoder_k(self): + """ + Initialize the target network (encoder_k) with the parameters of the online network (encoder_q). + The requires_grad attribute of the target network parameters is set to False to prevent gradient updates during training, + ensuring that the target network remains a stable reference point. + This method uses the `copy_` method to copy the parameters from the online network to the target network + and sets the requires_grad attribute of the target network parameters to False to prevent gradient updates. + """ for param_q, param_k in zip(self.online.parameters(), self.target.parameters()): param_k.data.copy_(param_q.data) param_k.requires_grad = False diff --git a/src/torchsurv/loss/weibull.py b/src/torchsurv/loss/weibull.py index 926488a..469aa39 100644 --- a/src/torchsurv/loss/weibull.py +++ b/src/torchsurv/loss/weibull.py @@ -2,7 +2,142 @@ import torch -TORCH_CLAMP_VALUE = 1e10 +from torchsurv.tools.validate_data import validate_log_shape, validate_loss + + +def cumulative_hazard( + log_params: torch.Tensor, + time: torch.Tensor, + all_times: bool = True, + clamp_value: float = 1e10, +) -> torch.Tensor: + """Cumulative hazard for the Weibull Accelerated Time Failure (AFT) survival model. + + Args: + log_params (torch.Tensor, float): + Parameters of the Weibull distribution of shape = (n_samples, 1) or (n_samples, 2). + The first column corresponds to the log scale parameter. The second column + corresponds to the log shape parameter. If the log shape parameter is missing, it is + imputed with 0. + time (torch.Tensor, float): + Time-to-event or censoring of length n_samples. + all_times (bool) + If True, subject-specific cumulative hazard is evaluated at all ``time`` (used for evaluation metrics). + If False, subject-specific cumulative hazard is evaluated at respective ``time``. + Defaults is True. + + Returns: + (torch.Tensor, float): Subject-specific cumulative hazard evaluated at ``time``. + + Examples: + >>> _ = torch.manual_seed(42) + >>> time = torch.randint(low=1, high=100, size=(4,)) + >>> log_params = torch.randn((4, 2)) + >>> cumulative_hazard(log_params, time, all_times=False) # Cumulative hazard at respective time + tensor([ 8.6257, 112.2115, 3.5105, 112.6339]) + >>> cumulative_hazard(log_params, time, all_times=True) # Default. Cumulative hazard at all time + tensor([[ 8.6257, 233.0865, 239.2167, 126.2805], + [ 12.7698, 112.2115, 114.1484, 74.9134], + [ 0.8706, 3.4725, 3.5105, 2.6850], + [ 6.9530, 212.7592, 218.5687, 112.6339]]) + """ + log_scale, log_shape = validate_log_shape(log_params).unbind(1) + + if all_times: + # Use all times for each sample + time = time.unsqueeze(0).expand(len(time), len(time)) # expand across rows + log_scale = log_scale.unsqueeze(1).expand( + len(time), len(time) + ) # expand across columns + log_shape = log_shape.unsqueeze(1).expand( + len(time), len(time) + ) # expand across columns + + return torch.clamp( + torch.exp( + torch.exp(log_shape) + * (torch.log(torch.clamp(time, min=1e-100, max=torch.inf)) - log_scale) + ), + min=0, + max=clamp_value, + ) + + +def log_hazard( + log_params: torch.Tensor, + time: torch.Tensor, + all_times: bool = True, + clamp_value: float = 1e10, +) -> torch.Tensor: + """Log hazard of the Weibull Accelerated Time Failure (AFT) survival model. + + Args: + log_params (torch.Tensor, float): + Parameters of the Weibull distribution of shape = (n_samples, 1) or (n_samples, 2). + The first column corresponds to the log scale parameter. The second column + corresponds to the log shape parameter. If the log shape parameter is missing, it is + imputed with 0. + time (torch.Tensor, float): + Time at which to evaluate the log hazard. + Should be of length n_samples to evaluate the log hazard at observed time-to-event or censoring, + or of length one to evaluate the log hazard at a new time. + all_times (bool): + If True, subject-specific log hazard is evaluated at all ``time`` (used for evaluation metrics). + If False, subject-specific log hazard is evaluated at respective ``time``. + Defaults is True. + Ignored if ``time`` is of length one. + + Returns: + (torch.Tensor, float): Subject-specific log hazard evaluated at ``time``. + + Examples: + >>> _ = torch.manual_seed(42) + >>> time = torch.randint(low=1, high=100, size=(4,)) + >>> log_params = torch.randn((4, 2)) + >>> log_hazard(log_params, time, all_times = False) # Log hazard at respective time + tensor([ 0.4392, -0.0303, -3.9672, 0.9140]) + >>> log_hazard(log_params, time, all_times = True) # Default. Log hazard at all time + tensor([[ 0.4392, 1.1174, 1.1227, 0.9913], + [ 0.4148, -0.0303, -0.0338, 0.0525], + [-2.7225, -3.9575, -3.9672, -3.7279], + [ 0.2606, 1.0632, 1.0695, 0.9140]]) + >>> log_hazard(log_params, time=torch.tensor(10.0)) # Log hazard at one new time (e.g., 10 years) + tensor([ 0.5316, 0.3542, -2.8907, 0.3699]) + >>> for t in torch.tensor([100.0, 150.0]): log_hazard(log_params, time=t) # Subject-specific log hazard at multiple new times + tensor([ 1.1280, -0.0372, -3.9767, 1.0757]) + tensor([ 1.2330, -0.1062, -4.1680, 1.1999]) + >>> log_params *= 1e2 # Increase scale + >>> log_hazard(log_params, time, all_times = False) # Check for Torch.Inf values + tensor([-1.0000e+10, -2.3197e+01, -6.8385e+01, -1.0000e+10]) + """ + + log_scale, log_shape = validate_log_shape(log_params).unbind(1) + + if time.dim() == 0: + # Use fixed time for each sample + time = time.repeat(len(log_params)) + elif time.size(0) == log_params.size(0) and all_times: + # Use all times for each sample + time = time.unsqueeze(0).expand(len(time), len(time)) # expand across rows + log_scale = log_scale.unsqueeze(1).expand( + len(time), len(time) + ) # expand across columns + log_shape = log_shape.unsqueeze(1).expand( + len(time), len(time) + ) # expand across columns + if time.size(0) != log_params.size(0): + raise ValueError( + f"Dimension mismatch: 'time' ({len(time)}) does not match the length of 'log_params' ({len(log_params)})." + ) + + return torch.clamp( + log_shape + - log_scale + + torch.expm1(log_shape) + * (torch.log(torch.clamp(time, min=1e-100, max=torch.inf)) - log_scale), + min=-clamp_value, + max=clamp_value, + ) def neg_log_likelihood( @@ -95,12 +230,12 @@ def neg_log_likelihood( """ if checks: - _check_inputs(log_params, event, time) + validate_loss(log_params, event, time, model_type="weibull") # Negative log likelihood nll = torch.neg( - event * log_hazard(log_params, time, all_times=False) - - cumulative_hazard(log_params, time, all_times=False) # Huge values here + event * log_hazard(log_params, time, False) + - cumulative_hazard(log_params, time, False) # Huge values here ) if any(torch.isinf(nll)): @@ -163,7 +298,7 @@ def survival_function( """ - log_scale, log_shape = _check_log_shape(log_params).unbind(1) + log_scale, log_shape = validate_log_shape(log_params).unbind(1) if time.dim() == 0: # Use one time for each sample @@ -186,183 +321,6 @@ def survival_function( ).cdf(time) -def log_hazard( - log_params: torch.Tensor, time: torch.Tensor, all_times: bool = True -) -> torch.Tensor: - """Log hazard of the Weibull Accelerated Time Failure (AFT) survival model. - - Args: - log_params (torch.Tensor, float): - Parameters of the Weibull distribution of shape = (n_samples, 1) or (n_samples, 2). - The first column corresponds to the log scale parameter. The second column - corresponds to the log shape parameter. If the log shape parameter is missing, it is - imputed with 0. - time (torch.Tensor, float): - Time at which to evaluate the log hazard. - Should be of length n_samples to evaluate the log hazard at observed time-to-event or censoring, - or of length one to evaluate the log hazard at a new time. - all_times (bool): - If True, subject-specific log hazard is evaluated at all ``time`` (used for evaluation metrics). - If False, subject-specific log hazard is evaluated at respective ``time``. - Defaults is True. - Ignored if ``time`` is of length one. - - Returns: - (torch.Tensor, float): Subject-specific log hazard evaluated at ``time``. - - Examples: - >>> _ = torch.manual_seed(42) - >>> time = torch.randint(low=1, high=100, size=(4,)) - >>> log_params = torch.randn((4, 2)) - >>> log_hazard(log_params, time, all_times = False) # Log hazard at respective time - tensor([ 0.4392, -0.0303, -3.9672, 0.9140]) - >>> log_hazard(log_params, time, all_times = True) # Default. Log hazard at all time - tensor([[ 0.4392, 1.1174, 1.1227, 0.9913], - [ 0.4148, -0.0303, -0.0338, 0.0525], - [-2.7225, -3.9575, -3.9672, -3.7279], - [ 0.2606, 1.0632, 1.0695, 0.9140]]) - >>> log_hazard(log_params, time=torch.tensor(10.0)) # Log hazard at one new time (e.g., 10 years) - tensor([ 0.5316, 0.3542, -2.8907, 0.3699]) - >>> for t in torch.tensor([100.0, 150.0]): log_hazard(log_params, time=t) # Subject-specific log hazard at multiple new times - tensor([ 1.1280, -0.0372, -3.9767, 1.0757]) - tensor([ 1.2330, -0.1062, -4.1680, 1.1999]) - >>> log_params *= 1e2 # Increase scale - >>> log_hazard(log_params, time, all_times = False) # Check for Torch.Inf values - tensor([-1.0000e+10, -2.3197e+01, -6.8385e+01, -1.0000e+10]) - """ - - log_scale, log_shape = _check_log_shape(log_params).unbind(1) - - if time.dim() == 0: - # Use fixed time for each sample - time = time.repeat(len(log_params)) - elif all([time.size(0) == log_params.size(0), all_times]): - # Use all times for each sample - time = time.unsqueeze(0).expand(len(time), len(time)) # expand across rows - log_scale = log_scale.unsqueeze(1).expand( - len(time), len(time) - ) # expand across columns - log_shape = log_shape.unsqueeze(1).expand( - len(time), len(time) - ) # expand across columns - if time.size(0) != log_params.size(0): - raise ValueError( - f"Dimension mismatch: 'time' ({len(time)}) does not match the length of 'log_params' ({len(log_params)})." - ) - - return torch.clamp( - log_shape - - log_scale - + torch.expm1(log_shape) - * (torch.log(torch.clip(time, 1e-100, torch.inf)) - log_scale), - min=-TORCH_CLAMP_VALUE, - max=TORCH_CLAMP_VALUE, - ) - - -def cumulative_hazard( - log_params: torch.Tensor, time: torch.Tensor, all_times: bool = True -) -> torch.Tensor: - """Cumulative hazard for the Weibull Accelerated Time Failure (AFT) survival model. - - Args: - log_params (torch.Tensor, float): - Parameters of the Weibull distribution of shape = (n_samples, 1) or (n_samples, 2). - The first column corresponds to the log scale parameter. The second column - corresponds to the log shape parameter. If the log shape parameter is missing, it is - imputed with 0. - time (torch.Tensor, float): - Time-to-event or censoring of length n_samples. - all_times (bool) - If True, subject-specific cumulative hazard is evaluated at all ``time`` (used for evaluation metrics). - If False, subject-specific cumulative hazard is evaluated at respective ``time``. - Defaults is True. - - Returns: - (torch.Tensor, float): Subject-specific cumulative hazard evaluated at ``time``. - - Examples: - >>> _ = torch.manual_seed(42) - >>> time = torch.randint(low=1, high=100, size=(4,)) - >>> log_params = torch.randn((4, 2)) - >>> cumulative_hazard(log_params, time, all_times=False) # Cumulative hazard at respective time - tensor([ 8.6257, 112.2115, 3.5105, 112.6339]) - >>> cumulative_hazard(log_params, time, all_times=True) # Default. Cumulative hazard at all time - tensor([[ 8.6257, 233.0865, 239.2167, 126.2805], - [ 12.7698, 112.2115, 114.1484, 74.9134], - [ 0.8706, 3.4725, 3.5105, 2.6850], - [ 6.9530, 212.7592, 218.5687, 112.6339]]) - """ - log_scale, log_shape = _check_log_shape(log_params).unbind(1) - - if all_times: - # Use all times for each sample - time = time.unsqueeze(0).expand(len(time), len(time)) # expand across rows - log_scale = log_scale.unsqueeze(1).expand( - len(time), len(time) - ) # expand across columns - log_shape = log_shape.unsqueeze(1).expand( - len(time), len(time) - ) # expand across columns - - return torch.clamp( - torch.exp( - torch.exp(log_shape) - * (torch.log(torch.clip(time, 1e-100, torch.inf)) - log_scale) - ), - min=0, - max=TORCH_CLAMP_VALUE, - ) - - -def _check_log_shape(log_params: torch.Tensor) -> torch.Tensor: - """Private function, check if the log shape is missing and impute it with 0 - if needed.""" - if any( - [ - log_params.dim() == 0, - log_params.dim() == 1, # if shape = [n_samples] - log_params.dim() > 1 - and log_params.size(1) == 1, # if shape = [n_samples, 1] - ] - ): - if log_params.dim() == 1: - log_params = log_params.unsqueeze(1) - - # Missing log shape parameter. Creating zeros placeholder instead. - log_params = torch.hstack((log_params, torch.zeros_like(log_params))) - - return log_params - - -def _check_inputs(log_params: torch.Tensor, event: torch.Tensor, time: torch.Tensor): - """Private function, perform input format checks.""" - if not isinstance(log_params, torch.Tensor): - raise TypeError("Input 'log_params' must be a tensor.") - - if not isinstance(event, torch.Tensor): - raise TypeError("Input 'event' must be a tensor.") - - if not isinstance(time, torch.Tensor): - raise TypeError("``Input 'time' must be a tensor.") - - if log_params.shape[0] != len(event): - raise ValueError( - "Length mismatch: The length of 'log_params' must match the length of 'event'." - ) - - if len(time) != len(event): - raise ValueError( - "Length mismatch: The length of 'time' must match the length of 'event'.`" - ) - - if any(val < 0 for val in time): - raise ValueError("All elements in 'time' must be non-negative.") - - if any(val not in [True, False, 0, 1] for val in event): - raise ValueError("All elements in 'event' must be boolean (True/False or 0/1).") - - if __name__ == "__main__": import doctest diff --git a/src/torchsurv/metrics/auc.py b/src/torchsurv/metrics/auc.py index b5a6a98..303e68e 100644 --- a/src/torchsurv/metrics/auc.py +++ b/src/torchsurv/metrics/auc.py @@ -6,8 +6,12 @@ from scipy import stats from torchmetrics import regression -from ..stats import kaplan_meier -from ..tools import validate_inputs +from torchsurv.stats import kaplan_meier +from torchsurv.tools.validate_data import ( + validate_log_shape, + validate_new_time, + validate_survival_data, +) class Auc: @@ -212,9 +216,9 @@ def __call__( # further input format checks if self.checks: - validate_inputs.validate_survival_data(event, time) - validate_inputs.validate_evaluation_time(new_time, time) - validate_inputs.validate_estimate(estimate, time, new_time) + validate_survival_data(event, time) + validate_new_time(new_time, time) + validate_log_shape(estimate) # sample size and length of time n_samples, n_times = estimate.shape[0], new_time.shape[0] @@ -579,7 +583,7 @@ def compare( return pvalue # pylint: disable=invalid-name - def _integrate_incident(self, S: torch.tensor, tmax: torch.tensor) -> torch.Tensor: + def _integrate_incident(self, S: torch.Tensor, tmax: torch.Tensor) -> torch.Tensor: """Integrates the incident/dynamic AUC, int_t AUC(t) x w(t) dt where w(t) = 2*f(t)*S(t) and f(t) is the lifeline distribution, S(t) is the survival distribution estimated with the Kaplan @@ -617,7 +621,7 @@ def _integrate_incident(self, S: torch.tensor, tmax: torch.tensor) -> torch.Tens # pylint: disable=invalid-name def _integrate_cumulative( - self, S: torch.tensor, tmax: torch.tensor + self, S: torch.Tensor, tmax: torch.Tensor ) -> torch.Tensor: """Integrates the cumulative/dynamic AUC, int_t AUC(t) ยท f(t) dt where f(t) is the lifeline distribution estimated from the discrete @@ -1168,7 +1172,7 @@ def _bootstrap_auc( @staticmethod def _find_torch_unique_indices( - inverse_indices: torch.tensor, counts: torch.tensor + inverse_indices: torch.Tensor, counts: torch.Tensor ) -> torch.tensor: """return unique_sorted_indices such that sorted_unique_tensor[inverse_indices] = original_tensor @@ -1214,12 +1218,12 @@ def _validate_auc_inputs( @staticmethod def _update_auc_new_time( - estimate: torch.tensor, - event: torch.tensor, - time: torch.tensor, - new_time: torch.tensor, - weight: torch.tensor, - weight_new_time: torch.tensor, + estimate: torch.Tensor, + event: torch.Tensor, + time: torch.Tensor, + new_time: torch.Tensor, + weight: torch.Tensor, + weight_new_time: torch.Tensor, ) -> torch.tensor: # update new time if ( @@ -1255,7 +1259,7 @@ def _update_auc_new_time( @staticmethod def _update_auc_estimate( - estimate: torch.tensor, new_time: torch.tensor + estimate: torch.Tensor, new_time: torch.Tensor ) -> torch.tensor: # squeeze estimate if shape = (n_samples, 1) if estimate.ndim == 2 and estimate.shape[1] == 1: @@ -1271,10 +1275,10 @@ def _update_auc_estimate( @staticmethod def _update_auc_weight( - time: torch.tensor, - new_time: torch.tensor, - weight: torch.tensor, - weight_new_time: torch.tensor, + time: torch.Tensor, + new_time: torch.Tensor, + weight: torch.Tensor, + weight_new_time: torch.Tensor, ) -> torch.tensor: # if weight was not specified, weight of 1 if weight is None: diff --git a/src/torchsurv/metrics/brier_score.py b/src/torchsurv/metrics/brier_score.py index e50ebe1..77c1f53 100644 --- a/src/torchsurv/metrics/brier_score.py +++ b/src/torchsurv/metrics/brier_score.py @@ -5,7 +5,11 @@ import torch from scipy import stats -from ..tools import validate_inputs +from torchsurv.tools.validate_data import ( + validate_log_shape, + validate_new_time, + validate_survival_data, +) class BrierScore: @@ -185,11 +189,9 @@ def __call__( # further input format checks if self.checks: - validate_inputs.validate_survival_data(event, time) - validate_inputs.validate_evaluation_time( - new_time, time, within_follow_up=False - ) - validate_inputs.validate_estimate(estimate, time, new_time) + validate_survival_data(event, time) + validate_new_time(new_time, time, within_follow_up=False) + validate_log_shape(estimate) # Calculating the residuals for each subject and time point residuals = torch.zeros_like(estimate) @@ -802,7 +804,7 @@ def _bootstrap_brier_score( @staticmethod def _find_torch_unique_indices( - inverse_indices: torch.tensor, counts: torch.tensor + inverse_indices: torch.Tensor, counts: torch.Tensor ) -> torch.tensor: """return unique_sorted_indices such that sorted_unique_tensor[inverse_indices] = original_tensor @@ -824,11 +826,11 @@ def _find_torch_unique_indices( @staticmethod def _validate_brier_score_inputs( - estimate: torch.tensor, - time: torch.tensor, - new_time: torch.tensor, - weight: torch.tensor, - weight_new_time: torch.tensor, + estimate: torch.Tensor, + time: torch.Tensor, + new_time: torch.Tensor, + weight: torch.Tensor, + weight_new_time: torch.Tensor, ) -> torch.tensor: # check new_time and weight are provided, weight_new_time should be provided if all([new_time is not None, weight is not None, weight_new_time is None]): @@ -855,11 +857,11 @@ def _validate_brier_score_inputs( @staticmethod def _update_brier_score_new_time( - estimate: torch.tensor, - time: torch.tensor, - new_time: torch.tensor, - weight: torch.tensor, - weight_new_time: torch.tensor, + estimate: torch.Tensor, + time: torch.Tensor, + new_time: torch.Tensor, + weight: torch.Tensor, + weight_new_time: torch.Tensor, ) -> torch.tensor: # check format of new_time if ( @@ -891,10 +893,10 @@ def _update_brier_score_new_time( @staticmethod def _update_brier_score_weight( - time: torch.tensor, - new_time: torch.tensor, - weight: torch.tensor, - weight_new_time: torch.tensor, + time: torch.Tensor, + new_time: torch.Tensor, + weight: torch.Tensor, + weight_new_time: torch.Tensor, ) -> torch.tensor: # if weight was not specified, weight of 1 if weight is None: diff --git a/src/torchsurv/metrics/cindex.py b/src/torchsurv/metrics/cindex.py index 9276a07..4882c6d 100644 --- a/src/torchsurv/metrics/cindex.py +++ b/src/torchsurv/metrics/cindex.py @@ -7,7 +7,7 @@ from scipy import stats from torchmetrics import regression -from ..tools import validate_inputs +from torchsurv.tools.validate_data import validate_log_shape, validate_survival_data class ConcordanceIndex: @@ -185,8 +185,8 @@ def __call__( # Inputs checks if self.checks: - validate_inputs.validate_survival_data(event, time) - validate_inputs.validate_estimate(estimate, time) + validate_survival_data(event, time) + validate_log_shape(estimate) # find comparable pairs comparable = self._get_comparable_and_tied_time(event, time) diff --git a/src/torchsurv/stats/ipcw.py b/src/torchsurv/stats/ipcw.py index 871cdd7..20446c9 100644 --- a/src/torchsurv/stats/ipcw.py +++ b/src/torchsurv/stats/ipcw.py @@ -1,18 +1,17 @@ -import sys import warnings from typing import Optional import torch -from ..tools import validate_inputs -from . import kaplan_meier +from torchsurv.stats import kaplan_meier +from torchsurv.tools.validate_data import validate_survival_data # pylint: disable=anomalous-backslash-in-string def get_ipcw( - event: torch.tensor, - time: torch.tensor, - new_time: Optional[torch.tensor] = None, + event: torch.Tensor, + time: torch.Tensor, + new_time: Optional[torch.Tensor] = None, checks: bool = True, ) -> torch.Tensor: """Calculate the inverse probability censoring weights (IPCW). @@ -56,7 +55,7 @@ def get_ipcw( """ if checks: - validate_inputs.validate_survival_data(event, time) + validate_survival_data(event, time) # time on which to evaluate IPCW if new_time is None: # if none, return ipcw of same size as time @@ -95,12 +94,14 @@ def _inverse_censoring_dist(ct: torch.Tensor) -> torch.Tensor: tensor([2.9701, 7.7634, 4.2651, 4.3415]) """ - if torch.any(ct.eq(0.0)): + if torch.any(ct == 0.0): + zero_indices = torch.nonzero(ct.eq(0.0)).squeeze() + zero_indices_list = zero_indices.tolist() # Explicitly convert to list warnings.warn( - "Censoring distribution zero at one or more time points. Returning ones as weight" + f"Censoring distribution zero at time points: {zero_indices_list}. Returning ones as weight" ) - return torch.ones_like(ct, dtype=ct.dtype) - weight = torch.ones(1, dtype=ct.dtype) / ct + weight = 1.0 / ct + weight = torch.ones_like(ct) / ct return weight @@ -111,6 +112,3 @@ def _inverse_censoring_dist(ct: torch.Tensor) -> torch.Tensor: results = doctest.testmod() if results.failed == 0: print("All tests passed.") - else: - print("Some doctests failed.") - sys.exit(1) diff --git a/src/torchsurv/stats/kaplan_meier.py b/src/torchsurv/stats/kaplan_meier.py index 750c2b2..da79e68 100644 --- a/src/torchsurv/stats/kaplan_meier.py +++ b/src/torchsurv/stats/kaplan_meier.py @@ -4,7 +4,7 @@ import torch -from ..tools import validate_inputs +from torchsurv.tools.validate_data import validate_survival_data class KaplanMeierEstimator: @@ -12,8 +12,8 @@ class KaplanMeierEstimator: def __call__( self, - event: torch.tensor, - time: torch.tensor, + event: torch.Tensor, + time: torch.Tensor, censoring_dist: bool = False, check: bool = True, ): @@ -62,7 +62,7 @@ def __call__( # Check input validity if required if check: - validate_inputs.validate_survival_data(event, time) + validate_survival_data(event, time) # Compute the counts of events, censorings, and the number at risk at each unique time uniq_times, n_events, n_at_risk, n_censored = self._compute_counts() @@ -200,7 +200,7 @@ def print_survival_table(self): def _compute_counts( self, - ) -> Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the counts of events, censorings and risk set at ``time``. Returns: Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor] diff --git a/src/torchsurv/tools/validate_data.py b/src/torchsurv/tools/validate_data.py new file mode 100644 index 0000000..571dd8f --- /dev/null +++ b/src/torchsurv/tools/validate_data.py @@ -0,0 +1,192 @@ +import torch + + +@torch.jit.script +def validate_log_shape(log_params: torch.Tensor) -> torch.Tensor: + """Private function, check if the log shape is missing and impute it with 0 + if needed.""" + if any( + [ + log_params.dim() == 0, + log_params.dim() == 1, # if shape = [n_samples] + log_params.dim() > 1 + and log_params.size(1) == 1, # if shape = [n_samples, 1] + ] + ): + if log_params.dim() == 1: + log_params = log_params.unsqueeze(1) + + # Missing log shape parameter. Creating zeros placeholder instead. + log_params = torch.hstack((log_params, torch.zeros_like(log_params))) + + return log_params + + +@torch.jit.script +def check_within_follow_up( + new_time: torch.Tensor, time: torch.Tensor, within_follow_up: bool +) -> None: + # Check if the within_follow_up flag is set to True + if within_follow_up: + # Check if any value in new_time is outside the range of time + if new_time.max() >= time.max() or new_time.min() < time.min(): + # Get the minimum and maximum values of time + min_time = time.min().item() + max_time = time.max().item() + # Raise a ValueError if new_time is not within the follow-up time range + raise ValueError( + "Value error: All new_time must be within follow-up time of test data: [{}; {}[".format( + min_time, max_time + ) + ) + + +@torch.jit.script +def validate_new_time( + new_time: torch.Tensor, time: torch.Tensor, within_follow_up: bool = True +) -> None: + """ + Validate the new_time tensor for survival analysis functions. + + Args: + new_time (torch.Tensor, float): Time points for metric computation of size n_times. + time (torch.Tensor, float): Event or censoring time of size n_samples. + within_follow_up (bool, optional): Whether values of ``new_time`` must be within values in ``time``. Defaults to True. + + Raises: + ValueError: If ``new_time`` contains duplicate values. + ValueError: If ``new_time`` is not sorted. + TypeError: If ``new_time`` is not a tensor. + ValueError: If ``new_time`` is not of floating-point type. + ValueError: If ``new_time`` is not within the range of follow-up in ``time``. Assessed only if ``within_follow_up`` is True. + """ + if not isinstance(new_time, torch.Tensor): + raise TypeError("Type error: Input 'new_time' should be a tensor.") + + if not torch.is_floating_point(new_time): + raise ValueError( + "Value error: Input 'new_time' should be of floating-point type." + ) + + new_time_sorted, _ = torch.sort(new_time) + if not torch.equal(new_time_sorted, new_time): + raise ValueError( + "Value error: Input 'new_time' should be sorted from the smallest time to the largest." + ) + + if len(new_time_sorted) != len(torch.unique(new_time_sorted)): + raise ValueError("Value error: Input 'new_time' should contain unique values.") + + check_within_follow_up(new_time, time, within_follow_up) + + +def validate_survival_data(event, time): + """Perform format and validity checks for survival data. + + Args: + event (torch.Tensor, boolean): + Event indicator of size n_samples (= True if event occured). + time (torch.Tensor, float): + Event or censoring time of size n_samples. + + Raises: + TypeError: If ``event`` or ``time`` are not tensors. + ValueError: If ``event`` is not boolean. + ValueError: If ``event`` and ``time`` are not of the same length. + ValueError: If all ``event`` are False. + ValueError: If any ``time`` are negative. + """ + if not torch.is_tensor(event) or not torch.is_tensor(time): + raise TypeError("Inputs 'event' and 'time' should be tensors") + + if not event.dtype == torch.bool: + raise ValueError("Input 'event' should be of boolean type.") + + if not torch.is_floating_point(time): + raise ValueError("Input 'time' should be of float type.") + + if len(event) != len(time): + raise ValueError( + "Dimension mismatch: Incompatible length between inputs 'time' and 'event'." + ) + + if torch.sum(event) <= 0: + raise ValueError("All samples are censored.") + + if torch.any(time < 0.0): + raise ValueError("Input 'time' should be non-negative.") + + +@torch.jit.script +def validate_loss( + log_params: torch.Tensor, event: torch.Tensor, time: torch.Tensor, model_type: str +) -> None: + """ + Validate the inputs for survival analysis functions. + + Args: + log_params (torch.Tensor): Parameters of the model (log hazards or log parameters). + event (torch.Tensor): Event indicator tensor. + time (torch.Tensor): Time-to-event or censoring tensor. + model_type (str): Type of the model ('weibull' or 'cox'). + + Raises: + TypeError: If inputs are not tensors. + ValueError: If any `time` are negative. + ValueError: If `event` contains invalid values. + ValueError: If lengths of inputs do not match. + """ + if not isinstance(log_params, torch.Tensor): + raise TypeError("Input 'log_params' must be a tensor.") + + if not isinstance(event, torch.Tensor): + raise TypeError("Input 'event' must be a tensor.") + + if not isinstance(time, torch.Tensor): + raise TypeError("Input 'time' must be a tensor.") + + if log_params.shape[0] != len(event): + raise ValueError( + "Length mismatch: The length of 'log_params' must match the length of 'event'." + ) + + if len(time) != len(event): + raise ValueError( + "Length mismatch: The length of 'time' must match the length of 'event'." + ) + + if torch.any(time < 0): + raise ValueError("All elements in 'time' must be non-negative.") + + if torch.any((event != 0) & (event != 1)): + raise ValueError( + "Invalid values: 'event' must contain only boolean values (True/False or 1/0)" + ) + + if model_type == "weibull": + # if log_params.shape[1] is not 1 or 2: + if log_params.shape[1] not in [1, 2]: + raise ValueError( + f"For Weibull model, 'log_params' must have shape (n_samples, 2) or (n_samples, 1)." + ) + elif model_type == "cox": + if log_params.shape[1] != 1: + raise ValueError( + "For Cox model, 'log_params' must have shape (n_samples, 1)." + ) + else: + raise ValueError("Invalid model type. Must be 'weibull' or 'cox'.") + + +# Example usage +if __name__ == "__main__": + log_params_weibull = torch.randn((5, 2)) + log_params_cox = torch.randn((5, 1)) + event = torch.tensor([1, 0, 1, 1, 0], dtype=torch.float32) + time = torch.tensor([10, 20, 30, 40, 50], dtype=torch.float32) + + # Validate Weibull model inputs + validate_loss(log_params_weibull, event, time, model_type="weibull") + + # Validate Cox model inputs + validate_loss(log_params_cox, event, time, model_type="cox") diff --git a/src/torchsurv/tools/validate_inputs.py b/src/torchsurv/tools/validate_inputs.py deleted file mode 100644 index 1d19023..0000000 --- a/src/torchsurv/tools/validate_inputs.py +++ /dev/null @@ -1,125 +0,0 @@ -import torch - - -def validate_survival_data(event, time): - """Perform format and validity checks for survival data. - - Args: - event (torch.Tensor, boolean): - Event indicator of size n_samples (= True if event occured). - time (torch.Tensor, float): - Event or censoring time of size n_samples. - - Raises: - TypeError: If ``event`` or ``time`` are not tensors. - ValueError: If ``event`` is not boolean. - ValueError: If ``event`` and ``time`` are not of the same length. - ValueError: If all ``event`` are False. - ValueError: If any ``time`` are negative. - """ - if not torch.is_tensor(event) or not torch.is_tensor(time): - raise TypeError("Inputs 'event' and 'time' should be tensors") - - if not event.dtype == torch.bool: - raise ValueError("Input 'event' should be of boolean type.") - - if not torch.is_floating_point(time): - raise ValueError("Input 'time' should be of float type.") - - if len(event) != len(time): - raise ValueError( - "Dimension mismatch: Incompatible length between inputs 'time' and 'event'." - ) - - if torch.sum(event) <= 0: - raise ValueError("All samples are censored.") - - if torch.any(time < 0.0): - raise ValueError("Input 'time' should be non-negative.") - - -def validate_evaluation_time(new_time, time, within_follow_up=True): - """Perform format and validity checks for evaluation time. - - Args: - new_time (torch.Tensor, optional): - Time points for metric computation of size n_times. - time (torch.Tensor, float): - Event or censoring time of size n_samples. - within_follow_up (bool, optional): - Whether values of ``new_time`` must be within values in ``time``. - Defaults to True. - - Raises: - ValueError: If ``new_time`` contains duplicate values. - ValueError: If ``new_time`` is not sorted. - TypeError: If ``new_time`` is not a tensor. - ValueError: If ``new_time`` is not of floating-point type. - ValueError: - If ``new_time`` is not within the range of follow-up in ``time``. - Assessed only if ``within_follow_up`` is True. - """ - new_time_sorted, indices = torch.unique(new_time, return_inverse=True) - - if len(new_time_sorted) != len(new_time): - raise ValueError("'Value error: Input 'new_time' should contain unique values.") - - if not torch.all(indices == torch.arange(len(new_time))): - raise ValueError( - "Value error: Input 'new_time' should be sorted from the smallest time to the largest." - ) - - if not torch.is_tensor(new_time): - raise TypeError("Type error: Input 'new_time' should be a tensor.") - - if not torch.is_floating_point(new_time): - raise ValueError( - "Value error: Input 'new_time' should be of floating-point type." - ) - - if within_follow_up: - if new_time.max() >= time.max() or new_time.min() < time.min(): - raise ValueError( - "Value error: All new_time must be within follow-up time of test data: [{}; {}[".format( - time.min(), time.max() - ) - ) - - -def validate_estimate(estimate, time, new_time=None) -> None: - """Perform format and validity checks for estimate. - - Args: - estimate (torch.Tensor): - Estimates of shape = (n_samples,) or (n_samples, n_times). - time (torch.Tensor, float): - Time of event or censoring of size n_samples. - new_time (torch.Tensor, optional): - Time points for metric computation of size n_times. - - Raises: - TypeError: If ``estimate`` is not a tensor. - ValueError: If ``estimate`` has more than 2 dimensions. - ValueError: If number of rows of ``estimate`` is not n_samples. - ValueError: - If number of columns of ``estimate`` is not n_times. - Assessed only if ``new_time`` is not None. - """ - if not torch.is_tensor(estimate): - raise TypeError("Type error: Input 'estimate' should be a tensor.") - - if estimate.ndim > 2: - raise ValueError("Value error: Input 'estimate' should have 1 or 2 dimensions.") - - if estimate.shape[0] != time.shape[0]: - raise ValueError( - "Dimension mismatch: Inconsistent number of samples between inputs 'time' and 'estimate'." - ) - - if not new_time is None: - if estimate.ndim == 2 and estimate.shape[1] != new_time.shape[0]: - raise ValueError( - "Dimension mismatch: Expected inputs 'estimate' with {} columns, but got {}".format( - new_time.shape[0], estimate.shape[1] - ) - ) diff --git a/tests/test_kaplan_meier.py b/tests/test_kaplan_meier.py index 62e0a5d..76eaa6d 100644 --- a/tests/test_kaplan_meier.py +++ b/tests/test_kaplan_meier.py @@ -216,6 +216,45 @@ def test_kaplan_meier_prediction_error_raised(self): self.assertRaises(ValueError, km.predict, test_time) + def test_kaplan_meier_plot_km(self): + """test Kaplan Meier plot function""" + import matplotlib.pyplot as plt + + event = torch.tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 1], dtype=torch.bool) + time = torch.tensor([1, 2, 2, 3, 3, 4, 4, 5, 6, 7], dtype=torch.float32) + + km = KaplanMeierEstimator() + km(event, time, censoring_dist=False) + + fig, ax = plt.subplots() + km.plot_km(ax=ax) + plt.close(fig) # Close the plot to avoid displaying it during tests + + def test_kaplan_meier_print_survival_table(self): + """test Kaplan Meier print survival table function""" + event = torch.tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 1], dtype=torch.bool) + time = torch.tensor([1, 2, 2, 3, 3, 4, 4, 5, 6, 7], dtype=torch.float32) + + km = KaplanMeierEstimator() + km(event, time, censoring_dist=False) + km.print_survival_table() + + # Check if the survival table is printed correctly + expected_output = ( + "Time\tSurvival\n" + "----------------\n" + "1.00\t1.0000\n" + "2.00\t0.8889\n" + "3.00\t0.6667\n" + "4.00\t0.5000\n" + "5.00\t0.5000\n" + "6.00\t0.3333\n" + "7.00\t0.0000\n" + ) + with self.assertLogs(level="INFO") as log: + km.print_survival_table() + self.assertIn(expected_output, log.output) + if __name__ == "__main__": unittest.main()