Skip to content

Commit

Permalink
feedbakc
Browse files Browse the repository at this point in the history
  • Loading branch information
nsmithtt committed Aug 7, 2024
1 parent f2fddea commit b91c0a0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion runtime/lib/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ std::vector<TensorDesc> getProgramInputs(Flatbuffer binary,
std::uint32_t programIndex) {
std::vector<TensorDesc> inputs;
auto const *program = getBinary(binary)->programs()->Get(programIndex);
assert(program->device_programs()->size() == 1);
assert(program->device_programs()->size() == 1 && "Currently only one device is supported");
for (auto const *input : *program->device_programs()->Get(0)->inputs()) {
TensorDesc desc;
desc.shape = {input->desc()->shape()->begin(),
Expand Down
12 changes: 6 additions & 6 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ static std::pair<std::shared_ptr<::tt::tt_metal::Buffer>,
std::shared_ptr<::tt::tt_metal::Event>>
prepareInput(::tt::tt_metal::Device *device, MetalTensor const &metalTensor,
void *data, ::tt::target::TensorRef const *tensorRef) {
if (TensorDesc const *hostTensorDesc = std::get_if<TensorDesc>(&metalTensor);
hostTensorDesc) {
if (std::holds_alternative<TensorDesc>(metalTensor)) {
// todo assert that tensorDesc matches hostTensorDesc
std::shared_ptr<::tt::tt_metal::Buffer> buffer =
createBufferFromTensorRef(device, tensorRef);
Expand All @@ -146,10 +145,11 @@ prepareInput(::tt::tt_metal::Device *device, MetalTensor const &metalTensor,
::tt::tt_metal::EnqueueWriteBuffer(cq, buffer, data, blocking);
::tt::tt_metal::EnqueueRecordEvent(cq, event);
return std::make_pair(buffer, event);
} else if (std::shared_ptr<::tt::tt_metal::Buffer> const *buffer =
std::get_if<std::shared_ptr<::tt::tt_metal::Buffer>>(
&metalTensor);
buffer) {
} else if (std::holds_alternative<std::shared_ptr<::tt::tt_metal::Buffer>>(
metalTensor)) {
std::shared_ptr<::tt::tt_metal::Buffer> buffer =
std::get<std::shared_ptr<::tt::tt_metal::Buffer>>(
metalTensor);
throw std::runtime_error("Input from buffer not supported yet");
}
assert(false && "Unsupported tensor type");
Expand Down

0 comments on commit b91c0a0

Please sign in to comment.