Skip to content

Commit

Permalink
test(axes): Add tensordot tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jan 15, 2024
1 parent 3dacb89 commit c64b4e4
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,11 +560,29 @@ def test_linalg_solve_incompatible_left():


def test_tensordot_int_axes():
...
axes_a = {"ax_a": 0, "ax_b": [1, 2]}
axes_b = {"ax_b": [0, 1], "ax_c": 2}
arr = np.arange(8).reshape((2, 2, 2))
arr_a = AxesArray(arr, axes_a)
arr_b = AxesArray(arr, axes_b)
result = np.tensordot(arr_a, arr_b, 2)
super_result = np.tensordot(arr, arr, 2)
expected_axes = {"ax_a": 0, "ax_c": 1}
assert result.axes == expected_axes
assert_array_equal(result, super_result)


def test_tensordot_list_axes():
...
axes_a = {"ax_a": 0, "ax_b": [1, 2]}
axes_b = {"ax_c": [0, 1], "ax_b": 2}
arr = np.arange(8).reshape((2, 2, 2))
arr_a = AxesArray(arr, axes_a)
arr_b = AxesArray(arr, axes_b)
result = np.tensordot(arr_a, arr_b, [[1], [2]])
super_result = np.tensordot(arr, arr, 2)
expected_axes = {"ax_a": 0, "ax_b": 1, "ax_c": [2, 3]}
assert result.axes == expected_axes
assert_array_equal(result, super_result)


def test_einsum_implicit():
Expand Down

0 comments on commit c64b4e4

Please sign in to comment.