Skip to content

Commit

Permalink
Implement the ConvTranspose operation
Browse files Browse the repository at this point in the history
  • Loading branch information
mayjs committed Aug 6, 2023
1 parent 07b7a9e commit 8ec1e4e
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 0 deletions.
103 changes: 103 additions & 0 deletions wonnx/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ lazy_static! {
include_str!("../templates/endomorphism/broadcast.wgsl"),
)
.unwrap();
tera.add_raw_template(
"unpool/convtranspose.wgsl",
include_str!("../templates/unpool/convtranspose.wgsl"),
)
.unwrap();
tera
};
}
Expand Down Expand Up @@ -1384,6 +1389,104 @@ pub fn compile(
threads: (ceil(output_lengths[0], 256) as _, 1, 1),
}
}
"ConvTranspose" => {
/* Inputs:
* 1. X Input (N x C x H x W; Batch Size x Channels x Height x Weight)
* 2. Kernel
* 3. Bias
*/
log::debug!("{:?}", input_shapes);

if input_shapes[0].rank() != 4 {
/* FIXME: We don't handle non-2D input for now */
return Err(CompileError::InvalidInputShape {
input_index: 0,
input_shape: input_shapes[0].clone(),
});
}

/* Step 1: Get the input dimensions */
let input_height = input_shapes[0].dims[2] as i64;
let input_width = input_shapes[0].dims[3] as i64;

/* Step 2: Read attributes */
/* TODO: auto_pad */
let dilations = node.get_attribute_value("dilations", Some(vec![1, 1]))?;

let group = node.get_attribute_value("group", Some(1))?;
if group != 1 {
return Err(CompileError::InvalidAttributeValue {
attribute: "group".into(),
value: group.to_string(),
opset_version,
});
}

let inferred_kernel_shape = input_shapes[1]
.dims
.iter()
.skip(2)
.map(|&x| x as i64)
.collect::<Vec<_>>();

let kernel_shape =
node.get_attribute_value("kernel_shape", Some(inferred_kernel_shape.clone()))?;
if inferred_kernel_shape != kernel_shape {
log::error!("Inferred kernel shape: {:?}", inferred_kernel_shape);
return Err(CompileError::InvalidAttributeValue {
attribute: "kernel_shape".to_string(),
value: format!("{:?}", kernel_shape),
opset_version,
});
}

let output_padding = node.get_attribute_value("output_padding", Some(vec![0, 0]))?;

let pads = node.get_attribute_value("pads", Some(vec![0, 0, 0, 0]))?;

let strides = node.get_attribute_value("strides", Some(vec![1, 1]))?;

context.insert("stride", &strides);

let output_height = strides[0] * (input_height - 1)
+ output_padding[0]
+ ((kernel_shape[0] - 1) * dilations[0] + 1)
- pads[0]
- pads[2];
let output_width = strides[1] * (input_width - 1)
+ output_padding[1]
+ ((kernel_shape[1] - 1) * dilations[1] + 1)
- pads[1]
- pads[3];

log::debug!(
"Calculated output size: {:?}x{:?}",
output_width,
output_height
);

let (x_threads, workgroup_size_x) = workgroup_size(
output_lengths[0],
MAX_COMPUTE_WORKGROUPS_PER_DIMENSION,
MAX_WORKGROUP_SIZE_X,
)?;
context.insert("workgroup_size_x", &workgroup_size_x);

let scalar_type = agreed_type(input_shapes, output_shapes)?;

if scalar_type.is_float() {
NodeTemplate {
scalar_type,
template: "unpool/convtranspose.wgsl",
threads: (x_threads, 1, 1),
}
} else {
return Err(CompileError::UnimplementedVariant {
variant: "Non-Float".into(),
op: "ConvTranspose".into(),
});
}
}
op => return Err(CompileError::UnimplementedOp(op.to_string())),
};

Expand Down
75 changes: 75 additions & 0 deletions wonnx/templates/unpool/convtranspose.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
{%- include "structs.wgsl" -%}

// Input tensor, shape NxCxHxW
@group(0) @binding(0)
var<storage, read> input_tensor: Array;

// Kernel weight tensor, shape CxM/groupxkHxkW
@group(0) @binding(1)
var<storage, read> input_kernel_weights: Array;

{% if i_lens | length == 3 -%}
@group(0) @binding(2)
var<storage, read> input_bias: Array;

@group(0) @binding(3)
var<storage, read_write> output_0: Array;
{%- else -%}
@group(0) @binding(2)
var<storage, read_write> output_0: Array;
{%- endif %}

@compute @workgroup_size({{ workgroup_size_x }}, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let output_idx = global_id.x;

if (output_idx < {{ o_lens[0] }}u) {
// Calculate the output coordinates we are responsible for
let batch = output_idx / {{ o_chunks[0][0] }}u;
var rest = output_idx % {{ o_chunks[0][0] }}u;

let channel = rest / {{ o_chunks[0][1] }}u;
rest = rest % {{ o_chunks[0][1] }}u;

let y = rest / {{ o_chunks[0][2] }}u;
let x = rest % {{ o_chunks[0][2] }}u;

let sample_root_index = batch * {{ i_chunks[0][0] }}u;

// Calculate the input coordinate range for our output coordinate
let min_in_y = select(0u, (y - {{ i_shape[1][2] }}u) / {{ stride[0] }}u, y > {{ i_shape[1][2] }}u);
let max_in_y = select({{ i_shape[0][2] }}u - 1u, y / {{ stride[0] }}u, y / {{ stride[0] }}u < {{ i_shape[0][3] }}u);
let min_in_x = select(0u, (x - {{ i_shape[1][3] }}u) / {{ stride[1] }}u, x > {{ i_shape[1][3] }}u);
let max_in_x = select({{ i_shape[0][3] }}u - 1u, x / {{ stride[1] }}u, x / {{ stride[1] }}u < {{ i_shape[0][3] }}u);

var result: Scalar = Scalar();

// Now, go over each input channel and apply the corresponing kernel for that channel
// to calculate the output piece by piece.
for(var ichannel: u32 = 0u; ichannel < {{ i_shape[0][1] }}u; ichannel = ichannel + 1u) {
// Base index for the 2D data in the input data
let base_index = sample_root_index + ichannel * {{ i_chunks[0][1] }}u;
// Get the starting position of the kernel for the given input and output channel
let base_kernel_index = ichannel *{{ i_chunks[1][0] }}u + channel * {{ i_chunks[1][1] }}u;

// Iterate of all potential input values
for(var in_y: u32 = min_in_y; in_y <= max_in_y; in_y = in_y + 1u) {
for(var in_x: u32 = min_in_x; in_x <= max_in_x; in_x = in_x + 1u) {
let kernel_y = y - (in_y * {{ stride[0] }}u);
let kernel_x = x - (in_x * {{ stride[1] }}u);

if(kernel_y < {{ i_shape[1][2] }}u && kernel_x < {{ i_shape[1][3] }}u) {
result = result + (input_tensor.data[base_index + (in_y * {{ i_chunks[0][2] }}u) + in_x]
* input_kernel_weights.data[base_kernel_index + kernel_y * {{ i_chunks[1][2] }}u + kernel_x]);
}
}
}
}
{% if i_lens | length == 3 -%}
// Apply Bias if specified
result = result + input_bias.data[channel];
{%- endif %}

output_0.data[output_idx] = result;
}
}
93 changes: 93 additions & 0 deletions wonnx/tests/convtranspose.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use std::collections::HashMap;
use wonnx::utils::{attribute, graph, initializer, model, node, tensor, OutputTensor};
mod common;

#[test]
fn convtranspose_default() {
let data: Vec<f32> = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let input_shape = vec![1, 1, 3, 3];

let data_w = vec![
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
];
let kernel_shape = vec![1, 2, 3, 3];

let output_shape = vec![1, 2, 5, 5];

let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]);

let convtranpose_model = model(graph(
vec![tensor("X", &input_shape)],
vec![tensor("Y", &output_shape)],
vec![],
vec![initializer("W", data_w, kernel_shape)],
vec![node(
vec!["X", "W"],
vec!["Y"],
"convtranspose",
"ConvTranspose",
vec![attribute("kernel_shape", vec![3, 3])],
)],
));

let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model))
.expect("Session did not create");
let result = pollster::block_on(session.run(&input_data)).unwrap();

assert_eq!(
result["Y"],
OutputTensor::F32(vec![
0.0, 1.0, 3.0, 3.0, 2.0, 3.0, 8.0, 15.0, 12.0, 7.0, 9.0, 21.0, 36.0, 27.0, 15.0, 9.0,
20.0, 33.0, 24.0, 13.0, 6.0, 13.0, 21.0, 15.0, 8.0, 0.0, 1.0, 3.0, 3.0, 2.0, 3.0, 8.0,
15.0, 12.0, 7.0, 9.0, 21.0, 36.0, 27.0, 15.0, 9.0, 20.0, 33.0, 24.0, 13.0, 6.0, 13.0,
21.0, 15.0, 8.0,
])
);
}

#[test]
fn convtranspose_strides() {
let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; // (1, 1, 3, 3)
let input_shape = vec![1, 1, 3, 3];

let data_w = vec![
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
];
let kernel_shape = vec![1, 2, 3, 3];

let output_data = vec![
0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 0.0,
1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 0.0, 3.0, 3.0, 7.0, 4.0,
9.0, 5.0, 5.0, 0.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 0.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0,
8.0, 0.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 0.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0,
0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 3.0, 3.0, 7.0,
4.0, 9.0, 5.0, 5.0, 0.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 0.0, 3.0, 3.0, 7.0, 4.0, 9.0,
5.0, 5.0, 0.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 0.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0,
8.0, 0.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
let output_shape = vec![1, 2, 10, 8];

let convtranpose_model = model(graph(
vec![tensor("X", &input_shape)],
vec![tensor("Y", &output_shape)],
vec![],
vec![initializer("W", data_w, kernel_shape)],
vec![node(
vec!["X", "W"],
vec!["Y"],
"convtranspose",
"ConvTranspose",
vec![
attribute("kernel_shape", vec![3, 3]),
attribute("strides", vec![3, 2]),
],
)],
));

let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]);
let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model))
.expect("Session did not create");
let result = pollster::block_on(session.run(&input_data)).unwrap();
assert_eq!(result["Y"], OutputTensor::F32(output_data));
}

0 comments on commit 8ec1e4e

Please sign in to comment.