Skip to content

Commit

Permalink
#9419: use memcpy to avoid mem misalignment (#10154)
Browse files Browse the repository at this point in the history
  • Loading branch information
ihamer-tt authored Jul 12, 2024
1 parent 9756f18 commit e2c5477
Showing 1 changed file with 74 additions and 32 deletions.
106 changes: 74 additions & 32 deletions tt_metal/impl/debug/dprint_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <cstddef>
#include <cstdint>
#include <thread>
#include <future>
Expand Down Expand Up @@ -199,7 +200,12 @@ static void PrintTileSlice(ostream& stream, uint8_t* ptr, int hart_id) {
return;
}

TileSliceHostDev<0>* ts = reinterpret_cast<TileSliceHostDev<0>*>(ptr);
TileSliceHostDev<0>ts_copy; // Make a copy since ptr might not be properly aligned
std::memcpy(&ts_copy, ptr, sizeof(TileSliceHostDev<0>));
TileSliceHostDev<0>* ts = &ts_copy;
TT_ASSERT(offsetof(TileSliceHostDev<0>, samples_) % sizeof(uint16_t) == 0, "TileSliceHostDev<0> samples_ field is not properly aligned");
uint16_t *samples_ = reinterpret_cast<uint16_t *>(ptr) + offsetof(TileSliceHostDev<0>, samples_) / sizeof(uint16_t);

if (ts->w0_ == 0xFFFF) {
stream << "BAD TILE POINTER" << std::flush;
stream << " count=" << ts->count_ << std::flush;
Expand All @@ -215,7 +221,7 @@ static void PrintTileSlice(ostream& stream, uint8_t* ptr, int hart_id) {
count_exceeded = true;
break;
}
stream << bfloat16_to_float(ts->samples_[i]);
stream << bfloat16_to_float(samples_[i]);
if (w + ts->ws_ < ts->w1_)
stream << " ";
i++;
Expand Down Expand Up @@ -754,42 +760,78 @@ bool DebugPrintServerContext::PeekOneHartNonBlocking(
TT_ASSERT(sz == 1);
break;
case DPrintUINT16:
stream << *reinterpret_cast<uint16_t*>(ptr);
TT_ASSERT(sz == 2);
break;
{
uint16_t value;
memcpy(&value, ptr, sizeof(uint16_t));
stream << value;
TT_ASSERT(sz == 2);
}
break;
case DPrintUINT32:
stream << *reinterpret_cast<uint32_t*>(ptr);
TT_ASSERT(sz == 4);
break;
{
uint32_t value;
memcpy(&value, ptr, sizeof(uint32_t));
stream << value;
TT_ASSERT(sz == 4);
}
break;
case DPrintUINT64:
stream << *reinterpret_cast<uint64_t*>(ptr);
TT_ASSERT(sz == 8);
break;
{
uint64_t value;
memcpy(&value, ptr, sizeof(uint64_t));
stream << value;
TT_ASSERT(sz == 8);
}
break;
case DPrintINT8:
// For int8_ts, we'll print an number instead of a char.
stream << (int) *reinterpret_cast<int8_t*>(ptr);
TT_ASSERT(sz == 1);
break;
{
int8_t value;
memcpy(&value, ptr, sizeof(int8_t));
stream << (int)value; // Cast to int to ensure it prints as a number, not a char
TT_ASSERT(sz == 1);
}
break;
case DPrintINT16:
stream << *reinterpret_cast<int16_t*>(ptr);
TT_ASSERT(sz == 2);
break;
{
int16_t value;
memcpy(&value, ptr, sizeof(int16_t));
stream << value;
TT_ASSERT(sz == 2);
}
break;
case DPrintINT32:
stream << *reinterpret_cast<int32_t*>(ptr);
TT_ASSERT(sz == 4);
break;
{
int32_t value;
memcpy(&value, ptr, sizeof(int32_t));
stream << value;
TT_ASSERT(sz == 4);
}
break;
case DPrintINT64:
stream << *reinterpret_cast<int64_t*>(ptr);
TT_ASSERT(sz == 8);
break;
{
int64_t value;
memcpy(&value, ptr, sizeof(int64_t));
stream << value;
TT_ASSERT(sz == 8);
}
break;
case DPrintFLOAT32:
stream << *reinterpret_cast<float*>(ptr);
TT_ASSERT(sz == 4);
break;
{
float value;
memcpy(&value, ptr, sizeof(float));
stream << value;
TT_ASSERT(sz == 4);
}
break;
case DPrintBFLOAT16:
stream << bfloat16_to_float(*reinterpret_cast<uint16_t*>(ptr));
TT_ASSERT(sz == 2);
break;
{
uint16_t rawValue;
memcpy(&rawValue, ptr, sizeof(uint16_t));
float value = bfloat16_to_float(rawValue);
stream << value;
TT_ASSERT(sz == 2);
}
break;
case DPrintCHAR:
stream << *reinterpret_cast<char*>(ptr);
TT_ASSERT(sz == 1);
Expand All @@ -801,7 +843,7 @@ bool DebugPrintServerContext::PeekOneHartNonBlocking(
PrintTypedUint32Array(stream, most_recent_setw, sz/4, reinterpret_cast<uint32_t*>(ptr));
break;
case DPrintRAISE:
sigval = *reinterpret_cast<uint32_t*>(ptr);
memcpy (&sigval, ptr, sizeof(uint32_t));
// Add this newly raised signals to the set of raised signals.
raise_wait_lock_.lock();
raised_signals_.insert(sigval);
Expand All @@ -811,7 +853,7 @@ bool DebugPrintServerContext::PeekOneHartNonBlocking(
break;
case DPrintWAIT:
{
sigval = *reinterpret_cast<uint32_t*>(ptr);
memcpy (&sigval, ptr, sizeof(uint32_t));
// Given that we break immediately on a wait, this core should never be waiting
// on multiple signals at the same time.
tuple<uint32_t, uint32_t, uint32_t, uint32_t> hart_key {chip_id, core.x, core.y, hart_id};
Expand Down

0 comments on commit e2c5477

Please sign in to comment.