Skip to content

Commit

Permalink
feat: support sparse X input
Browse files Browse the repository at this point in the history
  • Loading branch information
jolars committed Nov 30, 2023
1 parent 78e2918 commit 8d2cf08
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
8 changes: 7 additions & 1 deletion sortedl1/estimators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import _sortedl1 as sl1
import numpy as np
from scipy import sparse
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y

Expand Down Expand Up @@ -53,7 +54,12 @@ def fit(self, X, y):

alpha = np.atleast_1d(self.alpha).astype(np.float64)

result = sl1.fit_slope(X, y, lam, alpha)
if sparse.issparse(X):
fit_slope = sl1.fit_slope_sparse
else:
fit_slope = sl1.fit_slope_dense

result = fit_slope(X, y, lam, alpha)

self.intercept_ = result[0]
self.sparse_coef_ = result[1]
Expand Down
22 changes: 17 additions & 5 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,21 @@ using namespace pybind11::literals;
namespace py = pybind11;

pybind11::tuple
fit_slope(const Eigen::MatrixXd x,
const Eigen::MatrixXd y,
Eigen::ArrayXd lambda,
Eigen::ArrayXd alpha)
fit_slope_dense(const Eigen::MatrixXd x,
const Eigen::MatrixXd y,
Eigen::ArrayXd lambda,
Eigen::ArrayXd alpha)
{
auto result = slope::slope(x, y, alpha, lambda);

return py::make_tuple(result.beta0s, result.betas);
}

pybind11::tuple
fit_slope_sparse(const Eigen::SparseMatrix<double> x,
const Eigen::MatrixXd y,
Eigen::ArrayXd lambda,
Eigen::ArrayXd alpha)
{
auto result = slope::slope(x, y, alpha, lambda);

Expand All @@ -21,5 +32,6 @@ fit_slope(const Eigen::MatrixXd x,

PYBIND11_MODULE(_sortedl1, m)
{
m.def("fit_slope", &fit_slope);
m.def("fit_slope_dense", &fit_slope_dense);
m.def("fit_slope_sparse", &fit_slope_sparse);
}

0 comments on commit 8d2cf08

Please sign in to comment.