diff --git a/README.md b/README.md index 50a4ffdf..ddf11304 100644 --- a/README.md +++ b/README.md @@ -587,13 +587,111 @@ trainer.fit( ) ``` -**7. Tabular with a multi-target loss** +**7. A two-tower model** + +This is a popular model in the context of recommendation systems. Let's say we +have a tabular dataset formed my triples (user features, item features, +target). We can create a two-tower model where the user and item features are +passed through two separate models and then "fused" via a dot product. + +
+ +
+ + +```python +import numpy as np +import pandas as pd + +from pytorch_widedeep import Trainer +from pytorch_widedeep.preprocessing import TabPreprocessor +from pytorch_widedeep.models import TabMlp, WideDeep, ModelFuser + +# Let's create the interaction dataset +# user_features dataframe +np.random.seed(42) +user_ids = np.arange(1, 101) +ages = np.random.randint(18, 60, size=100) +genders = np.random.choice(["male", "female"], size=100) +locations = np.random.choice(["city_a", "city_b", "city_c", "city_d"], size=100) +user_features = pd.DataFrame( + {"id": user_ids, "age": ages, "gender": genders, "location": locations} +) + +# item_features dataframe +item_ids = np.arange(1, 101) +prices = np.random.uniform(10, 500, size=100).round(2) +colors = np.random.choice(["red", "blue", "green", "black"], size=100) +categories = np.random.choice(["electronics", "clothing", "home", "toys"], size=100) + +item_features = pd.DataFrame( + {"id": item_ids, "price": prices, "color": colors, "category": categories} +) + +# Interactions dataframe +interaction_user_ids = np.random.choice(user_ids, size=1000) +interaction_item_ids = np.random.choice(item_ids, size=1000) +purchased = np.random.choice([0, 1], size=1000, p=[0.7, 0.3]) +interactions = pd.DataFrame( + { + "user_id": interaction_user_ids, + "item_id": interaction_item_ids, + "purchased": purchased, + } +) +user_item_purchased = interactions.merge( + user_features, left_on="user_id", right_on="id" +).merge(item_features, left_on="item_id", right_on="id") + +# Users +tab_preprocessor_user = TabPreprocessor( + cat_embed_cols=["gender", "location"], + continuous_cols=["age"], +) +X_user = tab_preprocessor_user.fit_transform(user_item_purchased) +tab_mlp_user = TabMlp( + column_idx=tab_preprocessor_user.column_idx, + cat_embed_input=tab_preprocessor_user.cat_embed_input, + continuous_cols=["age"], + mlp_hidden_dims=[16, 8], + mlp_dropout=[0.2, 0.2], +) + +# Items +tab_preprocessor_item = TabPreprocessor( + cat_embed_cols=["color", "category"], + continuous_cols=["price"], +) +X_item = tab_preprocessor_item.fit_transform(user_item_purchased) +tab_mlp_item = TabMlp( + column_idx=tab_preprocessor_item.column_idx, + cat_embed_input=tab_preprocessor_item.cat_embed_input, + continuous_cols=["price"], + mlp_hidden_dims=[16, 8], + mlp_dropout=[0.2, 0.2], +) + +two_tower_model = ModelFuser([tab_mlp_user, tab_mlp_item], fusion_method="dot") + +model = WideDeep(deeptabular=two_tower_model) + +trainer = Trainer(model, objective="binary") + +trainer.fit( + X_tab=[X_user, X_item], + target=interactions.purchased.values, + n_epochs=1, + batch_size=32, +) +``` + +**8. Tabular with a multi-target loss** This one is "a bonus" to illustrate the use of multi-target losses, more than actually a different architecture.- +
diff --git a/VERSION b/VERSION index fdd3be6d..266146b8 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.6.2 +1.6.3 diff --git a/docs/examples.rst b/docs/examples.rst index b78fd394..489560fa 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -17,5 +17,4 @@ them to address different problems * `HyperParameter Tuning With RayTune\n", - " | MedInc | \n", - "HouseAge | \n", - "AveRooms | \n", - "AveBedrms | \n", - "Population | \n", - "AveOccup | \n", - "Latitude | \n", - "Longitude | \n", - "MedHouseVal | \n", - "
---|---|---|---|---|---|---|---|---|---|
0 | \n", - "8.3252 | \n", - "41.0 | \n", - "6.984127 | \n", - "1.023810 | \n", - "322.0 | \n", - "2.555556 | \n", - "37.88 | \n", - "-122.23 | \n", - "4.526 | \n", - "
1 | \n", - "8.3014 | \n", - "21.0 | \n", - "6.238137 | \n", - "0.971880 | \n", - "2401.0 | \n", - "2.109842 | \n", - "37.86 | \n", - "-122.22 | \n", - "3.585 | \n", - "
2 | \n", - "7.2574 | \n", - "52.0 | \n", - "8.288136 | \n", - "1.073446 | \n", - "496.0 | \n", - "2.802260 | \n", - "37.85 | \n", - "-122.24 | \n", - "3.521 | \n", - "
3 | \n", - "5.6431 | \n", - "52.0 | \n", - "5.817352 | \n", - "1.073059 | \n", - "558.0 | \n", - "2.547945 | \n", - "37.85 | \n", - "-122.25 | \n", - "3.413 | \n", - "
4 | \n", - "3.8462 | \n", - "52.0 | \n", - "6.281853 | \n", - "1.081081 | \n", - "565.0 | \n", - "2.181467 | \n", - "37.85 | \n", - "-122.25 | \n", - "3.422 | \n", - "