Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do MLPs share weight by accident? #37

Open
bluenote10 opened this issue Dec 22, 2022 · 0 comments
Open

Do MLPs share weight by accident? #37

bluenote10 opened this issue Dec 22, 2022 · 0 comments

Comments

@bluenote10
Copy link

I noticed that self.in_mlps uses this expression:

self.in_mlps = nn.ModuleList([mlp(1, hidden_size, 3)] * 2)

As a result of how list multiplication works in Python, this will actually mean that self.in_mlps[0] is self.in_mlps[1], i.e., the two elements of the list have the same identity.

In terms of the network structure, I assume that this means that they share their weights. This is a bit unexpected, because self.in_mlps[0] is used on pitches and self.in_mlps[1] is used on loudness. I haven't looked into the official ddsp implementation, but at least from the paper it would not assume that the MLPs are supposed to use shared weights. Or are their semantics similar enough to justify using shared weights?

For comparison, here is the torch info output as it is implemented now (note that "Sequential: 2-2" is labeled as recursive because torch info recognizes it as shared weights):

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
DDSP                                     [16, 64000, 1]            --
├─ModuleList: 1-1                        --                        --
│    └─Sequential: 2-1                   [16, 500, 512]            --
│    │    └─Linear: 3-1                  [16, 500, 512]            1,024
│    │    └─LayerNorm: 3-2               [16, 500, 512]            1,024
│    │    └─LeakyReLU: 3-3               [16, 500, 512]            --
│    │    └─Linear: 3-4                  [16, 500, 512]            262,656
│    │    └─LayerNorm: 3-5               [16, 500, 512]            1,024
│    │    └─LeakyReLU: 3-6               [16, 500, 512]            --
│    │    └─Linear: 3-7                  [16, 500, 512]            262,656
│    │    └─LayerNorm: 3-8               [16, 500, 512]            1,024
│    │    └─LeakyReLU: 3-9               [16, 500, 512]            --
│    └─Sequential: 2-2                   [16, 500, 512]            (recursive)
│    │    └─Linear: 3-10                 [16, 500, 512]            (recursive)
│    │    └─LayerNorm: 3-11              [16, 500, 512]            (recursive)
│    │    └─LeakyReLU: 3-12              [16, 500, 512]            --
│    │    └─Linear: 3-13                 [16, 500, 512]            (recursive)
│    │    └─LayerNorm: 3-14              [16, 500, 512]            (recursive)
│    │    └─LeakyReLU: 3-15              [16, 500, 512]            --
│    │    └─Linear: 3-16                 [16, 500, 512]            (recursive)
│    │    └─LayerNorm: 3-17              [16, 500, 512]            (recursive)
│    │    └─LeakyReLU: 3-18              [16, 500, 512]            --
├─GRU: 1-2                               [16, 500, 512]            2,362,368
├─Sequential: 1-3                        [16, 500, 512]            --
│    └─Linear: 2-3                       [16, 500, 512]            263,680
│    └─LayerNorm: 2-4                    [16, 500, 512]            1,024
│    └─LeakyReLU: 2-5                    [16, 500, 512]            --
│    └─Linear: 2-6                       [16, 500, 512]            262,656
│    └─LayerNorm: 2-7                    [16, 500, 512]            1,024
│    └─LeakyReLU: 2-8                    [16, 500, 512]            --
│    └─Linear: 2-9                       [16, 500, 512]            262,656
│    └─LayerNorm: 2-10                   [16, 500, 512]            1,024
│    └─LeakyReLU: 2-11                   [16, 500, 512]            --
├─ModuleList: 1-4                        --                        --
│    └─Linear: 2-12                      [16, 500, 81]             41,553
│    └─Linear: 2-13                      [16, 500, 65]             33,345
├─Reverb: 1-5                            [16, 64000, 1]            16,002
==========================================================================================
Total params: 3,774,740
Trainable params: 3,774,740
Non-trainable params: 0
Total mult-adds (G): 18.93
==========================================================================================
Input size (MB): 0.06
Forward/backward pass size (MB): 443.52
Params size (MB): 15.10
Estimated Total Size (MB): 458.68
==========================================================================================

This is how the model changes when replacing the expression to [mlp(1, config.hidden_size, 3) for _ in range(2)] to give them individual weights (note the increased model size):

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
DDSP                                     [16, 64000, 1]            --
├─ModuleList: 1-1                        --                        --
│    └─Sequential: 2-1                   [16, 500, 512]            --
│    │    └─Linear: 3-1                  [16, 500, 512]            1,024
│    │    └─LayerNorm: 3-2               [16, 500, 512]            1,024
│    │    └─LeakyReLU: 3-3               [16, 500, 512]            --
│    │    └─Linear: 3-4                  [16, 500, 512]            262,656
│    │    └─LayerNorm: 3-5               [16, 500, 512]            1,024
│    │    └─LeakyReLU: 3-6               [16, 500, 512]            --
│    │    └─Linear: 3-7                  [16, 500, 512]            262,656
│    │    └─LayerNorm: 3-8               [16, 500, 512]            1,024
│    │    └─LeakyReLU: 3-9               [16, 500, 512]            --
│    └─Sequential: 2-2                   [16, 500, 512]            --
│    │    └─Linear: 3-10                 [16, 500, 512]            1,024
│    │    └─LayerNorm: 3-11              [16, 500, 512]            1,024
│    │    └─LeakyReLU: 3-12              [16, 500, 512]            --
│    │    └─Linear: 3-13                 [16, 500, 512]            262,656
│    │    └─LayerNorm: 3-14              [16, 500, 512]            1,024
│    │    └─LeakyReLU: 3-15              [16, 500, 512]            --
│    │    └─Linear: 3-16                 [16, 500, 512]            262,656
│    │    └─LayerNorm: 3-17              [16, 500, 512]            1,024
│    │    └─LeakyReLU: 3-18              [16, 500, 512]            --
├─GRU: 1-2                               [16, 500, 512]            2,362,368
├─Sequential: 1-3                        [16, 500, 512]            --
│    └─Linear: 2-3                       [16, 500, 512]            263,680
│    └─LayerNorm: 2-4                    [16, 500, 512]            1,024
│    └─LeakyReLU: 2-5                    [16, 500, 512]            --
│    └─Linear: 2-6                       [16, 500, 512]            262,656
│    └─LayerNorm: 2-7                    [16, 500, 512]            1,024
│    └─LeakyReLU: 2-8                    [16, 500, 512]            --
│    └─Linear: 2-9                       [16, 500, 512]            262,656
│    └─LayerNorm: 2-10                   [16, 500, 512]            1,024
│    └─LeakyReLU: 2-11                   [16, 500, 512]            --
├─ModuleList: 1-4                        --                        --
│    └─Linear: 2-12                      [16, 500, 81]             41,553
│    └─Linear: 2-13                      [16, 500, 65]             33,345
├─Reverb: 1-5                            [16, 64000, 1]            16,002
==========================================================================================
Total params: 4,304,148
Trainable params: 4,304,148
Non-trainable params: 0
Total mult-adds (G): 18.93
==========================================================================================
Input size (MB): 0.06
Forward/backward pass size (MB): 640.13
Params size (MB): 17.22
Estimated Total Size (MB): 657.41
==========================================================================================

I'm just wondering if this is rather an accident that happens to work okay-ish or if the weights were merged deliberately as a performance optimization?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant