Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: embedding-aware attention #217

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 3 additions & 19 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
import torch
from torch.nn.utils import clip_grad_norm_
import numpy as np
from scipy.sparse import csc_matrix
from abc import abstractmethod
from pytorch_tabnet import tab_network
from pytorch_tabnet.utils import (
PredictDataset,
create_explain_matrix,
validate_eval_set,
create_dataloaders,
define_device,
Expand Down Expand Up @@ -252,13 +250,9 @@ def explain(self, X):

M_explain, masks = self.network.forward_masks(data)
for key, value in masks.items():
masks[key] = csc_matrix.dot(
value.cpu().detach().numpy(), self.reducing_matrix
)
masks[key] = value.cpu().detach().numpy()

res_explain.append(
csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix)
)
res_explain.append(M_explain.cpu().detach().numpy())

if batch_nb == 0:
res_masks = masks
Expand Down Expand Up @@ -481,13 +475,6 @@ def _set_network(self):
mask_type=self.mask_type,
).to(self.device)

self.reducing_matrix = create_explain_matrix(
self.network.input_dim,
self.network.cat_emb_dim,
self.network.cat_idxs,
self.network.post_embed_dim,
)

def _set_metrics(self, metrics, eval_names):
"""Set attributes relative to the metrics.

Expand Down Expand Up @@ -615,15 +602,12 @@ def _compute_feature_importances(self, loader):

"""
self.network.eval()
feature_importances_ = np.zeros((self.network.post_embed_dim))
feature_importances_ = np.zeros((self.network.input_dim))
for data, targets in loader:
data = data.to(self.device).float()
M_explain, masks = self.network.forward_masks(data)
feature_importances_ += M_explain.sum(dim=0).cpu().detach().numpy()

feature_importances_ = csc_matrix.dot(
feature_importances_, self.reducing_matrix
)
self.feature_importances_ = feature_importances_ / np.sum(feature_importances_)

@abstractmethod
Expand Down
3 changes: 1 addition & 2 deletions pytorch_tabnet/multiclass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,7 @@ def infer_multitask_output(y_train):

if len(y_train.shape) < 2:
raise ValueError(
f"""y_train shoud be of shape (n_examples, n_tasks) """
+ f"""but got {y_train.shape}"""
f"y_train shoud be of shape (n_examples, n_tasks) but got {y_train.shape}"
)
nb_tasks = y_train.shape[1]
tasks_dims = []
Expand Down
52 changes: 44 additions & 8 deletions pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def forward(self, x):

class TabNetNoEmbeddings(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the name suggest, this class is supposed to be basic tabnet with no embeddings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup and I agree it's a nice distinction to keep! So my thought was to add an optional parameter (feature_embed_widths) to this class where, if they want, the user can tell TabNet to treat multiple columns of the input as a single "feature" for attention purposes. By default (None) TabNetNoEmbeddings should work as before, treating every column independently. There are a couple of ways this API could pass in the information required, so the current one is a list of how wide each "feature" in your input is: E.g. [1, 1, 2, 3] would mean:

  • input_dim is 1+1+2+3=7
  • n_features is 4
  • First two columns are scalar features, next two are a feature with emb_dim 2, next three are a feature with emb_dim 3

def __init__(self, input_dim, output_dim,
feature_embed_widths=None,
n_d=8, n_a=8,
n_steps=3, gamma=1.3,
n_independent=2, n_shared=2, epsilon=1e-15,
Expand All @@ -51,10 +52,13 @@ def __init__(self, input_dim, output_dim,
Parameters
----------
input_dim : int
Number of features
Number of input columns
output_dim : int or list of int for multi task classification
Dimension of network output
examples : one for regression, 2 for binary classification etc...
feature_embed_widths : list of int
The embedding width of each underlying feature in the input, for embedding-aware
attention. If not supplied, every input column will be treated as a separate feature.
n_d : int
Dimension of the prediction layer (usually between 4 and 64)
n_a : int
Expand All @@ -79,6 +83,8 @@ def __init__(self, input_dim, output_dim,
super(TabNetNoEmbeddings, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.feature_embed_widths = feature_embed_widths
self.n_features = len(feature_embed_widths) if feature_embed_widths else input_dim
self.is_multi_task = isinstance(output_dim, list)
self.n_d = n_d
self.n_a = n_a
Expand All @@ -104,6 +110,14 @@ def __init__(self, input_dim, output_dim,
else:
shared_feat_transform = None

if self.feature_embed_widths:
# Pre-process e.g. [2, 3, 1] into LengTensor([0, 0, 1, 1, 1, 2]) for fast mask matrix
# expansion via indexing in forward pass:
mask_ixs_nested = [[ix] * size for ix, size in enumerate(self.feature_embed_widths)]
self.mask_feature_indexes = torch.LongTensor(
[item for sublist in mask_ixs_nested for item in sublist]
)

self.initial_splitter = FeatTransformer(self.input_dim, n_d+n_a, shared_feat_transform,
n_glu_independent=self.n_independent,
virtual_batch_size=self.virtual_batch_size,
Expand All @@ -117,7 +131,7 @@ def __init__(self, input_dim, output_dim,
n_glu_independent=self.n_independent,
virtual_batch_size=self.virtual_batch_size,
momentum=momentum)
attention = AttentiveTransformer(n_a, self.input_dim,
attention = AttentiveTransformer(n_a, self.n_features,
virtual_batch_size=self.virtual_batch_size,
momentum=momentum,
mask_type=self.mask_type)
Expand All @@ -138,7 +152,9 @@ def forward(self, x):
res = 0
x = self.initial_bn(x)

prior = torch.ones(x.shape).to(x.device)
prior = torch.ones(
[x.shape[0], self.n_features] if self.feature_embed_widths else x.shape
).to(x.device)
M_loss = 0
att = self.initial_splitter(x)[:, self.n_d:]

Expand All @@ -148,8 +164,12 @@ def forward(self, x):
dim=1))
# update prior
prior = torch.mul(self.gamma - M, prior)

# expand M to match embedded input dimension, if necessary:
M_x = M[:, self.mask_feature_indexes] if self.feature_embed_widths else M

# output
masked_x = torch.mul(M, x)
masked_x = torch.mul(M_x, x)
out = self.feat_transformers[step](masked_x)
d = ReLU()(out[:, :self.n_d])
res = torch.add(res, d)
Expand All @@ -170,8 +190,12 @@ def forward(self, x):
def forward_masks(self, x):
x = self.initial_bn(x)

prior = torch.ones(x.shape).to(x.device)
M_explain = torch.zeros(x.shape).to(x.device)
prior = torch.ones(
[x.shape[0], self.n_features] if self.feature_embed_widths else x.shape
).to(x.device)
M_explain = torch.zeros(
[x.shape[0], self.n_features] if self.feature_embed_widths else x.shape
).to(x.device)
att = self.initial_splitter(x)[:, self.n_d:]
masks = {}

Expand All @@ -180,8 +204,12 @@ def forward_masks(self, x):
masks[step] = M
# update prior
prior = torch.mul(self.gamma - M, prior)

# expand M to match embedded input dimension, if necessary:
M_x = M[:, self.mask_feature_indexes] if self.feature_embed_widths else M

# output
masked_x = torch.mul(M, x)
masked_x = torch.mul(M_x, x)
out = self.feat_transformers[step](masked_x)
d = ReLU()(out[:, :self.n_d])
# explain
Expand Down Expand Up @@ -263,7 +291,9 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8,
self.virtual_batch_size = virtual_batch_size
self.embedder = EmbeddingGenerator(input_dim, cat_dims, cat_idxs, cat_emb_dim)
self.post_embed_dim = self.embedder.post_embed_dim
self.tabnet = TabNetNoEmbeddings(self.post_embed_dim, output_dim, n_d, n_a, n_steps,
self.feature_embed_widths = self.embedder.feature_embed_widths
self.tabnet = TabNetNoEmbeddings(self.post_embed_dim, output_dim,
self.feature_embed_widths, n_d, n_a, n_steps,
gamma, n_independent, n_shared, epsilon,
virtual_batch_size, momentum, mask_type)

Expand Down Expand Up @@ -475,6 +505,7 @@ def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dim):
super(EmbeddingGenerator, self).__init__()
if cat_dims == [] or cat_idxs == []:
self.skip_embedding = True
self.feature_embed_widths = None
self.post_embed_dim = input_dim
return

Expand Down Expand Up @@ -505,6 +536,11 @@ def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dim):
self.continuous_idx = torch.ones(input_dim, dtype=torch.bool)
self.continuous_idx[cat_idxs] = 0

# Record final embedded widths of each feature
self.feature_embed_widths = input_dim * [1]
for ix, dim in zip(cat_idxs, self.cat_emb_dims):
self.feature_embed_widths[ix] = dim

def forward(self, x):
"""
Apply embdeddings to inputs
Expand Down