Skip to content

Commit

Permalink
A further increased in coverage of edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzaurin committed Aug 26, 2024
1 parent 07155ec commit 16922ce
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ def test_full_process_without_fusion(head_type):
trainer = Trainer(
model,
objective="binary",
verbose=0,
)

X_train = {
Expand All @@ -694,7 +695,6 @@ def test_full_process_without_fusion(head_type):
X_val=X_val,
n_epochs=n_epochs,
batch_size=4,
verbose=1,
)

# weak assertion, but anyway...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,8 @@ def test_model_fusion_projection_methods(projection_method):
assert out.shape[1] == proj_dim == models_fuser.output_dim


def test_model_fusion_full_process():
@pytest.mark.parametrize("head_type", [None, "via_params", "custom"])
def test_full_process_with_fusion(head_type):

fused_text_model = ModelFuser(
models=[rnn_1, rnn_2],
Expand All @@ -708,17 +709,102 @@ def test_model_fusion_full_process():
projection_method="max",
)

if head_type == "via_params":
head_hidden_dims = [
fused_text_model.output_dim,
fused_image_model.output_dim + tab_mlp.output_dim,
8,
]
custom_head = None
elif head_type == "custom":
custom_head = CustomHead(
fused_text_model.output_dim
+ fused_image_model.output_dim
+ tab_mlp.output_dim,
8,
)
head_hidden_dims = None
else:
custom_head = None
head_hidden_dims = None

model = WideDeep(
deeptabular=tab_mlp,
deeptext=fused_text_model,
deepimage=fused_image_model,
head_hidden_dims=head_hidden_dims,
deephead=custom_head,
pred_dim=1,
)

n_epochs = 2
trainer = Trainer(
model,
objective="binary",
verbose=0,
)

X_train = {
"X_tab": X_tab_tr,
"X_text": [X_text_tr_1, X_text_tr_2],
"X_img": [X_img_tr_1, X_img_tr_2],
"target": train_df["target"].values,
}
X_val = {
"X_tab": X_tab_val,
"X_text": [X_text_val_1, X_text_val_2],
"X_img": [X_img_val_1, X_img_val_2],
"target": valid_df["target"].values,
}
trainer.fit(
X_train=X_train,
X_val=X_val,
n_epochs=n_epochs,
batch_size=4,
)

# weak assertion, but anyway...
assert len(trainer.history["train_loss"]) == n_epochs


@pytest.mark.parametrize("head_type", [None, "via_params", "custom"])
def test_full_process_without_fusion(head_type):

if head_type == "via_params":
head_hidden_dims = [
rnn_1.output_dim + rnn_2.output_dim,
vision_1.output_dim + vision_2.output_dim + tab_mlp.output_dim,
8,
]
custom_head = None
elif head_type == "custom":
custom_head = CustomHead(
rnn_1.output_dim
+ rnn_2.output_dim
+ vision_1.output_dim
+ vision_2.output_dim
+ tab_mlp.output_dim,
8,
)
head_hidden_dims = None
else:
custom_head = None
head_hidden_dims = None

model = WideDeep(
deeptabular=tab_mlp,
deeptext=[rnn_1, rnn_2],
deepimage=[vision_1, vision_2],
head_hidden_dims=head_hidden_dims,
deephead=custom_head,
pred_dim=1,
)

n_epochs = 2
trainer = Trainer(
model,
objective="binary",
verbose=0,
)

X_train = {
Expand All @@ -738,7 +824,6 @@ def test_model_fusion_full_process():
X_val=X_val,
n_epochs=n_epochs,
batch_size=4,
verbose=1,
)

# weak assertion, but anyway...
Expand Down

0 comments on commit 16922ce

Please sign in to comment.