Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TabNet #21

Open
ChungsooKim opened this issue Feb 9, 2022 · 6 comments
Open

Implement TabNet #21

ChungsooKim opened this issue Feb 9, 2022 · 6 comments

Comments

@ChungsooKim
Copy link
Collaborator

ChungsooKim commented Feb 9, 2022

TabNet (paper) has been implemented in R recently (mlverse/tabnet, even it is based on torch package !), it is known for its good performance with tabular data. Therefore, it would be interesting to implement TabNet and compare the model performance with the traditional algorithm (lasso logistic regression, xgboost).

image

@egillax
Copy link
Collaborator

egillax commented Feb 14, 2022

Good idea! Do you want to take a stab at it? They have a pretty high level api. I can also help.

@ChungsooKim
Copy link
Collaborator Author

Great, I think we can implement this into our package. I'll try to write generic codes for model fitting with the tabNet functions.
here is a video for a tutorial.

@ChungsooKim
Copy link
Collaborator Author

Hi @jreps, @egillax I wrote generic codes for TabNet in the develop branch. Could you review the codes and give any comments on this? After this, I think we can compare its performances to the traditional algorithms or other deep learning algorithms like the Resnet.

@egillax
Copy link
Collaborator

egillax commented Mar 28, 2022

Hi @ted9219 ,

Great job! I looked at the code and made small changes to get it to work on my end. I also added things like saving the model and feature importances. And verified it works and displays correctly in the shiny results viewer. I pushed it into develop.

There are though a few things that worry me.

  • The tabnet package currently only accepts dense matrices. so we are currently limited in how large datasets we can use. An alternative I could explore would be to see if the nn module from the tabnet package can be made to fit in the Estimator class I made. Then we would control the input ourselves and make it support sparse matrices.
  • Currently when we call tabnet fit it always calculates the feature importances after fitting, I used a small simulated dataset from simulatePlp with 2000 rows and 32000 features but it seems this step takes some time and it's really only neccessary for the final model. I'm worried that when we run this on real data this step takes a non-trivial amount of time. Solutions would be either to ask the tabnet folks to make this step optional when fitting or as mentioned above explore solution that only uses the nn module from the tabnet package

@ChungsooKim
Copy link
Collaborator Author

ChungsooKim commented Apr 12, 2022

Thanks, @egillax . I totally agree with your comments.

By benchmarking this repo, can we develop modules that implement the tabNet ourselves?

I also found that the performance is lower than expected and lower than other modules like ResNet. It needs to be figured out why the performance is always low. I'll try.

@egillax
Copy link
Collaborator

egillax commented Apr 14, 2022

Just noting down a few observations I've made.

The output of the model for me was weird with the default loss function. It was logits (it had some negative values) but the mean of the output was the outcome rate, suggesting the loss function used is treating the output as probabilities. So I switched to using loss <- torch::nn_bce_with_logits() in my config (it's the same loss function as used as default in my estimator class). Then the output makes sense after I use torch::nnf_sigmoid() to convert the logits to probabilities, as in the mean is the outcome rate and the distribution is what I would expect.

Another thing, it seems the columns in the input need to be factors otherwise it treats them as numerical and skips the embedding layer. I think this could be the reason for the poor performance. I tried converting the columns to factors with trainDataset[] <- lapply(trainDataset, factor) but that takes a lot of time and then the training is extremely slow. After running it overnight I'm still not getting better auc's.

I suspect the best move forward would be to use the module you linked from the tabnet repo but possibly rewriting the embedding layer to deal with our case where almost all variables are binary. I did already start implementing another transformer model from scratch and in the process made the estimator and the fit function more general, so in the future I think to add models you would only need to create a setModel function.

I will commit those changes later today but I was stuck on the embedding layer with the same problem as for the tabnet, how to efficiently create an embedding layer for a matrix with binary features. So when I solve that I might have a solution for tabnet as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants