Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No error to check convergence in output of SpatioTemporalProblem #768

Open
matthieuheitz opened this issue Dec 5, 2024 · 11 comments
Open
Assignees
Labels
documentation Improvements or additions to documentation examples

Comments

@matthieuheitz
Copy link

In the output of TemporalProblem, I'm able to look at the marginal error (I'm using balanced OT) over the iterations, to get a sense of the convergence of the algorithm (in a more detailed way than just the boolean "converged") using:
tp.solutions[(0,1)]._errors

However, with the SpatioTemporalProblem, this _error field doesn't exist.
Is that normal for the FGW problem? Can't we calculate the marginal error for it?
Otherwise, it could maybe return the error vector of the last inner Sinkhorn loop?

Thanks!

@MUCDK
Copy link
Collaborator

MUCDK commented Dec 6, 2024

Hi @matthieuheitz ,

Thanks for raising this.

What do you mean by "it does not exist", I guess it should exist, but might be None ? At least we set it here:

self._errors = output.errors

But I do see that it might be set to False by default, as in ott-jax, it's False by default: https://github.com/ott-jax/ott/blob/690b1aed1c0519899c94dcf0ccdd84500127af61/src/ott/solvers/was_solver.py#L40 which causes it to be not saved here: https://github.com/ott-jax/ott/blob/626aad6efed729a9e167b0963f4c447a2697e119/src/ott/solvers/quadratic/gromov_wasserstein.py#L291

It thus should be possible to set it via kwargs={"store_inner_errors": True}.

Pinging @selmanozleyen who is currently working on updating moscot to ott-jax=0.5.0. @selmanozleyen , can you please verify this ?

Also, @ArinaDanilina , once we have resolved this (and updated to a new version), can we please write an example how to investigate the convergence? I.e. write an example on plotting the cost and errors, and explaining the difference?

@MUCDK MUCDK added examples documentation Improvements or additions to documentation labels Dec 6, 2024
@matthieuheitz
Copy link
Author

Thanks for your answer!

Yes, you are correct, the attribute exists, but it's equal to None.

Actually, it didn't work when I passed kwargs={"store_inner_errors": True} to solve(), but it worked when I directly passed store_inner_errors=True to solve().
The _errors field then contains a (50,200) array.

Of note, there seems to be an inconsistency between the number of (outer) iterations you provide to the function through max_iterations, and the actual number of iterations (tracked with the progress_fn), there is always one more actual iteration than what max_iteration says.
For example, when setting max_iterations=1, my callback function is called for every all 2000 inner Sinkhorn iterations, and then once more for another 2000 inner Sinkhorn iterations.
And the number of error vectors in _errors is consistent with max_iteration, but then it's also inconsistent with the actual number of iterations.
It seems that it's the error of the first outer iteration that is missing from _errors.

@matthieuheitz
Copy link
Author

matthieuheitz commented Dec 6, 2024

Oh, it might be because the first iteration is iteration 0, so max_iterations is not the max number of iterations, but it's the max iteration number.
But in that case, since e.g. the plot_costs() function plots iteration 0, it would be consistent for the error to also include the error for iteration 0.

@MUCDK
Copy link
Collaborator

MUCDK commented Dec 9, 2024

It's true that counting starts at 0, but not sure why this is not being plotted then, following

def _plot_lines(
.

@ArinaDanilina could you please check that?

@yuling999666
Copy link

Hi @MUCDK , if it is not convergent, could it be used as a solution? Thank you!

@MUCDK
Copy link
Collaborator

MUCDK commented Jan 13, 2025

Hi @yuling999666 ,

you should check the errors, and if this looks reasonable, I would suggest to increase the threshold, and then rerun. I would first play with different values of epsilon, and possibly tau_a and tau_b.

@yuling999666
Copy link

Image Hi @MUCDK , thank you for your reply. I'm curious about how should I tell the errors are reasonable. According to the attached result, I think the errors are very big and I think I need to change alpha value.

@MUCDK
Copy link
Collaborator

MUCDK commented Jan 13, 2025

Hi,

yeah not sure whether you printed all of them, but I do agree that you should first play with alpha, epsilon, tau_a/b.

@yuling999666
Copy link

Hi @MUCDK thank you for your suggestions. I appreciate it!

@yuling999666
Copy link

Hi @MUCDK , what does the errors specifically stand for? Thank you !

@MUCDK
Copy link
Collaborator

MUCDK commented Jan 14, 2025

It's the sum over the errors of the row/column sums from the marginals (in case it's a balanced problem). For unbalanced problems the covergence is determined by Cauchy sequence criteria

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation examples
Projects
None yet
Development

No branches or pull requests

5 participants