From a861e69d8908500c06372c05b5d9a06858004943 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Fri, 18 Dec 2020 17:17:02 -0500 Subject: [PATCH] modfications to TDNNF semi-orthogonal error --- README.md | 23 ++++++++++++++++------- pytorch_tdnn/tdnn.py | 6 ++++-- pytorch_tdnn/tdnnf.py | 25 +++++++------------------ 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index ae8a460..002f916 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ pip install pytorch-tdnn ``` To install for development, clone the repository, and then run the following from -within the roor directory. +within the root directory. ```bash pip install -e . @@ -34,7 +34,7 @@ tdnn = TDNNLayer( y = tdnn(x) ``` -Here, `x` should have the shape `(batch_size, sequence_length, input_dim)`. +Here, `x` should have the shape `(batch_size, input_dim, sequence_length)`. **Note:** The `context` list should follow these constraints: * The length of the list should be 2 or an odd number. @@ -55,20 +55,29 @@ tdnnf = TDNNFLayer( 1, # time stride ) -y = tdnnf(x, training=True) +y = tdnnf(x, semi_ortho_step=True) ``` -The argument `training` is used to perform the semi-orthogonality step only during -the model training. If this call is made from within a `forward()` function of an -`nn.Module` class, `training` can be set to `self.training`. +The argument `semi_ortho_step` determines whether to take the step towards semi- +orthogonality for the constrained convolutional layers in the 3-stage splicing. +If this call is made from within a `forward()` function of an +`nn.Module` class, it can be set as follows to approximate Kaldi-style training +where the step is taken once every 4 iterations: + +```python +import random +semi_ortho_step = self.training and (random.uniform(0,1) < 0.25) +``` **Note:** Time stride should be greater than or equal to 0. For example, if the time stride is 1, a context of `[-1,1]` is used for each stage of splicing. ### Credits -* The TDNN implementation is based on: https://github.com/jonasvdd/TDNN. +* The TDNN implementation is based on: https://github.com/jonasvdd/TDNN and https://github.com/m-wiesner/nnet_pytorch. * Semi-orthogonal convolutions used in TDNN-F are based on: https://github.com/cvqluu/Factorized-TDNN. +* Thanks to [Matthew Wiesner](https://github.com/m-wiesner) for helpful discussions +about the implementations. This repository aims to wrap up these implementations in easy-installable PyPi packages, which can be used directly in PyTorch based neural network training. diff --git a/pytorch_tdnn/tdnn.py b/pytorch_tdnn/tdnn.py index cf1b7ad..acf3c0d 100644 --- a/pytorch_tdnn/tdnn.py +++ b/pytorch_tdnn/tdnn.py @@ -7,7 +7,8 @@ class TDNN(torch.nn.Module): def __init__(self, input_dim: int, output_dim: int, - context: list): + context: list, + bias: bool = True): """ Implementation of TDNN using the dilation argument of the PyTorch Conv1d class Due to its fastness the context has gained two constraints: @@ -44,7 +45,8 @@ def __init__(self, output_dim, kernel_size=kernel_size, dilation=dilation, - padding=padding + padding=padding, + bias=bias # will be set to False for semi-orthogonal TDNNF convolutions )) def forward(self, x): diff --git a/pytorch_tdnn/tdnnf.py b/pytorch_tdnn/tdnnf.py index 27e3be7..aa2491a 100644 --- a/pytorch_tdnn/tdnnf.py +++ b/pytorch_tdnn/tdnnf.py @@ -1,7 +1,5 @@ # This implementation is based on: https://github.com/cvqluu/Factorized-TDNN -import random - import torch import torch.nn.functional as F @@ -14,7 +12,7 @@ def __init__(self, input_dim: int, output_dim: int, context: list, - init: str = 'kaldi'): + init: str = 'xavier'): """ Semi-orthogonal convolutions. The forward function takes an additional parameter that specifies whether to take the semi-orthogonality step. @@ -23,7 +21,7 @@ def __init__(self, :param output_dim: The number of channels produced by the temporal convolution :param init: Initialization method for weight matrix (default = Kaldi-style) """ - super(SemiOrthogonalConv, self).__init__(input_dim, output_dim, context) + super(SemiOrthogonalConv, self).__init__(input_dim, output_dim, context, bias=False) self.init_method = init self.reset_parameters() @@ -38,7 +36,7 @@ def reset_parameters(self): elif self.init_method == 'xavier': # Use Xavier initialization torch.nn.init.xavier_normal_( - self.temporal_conv + self.temporal_conv.weight ) def step_semi_orth(self): @@ -72,12 +70,11 @@ def get_semi_orth_weight(M): ratio = trace_PP * P.shape[0] / (trace_P * trace_P) # the following is the tweak to avoid divergence (more info in Kaldi) - assert ratio > 0.99, "Ratio of traces is less than 0.99" + # assert ratio > 0.9, "Ratio of traces is less than 0.9" if ratio > 1.02: update_speed *= 0.5 if ratio > 1.1: update_speed *= 0.5 - scale2 = trace_PP/trace_P update = P - (torch.matrix_power(P, 0) * scale2) alpha = update_speed / scale2 @@ -106,12 +103,7 @@ def get_semi_orth_error(M): if mshape[0] > mshape[1]: # semi orthogonal constraint for rows > cols M = M.T P = torch.mm(M, M.T) - PP = torch.mm(P, P.T) - trace_P = torch.trace(P) - trace_PP = torch.trace(PP) - scale2 = torch.sqrt(trace_PP/trace_P) ** 2 - update = P - (torch.matrix_power(P, 0) * scale2) - return torch.norm(update, p='fro') + return torch.norm(P, p='fro') def forward(self, x, semi_ortho_step = False): """ @@ -149,8 +141,6 @@ def __init__(self, self.bottleneck_dim = bottleneck_dim self.output_dim = output_dim - random.seed(0) - if time_stride == 0: context = [0] else: @@ -160,14 +150,13 @@ def __init__(self, self.factor2 = SemiOrthogonalConv(bottleneck_dim, bottleneck_dim, context) self.factor3 = TDNN(bottleneck_dim, output_dim, context) - def forward(self, x, training=True): + def forward(self, x, semi_ortho_step=True): """ :param x: is one batch of data, x.size(): [batch_size, input_dim, in_seq_length] sequence length is the dimension of the arbitrary length data - :param training: True if model is in training phase + :param semi_ortho_step: if True, update parameter for semi-orthogonality :return: [batch_size, output_dim, out_seq_length] """ - semi_ortho_step = training and (random.uniform(0,1) < 0.25) x = self.factor1(x, semi_ortho_step=semi_ortho_step) x = self.factor2(x, semi_ortho_step=semi_ortho_step) x = self.factor3(x)