Skip to content

Commit

Permalink
Clean up/speed up numpy array construction
Browse files Browse the repository at this point in the history
  • Loading branch information
sdfordham committed Dec 31, 2023
1 parent e8a919a commit 244dac9
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
3 changes: 2 additions & 1 deletion pysyncon/augsynth.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def fit(self, dataprep: Dataprep, lambda_: Optional[float] = None) -> None:
else:
self.lambda_ = lambda_

V_mat = np.diag([1.0 / X0.shape[0]] * X0.shape[0])
n_r, _ = X0.shape
V_mat = np.diag(np.full(n_r, 1 / n_r))
W, _ = self.w_optimize(V_mat=V_mat, X0=X0.to_numpy(), X1=X1.to_numpy())

W_ridge = self.solve_ridge(
Expand Down
6 changes: 3 additions & 3 deletions pysyncon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ def w_optimize(
def fun(x):
return q.T @ x + 0.5 * x.T @ P @ x

bounds = Bounds(lb=np.array([0.0] * n_c).T, ub=np.array([1.0] * n_c).T)
constraints = LinearConstraint(A=np.array([1.0] * n_c), lb=1.0, ub=1.0)
bounds = Bounds(lb=np.full(n_c, 0.0), ub=np.full(n_c, 1.0))
constraints = LinearConstraint(A=np.full(n_c, 1.0), lb=1.0, ub=1.0)

x0 = np.array([1 / n_c] * n_c)
x0 = np.full(n_c, 1 / n_c)
res = minimize(
fun=fun,
x0=x0,
Expand Down
6 changes: 3 additions & 3 deletions pysyncon/penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def w_optimize(
def fun(x):
return (X1 - X0 @ x).T @ (X1 - X0 @ x) + lambda_ * (r.T @ x)

bounds = Bounds(lb=np.array([0.0] * n_c).T, ub=np.array([1.0] * n_c).T)
constraints = LinearConstraint(A=np.array([1.0] * n_c), lb=1.0, ub=1.0)
bounds = Bounds(lb=np.full(n_c, 0.0), ub=np.full(n_c, 1.0))
constraints = LinearConstraint(A=np.full(n_c, 1.0), lb=1.0, ub=1.0)

if initial:
x0 = initial
else:
x0 = np.array([1 / n_c] * n_c)
x0 = np.full(n_c, 1 / n_c)

res = minimize(fun=fun, x0=x0, bounds=bounds, constraints=constraints)
W, loss_W = res["x"], res["fun"]
Expand Down
2 changes: 1 addition & 1 deletion pysyncon/synth.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def fit(
x0 = [1 / n_r] * n_r
elif optim_initial == "ols":
X_arr = np.hstack([X0_arr, X1_arr.reshape(-1, 1)])
X_arr = np.hstack([np.array([1] * X_arr.shape[1], ndmin=2).T, X_arr.T])
X_arr = np.hstack([np.full((X_arr.shape[1], 1), 1), X_arr.T])
Z_arr = np.hstack([Z0_arr, Z1_arr.reshape(-1, 1)])

try:
Expand Down

0 comments on commit 244dac9

Please sign in to comment.