From 3dacb89368d48cd778573e64391d8f3876c7616d Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 15 Jan 2024 21:19:35 +0000 Subject: [PATCH] bug(axes) Change axis alignment linalg_solve + test --- pysindy/utils/axes.py | 16 ++++++++++++---- test/utils/test_axes.py | 42 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 0c592a9dc..27c10abc5 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -550,12 +550,20 @@ def _join_unique_names(l_of_s: List[str]) -> str: def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray: result = np.linalg.solve(np.asarray(a), np.asarray(b)) a_rev = a._ax_map.reverse_map - contracted_axis_name = a_rev[sorted(a_rev)[-1]] + a_names = [a_rev[k] for k in sorted(a_rev)] + contracted_axis_name = a_names[-1] b_rev = b._ax_map.reverse_map - rest_of_names = [b_rev[k] for k in sorted(b_rev)] - axes = _AxisMapping.fwd_from_names( - [*rest_of_names[:-2], contracted_axis_name, rest_of_names[-1]] + b_names = [b_rev[k] for k in sorted(b_rev)] + match_axes_list = a_names[:-1] + start = max(b.ndim - a.ndim, 0) + end = start + len(match_axes_list) + align = slice(start, end) + if match_axes_list != b_names[align]: + raise ValueError("Mismatch in operand axis names when aligning A and b") + all_names = ( + b_names[: align.stop - 1] + [contracted_axis_name] + b_names[align.stop :] ) + axes = _AxisMapping.fwd_from_names(all_names) return AxesArray(result, axes) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 00d992547..38b19350b 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -538,10 +538,50 @@ def test_linalg_solve_align_right(): assert_array_equal(result, super_result) +def test_linalg_solve_align_right_xl(): + axesA = {"ax_sample": 0, "ax_feature": 1} + arrA = AxesArray(np.arange(4).reshape(2, 2), axesA) + axesb = {"ax_prob": 0, "ax_sample": 1, "ax_target": 2} + arrb = AxesArray(np.arange(8).reshape(2, 2, 2), axesb) + result = np.linalg.solve(arrA, arrb) + expected_axes = {"ax_prob": 0, "ax_feature": 1, "ax_target": 2} + assert result.axes == expected_axes + super_result = np.linalg.solve(np.asarray(arrA), np.asarray(arrb)) + assert_array_equal(result, super_result) + + def test_linalg_solve_incompatible_left(): axesA = {"ax_prob": 0, "ax_sample": 1, "ax_coord": 2} arrA = AxesArray(np.arange(8).reshape(2, 2, 2), axesA) axesb = {"ax_foo": 0, "ax_sample": 1} arrb = AxesArray(np.arange(4).reshape(2, 2), axesb) - with pytest.raises(ValueError, match="fdsafds"): + with pytest.raises(ValueError, match="Mismatch in operand axis names"): np.linalg.solve(arrA, arrb) + + +def test_tensordot_int_axes(): + ... + + +def test_tensordot_list_axes(): + ... + + +def test_einsum_implicit(): + ... + + +def test_einsum_trace(): + ... + + +def test_einsum_diag(): + ... + + +def test_einsum_contraction(): + ... + + +def test_einsum_mixed(): + ...