Skip to content

Commit

Permalink
Implement ConvXDTranspose
Browse files Browse the repository at this point in the history
This PR implements unified transpose convolution covering 1D/2D/3D,
SAME/VALID/CAUSAL and arbitrary padding, arbitrary window, stride, and
dilation.

SAME and VALID is equivalent to jax.lax.conv_transpose(). CAUSAL is defined in
this PR.

Each Literal padding follows the formulas below,
* SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2)))
    pad_total = window+stride-2
    when stride > window -> (window-1, stride-1)
* VALID: padding=(window-1, max(stride-1, window-1))
    pad_total = window+stride-2 + max(window-stride, 0)
    when stride > window -> (window-1, stride-1)
* CAUSAL: padding=(window-1, stride-1)
    pad_total = window+stride-2

Note: output_size = input_size*stride - (window+stride-2) + pad_total
                  = input_size*stride  <- "SAME", "CAUSAL"
                  = input_size*stride + max(window-stride, 0)  <- "VALID"

Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1.
    dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window()

The following illustration demonstrates how Conv Transpose operates, assuming all kernel values
are set to 1 for simplicity in showcasing output values.

In the window=3 and stride=1 case, this function creates outputs as follows:
* "SAME" padding=(1, 1)
                pad|       |pad
    paddings:     0|0 0 1 1|0
                  0 0 0  -> 0
                    0 0 1  -> 1
                      0 1 1  -> 2
                        1 1 0  -> 2

* "VALID" padding=(2, 2)
                pad  |       |pad
    paddings:     0 0|0 0 1 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 1  -> 1
                        0 1 1  -> 2
                          1 1 0  -> 2
                            1 0 0  -> 1

* "CAUSAL" padding=(2, 0)
                pad  |       |pad
    paddings:     0 0|0 0 1 1|
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 1  -> 1
                        0 1 1  -> 2

In the window=3 and stride=2 case, this function creates outputs as follows:
* "SAME" padding=(2, 1)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1

* "VALID" padding=(2, 2)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1
                                  1 0 0  -> 1

* "CAUSAL" padding=(2, 1)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1

In the window=3 and stride=3 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 2)
                pad  |                   |pad
    paddings:     0 0|0 * * 0 * * 1 * * 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 0  -> 0
                            0 0 0  -> 0
                              0 0 1  -> 1
                                0 1 0  -> 1
                                  1 0 0  -> 1
                                    0 0 1  -> 1
                                      0 1 0  -> 1
                                        1 0 0  -> 1

In the window=3 and stride=4 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 3)
                pad  |                         |pad
    paddings:     0 0|0 * * * 0 * * * 1 * * * 1|0 0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 0  -> 0
                            0 0 0  -> 0
                              0 0 0  -> 0
                                0 0 0  -> 0
                                  0 0 1  -> 1
                                    0 1 0  -> 1
                                      1 0 0  -> 1
                                        0 0 0  -> 0
                                          0 0 1  -> 1
                                            0 1 0  -> 1
                                              1 0 0  -> 1
                                                0 0 0  -> 0
    Here is how to compute output_size, given the above example,
      1.          |_|  -(window-1)
      2.              |_______________________|  (input_size-1)*stride + 1
      3.          |_|                           |___|  + pad_total

    So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total
                    = input_size*stride - (window+stride-2) + pad_total
                    = input_size*stride  <- "SAME", "CAUSAL"
                    = input_size*stride + max(window-stride, 0)  <- "VALID"

OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1.
For example, when window=3 and dilation=2, dilate_window=5.

In the stride=2 case, this function creates outputs as follows:
* "SAME" padding=(3, 2)
                pad    |             |pad
    paddings:     0 0 0|0 * 0 * 1 * 1|0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 1  -> 1
                          0 * 0 * 0  -> 0
                            0 * 1 * 1  -> 2
                              0 * 0 * 0  -> 0
                                1 * 1 * 0  -> 2

* "VALID" padding=(4, 4)
                pad      |             |pad
    paddings:     0 0 0 0|0 * 0 * 1 * 1|0 0 0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 0  -> 0
                          0 * 0 * 1  -> 1
                            0 * 0 * 0  -> 0
                              0 * 1 * 1  -> 2
                                0 * 0 * 0  -> 0
                                  1 * 1 * 0  -> 2
                                    0 * 0 * 0  -> 0
                                      1 * 0 * 0  -> 1

* "CAUSAL" padding=(4, 1)
                pad      |             |pad
    paddings:     0 0 0 0|0 * 0 * 1 * 1|0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 0  -> 0
                          0 * 0 * 1  -> 1
                            0 * 0 * 0  -> 0
                              0 * 1 * 1  -> 2
                                0 * 0 * 0  -> 0
  • Loading branch information
ds-hwang committed Nov 22, 2024
1 parent d0edead commit 12d9e6f
Show file tree
Hide file tree
Showing 2 changed files with 1,641 additions and 190 deletions.
Loading

0 comments on commit 12d9e6f

Please sign in to comment.