Skip to content

Commit

Permalink
#8837: Skip generating buffer page mapping for linear sharded reads
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed May 25, 2024
1 parent e44c930 commit 5d683ba
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 78 deletions.
128 changes: 65 additions & 63 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,9 @@ void EnqueueReadInterleavedBufferCommand::add_prefetch_relay(HugepageDeviceComma

void EnqueueReadShardedBufferCommand::add_prefetch_relay(HugepageDeviceCommand& command) {
uint32_t padded_page_size = align(this->buffer.page_size(), ADDRESS_ALIGNMENT);
CoreCoord logical_core =
this->buffer_page_mapping.all_cores_[this->buffer_page_mapping.dev_page_to_core_mapping_[this->src_page_index]];
CoreCoord core = this->buffer.device()->worker_core_from_logical_core(logical_core);
const CoreCoord worker_core = this->buffer.device()->worker_core_from_logical_core(this->core);
command.add_prefetch_relay_linear(
get_noc_unicast_encoding(core),
padded_page_size * this->pages_to_read,
this->buffer.address() +
this->buffer_page_mapping.host_page_to_local_shard_page_mapping_
[this->buffer_page_mapping.dev_page_to_host_page_mapping_[this->src_page_index].value()] *
padded_page_size);
get_noc_unicast_encoding(worker_core), padded_page_size * this->pages_to_read, this->buffer.address());
}

void EnqueueReadBufferCommand::process() {
Expand Down Expand Up @@ -224,30 +217,19 @@ void EnqueueWriteShardedBufferCommand::add_buffer_data(HugepageDeviceCommand& co
uint32_t data_size_bytes = this->pages_to_write * this->padded_page_size;
if (this->buffer_page_mapping.has_value()) {
const auto& page_mapping = this->buffer_page_mapping.value();
uint32_t core_index = page_mapping.dev_page_to_core_mapping_[this->dst_page_index];
bool width_page_padded =
page_mapping.core_shard_shape_[core_index][1] != buffer.shard_spec().shape_in_pages()[1];
if (width_page_padded or this->width_split or
(this->buffer.page_size() != this->padded_page_size and this->buffer.page_size() != this->buffer.size())) {
uint8_t* dst = command_sequence.reserve_space<uint8_t*, true>(data_size_bytes);
// TODO: Expose getter for cmd_write_offsetB?
uint32_t dst_offset = dst - (uint8_t*)command_sequence.data();
for (uint32_t dev_page = this->dst_page_index; dev_page < this->dst_page_index + this->pages_to_write;
++dev_page) {
auto& host_page = page_mapping.dev_page_to_host_page_mapping_[dev_page];
if (host_page.has_value()) {
command_sequence.update_cmd_sequence(
dst_offset,
(char*)this->src + host_page.value() * this->buffer.page_size(),
this->buffer.page_size());
}
dst_offset += this->padded_page_size;
uint8_t* dst = command_sequence.reserve_space<uint8_t*, true>(data_size_bytes);
// TODO: Expose getter for cmd_write_offsetB?
uint32_t dst_offset = dst - (uint8_t*)command_sequence.data();
for (uint32_t dev_page = this->dst_page_index; dev_page < this->dst_page_index + this->pages_to_write;
++dev_page) {
auto& host_page = page_mapping.dev_page_to_host_page_mapping_[dev_page];
if (host_page.has_value()) {
command_sequence.update_cmd_sequence(
dst_offset,
(char*)this->src + host_page.value() * this->buffer.page_size(),
this->buffer.page_size());
}
} else {
// There are no padded pages
uint32_t unpadded_src_offset =
page_mapping.dev_page_to_host_page_mapping_[this->dst_page_index].value() * this->buffer.page_size();
command_sequence.add_data((char*)this->src + unpadded_src_offset, data_size_bytes, data_size_bytes);
dst_offset += this->padded_page_size;
}
} else {
if (this->buffer.page_size() != this->padded_page_size and this->buffer.page_size() != this->buffer.size()) {
Expand Down Expand Up @@ -1316,20 +1298,38 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin
uint32_t src_page_index = 0;

if (is_sharded(buffer.buffer_layout())) {
auto buffer_page_mapping = generate_buffer_page_mapping(buffer);
bool width_split = buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape[1];
std::optional<BufferPageMapping> buffer_page_mapping = std::nullopt;
if (width_split) {
buffer_page_mapping = generate_buffer_page_mapping(buffer);
}
// Note that the src_page_index is the device page idx, not the host page idx
// Since we read core by core we are reading the device pages sequentially
bool width_split = buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape[1];
const auto& cores = width_split ? buffer_page_mapping.value().all_cores_
: corerange_to_cores(
buffer.shard_spec().grid(),
buffer.num_cores(),
buffer.shard_spec().orientation() == ShardOrientation::ROW_MAJOR);
uint32_t num_total_pages = buffer.num_pages();
uint32_t max_pages_per_shard = buffer.shard_spec().size();
bool linear_page_copy = true;
for (uint32_t core_id = 0; core_id < buffer.num_cores(); ++core_id) {
uint32_t num_pages_to_read =
buffer_page_mapping.core_shard_shape_[core_id][0] * buffer.shard_spec().shape_in_pages()[1];
uint32_t num_pages_to_read;
if (width_split) {
num_pages_to_read =
buffer_page_mapping.value().core_shard_shape_[core_id][0] * buffer.shard_spec().shape_in_pages()[1];
} else {
num_pages_to_read = min(num_total_pages, max_pages_per_shard);
num_total_pages -= num_pages_to_read;
}
if (num_pages_to_read > 0) {
bool width_page_padded =
buffer_page_mapping.core_shard_shape_[core_id][1] != buffer.shard_spec().shape_in_pages()[1];
bool linear_page_copy = !(width_split or width_page_padded);
uint32_t host_page = buffer_page_mapping.core_host_page_indices_[core_id][0];
src_page_index = buffer_page_mapping.host_page_to_dev_page_mapping_[host_page];
unpadded_dst_offset = host_page * buffer.page_size();
if (width_split) {
uint32_t host_page = buffer_page_mapping.value().core_host_page_indices_[core_id][0];
src_page_index = buffer_page_mapping.value().host_page_to_dev_page_mapping_[host_page];
unpadded_dst_offset = host_page * buffer.page_size();
} else {
unpadded_dst_offset = src_page_index * buffer.page_size();
}

auto command = EnqueueReadShardedBufferCommand(
this->id,
Expand All @@ -1338,7 +1338,7 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin
dst,
this->manager,
this->expected_num_workers_completed,
buffer_page_mapping,
cores[core_id],
src_page_index,
num_pages_to_read);

Expand All @@ -1349,7 +1349,10 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin
unpadded_dst_offset,
num_pages_to_read,
src_page_index,
linear_page_copy));
width_split ? (*buffer_page_mapping).dev_page_to_host_page_mapping_
: vector<std::optional<uint32_t>>()));

src_page_index += num_pages_to_read;
this->enqueue_command(command, false);
this->increment_num_entries_in_completion_q();
}
Expand Down Expand Up @@ -1434,18 +1437,20 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src,
uint32_t dst_page_index = 0;

if (is_sharded(buffer.buffer_layout())) {
bool width_split = buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape[1];
bool height_sharded = buffer.buffer_layout() == TensorMemoryLayout::HEIGHT_SHARDED;
const bool width_split = buffer.shard_spec().shape_in_pages()[1] != buffer.shard_spec().tensor2d_shape[1];
std::optional<BufferPageMapping> buffer_page_mapping = std::nullopt;
if (!height_sharded) {
if (width_split) {
buffer_page_mapping = generate_buffer_page_mapping(buffer);
}
const auto& cores = !height_sharded ? buffer_page_mapping.value().all_cores_
: corerange_to_cores(
buffer.shard_spec().grid(),
buffer.num_cores(),
buffer.shard_spec().orientation() == ShardOrientation::ROW_MAJOR);
TT_ASSERT(max_data_sizeB >= padded_page_size);
const auto& cores = width_split ? buffer_page_mapping.value().all_cores_
: corerange_to_cores(
buffer.shard_spec().grid(),
buffer.num_cores(),
buffer.shard_spec().orientation() == ShardOrientation::ROW_MAJOR);
TT_FATAL(
max_data_sizeB >= padded_page_size,
"Writing padded page size > {} is currently unsupported for sharded tensors.",
max_data_sizeB);
uint32_t num_total_pages = buffer.num_pages();
uint32_t max_pages_per_shard = buffer.shard_spec().size();

Expand All @@ -1455,7 +1460,7 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src,
// Currently since writing sharded tensors uses write_linear, we write the padded pages on width
// Alternative write each page row into separate commands, or have a strided linear write
uint32_t num_pages;
if (!height_sharded) {
if (width_split) {
num_pages =
buffer_page_mapping.value().core_shard_shape_[core_id][0] * buffer.shard_spec().shape_in_pages()[1];
if (num_pages == 0) {
Expand All @@ -1469,14 +1474,12 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src,
}
uint32_t curr_page_idx_in_shard = 0;
while (num_pages != 0) {
uint32_t data_offset_bytes =
(sizeof(CQPrefetchCmd) +
sizeof(CQDispatchCmd)); // data appended after CQ_PREFETCH_CMD_RELAY_INLINE +
// CQ_DISPATCH_CMD_WRITE_PAGED
// data appended after CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WRITE_PAGED
uint32_t data_offset_bytes = (sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd));
bool issue_wait = dst_page_index == 0; // only stall for the first write of the buffer
if (issue_wait) {
data_offset_bytes *=
2; // commands prefixed with CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT
// commands prefixed with CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT
data_offset_bytes *= 2;
}
uint32_t space_available_bytes = std::min(
command_issue_limit - this->manager.get_issue_queue_write_ptr(this->id), max_prefetch_command_size);
Expand All @@ -1500,7 +1503,6 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src,
bank_base_address,
buffer_page_mapping,
cores[core_id],
width_split,
padded_page_size,
dst_page_index,
pages_to_write);
Expand Down Expand Up @@ -1701,7 +1703,7 @@ void HWCommandQueue::enqueue_trace(const uint32_t trace_id, bool blocking) {

void HWCommandQueue::copy_into_user_space(
const detail::ReadBufferDescriptor& read_buffer_descriptor, chip_id_t mmio_device_id, uint16_t channel) {
const auto& [buffer_layout, page_size, padded_page_size, linear_page_copy, dev_page_to_host_page_mapping, dst, dst_offset, num_pages_read, cur_dev_page_id] =
const auto& [buffer_layout, page_size, padded_page_size, dev_page_to_host_page_mapping, dst, dst_offset, num_pages_read, cur_dev_page_id] =
read_buffer_descriptor;

uint32_t padded_num_bytes = (num_pages_read * padded_page_size) + sizeof(CQDispatchCmd);
Expand Down Expand Up @@ -1747,7 +1749,7 @@ void HWCommandQueue::copy_into_user_space(

remaining_bytes_to_read -= bytes_xfered;

if (linear_page_copy) {
if (dev_page_to_host_page_mapping.empty()) {
void* contiguous_dst = (void*)(uint64_t(dst) + contig_dst_offset);
if ((page_size % ADDRESS_ALIGNMENT) == 0) {
uint32_t data_bytes_xfered = bytes_xfered - offset_in_completion_q_data;
Expand Down
22 changes: 7 additions & 15 deletions tt_metal/impl/dispatch/command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class EnqueueReadInterleavedBufferCommand : public EnqueueReadBufferCommand {
class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand {
private:
void add_prefetch_relay(HugepageDeviceCommand& command) override;
BufferPageMapping buffer_page_mapping;
const CoreCoord core;

public:
EnqueueReadShardedBufferCommand(
Expand All @@ -142,7 +142,7 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand {
void* dst,
SystemMemoryManager& manager,
uint32_t expected_num_workers_completed,
const BufferPageMapping& buffer_page_mapping,
const CoreCoord& core,
uint32_t src_page_index = 0,
std::optional<uint32_t> pages_to_read = std::nullopt) :
EnqueueReadBufferCommand(
Expand All @@ -154,7 +154,7 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand {
expected_num_workers_completed,
src_page_index,
pages_to_read),
buffer_page_mapping(buffer_page_mapping) {}
core(core) {}
};

class EnqueueWriteShardedBufferCommand;
Expand Down Expand Up @@ -241,7 +241,6 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand {

const std::optional<BufferPageMapping>& buffer_page_mapping;
const CoreCoord core;
const bool width_split;

public:
EnqueueWriteShardedBufferCommand(
Expand All @@ -254,8 +253,7 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand {
uint32_t expected_num_workers_completed,
uint32_t bank_base_address,
const std::optional<BufferPageMapping>& buffer_page_mapping,
const CoreCoord core,
bool width_split,
const CoreCoord& core,
uint32_t padded_page_size,
uint32_t dst_page_index = 0,
std::optional<uint32_t> pages_to_write = std::nullopt) :
Expand All @@ -272,8 +270,7 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand {
dst_page_index,
pages_to_write),
buffer_page_mapping(buffer_page_mapping),
core(core),
width_split(width_split) {
core(core) {
;
}
};
Expand Down Expand Up @@ -416,7 +413,6 @@ struct ReadBufferDescriptor {
TensorMemoryLayout buffer_layout;
uint32_t page_size;
uint32_t padded_page_size;
bool linear_page_copy;
vector<std::optional<uint32_t>> dev_page_to_host_page_mapping;
void* dst;
uint32_t dst_offset;
Expand All @@ -430,19 +426,15 @@ struct ReadBufferDescriptor {
uint32_t dst_offset,
uint32_t num_pages_read,
uint32_t cur_dev_page_id,
bool linear_page_copy = true) :
const std::vector<std::optional<uint32_t>>& dev_page_to_host_page_mapping = {}) :
buffer_layout(buffer.buffer_layout()),
page_size(this->page_size = buffer.page_size()),
padded_page_size(padded_page_size),
dst(dst),
dst_offset(dst_offset),
num_pages_read(num_pages_read),
cur_dev_page_id(cur_dev_page_id),
linear_page_copy(linear_page_copy) {
if (!linear_page_copy and is_sharded(this->buffer_layout)) {
this->dev_page_to_host_page_mapping = generate_buffer_page_mapping(buffer).dev_page_to_host_page_mapping_;
}
}
dev_page_to_host_page_mapping(dev_page_to_host_page_mapping) {}
};

/*
Expand Down

0 comments on commit 5d683ba

Please sign in to comment.