Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix print for verbose=False #573

Merged
merged 1 commit into from
Jan 28, 2024
Merged

Conversation

vroulet
Copy link
Collaborator

@vroulet vroulet commented Jan 24, 2024

Fix #571, added tests.

@vroulet vroulet changed the title fix_print Fix print for verbose=False Jan 24, 2024
@Joshuaalbert
Copy link

I noticed a difference between linesearch verbosity setting for LBFGS and NonlinearCG.

In LBFGS

self.linesearch_solver = _setup_linesearch(
        linesearch=self.linesearch,
        fun=_fun_with_aux,
        value_and_grad=self._value_and_grad_with_aux,
        has_aux=True,
        maxlsiter=self.maxls,
        max_stepsize=self.max_stepsize,
        jit=self.jit,
        unroll=unroll,
        verbose=self.verbose,
    )

In NonlinearCG:

linesearch_solver = _setup_linesearch(
        linesearch=self.linesearch,
        fun=_fun_with_aux,
        value_and_grad=self._value_and_grad_with_aux,
        has_aux=True,
        maxlsiter=self.maxls,
        max_stepsize=self.max_stepsize,
        jit=self.jit,
        unroll=unroll,
        verbose=int(self.verbose)-1
    )

Maybe multi-level verbosity is actually not so useful?

@vroulet
Copy link
Collaborator Author

vroulet commented Jan 25, 2024

Yes, I had forgotten the multiple level verbosity for lbfgs. It's been corrected in this PR.

@vroulet vroulet requested a review from fabianp January 26, 2024 12:14
Copy link
Collaborator

@fabianp fabianp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, thanks!

@copybara-service copybara-service bot merged commit 08ec55f into google:main Jan 28, 2024
7 checks passed
@rupeshknn
Copy link

I still face this issue, probably because there has been no release since Jan 10. when is the next release being planned?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

verbose=False is not working as expected for NonlinearCG
4 participants