Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: embedding-aware attention #217

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from

Conversation

athewsey
Copy link
Contributor

@athewsey athewsey commented Oct 22, 2020

⚠️ **This is a pretty early draft with more testing required and feedback very welcome, and I know there's been some other work around here before e.g. #92 ...But I thought it was worth sharing early since I'd had some time to play, in case it's useful! **

This implementation amends the attention transformer output dimension to equal the number of features, instead of the number of post-embedding dimensions.

  • EmbeddingGenerator is modified to keep a record of the number of dimensions that each feature was embedded to (in a deliberately agnostic way, because I've been experimenting with embedding scalar fields to multiple dimensions too).
  • TabNetNoEmbeddings is modified to use this feature_embed_widths list to expand out the raw mask matrix M (by features) to the embedding-compatible M_x - by replicating each feature's mask weights to however many columns it was embedded to.
  • Since all mask-based calcs are now at the feature level, rather than the embedding dimension level, I think the correct handling in AbstractModel is just to remove the reducing_matrix altogether? But should get more familiar with this area of the code.

Important limitations:

  • Current implementation does not introduce embedding-aware attention as an option, but makes it default and non-configurable - so would be good to gather more data on the effectiveness/accuracy.
  • So far only tested with basic TabNetClassifier fit & transform: As mentioned above I would need to drill further into the explainability to check there's no potential bugs introduced there. To my knowledge this change is agnostic to whether it's a regression or multi-task problem, but hopefully the CI tests will help confirm that 😂

Testing and results:

On a Forest Cover Type based example , I've observed this change to improve validation set performance from approx:

  • 58.7% to 63.1% at epoch 10 (faster convergence in early stages)
  • 82.4% to 90.2% at epoch 50 (still quite early on, but I haven't run any 200 epoch tests yet).

...at essentially the same training speed (11min 41sec to 50 epochs for both pre- and post-change algorithms).

Specifically:

  • 80% training / 10% validation random split of Forest Cover Type (10% test set not yet used)
  • Area and Soil_Type features consolidated from the raw (one-hot) data to categorical fields, with embedding dimensions 2 and 3 respectively (representing 4 distinct Areas, 40 Soil_Types)
  • batch_size=16384, clip_value=2.0, epsilon=1e-15, gamma=1.5, lambda_sparse=0.0001, lr=0.02, max_epochs=50, model_type='classification', momentum=0.3, n_a=64, n_d=64, n_independent=2, n_shared=2, n_steps=5, patience=100, seed=1337, target='Cover_Type', virtual_batch_size=256

IMPORTANT: Please do not create a Pull Request without creating an issue first.

Any change needs to be discussed before proceeding. Failure to do so may result in the rejection of the pull request.

What kind of change does this PR introduce? feature

Does this PR introduce a breaking change? ⚠️ Kinda

  • The new embedding-aware attention behaviour is both default and non-configurable in the current implementation
  • Need to check for any changes affecting explainability

What needs to be documented once your changes are merged?

  • Still need to check through READMEs, sample notebooks, etc for conflicting statements/guidance about how the attention & masking is implemented.

Closing issues

Hopefully #122, eventually

Apply common mask to all embedded columns of a feature
Copy link
Collaborator

@Optimox Optimox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing! I had a very quick look so I'm not sure I fully understood your approach.

Please correct me if I'm wrong:

  • you create mask for each feature without taking embedding dimensions into account. Then give the same attention to each corresponding embedding dimension?

That's a way to make attention aware indeed. I'm not a big fan of the for loops it creates though.
reducing_matrix is a way of getting back the information without for loops.

I did not try any benchmark with this, but you seem to have interesting results on ForestCoverType. Have you tried #92 on forest cover type? I could try to make this branch up to date, will see if I have time this weekend.

mask_type=self.mask_type)
attention = AttentiveTransformer(
n_a,
len(self.feature_embed_widths) if self.feature_embed_widths else self.input_dim,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly don't you always have len(self.feature_embed_widths)==self.input_dim ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No - since TabNetNoEmbeddings is just the portion of the network after embeddings have been done, the input_dim is the post-embedding dimension.

@@ -40,6 +40,7 @@ def forward(self, x):

class TabNetNoEmbeddings(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the name suggest, this class is supposed to be basic tabnet with no embeddings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup and I agree it's a nice distinction to keep! So my thought was to add an optional parameter (feature_embed_widths) to this class where, if they want, the user can tell TabNet to treat multiple columns of the input as a single "feature" for attention purposes. By default (None) TabNetNoEmbeddings should work as before, treating every column independently. There are a couple of ways this API could pass in the information required, so the current one is a list of how wide each "feature" in your input is: E.g. [1, 1, 2, 3] would mean:

  • input_dim is 1+1+2+3=7
  • n_features is 4
  • First two columns are scalar features, next two are a feature with emb_dim 2, next three are a feature with emb_dim 3

f-string issue and unused imports
Pre-compute n_features count and expand mask matrices via indexing
instead of for-loop concatenation.
@athewsey
Copy link
Contributor Author

Completely agree with you about the for loops and my concerns about that were why I measured execution time initially! Having thought about it for a while, I realised M_x can be expanded via an indexing operation instead of loops 🤯

e.g. for feature_embed_widths=[1, 3, 1, 2], then M_x = M[:, [0, 1, 1, 1, 2, 3, 3]]

This change seems to deliver a modest speedup on my setup, from 11min 41sec to 11min 32sec at 50 epochs (faster than the original develop-branch code without embedding-aware attention).

Your understanding was correct by the way: not making any changes to the actual implementation of AttentionTransformer; just restricting its output dimension down from the number of columns to the number of underlying features... So matrix M is now a true feature-wise mask, and I'm just copying columns of M as dictated by the emb_dim of each feature, to get back up to the M_x which is input_dim wide. Still haven't had a chance to test with #92 yet!

@athewsey
Copy link
Contributor Author

OK I tried several approaches but have not been able to get #92 to install on PyTorch v1.4 environments because of the torch-scatter dependency it introduces. Should theoretically be possible per their installation docs (as long as CUDA is <10.2), but I've not had any luck.

However I was able to run a comparison on a PyTorch v1.6 environment.

  • Same Forest Cover Type preparation and hyperparams mentioned above: ~465k training samples of 13 raw columns (including 2 categoricals embedded to 2 and 3dims respectively). ~58k validation samples.
  • Same 10-epoch and 50-epoch analysis points.
  • ...But now 6 test runs for each alternative (incl. a couple with fixed seed) to get an idea of variance.
  • Comparing this PR vs my quick & dirty attempt to update #92 from develop vs develop (a couple of commits behind - same as this PR and my update of 92).

Experimental Results

Validation accuracy at 10 epochs

Candidate Acc@10e (min) Acc@10e (mean) Acc@10e (std) Acc@10e (max)
develop 59.459% 61.223% 1.246% 62.756%
prev PR92 60.438% 61.961% 1.019% 62.761%
this PR217 61.634% 62.408% 0.747% 63.837%

Training time at 10 epochs (seconds)

Candidate sec@10e (min) sec@10e (mean) sec@10e (std) sec@10e (max)
develop 126s 131s 4.62s 136s
prev PR92 126s 129s 3.27s 135s
this PR217 127s 129s 3.37s 136s

Validation accuracy at 50 epochs

Candidate Acc@50e (min) Acc@50e (mean) Acc@50e (std) Acc@50e (max)
develop 81.887% 85.037% 2.381% 87.685%
prev PR92 85.284% 87.901% 1.907% 90.152%
this PR217 89.159% 89.558% 0.478% 90.231%

Training time at 50 epochs (seconds)

Candidate sec@50e (min) sec@50e (mean) sec@50e (std) sec@50e (max)
develop 601s 637s 24.8s 659s
prev PR92 614s 627s 16.3s 659s
this PR217 617s 629s 16.6s 662s

Interpretation / Observations

Per the tables:

  • Both candidates seem to slightly lower average execution times (presumably mainly because of reduced parameter count with shared attention on the 5 columns reduced to 2 categorical features), but the extrema suggest the difference is probably pretty marginal.
  • The previous PR92 may be slightly faster than this one but there's really very little to choose between them on speed.
  • Both candidates appear to offer pretty good accuracy improvement & variance reduction on this dataset, with the gains more pronounced at 50 epochs than the very-early 10 epoch figures.
  • Maybe there's evidence to suggest that this PR217 offers better accuracy overall? The ranges still overlap though

I also noticed:

  • PR217 (or maybe something that went wrong in my merge from develop?) seems to break seed: Results weren't reproducible with same seed value, which they were with develop and this PR.
  • As I've noted before on Create a set of benchmark dataset #127, I'm still seeing generally lower accuracies on PyTorch 1.6 vs 1.4 (which is why I was initially testing on 1.4)

This is of course only one sample dataset, and only one hyperparameter configuration!

My preference would be to push forward with an approach along the lines of this PR if we're comfortable it can deliver comparable performance to #92 - because it avoids introducing the extra CUDA-linked pytorch-scatter dependency and also seems to simplify away the reducing_matrix and associated scipy/sparse usage currently in develop... But I am biased of course! 😂

@Optimox
Copy link
Collaborator

Optimox commented Oct 27, 2020

@athewsey amazing contribution! Thanks for this detailed analysis.
I agree with you that being dependent to torch-scatter is probably not a good idea anyway!

I'll need to spend more time to dig deeper into the impact of your proposal, but I'm on board with the idea.

My only concern is that we are currently working on self supervised pretraining, so it will imply some refactorization of the network part, I'll then need to adapt your proposal to the code so that it fits a bit better to the overall code. But I don't want to "steal" your contribution, so we'll need to figure out a way to do this together. We'll find out!

Thanks! I'll get back to you!

Let us know if you do a similar benchmark on other datasets!

@athewsey
Copy link
Contributor Author

Sure thing, thanks! Happy to support on porting to another branch too if you have one that's looking stable & favoured - I'd expect this change should map fairly nicely since it's pretty self-contained and hopefully shouldn't have too much functional overlap, just touches same code sections.

@Optimox
Copy link
Collaborator

Optimox commented Oct 28, 2020

@athewsey I think it would be worth running for 200 epochs (not necessarily in a 5 fold setting, but with the original split of the paper, the same as we do in the Forest Cover Type notebook). Because what you showed is that we are converging faster (which is good) but not necessarily better. The final test accuracy should be able to reach 0.96~0.97 and after 50 epochs we are still far from this score.

@athewsey
Copy link
Contributor Author

athewsey commented Nov 4, 2020

Hey sorry it's taken a while - now got some extra results from longer testing!

I modified the train/val/test split to 60/20/20 in line with the example (previous tests were on 80/10/10)... But actually as commented before I've never seen the library reach ~96% on this dataset/hyperparam combo in PyTorch v1.6 - It got there in PyTorch v1.4, but on 1.6 always topped out at ~93% in my previous (200 epoch) tests.

...So I ran the tests on to 300 epochs and took measures every hundred, to see whether there were any easy gains to be had. I did repeat for 5 random seeds again, because (as we see below) the ranges all do overlap quite a bit so it didn't seem safe to just take a particular result for each branch.

Code branches are the same as in the previous tests - not updated with any new commit merges.

Validation accuracy at 100 epochs

Candidate Acc@100e (min) Acc@100e (mean) Acc@100e (std) Acc@100e (max)
develop 86.113% 88.258% 2.235% 91.753%
prev PR92 85.761% 87.776% 2.157% 90.471%
this PR217 89.330% 91.206% 1.156% 92.134%

Training time at 100 epochs (seconds)

Candidate sec@100e (min) sec@100e (mean) sec@100e (std) sec@100e (max)
develop 992s 1,005s 9.70s 1,016s
prev PR92 1,005s 1,015s 8.60s 1,024s
this PR217 1,004s 1,028s 26.41s 1,071s

Validation accuracy at 200 epochs

Candidate Acc@200e (min) Acc@200e (mean) Acc@200e (std) Acc@200e (max)
develop 90.523% 92.212% 1.455% 93.710%
prev PR92 91.549% 93.018% 0.953% 94.029%
this PR217 92.286% 93.760% 0.978% 94.590%

Training time at 200 epochs (seconds)

Candidate sec@200e (min) sec@200e (mean) sec@200e (std) sec@200e (max)
develop 1,950s 2,002s 31.76s 2,031s
prev PR92 2,010s 2,029s 15.77s 2,046s
this PR217 2,004s 2,053s 52.65s 2,137s

Validation accuracy at 300 epochs

Candidate Acc@300e (min) Acc@300e (mean) Acc@300e (std) Acc@300e (max)
develop 92.920% 94.086% 0.862% 94.925%
prev PR92 93.294% 94.452% 0.665% 94.874%
this PR217 93.732% 94.748% 0.671% 95.312%

Training time at 300 epochs (seconds)

Candidate sec@300e (min) sec@300e (mean) sec@300e (std) sec@300e (max)
develop 2,974s 3,014s 24.75s 3,035s
prev PR92 3,015s 3,046s 26.17s 3,079s
this PR217 3,006s 3,078s 74.76s 3,194s

Again the accuracy results all overlap significantly, but this candidate seems to come out top on average.

The timings are notably switched from the short tests: unchanged develop branch comes out reliably fastest. The difference between the two PR candidates gets a bit clearer with the longer test.

@bratao
Copy link

bratao commented Apr 17, 2021

@Optimox sorry for the bump, but do you have plans of merging this?
I +1 on improvements on my particular problem 😄

@Optimox
Copy link
Collaborator

Optimox commented Apr 18, 2021

@Optimox sorry for the bump, but do you have plans of merging this?
I +1 on improvements on my particular problem smile

Hmmm I would probably need some more time to review carefully and convince myself that this is the way to solve the problem. But yes I'll have a look as soon as I can.

@bratao Would you mind telling us a bit more about the improvements it gave on your particular problem and the problem itself?

@bratao
Copy link

bratao commented Apr 18, 2021

I use for a Fintech credit model. I got an extra 5% in RMSE compared to master ( I ported this patch to Regression)

It still not beat a well tuned random forest for my use case, but is getting closer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants