Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making reparam an instance method in SumLayer and ExponentialFamilyArray classes #2

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 13 additions & 24 deletions src/EinsumNetwork/ExponentialFamilyArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,6 @@ def __init__(self, num_var, num_dims, array_shape, num_stats, use_em):
self._online_em_stepsize = None
self._online_em_counter = 0

# if em is switched off, we re-parametrize the expectation parameters
# self.reparam holds the function object for this task
self.reparam = None
if not self._use_em:
self.reparam = self.reparam_function()

# --------------------------------------------------------------------------------
# The following functions need to be implemented to specify an exponential family.

Expand Down Expand Up @@ -151,16 +145,17 @@ def project_params(self, params):
"""
raise NotImplementedError

def reparam_function(self):
def reparam(self, params):
"""
Re-parameterize parameters, in order that they stay in their constrained domain.

When we are not using the EM, we need to transform unconstrained (real-valued) parameters to the constrained set
of the expectation parameter. This function should return such a function (i.e. the return value should not be
of the expectation parameter.
This function should return such a function (i.e. the return value should not be
a projection, but a function which does the projection).

:return: function object f which takes as input unconstrained parameters (Tensor) and returns re-parametrized
parameters.
:param params: unconstrained parameters (Tensor) to be projected
:return: re-parametrized parameters.
"""
raise NotImplementedError

Expand Down Expand Up @@ -406,12 +401,10 @@ def project_params(self, phi):
phi_project[..., self.num_dims:] += mu2
return phi_project

def reparam_function(self):
def reparam(params_in):
mu = params_in[..., 0:self.num_dims].clone()
var = self.min_var + torch.sigmoid(params_in[..., self.num_dims:]) * (self.max_var - self.min_var)
return torch.cat((mu, var + mu**2), -1)
return reparam
def reparam(self, params_in):
mu = params_in[..., 0:self.num_dims].clone()
var = self.min_var + torch.sigmoid(params_in[..., self.num_dims:]) * (self.max_var - self.min_var)
return torch.cat((mu, var + mu**2), -1)

def sufficient_statistics(self, x):
if len(x.shape) == 2:
Expand Down Expand Up @@ -465,10 +458,8 @@ def default_initializer(self):
def project_params(self, phi):
return torch.clamp(phi, 0.0, self.N)

def reparam_function(self):
def reparam(params):
return torch.sigmoid(params * 0.1) * float(self.N)
return reparam
def reparam(self, params):
return torch.sigmoid(params * 0.1) * float(self.N)

def sufficient_statistics(self, x):
if len(x.shape) == 2:
Expand Down Expand Up @@ -534,10 +525,8 @@ def project_params(self, phi):
phi = phi / torch.sum(phi, -1, keepdim=True)
return phi.reshape(self.num_var, *self.array_shape, self.num_dims * self.K)

def reparam_function(self):
def reparam(params):
return torch.nn.functional.softmax(params, -1)
return reparam
def reparam(self, params):
return torch.nn.functional.softmax(params, -1)

def sufficient_statistics(self, x):
if len(x.shape) == 2:
Expand Down
4 changes: 2 additions & 2 deletions src/EinsumNetwork/FactorizedLeafLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def em_update(self):
def project_params(self, params):
self.ef_array.project_params(params)

def reparam_function(self):
return self.ef_array.reparam_function()
def reparam(self, params):
return self.ef_array.reparam(params)

# --------------------------------------------------------------------------------

Expand Down
40 changes: 18 additions & 22 deletions src/EinsumNetwork/SumLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ def __init__(self, params_shape, normalization_dims, use_em, params_mask=None):
self.online_em_stepsize = None
self._online_em_counter = 0

# if EM is not used, we reparametrize
self.reparam = None
if not self._use_em:
self.reparam = self.reparam_function()

# --------------------------------------------------------------------------------
# The following functions need to be implemented in derived classes.

Expand Down Expand Up @@ -193,27 +188,28 @@ def em_update(self, _triggered=False):
self.params.data = self.params / (self.params.sum(self.normalization_dims, keepdim=True))
self.params.grad = None

def reparam_function(self):
def reparam(self, params_in):
"""
Reparametrization function, transforming unconstrained parameters into valid sum-weight
(non-negative, normalized).

:params_in params: unconstrained parameters (Tensor) to be projected
:return: re-parametrized parameters.
"""
def reparam(params_in):
other_dims = tuple(i for i in range(len(params_in.shape)) if i not in self.normalization_dims)

permutation = other_dims + self.normalization_dims
unpermutation = tuple(c for i in range(len(permutation)) for c, j in enumerate(permutation) if j == i)

numel = functools.reduce(lambda x, y: x * y, [params_in.shape[i] for i in self.normalization_dims])

other_shape = tuple(params_in.shape[i] for i in other_dims)
params_in = params_in.permute(permutation)
orig_shape = params_in.shape
params_in = params_in.reshape(other_shape + (numel,))
out = softmax(params_in, -1)
out = out.reshape(orig_shape).permute(unpermutation)
return out
return reparam
other_dims = tuple(i for i in range(len(params_in.shape)) if i not in self.normalization_dims)

permutation = other_dims + self.normalization_dims
unpermutation = tuple(c for i in range(len(permutation)) for c, j in enumerate(permutation) if j == i)

numel = functools.reduce(lambda x, y: x * y, [params_in.shape[i] for i in self.normalization_dims])

other_shape = tuple(params_in.shape[i] for i in other_dims)
params_in = params_in.permute(permutation)
orig_shape = params_in.shape
params_in = params_in.reshape(other_shape + (numel,))
out = softmax(params_in, -1)
out = out.reshape(orig_shape).permute(unpermutation)
return out

def project_params(self, params):
"""Currently not required."""
Expand Down
Loading