diff --git a/wonnx/templates/matrix/lrn.wgsl b/wonnx/templates/matrix/lrn.wgsl index d18b3518..c3517ab2 100644 --- a/wonnx/templates/matrix/lrn.wgsl +++ b/wonnx/templates/matrix/lrn.wgsl @@ -19,5 +19,5 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { square_sum += I * I; } - output_0.data[c] = input_0.data[ c ] / pow({{bias}} + ({{alpha}} / {{size}}.0) * square_sum,{{beta}}); + output_0.data[c] = input_0.data[ c ] / pow({{ scalar_type }}({{ bias }}) + ({{ scalar_type }}({{ alpha }}) / {{ scalar_type }}({{ size }})) * square_sum, {{ scalar_type }}({{ beta }})); }