-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Triton #6798
Merged
Triton #6798
Changes from 7 commits
Commits
Show all changes
67 commits
Select commit
Hold shift + click to select a range
6dccf0a
Update infra_triggers.tf
ManfeiBai 9828123
Skeleton trition support
bhavya01 99bf48d
Merge branch 'master' into triton
bhavya01 b89e558
Fix bugs
bhavya01 64189bd
Fix custom call invocation
bhavya01 0c208ef
Refactor to include gpu custom call and create triton dir
bhavya01 b553ba7
Lint fixes
bhavya01 c5129e6
python lint fix
bhavya01 48e7127
Updated base image for CI
bhavya01 e04fc97
Update github workflow gcr image
bhavya01 37bf127
Merge branch 'master' into custom
bhavya01 6061895
Remove xrt build and test file
bhavya01 f59ddbf
Add temporary test to run triton kernel
bhavya01 158aed4
Fix tests
bhavya01 87b92c5
Update payload for xla gpu custom call
bhavya01 847ccc5
Update gpu runner
bhavya01 eca6d52
Merge branch 'master' into triton
bhavya01 2348ca3
Extract payload from triton kernel programatically
bhavya01 110c8c6
Merge branch 'master' into triton
bhavya01 a226150
Lint fixes
bhavya01 4c1f4f5
Only build triton files for GPU
bhavya01 431f822
build pytorch for ampere gpus
bhavya01 4bade16
c++ lint fix
bhavya01 1c5b47d
Python lint fix
bhavya01 3138a92
Fix torch cuda arch list
bhavya01 3f00cfd
Use a bigger machine for CI build
bhavya01 e729cfb
Add triton test to run_tests.sh
bhavya01 8e304c0
Update triton env variable
bhavya01 27bdc3a
Set up a separate CI for triton tests
bhavya01 9a3ef84
Fix github workflow to add _triton.yml
bhavya01 ade444d
Rebuild torch xla for triton tests
bhavya01 cb0bb85
Create a separate CI tab for triton tests
bhavya01 015b1ad
Separate build and test phase for triton
bhavya01 a18028a
Fix flags for docker run container
bhavya01 993ee92
Update triton.yml to output docker image
bhavya01 a87b782
Add a python binding to register custom calls and remove jax files
bhavya01 bf05d1b
Fix lint
bhavya01 4582fe8
Merge main
bhavya01 9680167
Merge master
bhavya01 a7b94c6
Merge master after updating
bhavya01 e14636a
Update CI to use cuda plugin
bhavya01 256d819
Install jaxlib while setting up triton tests
bhavya01 c616e64
Install triton package while running triton tests
bhavya01 60b8d18
Experimental: Build pytorch with cuda
bhavya01 2bde624
Revert build pytorch with CUDA
bhavya01 e6c4e0a
Merge branch 'master' into triton
bhavya01 14ee545
Remove ansible path for triton CI
bhavya01 25acb26
Style fixes
bhavya01 6b0ac18
[Experimental] test new CI
bhavya01 4d97150
[Experimental] Set XLA_CUDA=0 for cuda arch in ansible
bhavya01 e079049
[Experimental] Update CI to build pytorch cuda with ansible
bhavya01 d9c89b6
Update CI
bhavya01 7a6c809
Fix CI workflow file
bhavya01 6b1954d
Fix CI workflow
bhavya01 21797a6
Fix the wheels installed for tests requiring torch cuda
bhavya01 e6e89d3
Add compute_capability=8.6 for xla cuda plugin
bhavya01 ac45fe1
update TORCH_CUDA_ARCH_LIST
bhavya01 f828fbb
Experimental build torch and torch_xla cuda wheels
bhavya01 ac56c00
Merge branch 'master' into triton
bhavya01 c3b8653
Update build_and_test.yml
bhavya01 a1168c6
Update dlpack test to only use one device
bhavya01 39551a2
Remove compute capability 8.6 from cuda plugin
bhavya01 35e0869
Remove triton.sh
bhavya01 f95d898
Default empty torch_cuda_arch_list in ansible config
bhavya01 291104d
Merge branch 'master' into triton
bhavya01 f5c9b1a
Revert CI changes
bhavya01 5b23969
Revert CI changes pt2
bhavya01 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#include "torch_xla/csrc/ops/gpu_custom_call.h" | ||
|
||
#include "torch_xla/csrc/lowering_context.h" | ||
#include "torch_xla/csrc/ops/xla_ops.h" | ||
#include "torch_xla/csrc/xla_lower_util.h" | ||
|
||
namespace torch_xla { | ||
|
||
GpuCustomCall::GpuCustomCall(torch::lazy::OpList inputs, | ||
xla::Shape output_shape, | ||
const std::string& payload) | ||
: XlaNode(xla_gpu_custom_call, inputs, std::move(output_shape), | ||
/*num_outputs=*/1, torch::lazy::MHash(payload)), | ||
payload_(payload) {} | ||
|
||
torch::lazy::NodePtr GpuCustomCall::Clone(torch::lazy::OpList operands) const { | ||
return torch::lazy::MakeNode<GpuCustomCall>(operands, xla_shape(), payload_); | ||
} | ||
|
||
XlaOpVector GpuCustomCall::Lower(LoweringContext* loctx) const { | ||
std::vector<xla::XlaOp> inputs; | ||
inputs.reserve(operands().size()); | ||
for (auto& operand : operands()) { | ||
inputs.push_back(loctx->GetOutputOp(operand)); | ||
} | ||
xla::XlaOp output = BuildGpuCustomCall(inputs, xla_shape(), payload_); | ||
return ReturnOp(output, loctx); | ||
} | ||
|
||
std::string GpuCustomCall::ToString() const { | ||
std::stringstream ss; | ||
ss << XlaNode::ToString() << ", " << payload_; | ||
return ss.str(); | ||
} | ||
|
||
} // namespace torch_xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#ifndef XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ | ||
#define XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ | ||
|
||
#include "torch_xla/csrc/ir.h" | ||
|
||
namespace torch_xla { | ||
// TODO: Merge GPU and TPU custom call. | ||
bhavya01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class GpuCustomCall : public XlaNode { | ||
public: | ||
// Make a GPU custom call with payload, e.g., Triton. | ||
GpuCustomCall(torch::lazy::OpList inputs, xla::Shape output_shape, | ||
const std::string& payload); | ||
|
||
torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; | ||
|
||
XlaOpVector Lower(LoweringContext* loctx) const override; | ||
|
||
std::string ToString() const override; | ||
|
||
private: | ||
std::string payload_; | ||
}; | ||
|
||
} // namespace torch_xla | ||
|
||
#endif // XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
load("//bazel:rules_def.bzl","cc_proto_library",) | ||
|
||
cc_library( | ||
name = "cuda_vendor", | ||
hdrs = [ | ||
"gpu_vendor.h", | ||
], | ||
deps = [ | ||
"@local_config_cuda//cuda:cuda_headers", | ||
"@local_config_cuda//cuda:cudnn_header", | ||
], | ||
) | ||
|
||
proto_library( | ||
name = "triton_proto", | ||
srcs = ["triton.proto"], | ||
) | ||
|
||
cc_proto_library( | ||
name = "triton_cc_proto", | ||
deps = [":triton_proto"], | ||
) | ||
|
||
cc_library( | ||
name = "cuda_gpu_kernel_helpers", | ||
srcs = [ | ||
"gpu_kernel_helpers.cpp", | ||
], | ||
hdrs = [ | ||
"gpu_kernel_helpers.h", | ||
], | ||
copts = [ | ||
"-fexceptions", | ||
], | ||
features = ["-use_header_modules"], | ||
deps = [ | ||
":cuda_vendor", | ||
"@xla//xla/tsl/cuda:cupti", | ||
"@xla//xla/tsl/cuda:cusolver", | ||
"@xla//xla/tsl/cuda:cusparse", | ||
"@com_google_absl//absl/base:core_headers", | ||
"@com_google_absl//absl/log:check", | ||
"@com_google_absl//absl/memory", | ||
"@com_google_absl//absl/status", | ||
"@com_google_absl//absl/status:statusor", | ||
"@com_google_absl//absl/strings", | ||
"@com_google_absl//absl/strings:str_format", | ||
"@local_config_cuda//cuda:cublas_headers", | ||
"@local_config_cuda//cuda:cuda_headers", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "triton_utils", | ||
srcs = ["triton_utils.cpp"], | ||
hdrs = ["triton_utils.h"], | ||
visibility = ["//visibility:public"], | ||
deps = [ | ||
":cuda_gpu_kernel_helpers", | ||
":cuda_vendor", | ||
":triton_cc_proto", | ||
"@com_google_absl//absl/status", | ||
"@com_google_absl//absl/status:statusor", | ||
"@com_google_absl//absl/strings", | ||
"@zlib", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "triton_kernels", | ||
srcs = ["triton_kernels.cpp"], | ||
hdrs = ["triton_kernels.h"], | ||
deps = [ | ||
":cuda_gpu_kernel_helpers", | ||
":cuda_vendor", | ||
":triton_utils", | ||
":triton_cc_proto", | ||
"@xla//xla/service:custom_call_target_registry", | ||
"@xla//xla/service:custom_call_status", | ||
"@xla//xla/stream_executor/gpu:asm_compiler", | ||
"@xla//xla/tsl/cuda:cudart", | ||
"@tsl//tsl/platform:env", | ||
"@com_google_absl//absl/base:core_headers", | ||
"@com_google_absl//absl/cleanup", | ||
"@com_google_absl//absl/container:flat_hash_map", | ||
"@com_google_absl//absl/log", | ||
"@com_google_absl//absl/log:check", | ||
"@com_google_absl//absl/status", | ||
"@com_google_absl//absl/status:statusor", | ||
"@com_google_absl//absl/strings:str_format", | ||
"@com_google_absl//absl/synchronization", | ||
], | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I expect we have more than one test kernel going fwd - wdyt we put all kernel payloads for Triton (and Pallas) in a separate yaml file?
cc @jiawenliu64