Skip to content

Commit

Permalink
ENH: Add function to calculate number of polynomial features
Browse files Browse the repository at this point in the history
Also need to rename a variable so it doesn't shadow imported name
  • Loading branch information
Jacob-Stevens-Haas committed Dec 5, 2023
1 parent 2662af3 commit 43a8ca4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
40 changes: 35 additions & 5 deletions pysindy/feature_library/polynomial_library.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from itertools import chain
from math import comb
from typing import Iterator

import numpy as np
Expand Down Expand Up @@ -208,10 +209,10 @@ def transform(self, x_full):
)
if sparse.isspmatrix(x):
columns = []
for comb in combinations:
if comb:
for combo in combinations:
if combo:
out_col = 1
for col_idx in comb:
for col_idx in combo:
out_col = x[..., col_idx].multiply(out_col)
columns.append(out_col)
else:
Expand All @@ -227,7 +228,36 @@ def transform(self, x_full):
),
x.__dict__,
)
for i, comb in enumerate(combinations):
xp[..., i] = x[..., comb].prod(-1)
for i, combo in enumerate(combinations):
xp[..., i] = x[..., combo].prod(-1)
xp_full = xp_full + [xp]
return xp_full


def n_poly_features(
n_in_feat: int,
degree: int,
include_bias: bool = False,
include_interation: bool = True,
interaction_only: bool = False,
) -> int:
"""Calculate number of polynomial features
Args:
n_in_feat: number of input features, e.g. 3 for x, y, z
degree: polynomial degree, e.g. 2 for quadratic
include_bias: whether to include a constant term
include_interaction: whether to include terms mixing multiple inputs
interaction_only: whether to omit terms of x_m * x_n^p for p > 1
"""
if not include_interation and interaction_only:
raise ValueError("Cannot set interaction only if include_interaction is False")
n_feat = include_bias
if not include_interation:
return n_feat + n_in_feat * degree
for deg in range(1, degree + 1):
if interaction_only:
n_feat += comb(n_in_feat, deg)
else:
n_feat += comb(n_in_feat + deg - 1, deg)
return n_feat
4 changes: 4 additions & 0 deletions test/test_feature_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pysindy.feature_library import TensoredLibrary
from pysindy.feature_library import WeakPDELibrary
from pysindy.feature_library.base import BaseFeatureLibrary
from pysindy.feature_library.polynomial_library import n_poly_features
from pysindy.optimizers import SINDyPI
from pysindy.optimizers import STLSQ

Expand Down Expand Up @@ -880,3 +881,6 @@ def test_polynomial_combinations(include_interaction, interaction_only, bias, ex
)
result = tuple(sorted(list(combos)))
assert result == expected
assert len(expected) == n_poly_features(
2, 2, bias, include_interaction, interaction_only
)

0 comments on commit 43a8ca4

Please sign in to comment.