diff --git a/wonnx/templates/matrix/pad.wgsl b/wonnx/templates/matrix/pad.wgsl index e4c6e313..f8631746 100644 --- a/wonnx/templates/matrix/pad.wgsl +++ b/wonnx/templates/matrix/pad.wgsl @@ -26,19 +26,30 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { var pad = false; {% for pad in pad_info %} - let id_{{ loop.index0 }} = d_{{ loop.index0 }} - - {{ pad.copy_start }}u; + var id_{{ loop.index0 }} = 0u; if (d_{{ loop.index0 }} < {{ pad.copy_start }}u) { - pad = true; - } - if (d_{{ loop.index0 }} > {{ pad.end_pad_start }}u) { - pad = true; + {% if mode == "reflect" %} + id_{{ loop.index0 }} = ({{ pad.copy_start }}u - d_{{ loop.index0 }}) % {{ i_shape[0][loop.index0] }}u; + {% else %} + id_{{ loop.index0 }} = d_{{ loop.index0 }} - {{ pad.copy_start }}u; + pad = true; + {% endif %} } + else if (d_{{ loop.index0 }} > {{ pad.end_pad_start }}u) { + {% if mode == "reflect" %} + id_{{ loop.index0 }} = 2u * {{ pad.end_pad_start }}u - d_{{ loop.index0 }}; + {% else %} + id_{{ loop.index0 }} = d_{{ loop.index0 }} - {{ pad.copy_start }}u; + pad = true; + {% endif %} + } else { + id_{{ loop.index0 }} = d_{{ loop.index0 }} - {{ pad.copy_start }}u; + } {% endfor %} if (pad) { - output_0.data[gidx] = {{ scalar_type }}({{ constant_value }}); + output_0.data[gidx] = {{ scalar_type }}({{ constant_value }}); } else { let index = {%- for chunk in i_chunks | first -%}