Skip to content

Commit

Permalink
Fixed a bug related to multiple tabular components with a FC head. In…
Browse files Browse the repository at this point in the history
…creased test coverage
  • Loading branch information
jrzaurin committed Aug 26, 2024
1 parent f244689 commit 07155ec
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 6 deletions.
8 changes: 8 additions & 0 deletions pytorch_widedeep/models/model_fusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch import nn

from pytorch_widedeep.models import TabNet
from pytorch_widedeep.wdtypes import List, Union, Tensor, Literal, Optional
from pytorch_widedeep.models.tabular.mlp._layers import MLP
from pytorch_widedeep.models._base_wd_model_component import (
Expand Down Expand Up @@ -300,6 +301,13 @@ def output_dim(self) -> int:
return output_dim

def check_input_parameters(self): # noqa: C901

if any(isinstance(model, TabNet) for model in self.models):
raise ValueError(
"TabNet is not supported in ModelFuser. "
"Please, use another model for tabular data"
)

if isinstance(self.fusion_method, str):
if not any(
x == self.fusion_method
Expand Down
20 changes: 16 additions & 4 deletions pytorch_widedeep/models/wide_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def _forward_component_with_head(
) -> Tensor:
if isinstance(component, nn.ModuleList):
component_out = torch.cat( # type: ignore[call-overload]
[cp(X[component_type]) for cp in component], axis=1
[cp(X[component_type][i]) for i, cp in enumerate(component)], axis=1
)
else:
component_out = component(X[component_type])
Expand Down Expand Up @@ -547,11 +547,23 @@ def _check_inputs( # noqa: C901
deephead_inp_feat = next(deephead.parameters()).size(1)
output_dim = 0
if deeptabular is not None:
output_dim += deeptabular.output_dim
if isinstance(deeptabular, list):
for dt in deeptabular:
output_dim += dt.output_dim
else:
output_dim += deeptabular.output_dim
if deeptext is not None:
output_dim += deeptext.output_dim
if isinstance(deeptext, list):
for dt in deeptext:
output_dim += dt.output_dim
else:
output_dim += deeptext.output_dim
if deepimage is not None:
output_dim += deepimage.output_dim
if isinstance(deepimage, list):
for di in deepimage:
output_dim += di.output_dim

Check warning on line 564 in pytorch_widedeep/models/wide_deep.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/models/wide_deep.py#L563-L564

Added lines #L563 - L564 were not covered by tests
else:
output_dim += deepimage.output_dim
if deephead_inp_feat != output_dim:
warnings.warn(
"A custom 'deephead' is used and it seems that the input features "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
import pytest

from pytorch_widedeep import Trainer
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser
from pytorch_widedeep.models import (
TabMlp,
TabNet,
BasicRNN,
WideDeep,
ModelFuser,
)
from pytorch_widedeep.metrics import F1Score, Accuracy
from pytorch_widedeep.callbacks import LRHistory
from pytorch_widedeep.initializers import XavierNormal, KaimingNormal
Expand Down Expand Up @@ -569,7 +575,8 @@ def test_model_fusion_projection_methods(projection_method):
assert out.shape[1] == proj_dim == models_fuser.output_dim


def test_model_fusion_full_process():
@pytest.mark.parametrize("head_type", [None, "via_params", "custom"])
def test_full_process_with_fusion(head_type):

fused_tab_model = ModelFuser(
models=[tab_mlp_user, tab_mlp_item],
Expand All @@ -583,10 +590,87 @@ def test_model_fusion_full_process():
projection_method="min",
)

if head_type == "via_params":
head_hidden_dims = [fused_tab_model.output_dim + fused_text_model.output_dim, 8]
custom_head = None
elif head_type == "custom":
head_hidden_dims = None
custom_head = CustomHead(
fused_tab_model.output_dim + fused_text_model.output_dim, 8
)
else:
head_hidden_dims = None
custom_head = None

model = WideDeep(
deeptabular=fused_tab_model,
deeptext=fused_text_model,
pred_dim=1,
head_hidden_dims=head_hidden_dims,
deephead=custom_head,
)

n_epochs = 2
trainer = Trainer(
model,
objective="binary",
verbose=0,
)

X_train = {
"X_tab": [X_tab_user_tr, X_tab_item_tr],
"X_text": [X_text_review_tr, X_text_description_tr],
"target": train_df["purchased"].values,
}
X_val = {
"X_tab": [X_tab_user_val, X_tab_item_val],
"X_text": [X_text_review_val, X_text_description_val],
"target": valid_df["purchased"].values,
}
trainer.fit(
X_train=X_train,
X_val=X_val,
n_epochs=n_epochs,
batch_size=4,
)

# weak assertion, but anyway...
assert len(trainer.history["train_loss"]) == n_epochs


@pytest.mark.parametrize("head_type", [None, "via_params", "custom"])
def test_full_process_without_fusion(head_type):

# the 4 models to be combined are tab_mlp_user, tab_mlp_item, rnn_reviews,
# rnn_descriptions
if head_type == "via_params":
head_hidden_dims = [
tab_mlp_user.output_dim
+ tab_mlp_item.output_dim
+ rnn_reviews.output_dim
+ rnn_descriptions.output_dim,
8,
]
custom_head = None
elif head_type == "custom":
head_hidden_dims = None
custom_head = CustomHead(
tab_mlp_user.output_dim
+ tab_mlp_item.output_dim
+ rnn_reviews.output_dim
+ rnn_descriptions.output_dim,
8,
)
else:
head_hidden_dims = None
custom_head = None

model = WideDeep(
deeptabular=[tab_mlp_user, tab_mlp_item],
deeptext=[rnn_reviews, rnn_descriptions],
pred_dim=1,
head_hidden_dims=head_hidden_dims,
deephead=custom_head,
)

n_epochs = 2
Expand Down Expand Up @@ -615,3 +699,30 @@ def test_model_fusion_full_process():

# weak assertion, but anyway...
assert len(trainer.history["train_loss"]) == n_epochs


@pytest.mark.parametrize("fuse_models", [True, False])
def test_catch_tabnet_error(fuse_models):

tabnet_user = TabNet(
column_idx=tab_preprocessor_user.column_idx,
cat_embed_input=tab_preprocessor_user.cat_embed_input,
continuous_cols=tab_preprocessor_user.continuous_cols,
)

tab_mlp_item = TabMlp(
column_idx=tab_preprocessor_item.column_idx,
cat_embed_input=tab_preprocessor_item.cat_embed_input,
continuous_cols=tab_preprocessor_item.continuous_cols,
)

if fuse_models:
with pytest.raises(ValueError):
fused_model = ModelFuser( # noqa: F841
models=[tabnet_user, tab_mlp_item],
fusion_method="mean",
projection_method="max",
)
else:
with pytest.raises(ValueError):
model = WideDeep(deeptabular=[tabnet_user, tab_mlp_item]) # noqa: F841

0 comments on commit 07155ec

Please sign in to comment.