Skip to content

Commit

Permalink
Add least_squares replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
jmeyers314 committed Jul 24, 2024
1 parent 208f9a1 commit a11226c
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
4 changes: 1 addition & 3 deletions galsim/fitswcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from . import fits
from .errors import GalSimError, GalSimValueError, GalSimIncompatibleValuesError
from .errors import GalSimNotImplementedError, convert_cpp_errors, galsim_warn
from .utilities import horner2d
from .utilities import horner2d, least_squares
from .celestial import CelestialCoord
from ._pyfits import pyfits

Expand Down Expand Up @@ -1783,8 +1783,6 @@ def FittedSIPWCS(x, y, ra, dec, wcs_type='TAN', order=3, center=None):
the tangent plane is centered. [default: None, which means
use the average position of the list of reference stars]
"""
from scipy.optimize import least_squares

if order < 1:
raise GalSimValueError("Illegal SIP order", order)

Expand Down
69 changes: 69 additions & 0 deletions galsim/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,3 +1832,72 @@ def __exit__(self, type, value, traceback):
if self.reverse: # pragma: no cover
ps = ps.reverse_order()
ps.print_stats(self.nlines)


def least_squares(fun, x0, args=(), kwargs={}, max_iter=1000, tol=1e-9, lambda_init=1.0):
"""Perform a non-linear least squares fit using the Levenberg-Marquardt algorithm.
Drop in replacement for scipy.optimize.least_squares when using default options,
though many fewer options available in general.
Parameters:
fun: Function which computes vector of residuals, with the signature
fun(params, *args, **kwargs) -> np.ndarray.
x0: Initial guess for the parameters.
args: Additional arguments to pass to the function.
kwargs: Additional keyword arguments to pass to the function.
max_iter: Maximum number of iterations. [default: 1000]
tol: Tolerance for convergence. [default: 1e-9]
lambda_init: Initial damping factor for Levenberg-Marquardt. [default: 1.0]
Returns:
A named tuple with fields:
x: The final parameter values.
cost: The final cost (sum of squared residuals).
"""
# JM: This is a tweaked version of a ChatGPT-generated implementation of
# Levenberg-Marquardt (cross-checked against the wikipedia page
# https://en.wikipedia.org/wiki/Levenberg%E2%80%93Marquardt_algorithm).

from collections import namedtuple
params = np.array(x0)
lambda_ = lambda_init

for _ in range(max_iter): # pragma: no branch
residuals = fun(params, *args, **kwargs)

# Jacobian matrix
J = np.zeros((len(residuals), len(params)))
for j in range(len(params)):
perturbation = np.zeros(len(params))
perturbation[j] = tol
J[:, j] = (fun(params + perturbation, *args, **kwargs) - residuals) / tol

# Regular least squares param update
JTJ = np.dot(J.T, J)
JTr = np.dot(J.T, residuals)

# Levenberg-Marquardt adjustment
A = JTJ + lambda_ * np.eye(len(JTJ))
delta_params = np.linalg.solve(A, JTr)

new_params = params - delta_params
new_residuals = fun(new_params, *args, **kwargs)

if np.linalg.norm(new_residuals) < np.linalg.norm(residuals):
params = new_params
residuals = new_residuals
lambda_ /= 3 # reduce damping
else:
lambda_ *= 3 # increase damping

if np.linalg.norm(delta_params) < tol:
break

cost = 0.5 * np.sum(residuals**2)

# Create a result object similar to scipy.optimize.OptimizeResult
Result = namedtuple('Result', ['x', 'cost'])
result = Result(x=params, cost=cost)

return result

0 comments on commit a11226c

Please sign in to comment.