Skip to content

Commit

Permalink
Merge branch 'master' into trapping-constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas authored Jan 16, 2024
2 parents f53151a + a7eb890 commit dd501a8
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 2 deletions.
9 changes: 9 additions & 0 deletions pysindy/optimizers/miosr.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,12 @@ def _reduce(self, x, y):
def complexity(self):
check_is_fitted(self)
return np.count_nonzero(self.coef_)

def __getstate__(self):
state = self.__dict__.copy()
del state["model"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.model = None
7 changes: 5 additions & 2 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,16 @@ def decorator(func):


@implements(np.concatenate)
def concatenate(arrays, axis=0):
def concatenate(arrays, axis=0, out=None, dtype=None, casting="same_kind"):
parents = [np.asarray(obj) for obj in arrays]
ax_list = [obj.__dict__ for obj in arrays if isinstance(obj, AxesArray)]
for ax1, ax2 in zip(ax_list[:-1], ax_list[1:]):
if ax1 != ax2:
raise TypeError("Concatenating >1 AxesArray with incompatible axes")
return AxesArray(np.concatenate(parents, axis), axes=ax_list[0])
result = np.concatenate(parents, axis, out=out, dtype=dtype, casting=casting)
if isinstance(out, AxesArray):
out.__dict__ = ax_list[0]
return AxesArray(result, axes=ax_list[0])


def comprehend_axes(x):
Expand Down
12 changes: 12 additions & 0 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Unit tests for optimizers.
"""
import pickle

import numpy as np
import pytest
from numpy.linalg import norm
Expand Down Expand Up @@ -1151,3 +1153,13 @@ def test_trapping_mixed_only():
constraint_lhs = _antisymm_triple_constraints(3, 3, mixed_terms)
result = np.tensordot(constraint_lhs, stable_coefs, ((1, 2), (1, 0)))
assert result[0] == 0


def test_pickle(data_lorenz):
x, t = data_lorenz
y = PolynomialLibrary(degree=2).fit_transform(x)
opt = MIOSR(target_sparsity=7).fit(x, y)
expected = opt.coef_
new_opt = pickle.loads(pickle.dumps(opt))
result = new_opt.coef_
np.testing.assert_array_equal(result, expected)
7 changes: 7 additions & 0 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
from pysindy import AxesArray


def test_concat_out():
arr = AxesArray(np.arange(3).reshape(1, 3), {"ax_a": 0, "ax_b": 1})
arr_out = np.empty((2, 3)).view(AxesArray)
result = np.concatenate((arr, arr), axis=0, out=arr_out)
assert_equal(result, arr_out)


def test_reduce_mean_noinf_recursion():
arr = AxesArray(np.array([[1]]), {})
np.mean(arr, axis=0)
Expand Down

0 comments on commit dd501a8

Please sign in to comment.