Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discounted Cumulative Gain #188

Merged
merged 9 commits into from
Nov 3, 2023
97 changes: 97 additions & 0 deletions lib/scholar/metrics/ranking.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
defmodule Scholar.Metrics.Ranking do
@moduledoc """
Provides metrics and calculations related to ranking quality.

Ranking metrics evaluate the quality of ordered lists of items,
often used in information retrieval and recommendation systems.

This module currently supports the following ranking metrics:
* Discounted Cumulative Gain (DCG)
"""

import Nx.Defn
import Scholar.Shared
require Nx

@dcg_opts [
k: [
type: {:custom, Scholar.Options, :positive_number, []},
doc: "Truncation parameter to consider only the top-k elements."
]
]

@dcg_opts_schema NimbleOptions.new!(@dcg_opts)

deftransform dcg(y_true, y_score, opts \\ []) do
dcg_n(y_true, y_score, NimbleOptions.validate!(opts, @dcg_opts_schema))
end

@doc """
## Options
#{NimbleOptions.docs(@dcg_opts_schema)}

Computes the DCG based on true relevance scores (`y_true`) and their respective predicted scores (`y_score`).
"""
defn dcg_n(y_true, y_score, opts) do
y_true_shape = Nx.shape(y_true)
y_score_shape = Nx.shape(y_score)

check_shape(y_true_shape, y_score_shape)

{adjusted_y_true, adjusted_y_score} = handle_ties(y_true, y_score)

sorted_indices = Nx.argsort(adjusted_y_score, axis: 0, direction: :desc)
sorted_y_true = Nx.take(adjusted_y_true, sorted_indices)

truncated_y_true = truncate_at_k(sorted_y_true, opts)
dcg_value(truncated_y_true)
end

defnp check_shape(y_true, y_pred) do
assert_same_shape!(y_true, y_pred)
end

defnp handle_ties(y_true, y_score) do
sorted_indices = Nx.argsort(y_score, axis: 0, direction: :desc)

sorted_y_true = Nx.take(y_true, sorted_indices)
sorted_y_score = Nx.take(y_score, sorted_indices)

tie_sorted_indices = Nx.argsort(sorted_y_true, axis: 0, direction: :desc)
adjusted_y_true = Nx.take(sorted_y_true, tie_sorted_indices)
adjusted_y_score = Nx.take(sorted_y_score, tie_sorted_indices)

{adjusted_y_true, adjusted_y_score}
end

defnp dcg_value(y_true) do
float_y_true = Nx.as_type(y_true, :f32)

log_tensor =
y_true
|> Nx.shape()
|> Nx.iota()
|> Nx.as_type(:f32)
|> Nx.add(2.0)
|> Nx.log2()

div_result = Nx.divide(float_y_true, log_tensor)

Nx.sum(div_result)
end

defnp truncate_at_k(tensor, opts) do
case opts[:k] do
nil ->
tensor

_ ->
if opts[:k] > Nx.axis_size(tensor, 0) do
tensor
else
{top_k, _rest} = Nx.split(tensor, opts[:k], axis: 0)
top_k
end
end
end
end
47 changes: 47 additions & 0 deletions test/scholar/metrics/ranking_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
defmodule Scholar.Metrics.RankingTest do
use Scholar.Case, async: true
alias Scholar.Metrics.Ranking

describe "dcg/3" do
test "computes DCG when there are no ties" do
y_true = Nx.tensor([3, 2, 3, 0, 1, 2])
y_score = Nx.tensor([3.0, 2.2, 3.5, 0.5, 1.0, 2.1])

result = Ranking.dcg(y_true, y_score)

x = Nx.tensor([7.140995025634766])
assert x == Nx.broadcast(result, {1})
end

test "computes DCG with ties" do
y_true = Nx.tensor([3, 3, 3])
y_score = Nx.tensor([2.0, 2.0, 3.5])

result = Ranking.dcg(y_true, y_score)

x = Nx.tensor([6.3927892607143715])
assert x == Nx.broadcast(result, {1})
end

test "raises error when shapes mismatch" do
y_true = Nx.tensor([3, 2, 3])
y_score = Nx.tensor([3.0, 2.2, 3.5, 0.5])

assert_raise ArgumentError,
"expected tensor to have shape {3}, got tensor with shape {4}",
fn ->
Ranking.dcg(y_true, y_score)
end
end

test "computes DCG for top-k values" do
y_true = Nx.tensor([3, 2, 3, 0, 1, 2])
y_score = Nx.tensor([3.0, 2.2, 3.5, 0.5, 1.0, 2.1])

result = Ranking.dcg(y_true, y_score, k: 3)

x = Nx.tensor([5.892789363861084])
assert x == Nx.broadcast(result, {1})
end
end
end
Loading