Skip to content

Commit

Permalink
[Solvers] Clarified documentation for fixed-point solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
arpastrana committed Jan 26, 2025
1 parent 2629880 commit fb4745e
Showing 1 changed file with 97 additions and 15 deletions.
112 changes: 97 additions & 15 deletions src/jax_fdm/equilibrium/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
def solver_anderson(f, a, x_init, solver_config):
"""
Solve for a fixed point of a function f(a, x) using anderson acceleration in jaxopt.
Parameters
----------
f : The function to iterate upon.
a : The function parameters.
x_init: An initial guess for the values of the solution vector.
solver_config: The configuration options of the solver.
Returns
-------
x : The solution vector at a fixed point.
"""
tmax = solver_config["tmax"]
eta = solver_config["eta"]
Expand All @@ -47,6 +59,17 @@ def f_swapped(x, a):
def solver_fixedpoint(f, a, x_init, solver_config):
"""
Solve for a fixed point of a function f(a, x) using forward iteration in jaxopt.
Parameters
----------
f : The function to iterate upon.
a : The function parameters.
x_init: An initial guess for the values of the solution vector.
solver_config: The configuration options of the solver.
Returns
-------
x : The solution vector at a fixed point.
"""
tmax = solver_config["tmax"]
eta = solver_config["eta"]
Expand All @@ -72,6 +95,17 @@ def f_swapped(x, a):
def solver_forward(f, a, x_init, solver_config):
"""
Solve for a fixed point of a function f(a, x) using forward iteration.
Parameters
----------
f : The function to iterate upon.
a : The function parameters.
x_init: An initial guess for the values of the solution vector.
solver_config: The configuration options of the solver.
Returns
-------
x : The solution vector at a fixed point.
"""
tmax = solver_config["tmax"]
eta = solver_config["eta"]
Expand Down Expand Up @@ -105,6 +139,17 @@ def body_fun(carry):
def solver_newton(f, a, x_init, solver_config):
"""
Find a root of the equation f(a, x) - x = 0 using Newton's method.
Parameters
----------
f : The function to iterate upon.
a : The function parameters.
x_init: An initial guess for the values of the solution vector.
solver_config: The configuration options of the solver.
Returns
-------
x : The solution vector at a fixed point.
"""
def f_root(x):
return f(a, x) - x
Expand Down Expand Up @@ -134,37 +179,74 @@ def f_newton(a, x):
@partial(custom_vjp, nondiff_argnums=(0, 1, 2))
def fixed_point(solver, solver_config, f, a, x_init):
"""
Solve for a fixed point of a function f(a, x) using forward iteration.
Solve for a fixed point of a function f(a, x) using an iterative solver.
"""
return solver(f, a, x_init, solver_config)


def fixed_point_fwd(solver, solver_config, f, a, x_init):
"""
The forward pass of a fixed point solver.
The forward pass of an iterative fixed point solver.
Parameters
----------
solver: The function that executes a fixed point solver.
solver_config: The configuration options of the solver.
fn : The function to iterate upon.
a : The function parameters.
x_init: An initial guess for the values of the solution vector.
Returns
-------
x : The solution vector at a fixed point.
res : Auxiliary data to transfer to the backward pass.
"""
x_star = fixed_point(solver, solver_config, f, a, x_init)

return x_star, (a, x_star)


def fixed_point_bwd(solver, solver_config, fn, res, x_star_bar):
def fixed_point_bwd(solver, solver_config, f, res, vec):
"""
The backward pass of a fixed point solver.
The backward pass of an iterative fixed point solver.
Parameters
----------
solver: The function that executes a fixed point solver.
solver_config: The configuration options of the solver.
f : The function to iterate upon.
res : Auxiliary data transferred from the forward pass.
vec: The vector on the left of the VJP.
Returns
-------
x : The solution vector at a fixed point.
res : None
"""
a, x_star = res
_, vjp_a = vjp(lambda a: fn(a, x_star), a)
_, vjp_a = vjp(lambda a: f(a, x_star), a)

def rev_iter(packed, u):
a, x_star, x_star_bar = packed
_, vjp_x = vjp(lambda x: fn(a, x), x_star)
return x_star_bar + vjp_x(u)[0]

partial_func = solver(rev_iter,
(a, x_star, x_star_bar),
x_star_bar,
solver_config)

a_bar = vjp_a(partial_func)[0]
"""
The function ought to have signature f(a, u(a)).
We are looking for a fixed point u*(a) = f(a, u*(a)).
"""
a, x_star, vec = packed

# Calculate the Jacobian df / dx
_, vjp_x = vjp(lambda x: f(a, x), x_star)

# Affine function: u = vector + u * df / dx
return vec + vjp_x(u)[0]

u_star = solver(
rev_iter, # The function to find a fixed-point of
(a, x_star, vec), # The parameters of rev_iter
vec, # The initial guess of the solution vector
solver_config) # The configuration of the solver

# VJP: u * df / da
a_bar = vjp_a(u_star)[0]

return a_bar, None

Expand Down

0 comments on commit fb4745e

Please sign in to comment.