Skip to content

Commit

Permalink
#14245: fix some pybind issue that didn't pass the new parameter in p…
Browse files Browse the repository at this point in the history
…ython
  • Loading branch information
llongTT committed Dec 4, 2024
1 parent d51b60c commit c565d87
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,16 @@ void bind_interleaved_to_sharded(
tt::tt_metal::ShardOrientation shard_orientation,
const std::optional<ttnn::DataType>& output_dtype,
uint8_t queue_id,
const bool keep_l1_aligned) -> ttnn::Tensor {
return self(queue_id, input_tensor, grid, shard_shape, shard_scheme, shard_orientation, output_dtype);
const std::optional<bool>& keep_l1_aligned) -> ttnn::Tensor {
return self(
queue_id,
input_tensor,
grid,
shard_shape,
shard_scheme,
shard_orientation,
output_dtype,
keep_l1_aligned);
},
py::arg("input_tensor").noconvert(),
py::arg("grid"),
Expand All @@ -50,8 +58,8 @@ void bind_interleaved_to_sharded(
const MemoryConfig& sharded_memory_config,
const std::optional<ttnn::DataType>& output_dtype,
uint8_t queue_id,
const bool keep_l1_aligned) -> ttnn::Tensor {
return self(queue_id, input_tensor, sharded_memory_config, output_dtype);
const std::optional<bool>& keep_l1_aligned) -> ttnn::Tensor {
return self(queue_id, input_tensor, sharded_memory_config, output_dtype, keep_l1_aligned);
},
py::arg("input_tensor").noconvert(),
py::arg("sharded_memory_config"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ void bind_sharded_to_interleaved(
const std::optional<MemoryConfig>& memory_config,
const std::optional<DataType>& output_dtype,
uint8_t queue_id,
const bool is_l1_aligned) -> ttnn::Tensor {
const std::optional<bool>& is_l1_aligned) -> ttnn::Tensor {
return self(
queue_id,
input_tensor,
memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG),
output_dtype);
output_dtype,
is_l1_aligned);
},
py::arg("input_tensor").noconvert(),
py::arg("memory_config") = std::nullopt,
Expand Down

0 comments on commit c565d87

Please sign in to comment.