Skip to content

Commit

Permalink
Numba utils refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
AmenRa committed Aug 16, 2023
1 parent 937a2c8 commit 1c335cc
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 28 deletions.
6 changes: 3 additions & 3 deletions retriv/sparse_retriever/sparse_retrieval_models/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numba import njit, prange
from numba.typed import List as TypedList

from ...utils.numba_utils import join_sorted_multi_recursive, unsorted_top_k
from ...utils.numba_utils import union_sorted_multi, unsorted_top_k


@njit(cache=True)
Expand All @@ -18,7 +18,7 @@ def bm25(
doc_count: int,
cutoff: int,
) -> Tuple[np.ndarray]:
unique_doc_ids = join_sorted_multi_recursive(doc_ids)
unique_doc_ids = union_sorted_multi(doc_ids)

scores = np.empty(doc_count, dtype=np.float32)
scores[unique_doc_ids] = 0.0 # Initialize scores
Expand Down Expand Up @@ -63,7 +63,7 @@ def bm25_multi(
_term_doc_freqs = term_doc_freqs[i]
_doc_ids = doc_ids[i]

_unique_doc_ids = join_sorted_multi_recursive(_doc_ids)
_unique_doc_ids = union_sorted_multi(_doc_ids)

_scores = np.empty(doc_count, dtype=np.float32)
_scores[_unique_doc_ids] = 0.0 # Initialize _scores
Expand Down
6 changes: 3 additions & 3 deletions retriv/sparse_retriever/sparse_retrieval_models/tf_idf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numba import njit, prange
from numba.typed import List as TypedList

from ...utils.numba_utils import join_sorted_multi_recursive, unsorted_top_k
from ...utils.numba_utils import union_sorted_multi, unsorted_top_k


@njit(cache=True)
Expand All @@ -15,7 +15,7 @@ def tf_idf(
doc_lens: nb.typed.List[np.ndarray],
cutoff: int,
) -> Tuple[np.ndarray]:
unique_doc_ids = join_sorted_multi_recursive(doc_ids)
unique_doc_ids = union_sorted_multi(doc_ids)

doc_count = len(doc_lens)
scores = np.empty(doc_count, dtype=np.float32)
Expand Down Expand Up @@ -57,7 +57,7 @@ def tf_idf_multi(
_term_doc_freqs = term_doc_freqs[i]
_doc_ids = doc_ids[i]

_unique_doc_ids = join_sorted_multi_recursive(_doc_ids)
_unique_doc_ids = union_sorted_multi(_doc_ids)

doc_count = len(doc_lens)
_scores = np.empty(doc_count, dtype=np.float32)
Expand Down
68 changes: 58 additions & 10 deletions retriv/utils/numba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from numba import njit


# UNION ------------------------------------------------------------------------
@njit(cache=True)
def join_sorted(a1: np.array, a2: np.array):
def union_sorted(a1: np.array, a2: np.array):
result = np.empty(len(a1) + len(a2), dtype=np.int32)
i = 0
j = 0
Expand Down Expand Up @@ -33,25 +34,72 @@ def join_sorted(a1: np.array, a2: np.array):


@njit(cache=True)
def join_sorted_multi(arrays):
def union_sorted_multi(arrays):
if len(arrays) == 1:
return arrays[0]
elif len(arrays) == 2:
return union_sorted(arrays[0], arrays[1])
else:
return union_sorted(
union_sorted_multi(arrays[:2]), union_sorted_multi(arrays[2:])
)


# INTERSECTION -----------------------------------------------------------------
@njit(cache=True)
def intersect_sorted(a1: np.array, a2: np.array):
result = np.empty(min(len(a1), len(a2)), dtype=np.int32)
i = 0
j = 0
k = 0

while i < len(a1) and j < len(a2):
if a1[i] < a2[j]:
i += 1
elif a1[i] > a2[j]:
j += 1
else: # a1[i] == a2[j]
result[k] = a1[i]
i += 1
j += 1
k += 1

return result[:k]


@njit(cache=True)
def intersect_sorted_multi(arrays):
a = arrays[0]

for i in range(1, len(arrays)):
a = join_sorted(a, arrays[i])
a = intersect_sorted(a, arrays[i])

return a


# DIFFERENCE -------------------------------------------------------------------
@njit(cache=True)
def join_sorted_multi_recursive(arrays):
if len(arrays) == 1:
return arrays[0]
elif len(arrays) == 2:
return join_sorted(arrays[0], arrays[1])
else:
return join_sorted(join_sorted_multi(arrays[:2]), join_sorted_multi(arrays[2:]))
def diff_sorted(a1: np.array, a2: np.array):
result = np.empty(len(a1), dtype=np.int32)
i = 0
j = 0
k = 0

while i < len(a1) and j < len(a2):
if a1[i] < a2[j]:
result[k] = a1[i]
i += 1
k += 1
elif a1[i] > a2[j]:
j += 1
else: # a1[i] == a2[j]
i += 1
j += 1

return result[:k]


# -----------------------------------------------------------------------------
@njit(cache=True)
def concat1d(X):
out = np.empty(sum([len(x) for x in X]), dtype=X[0].dtype)
Expand Down
47 changes: 35 additions & 12 deletions tests/sparse_retriever/numba_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,73 @@
import numpy as np
import pytest
from numba.typed import List as TypedList

from retriv.utils.numba_utils import (
concat1d,
diff_sorted,
get_indices,
join_sorted,
join_sorted_multi,
join_sorted_multi_recursive,
intersect_sorted,
intersect_sorted_multi,
union_sorted,
union_sorted_multi,
unsorted_top_k,
)


# TESTS ========================================================================
def test_join_sorted():
def test_union_sorted():
a1 = np.array([1, 3, 4, 7], dtype=np.int32)
a2 = np.array([1, 4, 7, 9], dtype=np.int32)
result = join_sorted(a1, a2)
result = union_sorted(a1, a2)
expected = np.array([1, 3, 4, 7, 9], dtype=np.int32)

assert np.array_equal(result, expected)


def test_join_sorted_multi():
def test_union_sorted_multi():
a1 = np.array([1, 3, 4, 7], dtype=np.int32)
a2 = np.array([1, 4, 7, 9], dtype=np.int32)
a3 = np.array([10, 11], dtype=np.int32)
a4 = np.array([11, 12, 13], dtype=np.int32)

arrays = TypedList([a1, a2, a3, a4])

result = join_sorted_multi(arrays)
result = union_sorted_multi(arrays)
expected = np.array([1, 3, 4, 7, 9, 10, 11, 12, 13], dtype=np.int32)

assert np.array_equal(result, expected)


def test_join_sorted_multi_recursive():
def test_intersect_sorted():
a1 = np.array([1, 3, 4, 7], dtype=np.int32)
a2 = np.array([1, 4, 7, 9], dtype=np.int32)
a3 = np.array([10, 11], dtype=np.int32)
a4 = np.array([11, 12, 13], dtype=np.int32)
result = intersect_sorted(a1, a2)
expected = np.array([1, 4, 7], dtype=np.int32)

assert np.array_equal(result, expected)


def test_intersect_sorted_multi():
a1 = np.array([1, 3, 4, 7], dtype=np.int32)
a2 = np.array([1, 4, 7, 9], dtype=np.int32)
a3 = np.array([4, 7], dtype=np.int32)
a4 = np.array([3, 7, 9], dtype=np.int32)

arrays = TypedList([a1, a2, a3, a4])

result = join_sorted_multi_recursive(arrays)
expected = np.array([1, 3, 4, 7, 9, 10, 11, 12, 13], dtype=np.int32)
result = intersect_sorted_multi(arrays)
expected = np.array([7], dtype=np.int32)

print(result)

assert np.array_equal(result, expected)


def test_diff_sorted():
a1 = np.array([1, 3, 4, 7], dtype=np.int32)
a2 = np.array([1, 4, 7, 9], dtype=np.int32)
result = diff_sorted(a1, a2)
expected = np.array([3], dtype=np.int32)

assert np.array_equal(result, expected)

Expand Down

0 comments on commit 1c335cc

Please sign in to comment.