Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

67 torch script #80

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/notebooks/introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"# %pip install lifelines\n",
"# %pip install matplotlib\n",
"# %pip install sklearn\n",
"# %pip install pandas\n"
"# %pip install pandas"
]
},
{
Expand Down
177 changes: 83 additions & 94 deletions src/torchsurv/loss/cox.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,88 @@

import torch

from torchsurv.tools.validate_data import validate_loss


@torch.jit.script
def _partial_likelihood_cox(
log_hz_sorted: torch.Tensor,
event_sorted: torch.Tensor,
) -> torch.Tensor:
"""Calculate the partial log likelihood for the Cox proportional hazards model
in the absence of ties in event time.
"""
log_hz_flipped = log_hz_sorted.flip(0)
log_denominator = torch.logcumsumexp(log_hz_flipped, dim=0).flip(0)
return (log_hz_sorted - log_denominator)[event_sorted]


@torch.jit.script
def _partial_likelihood_efron(
log_hz_sorted: torch.Tensor,
event_sorted: torch.Tensor,
time_sorted: torch.Tensor,
time_unique: torch.Tensor,
) -> torch.Tensor:
"""Calculate the partial log likelihood for the Cox proportional hazards model
using Efron's method to handle ties in event time.
"""
J = len(time_unique)

H = [
torch.where((time_sorted == time_unique[j]) & (event_sorted == 1))[0]
for j in range(J)
]
R = [torch.where(time_sorted >= time_unique[j])[0] for j in range(J)]

# Calculate the length of each element in H and store it in a tensor
m = torch.tensor([len(h) for h in H])

# Create a boolean tensor indicating whether each element in H has a length greater than 0
include = torch.tensor([len(h) > 0 for h in H])

log_nominator = torch.stack([torch.sum(log_hz_sorted[h]) for h in H])

denominator_naive = torch.stack([torch.sum(torch.exp(log_hz_sorted[r])) for r in R])
denominator_ties = torch.stack([torch.sum(torch.exp(log_hz_sorted[h])) for h in H])

log_denominator_efron = torch.zeros(J, device=log_hz_sorted.device)
for j in range(J):
mj = int(m[j].item())
for l in range(1, mj + 1):
log_denominator_efron[j] += torch.log(
denominator_naive[j] - (l - 1) / float(m[j]) * denominator_ties[j]
)
return (log_nominator - log_denominator_efron)[include]


@torch.jit.script
def _partial_likelihood_breslow(
log_hz_sorted: torch.Tensor,
event_sorted: torch.Tensor,
time_sorted: torch.Tensor,
):
"""
Compute the partial likelihood using Breslow's method for Cox proportional hazards model.

Args:
log_hz_sorted (torch.Tensor): Log hazard rates sorted by time.
event_sorted (torch.Tensor): Binary tensor indicating if the event occurred (1) or was censored (0), sorted by time.
time_sorted (torch.Tensor): Event or censoring times sorted in ascending order.

Returns:
torch.Tensor: The partial likelihood for the observed events.
"""
N = len(time_sorted)
R = [torch.where(time_sorted >= time_sorted[i])[0] for i in range(N)]
log_denominator = torch.stack(
[torch.logsumexp(log_hz_sorted[R[i]], dim=0) for i in range(N)]
)

return (log_hz_sorted - log_denominator)[event_sorted]


@torch.jit.script
def neg_partial_log_likelihood(
log_hz: torch.Tensor,
event: torch.Tensor,
Expand Down Expand Up @@ -118,9 +199,9 @@ def neg_partial_log_likelihood(
"""

if checks:
_check_inputs(log_hz, event, time)
validate_loss(log_hz, event, time, model_type="cox")

if any([event.sum() == 0, len(log_hz.size()) == 0]):
if any([event.sum().item() == 0, len(log_hz.size()) == 0]):
warnings.warn("No events OR single sample. Returning zero loss for the batch")
return torch.tensor(0.0, requires_grad=True)

Expand Down Expand Up @@ -164,98 +245,6 @@ def neg_partial_log_likelihood(
return loss


def _partial_likelihood_cox(
log_hz_sorted: torch.Tensor,
event_sorted: torch.Tensor,
) -> torch.Tensor:
"""Calculate the partial log likelihood for the Cox proportional hazards model
in the absence of ties in event time.
"""
log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0)
return (log_hz_sorted - log_denominator)[event_sorted]


def _partial_likelihood_efron(
log_hz_sorted: torch.Tensor,
event_sorted: torch.Tensor,
time_sorted: torch.Tensor,
time_unique: torch.Tensor,
) -> torch.Tensor:
"""Calculate the partial log likelihood for the Cox proportional hazards model
using Efron's method to handle ties in event time.
"""
J = len(time_unique)

H = [
torch.where((time_sorted == time_unique[j]) & (event_sorted == 1))[0]
for j in range(J)
]
R = [torch.where(time_sorted >= time_unique[j])[0] for j in range(J)]

m = torch.tensor([len(h) for h in H])
include = torch.tensor([len(h) > 0 for h in H])

log_nominator = torch.stack([torch.sum(log_hz_sorted[h]) for h in H])

denominator_naive = torch.stack([torch.sum(torch.exp(log_hz_sorted[r])) for r in R])
denominator_ties = torch.stack([torch.sum(torch.exp(log_hz_sorted[h])) for h in H])

log_denominator_efron = torch.zeros(J).to(log_hz_sorted.device)
for j in range(J):
for l in range(1, m[j] + 1):
log_denominator_efron[j] += torch.log(
denominator_naive[j] - (l - 1) / m[j] * denominator_ties[j]
)
return (log_nominator - log_denominator_efron)[include]


def _partial_likelihood_breslow(
log_hz_sorted: torch.Tensor,
event_sorted: torch.Tensor,
time_sorted: torch.Tensor,
):
"""Calculate the partial log likelihood for the Cox proportional hazards model
using Breslow's method to handle ties in event time.
"""
N = len(time_sorted)

R = [torch.where(time_sorted >= time_sorted[i])[0] for i in range(N)]
log_denominator = torch.tensor(
[torch.logsumexp(log_hz_sorted[R[i]], dim=0) for i in range(N)]
)

return (log_hz_sorted - log_denominator)[event_sorted]


def _check_inputs(log_hz: torch.Tensor, event: torch.Tensor, time: torch.Tensor):
if not isinstance(log_hz, torch.Tensor):
raise TypeError("Input 'log_hz' must be a tensor.")

if not isinstance(event, torch.Tensor):
raise TypeError("Input 'event' must be a tensor.")

if not isinstance(time, torch.Tensor):
raise TypeError("Input 'time' must be a tensor.")

if len(log_hz) != len(event):
raise ValueError(
"Length mismatch: 'log_hz' and 'event' must have the same length."
)

if len(time) != len(event):
raise ValueError(
"Length mismatch: 'time' must have the same length as 'event'."
)

if any(val < 0 for val in time):
raise ValueError("Invalid values: All elements in 'time' must be non-negative.")

if any(val not in [True, False, 0, 1] for val in event):
raise ValueError(
"Invalid values: 'event' must contain only boolean values (True/False or 1/0)"
)


if __name__ == "__main__":
import doctest

Expand Down
32 changes: 19 additions & 13 deletions src/torchsurv/loss/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,15 @@ def forward(

"""

estimate_q = self.online(inputs)
for estimate in zip(estimate_q, event, time):
self.memory_q.append(self.survtuple(*list(estimate)))
online_estimate = self.online(inputs)
for estimate in zip(online_estimate, event, time):
self.memory_q.append(self.survtuple(*estimate))
loss = self._bank_loss()
with torch.no_grad():
self._update_momentum_encoder()
estimate_k = self.target(inputs)
for estimate in zip(estimate_k, event, time):
self.memory_k.append(self.survtuple(*list(estimate)))
target_estimate = self.target(inputs)
for estimate in zip(target_estimate, event, time):
self.memory_k.append(self.survtuple(*estimate))
return loss

@torch.no_grad() # deactivates autograd
Expand Down Expand Up @@ -187,27 +187,33 @@ def infer(self, inputs: torch.Tensor) -> torch.Tensor:
return self.target(inputs)

def _bank_loss(self) -> torch.Tensor:
"""computer the negative loss likelyhood from memory bank"""
"""compute the negative log-likelihood from memory bank"""

# Combine current batch and momentum
bank = self.memory_k + self.memory_q
assert all(
x in bank[0]._fields for x in ["estimate", "event", "time"]
), "All fields must be present"
return self.loss(
torch.stack([mem.estimate.cpu() for mem in bank]).squeeze(),
torch.stack([mem.event.cpu() for mem in bank]).squeeze(),
torch.stack([mem.time.cpu() for mem in bank]).squeeze(),
)
log_estimates = torch.stack([mem.estimate.cpu() for mem in bank]).squeeze()
events = torch.stack([mem.event.cpu() for mem in bank]).squeeze()
times = torch.stack([mem.time.cpu() for mem in bank]).squeeze()
return self.loss(log_estimates, events, times)

@torch.no_grad()
def _update_momentum_encoder(self):
"""Exponantial moving average"""
"""Exponential moving average"""
for param_b, param_m in zip(self.online.parameters(), self.target.parameters()):
param_m.data = param_m.data * self.rate + param_b.data * (1.0 - self.rate)

@torch.no_grad()
def _init_encoder_k(self):
"""
Initialize the target network (encoder_k) with the parameters of the online network (encoder_q).
The requires_grad attribute of the target network parameters is set to False to prevent gradient updates during training,
ensuring that the target network remains a stable reference point.
This method uses the `copy_` method to copy the parameters from the online network to the target network
and sets the requires_grad attribute of the target network parameters to False to prevent gradient updates.
"""
for param_q, param_k in zip(self.online.parameters(), self.target.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
Expand Down
Loading
Loading