From e4ba45d23f4283acc5e4ff60637101f080dd3e20 Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Tue, 29 Oct 2024 21:25:00 -0700 Subject: [PATCH] Add basic ranking example. (#7) Example was run with all 3 backends (tf, jax and torch). --- examples/basic_ranking.py | 224 ++++++++++++++++++++++++++++++++++++ examples/basic_retrieval.py | 34 ++++-- 2 files changed, 246 insertions(+), 12 deletions(-) create mode 100644 examples/basic_ranking.py diff --git a/examples/basic_ranking.py b/examples/basic_ranking.py new file mode 100644 index 0000000..f998b06 --- /dev/null +++ b/examples/basic_ranking.py @@ -0,0 +1,224 @@ +""" +# Recommending movies: ranking + +Recommender systems are often composed of two stages: + +1. The retrieval stage is responsible for selecting an initial set of hundreds + of candidates from all possible candidates. The main objective of this model + is to efficiently weed out all candidates that the user is not interested in. + Because the retrieval model may be dealing with millions of candidates, it + has to be computationally efficient. +2. The ranking stage takes the outputs of the retrieval model and fine-tunes + them to select the best possible handful of recommendations. Its task is to + narrow down the set of items the user may be interested in to a shortlist of + likely candidates. + +In this tutorial, we're going to focus on the first stage, retrieval. If you are +interested in the ranking stage, have a look at our +[retrieval](https://github.com/keras-team/keras-rs/blob/main/examples/basic_retrieval.py) +tutorial. + +In this tutorial, we're going to: + +1. Get our data and split it into a training and test set. +2. Implement a ranking model. +3. Fit and evaluate it. +4. Test running predictions with the model. +""" + +import keras + +# Needed for the dataset +import tensorflow as tf +import tensorflow_datasets as tfds + +""" +## Preparing the dataset + +We're going to use the same data as the +[retrieval](https://github.com/keras-team/keras-rs/blob/main/examples/basic_retrieval.py) +tutorial. The ratings are the objectives we are trying to predict. +""" + +# Ratings data. +ratings = tfds.load("movielens/100k-ratings", split="train") +# Features of all the available movies. +movies = tfds.load("movielens/100k-movies", split="train") + +""" +In the Movielens dataset, user ids are integers (represented as strings) +starting at 1 and with no gap. Normally, you would need to create a lookup table +to map user ids to integers from 0 to N-1. But as a simplication, we'll use the +user id directly as an index in our model, in particular to lookup the user +embedding from the user embedding table. So we need do know the number of users. +""" + +users_count = ( + ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) + .reduce(tf.constant(0, tf.int32), tf.maximum) + .numpy() +) + +""" +In the Movielens dataset, movie ids are integers (represented as strings) +starting at 1 and with no gap. Normally, you would need to create a lookup table +to map movie ids to integers from 0 to N-1. But as a simplication, we'll use the +movie id directly as an index in our model, in particular to lookup the movie +embedding from the movie embedding table. So we need do know the number of +movies. +""" + +movies_count = movies.cardinality().numpy() + +""" +The inputs to the model are the user ids and movie ids and the labels are the +ratings. +""" + + +def preprocess_rating(x): + return ( + # Inputs are user ids and movie ids + { + "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32), + "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), + }, + # Labels are ratings between 0 and 1. + (x["user_rating"] - 1.0) / 4.0, + ) + + +""" +We'll split the data by putting 80% of the ratings in the train set, and 20% in +the test set. +""" + +shuffled_ratings = ratings.map(preprocess_rating).shuffle( + 100_000, seed=42, reshuffle_each_iteration=False +) +train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() +test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() + +""" +## Implementing the Model + +### Architecture + +Ranking models do not face the same efficiency constraints as retrieval models +do, and so we have a little bit more freedom in our choice of architectures. + +A model composed of multiple stacked dense layers is a relatively common +architecture for ranking tasks. We can implement it as follows: +""" + + +class RankingModel(keras.Model): + + def __init__( + self, + num_users, + num_candidates, + embedding_dimension=32, + **kwargs, + ): + super().__init__(**kwargs) + # Embedding table for users. + self.user_embedding = keras.layers.Embedding( + num_users, embedding_dimension + ) + # Embedding table for candidates. + self.candidate_embedding = keras.layers.Embedding( + num_candidates, embedding_dimension + ) + # Predictions. + self.ratings = keras.Sequential( + [ + # Learn multiple dense layers. + keras.layers.Dense(256, activation="relu"), + keras.layers.Dense(64, activation="relu"), + # Make rating predictions in the final layer. + keras.layers.Dense(1), + ] + ) + + def call(self, inputs): + user_id, movie_id = inputs["user_id"], inputs["movie_id"] + user_embeddings = self.user_embedding(user_id) + candidate_embeddings = self.candidate_embedding(movie_id) + return self.ratings( + keras.ops.concatenate( + [user_embeddings, candidate_embeddings], axis=1 + ) + ) + + +""" +Let's first instantiate the model. Note that we add `+ 1` to the number of users +and movies to account for the fact that id zero is not used for either (ids +start at 1), but still takes a row in the embedding tables. +""" + +model = RankingModel(users_count + 1, movies_count + 1) + +""" +### Loss and metrics + +The next component is the loss used to train our model. Keras has several losses +to make this easy. In this instance, we'll make use of the `MeanSquaredError` +loss in order to predict the ratings. We'll also look at the +`RootMeanSquaredError` metric. +""" + +model.compile( + loss=keras.losses.MeanSquaredError(), + metrics=[keras.metrics.RootMeanSquaredError()], + optimizer=keras.optimizers.Adagrad(learning_rate=0.1), +) + +""" +## Fitting and evaluating + +After defining the model, we can use the standard Keras `model.fit()` to train +the model. +""" + +model.fit(train_ratings, epochs=5) + +""" +As the model trains, the loss is falling and the RMSE metric is improving. + +Finally, we can evaluate our model on the test set. The lower the RMSE metric, +the more accurate our model is at predicting ratings. +""" + +model.evaluate(test_ratings, return_dict=True) + +""" +## Testing the ranking model + +So far, we have only handled movies by id. Now is the time to create a mapping +keyed by movie ids to be able to surface the titles. +""" + +movie_id_to_movie_title = { + int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() +} +movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. + +""" +Now we can test the ranking model by computing predictions for a set of movies +and then rank these movies based on the predictions: +""" + +user_id = 42 +movie_ids = [204, 141, 131] +predictions = model.predict( + { + "user_id": keras.ops.array([user_id] * len(movie_ids)), + "movie_id": keras.ops.array(movie_ids), + } +) +predictions = keras.ops.convert_to_numpy(keras.ops.squeeze(predictions, axis=1)) + +for movie_id, prediction in zip(movie_ids, predictions): + print(f"{movie_id_to_movie_title[movie_id]}: {5.0 * prediction:,.2f}") diff --git a/examples/basic_retrieval.py b/examples/basic_retrieval.py index 6ec75d3..7397eaf 100644 --- a/examples/basic_retrieval.py +++ b/examples/basic_retrieval.py @@ -13,6 +13,11 @@ narrow down the set of items the user may be interested in to a shortlist of likely candidates. +In this tutorial, we're going to focus on the first stage, retrieval. If you are +interested in the ranking stage, have a look at our +[ranking](https://github.com/keras-team/keras-rs/blob/main/examples/basic_ranking.py) +tutorial. + Retrieval models are often composed of two sub-models: 1. A query tower computing the query representation (normally a @@ -93,9 +98,11 @@ print(str(data).replace(", '", ",\n '")) """ -User ids are integers (represented as strings) between 1 and 943. We'll use the -user id as an index in our model, in particular to lookup the user embedding -from the user embedding table. So we need do know the number of users. +In the Movielens dataset, user ids are integers (represented as strings) +starting at 1 and with no gap. Normally, you would need to create a lookup table +to map user ids to integers from 0 to N-1. But as a simplication, we'll use the +user id directly as an index in our model, in particular to lookup the user +embedding from the user embedding table. So we need do know the number of users. """ users_count = ( @@ -113,8 +120,10 @@ print(str(data).replace(", '", ",\n '")) """ -Movie ids are integers (represented as strings) between 1 and 1682. We'll use -the movie id as an index in our model, in particular to lookup the movie +In the Movielens dataset, movie ids are integers (represented as strings) +starting at 1 and with no gap. Normally, you would need to create a lookup table +to map movie ids to integers from 0 to N-1. But as a simplication, we'll use the +movie id directly as an index in our model, in particular to lookup the movie embedding from the movie embedding table. So we need do know the number of movies. """ @@ -220,18 +229,19 @@ class RetrievalModel(keras.Model): def __init__( self, - user_embeddings_count, - candidate_embeddings_count, + num_users, + num_candidates, embedding_dimension=32, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) # Our query tower, simply an embedding table. self.user_embedding = keras.layers.Embedding( - user_embeddings_count, embedding_dimension + num_users, embedding_dimension ) # Our candidate tower, simply an embedding table. self.candidate_embedding = keras.layers.Embedding( - candidate_embeddings_count, embedding_dimension + num_candidates, embedding_dimension ) # The layer that performs the retrieval. self.retrieval = keras_rs.layers.BruteForceRetrieval( @@ -283,8 +293,8 @@ def compute_loss(self, x, y, y_pred, sample_weight, training=True): and evaluate the model. Let's first instantiate the model. Note that we add `+ 1` to the number of users -and movies to account for the fact that id zero is not used, but still take a -row in the embedding tables. +and movies to account for the fact that id zero is not used for either (ids +start at 1), but still takes a row in the embedding tables. """ model = RetrievalModel(users_count + 1, movies_count + 1)