-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
132 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
on: | ||
pull_request: | ||
branches: | ||
- master | ||
- r[0-9]+.[0-9]+ | ||
paths: | ||
- 'experimental/torch_xla2/**' | ||
push: | ||
branches: | ||
- master | ||
- r[0-9]+.[0-9]+ | ||
paths: | ||
- 'experimental/torch_xla2/**' | ||
workflow_dispatch: | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
torchxla2-cpu: | ||
runs-on: ubuntu-20.04 | ||
steps: | ||
- name: Checkout repo | ||
uses: actions/checkout@v4 | ||
with: | ||
sparse-checkout: | | ||
experimental/torch_xla2/** | ||
- name: Setup Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: '3.10' | ||
- name: Install | ||
shell: bash | ||
working-directory: experimental/torch_xla2 | ||
run: | | ||
pip install -e . | ||
pip install pytest | ||
- name: Run tests | ||
working-directory: experimental/torch_xla2 | ||
shell: bash | ||
run: | | ||
pytest test/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch | ||
from torch import nn | ||
import torch_xla2 | ||
from . import test_base | ||
|
||
class CustomConv1(torch.nn.Module): | ||
|
||
def __init__( | ||
self, | ||
channels_conv1=3, | ||
width_conv1=3, | ||
channels_conv2=5, | ||
width_conv2=5, | ||
hidden_layer_size=50, | ||
): | ||
super(CustomConv1, self).__init__() | ||
self.conv1 = nn.Conv1d(1, channels_conv1, width_conv1) | ||
self.conv2 = nn.Conv1d(channels_conv1, channels_conv2, width_conv2) | ||
self.fc1 = nn.Linear(hidden_layer_size, 2) | ||
|
||
def forward(self, x): | ||
x = nn.functional.max_pool1d(nn.functional.relu(self.conv1(x)), 2, stride=2) | ||
x = nn.functional.max_pool1d(nn.functional.relu(self.conv2(x)), 2, stride=2) | ||
x = torch.flatten(x, 1) | ||
x = nn.functional.softmax(self.fc1(x), dim=1) | ||
return x | ||
|
||
|
||
class CustomConv2(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
inp = 4 | ||
out = 16 | ||
|
||
self.conv = nn.Conv2d(inp, out, kernel_size=3, padding=1) | ||
|
||
# This is supposed to be a squeeze and excitation block. | ||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | ||
|
||
self.scale = nn.Sequential(nn.Linear(out, out), nn.Sigmoid()) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
|
||
b = x.shape[0] | ||
ap = self.avg_pool(x).view(b, -1) | ||
ap = self.scale(ap) | ||
ap = ap.view(b, -1, 1, 1) | ||
|
||
return x * ap | ||
|
||
|
||
class ConvTest(test_base.TestCase): | ||
|
||
def test_conv1(self): | ||
m = CustomConv1() | ||
arg = torch.randn((20, 1, 50)) | ||
res = m(arg) | ||
|
||
jax_weights, jax_func = torch_xla2.extract_jax(m) | ||
arg = torch_xla2.tensor.t2j(arg) | ||
res2 = jax_func(jax_weights, (arg, )) | ||
res2_torch = torch_xla2.tensor.j2t(res2) | ||
self.assertTrue(torch.allclose(res, res2_torch)) | ||
|
||
def test_conv2(self): | ||
m = CustomConv2() | ||
arg = torch.randn((20, 4, 50, 100)) | ||
res = m(arg) | ||
jax_weights, jax_func = torch_xla2.extract_jax(m) | ||
arg = torch_xla2.tensor.t2j(arg) | ||
res2 = jax_func(jax_weights, (arg, )) | ||
res2_torch = torch_xla2.tensor.j2t(res2) | ||
self.assertTrue(torch.allclose(res, res2_torch, atol=1e-4, rtol=1e-4)) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_base.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters