From 68a894a5875bfd958b8254afd3bbb23db9c2e813 Mon Sep 17 00:00:00 2001 From: Piotr Dabkowski Date: Tue, 2 Aug 2022 11:34:10 +0200 Subject: [PATCH] Fix uninitialized parameter in conformer relative attention. (#18368) `torch.Tensor` creates an unitialized tensor (as via `torch.empty`), this leads to undeterministic behavior, poor initialization, and nans if you have unlucky init. The paper does not specify the initialization for bias terms, so I guess zero seems like a good choice - no bias initially. `torch.Tensor` is usually populated with zeros, so this fix will be close to the intended behavior: ``` >>> torch.Tensor(100, 100).sum() tensor(0.) >>> torch.Tensor(100, 100).sum() tensor(nan) >>> torch.Tensor(100, 100).sum() tensor(0.) ``` --- .../models/wav2vec2_conformer/modeling_wav2vec2_conformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index a8ff9b5b20fe00..4c4962b155c35c 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -670,8 +670,8 @@ def __init__(self, config): self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_size)) - self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_size)) + self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) def forward( self,