Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REQ] Support for access to internal variables + plugin output in factorize algorithm #238

Open
campsd opened this issue May 18, 2022 · 0 comments
Assignees

Comments

@campsd
Copy link
Collaborator

campsd commented May 18, 2022

For a historic_convergence plugin, we would want to save the loss, penalty and combination of both for every vector instance separately.

This is currently not possible externally as the GradientDescentState only stores the combined loss over all vector instances.

This can be overcome by defining the plugin inside the factorize method:

    hc = []
    '''define plugin for historic convergence data'''
    @gradient_descent_plugin(every=1)
    def historic_convergence(state: GradientDescentState):
        # TODO: use external validation set
        loss_val = ab.to_numpy(
            loss(fac(), target, sum_vec=False, vectorized_along_last=append)
        )
        penalty_val = ab.to_numpy(
            fac.penalty(sum_leafs=True, sum_vec=False)
        )

        for i, lp in enumerate(zip(loss_val, penalty_val)):
            hc.append(
                dict(
                    step=state.step,
                    vec=i,
                    loss=lp[0],
                    penalty=lp[1],
                    loss_and_penalty=lp[0] + penalty_weight * lp[1]
                )
            )

It would be good to have a way to actually do this as an external plugin.

Secondly, for the historic_convergence plugin the user would want access to the hc data. So there should be a way to let a plugin return information as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants