Skip to content

Latest commit

 

History

History
114 lines (79 loc) · 4.72 KB

CONTRIBUTING.md

File metadata and controls

114 lines (79 loc) · 4.72 KB

Contribute To PyTorch/XLA

We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels.

If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of.

Building Manually

We recommend you to use our prebuilt Docker image to start your development work. If you want to use VSCode with docker, please refer to this config.

  • Setup Development Docker Image

    docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu
    docker run --privileged --name ptxla -it -d -e "TERM=xterm-256color" us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu
    docker exec --privileged -it ptxla /bin/bash

    All of the code below will be assumed to be run within the docker.

  • Clone the PyTorch repo as per instructions.

    git clone --recursive https://github.com/pytorch/pytorch
    cd pytorch/
  • Clone the PyTorch/XLA repo:

    git clone --recursive https://github.com/pytorch/xla.git
  • Build PyTorch

    python setup.py develop
  • Build PyTorch/XLA

    cd xla/
    python setup.py develop

Build PyTorch/XLA from source with GPU support

Please refer to this guide.

Before Submitting A Pull Request:

In pytorch/xla repo we enforce coding style for both C++ and Python files. Please try to format your code before submitting a pull request.

C++ Style Guide

pytorch/xla uses clang-format-11 with a customized style config. If your PR touches the C++ source files, please run the following command before submitting a PR.

# How to install: sudo apt install clang-format-11
# If your PR only changes foo.cpp, run the following in xla/ folder
clang-format-11 -i -style=file /PATH/TO/foo.cpp
# To format all cpp files, run the following in xla/ folder
find -name '*.cpp' -o -name '*.h' -o -name '*.cc' | xargs clang-format-11 -i -style=file

Python Style Guide

pytorch/xla uses yapf(specially version 0.30.0 in case it's not backward compatible) with a customized style config. If your PR touches the Python source files, please run the following command before submitting a PR.

# How to install: pip install yapf==0.30.0
yapf --recursive -i *.py test/ scripts/ torch_xla/ benchmarks/

Running the Tests

To run the tests, follow one of the options below:

  • Run on local CPU:

    export PJRT_DEVICE=CPU
  • Run on Cloud TPU:

    export PJRT_DEVICE=TPU
  • Run on GPU:

    export PJRT_DEVICE=CUDA GPU_NUM_DEVICES=${NUM_GPU}

For more detail on configuring the runtime, please refer to this doc

If you are planning to be building from source and hence using the latest PyTorch/TPU code base, it is suggested for you to select the Nightly builds when you create a Cloud TPU instance.

Then run test/run_tests.sh and test/cpp/run_tests.sh to verify the setup is working.

Useful materials

  1. OP Lowering Guide
  2. CODEGEN MIGRATION GUIDE
  3. Dynamo Integration Guide

Sharp Edges

  • If local changes aren't visible, uninstall existing pytorch/xla with pip uninstall torch_xla and pip uninstall torch, then rebuild PyTorch and PyTorch/XLA with python setup.py develop or python setup.py install.
  • PJRT errors when running on TPU such as The PJRT plugin has PJRT API version 0.34. The framework PJRT API version is 0.40. You need to update your libtpu.so and ensure it's in your LD_LIBRARY_PATH environmental directory. You can download a new libtpu.so at Google Cloud, which are sorted by date. Download the newest one and install it at pip install libtpu...whl.