Skip to content

Commit

Permalink
Including dilations for 2D layers
Browse files Browse the repository at this point in the history
  • Loading branch information
jezsadler committed Nov 17, 2023
1 parent 149bc30 commit 6d0cf77
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 15 deletions.
30 changes: 17 additions & 13 deletions src/omlt/io/onnx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,24 +359,29 @@ def _consume_conv_nodes(self, node, next_nodes):
f"Input/output size ({input_output_size}) first dimension must match input weights channels ({in_channels})."
)

# TODO: need to check pads and dilations also have correct dimensions. Also should
# add support for autopad.
if "pads" in attr:
pads = attr["pads"]
else:
pads = 2*(len(input_output_size)-1)*[0]

if "dilations" in attr:
dilations = attr["dilations"]
else:
dilations = (len(input_output_size)-1)*[1]

Check warning on line 372 in src/omlt/io/onnx_parser.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/io/onnx_parser.py#L372

Added line #L372 was not covered by tests

# Other attributes are not supported
if "dilations" in attr and attr["dilations"] != [1, 1]:
raise ValueError(
f"{node} has non-identity dilations ({attr['dilations']}). This is not supported."
)
if attr["group"] != 1:
raise ValueError(

Check warning on line 376 in src/omlt/io/onnx_parser.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/io/onnx_parser.py#L376

Added line #L376 was not covered by tests
f"{node} has multiple groups ({attr['group']}). This is not supported."
)
if "pads" in attr:
pads = attr["pads"]
else:
pads = 2*(len(input_output_size)-1)*[0]

# generate new nodes for the node output
padding = [
pads[i] + pads[i + len(input_output_size)-1]
for i in range(len(input_output_size)-1)]
for i in range(len(input_output_size)-1)
]
output_size = [out_channels]
for w, k, s, p in zip(input_output_size[1:], kernel_shape, strides, padding):
new_w = int((w - k + p) / s) + 1
Expand Down Expand Up @@ -404,6 +409,7 @@ def _consume_conv_nodes(self, node, next_nodes):
strides,
weights,
pads=pads,
dilations=dilations,
activation=activation,
input_index_mapper=transformer,
)
Expand Down Expand Up @@ -471,13 +477,10 @@ def _consume_pool_nodes(self, node, next_nodes):
kernel_shape = attr["kernel_shape"][1:]
strides = attr["strides"] if "strides" in attr else [1] * len(kernel_shape)
pads = attr["pads"] if "pads" in attr else None
dilations = attr["dilations"] if "dilations" in attr else None

# check only kernel shape, stride, storage order are set
# everything else is not supported
if "dilations" in attr and attr["dilations"] != [1, 1]:
raise ValueError(
f"{node.name} has non-identity dilations ({attr['dilations']}). This is not supported."
)
if ("auto_pad" in attr) and (attr["auto_pad"] != "NOTSET"):
raise ValueError(

Check warning on line 485 in src/omlt/io/onnx_parser.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/io/onnx_parser.py#L485

Added line #L485 was not covered by tests
f"{node.name} has autopad set ({attr['auto_pad']}). This is not supported."
Expand Down Expand Up @@ -520,6 +523,7 @@ def _consume_pool_nodes(self, node, next_nodes):
tuple(kernel_shape),
kernel_depth,
pads=pads,
dilations=dilations,
activation=activation,
input_index_mapper=transformer,
)
Expand Down
61 changes: 59 additions & 2 deletions src/omlt/neuralnet/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ class Layer2D(Layer):
stride of the kernel.
pads : matrix-like
Padding for the kernel. Given as [left, bottom, right, top]
dilations : matrix-like
Dilations of the kernel
activation : str or None
activation function name
input_index_mapper : IndexMapper or None
Expand All @@ -240,6 +242,7 @@ def __init__(
strides,
*,
pads=None,
dilations=None,
activation=None,
input_index_mapper=None,
):
Expand All @@ -254,6 +257,10 @@ def __init__(
self.__pads = [0, 0, 0, 0]
else:
self.__pads = pads
if dilations is None:
self.__dilations = [1, 1]
else:
self.__dilations = dilations

@property
def strides(self):
Expand All @@ -275,6 +282,20 @@ def kernel_depth(self):
"""Return the depth of the kernel"""
raise NotImplementedError()

@property
def dilations(self):
"""Return the kernel dilation of the layer"""
return self.__dilations

@property
def dilated_kernel_shape(self):
"""Return the shape of the kernel after dilation"""
dilated_dims = [
self.dilations[i]*(self.kernel_shape[i]-1) + 1
for i in range(len(self.kernel_shape))
]
return tuple(dilated_dims)

def kernel_index_with_input_indexes(self, out_d, out_r, out_c):
"""
Returns an iterator over the index within the kernel and input index
Expand All @@ -290,9 +311,9 @@ def kernel_index_with_input_indexes(self, out_d, out_r, out_c):
the output column.
"""
kernel_d = self.kernel_depth
[kernel_r, kernel_c] = self.kernel_shape
[kernel_r, kernel_c] = self.dilated_kernel_shape
[rows_stride, cols_stride] = self.__strides
[pads_row, pads_col] = self.__pads[:1]
[pads_row, pads_col] = self.__pads[:2]
start_in_d = 0
start_in_r = out_r * rows_stride - pads_row
start_in_c = out_c * cols_stride - pads_col
Expand Down Expand Up @@ -362,6 +383,8 @@ class PoolingLayer2D(Layer2D):
stride of the kernel.
pads : matrix-like
Padding for the kernel. Given as [left, bottom, right, top]
dilations : matrix-like
Dilations of the kernel
pool_func : str
name of function used to pool values in a kernel to a single value.
transpose : bool
Expand All @@ -385,6 +408,7 @@ def __init__(
kernel_depth,
*,
pads=None,
dilations=None,
activation=None,
input_index_mapper=None,
):
Expand All @@ -393,6 +417,7 @@ def __init__(
output_size,
strides,
pads=pads,
dilations=dilations,
activation=activation,
input_index_mapper=input_index_mapper,
)
Expand Down Expand Up @@ -442,6 +467,8 @@ class ConvLayer2D(Layer2D):
the cross-correlation kernel.
pads : matrix-like
Padding for the kernel. Given as [left, bottom, right, top]
dilations : matrix-like
Dilations of the kernel
activation : str or None
activation function name
input_index_mapper : IndexMapper or None
Expand All @@ -456,6 +483,7 @@ def __init__(
kernel,
*,
pads=None,
dilations=None,
activation=None,
input_index_mapper=None,
):
Expand All @@ -464,10 +492,34 @@ def __init__(
output_size,
strides,
pads=pads,
dilations=dilations,
activation=activation,
input_index_mapper=input_index_mapper,
)
self.__kernel = kernel
if self.dilations != [1, 1]:
dilate_rows = np.hstack([
np.hstack([
np.hstack([
kernel[:, :, i, :].reshape((
kernel.shape[0], kernel.shape[1], 1, kernel.shape[3])),
np.zeros((
kernel.shape[0], kernel.shape[1], self.dilations[0] - 1, kernel.shape[3]))])
for i in range(kernel.shape[2]-1)]),
kernel[:, :, -1, :].reshape((kernel.shape[0], kernel.shape[1], 1, kernel.shape[3]))
])
dilate_kernel = np.dstack([
np.dstack([
np.dstack([
dilate_rows[:, :, :, i].reshape((
dilate_rows.shape[0], dilate_rows.shape[1], dilate_rows.shape[2], 1)),
np.zeros((dilate_rows.shape[0], dilate_rows.shape[1], dilate_rows.shape[2], self.dilations[1] - 1))])
for i in range(dilate_rows.shape[3]-1)]),
dilate_rows[:, :, :, -1].reshape((dilate_rows.shape[0], dilate_rows.shape[1], dilate_rows.shape[2], 1))
])
self.__dilated_kernel = dilate_kernel

Check warning on line 520 in src/omlt/neuralnet/layer.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/neuralnet/layer.py#L520

Added line #L520 was not covered by tests
else:
self.__dilated_kernel = kernel

def kernel_with_input_indexes(self, out_d, out_r, out_c):
"""
Expand Down Expand Up @@ -504,6 +556,11 @@ def kernel(self):
"""Return the cross-correlation kernel"""
return self.__kernel

@property
def dilated_kernel(self):
"""Return the dilated cross-correlation kernel"""
return self.__dilated_kernel

Check warning on line 562 in src/omlt/neuralnet/layer.py

View check run for this annotation

Codecov / codecov/patch

src/omlt/neuralnet/layer.py#L562

Added line #L562 was not covered by tests

def __str__(self):
return f"ConvLayer(input_size={self.input_size}, output_size={self.output_size}, strides={self.strides}, kernel_shape={self.kernel_shape})"

Expand Down

0 comments on commit 6d0cf77

Please sign in to comment.