The LassoNet has been implemented by their authors, however only for feed-forward neural networks with ReLU activation. Here, we try to implement the idea more generally.
Define a PyTorch network G
(i.e. some class inheriting from torch.nn.Module
) with arbitrary architecture (i.e. a forward
-method). G
must fulfill
- first layer is of type
torch.nn.Linear
and calledG.W1
. - needs the attributes
G.D_in
andG.D_out
which are input and output dimension of the network.
The LassoNet
based on G
is then initialized via
model = LassoNet(G, lambda_, M)
where lambda_
and M
are as in the paper.
- See
example.py
for a simple example on how to defineG
and how to train LassoNet. - See
example_mnist.py
for an example using the MNIST datatset.