diff --git a/jaxopt/_src/gradient_descent.py b/jaxopt/_src/gradient_descent.py index bbaf100f..51f461c8 100644 --- a/jaxopt/_src/gradient_descent.py +++ b/jaxopt/_src/gradient_descent.py @@ -79,11 +79,9 @@ def init_state(self, """ return super().init_state(init_params, None, *args, **kwargs) - def update(self, - params: Any, - state: NamedTuple, - *args, - **kwargs) -> base.OptStep: + def update( + self, params: Any, state: ProxGradState, *args, **kwargs + ) -> base.OptStep: """Performs one iteration of gradient descent. Args: