Skip to content

Commit

Permalink
Updated style libraries
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzaurin committed Feb 16, 2024
1 parent bd76625 commit 973f987
Show file tree
Hide file tree
Showing 19 changed files with 171 additions and 143 deletions.
12 changes: 6 additions & 6 deletions examples/scripts/adult_census_bayesian_tabmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@

model = BayesianWide(
input_dim=np.unique(X_tab).shape[0],
pred_dim=df["age_buckets"].nunique()
if objective == "multiclass"
else 1,
pred_dim=(
df["age_buckets"].nunique() if objective == "multiclass" else 1
),
prior_sigma_1=1.0,
prior_sigma_2=0.002,
prior_pi=0.8,
Expand All @@ -88,9 +88,9 @@
prior_pi=0.8,
posterior_mu_init=0,
posterior_rho_init=-7.0,
pred_dim=df["age_buckets"].nunique()
if objective == "multiclass"
else 1,
pred_dim=(
df["age_buckets"].nunique() if objective == "multiclass" else 1
),
)

model_checkpoint = ModelCheckpoint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ def __init__(
prior_pi=self.prior_pi,
posterior_mu_init=self.posterior_mu_init,
posterior_rho_init=self.posterior_rho_init,
use_bias=False
if self.use_cont_bias is None
else self.use_cont_bias,
use_bias=(
False if self.use_cont_bias is None else self.use_cont_bias
),
activation_fn=self.cont_embed_activation,
)
self.cont_out_dim = len(self.continuous_cols) * self.cont_embed_dim
Expand Down
1 change: 1 addition & 0 deletions pytorch_widedeep/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
CREDIT TO THE TORCHSAMPLE AND KERAS TEAMS
"""

import os
import copy
import datetime
Expand Down
50 changes: 26 additions & 24 deletions pytorch_widedeep/models/tabular/_base_tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def __init__(
self.cat_embed = DiffSizeCatEmbeddings(
column_idx=self.column_idx,
embed_input=self.cat_embed_input,
embed_dropout=0.0
if self.cat_embed_dropout is None
else self.cat_embed_dropout,
embed_dropout=(
0.0 if self.cat_embed_dropout is None else self.cat_embed_dropout
),
use_bias=False if self.use_cat_bias is None else self.use_cat_bias,
activation_fn=self.cat_embed_activation,
)
Expand Down Expand Up @@ -253,20 +253,22 @@ def __init__(
embed_dim=self.input_dim,
column_idx=self.column_idx,
embed_input=self.cat_embed_input,
embed_dropout=0.0
if self.cat_embed_dropout is None
else self.cat_embed_dropout,
full_embed_dropout=False
if self.full_embed_dropout is None
else self.full_embed_dropout,
embed_dropout=(
0.0 if self.cat_embed_dropout is None else self.cat_embed_dropout
),
full_embed_dropout=(
False
if self.full_embed_dropout is None
else self.full_embed_dropout
),
use_bias=False if self.use_cat_bias is None else self.use_cat_bias,
shared_embed=False if self.shared_embed is None else self.shared_embed,
add_shared_embed=False
if self.add_shared_embed is None
else self.add_shared_embed,
frac_shared_embed=0.0
if self.frac_shared_embed is None
else self.frac_shared_embed,
add_shared_embed=(
False if self.add_shared_embed is None else self.add_shared_embed
),
frac_shared_embed=(
0.0 if self.frac_shared_embed is None else self.frac_shared_embed
),
activation_fn=self.cat_embed_activation,
)

Expand Down Expand Up @@ -359,9 +361,9 @@ def _set_continous_embeddings_layer(
n_cont_cols=len(continuous_cols),
embed_dim=cont_embed_dim,
embed_dropout=0.0 if cont_embed_dropout is None else cont_embed_dropout,
full_embed_dropout=False
if full_embed_dropout is None
else full_embed_dropout,
full_embed_dropout=(
False if full_embed_dropout is None else full_embed_dropout
),
activation_fn=cont_embed_activation,
)

Expand All @@ -377,9 +379,9 @@ def _set_continous_embeddings_layer(
quantization_setup=quantization_setup,
embed_dim=cont_embed_dim,
embed_dropout=0.0 if cont_embed_dropout is None else cont_embed_dropout,
full_embed_dropout=False
if full_embed_dropout is None
else full_embed_dropout,
full_embed_dropout=(
False if full_embed_dropout is None else full_embed_dropout
),
activation_fn=cont_embed_activation,
)

Expand All @@ -396,9 +398,9 @@ def _set_continous_embeddings_layer(
n_cont_cols=len(continuous_cols),
embed_dim=cont_embed_dim,
embed_dropout=0.0 if cont_embed_dropout is None else cont_embed_dropout,
full_embed_dropout=False
if full_embed_dropout is None
else full_embed_dropout,
full_embed_dropout=(
False if full_embed_dropout is None else full_embed_dropout
),
n_frequencies=n_frequencies,
sigma=sigma,
share_last_layer=share_last_layer,
Expand Down
40 changes: 22 additions & 18 deletions pytorch_widedeep/models/tabular/resnet/tab_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,17 +252,19 @@ def __init__(
if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.blocks_dims[-1]] + self.mlp_hidden_dims,
activation="relu"
if self.mlp_activation is None
else self.mlp_activation,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout,
batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm,
batchnorm_last=False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last,
linear_first=True
if self.mlp_linear_first is None
else self.mlp_linear_first,
batchnorm_last=(
False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last
),
linear_first=(
True if self.mlp_linear_first is None else self.mlp_linear_first
),
)
else:
self.mlp = None
Expand Down Expand Up @@ -399,17 +401,19 @@ def __init__(
if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim,
activation="relu"
if self.mlp_activation is None
else self.mlp_activation,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout,
batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm,
batchnorm_last=False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last,
linear_first=True
if self.mlp_linear_first is None
else self.mlp_linear_first,
batchnorm_last=(
False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last
),
linear_first=(
True if self.mlp_linear_first is None else self.mlp_linear_first
),
)
self.decoder = DenseResnet(
self.mlp_hidden_dims[-1],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- CutMix: https://github.com/clovaai/CutMix-PyTorch
- MixUp: https://github.com/facebookresearch/mixup-cifar10
"""

import numpy as np
import torch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def __init__(
self.denoise_cont_mlp,
) = self._set_cat_and_cont_denoise_mlps()

def forward(
self, X: Tensor
) -> Tuple[
def forward(self, X: Tensor) -> Tuple[
Optional[Tuple[Tensor, Tensor]],
Optional[Tuple[Tensor, Tensor]],
Optional[Tuple[Tensor, Tensor]],
Expand Down Expand Up @@ -190,13 +188,13 @@ def _set_cat_and_cont_denoise_mlps(

if self.model.continuous_cols is not None:
if self.cont_mlp_type == "single":
denoise_cont_mlp: Union[
ContSingleMlp, ContMlpPerFeature
] = ContSingleMlp(
self.model.input_dim,
self.model.continuous_cols,
self.model.column_idx,
self.denoise_mlps_activation,
denoise_cont_mlp: Union[ContSingleMlp, ContMlpPerFeature] = (
ContSingleMlp(
self.model.input_dim,
self.model.continuous_cols,
self.model.column_idx,
self.denoise_mlps_activation,
)
)
elif self.cont_mlp_type == "multiple":
denoise_cont_mlp = ContMlpPerFeature(
Expand Down
6 changes: 3 additions & 3 deletions pytorch_widedeep/models/tabular/tabnet/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def __init__(
self.bn = nn.BatchNorm1d(output_dim, momentum=momentum)

if mask_type == "sparsemax":
self.mask: Union[
sparsemax.Sparsemax, sparsemax.Entmax15
] = sparsemax.Sparsemax(dim=-1)
self.mask: Union[sparsemax.Sparsemax, sparsemax.Entmax15] = (
sparsemax.Sparsemax(dim=-1)
)
elif mask_type == "entmax":
self.mask = sparsemax.Entmax15(dim=-1)
else:
Expand Down
20 changes: 11 additions & 9 deletions pytorch_widedeep/models/tabular/transformers/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,17 +308,19 @@ def __init__(
if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim,
activation="relu"
if self.mlp_activation is None
else self.mlp_activation,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout,
batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm,
batchnorm_last=False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last,
linear_first=False
if self.mlp_linear_first is None
else self.mlp_linear_first,
batchnorm_last=(
False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last
),
linear_first=(
False if self.mlp_linear_first is None else self.mlp_linear_first
),
)
else:
self.mlp = None
Expand Down
20 changes: 11 additions & 9 deletions pytorch_widedeep/models/tabular/transformers/saint.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,17 +289,19 @@ def __init__(
if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim,
activation="relu"
if self.mlp_activation is None
else self.mlp_activation,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout,
batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm,
batchnorm_last=False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last,
linear_first=False
if self.mlp_linear_first is None
else self.mlp_linear_first,
batchnorm_last=(
False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last
),
linear_first=(
False if self.mlp_linear_first is None else self.mlp_linear_first
),
)
else:
self.mlp = None
Expand Down
20 changes: 11 additions & 9 deletions pytorch_widedeep/models/tabular/transformers/tab_fastformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,17 +325,19 @@ def __init__(
if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim,
activation="relu"
if self.mlp_activation is None
else self.mlp_activation,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout,
batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm,
batchnorm_last=False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last,
linear_first=False
if self.mlp_linear_first is None
else self.mlp_linear_first,
batchnorm_last=(
False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last
),
linear_first=(
False if self.mlp_linear_first is None else self.mlp_linear_first
),
)
else:
self.mlp = None
Expand Down
20 changes: 11 additions & 9 deletions pytorch_widedeep/models/tabular/transformers/tab_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,17 +322,19 @@ def __init__(
if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim,
activation="relu"
if self.mlp_activation is None
else self.mlp_activation,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout,
batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm,
batchnorm_last=False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last,
linear_first=False
if self.mlp_linear_first is None
else self.mlp_linear_first,
batchnorm_last=(
False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last
),
linear_first=(
False if self.mlp_linear_first is None else self.mlp_linear_first
),
)
else:
self.mlp = None
Expand Down
20 changes: 11 additions & 9 deletions pytorch_widedeep/models/tabular/transformers/tab_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,19 @@ def __init__(
if self.mlp_hidden_dims is not None:
self.mlp = MLP(
d_hidden=[self.mlp_first_hidden_dim] + self.mlp_hidden_dim,
activation="relu"
if self.mlp_activation is None
else self.mlp_activation,
activation=(
"relu" if self.mlp_activation is None else self.mlp_activation
),
dropout=0.0 if self.mlp_dropout is None else self.mlp_dropout,
batchnorm=False if self.mlp_batchnorm is None else self.mlp_batchnorm,
batchnorm_last=False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last,
linear_first=False
if self.mlp_linear_first is None
else self.mlp_linear_first,
batchnorm_last=(
False
if self.mlp_batchnorm_last is None
else self.mlp_batchnorm_last
),
linear_first=(
False if self.mlp_linear_first is None else self.mlp_linear_first
),
)
else:
self.mlp = None
Expand Down
Loading

0 comments on commit 973f987

Please sign in to comment.