Skip to content

Commit

Permalink
#16175: Add DPRINT TileSlice support for int types
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-dma committed Jan 3, 2025
1 parent 66bb613 commit 8e36303
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 9 deletions.
3 changes: 2 additions & 1 deletion docs/source/tt-metalium/tools/kernel_print.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ An example of how to print data from a CB (in this case, ``CBIndex::c_25``) is s
to the current CB read or write pointer. This means that for printing a tile read from the front of the CB, the
``DPRINT`` call has to occur between the ``cb_wait_front`` and ``cb_pop_front`` calls. For printing a tile from the
back of the CB, the ``DPRINT`` call has to occur between the ``cb_reserve_back`` and ``cb_push_back`` calls. Currently supported data
formats for printing from CBs are ``DataFormat::Float32``, ``DataFormat::Float16_b``, ``DataFormat::Bfp8_b``, and ``DataFormat::Bfp4_b``.
formats for printing from CBs are ``DataFormat::Float32``, ``DataFormat::Float16_b``, ``DataFormat::Bfp8_b``, ``DataFormat::Bfp4_b``,
``DataFormat::Int8``, ``DataFormat::UInt8``, ``DataFormat::UInt16``, ``DataFormat::Int32``, and ``DataFormat::UInt832``.

.. code-block:: c++

Expand Down
125 changes: 119 additions & 6 deletions tests/tt_metal/tt_metal/debug_tools/dprint/test_print_tiles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,63 @@ static std::vector<uint32_t> GenerateInputTile(tt::DataFormat data_format) {
if (data_format == tt::DataFormat::Float32) {
u32_vec.resize(elements_in_tile);
for (int i = 0; i < u32_vec.size(); i++) {
float val = -12.3345 + static_cast<float>(i);
float val = -12.3345 + static_cast<float>(i); // Rebias to force some negative #s to be printed
u32_vec.at(i) = *reinterpret_cast<uint32_t*>(&val);
}
} else if (data_format == tt::DataFormat::Float16_b) {
std::vector<bfloat16> fp16b_vec(elements_in_tile);
for (int i = 0; i < fp16b_vec.size(); i++) {
uint16_t val = 0x3dfb + i;
uint16_t val = 0x3dfb + i; // Start at some known value (~0.1226) and increment for new numbers
fp16b_vec[i] = bfloat16(val);
}
u32_vec = pack_bfloat16_vec_into_uint32_vec(fp16b_vec);
} else if (data_format == tt::DataFormat::Bfp8_b) {
std::vector<float> float_vec(elements_in_tile);
for (int i = 0; i < float_vec.size(); i++) {
float_vec[i] = 0.012345 * i * (i % 32 == 0? -1 : 1);
float_vec[i] = 0.012345 * i * (i % 32 == 0 ? -1 : 1); // Small increments and force negatives for testing
}
u32_vec = pack_fp32_vec_as_bfp8_tiles(float_vec, true, false);
} else if (data_format == tt::DataFormat::Bfp4_b) {
std::vector<float> float_vec(elements_in_tile);
for (int i = 0; i < float_vec.size(); i++) {
float_vec[i] = 0.012345 * i * (i % 16 == 0? -1 : 1);
float_vec[i] = 0.012345 * i * (i % 16 == 0 ? -1 : 1); // Small increments and force negatives for testing
}
u32_vec = pack_fp32_vec_as_bfp4_tiles(float_vec, true, false);
} else {
} else if (data_format == tt::DataFormat::Int8) {
std::vector<int8_t> int8_vec(elements_in_tile);
for (int i = 0; i < int8_vec.size(); i++) {
int8_vec[i] = ((i / 2) % 256) - 128; // Force prints to be different (/2), within the int8 range (%256),
// and include negatives (-128) for testing purposes.
}
uint32_t datums_per_32 = sizeof(uint32_t) / tt::datum_size(data_format);
u32_vec.resize(elements_in_tile / datums_per_32);
std::memcpy(u32_vec.data(), int8_vec.data(), elements_in_tile * sizeof(int8_t));
} else if (data_format == tt::DataFormat::UInt8) {
std::vector<uint8_t> uint8_vec(elements_in_tile);
for (int i = 0; i < uint8_vec.size(); i++) {
uint8_vec[i] = ((i / 2) % 256); // Same as int8, just no negatives
}
uint32_t datums_per_32 = sizeof(uint32_t) / tt::datum_size(data_format);
u32_vec.resize(elements_in_tile / datums_per_32);
std::memcpy(u32_vec.data(), uint8_vec.data(), elements_in_tile * sizeof(uint8_t));
} else if (data_format == tt::DataFormat::UInt16) {
std::vector<uint16_t> uint16_vec(elements_in_tile);
for (int i = 0; i < uint16_vec.size(); i++) {
uint16_vec[i] = (i % 0x10000); // Force to within uint16 range
}
uint32_t datums_per_32 = sizeof(uint32_t) / tt::datum_size(data_format);
u32_vec.resize(elements_in_tile / datums_per_32);
std::memcpy(u32_vec.data(), uint16_vec.data(), elements_in_tile * sizeof(uint16_t));
} else if (data_format == tt::DataFormat::Int32) {
u32_vec.resize(elements_in_tile);
for (int i = 0; i < u32_vec.size(); i++) {
u32_vec[i] = (i % 2) ? i : i * -1; // Make every other number negative for printing purposes
}
} else if (data_format == tt::DataFormat::UInt32) {
u32_vec.resize(elements_in_tile);
for (int i = 0; i < u32_vec.size(); i++) {
u32_vec[i] = i;
}
}
return u32_vec;
}
Expand Down Expand Up @@ -89,7 +123,56 @@ static string GenerateExpectedData(tt::DataFormat data_format, std::vector<uint3
*reinterpret_cast<float *>(&float_vec[col * 32 + 16]),
*reinterpret_cast<float *>(&float_vec[col * 32 + 24]));
}
} else {
} else if (data_format == tt::DataFormat::Int8) {
int8_t* int8_ptr = reinterpret_cast<int8_t*>(input_tile.data());
for (uint32_t col = 0; col < 32; col += 8) {
data += fmt::format(
"\n{} {} {} {}",
int8_ptr[col * 32 + 0],
int8_ptr[col * 32 + 8],
int8_ptr[col * 32 + 16],
int8_ptr[col * 32 + 24]);
}
} else if (data_format == tt::DataFormat::UInt8) {
uint8_t* uint8_ptr = reinterpret_cast<uint8_t*>(input_tile.data());
for (uint32_t col = 0; col < 32; col += 8) {
data += fmt::format(
"\n{} {} {} {}",
uint8_ptr[col * 32 + 0],
uint8_ptr[col * 32 + 8],
uint8_ptr[col * 32 + 16],
uint8_ptr[col * 32 + 24]);
}
} else if (data_format == tt::DataFormat::UInt16) {
uint16_t* uint16_ptr = reinterpret_cast<uint16_t*>(input_tile.data());
for (uint32_t col = 0; col < 32; col += 8) {
data += fmt::format(
"\n{} {} {} {}",
uint16_ptr[col * 32 + 0],
uint16_ptr[col * 32 + 8],
uint16_ptr[col * 32 + 16],
uint16_ptr[col * 32 + 24]);
}
} else if (data_format == tt::DataFormat::Int32) {
int32_t* int32_ptr = reinterpret_cast<int32_t*>(input_tile.data());
for (uint32_t col = 0; col < 32; col += 8) {
data += fmt::format(
"\n{} {} {} {}",
int32_ptr[col * 32 + 0],
int32_ptr[col * 32 + 8],
int32_ptr[col * 32 + 16],
int32_ptr[col * 32 + 24]);
}
} else if (data_format == tt::DataFormat::UInt32) {
uint32_t* uint32_ptr = reinterpret_cast<uint32_t*>(input_tile.data());
for (uint32_t col = 0; col < 32; col += 8) {
data += fmt::format(
"\n{} {} {} {}",
uint32_ptr[col * 32 + 0],
uint32_ptr[col * 32 + 8],
uint32_ptr[col * 32 + 16],
uint32_ptr[col * 32 + 24]);
}
}
return data;
}
Expand Down Expand Up @@ -199,3 +282,33 @@ TEST_F(DPrintFixture, TestPrintTilesBfp8_b) {
[&](DPrintFixture* fixture, Device* device) { RunTest(fixture, device, tt::DataFormat::Bfp8_b); }, device);
}
}
TEST_F(DPrintFixture, TestPrintTilesInt8) {
for (Device* device : this->devices_) {
this->RunTestOnDevice(
[&](DPrintFixture* fixture, Device* device) { RunTest(fixture, device, tt::DataFormat::Int8); }, device);
}
}
TEST_F(DPrintFixture, TestPrintTilesInt32) {
for (Device* device : this->devices_) {
this->RunTestOnDevice(
[&](DPrintFixture* fixture, Device* device) { RunTest(fixture, device, tt::DataFormat::Int32); }, device);
}
}
TEST_F(DPrintFixture, TestPrintTilesUInt8) {
for (Device* device : this->devices_) {
this->RunTestOnDevice(
[&](DPrintFixture* fixture, Device* device) { RunTest(fixture, device, tt::DataFormat::UInt8); }, device);
}
}
TEST_F(DPrintFixture, TestPrintTilesUInt16) {
for (Device* device : this->devices_) {
this->RunTestOnDevice(
[&](DPrintFixture* fixture, Device* device) { RunTest(fixture, device, tt::DataFormat::UInt16); }, device);
}
}
TEST_F(DPrintFixture, TestPrintTilesUInt32) {
for (Device* device : this->devices_) {
this->RunTestOnDevice(
[&](DPrintFixture* fixture, Device* device) { RunTest(fixture, device, tt::DataFormat::UInt32); }, device);
}
}
4 changes: 2 additions & 2 deletions tt_metal/hostdevcommon/dprint_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,11 @@ static inline constexpr bool is_supported_format(const CommonDataFormat& format)
case CommonDataFormat::Float16_b: return true;
case CommonDataFormat::Float32: return true;
case CommonDataFormat::Int8:
case CommonDataFormat::Lf8:
case CommonDataFormat::UInt8:
case CommonDataFormat::UInt16:
case CommonDataFormat::UInt32:
case CommonDataFormat::Int32:
case CommonDataFormat::Int32: return true;
case CommonDataFormat::Lf8:
case CommonDataFormat::Invalid:
default: return false;
}
Expand Down
25 changes: 25 additions & 0 deletions tt_metal/impl/debug/dprint_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,31 @@ static void PrintTileSlice(ostringstream* stream, uint8_t* ptr) {
*stream << *reinterpret_cast<float*>(&bit_val);
break;
}
case tt::DataFormat::Int8: {
int8_t* data_ptr = reinterpret_cast<int8_t*>(data);
*stream << (int)data_ptr[i];
break;
}
case tt::DataFormat::UInt8: {
uint8_t* data_ptr = reinterpret_cast<uint8_t*>(data);
*stream << (unsigned int)data_ptr[i];
break;
}
case tt::DataFormat::UInt16: {
uint16_t* data_ptr = reinterpret_cast<uint16_t*>(data);
*stream << (unsigned int)data_ptr[i];
break;
}
case tt::DataFormat::Int32: {
int32_t* data_ptr = reinterpret_cast<int32_t*>(data);
*stream << (int)data_ptr[i];
break;
}
case tt::DataFormat::UInt32: {
uint32_t* data_ptr = reinterpret_cast<uint32_t*>(data);
*stream << (unsigned int)data_ptr[i];
break;
}
default: break;
}
if (w + ts->slice_range.ws < ts->slice_range.w1) {
Expand Down

0 comments on commit 8e36303

Please sign in to comment.