diff --git a/.github/workflows/torch_xla2.yml b/.github/workflows/torch_xla2.yml new file mode 100644 index 000000000000..ff5dab9c6798 --- /dev/null +++ b/.github/workflows/torch_xla2.yml @@ -0,0 +1,43 @@ +on: + pull_request: + branches: + - master + - r[0-9]+.[0-9]+ + paths: + - 'experimental/torch_xla2/**' + push: + branches: + - master + - r[0-9]+.[0-9]+ + paths: + - 'experimental/torch_xla2/**' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +jobs: + torchxla2-cpu: + runs-on: ubuntu-20.04 + steps: + - name: Checkout repo + uses: actions/checkout@v4 + with: + sparse-checkout: | + experimental/torch_xla2 + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install + shell: bash + working-directory: experimental/torch_xla2 + run: | + pip install -e . + pip install pytest + - name: Run tests + working-directory: experimental/torch_xla2 + shell: bash + run: | + pytest test/ \ No newline at end of file diff --git a/experimental/torch_xla2/test/test_conv.py b/experimental/torch_xla2/test/test_conv.py new file mode 100644 index 000000000000..de6873f0c0e4 --- /dev/null +++ b/experimental/torch_xla2/test/test_conv.py @@ -0,0 +1,79 @@ +import torch +from torch import nn +import torch_xla2 +from . import test_base + +class CustomConv1(torch.nn.Module): + + def __init__( + self, + channels_conv1=3, + width_conv1=3, + channels_conv2=5, + width_conv2=5, + hidden_layer_size=50, + ): + super(CustomConv1, self).__init__() + self.conv1 = nn.Conv1d(1, channels_conv1, width_conv1) + self.conv2 = nn.Conv1d(channels_conv1, channels_conv2, width_conv2) + self.fc1 = nn.Linear(hidden_layer_size, 2) + + def forward(self, x): + x = nn.functional.max_pool1d(nn.functional.relu(self.conv1(x)), 2, stride=2) + x = nn.functional.max_pool1d(nn.functional.relu(self.conv2(x)), 2, stride=2) + x = torch.flatten(x, 1) + x = nn.functional.softmax(self.fc1(x), dim=1) + return x + + +class CustomConv2(nn.Module): + + def __init__(self): + super().__init__() + inp = 4 + out = 16 + + self.conv = nn.Conv2d(inp, out, kernel_size=3, padding=1) + + # This is supposed to be a squeeze and excitation block. + self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + + self.scale = nn.Sequential(nn.Linear(out, out), nn.Sigmoid()) + + def forward(self, x): + x = self.conv(x) + + b = x.shape[0] + ap = self.avg_pool(x).view(b, -1) + ap = self.scale(ap) + ap = ap.view(b, -1, 1, 1) + + return x * ap + + +class ConvTest(test_base.TestCase): + + def test_conv1(self): + m = CustomConv1() + arg = torch.randn((20, 1, 50)) + res = m(arg) + + jax_weights, jax_func = torch_xla2.extract_jax(m) + arg = torch_xla2.tensor.t2j(arg) + res2 = jax_func(jax_weights, (arg, )) + res2_torch = torch_xla2.tensor.j2t(res2) + self.assertTrue(torch.allclose(res, res2_torch)) + + def test_conv2(self): + m = CustomConv2() + arg = torch.randn((20, 4, 50, 100)) + res = m(arg) + jax_weights, jax_func = torch_xla2.extract_jax(m) + arg = torch_xla2.tensor.t2j(arg) + res2 = jax_func(jax_weights, (arg, )) + res2_torch = torch_xla2.tensor.j2t(res2) + self.assertTrue(torch.allclose(res, res2_torch, atol=1e-4, rtol=1e-4)) + + +if __name__ == '__main__': + test_base.main() \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index 94e50b47c95a..2dd6340f1665 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -1,13 +1,16 @@ import jax import torch -import torch._functorch +from torch._functorch import make_functional +from torch.utils import _pytree as pytree from torch_xla2 import tensor +from torch_xla2 import export, ops, ops_registry, tensor, tf_integration def extract_jax(mod: torch.nn.Module): """Returns a pytree of jax.ndarray and a jax callable.""" - func, weights, buffer = torch._functorch.make_functional_with_buffers(mod) + func, weights, buffer = make_functional.make_functional_with_buffers(mod) states = (weights, buffer) + states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states) @jax.jit def jax_func(states, inputs): diff --git a/experimental/torch_xla2/torch_xla2/ops.py b/experimental/torch_xla2/torch_xla2/ops.py index aa0643c61bc4..bb47775e5204 100644 --- a/experimental/torch_xla2/torch_xla2/ops.py +++ b/experimental/torch_xla2/torch_xla2/ops.py @@ -2,7 +2,6 @@ """Torch ops implemented using jax.""" import sys -import flax import jax from jax import numpy as jnp import numpy as np @@ -508,8 +507,11 @@ def create_default_conv_dimension_numbers(num_spatial_dims): ) if bias is not None: - # TODO(qihqi): this is wrong - bias = bias.reshape(bias.shape + (1,)) + # TODO(qihqi): bias always on channel? + if len(bias.shape) == 1: + shape = [1] * len(res.shape) + shape[1] = bias.shape[0] + bias = bias.reshape(tuple(shape)) res = res + bias return res