From 8b6f35ca7be8257a530a6d8ead7dd38b086597db Mon Sep 17 00:00:00 2001 From: JAXopt authors Date: Mon, 9 Sep 2024 16:03:09 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 672706385 --- jaxopt/_src/gradient_descent.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jaxopt/_src/gradient_descent.py b/jaxopt/_src/gradient_descent.py index bbaf100f..51f461c8 100644 --- a/jaxopt/_src/gradient_descent.py +++ b/jaxopt/_src/gradient_descent.py @@ -79,11 +79,9 @@ def init_state(self, """ return super().init_state(init_params, None, *args, **kwargs) - def update(self, - params: Any, - state: NamedTuple, - *args, - **kwargs) -> base.OptStep: + def update( + self, params: Any, state: ProxGradState, *args, **kwargs + ) -> base.OptStep: """Performs one iteration of gradient descent. Args: