Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Puzzles about Inconsistency between code and article #3

Open
QJ-Chen opened this issue Sep 28, 2022 · 4 comments
Open

Puzzles about Inconsistency between code and article #3

QJ-Chen opened this issue Sep 28, 2022 · 4 comments

Comments

@QJ-Chen
Copy link

QJ-Chen commented Sep 28, 2022

# losses.sliced_sm
def sliced_score_estimation(score_net, samples, n_particles=1):
    dup_samples = samples.unsqueeze(0).expand(n_particles, *samples.shape).contiguous().view(-1, *samples.shape[1:])
    dup_samples.requires_grad_(True)
    vectors = torch.randn_like(dup_samples)
    vectors = vectors / torch.norm(vectors, dim=-1, keepdim=True)

    grad1 = score_net(dup_samples)  # H, estimation of score
    gradv = torch.sum(grad1 * vectors)  # project H with v
    loss1 = torch.sum(grad1 * vectors, dim=-1) ** 2 * 0.5  # second term of J(\theta) 
    grad2 = autograd.grad(gradv, dup_samples, create_graph=True)[0] # grad of h w.r.t samples(z)
    loss2 = torch.sum(vectors * grad2, dim=-1)

    loss1 = loss1.view(n_particles, -1).mean(dim=0)
    loss2 = loss2.view(n_particles, -1).mean(dim=0)

    loss = loss1 + loss2
    return loss.mean(), loss1.mean(), loss2.mean()

# losses.vae.elbo_ssm
z = imp_encoder(X)
ssm_loss, *_ = sliced_score_estimation_vr(functools.partial(score, dup_X), z, n_particles=n_particles)

To my understanding, grad1 is the estimation of score $h = S_{m}(x;\theta)$ and loss2 is the first term of $J(\theta)$, which is $v^{T}\nabla_{x}h(x;\theta)v$. But in the code, it seems to be calculated as $v^{T}\nabla_{z}h(x;\theta)v$.

@cnut1648
Copy link

Yes @chen-qj, I noticed this too. Did you figure out why?

@ifgovh
Copy link

ifgovh commented Nov 18, 2022

I noticed another question. The multiplication of vectors and grad1/2 is element-wise but in the paper, it is matrix multiplication. Or I misunderstand the theory?

@dongdongunique
Copy link

# losses.sliced_sm
def sliced_score_estimation(score_net, samples, n_particles=1):
    dup_samples = samples.unsqueeze(0).expand(n_particles, *samples.shape).contiguous().view(-1, *samples.shape[1:])
    dup_samples.requires_grad_(True)
    vectors = torch.randn_like(dup_samples)
    vectors = vectors / torch.norm(vectors, dim=-1, keepdim=True)

    grad1 = score_net(dup_samples)  # H, estimation of score
    gradv = torch.sum(grad1 * vectors)  # project H with v
    loss1 = torch.sum(grad1 * vectors, dim=-1) ** 2 * 0.5  # second term of J(\theta) 
    grad2 = autograd.grad(gradv, dup_samples, create_graph=True)[0] # grad of h w.r.t samples(z)
    loss2 = torch.sum(vectors * grad2, dim=-1)

    loss1 = loss1.view(n_particles, -1).mean(dim=0)
    loss2 = loss2.view(n_particles, -1).mean(dim=0)

    loss = loss1 + loss2
    return loss.mean(), loss1.mean(), loss2.mean()

# losses.vae.elbo_ssm
z = imp_encoder(X)
ssm_loss, *_ = sliced_score_estimation_vr(functools.partial(score, dup_X), z, n_particles=n_particles)

To my understanding, grad1 is the estimation of score h=Sm(x;θ) and loss2 is the first term of J(θ), which is vT∇xh(x;θ)v. But in the code, it seems to be calculated as vT∇zh(x;θ)v.

image

The Author is not using score matching to learn the data distribution $x$, instead, he uses the score matching to compute the entropy's gradient of implicit distribution. So, the code is computing the gradient of $z$ instead of the $x$.

@dongdongunique
Copy link

I noticed another question. The multiplication of vectors and grad1/2 is element-wise but in the paper, it is matrix multiplication. Or I misunderstand the theory?

They are equivalent. Flatten the data into one dimension, you will find it easier to understand.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants