Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavya01 committed May 10, 2024
1 parent a87b782 commit bf05d1b
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 9 deletions.
6 changes: 3 additions & 3 deletions test/test_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def test_gpu_custom_call_triton_add(self):
grid = (size // block_size,)
payload = torch_triton.triton_call(
x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size)
output = torch_xla._XLAC._xla_gpu_custom_call(
[x,y], payload, [output.shape],
[torch.int64])
output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload,
[output.shape], [torch.int64])
output_torch = x + y
self.assertTrue(torch.allclose(output[0].cpu(), output_torch.cpu()))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
torch.set_default_dtype(torch.float32)
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2320,7 +2320,8 @@ void InitXlaModuleBindings(py::module m) {
m.def("_xla_register_custom_call_target",
[](const std::string& fn_name, const py::capsule& function_ptr,
const std::string& platform) {
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(fn_name, function_ptr.get_pointer(), platform);
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
fn_name, function_ptr.get_pointer(), platform);
});
m.def("_set_xla_custom_op_name_prefix",
[](const at::Tensor& input, const std::string& op_name_prefix,
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/ops/gpu_custom_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ GpuCustomCall::GpuCustomCall(torch::lazy::OpList inputs,
xla::Shape output_shape,
const std::string& payload)
: XlaNode(xla_gpu_custom_call, inputs, std::move(output_shape),
/*num_outputs=*/output_shape.tuple_shapes_size(), torch::lazy::MHash(payload)),
/*num_outputs=*/output_shape.tuple_shapes_size(),
torch::lazy::MHash(payload)),
payload_(payload) {}

torch::lazy::NodePtr GpuCustomCall::Clone(torch::lazy::OpList operands) const {
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ std::vector<xla::XlaOp> BuildTpuCustomCall(
xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores,
xla::XlaOp iou_threshold);

std::vector<xla::XlaOp> BuildGpuCustomCall(const std::vector<xla::XlaOp>& inputs,
const xla::Shape& output_shape,
const std::string& payload);
std::vector<xla::XlaOp> BuildGpuCustomCall(
const std::vector<xla::XlaOp>& inputs, const xla::Shape& output_shape,
const std::string& payload);

} // namespace torch_xla

Expand Down
8 changes: 7 additions & 1 deletion torch_xla/experimental/torch_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from jax._src.lib import gpu_triton as lib_triton
import torch_xla

torch_xla._XLAC._xla_register_custom_call_target('triton_kernel_call', lib_triton._cuda_triton.get_custom_call(), 'CUDA')
# Register target corresponding to gpu custom call using the
# implementation provided by jaxlib.
torch_xla._XLAC._xla_register_custom_call_target(
'triton_kernel_call', lib_triton._cuda_triton.get_custom_call(), 'CUDA')

Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]
GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]]

Expand Down Expand Up @@ -102,6 +106,8 @@ def get_or_create_triton_kernel(
args,
dump,
) -> Tuple[lib_triton.TritonKernel, Any]:
# Extract the compilation parameters and compiled ptx from the
# compiled triton kernel.
ttir = compiled_kernel.asm['ttir']
ptx = compiled_kernel.asm['ptx']
if (dump):
Expand Down

0 comments on commit bf05d1b

Please sign in to comment.