From 668f192e433c4219fb5aa35dda8cf1700e429e28 Mon Sep 17 00:00:00 2001 From: Jeremy Sadler <53983960+jezsadler@users.noreply.github.com> Date: Fri, 17 Nov 2023 01:26:29 +0000 Subject: [PATCH] Linting for dilations pt 1 --- src/omlt/io/onnx_parser.py | 8 ++-- src/omlt/neuralnet/layer.py | 88 +++++++++++++++++++++++++++---------- 2 files changed, 70 insertions(+), 26 deletions(-) diff --git a/src/omlt/io/onnx_parser.py b/src/omlt/io/onnx_parser.py index 3ddb8005..1d1a3d6e 100644 --- a/src/omlt/io/onnx_parser.py +++ b/src/omlt/io/onnx_parser.py @@ -364,12 +364,12 @@ def _consume_conv_nodes(self, node, next_nodes): if "pads" in attr: pads = attr["pads"] else: - pads = 2*(len(input_output_size)-1)*[0] + pads = 2 * (len(input_output_size) - 1) * [0] if "dilations" in attr: dilations = attr["dilations"] else: - dilations = (len(input_output_size)-1)*[1] + dilations = (len(input_output_size) - 1) * [1] # Other attributes are not supported if attr["group"] != 1: @@ -379,8 +379,8 @@ def _consume_conv_nodes(self, node, next_nodes): # generate new nodes for the node output padding = [ - pads[i] + pads[i + len(input_output_size)-1] - for i in range(len(input_output_size)-1) + pads[i] + pads[i + len(input_output_size) - 1] + for i in range(len(input_output_size) - 1) ] output_size = [out_channels] for w, k, s, p in zip(input_output_size[1:], kernel_shape, strides, padding): diff --git a/src/omlt/neuralnet/layer.py b/src/omlt/neuralnet/layer.py index f38fc747..9cd0a039 100644 --- a/src/omlt/neuralnet/layer.py +++ b/src/omlt/neuralnet/layer.py @@ -291,7 +291,7 @@ def dilations(self): def dilated_kernel_shape(self): """Return the shape of the kernel after dilation""" dilated_dims = [ - self.dilations[i]*(self.kernel_shape[i]-1) + 1 + self.dilations[i] * (self.kernel_shape[i] - 1) + 1 for i in range(len(self.kernel_shape)) ] return tuple(dilated_dims) @@ -333,8 +333,7 @@ def kernel_index_with_input_indexes(self, out_d, out_r, out_c): # as this could require using a partial kernel # even though we loop over ALL kernel indexes. if not all( - input_index[i] < self.input_size[i] - and input_index[i] >= 0 + input_index[i] < self.input_size[i] and input_index[i] >= 0 for i in range(len(input_index)) ): continue @@ -498,25 +497,70 @@ def __init__( ) self.__kernel = kernel if self.dilations != [1, 1]: - dilate_rows = np.hstack([ - np.hstack([ - np.hstack([ - kernel[:, :, i, :].reshape(( - kernel.shape[0], kernel.shape[1], 1, kernel.shape[3])), - np.zeros(( - kernel.shape[0], kernel.shape[1], self.dilations[0] - 1, kernel.shape[3]))]) - for i in range(kernel.shape[2]-1)]), - kernel[:, :, -1, :].reshape((kernel.shape[0], kernel.shape[1], 1, kernel.shape[3])) - ]) - dilate_kernel = np.dstack([ - np.dstack([ - np.dstack([ - dilate_rows[:, :, :, i].reshape(( - dilate_rows.shape[0], dilate_rows.shape[1], dilate_rows.shape[2], 1)), - np.zeros((dilate_rows.shape[0], dilate_rows.shape[1], dilate_rows.shape[2], self.dilations[1] - 1))]) - for i in range(dilate_rows.shape[3]-1)]), - dilate_rows[:, :, :, -1].reshape((dilate_rows.shape[0], dilate_rows.shape[1], dilate_rows.shape[2], 1)) - ]) + dilate_rows = np.hstack( + [ + np.hstack( + [ + np.hstack( + [ + kernel[:, :, i, :].reshape( + ( + kernel.shape[0], + kernel.shape[1], + 1, + kernel.shape[3] + ) + ), + np.zeros( + ( + kernel.shape[0], + kernel.shape[1], + self.dilations[0] - 1, + kernel.shape[3] + ) + ) + ] + ) + for i in range(kernel.shape[2] - 1) + ] + ), + kernel[:, :, -1, :].reshape( + (kernel.shape[0], kernel.shape[1], 1, kernel.shape[3]) + ), + ] + ) + dilate_kernel = np.dstack( + [ + np.dstack( + [ + np.dstack( + [ + dilate_rows[:, :, :, i].reshape( + ( + dilate_rows.shape[0], + dilate_rows.shape[1], + dilate_rows.shape[2], + 1 + ) + ), + np.zeros( + ( + dilate_rows.shape[0], + dilate_rows.shape[1], + dilate_rows.shape[2], + self.dilations[1] - 1 + ) + ) + ] + ) + for i in range(dilate_rows.shape[3]-1) + ] + ), + dilate_rows[:, :, :, -1].reshape( + (dilate_rows.shape[0], dilate_rows.shape[1], dilate_rows.shape[2], 1) + ), + ] + ) self.__dilated_kernel = dilate_kernel else: self.__dilated_kernel = kernel