Skip to content

Commit

Permalink
#5449: Incorporate changes based on PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
tooniz committed Mar 1, 2024
1 parent 5874d37 commit 0cf156a
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 61 deletions.
5 changes: 2 additions & 3 deletions docs/aspell-dictionary.pws
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
personal_ws-1.1 en 528
personal_ws-1.1 en 529
ABI
ADDI
API
Expand Down Expand Up @@ -61,9 +61,8 @@ EndTrace
EnqueueProgram
EnqueueReadBuffer
EnqueueRecordEvent
EnqueueWaitForEvent
EnqueueReadBuffer
EnqueueTrace
EnqueueWaitForEvent
EnqueueWriteBuffer
Enum
Enums
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ TEST_F(CommandQueueFixture, EnqueueTwoProgramTrace) {
EnqueueProgram(command_queue, op1, blocking);
EnqueueReadBuffer(command_queue, output, eager_output_data.data(), blocking);
}
if (!blocking) {
if (not blocking) {
// (Optional) wait for the last non-blocking command to finish
Finish(command_queue);
}
Expand Down Expand Up @@ -210,7 +210,7 @@ TEST_F(CommandQueueFixture, EnqueueMultiProgramTraceBenchmark) {
}
EnqueueReadBuffer(command_queue, output, eager_output_data.data(), blocking);
}
if (!blocking) {
if (not blocking) {
// (Optional) wait for the last non-blocking command to finish
Finish(command_queue);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, EnqueueOneProgramTrace) {
EnqueueWriteBuffer(data_movement_queue, input, input_data.data(), true);

BeginTrace(trace);
EnqueueProgram(trace.trace_queue(), simple_program, false);
EnqueueProgram(trace.queue(), simple_program, false);
EndTrace(trace);
// Instantiate a trace on a device queue
uint32_t trace_id = InstantiateTrace(trace, command_queue);
Expand Down Expand Up @@ -148,9 +148,9 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, EnqueueOneProgramTraceLoops) {
for (auto i = 0; i < num_loops; i++) {
EnqueueWriteBuffer(data_movement_queue, input, input_data.data(), true);

if (!trace_captured) {
if (not trace_captured) {
BeginTrace(trace);
EnqueueProgram(trace.trace_queue(), simple_program, false);
EnqueueProgram(trace.queue(), simple_program, false);
EndTrace(trace);
// Instantiate a trace on a device queue
trace_id = InstantiateTrace(trace, command_queue);
Expand Down Expand Up @@ -214,7 +214,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, EnqueueOneProgramTraceBenchmark) {
EnqueueProgram(command_queue, simple_program, blocking);
EnqueueReadBuffer(command_queue, output, eager_output_data.data(), blocking);
}
if (!blocking) {
if (not blocking) {
// (Optional) wait for the last non-blocking command to finish
Finish(command_queue);
}
Expand Down
8 changes: 4 additions & 4 deletions tt_metal/host_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ void EnqueueProgram(CommandQueue& cq, std::variant<std::reference_wrapper<Progra
void Finish(CommandQueue& cq);

/**
* Begins capture on a trace, when the trace is in capture mode all programs push into the trace queue will not be executed.
* The capture must be later ended via EndTrace, and can be instantiated via InstantiateTrace on a device command queue.
* Begins capture on a trace, when the trace is in capture mode all programs pushed into the trace queue will have their execution delayed until the trace is instantiated and enqueued.
* The capture must be later ended via EndTrace, and can be instantiated via InstantiateTrace on a device command queue, and finally scheduled to be executed via EnqueueTrace.
*
* Return value: CommandQueue&
*
Expand All @@ -364,7 +364,7 @@ CommandQueue& BeginTrace(Trace &trace);

/**
* Completes capture on a trace, if captured commands do not conform to the rules of the trace, the trace will be invalidated.
* This trace can later be instantiated via InstantiateTrace on a device command queue, and executed via EnqueueTrace on the same device command queue.
* This trace can later be instantiated via InstantiateTrace on a device command queue, and enqueued for execution via EnqueueTrace on the same device command queue.
*
* Return value: void
*
Expand All @@ -376,7 +376,7 @@ void EndTrace(Trace &trace);

/**
* Instantiates a trace on a device command queue, triggering the staging of traced commands and data to the device.
* Staging is a blocking operation and must be completed before the trace can be enqueued. A unique trace instance id is returned
* Staging is a blocking operation and must be completed before the trace can be enqueued for exeuction. A unique trace instance id is returned
*
* Return value: uint32_t
*
Expand Down
65 changes: 21 additions & 44 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <algorithm> // for copy() and assign()
#include <iterator> // for back_inserter
#include <memory>
#include <string>

#include "debug_tools.hpp"
#include "dev_msgs.h"
Expand Down Expand Up @@ -537,7 +538,7 @@ void EnqueueProgramCommand::process() {
this->manager.issue_queue_reserve_back(cmd_size, this->command_queue_id);
this->manager.cq_write(cmd.data(), DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND, write_ptr);

bool tracing = this->trace.has_value() && !this->trace->get().trace_complete;
bool tracing = this->trace.has_value() and not this->trace->get().trace_complete;
vector<uint32_t> trace_host_data;
uint32_t start_addr = system_memory_temporary_storage_address;
constexpr static uint32_t padding_alignment = 16;
Expand All @@ -564,7 +565,7 @@ void EnqueueProgramCommand::process() {
start_addr + align(system_memory_temporary_storage_address - start_addr, DeviceCommand::PROGRAM_PAGE_SIZE);

array<uint32_t, 4> cb_data;
TT_ASSERT(cb_data.size() * sizeof(uint32_t) <= padding_alignment, "cb_data size is exceeds padding_alignment");
TT_ASSERT(cb_data.size() * sizeof(uint32_t) <= padding_alignment, "cb_data size exceeds padding_alignment");
for (const shared_ptr<CircularBuffer>& cb : program.circular_buffers()) {
for (const auto buffer_index : cb->buffer_indices()) {
cb_data = {
Expand Down Expand Up @@ -1009,7 +1010,7 @@ void HWCommandQueue::enqueue_wait_for_event(std::reference_wrapper<Event> event)

void HWCommandQueue::enqueue_trace() {
ZoneScopedN("HWCommandQueue_enqueue_trace");
TT_ASSERT(false, "Not implemented");
TT_THROW("Not implemented");
}

void HWCommandQueue::copy_into_user_space(uint32_t event, uint32_t read_ptr, chip_id_t mmio_device_id, uint16_t channel) {
Expand Down Expand Up @@ -1130,15 +1131,15 @@ Trace::Trace() : trace_complete(false), num_data_bytes(0) {
}

void Trace::record(const TraceNode& trace_node) {
TT_ASSERT(not this->trace_complete, "Cannot record any more for a completed trace");
TT_FATAL(not this->trace_complete, "Cannot record any more for a completed trace");
this->num_data_bytes += trace_node.num_data_bytes;
this->history.push_back(trace_node);
}

void Trace::validate() {
for (const auto& cmd : this->trace_queue().worker_queue) {
for (const auto& cmd : this->queue().worker_queue) {
if (cmd.blocking.has_value()) {
TT_ASSERT(!cmd.blocking.value(), "Blocking commands are not supported in traces");
TT_FATAL(cmd.blocking.value() == false, "Blocking commands are not supported in traces");
}
}
}
Expand All @@ -1159,7 +1160,7 @@ uint32_t Trace::instantiate(CommandQueue& cq) {
// - map the trace id to the DRAM buffer for later enqueue Trace

if (trace_instances.count(trace_id)) {
log_fatal("Trace ID {} already exists", trace_id);
TT_THROW("Trace ID " + std::to_string(trace_id) + " already exists");
}

trace_instances.insert(trace_id);
Expand Down Expand Up @@ -1236,7 +1237,7 @@ void EnqueueWriteBufferImpl(CommandQueue& cq, std::variant<std::reference_wrappe
void EnqueueProgram(CommandQueue& cq, std::variant < std::reference_wrapper<Program>, std::shared_ptr<Program> > program, bool blocking) {
detail::DispatchStateCheck(true);
if (cq.get_mode() != CommandQueue::CommandQueueMode::TRACE) {
TT_ASSERT(cq.id() == 0, "EnqueueProgram only supported on first command queue on device for time being.");
TT_FATAL(cq.id() == 0, "EnqueueProgram only supported on first command queue on device for time being.");
}
cq.run_command(CommandInterface{
.type = EnqueueCommandType::ENQUEUE_PROGRAM,
Expand Down Expand Up @@ -1331,23 +1332,12 @@ void FinishImpl(CommandQueue& cq) {

CommandQueue& BeginTrace(Trace& trace) {
TT_ASSERT(not trace.trace_complete, "Already completed this trace");
TT_ASSERT(trace.trace_queue().empty(), "Cannot begin trace on one that already captured commands");
return trace.trace_queue();
TT_ASSERT(trace.queue().empty(), "Cannot begin trace on one that already captured commands");
return trace.queue();
}

void EndTrace(Trace& trace) {
TT_ASSERT(not trace.trace_complete, "Already completed this trace");
// SystemMemoryManager& manager = trace.command_queue.manager;
// const uint32_t command_queue_id = trace.command_queue.id;
// TT_FATAL(
// trace.num_data_bytes + trace.history.size() * DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND <=
// manager.get_issue_queue_limit(command_queue_id),
// "Trace does not fit in issue queue");
// trace.trace_complete = true;
// manager.set_issue_queue_size(
// command_queue_id, trace.num_data_bytes + DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND * trace.history.size());
// trace.create_replay();
// manager.reset(trace.command_queue.id);
trace.trace_complete = true;
trace.validate();
}
Expand All @@ -1362,9 +1352,9 @@ uint32_t InstantiateTrace(Trace& trace, CommandQueue& cq) {
void EnqueueTrace(CommandQueue& cq, uint32_t trace_id, bool blocking) {
detail::DispatchStateCheck(true);
TT_ASSERT(cq.trace(), "A trace has not been instantiated on this command queue yet!");
if (cq.trace()->trace_instances.count(trace_id) == 0)
log_fatal("Trace instance {} does not exist", trace_id);

if (cq.trace()->trace_instances.count(trace_id) == 0) {
TT_THROW("Trace instance " + std::to_string(trace_id) + " does not exist");
}
cq.run_command(CommandInterface{
.type = EnqueueCommandType::ENQUEUE_TRACE,
.blocking = blocking
Expand All @@ -1373,22 +1363,10 @@ void EnqueueTrace(CommandQueue& cq, uint32_t trace_id, bool blocking) {

void EnqueueTraceImpl(CommandQueue& cq) {
// STUB: Run the trace in eager mode for now
auto& tq = cq.trace()->trace_queue();
auto& tq = cq.trace()->queue();
for (const auto& cmd : tq.worker_queue) {
cq.run_command_impl(cmd);
}

// Run the trace
// HWCommandQueue& command_queue = trace.command_queue;
// uint32_t trace_size = trace.history.size() * DeviceCommand::NUM_BYTES_IN_DEVICE_COMMAND + trace.num_data_bytes;
// command_queue.manager.issue_queue_reserve_back(trace_size, command_queue.id);
// command_queue.manager.issue_queue_push_back(trace_size, false, command_queue.id);

// // This will block because the wr toggles will be different between the host and the device
// if (blocking) {
// command_queue.manager.issue_queue_reserve_back(trace_size, command_queue.id);
// }
// cq.hw_command_queue().enqueue_trace();
}

CommandQueue::CommandQueue(Device* device, uint32_t id, CommandQueueMode mode) :
Expand All @@ -1405,7 +1383,7 @@ CommandQueue::CommandQueue(Device* device, uint32_t id, CommandQueueMode mode) :
CommandQueue::CommandQueue(Trace* trace) :
device_ptr(nullptr),
trace_ptr(trace),
cq_id(CommandQueue::TRACE_QUEUE_CQ_ID),
cq_id(-1),
mode(CommandQueueMode::TRACE),
worker_state(CommandQueueState::IDLE) {
TT_ASSERT(this->trace_ptr, "A valid trace must be provided for a trace mode queue");
Expand Down Expand Up @@ -1443,8 +1421,7 @@ void CommandQueue::wait_until_empty() {
}

void CommandQueue::set_mode(const CommandQueueMode& mode_) {
TT_ASSERT(
!this->trace_mode(), "Cannot change mode of a trace command queue, copy to a non-trace command queue instead!");
TT_ASSERT(not this->trace_mode(), "Cannot change mode of a trace command queue, copy to a non-trace command queue instead!");
this->mode = mode_;
if (this->async_mode()) {
this->start_worker();
Expand Down Expand Up @@ -1490,10 +1467,10 @@ void CommandQueue::run_worker() {

void CommandQueue::run_command(const CommandInterface& command) {
log_trace(LogDispatch, "CQ{} received {} in {} mode", this->cq_id, command.type, this->mode);
if (!this->passthrough_mode()) {
if (not this->passthrough_mode()) {
this->worker_queue.push(command);
if (command.blocking.has_value() and *command.blocking == true) {
TT_ASSERT(!this->trace_mode(), "Blocking commands cannot be traced!");
TT_ASSERT(not this->trace_mode(), "Blocking commands cannot be traced!");
this->wait_until_empty();
}
} else {
Expand Down Expand Up @@ -1556,7 +1533,7 @@ std::ostream& operator<<(std::ostream& os, EnqueueCommandType const& type) {
case EnqueueCommandType::ENQUEUE_WAIT_FOR_EVENT: os << "ENQUEUE_WAIT_FOR_EVENT"; break;
case EnqueueCommandType::FINISH: os << "FINISH"; break;
case EnqueueCommandType::FLUSH: os << "FLUSH"; break;
default: tt::log_fatal("Invalid command type!");
default: TT_THROW("Invalid command type!");
}
return os;
}
Expand All @@ -1566,7 +1543,7 @@ std::ostream& operator<<(std::ostream& os, CommandQueue::CommandQueueMode const&
case CommandQueue::CommandQueueMode::PASSTHROUGH: os << "PASSTHROUGH"; break;
case CommandQueue::CommandQueueMode::ASYNC: os << "ASYNC"; break;
case CommandQueue::CommandQueueMode::TRACE: os << "TRACE"; break;
default: tt::log_fatal("Invalid CommandQueueMode type!");
default: TT_THROW("Invalid CommandQueueMode type!");
}
return os;
}
5 changes: 1 addition & 4 deletions tt_metal/impl/dispatch/command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ class Trace {

public:
Trace();
CommandQueue& trace_queue() const { return *cq; };
CommandQueue& queue() const { return *cq; };
uint32_t instantiate(CommandQueue& cq); // return a unique trace id
};

Expand Down Expand Up @@ -499,8 +499,6 @@ class HWCommandQueue {
void completion_wrap(uint32_t event);
void launch(launch_msg_t& msg);
friend void EnqueueTraceImpl(CommandQueue& cq);
// friend void EndTrace(Trace& trace);
// friend Trace BeginTrace(CommandQueue& cq);
friend void EnqueueProgramImpl(CommandQueue& cq, std::variant < std::reference_wrapper<Program>, std::shared_ptr<Program> > program, bool blocking);
friend void EnqueueReadBufferImpl(CommandQueue& cq, std::variant<std::reference_wrapper<Buffer>, std::shared_ptr<Buffer> > buffer, void* dst, bool blocking);
friend void EnqueueWriteBufferImpl(CommandQueue& cq, std::variant<std::reference_wrapper<Buffer>, std::shared_ptr<Buffer> > buffer, const void* src, bool blocking);
Expand Down Expand Up @@ -568,7 +566,6 @@ class CommandQueue {
RUNNING = 1,
TERMINATE = 2,
};
constexpr static uint32_t TRACE_QUEUE_CQ_ID = 401;
friend class Trace;
friend void EnqueueTraceImpl(CommandQueue& cq);
friend uint32_t InstantiateTrace(Trace& trace, CommandQueue& cq);
Expand Down

0 comments on commit 0cf156a

Please sign in to comment.