diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b602a686e727..afb4db5290c0 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -718,35 +718,19 @@ void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) { XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); } -void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List tile_assignment, c10::List group_assignment, c10::List replication_groups, int64_t sharding_type) { +void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List& tile_assignment, c10::List& group_assignment, c10::List& replication_groups, int64_t sharding_type) { std::cout << "at xla_mark_sharding_dynamo_custom_op0" << std::endl; std::cout << "input: " << input << std::endl; - // std::cout << "tile_assignment: " << tile_assignment << std::endl; + std::cout << "tile_assignment.size(): " << tile_assignment.size() << std::endl; std::cout << "converting tile_assignment_py" << std::endl; - // const py::list& tile_assignment_py = py::cast(tile_assignment[0]); - const py::list& tile_assignment_py = py::cast(torch::lazy::ToVector(tile_assignment[0])); - - // std::cout << "group_assignment: " << group_assignment << std::endl; - std::cout << "converting group_assignment_py" << std::endl; - const py::list& group_assignment_py = py::cast(group_assignment); - - // std::cout << "replication_groups: " << replication_groups << std::endl; - std::cout << "converting replication_groups_py" << std::endl; - const py::list& replication_groups_py = py::cast(replication_groups); - - std::cout << "at xla_mark_sharding_dynamo_custom_op1" << std::endl; - - const xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding( - tile_assignment_py, group_assignment_py, replication_groups_py, - ShardingUtil::ShardingType(sharding_type)); - - - std::cout << "at xla_mark_sharding_dynamo_custom_op2" << std::endl; - - xla_mark_sharding(input, op_sharding); + py::list tile_assignment_py = py::list(); + for (const at::IntArrayRef t : tile_assignment) { + // auto t_vec = XlaHelpers::I64List(t); + // tile_assignment_py.append(py::cast(t_vec)); + } - std::cout << "at xla_mark_sharding_dynamo_custom_op3" << std::endl; + std::cout << "tile_assignment_py.size(): " << tile_assignment_py.size() << std::endl; } // Macro for defining a function that will be run at static initialization time to define a library of operators in the namespace. @@ -1671,6 +1655,12 @@ void InitXlaModuleBindings(py::module m) { .def(py::init([](const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, int sharding_type) { + std::cout << "at OpSharding" << std::endl; + // std::cout << "tile_assignment:" << tile_assignment << std::endl; + // auto vec = tile_assignment.cast>>(tile_assignment); + // std::cout << "casted: " << vec << std::endl; + // std::cout << "group_assignment:" << group_assignment << std::endl; + // std::cout << "replication_groups:" << replication_groups << std::endl; return ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)); @@ -1679,12 +1669,39 @@ void InitXlaModuleBindings(py::module m) { xla::OpSharding sharding) { xla_mark_sharding(input, sharding); }); - // m.def("_xla_mark_sharding_dynamo_custom_op", - // [](const at::Tensor& input, xla::OpSharding sharding) { - // // xla_mark_sharding_dynamo_custom_op(input, tile_assignment, group_assignment, replication_groups, sharding_type); - // // at::IntArrayRef tile_assignment, at::IntArrayRef group_assignment, c10::List replication_groups, int64_t sharding_type - // at::IntArrayRef tile_assignment = - // }); + m.def("_xla_mark_sharding_dynamo_custom_op", + [](const at::Tensor& input, const py::list& tile_assignment, + const py::list& group_assignment, + const py::list& replication_groups, int sharding_type) { + // xla_mark_sharding_dynamo_custom_op(input, tile_assignment, group_assignment, replication_groups, sharding_type); + // at::IntArrayRef tile_assignment, at::IntArrayRef group_assignment, c10::List replication_groups, int64_t sharding_type + // auto sharding_type = sharding.type(); + std::cout << "WONJOO: in pybind::_xla_mark_sharding_dynamo_custom_op" << std::endl; + std::cout << "WONJOO: tile_assignment: " << tile_assignment << std::endl; + std::cout << "WONJOO: tile_assignment.size(): " << tile_assignment.size() << std::endl; + std::cout << "WONJOO: group_assignment: " << group_assignment << std::endl; + std::cout << "WONJOO: group_assignment.size(): " << group_assignment.size() << std::endl; + std::cout << "WONJOO: replication_groups: " << replication_groups << std::endl; + std::cout << "WONJOO: replication_groups.size(): " << replication_groups.size() << std::endl; + std::cout << "WONJOO: sharding_type: " << sharding_type << std::endl; + + c10::List time_assignment_list = c10::List(); + for (auto t : tile_assignment) { + time_assignment_list.push_back(at::IntArrayRef(t.cast>())); + } + + c10::List group_assignment_list = c10::List(); + for (auto t : group_assignment) { + group_assignment_list.push_back(at::IntArrayRef(t.cast>())); + } + + c10::List replication_groups_list = c10::List(); + for (auto t : replication_groups) { + replication_groups_list.push_back(at::IntArrayRef(t.cast>())); + } + + xla_mark_sharding_dynamo_custom_op(input, time_assignment_list, group_assignment_list, replication_groups_list, sharding_type); + }); m.def("_xla_clear_sharding", [](const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); xtensor->ClearShardingSpec(); diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index f08b2370d7d5..178e813efbcb 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -107,9 +107,12 @@ def get_op_sharding(self, # If flatten = True, return the flattened version of OpSharding # print each return type to debug - print(tile_assignment.tolist()) - print(group_assignment) - print(replication_groups) + print(f'tile_assignment.tolist(): {tile_assignment.tolist()}') + print(f'type(tile_assignment.tolist()): {type(tile_assignment.tolist())}') + print(f'group_assignment: {group_assignment}') + print(f'type(group_assignment) {type(group_assignment)}') + print(f'replication_groups: {replication_groups}') + print(f'type(replication_groups): {type(replication_groups)}') if flatten: return (tile_assignment.tolist(), group_assignment, replication_groups, int(sharding_type)) @@ -520,7 +523,8 @@ def mark_sharding_dynamo_custom_op( t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor: """ - Same functionality as `mark_sharding` above, except this variant uses the custom mark_sharding op in torch_xla._XLAC to allow dynamo to recognize and trace it. + Same functionality as `mark_sharding` above, except this variant uses the custom + mark_sharding op in torch_xla._XLAC to allow dynamo to recognize and trace it. """ num_devices = xr.global_runtime_device_count() assert num_devices > 0, "This requires XLA supported device(s)." @@ -535,12 +539,11 @@ def mark_sharding_dynamo_custom_op( tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding(partition_spec, flatten = True) - print('about to call xla_mark_sharding_dynamo_custom_op') if isinstance(t, XLAShardedTensor): - torch.ops.xla.xla_mark_sharding_dynamo_custom_op(t.global_tensor, tile_assignment, group_assignment, replication_groups, sharding_type) + # torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t.global_tensor, op_sharding) + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t.global_tensor, tile_assignment, group_assignment, replication_groups, sharding_type) return t - torch.ops.xla.xla_mark_sharding_dynamo_custom_op(t, tile_assignment, group_assignment, replication_groups, sharding_type) - print('xla_mark_sharding_dynamo_custom_op call finished') + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, tile_assignment, group_assignment, replication_groups, sharding_type) return XLAShardedTensor(t)