From 83be044022f15c9f8f9a4ecf48ce64fbaf66b4a3 Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Fri, 27 Sep 2024 11:27:59 +0000 Subject: [PATCH] #12328: Optimize argmax bfloat16 comparison, saves llama 16us per token --- .../kernels/reader_argmax_interleaved.cpp | 65 ++++++++----------- .../reader_argmax_interleaved_multicore.cpp | 65 ++++++++----------- 2 files changed, 52 insertions(+), 78 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/reduction/argmax/device/kernels/reader_argmax_interleaved.cpp b/ttnn/cpp/ttnn/operations/reduction/argmax/device/kernels/reader_argmax_interleaved.cpp index 42bbd549f1a1..da3a5f16e5bd 100644 --- a/ttnn/cpp/ttnn/operations/reduction/argmax/device/kernels/reader_argmax_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/argmax/device/kernels/reader_argmax_interleaved.cpp @@ -8,48 +8,35 @@ //#include "debug/dprint.h" -// Function to compare two bfloat16 values using integer arithmetic +// Optimized function to compare two bfloat16 values using integer arithmetic bool bfloat16_greater(uint16_t bf16_a, uint16_t bf16_b) { - // Extract signs - uint16_t sign_a = (bf16_a >> 15) & 0x1; - uint16_t sign_b = (bf16_b >> 15) & 0x1; - - uint16_t exp_a = (bf16_a >> 7) & 0xFF; - uint16_t exp_b = (bf16_b >> 7) & 0xFF; - - uint16_t man_a = bf16_a & 0x7F; - uint16_t man_b = bf16_b & 0x7F; - - // TODO: Investigate subnormal support - // uint16_t subnormal_a = (exp_a == 0x00); - // uint16_t subnormal_b = (exp_b == 0x00); - - // DPRINT << HEX() << (bf16_a) << " > " << bf16_b << ENDL(); - // DPRINT << HEX() << (sign_a) << " signs " << sign_b << ENDL(); - // DPRINT << HEX() << (exp_a) << " exp " << exp_b << ENDL(); - // DPRINT << HEX() << (man_a) << " man " << man_b << ENDL(); - - // If signs are different, the one without the sign bit is greater - if (sign_a != sign_b) { - // DPRINT << "sign_b > sign_a: " << (int)(sign_b > sign_a) << ENDL(); - return sign_b > sign_a; + /* + bfloat16 format (16 bits total): + [Sign (1 bit)][Exponent (8 bits)][Mantissa (7 bits)] + bit 15 bits 14-7 bits 6-0 + + Comparison Logic: + - If signs differ: + - If bf16_a is positive (sign bit 0), it is greater. + - If bf16_a is negative (sign bit 1), it is not greater. + - If signs are the same: + - Positive numbers: higher bits mean greater value. + - Negative numbers: higher bits mean smaller value (reverse comparison). + */ + + // Check if signs are different + if ((bf16_a ^ bf16_b) & 0x8000) { + // Signs differ: if bf16_a is positive, it's greater + return (bf16_a & 0x8000) == 0; } - // If signs are the same, compare the exponent and mantissa - if (sign_a == 0) { // Positive numbers - if(exp_a == exp_b) { - // DPRINT << "man_a > man_b: " << (int)(man_a > man_b) << ENDL(); - return man_a > man_b; - } - // DPRINT << "exp_a > exp_b: " << (int)(exp_a > exp_b) << ENDL(); - return exp_a > exp_b; - } else { // Negative numbers - if(exp_a == exp_b) { - // DPRINT << "man_a < man_b: " << (int)(man_a < man_b) << ENDL(); - return man_a < man_b; - } - // DPRINT << "exp_a < exp_b: " << (int)(exp_a < exp_b) << ENDL(); - return exp_a < exp_b; + // Signs are the same + if (bf16_a & 0x8000) { + // Both negative: reverse comparison + return bf16_a < bf16_b; + } else { + // Both positive: regular comparison + return bf16_a > bf16_b; } } diff --git a/ttnn/cpp/ttnn/operations/reduction/argmax/device/kernels/reader_argmax_interleaved_multicore.cpp b/ttnn/cpp/ttnn/operations/reduction/argmax/device/kernels/reader_argmax_interleaved_multicore.cpp index ee69fe9b9bba..e5456c446e28 100644 --- a/ttnn/cpp/ttnn/operations/reduction/argmax/device/kernels/reader_argmax_interleaved_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/argmax/device/kernels/reader_argmax_interleaved_multicore.cpp @@ -8,48 +8,35 @@ //#include "debug/dprint.h" -// Function to compare two bfloat16 values using integer arithmetic +// Optimized function to compare two bfloat16 values using integer arithmetic bool bfloat16_greater(uint16_t bf16_a, uint16_t bf16_b) { - // Extract signs - uint16_t sign_a = (bf16_a >> 15) & 0x1; - uint16_t sign_b = (bf16_b >> 15) & 0x1; - - uint16_t exp_a = (bf16_a >> 7) & 0xFF; - uint16_t exp_b = (bf16_b >> 7) & 0xFF; - - uint16_t man_a = bf16_a & 0x7F; - uint16_t man_b = bf16_b & 0x7F; - - // TODO: Investigate subnormal support - // uint16_t subnormal_a = (exp_a == 0x00); - // uint16_t subnormal_b = (exp_b == 0x00); - - // DPRINT << HEX() << (bf16_a) << " > " << bf16_b << ENDL(); - // DPRINT << HEX() << (sign_a) << " signs " << sign_b << ENDL(); - // DPRINT << HEX() << (exp_a) << " exp " << exp_b << ENDL(); - // DPRINT << HEX() << (man_a) << " man " << man_b << ENDL(); - - // If signs are different, the one without the sign bit is greater - if (sign_a != sign_b) { - // DPRINT << "sign_b > sign_a: " << (int)(sign_b > sign_a) << ENDL(); - return sign_b > sign_a; + /* + bfloat16 format (16 bits total): + [Sign (1 bit)][Exponent (8 bits)][Mantissa (7 bits)] + bit 15 bits 14-7 bits 6-0 + + Comparison Logic: + - If signs differ: + - If bf16_a is positive (sign bit 0), it is greater. + - If bf16_a is negative (sign bit 1), it is not greater. + - If signs are the same: + - Positive numbers: higher bits mean greater value. + - Negative numbers: higher bits mean smaller value (reverse comparison). + */ + + // Check if signs are different + if ((bf16_a ^ bf16_b) & 0x8000) { + // Signs differ: if bf16_a is positive, it's greater + return (bf16_a & 0x8000) == 0; } - // If signs are the same, compare the exponent and mantissa - if (sign_a == 0) { // Positive numbers - if(exp_a == exp_b) { - // DPRINT << "man_a > man_b: " << (int)(man_a > man_b) << ENDL(); - return man_a > man_b; - } - // DPRINT << "exp_a > exp_b: " << (int)(exp_a > exp_b) << ENDL(); - return exp_a > exp_b; - } else { // Negative numbers - if(exp_a == exp_b) { - // DPRINT << "man_a < man_b: " << (int)(man_a < man_b) << ENDL(); - return man_a < man_b; - } - // DPRINT << "exp_a < exp_b: " << (int)(exp_a < exp_b) << ENDL(); - return exp_a < exp_b; + // Signs are the same + if (bf16_a & 0x8000) { + // Both negative: reverse comparison + return bf16_a < bf16_b; + } else { + // Both positive: regular comparison + return bf16_a > bf16_b; } }