diff --git a/examples/scripts/adult_census_bayesian_tabmlp.py b/examples/scripts/adult_census_bayesian_tabmlp.py index 758285e8..682f2808 100644 --- a/examples/scripts/adult_census_bayesian_tabmlp.py +++ b/examples/scripts/adult_census_bayesian_tabmlp.py @@ -61,9 +61,9 @@ model = BayesianWide( input_dim=np.unique(X_tab).shape[0], - pred_dim=df["age_buckets"].nunique() - if objective == "multiclass" - else 1, + pred_dim=( + df["age_buckets"].nunique() if objective == "multiclass" else 1 + ), prior_sigma_1=1.0, prior_sigma_2=0.002, prior_pi=0.8, @@ -88,9 +88,9 @@ prior_pi=0.8, posterior_mu_init=0, posterior_rho_init=-7.0, - pred_dim=df["age_buckets"].nunique() - if objective == "multiclass" - else 1, + pred_dim=( + df["age_buckets"].nunique() if objective == "multiclass" else 1 + ), ) model_checkpoint = ModelCheckpoint( diff --git a/pytorch_widedeep/bayesian_models/tabular/bayesian_mlp/bayesian_tab_mlp.py b/pytorch_widedeep/bayesian_models/tabular/bayesian_mlp/bayesian_tab_mlp.py index 0b31784e..bead3d8e 100644 --- a/pytorch_widedeep/bayesian_models/tabular/bayesian_mlp/bayesian_tab_mlp.py +++ b/pytorch_widedeep/bayesian_models/tabular/bayesian_mlp/bayesian_tab_mlp.py @@ -226,9 +226,9 @@ def __init__( prior_pi=self.prior_pi, posterior_mu_init=self.posterior_mu_init, posterior_rho_init=self.posterior_rho_init, - use_bias=False - if self.use_cont_bias is None - else self.use_cont_bias, + use_bias=( + False if self.use_cont_bias is None else self.use_cont_bias + ), activation_fn=self.cont_embed_activation, ) self.cont_out_dim = len(self.continuous_cols) * self.cont_embed_dim diff --git a/pytorch_widedeep/callbacks.py b/pytorch_widedeep/callbacks.py index d73957c1..061fa691 100644 --- a/pytorch_widedeep/callbacks.py +++ b/pytorch_widedeep/callbacks.py @@ -3,6 +3,7 @@ CREDIT TO THE TORCHSAMPLE AND KERAS TEAMS """ + import os import copy import datetime diff --git a/pytorch_widedeep/models/tabular/_base_tabular_model.py b/pytorch_widedeep/models/tabular/_base_tabular_model.py index 60874528..b8f04564 100644 --- a/pytorch_widedeep/models/tabular/_base_tabular_model.py +++ b/pytorch_widedeep/models/tabular/_base_tabular_model.py @@ -101,9 +101,9 @@ def __init__( self.cat_embed = DiffSizeCatEmbeddings( column_idx=self.column_idx, embed_input=self.cat_embed_input, - embed_dropout=0.0 - if self.cat_embed_dropout is None - else self.cat_embed_dropout, + embed_dropout=( + 0.0 if self.cat_embed_dropout is None else self.cat_embed_dropout + ), use_bias=False if self.use_cat_bias is None else self.use_cat_bias, activation_fn=self.cat_embed_activation, ) @@ -253,20 +253,22 @@ def __init__( embed_dim=self.input_dim, column_idx=self.column_idx, embed_input=self.cat_embed_input, - embed_dropout=0.0 - if self.cat_embed_dropout is None - else self.cat_embed_dropout, - full_embed_dropout=False - if self.full_embed_dropout is None - else self.full_embed_dropout, + embed_dropout=( + 0.0 if self.cat_embed_dropout is None else self.cat_embed_dropout + ), + full_embed_dropout=( + False + if self.full_embed_dropout is None + else self.full_embed_dropout + ), use_bias=False if self.use_cat_bias is None else self.use_cat_bias, shared_embed=False if self.shared_embed is None else self.shared_embed, - add_shared_embed=False - if self.add_shared_embed is None - else self.add_shared_embed, - frac_shared_embed=0.0 - if self.frac_shared_embed is None - else self.frac_shared_embed, + add_shared_embed=( + False if self.add_shared_embed is None else self.add_shared_embed + ), + frac_shared_embed=( + 0.0 if self.frac_shared_embed is None else self.frac_shared_embed + ), activation_fn=self.cat_embed_activation, ) @@ -359,9 +361,9 @@ def _set_continous_embeddings_layer( n_cont_cols=len(continuous_cols), embed_dim=cont_embed_dim, embed_dropout=0.0 if cont_embed_dropout is None else cont_embed_dropout, - full_embed_dropout=False - if full_embed_dropout is None - else full_embed_dropout, + full_embed_dropout=( + False if full_embed_dropout is None else full_embed_dropout + ), activation_fn=cont_embed_activation, ) @@ -377,9 +379,9 @@ def _set_continous_embeddings_layer( quantization_setup=quantization_setup, embed_dim=cont_embed_dim, embed_dropout=0.0 if cont_embed_dropout is None else cont_embed_dropout, - full_embed_dropout=False - if full_embed_dropout is None - else full_embed_dropout, + full_embed_dropout=( + False if full_embed_dropout is None else full_embed_dropout + ), activation_fn=cont_embed_activation, ) @@ -396,9 +398,9 @@ def _set_continous_embeddings_layer( n_cont_cols=len(continuous_cols), embed_dim=cont_embed_dim, embed_dropout=0.0 if cont_embed_dropout is None else cont_embed_dropout, - full_embed_dropout=False - if full_embed_dropout is None - else full_embed_dropout, + full_embed_dropout=( + False if full_embed_dropout is None else full_embed_dropout + ), n_frequencies=n_frequencies, sigma=sigma, share_last_layer=share_last_layer, diff --git a/pytorch_widedeep/models/tabular/resnet/tab_resnet.py b/pytorch_widedeep/models/tabular/resnet/tab_resnet.py index 6df54d00..22830ba3 100644 --- a/pytorch_widedeep/models/tabular/resnet/tab_resnet.py +++ b/pytorch_widedeep/models/tabular/resnet/tab_resnet.py @@ -252,17 +252,19 @@ def __init__( if self.mlp_hidden_dims is not None: self.mlp = MLP( d_hidden=[self.blocks_dims[-1]] + self.mlp_hidden_dims, - activation="relu" - if self.mlp_activation is None - else self.mlp_activation, + activation=( + "relu" if self.mlp_activation is None else self.mlp_activation + ), dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout, batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm, - batchnorm_last=False - if self.mlp_batchnorm_last is None - else self.mlp_batchnorm_last, - linear_first=True - if self.mlp_linear_first is None - else self.mlp_linear_first, + batchnorm_last=( + False + if self.mlp_batchnorm_last is None + else self.mlp_batchnorm_last + ), + linear_first=( + True if self.mlp_linear_first is None else self.mlp_linear_first + ), ) else: self.mlp = None @@ -399,17 +401,19 @@ def __init__( if self.mlp_hidden_dims is not None: self.mlp = MLP( d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim, - activation="relu" - if self.mlp_activation is None - else self.mlp_activation, + activation=( + "relu" if self.mlp_activation is None else self.mlp_activation + ), dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout, batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm, - batchnorm_last=False - if self.mlp_batchnorm_last is None - else self.mlp_batchnorm_last, - linear_first=True - if self.mlp_linear_first is None - else self.mlp_linear_first, + batchnorm_last=( + False + if self.mlp_batchnorm_last is None + else self.mlp_batchnorm_last + ), + linear_first=( + True if self.mlp_linear_first is None else self.mlp_linear_first + ), ) self.decoder = DenseResnet( self.mlp_hidden_dims[-1], diff --git a/pytorch_widedeep/models/tabular/self_supervised/_augmentations.py b/pytorch_widedeep/models/tabular/self_supervised/_augmentations.py index b38b9871..7a59e10b 100644 --- a/pytorch_widedeep/models/tabular/self_supervised/_augmentations.py +++ b/pytorch_widedeep/models/tabular/self_supervised/_augmentations.py @@ -5,6 +5,7 @@ - CutMix: https://github.com/clovaai/CutMix-PyTorch - MixUp: https://github.com/facebookresearch/mixup-cifar10 """ + import numpy as np import torch diff --git a/pytorch_widedeep/models/tabular/self_supervised/contrastive_denoising_model.py b/pytorch_widedeep/models/tabular/self_supervised/contrastive_denoising_model.py index b03f0431..7fc3eecd 100644 --- a/pytorch_widedeep/models/tabular/self_supervised/contrastive_denoising_model.py +++ b/pytorch_widedeep/models/tabular/self_supervised/contrastive_denoising_model.py @@ -71,9 +71,7 @@ def __init__( self.denoise_cont_mlp, ) = self._set_cat_and_cont_denoise_mlps() - def forward( - self, X: Tensor - ) -> Tuple[ + def forward(self, X: Tensor) -> Tuple[ Optional[Tuple[Tensor, Tensor]], Optional[Tuple[Tensor, Tensor]], Optional[Tuple[Tensor, Tensor]], @@ -190,13 +188,13 @@ def _set_cat_and_cont_denoise_mlps( if self.model.continuous_cols is not None: if self.cont_mlp_type == "single": - denoise_cont_mlp: Union[ - ContSingleMlp, ContMlpPerFeature - ] = ContSingleMlp( - self.model.input_dim, - self.model.continuous_cols, - self.model.column_idx, - self.denoise_mlps_activation, + denoise_cont_mlp: Union[ContSingleMlp, ContMlpPerFeature] = ( + ContSingleMlp( + self.model.input_dim, + self.model.continuous_cols, + self.model.column_idx, + self.denoise_mlps_activation, + ) ) elif self.cont_mlp_type == "multiple": denoise_cont_mlp = ContMlpPerFeature( diff --git a/pytorch_widedeep/models/tabular/tabnet/_layers.py b/pytorch_widedeep/models/tabular/tabnet/_layers.py index f435f0f4..b6fbec6b 100644 --- a/pytorch_widedeep/models/tabular/tabnet/_layers.py +++ b/pytorch_widedeep/models/tabular/tabnet/_layers.py @@ -206,9 +206,9 @@ def __init__( self.bn = nn.BatchNorm1d(output_dim, momentum=momentum) if mask_type == "sparsemax": - self.mask: Union[ - sparsemax.Sparsemax, sparsemax.Entmax15 - ] = sparsemax.Sparsemax(dim=-1) + self.mask: Union[sparsemax.Sparsemax, sparsemax.Entmax15] = ( + sparsemax.Sparsemax(dim=-1) + ) elif mask_type == "entmax": self.mask = sparsemax.Entmax15(dim=-1) else: diff --git a/pytorch_widedeep/models/tabular/transformers/ft_transformer.py b/pytorch_widedeep/models/tabular/transformers/ft_transformer.py index fbd38f8b..bca5db35 100644 --- a/pytorch_widedeep/models/tabular/transformers/ft_transformer.py +++ b/pytorch_widedeep/models/tabular/transformers/ft_transformer.py @@ -308,17 +308,19 @@ def __init__( if self.mlp_hidden_dims is not None: self.mlp = MLP( d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim, - activation="relu" - if self.mlp_activation is None - else self.mlp_activation, + activation=( + "relu" if self.mlp_activation is None else self.mlp_activation + ), dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout, batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm, - batchnorm_last=False - if self.mlp_batchnorm_last is None - else self.mlp_batchnorm_last, - linear_first=False - if self.mlp_linear_first is None - else self.mlp_linear_first, + batchnorm_last=( + False + if self.mlp_batchnorm_last is None + else self.mlp_batchnorm_last + ), + linear_first=( + False if self.mlp_linear_first is None else self.mlp_linear_first + ), ) else: self.mlp = None diff --git a/pytorch_widedeep/models/tabular/transformers/saint.py b/pytorch_widedeep/models/tabular/transformers/saint.py index fbc967e7..0e43c730 100644 --- a/pytorch_widedeep/models/tabular/transformers/saint.py +++ b/pytorch_widedeep/models/tabular/transformers/saint.py @@ -289,17 +289,19 @@ def __init__( if self.mlp_hidden_dims is not None: self.mlp = MLP( d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim, - activation="relu" - if self.mlp_activation is None - else self.mlp_activation, + activation=( + "relu" if self.mlp_activation is None else self.mlp_activation + ), dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout, batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm, - batchnorm_last=False - if self.mlp_batchnorm_last is None - else self.mlp_batchnorm_last, - linear_first=False - if self.mlp_linear_first is None - else self.mlp_linear_first, + batchnorm_last=( + False + if self.mlp_batchnorm_last is None + else self.mlp_batchnorm_last + ), + linear_first=( + False if self.mlp_linear_first is None else self.mlp_linear_first + ), ) else: self.mlp = None diff --git a/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py b/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py index cb65fb76..2cfeb089 100644 --- a/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py +++ b/pytorch_widedeep/models/tabular/transformers/tab_fastformer.py @@ -325,17 +325,19 @@ def __init__( if self.mlp_hidden_dims is not None: self.mlp = MLP( d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim, - activation="relu" - if self.mlp_activation is None - else self.mlp_activation, + activation=( + "relu" if self.mlp_activation is None else self.mlp_activation + ), dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout, batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm, - batchnorm_last=False - if self.mlp_batchnorm_last is None - else self.mlp_batchnorm_last, - linear_first=False - if self.mlp_linear_first is None - else self.mlp_linear_first, + batchnorm_last=( + False + if self.mlp_batchnorm_last is None + else self.mlp_batchnorm_last + ), + linear_first=( + False if self.mlp_linear_first is None else self.mlp_linear_first + ), ) else: self.mlp = None diff --git a/pytorch_widedeep/models/tabular/transformers/tab_perceiver.py b/pytorch_widedeep/models/tabular/transformers/tab_perceiver.py index 479def5a..78035a28 100644 --- a/pytorch_widedeep/models/tabular/transformers/tab_perceiver.py +++ b/pytorch_widedeep/models/tabular/transformers/tab_perceiver.py @@ -322,17 +322,19 @@ def __init__( if self.mlp_hidden_dims is not None: self.mlp = MLP( d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim, - activation="relu" - if self.mlp_activation is None - else self.mlp_activation, + activation=( + "relu" if self.mlp_activation is None else self.mlp_activation + ), dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout, batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm, - batchnorm_last=False - if self.mlp_batchnorm_last is None - else self.mlp_batchnorm_last, - linear_first=False - if self.mlp_linear_first is None - else self.mlp_linear_first, + batchnorm_last=( + False + if self.mlp_batchnorm_last is None + else self.mlp_batchnorm_last + ), + linear_first=( + False if self.mlp_linear_first is None else self.mlp_linear_first + ), ) else: self.mlp = None diff --git a/pytorch_widedeep/models/tabular/transformers/tab_transformer.py b/pytorch_widedeep/models/tabular/transformers/tab_transformer.py index 1db83600..b6f00646 100644 --- a/pytorch_widedeep/models/tabular/transformers/tab_transformer.py +++ b/pytorch_widedeep/models/tabular/transformers/tab_transformer.py @@ -312,17 +312,19 @@ def __init__( if self.mlp_hidden_dims is not None: self.mlp = MLP( d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim, - activation="relu" - if self.mlp_activation is None - else self.mlp_activation, + activation=( + "relu" if self.mlp_activation is None else self.mlp_activation + ), dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout, batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm, - batchnorm_last=False - if self.mlp_batchnorm_last is None - else self.mlp_batchnorm_last, - linear_first=False - if self.mlp_linear_first is None - else self.mlp_linear_first, + batchnorm_last=( + False + if self.mlp_batchnorm_last is None + else self.mlp_batchnorm_last + ), + linear_first=( + False if self.mlp_linear_first is None else self.mlp_linear_first + ), ) else: self.mlp = None diff --git a/pytorch_widedeep/models/text/basic_transformer.py b/pytorch_widedeep/models/text/basic_transformer.py index 9ff721ca..d4dd2206 100644 --- a/pytorch_widedeep/models/text/basic_transformer.py +++ b/pytorch_widedeep/models/text/basic_transformer.py @@ -127,9 +127,9 @@ def __init__( if with_pos_encoding: if pos_encoder is not None: - self.pos_encoder: Union[ - nn.Module, nn.Identity, PositionalEncoding - ] = pos_encoder + self.pos_encoder: Union[nn.Module, nn.Identity, PositionalEncoding] = ( + pos_encoder + ) else: self.pos_encoder = PositionalEncoding( input_dim, pos_encoding_dropout, seq_length diff --git a/pytorch_widedeep/preprocessing/tab_preprocessor.py b/pytorch_widedeep/preprocessing/tab_preprocessor.py index ad67327e..24afef45 100644 --- a/pytorch_widedeep/preprocessing/tab_preprocessor.py +++ b/pytorch_widedeep/preprocessing/tab_preprocessor.py @@ -325,9 +325,9 @@ def fit(self, df: pd.DataFrame) -> BasePreprocessor: # noqa: C901 # Categorical embeddings logic if self.cat_embed_cols is not None or self.quantization_setup is not None: - self.cat_embed_input: List[ - Union[Tuple[str, int], Tuple[str, int, int]] - ] = [] + self.cat_embed_input: List[Union[Tuple[str, int], Tuple[str, int, int]]] = ( + [] + ) if self.cat_embed_cols is not None: df_cat, cat_embed_dim = self._prepare_categorical(df_adj) @@ -546,11 +546,13 @@ def _prepare_continuous( ( col, self.quantization_setup, - embed_sz_rule( - self.quantization_setup + 1, self.embedding_rule # type: ignore[arg-type] - ) - if self.auto_embed_dim - else self.default_embed_dim, + ( + embed_sz_rule( + self.quantization_setup + 1, self.embedding_rule # type: ignore[arg-type] + ) + if self.auto_embed_dim + else self.default_embed_dim + ), ) ) else: @@ -560,11 +562,13 @@ def _prepare_continuous( ( col, val, - embed_sz_rule( - val + 1, self.embedding_rule # type: ignore[arg-type] - ) - if self.auto_embed_dim - else self.default_embed_dim, + ( + embed_sz_rule( + val + 1, self.embedding_rule # type: ignore[arg-type] + ) + if self.auto_embed_dim + else self.default_embed_dim + ), ) ) else: @@ -572,9 +576,11 @@ def _prepare_continuous( ( col, len(val) - 1, - embed_sz_rule(len(val), self.embedding_rule) # type: ignore[arg-type] - if self.auto_embed_dim - else self.default_embed_dim, + ( + embed_sz_rule(len(val), self.embedding_rule) # type: ignore[arg-type] + if self.auto_embed_dim + else self.default_embed_dim + ), ) ) diff --git a/pytorch_widedeep/preprocessing/wide_preprocessor.py b/pytorch_widedeep/preprocessing/wide_preprocessor.py index c5cdf3af..caa8b7cb 100644 --- a/pytorch_widedeep/preprocessing/wide_preprocessor.py +++ b/pytorch_widedeep/preprocessing/wide_preprocessor.py @@ -119,9 +119,11 @@ def transform(self, df: pd.DataFrame) -> np.ndarray: encoded = np.zeros([len(df_wide), len(self.wide_crossed_cols)]) for col_i, col in enumerate(self.wide_crossed_cols): encoded[:, col_i] = df_wide[col].apply( - lambda x: self.encoding_dict[col + "_" + str(x)] - if col + "_" + str(x) in self.encoding_dict - else 0 + lambda x: ( + self.encoding_dict[col + "_" + str(x)] + if col + "_" + str(x) in self.encoding_dict + else 0 + ) ) return encoded.astype("int64") diff --git a/pytorch_widedeep/training/_trainer_utils.py b/pytorch_widedeep/training/_trainer_utils.py index 248ea3ae..ddb88164 100644 --- a/pytorch_widedeep/training/_trainer_utils.py +++ b/pytorch_widedeep/training/_trainer_utils.py @@ -171,9 +171,11 @@ def wd_train_val_split( # noqa: C901 np.arange(len(X_train["target"])), test_size=val_split, random_state=seed, - stratify=X_train["target"] - if method not in ["regression", "qregression"] - else None, + stratify=( + X_train["target"] + if method not in ["regression", "qregression"] + else None + ), ) X_tr, X_val = {"target": y_tr}, {"target": y_val} if "X_wide" in X_train.keys(): diff --git a/tests/test_data_utils/test_du_tabular.py b/tests/test_data_utils/test_du_tabular.py index 2f167f39..10325bf3 100644 --- a/tests/test_data_utils/test_du_tabular.py +++ b/tests/test_data_utils/test_du_tabular.py @@ -383,9 +383,11 @@ def test_quantization(quantization_setup, expected_bins_col1, expected_bins_col2 assert ( len(set(X_quant[:, 0])) == expected_bins_col1 if expected_bins_col1 - else 20 and len(set(X_quant[:, 1])) == expected_bins_col2 - if expected_bins_col2 - else 20 + else ( + 20 and len(set(X_quant[:, 1])) == expected_bins_col2 + if expected_bins_col2 + else 20 + ) ) diff --git a/tests/test_model_components/test_mc_transformers.py b/tests/test_model_components/test_mc_transformers.py index 29a83879..db58be51 100644 --- a/tests/test_model_components/test_mc_transformers.py +++ b/tests/test_model_components/test_mc_transformers.py @@ -232,9 +232,9 @@ def test_embed_continuous_and_with_cls_token_tabtransformer( params = { "column_idx": {k: v for v, k in enumerate(n_colnames)}, - "cat_embed_input": with_cls_token_embed_input - if with_cls_token - else embed_input, + "cat_embed_input": ( + with_cls_token_embed_input if with_cls_token else embed_input + ), "continuous_cols": n_colnames[cont_idx:], "embed_continuous": embed_continuous, "embed_continuous_method": "standard" if embed_continuous else None, @@ -287,9 +287,9 @@ def test_embed_continuous_and_with_cls_token_transformer_family( params = { "column_idx": {k: v for v, k in enumerate(n_colnames)}, - "cat_embed_input": with_cls_token_embed_input - if with_cls_token - else embed_input, + "cat_embed_input": ( + with_cls_token_embed_input if with_cls_token else embed_input + ), "continuous_cols": n_colnames[cont_idx:], }