Skip to content

Commit

Permalink
#12328: Optimize argmax bfloat16 comparison, saves llama 16us per token
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Sep 27, 2024
1 parent e53f39d commit 83be044
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,35 @@

//#include "debug/dprint.h"

// Function to compare two bfloat16 values using integer arithmetic
// Optimized function to compare two bfloat16 values using integer arithmetic
bool bfloat16_greater(uint16_t bf16_a, uint16_t bf16_b) {
// Extract signs
uint16_t sign_a = (bf16_a >> 15) & 0x1;
uint16_t sign_b = (bf16_b >> 15) & 0x1;

uint16_t exp_a = (bf16_a >> 7) & 0xFF;
uint16_t exp_b = (bf16_b >> 7) & 0xFF;

uint16_t man_a = bf16_a & 0x7F;
uint16_t man_b = bf16_b & 0x7F;

// TODO: Investigate subnormal support
// uint16_t subnormal_a = (exp_a == 0x00);
// uint16_t subnormal_b = (exp_b == 0x00);

// DPRINT << HEX() << (bf16_a) << " > " << bf16_b << ENDL();
// DPRINT << HEX() << (sign_a) << " signs " << sign_b << ENDL();
// DPRINT << HEX() << (exp_a) << " exp " << exp_b << ENDL();
// DPRINT << HEX() << (man_a) << " man " << man_b << ENDL();

// If signs are different, the one without the sign bit is greater
if (sign_a != sign_b) {
// DPRINT << "sign_b > sign_a: " << (int)(sign_b > sign_a) << ENDL();
return sign_b > sign_a;
/*
bfloat16 format (16 bits total):
[Sign (1 bit)][Exponent (8 bits)][Mantissa (7 bits)]
bit 15 bits 14-7 bits 6-0
Comparison Logic:
- If signs differ:
- If bf16_a is positive (sign bit 0), it is greater.
- If bf16_a is negative (sign bit 1), it is not greater.
- If signs are the same:
- Positive numbers: higher bits mean greater value.
- Negative numbers: higher bits mean smaller value (reverse comparison).
*/

// Check if signs are different
if ((bf16_a ^ bf16_b) & 0x8000) {
// Signs differ: if bf16_a is positive, it's greater
return (bf16_a & 0x8000) == 0;
}

// If signs are the same, compare the exponent and mantissa
if (sign_a == 0) { // Positive numbers
if(exp_a == exp_b) {
// DPRINT << "man_a > man_b: " << (int)(man_a > man_b) << ENDL();
return man_a > man_b;
}
// DPRINT << "exp_a > exp_b: " << (int)(exp_a > exp_b) << ENDL();
return exp_a > exp_b;
} else { // Negative numbers
if(exp_a == exp_b) {
// DPRINT << "man_a < man_b: " << (int)(man_a < man_b) << ENDL();
return man_a < man_b;
}
// DPRINT << "exp_a < exp_b: " << (int)(exp_a < exp_b) << ENDL();
return exp_a < exp_b;
// Signs are the same
if (bf16_a & 0x8000) {
// Both negative: reverse comparison
return bf16_a < bf16_b;
} else {
// Both positive: regular comparison
return bf16_a > bf16_b;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,35 @@

//#include "debug/dprint.h"

// Function to compare two bfloat16 values using integer arithmetic
// Optimized function to compare two bfloat16 values using integer arithmetic
bool bfloat16_greater(uint16_t bf16_a, uint16_t bf16_b) {
// Extract signs
uint16_t sign_a = (bf16_a >> 15) & 0x1;
uint16_t sign_b = (bf16_b >> 15) & 0x1;

uint16_t exp_a = (bf16_a >> 7) & 0xFF;
uint16_t exp_b = (bf16_b >> 7) & 0xFF;

uint16_t man_a = bf16_a & 0x7F;
uint16_t man_b = bf16_b & 0x7F;

// TODO: Investigate subnormal support
// uint16_t subnormal_a = (exp_a == 0x00);
// uint16_t subnormal_b = (exp_b == 0x00);

// DPRINT << HEX() << (bf16_a) << " > " << bf16_b << ENDL();
// DPRINT << HEX() << (sign_a) << " signs " << sign_b << ENDL();
// DPRINT << HEX() << (exp_a) << " exp " << exp_b << ENDL();
// DPRINT << HEX() << (man_a) << " man " << man_b << ENDL();

// If signs are different, the one without the sign bit is greater
if (sign_a != sign_b) {
// DPRINT << "sign_b > sign_a: " << (int)(sign_b > sign_a) << ENDL();
return sign_b > sign_a;
/*
bfloat16 format (16 bits total):
[Sign (1 bit)][Exponent (8 bits)][Mantissa (7 bits)]
bit 15 bits 14-7 bits 6-0
Comparison Logic:
- If signs differ:
- If bf16_a is positive (sign bit 0), it is greater.
- If bf16_a is negative (sign bit 1), it is not greater.
- If signs are the same:
- Positive numbers: higher bits mean greater value.
- Negative numbers: higher bits mean smaller value (reverse comparison).
*/

// Check if signs are different
if ((bf16_a ^ bf16_b) & 0x8000) {
// Signs differ: if bf16_a is positive, it's greater
return (bf16_a & 0x8000) == 0;
}

// If signs are the same, compare the exponent and mantissa
if (sign_a == 0) { // Positive numbers
if(exp_a == exp_b) {
// DPRINT << "man_a > man_b: " << (int)(man_a > man_b) << ENDL();
return man_a > man_b;
}
// DPRINT << "exp_a > exp_b: " << (int)(exp_a > exp_b) << ENDL();
return exp_a > exp_b;
} else { // Negative numbers
if(exp_a == exp_b) {
// DPRINT << "man_a < man_b: " << (int)(man_a < man_b) << ENDL();
return man_a < man_b;
}
// DPRINT << "exp_a < exp_b: " << (int)(exp_a < exp_b) << ENDL();
return exp_a < exp_b;
// Signs are the same
if (bf16_a & 0x8000) {
// Both negative: reverse comparison
return bf16_a < bf16_b;
} else {
// Both positive: regular comparison
return bf16_a > bf16_b;
}
}

Expand Down

0 comments on commit 83be044

Please sign in to comment.