From 07155ec55261f7f3a8cceabcd5783ababfbb8137 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 26 Aug 2024 12:05:45 +0200 Subject: [PATCH] Fixed a bug related to multiple tabular components with a FC head. Increased test coverage --- pytorch_widedeep/models/model_fusion.py | 8 ++ pytorch_widedeep/models/wide_deep.py | 20 ++- .../test_multi_tab_and_text_components.py | 115 +++++++++++++++++- 3 files changed, 137 insertions(+), 6 deletions(-) diff --git a/pytorch_widedeep/models/model_fusion.py b/pytorch_widedeep/models/model_fusion.py index 72bb40a6..a667aae4 100644 --- a/pytorch_widedeep/models/model_fusion.py +++ b/pytorch_widedeep/models/model_fusion.py @@ -1,6 +1,7 @@ import torch from torch import nn +from pytorch_widedeep.models import TabNet from pytorch_widedeep.wdtypes import List, Union, Tensor, Literal, Optional from pytorch_widedeep.models.tabular.mlp._layers import MLP from pytorch_widedeep.models._base_wd_model_component import ( @@ -300,6 +301,13 @@ def output_dim(self) -> int: return output_dim def check_input_parameters(self): # noqa: C901 + + if any(isinstance(model, TabNet) for model in self.models): + raise ValueError( + "TabNet is not supported in ModelFuser. " + "Please, use another model for tabular data" + ) + if isinstance(self.fusion_method, str): if not any( x == self.fusion_method diff --git a/pytorch_widedeep/models/wide_deep.py b/pytorch_widedeep/models/wide_deep.py index 544818b9..7d131b44 100644 --- a/pytorch_widedeep/models/wide_deep.py +++ b/pytorch_widedeep/models/wide_deep.py @@ -400,7 +400,7 @@ def _forward_component_with_head( ) -> Tensor: if isinstance(component, nn.ModuleList): component_out = torch.cat( # type: ignore[call-overload] - [cp(X[component_type]) for cp in component], axis=1 + [cp(X[component_type][i]) for i, cp in enumerate(component)], axis=1 ) else: component_out = component(X[component_type]) @@ -547,11 +547,23 @@ def _check_inputs( # noqa: C901 deephead_inp_feat = next(deephead.parameters()).size(1) output_dim = 0 if deeptabular is not None: - output_dim += deeptabular.output_dim + if isinstance(deeptabular, list): + for dt in deeptabular: + output_dim += dt.output_dim + else: + output_dim += deeptabular.output_dim if deeptext is not None: - output_dim += deeptext.output_dim + if isinstance(deeptext, list): + for dt in deeptext: + output_dim += dt.output_dim + else: + output_dim += deeptext.output_dim if deepimage is not None: - output_dim += deepimage.output_dim + if isinstance(deepimage, list): + for di in deepimage: + output_dim += di.output_dim + else: + output_dim += deepimage.output_dim if deephead_inp_feat != output_dim: warnings.warn( "A custom 'deephead' is used and it seems that the input features " diff --git a/tests/test_multi_model_and_mutil_data/test_multi_tab_and_text_components.py b/tests/test_multi_model_and_mutil_data/test_multi_tab_and_text_components.py index 691ad74f..4a28c393 100644 --- a/tests/test_multi_model_and_mutil_data/test_multi_tab_and_text_components.py +++ b/tests/test_multi_model_and_mutil_data/test_multi_tab_and_text_components.py @@ -10,7 +10,13 @@ import pytest from pytorch_widedeep import Trainer -from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser +from pytorch_widedeep.models import ( + TabMlp, + TabNet, + BasicRNN, + WideDeep, + ModelFuser, +) from pytorch_widedeep.metrics import F1Score, Accuracy from pytorch_widedeep.callbacks import LRHistory from pytorch_widedeep.initializers import XavierNormal, KaimingNormal @@ -569,7 +575,8 @@ def test_model_fusion_projection_methods(projection_method): assert out.shape[1] == proj_dim == models_fuser.output_dim -def test_model_fusion_full_process(): +@pytest.mark.parametrize("head_type", [None, "via_params", "custom"]) +def test_full_process_with_fusion(head_type): fused_tab_model = ModelFuser( models=[tab_mlp_user, tab_mlp_item], @@ -583,10 +590,87 @@ def test_model_fusion_full_process(): projection_method="min", ) + if head_type == "via_params": + head_hidden_dims = [fused_tab_model.output_dim + fused_text_model.output_dim, 8] + custom_head = None + elif head_type == "custom": + head_hidden_dims = None + custom_head = CustomHead( + fused_tab_model.output_dim + fused_text_model.output_dim, 8 + ) + else: + head_hidden_dims = None + custom_head = None + model = WideDeep( deeptabular=fused_tab_model, deeptext=fused_text_model, pred_dim=1, + head_hidden_dims=head_hidden_dims, + deephead=custom_head, + ) + + n_epochs = 2 + trainer = Trainer( + model, + objective="binary", + verbose=0, + ) + + X_train = { + "X_tab": [X_tab_user_tr, X_tab_item_tr], + "X_text": [X_text_review_tr, X_text_description_tr], + "target": train_df["purchased"].values, + } + X_val = { + "X_tab": [X_tab_user_val, X_tab_item_val], + "X_text": [X_text_review_val, X_text_description_val], + "target": valid_df["purchased"].values, + } + trainer.fit( + X_train=X_train, + X_val=X_val, + n_epochs=n_epochs, + batch_size=4, + ) + + # weak assertion, but anyway... + assert len(trainer.history["train_loss"]) == n_epochs + + +@pytest.mark.parametrize("head_type", [None, "via_params", "custom"]) +def test_full_process_without_fusion(head_type): + + # the 4 models to be combined are tab_mlp_user, tab_mlp_item, rnn_reviews, + # rnn_descriptions + if head_type == "via_params": + head_hidden_dims = [ + tab_mlp_user.output_dim + + tab_mlp_item.output_dim + + rnn_reviews.output_dim + + rnn_descriptions.output_dim, + 8, + ] + custom_head = None + elif head_type == "custom": + head_hidden_dims = None + custom_head = CustomHead( + tab_mlp_user.output_dim + + tab_mlp_item.output_dim + + rnn_reviews.output_dim + + rnn_descriptions.output_dim, + 8, + ) + else: + head_hidden_dims = None + custom_head = None + + model = WideDeep( + deeptabular=[tab_mlp_user, tab_mlp_item], + deeptext=[rnn_reviews, rnn_descriptions], + pred_dim=1, + head_hidden_dims=head_hidden_dims, + deephead=custom_head, ) n_epochs = 2 @@ -615,3 +699,30 @@ def test_model_fusion_full_process(): # weak assertion, but anyway... assert len(trainer.history["train_loss"]) == n_epochs + + +@pytest.mark.parametrize("fuse_models", [True, False]) +def test_catch_tabnet_error(fuse_models): + + tabnet_user = TabNet( + column_idx=tab_preprocessor_user.column_idx, + cat_embed_input=tab_preprocessor_user.cat_embed_input, + continuous_cols=tab_preprocessor_user.continuous_cols, + ) + + tab_mlp_item = TabMlp( + column_idx=tab_preprocessor_item.column_idx, + cat_embed_input=tab_preprocessor_item.cat_embed_input, + continuous_cols=tab_preprocessor_item.continuous_cols, + ) + + if fuse_models: + with pytest.raises(ValueError): + fused_model = ModelFuser( # noqa: F841 + models=[tabnet_user, tab_mlp_item], + fusion_method="mean", + projection_method="max", + ) + else: + with pytest.raises(ValueError): + model = WideDeep(deeptabular=[tabnet_user, tab_mlp_item]) # noqa: F841