Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable feature grouping for attention mechanism #443

Merged
merged 1 commit into from
Dec 12, 2022

Conversation

Optimox
Copy link
Collaborator

@Optimox Optimox commented Dec 1, 2022

What kind of change does this PR introduce?
This PR solves #122 in a new way.

Embeddings from a same column are automatically grouped together.
Now attention works at the group level and not feature level. So without specifying anything groups for different categorical features are created.

This is a new as it allows users to specify features they would like to be grouped together by the attention mechanism. This can be very useful when sparse features are created (for example after a TD-IDF), attention has a hard time using this kind of features because of sparsity coming from both the data and the attention. Now you can group all those features together and have a single attention.

Does this PR introduce a breaking change?
I'm not sure if this should be considered a breaking change or not. I think so as old trained models used with this new code will have a different behaviour.

What needs to be documented once your changes are merged?

I think all changes are already documented in this PR.

Closing issues
closes #122

@Optimox Optimox force-pushed the feat/grouped-attention branch 8 times, most recently from 177a0f9 to a12c039 Compare December 1, 2022 15:40
pytorch_tabnet/tab_network.py Outdated Show resolved Hide resolved
pytorch_tabnet/utils.py Outdated Show resolved Hide resolved
- list_groups : list of list of int
Each element is a list representing features in the same group.
One feature should appear in maximum one group.
Feature that don't get assign a group will be in their own group of one feature.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Features that don't get assigned

for group_pos, group in enumerate(list_groups):
msg = f"Groups must be given as a list of list, but found {group} in position {group_pos}." # noqa
assert isinstance(group, list), msg
assert len(group) > 0, "Empty groups are forbidding please remove empty groups []"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forbidding -> forbidden

out = self.feat_transformers[step](masked_x)
d = ReLU()(out[:, : self.n_d])
# explain
step_importance = torch.sum(d, dim=1)
M_explain += torch.mul(M, step_importance.unsqueeze(dim=1))
M_explain += torch.mul(M_feature_level, step_importance.unsqueeze(dim=1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can multiply this by the transpose at the end to divide (equally given how the matrix is created) the importance for each feature of the group. Don't know if it's a good idea.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already a mapping to get importance from post embedding dimension to initial features. They might be redundant indeed, but I don't think I have time to think about it.

@eduardocarvp eduardocarvp merged commit bcae5f4 into develop Dec 12, 2022
@eduardocarvp eduardocarvp deleted the feat/grouped-attention branch December 12, 2022 09:34
@gauravbrills
Copy link

can we add an example of how to use this , bit confused about this @Optimox

@Optimox
Copy link
Collaborator Author

Optimox commented Oct 30, 2023

@gauravbrills There is an example on this notebook : https://github.com/dreamquark-ai/tabnet/blob/develop/census_example.ipynb

You simply need to give a list of groups, each group is a list with the index of the features forming the group.
Attention mechanism will consider each group as one single feature, so all features of the group will get the same attention (and importance). Note that all embedding dimensions generated by a categorical feature will be grouped together.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Research : Embedding Aware Attention
3 participants