-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
80e8889
commit 3f5b6bd
Showing
2 changed files
with
144 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
defmodule Scholar.Metrics.Ranking do | ||
@moduledoc """ | ||
Provides metrics and calculations related to ranking quality. | ||
Ranking metrics evaluate the quality of ordered lists of items, | ||
often used in information retrieval and recommendation systems. | ||
This module currently supports the following ranking metrics: | ||
* Discounted Cumulative Gain (DCG) | ||
""" | ||
|
||
import Nx.Defn | ||
import Scholar.Shared | ||
require Nx | ||
|
||
@dcg_opts [ | ||
k: [ | ||
type: {:custom, Scholar.Options, :positive_number, []}, | ||
doc: "Truncation parameter to consider only the top-k elements." | ||
] | ||
] | ||
|
||
@dcg_opts_schema NimbleOptions.new!(@dcg_opts) | ||
|
||
deftransform dcg(y_true, y_score, opts \\ []) do | ||
dcg_n(y_true, y_score, NimbleOptions.validate!(opts, @dcg_opts_schema)) | ||
end | ||
|
||
@doc """ | ||
## Options | ||
#{NimbleOptions.docs(@dcg_opts_schema)} | ||
Computes the DCG based on true relevance scores (`y_true`) and their respective predicted scores (`y_score`). | ||
""" | ||
defn dcg_n(y_true, y_score, opts) do | ||
y_true_shape = Nx.shape(y_true) | ||
y_score_shape = Nx.shape(y_score) | ||
|
||
check_shape(y_true_shape, y_score_shape) | ||
|
||
{adjusted_y_true, adjusted_y_score} = handle_ties(y_true, y_score) | ||
|
||
sorted_indices = Nx.argsort(adjusted_y_score, axis: 0, direction: :desc) | ||
sorted_y_true = Nx.take(adjusted_y_true, sorted_indices) | ||
|
||
truncated_y_true = truncate_at_k(sorted_y_true, opts) | ||
dcg_value(truncated_y_true) | ||
end | ||
|
||
defnp check_shape(y_true, y_pred) do | ||
assert_same_shape!(y_true, y_pred) | ||
end | ||
|
||
defnp handle_ties(y_true, y_score) do | ||
sorted_indices = Nx.argsort(y_score, axis: 0, direction: :desc) | ||
|
||
sorted_y_true = Nx.take(y_true, sorted_indices) | ||
sorted_y_score = Nx.take(y_score, sorted_indices) | ||
|
||
tie_sorted_indices = Nx.argsort(sorted_y_true, axis: 0, direction: :desc) | ||
adjusted_y_true = Nx.take(sorted_y_true, tie_sorted_indices) | ||
adjusted_y_score = Nx.take(sorted_y_score, tie_sorted_indices) | ||
|
||
{adjusted_y_true, adjusted_y_score} | ||
end | ||
|
||
defnp dcg_value(y_true) do | ||
float_y_true = Nx.as_type(y_true, :f32) | ||
|
||
log_tensor = | ||
y_true | ||
|> Nx.shape() | ||
|> Nx.iota() | ||
|> Nx.as_type(:f32) | ||
|> Nx.add(2.0) | ||
|> Nx.log2() | ||
|
||
div_result = Nx.divide(float_y_true, log_tensor) | ||
|
||
Nx.sum(div_result) | ||
end | ||
|
||
defnp truncate_at_k(tensor, opts) do | ||
case opts[:k] do | ||
nil -> | ||
tensor | ||
|
||
_ -> | ||
if opts[:k] > Nx.axis_size(tensor, 0) do | ||
tensor | ||
else | ||
{top_k, _rest} = Nx.split(tensor, opts[:k], axis: 0) | ||
top_k | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
defmodule Scholar.Metrics.RankingTest do | ||
use Scholar.Case, async: true | ||
alias Scholar.Metrics.Ranking | ||
|
||
describe "dcg/3" do | ||
test "computes DCG when there are no ties" do | ||
y_true = Nx.tensor([3, 2, 3, 0, 1, 2]) | ||
y_score = Nx.tensor([3.0, 2.2, 3.5, 0.5, 1.0, 2.1]) | ||
|
||
result = Ranking.dcg(y_true, y_score) | ||
|
||
x = Nx.tensor([7.140995025634766]) | ||
assert x == Nx.broadcast(result, {1}) | ||
end | ||
|
||
test "computes DCG with ties" do | ||
y_true = Nx.tensor([3, 3, 3]) | ||
y_score = Nx.tensor([2.0, 2.0, 3.5]) | ||
|
||
result = Ranking.dcg(y_true, y_score) | ||
|
||
x = Nx.tensor([6.3927892607143715]) | ||
assert x == Nx.broadcast(result, {1}) | ||
end | ||
|
||
test "raises error when shapes mismatch" do | ||
y_true = Nx.tensor([3, 2, 3]) | ||
y_score = Nx.tensor([3.0, 2.2, 3.5, 0.5]) | ||
|
||
assert_raise ArgumentError, | ||
"expected tensor to have shape {3}, got tensor with shape {4}", | ||
fn -> | ||
Ranking.dcg(y_true, y_score) | ||
end | ||
end | ||
|
||
test "computes DCG for top-k values" do | ||
y_true = Nx.tensor([3, 2, 3, 0, 1, 2]) | ||
y_score = Nx.tensor([3.0, 2.2, 3.5, 0.5, 1.0, 2.1]) | ||
|
||
result = Ranking.dcg(y_true, y_score, k: 3) | ||
|
||
x = Nx.tensor([5.892789363861084]) | ||
assert x == Nx.broadcast(result, {1}) | ||
end | ||
end | ||
end |