Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ConvXDTranspose #853

Merged
merged 1 commit into from
Nov 25, 2024
Merged

Implement ConvXDTranspose #853

merged 1 commit into from
Nov 25, 2024

Conversation

ds-hwang
Copy link
Contributor

@ds-hwang ds-hwang commented Nov 21, 2024

This PR implements unified transpose convolution covering 1D/2D/3D, SAME/VALID/CAUSAL and arbitrary
padding, arbitrary window, stride, and dilation.

SAME and VALID is equivalent to jax.lax.conv_transpose(). CAUSAL is defined in this PR.

Each Literal padding follows the formulas below,
* SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2)))
     pad_total = window+stride-2
     when stride > window -> (window-1, stride-1)
* VALID: padding=(window-1, max(stride-1, window-1)) 
     pad_total = window+stride-2 + max(window-stride, 0)
     when stride > window -> (window-1, stride-1)
* CAUSAL: padding=(window-1, stride-1)
     pad_total = window+stride-2

Note: output_size = input_size*stride - (window+stride-2) + pad_total
                  = input_size*stride  <- "SAME", "CAUSAL"
                  = input_size*stride + max(window-stride, 0)  <- "VALID"

Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1.
    dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window()

The following illustration demonstrates how Conv Transpose operates, assuming all kernel values are set
to 1 for simplicity in showcasing output values.

In the window=3 and stride=1 case, this function creates outputs as follows:
* "SAME" padding=(1, 1)
                pad|       |pad
    paddings:     0|0 0 1 1|0
                  0 0 0  -> 0
                    0 0 1  -> 1
                      0 1 1  -> 2
                        1 1 0  -> 2

* "VALID" padding=(2, 2)
                pad  |       |pad
    paddings:     0 0|0 0 1 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 1  -> 1
                        0 1 1  -> 2
                          1 1 0  -> 2
                            1 0 0  -> 1

* "CAUSAL" padding=(2, 0)
                pad  |       |pad
    paddings:     0 0|0 0 1 1|
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 1  -> 1
                        0 1 1  -> 2

In the window=3 and stride=2 case, this function creates outputs as follows:
* "SAME" padding=(2, 1)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1

* "VALID" padding=(2, 2)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1
                                  1 0 0  -> 1

* "CAUSAL" padding=(2, 1)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1

In the window=3 and stride=3 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 2)
                pad  |                   |pad
    paddings:     0 0|0 * * 0 * * 1 * * 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 0  -> 0
                            0 0 0  -> 0
                              0 0 1  -> 1
                                0 1 0  -> 1
                                  1 0 0  -> 1
                                    0 0 1  -> 1
                                      0 1 0  -> 1
                                        1 0 0  -> 1

In the window=3 and stride=4 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 3)
                pad  |                         |pad
    paddings:     0 0|0 * * * 0 * * * 1 * * * 1|0 0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 0  -> 0
                            0 0 0  -> 0
                              0 0 0  -> 0
                                0 0 0  -> 0
                                  0 0 1  -> 1
                                    0 1 0  -> 1
                                      1 0 0  -> 1
                                        0 0 0  -> 0
                                          0 0 1  -> 1
                                            0 1 0  -> 1
                                              1 0 0  -> 1
                                                0 0 0  -> 0
    Here is how to compute output_size, given the above example,
      1.          |_|  -(window-1)
      2.              |_______________________|  (input_size-1)*stride + 1
      3.          |_|                           |___|  + pad_total

    So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total
                    = input_size*stride - (window+stride-2) + pad_total
                    = input_size*stride  <- "SAME", "CAUSAL"
                    = input_size*stride + max(window-stride, 0)  <- "VALID"

OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. For example, 
when window=3 and dilation=2, dilate_window=5.

In the stride=2 case, this function creates outputs as follows:
* "SAME" padding=(3, 2)
                pad    |             |pad
    paddings:     0 0 0|0 * 0 * 1 * 1|0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 1  -> 1
                          0 * 0 * 0  -> 0
                            0 * 1 * 1  -> 2
                              0 * 0 * 0  -> 0
                                1 * 1 * 0  -> 2

* "VALID" padding=(4, 4)
                pad      |             |pad
    paddings:     0 0 0 0|0 * 0 * 1 * 1|0 0 0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 0  -> 0
                          0 * 0 * 1  -> 1
                            0 * 0 * 0  -> 0
                              0 * 1 * 1  -> 2
                                0 * 0 * 0  -> 0
                                  1 * 1 * 0  -> 2
                                    0 * 0 * 0  -> 0
                                      1 * 0 * 0  -> 1

* "CAUSAL" padding=(4, 1)
                pad      |             |pad
    paddings:     0 0 0 0|0 * 0 * 1 * 1|0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 0  -> 0
                          0 * 0 * 1  -> 1
                            0 * 0 * 0  -> 0
                              0 * 1 * 1  -> 2
                                0 * 0 * 0  -> 0

@ds-hwang ds-hwang marked this pull request as ready for review November 22, 2024 22:57
This PR implements unified transpose convolution covering 1D/2D/3D,
SAME/VALID/CAUSAL and arbitrary padding, arbitrary window, stride, and
dilation.

SAME and VALID is equivalent to jax.lax.conv_transpose(). CAUSAL is defined in
this PR.

Each Literal padding follows the formulas below,
* SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2)))
    pad_total = window+stride-2
    when stride > window -> (window-1, stride-1)
* VALID: padding=(window-1, max(stride-1, window-1))
    pad_total = window+stride-2 + max(window-stride, 0)
    when stride > window -> (window-1, stride-1)
* CAUSAL: padding=(window-1, stride-1)
    pad_total = window+stride-2

Note: output_size = input_size*stride - (window+stride-2) + pad_total
                  = input_size*stride  <- "SAME", "CAUSAL"
                  = input_size*stride + max(window-stride, 0)  <- "VALID"

Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1.
    dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window()

The following illustration demonstrates how Conv Transpose operates, assuming all kernel values
are set to 1 for simplicity in showcasing output values.

In the window=3 and stride=1 case, this function creates outputs as follows:
* "SAME" padding=(1, 1)
                pad|       |pad
    paddings:     0|0 0 1 1|0
                  0 0 0  -> 0
                    0 0 1  -> 1
                      0 1 1  -> 2
                        1 1 0  -> 2

* "VALID" padding=(2, 2)
                pad  |       |pad
    paddings:     0 0|0 0 1 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 1  -> 1
                        0 1 1  -> 2
                          1 1 0  -> 2
                            1 0 0  -> 1

* "CAUSAL" padding=(2, 0)
                pad  |       |pad
    paddings:     0 0|0 0 1 1|
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 1  -> 1
                        0 1 1  -> 2

In the window=3 and stride=2 case, this function creates outputs as follows:
* "SAME" padding=(2, 1)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1

* "VALID" padding=(2, 2)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1
                                  1 0 0  -> 1

* "CAUSAL" padding=(2, 1)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1

In the window=3 and stride=3 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 2)
                pad  |                   |pad
    paddings:     0 0|0 * * 0 * * 1 * * 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 0  -> 0
                            0 0 0  -> 0
                              0 0 1  -> 1
                                0 1 0  -> 1
                                  1 0 0  -> 1
                                    0 0 1  -> 1
                                      0 1 0  -> 1
                                        1 0 0  -> 1

In the window=3 and stride=4 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 3)
                pad  |                         |pad
    paddings:     0 0|0 * * * 0 * * * 1 * * * 1|0 0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 0  -> 0
                            0 0 0  -> 0
                              0 0 0  -> 0
                                0 0 0  -> 0
                                  0 0 1  -> 1
                                    0 1 0  -> 1
                                      1 0 0  -> 1
                                        0 0 0  -> 0
                                          0 0 1  -> 1
                                            0 1 0  -> 1
                                              1 0 0  -> 1
                                                0 0 0  -> 0
    Here is how to compute output_size, given the above example,
      1.          |_|  -(window-1)
      2.              |_______________________|  (input_size-1)*stride + 1
      3.          |_|                           |___|  + pad_total

    So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total
                    = input_size*stride - (window+stride-2) + pad_total
                    = input_size*stride  <- "SAME", "CAUSAL"
                    = input_size*stride + max(window-stride, 0)  <- "VALID"

OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1.
For example, when window=3 and dilation=2, dilate_window=5.

In the stride=2 case, this function creates outputs as follows:
* "SAME" padding=(3, 2)
                pad    |             |pad
    paddings:     0 0 0|0 * 0 * 1 * 1|0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 1  -> 1
                          0 * 0 * 0  -> 0
                            0 * 1 * 1  -> 2
                              0 * 0 * 0  -> 0
                                1 * 1 * 0  -> 2

* "VALID" padding=(4, 4)
                pad      |             |pad
    paddings:     0 0 0 0|0 * 0 * 1 * 1|0 0 0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 0  -> 0
                          0 * 0 * 1  -> 1
                            0 * 0 * 0  -> 0
                              0 * 1 * 1  -> 2
                                0 * 0 * 0  -> 0
                                  1 * 1 * 0  -> 2
                                    0 * 0 * 0  -> 0
                                      1 * 0 * 0  -> 1

* "CAUSAL" padding=(4, 1)
                pad      |             |pad
    paddings:     0 0 0 0|0 * 0 * 1 * 1|0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 0  -> 0
                          0 * 0 * 1  -> 1
                            0 * 0 * 0  -> 0
                              0 * 1 * 1  -> 2
                                0 * 0 * 0  -> 0
@ds-hwang
Copy link
Contributor Author

@ruomingp could you approve it? from 926

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Dongseong. One question...

Comment on lines +1719 to +1730
In the window=3 and stride=2 case, this function creates outputs as follows:
* "SAME" padding=(2, 1)
pad | |pad
paddings: 0 0|0 * 0 * 1 * 1|0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 0 -> 0
0 0 1 -> 1
0 1 0 -> 1
1 0 1 -> 2
0 1 0 -> 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why stride=2 but the adjacent conv transpose windows shown are only one position apart from each other.

Copy link
Contributor Author

@ds-hwang ds-hwang Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, in transposed convolution, stride=2 is not window stride but input dilation.

Given input |0 0 1 1|
stride=2 dilates it to |0 * 0 * 1 * 1|
and padding=(2,1) add padding as 0 0|0 * 0 * 1 * 1|0.
That's all transposed convolution is about. After this, window=3 and window_stride=1 convolution takes place.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is my understanding:

        inputs = jnp.asarray([0, 1, 2, 3, 4]).reshape([1, -1, 1])
        kernel = jnp.asarray([1, 2, 4]).reshape([-1, 1, 1])
        # padded_inputs = [_, _, 0, 1, 2, 3, 4]
        # windows =      [[_, _, 0],
        #                    [0, 0, 4],
        #                       [0, 2, 8]
        #                          [1, 4,12]
        #                             [2, 6,16]
        #                                [3, 8, _]
        # strided =      [[_, _, 0],
        #                       [0, 0, 4],
        #                             [0, 2, 8]
        #                                   [1, 4,12]
        #                                         [2, 6,16]
        #                                               [3, 8, _]
        # sum =                 [0, 0, 4, 2, 9, 4,14, 6,19].
        outputs = jax.lax.conv_transpose(inputs, rhs=kernel, padding=((2, 0),), strides=(2,))
        assert_allclose(outputs, jnp.asarray([0, 0, 4, 2, 9, 4, 14, 6, 19]).reshape((1, -1, 1)))

Note that in this case, I only see 6 convolution windows that affect the outputs (5 windows if the input length is 4). So I'm confused by the 8 windows in the above comments.

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock follow-up work, but let's also resolve the question about the conv transpose windows...

@ds-hwang
Copy link
Contributor Author

Thank you for review!

@ds-hwang ds-hwang added this pull request to the merge queue Nov 25, 2024
Merged via the queue into apple:main with commit bad0f0f Nov 25, 2024
10 checks passed
@ds-hwang ds-hwang deleted the convt branch November 25, 2024 17:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants