-
Notifications
You must be signed in to change notification settings - Fork 269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement ConvXDTranspose #853
Conversation
ds-hwang
commented
Nov 21, 2024
•
edited
Loading
edited
c388c65
to
4005ff0
Compare
a67740b
to
5f6f481
Compare
This PR implements unified transpose convolution covering 1D/2D/3D, SAME/VALID/CAUSAL and arbitrary padding, arbitrary window, stride, and dilation. SAME and VALID is equivalent to jax.lax.conv_transpose(). CAUSAL is defined in this PR. Each Literal padding follows the formulas below, * SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2))) pad_total = window+stride-2 when stride > window -> (window-1, stride-1) * VALID: padding=(window-1, max(stride-1, window-1)) pad_total = window+stride-2 + max(window-stride, 0) when stride > window -> (window-1, stride-1) * CAUSAL: padding=(window-1, stride-1) pad_total = window+stride-2 Note: output_size = input_size*stride - (window+stride-2) + pad_total = input_size*stride <- "SAME", "CAUSAL" = input_size*stride + max(window-stride, 0) <- "VALID" Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1. dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() The following illustration demonstrates how Conv Transpose operates, assuming all kernel values are set to 1 for simplicity in showcasing output values. In the window=3 and stride=1 case, this function creates outputs as follows: * "SAME" padding=(1, 1) pad| |pad paddings: 0|0 0 1 1|0 0 0 0 -> 0 0 0 1 -> 1 0 1 1 -> 2 1 1 0 -> 2 * "VALID" padding=(2, 2) pad | |pad paddings: 0 0|0 0 1 1|0 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 1 -> 2 1 1 0 -> 2 1 0 0 -> 1 * "CAUSAL" padding=(2, 0) pad | |pad paddings: 0 0|0 0 1 1| 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 1 -> 2 In the window=3 and stride=2 case, this function creates outputs as follows: * "SAME" padding=(2, 1) pad | |pad paddings: 0 0|0 * 0 * 1 * 1|0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 1 -> 2 0 1 0 -> 1 * "VALID" padding=(2, 2) pad | |pad paddings: 0 0|0 * 0 * 1 * 1|0 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 1 -> 2 0 1 0 -> 1 1 0 0 -> 1 * "CAUSAL" padding=(2, 1) pad | |pad paddings: 0 0|0 * 0 * 1 * 1|0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 1 -> 2 0 1 0 -> 1 In the window=3 and stride=3 case, this function creates outputs as follows: * "SAME", "VALID" and "CAUSAL" padding=(2, 2) pad | |pad paddings: 0 0|0 * * 0 * * 1 * * 1|0 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 0 -> 1 0 0 1 -> 1 0 1 0 -> 1 1 0 0 -> 1 In the window=3 and stride=4 case, this function creates outputs as follows: * "SAME", "VALID" and "CAUSAL" padding=(2, 3) pad | |pad paddings: 0 0|0 * * * 0 * * * 1 * * * 1|0 0 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 0 -> 1 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 0 -> 1 0 0 0 -> 0 Here is how to compute output_size, given the above example, 1. |_| -(window-1) 2. |_______________________| (input_size-1)*stride + 1 3. |_| |___| + pad_total So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total = input_size*stride - (window+stride-2) + pad_total = input_size*stride <- "SAME", "CAUSAL" = input_size*stride + max(window-stride, 0) <- "VALID" OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. For example, when window=3 and dilation=2, dilate_window=5. In the stride=2 case, this function creates outputs as follows: * "SAME" padding=(3, 2) pad | |pad paddings: 0 0 0|0 * 0 * 1 * 1|0 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 1 -> 1 0 * 0 * 0 -> 0 0 * 1 * 1 -> 2 0 * 0 * 0 -> 0 1 * 1 * 0 -> 2 * "VALID" padding=(4, 4) pad | |pad paddings: 0 0 0 0|0 * 0 * 1 * 1|0 0 0 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 1 -> 1 0 * 0 * 0 -> 0 0 * 1 * 1 -> 2 0 * 0 * 0 -> 0 1 * 1 * 0 -> 2 0 * 0 * 0 -> 0 1 * 0 * 0 -> 1 * "CAUSAL" padding=(4, 1) pad | |pad paddings: 0 0 0 0|0 * 0 * 1 * 1|0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 1 -> 1 0 * 0 * 0 -> 0 0 * 1 * 1 -> 2 0 * 0 * 0 -> 0
@ruomingp could you approve it? from 926 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Dongseong. One question...
In the window=3 and stride=2 case, this function creates outputs as follows: | ||
* "SAME" padding=(2, 1) | ||
pad | |pad | ||
paddings: 0 0|0 * 0 * 1 * 1|0 | ||
0 0 0 -> 0 | ||
0 0 0 -> 0 | ||
0 0 0 -> 0 | ||
0 0 0 -> 0 | ||
0 0 1 -> 1 | ||
0 1 0 -> 1 | ||
1 0 1 -> 2 | ||
0 1 0 -> 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why stride=2 but the adjacent conv transpose windows shown are only one position apart from each other.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, in transposed convolution, stride=2 is not window stride but input dilation.
Given input |0 0 1 1|
stride=2 dilates it to |0 * 0 * 1 * 1|
and padding=(2,1) add padding as 0 0|0 * 0 * 1 * 1|0
.
That's all transposed convolution is about. After this, window=3
and window_stride=1 convolution takes place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is my understanding:
inputs = jnp.asarray([0, 1, 2, 3, 4]).reshape([1, -1, 1])
kernel = jnp.asarray([1, 2, 4]).reshape([-1, 1, 1])
# padded_inputs = [_, _, 0, 1, 2, 3, 4]
# windows = [[_, _, 0],
# [0, 0, 4],
# [0, 2, 8]
# [1, 4,12]
# [2, 6,16]
# [3, 8, _]
# strided = [[_, _, 0],
# [0, 0, 4],
# [0, 2, 8]
# [1, 4,12]
# [2, 6,16]
# [3, 8, _]
# sum = [0, 0, 4, 2, 9, 4,14, 6,19].
outputs = jax.lax.conv_transpose(inputs, rhs=kernel, padding=((2, 0),), strides=(2,))
assert_allclose(outputs, jnp.asarray([0, 0, 4, 2, 9, 4, 14, 6, 19]).reshape((1, -1, 1)))
Note that in this case, I only see 6 convolution windows that affect the outputs (5 windows if the input length is 4). So I'm confused by the 8 windows in the above comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to unblock follow-up work, but let's also resolve the question about the conv transpose windows...
Thank you for review! |