Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: Can Torchsurv Handle Time Series Data for Survival Analysis? #42

Closed
Lamgayin opened this issue Jul 5, 2024 · 3 comments
Closed
Labels
enhancement ✨ New feature or request good first issue 🥇 Good for newcomers
Milestone

Comments

@Lamgayin
Copy link

Lamgayin commented Jul 5, 2024

Dear authors, hello, I am a newcomer in the field of survival analysis. I have a research topic regarding the application of time series models to survival analysis, and I am not sure if torchsurv supports such functionality.thx a lot

@melodiemonod
Copy link
Collaborator

Hey @Lamgayin , thank you for your question. For now, time-dependent covariates are not supported by TorchSurv. We will add this as a potential new feature. Best regards, Melodie

@melodiemonod melodiemonod self-assigned this Jul 8, 2024
@melodiemonod melodiemonod added the enhancement ✨ New feature or request label Jul 8, 2024
@melodiemonod melodiemonod added this to the backlog milestone Jul 8, 2024
@tcoroller
Copy link
Collaborator

tcoroller commented Jul 9, 2024

Hi @Lamgayin,

TorchSurv allows you to use any model architecture and any data by design. If you want to use time series model to predict survival, you can simply use a RNN (many to one type) or transformer architecture to fit your longitudinal data. As long as your model outputs a single (or two for Weibull) estimate, you can then connect any torchsurv functions (loss and/or metrics) to it.

I attached a simple code example to illustrate how to use time series with RNN with TorchSurv. Here I am using 10 features across 5 time steps. The RNN is outputting a single estimate for each sample. It is then very easy to connect TorchSurv loss function and metrics from there.

Good luck!

import torch
from torchsurv.loss import cox
from torchsurv.metrics.cindex import ConcordanceIndex

# Parameters
input_size = 10
output_size = 1
num_layers = 2
seq_length = 5
batch_size = 8

# make random boolean events
events = torch.rand(batch_size) > 0.5
print(events)  # tensor([ True, False,  True,  True, False, False,  True, False])

# make random positive time to event
time = torch.rand(batch_size) * 100
print(time)  # tensor([32.8563, 38.3207, 24.6015, 72.2986, 19.9004, 65.2180, 73.2083, 21.2663])

# Create simple RNN model
rnn = torch.nn.RNN(input_size, output_size, num_layers)
inputs = torch.randn(seq_length, batch_size, input_size)
h0 = torch.randn(num_layers, batch_size, output_size)

# Forward pass time series input
outputs, _ = rnn(inputs, h0)
estimates = outputs[-1]  # Keep only last predictions, many to one approach
print(estimates.size())  # torch.Size([8, 1])
print(f"Estimate shape for {batch_size} samples = {estimates.size()}")  # Estimate shape for 8 samples = torch.Size([8, 1])


loss = cox.neg_partial_log_likelihood(estimates, events, time)
print(f"loss = {loss}, has gradient = {loss.requires_grad}")  # loss = 1.0389232635498047, has gradient = True

cindex = ConcordanceIndex()
print(f"c-index = {cindex(estimates, events, time)}")  # c-index = 0.20000000298023224

@tcoroller tcoroller added the good first issue 🥇 Good for newcomers label Jul 9, 2024
@tcoroller tcoroller pinned this issue Jul 9, 2024
@tcoroller tcoroller unpinned this issue Jul 15, 2024
@Lamgayin
Copy link
Author

@tcoroller Thank you for taking the time to reply to my question!!! It helped me a lot and my problem has been solved based on your suggestions 🙏🙏

@tcoroller tcoroller pinned this issue Jul 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement ✨ New feature or request good first issue 🥇 Good for newcomers
Projects
None yet
Development

No branches or pull requests

3 participants