Skip to content

Commit

Permalink
bug(axes) Change axis alignment linalg_solve + test
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jan 15, 2024
1 parent 8f1e4bc commit 3dacb89
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
16 changes: 12 additions & 4 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,12 +550,20 @@ def _join_unique_names(l_of_s: List[str]) -> str:
def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray:
result = np.linalg.solve(np.asarray(a), np.asarray(b))
a_rev = a._ax_map.reverse_map
contracted_axis_name = a_rev[sorted(a_rev)[-1]]
a_names = [a_rev[k] for k in sorted(a_rev)]
contracted_axis_name = a_names[-1]
b_rev = b._ax_map.reverse_map
rest_of_names = [b_rev[k] for k in sorted(b_rev)]
axes = _AxisMapping.fwd_from_names(
[*rest_of_names[:-2], contracted_axis_name, rest_of_names[-1]]
b_names = [b_rev[k] for k in sorted(b_rev)]
match_axes_list = a_names[:-1]
start = max(b.ndim - a.ndim, 0)
end = start + len(match_axes_list)
align = slice(start, end)
if match_axes_list != b_names[align]:
raise ValueError("Mismatch in operand axis names when aligning A and b")
all_names = (
b_names[: align.stop - 1] + [contracted_axis_name] + b_names[align.stop :]
)
axes = _AxisMapping.fwd_from_names(all_names)
return AxesArray(result, axes)


Expand Down
42 changes: 41 additions & 1 deletion test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,50 @@ def test_linalg_solve_align_right():
assert_array_equal(result, super_result)


def test_linalg_solve_align_right_xl():
axesA = {"ax_sample": 0, "ax_feature": 1}
arrA = AxesArray(np.arange(4).reshape(2, 2), axesA)
axesb = {"ax_prob": 0, "ax_sample": 1, "ax_target": 2}
arrb = AxesArray(np.arange(8).reshape(2, 2, 2), axesb)
result = np.linalg.solve(arrA, arrb)
expected_axes = {"ax_prob": 0, "ax_feature": 1, "ax_target": 2}
assert result.axes == expected_axes
super_result = np.linalg.solve(np.asarray(arrA), np.asarray(arrb))
assert_array_equal(result, super_result)


def test_linalg_solve_incompatible_left():
axesA = {"ax_prob": 0, "ax_sample": 1, "ax_coord": 2}
arrA = AxesArray(np.arange(8).reshape(2, 2, 2), axesA)
axesb = {"ax_foo": 0, "ax_sample": 1}
arrb = AxesArray(np.arange(4).reshape(2, 2), axesb)
with pytest.raises(ValueError, match="fdsafds"):
with pytest.raises(ValueError, match="Mismatch in operand axis names"):
np.linalg.solve(arrA, arrb)


def test_tensordot_int_axes():
...


def test_tensordot_list_axes():
...


def test_einsum_implicit():
...


def test_einsum_trace():
...


def test_einsum_diag():
...


def test_einsum_contraction():
...


def test_einsum_mixed():
...

0 comments on commit 3dacb89

Please sign in to comment.