-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Triton #6798
Triton #6798
Changes from 47 commits
6dccf0a
9828123
99bf48d
b89e558
64189bd
0c208ef
b553ba7
c5129e6
48e7127
e04fc97
37bf127
6061895
f59ddbf
158aed4
87b92c5
847ccc5
eca6d52
2348ca3
110c8c6
a226150
4c1f4f5
431f822
4bade16
1c5b47d
3138a92
3f00cfd
e729cfb
8e304c0
27bdc3a
9a3ef84
ade444d
cb0bb85
015b1ad
a18028a
993ee92
a87b782
bf05d1b
4582fe8
9680167
a7b94c6
e14636a
256d819
c616e64
60b8d18
2bde624
e6c4e0a
14ee545
25acb26
6b0ac18
4d97150
e079049
d9c89b6
7a6c809
6b1954d
21797a6
e6e89d3
ac45fe1
f828fbb
ac56c00
c3b8653
a1168c6
39551a2
35e0869
f95d898
291104d
f5c9b1a
5b23969
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#!/bin/bash | ||
|
||
set -ex | ||
|
||
source .circleci/common.sh | ||
PYTORCH_DIR=/tmp/pytorch | ||
XLA_DIR=$PYTORCH_DIR/xla | ||
clone_pytorch $PYTORCH_DIR $XLA_DIR | ||
|
||
# Use bazel cache | ||
USE_CACHE=1 | ||
|
||
pushd $PYTORCH_DIR | ||
checkout_torch_pin_if_available | ||
|
||
if ! install_deps_pytorch_xla $XLA_DIR $USE_CACHE; then | ||
exit 1 | ||
fi | ||
|
||
apply_patches | ||
|
||
python -c "import fcntl; fcntl.fcntl(1, fcntl.F_SETFL, 0)" | ||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html | ||
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html | ||
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html | ||
|
||
export PATH=$PATH:/usr/local/cuda-12.1/bin | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64 | ||
export USE_CUDA=1 | ||
export TORCH_CUDA_ARCH_LIST='8.6' | ||
python setup.py install | ||
|
||
XLA_DIR=$PYTORCH_DIR/xla | ||
export TF_CUDA_COMPUTE_CAPABILITIES="compute_86" | ||
export XLA_CUDA=1 | ||
build_torch_xla $XLA_DIR |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,96 @@ | ||||
on: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same, can we augment the existing GPU flow to have Triton as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem with using the existing GPU flow is that it build pytorch without cuda dependency which doesn't work with Triton as Triton uses PyTorch to detect if there is a GPU device. It would be great if we can use this for now while I figure out how to merge the two CI workflows. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as the triton.sh. @will-cromar There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it feasible to merge your new tests into the existing workflow and upgrade the GPUs we use there? It's probably time we use more modern GPUs for all of our CI tests anyway. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can upgrade the runner here: xla/.github/workflows/build_and_test.yml Line 56 in 8a1ada8
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me check with pytorch folks first.. |
||||
pull_request: | ||||
branches: | ||||
- master | ||||
- r[0-9]+.[0-9]+ | ||||
paths: | ||||
- 'torch_xla/experimental/torch_triton.py' | ||||
push: | ||||
branches: | ||||
- master | ||||
- r[0-9]+.[0-9]+ | ||||
paths: | ||||
- 'torch_xla/experimental/torch_triton.py' | ||||
workflow_dispatch: | ||||
|
||||
concurrency: | ||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} | ||||
cancel-in-progress: true | ||||
|
||||
jobs: | ||||
build-triton: | ||||
runs-on: linux.24xlarge | ||||
timeout-minutes: 300 | ||||
outputs: | ||||
docker-image: ${{ steps.upload-docker-image.outputs.docker-image }} | ||||
env: | ||||
DOCKER_IMAGE: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1 | ||||
ECR_DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base | ||||
WORKDIR: /triton_dir | ||||
steps: | ||||
- name: Setup Linux | ||||
uses: pytorch/test-infra/.github/actions/setup-linux@main | ||||
- name: Setup SSH (Click me for login details) | ||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main | ||||
with: | ||||
github-secret: ${{ secrets.GITHUB_TOKEN }} | ||||
instructions: | | ||||
Tests are done inside the container, to start an interactive session run: | ||||
docker exec -it $(docker container ps --format '{{.ID}}') bash | ||||
- name: Checkout repo | ||||
uses: actions/checkout@v3 | ||||
- name: Download docker image from GCR | ||||
shell: bash | ||||
run: docker pull "${DOCKER_IMAGE}" | ||||
- name: Start the container | ||||
shell: bash | ||||
run: | | ||||
pid=$(docker run --privileged -t -d -w "${WORKDIR}" "${DOCKER_IMAGE}") | ||||
docker cp "${GITHUB_WORKSPACE}/." "$pid:$WORKDIR" | ||||
echo "pid=${pid}" >> "${GITHUB_ENV}" | ||||
- name: Build and Test | ||||
shell: bash | ||||
run: | | ||||
docker exec --privileged "${pid}" bash -c ".circleci/triton.sh" | ||||
- name: Push built docker image to ECR | ||||
id: upload-docker-image | ||||
shell: bash | ||||
run: | | ||||
export COMMIT_DOCKER_IMAGE="${ECR_DOCKER_IMAGE_BASE}:triton-${GITHUB_SHA}" | ||||
time docker commit "${pid}" "${COMMIT_DOCKER_IMAGE}" | ||||
time docker push "${COMMIT_DOCKER_IMAGE}" | ||||
echo "docker-image=${COMMIT_DOCKER_IMAGE}" >> "${GITHUB_OUTPUT}" | ||||
- name: Teardown Linux | ||||
uses: pytorch/test-infra/.github/actions/teardown-linux@main | ||||
if: always() | ||||
test-triton: | ||||
runs-on: linux.g5.4xlarge.nvidia.gpu | ||||
timeout-minutes: 300 | ||||
needs: build-triton | ||||
env: | ||||
DOCKER_IMAGE: ${{ needs.build-triton.outputs.docker-image }} | ||||
WORKDIR: /triton_dir | ||||
steps: | ||||
- name: Setup Linux | ||||
uses: pytorch/test-infra/.github/actions/setup-linux@main | ||||
- name: Setup SSH (Click me for login details) | ||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main | ||||
with: | ||||
github-secret: ${{ secrets.GITHUB_TOKEN }} | ||||
instructions: | | ||||
Tests are done inside the container, to start an interactive session run: | ||||
docker exec -it $(docker container ps --format '{{.ID}}') bash | ||||
- name: Download and run docker image from GCR | ||||
shell: bash | ||||
run: | | ||||
echo "DOCKER_IMAGE: ${DOCKER_IMAGE}" | ||||
docker pull "${DOCKER_IMAGE}" | ||||
pid=$(docker run --shm-size=16g ${GPU_FLAG:-} -t -d -w "$WORKDIR" "${DOCKER_IMAGE}") | ||||
echo "pid=${pid}" >> "${GITHUB_ENV}" | ||||
- name: Test | ||||
shell: bash | ||||
run: | | ||||
docker exec --privileged "${pid}" bash -c 'TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas python test/test_triton.py' | ||||
- name: Teardown Linux | ||||
uses: pytorch/test-infra/.github/actions/teardown-linux@main | ||||
if: always() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import logging | ||
import torch | ||
from torch import nn as nn | ||
import unittest | ||
|
||
import torch_xla.experimental.torch_triton as torch_triton | ||
bhavya01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import torch_xla | ||
from torch_xla import runtime as xr | ||
|
||
import triton | ||
import triton.language as tl | ||
|
||
|
||
@triton.jit | ||
def add_kernel( | ||
x_ptr, # *Pointer* to first input vector. | ||
y_ptr, # *Pointer* to second input vector. | ||
output_ptr, # *Pointer* to output vector. | ||
n_elements, # Size of the vector. | ||
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. | ||
# NOTE: `constexpr` so it can be used as a shape value. | ||
): | ||
# Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28 | ||
# There are multiple 'programs' processing different data. We identify which program | ||
# we are here: | ||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. | ||
# This program will process inputs that are offset from the initial data. | ||
# For instance, if you had a vector of length 256 and block_size of 64, the programs | ||
# would each access the elements [0:64, 64:128, 128:192, 192:256]. | ||
# Note that offsets is a list of pointers: | ||
block_start = pid * BLOCK_SIZE | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
# Create a mask to guard memory operations against out-of-bounds accesses. | ||
mask = offsets < n_elements | ||
# Load x and y from DRAM, masking out any extra elements in case the input is not a | ||
# multiple of the block size. | ||
x = tl.load(x_ptr + offsets, mask=mask) | ||
y = tl.load(y_ptr + offsets, mask=mask) | ||
output = x + y | ||
# Write x + y back to DRAM. | ||
tl.store(output_ptr + offsets, output, mask=mask) | ||
|
||
|
||
class TritonTest(unittest.TestCase): | ||
|
||
@unittest.skipIf(xr.device_type() != 'CUDA', "This test only works on GPU.") | ||
def test_gpu_custom_call_triton_add(self): | ||
size = 16 | ||
|
||
x = torch.arange(size, dtype=torch.int64).to("xla") | ||
y = torch.arange(size, dtype=torch.int64).to("xla") | ||
output = torch.empty_like(x) | ||
block_size = 8 | ||
grid = (triton.cdiv(size, block_size), ) | ||
payload = torch_triton.triton_call( | ||
bhavya01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size) | ||
output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload, | ||
[output.shape], [torch.int64]) | ||
output_torch = x + y | ||
self.assertTrue(torch.allclose(output[0].cpu(), output_torch.cpu())) | ||
|
||
|
||
if __name__ == '__main__': | ||
logging.getLogger().setLevel(logging.INFO) | ||
torch.set_default_dtype(torch.float32) | ||
torch.manual_seed(42) | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#include "torch_xla/csrc/ops/gpu_custom_call.h" | ||
|
||
#include "torch_xla/csrc/lowering_context.h" | ||
#include "torch_xla/csrc/ops/xla_ops.h" | ||
#include "torch_xla/csrc/xla_lower_util.h" | ||
|
||
namespace torch_xla { | ||
|
||
GpuCustomCall::GpuCustomCall(torch::lazy::OpList inputs, | ||
xla::Shape output_shape, | ||
const std::string& payload) | ||
: XlaNode(xla_gpu_custom_call, inputs, std::move(output_shape), | ||
/*num_outputs=*/output_shape.tuple_shapes_size(), | ||
torch::lazy::MHash(payload)), | ||
payload_(payload) {} | ||
|
||
torch::lazy::NodePtr GpuCustomCall::Clone(torch::lazy::OpList operands) const { | ||
return torch::lazy::MakeNode<GpuCustomCall>(operands, xla_shape(), payload_); | ||
} | ||
|
||
XlaOpVector GpuCustomCall::Lower(LoweringContext* loctx) const { | ||
std::vector<xla::XlaOp> inputs; | ||
inputs.reserve(operands().size()); | ||
for (auto& operand : operands()) { | ||
inputs.push_back(loctx->GetOutputOp(operand)); | ||
} | ||
auto output = BuildGpuCustomCall(inputs, xla_shape(), payload_); | ||
return ReturnOps(output, loctx); | ||
} | ||
|
||
std::string GpuCustomCall::ToString() const { | ||
std::stringstream ss; | ||
ss << XlaNode::ToString() << ", " << payload_; | ||
return ss.str(); | ||
} | ||
|
||
} // namespace torch_xla |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#ifndef XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ | ||
#define XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ | ||
|
||
#include "torch_xla/csrc/ir.h" | ||
|
||
namespace torch_xla { | ||
// TODO: Merge GPU and TPU custom call. | ||
bhavya01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class GpuCustomCall : public XlaNode { | ||
public: | ||
// Make a GPU custom call with payload, e.g., Triton. | ||
GpuCustomCall(torch::lazy::OpList inputs, xla::Shape output_shape, | ||
const std::string& payload); | ||
|
||
torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; | ||
|
||
XlaOpVector Lower(LoweringContext* loctx) const override; | ||
|
||
std::string ToString() const override; | ||
|
||
private: | ||
std::string payload_; | ||
}; | ||
|
||
} // namespace torch_xla | ||
|
||
#endif // XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need a new script to build torch-xla?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Triton only works with GPUs with compute capability > 7 and the existing CI GPUs are at 5.2. So, we need to rebuild pytorch with CUDA support for the new GPUs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment at the top to suggest why. @will-cromar can you review this part? Appreciate it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't add any new build scripts. These should have been deleted a long time ago, but nobody has had time to do a full refactor. There's already an ansible setting
cuda_compute_capabilities
to update the compute capabilities, which is set here:xla/.github/workflows/_build_plugin.yml
Line 42 in 77bbf7f