-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement TabNet #21
Comments
Good idea! Do you want to take a stab at it? They have a pretty high level api. I can also help. |
Great, I think we can implement this into our package. I'll try to write generic codes for model fitting with the tabNet functions. |
Hi @ted9219 , Great job! I looked at the code and made small changes to get it to work on my end. I also added things like saving the model and feature importances. And verified it works and displays correctly in the shiny results viewer. I pushed it into develop. There are though a few things that worry me.
|
Thanks, @egillax . I totally agree with your comments. By benchmarking this repo, can we develop modules that implement the tabNet ourselves? I also found that the performance is lower than expected and lower than other modules like ResNet. It needs to be figured out why the performance is always low. I'll try. |
Just noting down a few observations I've made. The output of the model for me was weird with the default loss function. It was logits (it had some negative values) but the mean of the output was the outcome rate, suggesting the loss function used is treating the output as probabilities. So I switched to using loss <- torch::nn_bce_with_logits() in my config (it's the same loss function as used as default in my estimator class). Then the output makes sense after I use torch::nnf_sigmoid() to convert the logits to probabilities, as in the mean is the outcome rate and the distribution is what I would expect. Another thing, it seems the columns in the input need to be factors otherwise it treats them as numerical and skips the embedding layer. I think this could be the reason for the poor performance. I tried converting the columns to factors with trainDataset[] <- lapply(trainDataset, factor) but that takes a lot of time and then the training is extremely slow. After running it overnight I'm still not getting better auc's. I suspect the best move forward would be to use the module you linked from the tabnet repo but possibly rewriting the embedding layer to deal with our case where almost all variables are binary. I did already start implementing another transformer model from scratch and in the process made the estimator and the fit function more general, so in the future I think to add models you would only need to create a setModel function. I will commit those changes later today but I was stuck on the embedding layer with the same problem as for the tabnet, how to efficiently create an embedding layer for a matrix with binary features. So when I solve that I might have a solution for tabnet as well. |
TabNet (paper) has been implemented in R recently (mlverse/tabnet, even it is based on torch package !), it is known for its good performance with tabular data. Therefore, it would be interesting to implement TabNet and compare the model performance with the traditional algorithm (lasso logistic regression, xgboost).
The text was updated successfully, but these errors were encountered: