Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TabNet explainability on custom data #38

Open
alexanderwatanabe opened this issue Dec 3, 2020 · 6 comments
Open

TabNet explainability on custom data #38

alexanderwatanabe opened this issue Dec 3, 2020 · 6 comments

Comments

@alexanderwatanabe
Copy link

alexanderwatanabe commented Dec 3, 2020

Hello, thank you for this repo. I am trying to run the TabNet notebook on a custom data set, have got everything working up to the explainability decorator which fails with this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-37-bfedb2755251> in <module>
----> 1 learn.explain(dl)

<ipython-input-36-c2a2fc0e0447> in explain(x, dl)
      6   for batch_nb, data in enumerate(dl):
      7     with torch.no_grad():
----> 8       out, M_loss, M_explain, masks = x.model(data[0], data[1], True)
      9     for key, value in masks.items():
     10       masks[key] = csc_matrix.dot(value.numpy(), matrix)

~/dev/Practical-Deep-Learning-for-Coders-2.0/.venv-nix/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

TypeError: forward() takes 3 positional arguments but 4 were given

I am reading the docs to better understand how to fix it, if you have any insights/pointers they would be appreciated.

@muellerzr
Copy link
Owner

muellerzr commented Dec 3, 2020 via email

@alexanderwatanabe
Copy link
Author

alexanderwatanabe commented Dec 3, 2020

here are what I think the relevant libraries i have in this environment are:
fastai 2.1.5
fastcore 1.3.6
fast_tabnet 0.2.0
pytorch-tabnet 2.0.1
fastinference 0.0.30
pytorch 1.7.0

also my model is setup to solve for single-variable regression

@muellerzr
Copy link
Owner

muellerzr commented Dec 3, 2020 via email

@alexanderwatanabe
Copy link
Author

Got it thanks!

@muellerzr
Copy link
Owner

muellerzr commented Dec 3, 2020 via email

@alexanderwatanabe
Copy link
Author

Got it working with your pinned versions and the new verions of fastai/torch. If you have any notes or outline for how you might approach fixing it for the new versions I'd love to get involved with contributing. Understand you are probably busy so it's a standing offer for later if necessary!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants