diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index bd1f9be5b0e..e7621777386 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -3,10 +3,14 @@ on: branches: - master - r[0-9]+.[0-9]+ + paths-ignore: + - 'experimental/torch_xla2/**' push: branches: - master - r[0-9]+.[0-9]+ + paths-ignore: + - 'experimental/torch_xla2/**' workflow_dispatch: concurrency: diff --git a/experimental/torch_xla2/LICENSE b/experimental/torch_xla2/LICENSE new file mode 100644 index 00000000000..1d064b89dc7 --- /dev/null +++ b/experimental/torch_xla2/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) 2023, pytorch-tpu + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/experimental/torch_xla2/README.md b/experimental/torch_xla2/README.md new file mode 100644 index 00000000000..97f336ea09e --- /dev/null +++ b/experimental/torch_xla2/README.md @@ -0,0 +1,3 @@ +# torchxla2 + +This directory contains things that are in the top-level git repository diff --git a/experimental/torch_xla2/dev-requirements.txt b/experimental/torch_xla2/dev-requirements.txt new file mode 100644 index 00000000000..6c6cb4d208e --- /dev/null +++ b/experimental/torch_xla2/dev-requirements.txt @@ -0,0 +1,8 @@ +-r requirements.txt +pytest +yapf +tabulate +transformers +tf-nightly +--pre -f https://download.pytorch.org/whl/nightly/torch_nightly.html +torchvision \ No newline at end of file diff --git a/experimental/torch_xla2/format.sh b/experimental/torch_xla2/format.sh new file mode 100755 index 00000000000..08efc04b399 --- /dev/null +++ b/experimental/torch_xla2/format.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -ex + +yapf --recursive -i *.py test torch_xla2 \ No newline at end of file diff --git a/experimental/torch_xla2/pyproject.toml b/experimental/torch_xla2/pyproject.toml new file mode 100644 index 00000000000..112e169e2c6 --- /dev/null +++ b/experimental/torch_xla2/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + + +[project] +version = "0.0.1" +name = "torch_xla2" +dependencies = [ + "torch>=2.1", + "jax>=0.4.24", + "jaxlib", +] + +requires-python = ">=3.10" +license = {file = "LICENSE"} diff --git a/experimental/torch_xla2/requirements.txt b/experimental/torch_xla2/requirements.txt new file mode 100644 index 00000000000..fa373ce770f --- /dev/null +++ b/experimental/torch_xla2/requirements.txt @@ -0,0 +1,5 @@ +jax==0.4.24.dev20240202 +jaxlib==0.4.24.dev20240202 +numpy==1.26.3 +torch==2.2.0 +typing_extensions==4.9.0 diff --git a/experimental/torch_xla2/test/BUILD b/experimental/torch_xla2/test/BUILD new file mode 100644 index 00000000000..b04071c18df --- /dev/null +++ b/experimental/torch_xla2/test/BUILD @@ -0,0 +1,31 @@ +# TODO(hanq): describe this package. + +load( + "//third_party/py/torch/google/bazel_rules/rules_python/python:defs.bzl", + "py_library", + "py_test", +) + +package( + default_applicable_licenses = ["//devtools/compliance/licenses:no_external_contributions"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +py_library( + name = "test_base", + srcs = ["test_base.py"], + deps = [ + "//testing/pybase", + ], +) + +py_test( + name = "test_core_aten_ops", + srcs = ["test_core_aten_ops.py"], + deps = [ + ":test_base", + "//third_party/py/absl:app", + "//third_party/py/torch/google/_torx", + ], +) diff --git a/experimental/torch_xla2/test/__init__.py b/experimental/torch_xla2/test/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/experimental/torch_xla2/test/llama/BUILD b/experimental/torch_xla2/test/llama/BUILD new file mode 100644 index 00000000000..5fd0fdf4b96 --- /dev/null +++ b/experimental/torch_xla2/test/llama/BUILD @@ -0,0 +1,25 @@ +# TODO(hanq): describe this package. +load( + "//third_party/py/torch/google/bazel_rules/rules_python/python:defs.bzl", + "py_test", +) + +package( + default_applicable_licenses = ["//devtools/compliance/licenses:no_external_contributions"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +py_test( + name = "test_llama", + srcs = [ + "llama_model.py", + "test_llama.py", + ], + deps = [ + "//third_party/py/jax", + "//third_party/py/torch:pytorch", + "//third_party/py/torch/google/_torx", + "//third_party/py/torch/google/_torx/test:test_base", + ], +) diff --git a/experimental/torch_xla2/test/llama/__init__.py b/experimental/torch_xla2/test/llama/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/experimental/torch_xla2/test/llama/llama_model.py b/experimental/torch_xla2/test/llama/llama_model.py new file mode 100644 index 00000000000..790afed0258 --- /dev/null +++ b/experimental/torch_xla2/test/llama/llama_model.py @@ -0,0 +1,309 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# This file is copied from https://github.com/pytorch-labs/gpt-fast +# This is used for unit test purposes +from dataclasses import dataclass +import math +from typing import Optional + +import torch +from torch import Tensor +import torch.nn as nn +from torch.nn import functional as F + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [ + config + for config in transformer_configs + if config in str(name).upper() or config in str(name) + ] + assert len(config) == 1, name + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "CodeLlama-7b-Python-hf": dict( + block_size=16384, + vocab_size=32000, + n_layer=32, + dim=4096, + rope_base=1000000, + ), + "7B": dict(n_layer=32, n_head=32, dim=4096), + "13B": dict(n_layer=40, n_head=40, dim=5120), + "30B": dict(n_layer=60, n_head=52, dim=6656), + "34B": dict( + n_layer=48, + n_head=64, + dim=8192, + vocab_size=32000, + n_local_heads=8, + intermediate_size=22016, + rope_base=1000000, + ), # CodeLlama-34B-Python-hf + "70B": dict( + n_layer=80, + n_head=64, + dim=8192, + n_local_heads=8, + intermediate_size=28672, + ), +} + + +class KVCache(nn.Module): + + def __init__( + self, + max_batch_size, + max_seq_length, + n_heads, + head_dim, + dtype=torch.bfloat16, + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class Transformer(nn.Module): + + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if ( + self.max_seq_length >= max_seq_length + and self.max_batch_size >= max_batch_size + ): + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, max_seq_length, self.config.n_local_heads, head_dim + ) + + self.freqs_cis = precompute_freqs_cis( + self.config.block_size, + self.config.dim // self.config.n_head, + self.config.rope_base, + ) + self.causal_mask = torch.tril( + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) + ) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor + ) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = ( + config.n_head + 2 * config.n_local_heads + ) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000 +) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] + - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/experimental/torch_xla2/test/llama/test_llama.py b/experimental/torch_xla2/test/llama/test_llama.py new file mode 100644 index 00000000000..e6f86ff22c0 --- /dev/null +++ b/experimental/torch_xla2/test/llama/test_llama.py @@ -0,0 +1,57 @@ +import unittest +import jax +import torch +from torch._functorch.make_functional import make_functional_with_buffers +from torch_xla2 import tensor, ops # pylint: disable=unused-import + +from .. import test_base +from . import llama_model +from torch.utils import _pytree as pytree + + +class LlamaTest(test_base.TestCase): + + def test_can_run(self): + sample_args = ( + torch.randint(0, 32000, (1, 2048)), + torch.arange(0, 2048), + ) + sample_args = pytree.tree_map(tensor.move_to_device, sample_args) + + model_args = llama_model.ModelArgs( + block_size=2048, + vocab_size=32000, + n_layer=2, + n_head=4, + dim=256, + ) + m = llama_model.Transformer(model_args) + m.to(torch.bfloat16) + m.setup_caches(1, 2048) + m_func, weights, buffer = make_functional_with_buffers(m) + + causal_mask = tensor.move_to_device(m.causal_mask) + freqs_cis = tensor.move_to_device(m.freqs_cis) + weights = pytree.tree_map(tensor.move_to_device, weights) + buffer = pytree.tree_map(tensor.move_to_device, buffer) + + @jax.jit + def m_func_jit(weights, buffer, args, causal_mask, freqs_cis): + weights, buffer, args, causal_mask, freqs_cis = tensor.wrap( + (weights, buffer, args, causal_mask, freqs_cis) + ) + m_func.stateless_model.freqs_cis = freqs_cis + m_func.stateless_model.causal_mask = causal_mask + res = m_func(weights, buffer, *args) + res = tensor.unwrap(res) + return res + + args = weights, buffer, sample_args, causal_mask, freqs_cis + args = tensor.unwrap(args) + # print(m_func_jit.lower(*args).as_text()) + res = m_func_jit(*args) + jax.block_until_ready(res) + + +if __name__ == "__main__": + test_base.main() diff --git a/experimental/torch_xla2/test/test_base.py b/experimental/torch_xla2/test/test_base.py new file mode 100644 index 00000000000..71f4dc97b67 --- /dev/null +++ b/experimental/torch_xla2/test/test_base.py @@ -0,0 +1,4 @@ +import unittest + +TestCase = unittest.TestCase +main = unittest.main diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py new file mode 100644 index 00000000000..dcb3f7bf00c --- /dev/null +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -0,0 +1,4514 @@ +import unittest + +import torch +from torch_xla2 import ops # pylint: disable=unused-import +from torch_xla2 import ops_registry +from torch_xla2 import tensor + +from . import test_base +from torch.utils import _pytree as pytree + + +def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True): + if isinstance(output1, torch.Tensor): + testcase.assertIsInstance(output2, torch.Tensor) + output2_cpu = output2.detach().cpu() + if output2_cpu.dtype != output1.dtype: + output2_cpu = output2_cpu.to(output1.dtype) + testcase.assertTrue( + torch.allclose( + output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan + ) + ) + elif isinstance(output1, (tuple, list)): + testcase.assertIsInstance(output2, (tuple, list)) + testcase.assertEqual(len(output1), len(output2)) + for o1, o2 in zip(output1, output2): + diff_output(testcase, o1, o2, rtol, atol) + else: + testcase.assertEqual(output1, output2) + + +def run_export_and_compare( + testcase, func, args, kwargs, atol=1e-3, rtol=1e-5, equal_nan=True +): + with testcase.subTest("torch_eval"): + res = func(*args, **kwargs) + with testcase.subTest("torch_xla2_eval"): + args2, kwargs2 = pytree.tree_map_only( + torch.Tensor, tensor.move_to_device, (args, kwargs) + ) + res2 = func(*args2, **kwargs2) + res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) + # import pdb; pdb.set_trace() + with testcase.subTest("torch_xla2_diff:" + str(atol)): + diff_output( + testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan + ) + + +class TestCoreAtenOps(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + ops_registry.print_missing_ops() + + def setUp(self): + super().setUp() + torch.manual_seed(0) + + def test_aten_abs_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.abs, args, kwargs) + + def test_aten_abs_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.abs, args, kwargs) + + def test_aten_abs_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.abs, args, kwargs) + + def test_aten_acos_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.acos, args, kwargs) + + def test_aten_acos_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.acos, args, kwargs) + + def test_aten_acos_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.acos, args, kwargs) + + def test_aten_acosh_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.acosh, args, kwargs) + + def test_aten_acosh_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.acosh, args, kwargs) + + def test_aten_acosh_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.acosh, args, kwargs) + + @unittest.skip + def test_aten_unsqueeze_0(self): + args = ( + torch.randn((1, 3, 10)).to(torch.float32), + -2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) + + @unittest.skip + def test_aten_unsqueeze_1(self): + args = ( + torch.randn((1, 3, 10)).to(torch.float16), + -2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) + + @unittest.skip + def test_aten_unsqueeze_2(self): + args = ( + torch.randint(0, 10, (1, 3, 10)).to(torch.int32), + -2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) + + @unittest.skip + def test_aten_unsqueeze_3(self): + args = ( + torch.randn((1, 3, 10)).to(torch.float32), + -2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) + + @unittest.skip + def test_aten_unsqueeze_4(self): + args = ( + torch.randn((1, 3, 10)).to(torch.float16), + -2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) + + @unittest.skip + def test_aten_unsqueeze_5(self): + args = ( + torch.randint(0, 10, (1, 3, 10)).to(torch.int32), + -2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) + + def test_aten_unsqueeze_6(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) + + def test_aten_unsqueeze_7(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) + + def test_aten_unsqueeze_8(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) + + @unittest.skip + def test_aten__adaptive_avg_pool2d_0(self): + args = ( + torch.randn((1, 3, 1, 10)).to(torch.float32), + [ + 1, + 5, + ], + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten._adaptive_avg_pool2d, args, kwargs + ) + + @unittest.skip + def test_aten__adaptive_avg_pool2d_1(self): + args = ( + torch.randn((1, 3, 10, 10)).to(torch.float32), + [ + 5, + 5, + ], + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten._adaptive_avg_pool2d, args, kwargs + ) + + @unittest.skip + def test_aten_squeeze_dim_0(self): + args = ( + torch.randn((1, 3, 1, 5)).to(torch.float32), + -2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) + + @unittest.skip + def test_aten_squeeze_dim_1(self): + args = ( + torch.randn((1, 3, 1, 5)).to(torch.float32), + -2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) + + @unittest.skip + def test_aten_squeeze_dim_2(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) + + @unittest.skip + def test_aten_squeeze_dim_3(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) + + @unittest.skip + def test_aten_squeeze_dim_4(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) + + @unittest.skip + def test_aten__adaptive_avg_pool3d_0(self): + args = ( + torch.randn((1, 3, 10, 10, 10)).to(torch.float32), + [ + 5, + 5, + 5, + ], + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten._adaptive_avg_pool3d, args, kwargs + ) + + @unittest.skip + def test_aten__adaptive_avg_pool3d_1(self): + args = ( + torch.randn((1, 3, 10, 10, 10)).to(torch.float16), + [ + 5, + 5, + 5, + ], + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten._adaptive_avg_pool3d, args, kwargs + ) + + def test_aten_add_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.add.Scalar, args, kwargs) + + def test_aten_add_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.add.Scalar, args, kwargs) + + def test_aten_add_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.add.Scalar, args, kwargs) + + def test_aten_add_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.add.Tensor, args, kwargs) + + def test_aten_add_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.add.Tensor, args, kwargs) + + def test_aten_add_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.add.Tensor, args, kwargs) + + @unittest.skip + def test_aten_addmm_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.addmm, args, kwargs) + + @unittest.skip + def test_aten_addmm_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.addmm, args, kwargs) + + @unittest.skip + def test_aten_addmm_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.addmm, args, kwargs) + + def test_aten_alias_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.alias, args, kwargs) + + def test_aten_alias_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.alias, args, kwargs) + + def test_aten_alias_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.alias, args, kwargs) + + def test_aten_amax_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.amax, args, kwargs) + + def test_aten_amax_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.amax, args, kwargs) + + def test_aten_amax_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.amax, args, kwargs) + + def test_aten_amin_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.amin, args, kwargs) + + def test_aten_amin_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.amin, args, kwargs) + + def test_aten_amin_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.amin, args, kwargs) + + def test_aten_any_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.any, args, kwargs) + + def test_aten_any_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.any, args, kwargs) + + def test_aten_any_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.any, args, kwargs) + + def test_aten_any_dim_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.any.dim, args, kwargs) + + def test_aten_any_dim_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.any.dim, args, kwargs) + + def test_aten_any_dim_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.any.dim, args, kwargs) + + def test_aten_any_dims_0(self): + args = (torch.randn((10, 10)).to(torch.float32), 0) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.any.dim, args, kwargs) + + def test_aten_any_dims_1(self): + args = (torch.randn((10, 10)).to(torch.float16), 0) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.any.dim, args, kwargs) + + def test_aten_any_dims_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32), 0) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.any.dim, args, kwargs) + + def test_aten_argmax_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.argmax, args, kwargs) + + def test_aten_argmax_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.argmax, args, kwargs) + + def test_aten_argmax_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.argmax, args, kwargs) + + def test_aten_argmin_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.argmin, args, kwargs) + + def test_aten_argmin_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.argmin, args, kwargs) + + def test_aten_argmin_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.argmin, args, kwargs) + + @unittest.skip + def test_aten_as_strided_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 0, + 1, + ], + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs) + + @unittest.skip + def test_aten_as_strided_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 0, + 1, + ], + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs) + + @unittest.skip + def test_aten_as_strided_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 0, + 1, + ], + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs) + + @unittest.skip + def test_aten_as_strided_copy_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 5, + 5, + ], + [ + 2, + 2, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.as_strided_copy, args, kwargs) + + @unittest.skip + def test_aten_as_strided_copy_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 5, + 5, + ], + [ + 2, + 2, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.as_strided_copy, args, kwargs) + + @unittest.skip + def test_aten_as_strided_copy_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 5, + 5, + ], + [ + 2, + 2, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.as_strided_copy, args, kwargs) + + def test_aten_asin_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.asin, args, kwargs) + + def test_aten_asin_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.asin, args, kwargs) + + def test_aten_asin_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.asin, args, kwargs) + + def test_aten_asinh_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.asinh, args, kwargs) + + def test_aten_asinh_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.asinh, args, kwargs) + + def test_aten_asinh_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.asinh, args, kwargs) + + def test_aten_atan_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.atan, args, kwargs) + + def test_aten_atan_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.atan, args, kwargs) + + @unittest.skip + def test_aten_atan_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.atan, args, kwargs) + + def test_aten_atan2_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.atan2, args, kwargs) + + def test_aten_atan2_1(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.atan2, args, kwargs) + + def test_aten_atanh_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.atanh, args, kwargs) + + def test_aten_atanh_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.atanh, args, kwargs) + + def test_aten_atanh_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.atanh, args, kwargs) + + @unittest.skip + def test_aten_avg_pool2d_0(self): + args = ( + torch.randn((1, 3, 1, 10)).to(torch.float32), + [ + 1, + 2, + ], + [ + 1, + 2, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.avg_pool2d, args, kwargs) + + @unittest.skip + def test_aten_avg_pool2d_1(self): + args = ( + torch.randn((3, 2, 10)).to(torch.float32), + [ + 2, + 2, + ], + [ + 1, + 1, + ], + [ + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.avg_pool2d, args, kwargs) + + @unittest.skip + def test_aten_avg_pool3d_0(self): + args = ( + torch.randn((1, 3, 10, 10, 10)).to(torch.float32), + [ + 2, + 2, + 2, + ], + [ + 2, + 2, + 2, + ], + [ + 0, + 0, + 0, + ], + False, + False, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.avg_pool3d, args, kwargs) + + def test_aten_bitwise_and_Scalar_0(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.bitwise_and.Scalar, args, kwargs + ) + + def test_aten_bitwise_and_Tensor_0(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.bitwise_and.Tensor, args, kwargs + ) + + def test_aten_bitwise_and_Tensor_1(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.bitwise_and.Tensor, args, kwargs + ) + + def test_aten_bitwise_and_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.bitwise_and.Tensor, args, kwargs + ) + + def test_aten_bitwise_and_Tensor_3(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.bitwise_and.Tensor, args, kwargs + ) + + def test_aten_bitwise_or_Scalar_0(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.bitwise_or.Scalar, args, kwargs) + + def test_aten_bitwise_xor_Scalar_0(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.bitwise_xor.Scalar, args, kwargs + ) + + @unittest.skip + def test_aten_bmm_0(self): + args = ( + torch.randn((10, 10, 10)).to(torch.float32), + torch.randn((10, 10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.bmm, args, kwargs) + + @unittest.skip + def test_aten_bmm_1(self): + args = ( + torch.randn((10, 10, 10)).to(torch.float16), + torch.randn((10, 10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.bmm, args, kwargs) + + @unittest.skip + def test_aten_bmm_2(self): + args = ( + torch.randint(0, 10, (10, 10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.bmm, args, kwargs) + + @unittest.skip + def test_aten_cat_0(self): + args = ( + [ + torch.randn((10, 10)).to(torch.float32), + ], + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.cat, args, kwargs) + + @unittest.skip + def test_aten_cat_1(self): + args = ( + [ + torch.randn((10, 10)).to(torch.float32), + ], + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.cat, args, kwargs) + + @unittest.skip + def test_aten_cat_2(self): + args = ( + [ + torch.randn((10, 10)).to(torch.float32), + ], + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.cat, args, kwargs) + + @unittest.skip + def test_aten__cdist_forward_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + 1.0, + None, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._cdist_forward, args, kwargs) + + def test_aten_ceil_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ceil, args, kwargs) + + def test_aten_ceil_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ceil, args, kwargs) + + def test_aten_ceil_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ceil, args, kwargs) + + def test_aten_clamp_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0, + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.clamp, args, kwargs) + + def test_aten_clamp_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0, + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.clamp, args, kwargs) + + def test_aten_clamp_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0, + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.clamp, args, kwargs) + + def test_aten_clamp_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((1,)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.clamp.Tensor, args, kwargs) + + def test_aten_clamp_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((1,)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.clamp.Tensor, args, kwargs) + + def test_aten_clamp_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (1,)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.clamp.Tensor, args, kwargs) + + def test_aten_clone_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.clone, args, kwargs) + + def test_aten_clone_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.clone, args, kwargs) + + def test_aten_clone_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.clone, args, kwargs) + + def test_aten_constant_pad_nd_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 0, + 1, + ], + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.constant_pad_nd, args, kwargs) + + def test_aten_constant_pad_nd_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 0, + 1, + ], + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.constant_pad_nd, args, kwargs) + + def test_aten_constant_pad_nd_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 0, + 1, + ], + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.constant_pad_nd, args, kwargs) + + def test_aten_convolution_0(self): + args = ( + torch.randn((3, 2, 10)).to(torch.float32), + torch.randn((2, 2, 2)).to(torch.float32), + None, + [ + 2, + ], + [ + 0, + ], + [ + 1, + ], + False, + [ + 0, + ], + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.convolution, args, kwargs) + + @unittest.skip + def test_aten_convolution_1(self): + args = ( + torch.randn((3, 2, 10)).to(torch.float16), + torch.randn((2, 2, 2)).to(torch.float16), + None, + [ + 2, + ], + [ + 0, + ], + [ + 1, + ], + False, + [ + 0, + ], + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.convolution, args, kwargs) + + def test_aten_convolution_2(self): + args = ( + torch.randint(0, 10, (3, 2, 10)).to(torch.int32), + torch.randint(0, 10, (2, 2, 2)).to(torch.int32), + None, + [ + 2, + ], + [ + 0, + ], + [ + 1, + ], + False, + [ + 0, + ], + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.convolution, args, kwargs) + + def test_aten_cos_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.cos, args, kwargs) + + def test_aten_cos_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.cos, args, kwargs) + + def test_aten_cos_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.cos, args, kwargs) + + def test_aten_cosh_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.cosh, args, kwargs) + + def test_aten_cosh_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.cosh, args, kwargs) + + def test_aten_cosh_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.cosh, args, kwargs) + + @unittest.skip + def test_aten_cumsum_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.cumsum, args, kwargs) + + @unittest.skip + def test_aten_cumsum_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.cumsum, args, kwargs) + + @unittest.skip + def test_aten_cumsum_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.cumsum, args, kwargs) + + @unittest.skip + def test_aten_diagonal_0(self): + args = (torch.randn((10, 20)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.diagonal, args, kwargs) + + @unittest.skip + def test_aten_diagonal_1(self): + args = (torch.randn((10, 20)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.diagonal, args, kwargs) + + @unittest.skip + def test_aten_diagonal_2(self): + args = (torch.randint(0, 10, (10, 20)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.diagonal, args, kwargs) + + def test_aten_div_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.div.Scalar, args, kwargs) + + def test_aten_div_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.div.Scalar, args, kwargs) + + def test_aten_div_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.5, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.div.Scalar, args, kwargs) + + @unittest.skip + def test_aten_div_Scalar_mode_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.123, + ) + kwargs = { + "rounding_mode": "trunc", + } + run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs) + + @unittest.skip + def test_aten_div_Scalar_mode_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.123, + ) + kwargs = { + "rounding_mode": "trunc", + } + run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs) + + def test_aten_div_Scalar_mode_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.123, + ) + kwargs = { + "rounding_mode": "trunc", + } + run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs) + + def test_aten_div_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs) + + def test_aten_div_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs) + + def test_aten_div_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs) + + @unittest.skip + def test_aten_div_Tensor_mode_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = { + "rounding_mode": "trunc", + } + run_export_and_compare(self, torch.ops.aten.div.Tensor_mode, args, kwargs) + + @unittest.skip + def test_aten_div_Tensor_mode_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = { + "rounding_mode": "trunc", + } + run_export_and_compare(self, torch.ops.aten.div.Tensor_mode, args, kwargs) + + def test_aten_embedding_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randint(0, 10, (10,)).to(torch.int64), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.embedding, args, kwargs) + + def test_aten_embedding_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randint(0, 10, (10,)).to(torch.int64), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.embedding, args, kwargs) + + def test_aten_embedding_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10,)).to(torch.int64), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.embedding, args, kwargs) + + def test_aten_eq_Scalar_0(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.eq.Scalar, args, kwargs) + + def test_aten_eq_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.eq.Scalar, args, kwargs) + + def test_aten_eq_Scalar_2(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.eq.Scalar, args, kwargs) + + def test_aten_eq_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.eq.Tensor, args, kwargs) + + def test_aten_eq_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.eq.Tensor, args, kwargs) + + def test_aten_eq_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.eq.Tensor, args, kwargs) + + @unittest.skip + def test_aten_erf_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.erf, args, kwargs) + + @unittest.skip + def test_aten_erf_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.erf, args, kwargs) + + @unittest.skip + def test_aten_erf_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.erf, args, kwargs) + + def test_aten_exp_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.exp, args, kwargs) + + def test_aten_exp_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.exp, args, kwargs) + + def test_aten_exp_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.exp, args, kwargs) + + def test_aten_expand_0(self): + args = ( + torch.randn((10, 1)).to(torch.float32), + [ + 10, + 10, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.expand, args, kwargs) + + def test_aten_expand_1(self): + args = ( + torch.randn((10, 1)).to(torch.float16), + [ + 10, + 10, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.expand, args, kwargs) + + def test_aten_expand_2(self): + args = ( + torch.randint(0, 10, (10, 1)).to(torch.int32), + [ + 10, + 10, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.expand, args, kwargs) + + @unittest.skip + def test_aten_expand_copy_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 10, + 10, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.expand_copy, args, kwargs) + + @unittest.skip + def test_aten_expand_copy_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 10, + 10, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.expand_copy, args, kwargs) + + @unittest.skip + def test_aten_expand_copy_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 10, + 10, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.expand_copy, args, kwargs) + + def test_aten_expm1_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.expm1, args, kwargs) + + def test_aten_expm1_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.expm1, args, kwargs, rtol=1e-3) + + def test_aten_expm1_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.expm1, args, kwargs) + + @unittest.skip + def test_aten_fill_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.fill.Scalar, args, kwargs) + + @unittest.skip + def test_aten_fill_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.fill.Scalar, args, kwargs) + + @unittest.skip + def test_aten_fill_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.fill.Scalar, args, kwargs) + + @unittest.skip + def test_aten_fill_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn(()).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.fill.Tensor, args, kwargs) + + @unittest.skip + def test_aten_fill_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn(()).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.fill.Tensor, args, kwargs) + + @unittest.skip + def test_aten_fill_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, ()).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.fill.Tensor, args, kwargs) + + def test_aten_flip_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.flip, args, kwargs) + + def test_aten_flip_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.flip, args, kwargs) + + def test_aten_flip_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.flip, args, kwargs) + + def test_aten_floor_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.floor, args, kwargs) + + def test_aten_floor_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.floor, args, kwargs) + + def test_aten_floor_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.floor, args, kwargs) + + @unittest.skip + def test_aten_floor_divide_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.floor_divide, args, kwargs) + + @unittest.skip + def test_aten_floor_divide_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.floor_divide, args, kwargs) + + @unittest.skip + def test_aten_fmod_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.fmod.Scalar, args, kwargs) + + @unittest.skip + def test_aten_fmod_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.fmod.Scalar, args, kwargs) + + @unittest.skip + def test_aten_fmod_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.fmod.Scalar, args, kwargs) + + @unittest.skip + def test_aten_fmod_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.fmod.Tensor, args, kwargs) + + @unittest.skip + def test_aten_fmod_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.fmod.Tensor, args, kwargs) + + @unittest.skip + def test_aten_full_like_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.full_like, args, kwargs) + + @unittest.skip + def test_aten_full_like_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.full_like, args, kwargs) + + @unittest.skip + def test_aten_full_like_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.full_like, args, kwargs) + + @unittest.skip + def test_aten_gather_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + torch.randint(0, 10, (2, 2)).to(torch.int64), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.gather, args, kwargs) + + @unittest.skip + def test_aten_gather_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + torch.randint(0, 10, (2, 2)).to(torch.int64), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.gather, args, kwargs) + + @unittest.skip + def test_aten_gather_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + torch.randint(0, 10, (2, 2)).to(torch.int64), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.gather, args, kwargs) + + @unittest.skip + def test_aten_ge_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ge.Scalar, args, kwargs) + + @unittest.skip + def test_aten_ge_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ge.Scalar, args, kwargs) + + @unittest.skip + def test_aten_ge_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ge.Scalar, args, kwargs) + + @unittest.skip + def test_aten_ge_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ge.Tensor, args, kwargs) + + @unittest.skip + def test_aten_ge_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ge.Tensor, args, kwargs) + + @unittest.skip + def test_aten_ge_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ge.Tensor, args, kwargs) + + @unittest.skip + def test_aten_gelu_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs) + + @unittest.skip + def test_aten_gelu_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs) + + @unittest.skip + def test_aten_glu_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.glu, args, kwargs) + + @unittest.skip + def test_aten_glu_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.glu, args, kwargs) + + @unittest.skip + def test_aten_grid_sampler_2d_0(self): + args = ( + torch.randn((1, 3, 2, 10)).to(torch.float32), + torch.randn((1, 2, 2, 2)).to(torch.float32), + 0, + 0, + False, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.grid_sampler_2d, args, kwargs) + + def test_aten_gt_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.gt.Scalar, args, kwargs) + + def test_aten_gt_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.gt.Scalar, args, kwargs) + + def test_aten_gt_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.gt.Scalar, args, kwargs) + + def test_aten_gt_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.gt.Tensor, args, kwargs) + + def test_aten_gt_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.gt.Tensor, args, kwargs) + + def test_aten_gt_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.gt.Tensor, args, kwargs) + + def test_aten_hardtanh_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.hardtanh, args, kwargs) + + def test_aten_hardtanh_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.hardtanh, args, kwargs) + + def test_aten_hardtanh_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.hardtanh, args, kwargs) + + def test_aten_index_put_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + torch.randint(0, 10, (1,)).to(torch.int64), + ], + torch.randn((10,)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs) + + def test_aten_index_put_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + torch.randint(0, 10, (1,)).to(torch.int64), + ], + torch.randn((10,)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs) + + def test_aten_index_put_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + torch.randint(0, 10, (1,)).to(torch.int64), + ], + torch.randint(0, 10, (10,)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs) + + def test_aten_index_select_0(self): + args = ( + torch.randn((2, 10)).to(torch.float32), + 1, + torch.randint(0, 10, (2,)).to(torch.int64), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) + + def test_aten_index_select_1(self): + args = ( + torch.randn((2, 10)).to(torch.float16), + 1, + torch.randint(0, 10, (2,)).to(torch.int64), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) + + def test_aten_index_select_2(self): + args = ( + torch.randint(0, 10, (2, 10)).to(torch.int32), + 1, + torch.randint(0, 10, (2,)).to(torch.int64), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) + + @unittest.skip + def test_aten_index_Tensor_0(self): + args = ( + torch.randn((2, 10)).to(torch.float32), + [ + torch.randint(0, 10, (2,)).to(torch.int64), + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index.Tensor, args, kwargs) + + @unittest.skip + def test_aten_index_Tensor_1(self): + args = ( + torch.randn((2, 10)).to(torch.float16), + [ + torch.randint(0, 10, (2,)).to(torch.int64), + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index.Tensor, args, kwargs) + + @unittest.skip + def test_aten_index_Tensor_2(self): + args = ( + torch.randint(0, 10, (2, 10)).to(torch.int32), + [ + torch.randint(0, 10, (2,)).to(torch.int64), + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index.Tensor, args, kwargs) + + def test_aten_isinf_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.isinf, args, kwargs) + + def test_aten_isinf_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.isinf, args, kwargs) + + def test_aten_isinf_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.isinf, args, kwargs) + + def test_aten_isnan_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.isnan, args, kwargs) + + def test_aten_isnan_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.isnan, args, kwargs) + + def test_aten_isnan_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.isnan, args, kwargs) + + def test_aten_le_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.le.Scalar, args, kwargs) + + def test_aten_le_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.le.Scalar, args, kwargs) + + def test_aten_le_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.le.Scalar, args, kwargs) + + def test_aten_le_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.le.Tensor, args, kwargs) + + def test_aten_le_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.le.Tensor, args, kwargs) + + def test_aten_le_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.le.Tensor, args, kwargs) + + @unittest.skip + def test_aten_leaky_relu_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.leaky_relu, args, kwargs) + + @unittest.skip + def test_aten_leaky_relu_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.leaky_relu, args, kwargs) + + @unittest.skip + def test_aten_lift_fresh_copy_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.lift_fresh_copy, args, kwargs) + + @unittest.skip + def test_aten_lift_fresh_copy_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.lift_fresh_copy, args, kwargs) + + @unittest.skip + def test_aten_lift_fresh_copy_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.lift_fresh_copy, args, kwargs) + + def test_aten_log_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.log, args, kwargs) + + def test_aten_log_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.log, args, kwargs) + + def test_aten_log_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.log, args, kwargs) + + def test_aten_log10_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.log10, args, kwargs) + + def test_aten_log10_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.log10, args, kwargs, + atol=0.001, rtol=0.001) + + def test_aten_log10_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.log10, args, kwargs) + + def test_aten_log1p_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.log1p, args, kwargs) + + def test_aten_log1p_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.log1p, args, kwargs, + atol=0.001, rtol=0.001) + + def test_aten_log1p_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.log1p, args, kwargs) + + def test_aten_log2_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.log2, args, kwargs) + + def test_aten_log2_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.log2, args, kwargs, + atol=0.001, rtol=0.001) + + def test_aten_log2_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.log2, args, kwargs) + + def test_aten__log_softmax_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + False, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._log_softmax, args, kwargs) + + def test_aten_logical_and_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs) + + def test_aten_logical_and_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs) + + def test_aten_logical_and_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs) + + @unittest.skip + def test_aten_logical_not_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_not, args, kwargs) + + @unittest.skip + def test_aten_logical_not_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_not, args, kwargs) + + @unittest.skip + def test_aten_logical_not_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_not, args, kwargs) + + @unittest.skip + def test_aten_logical_or_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_or, args, kwargs) + + @unittest.skip + def test_aten_logical_or_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_or, args, kwargs) + + @unittest.skip + def test_aten_logical_or_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_or, args, kwargs) + + def test_aten_logical_xor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_xor, args, kwargs) + + def test_aten_logical_xor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_xor, args, kwargs) + + def test_aten_logical_xor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logical_xor, args, kwargs) + + @unittest.skip + def test_aten_logit_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logit, args, kwargs) + + @unittest.skip + def test_aten_logit_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logit, args, kwargs) + + @unittest.skip + def test_aten_logit_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.logit, args, kwargs) + + def test_aten_lt_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.lt.Scalar, args, kwargs) + + def test_aten_lt_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.lt.Scalar, args, kwargs) + + def test_aten_lt_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.lt.Scalar, args, kwargs) + + def test_aten_lt_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.lt.Tensor, args, kwargs) + + def test_aten_lt_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.lt.Tensor, args, kwargs) + + def test_aten_lt_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.lt.Tensor, args, kwargs) + + @unittest.skip + def test_aten_masked_fill_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.bool), + 0.123, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.masked_fill.Scalar, args, kwargs + ) + + @unittest.skip + def test_aten_masked_fill_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.bool), + 0.123, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.masked_fill.Scalar, args, kwargs + ) + + @unittest.skip + def test_aten_masked_fill_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randn((10, 10)).to(torch.bool), + 0.123, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.masked_fill.Scalar, args, kwargs + ) + + def test_aten_max_dim_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.max.dim, args, kwargs) + + def test_aten_max_dim_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.max.dim, args, kwargs) + + def test_aten_max_dim_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.max.dim, args, kwargs) + + @unittest.skip + def test_aten_max_pool2d_with_indices_0(self): + args = ( + torch.randn((3, 2, 10)).to(torch.float32), + [ + 2, + 2, + ], + [ + 1, + 1, + ], + [ + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.max_pool2d_with_indices, args, kwargs + ) + + @unittest.skip + def test_aten_max_pool2d_with_indices_1(self): + args = ( + torch.randn((3, 2, 10)).to(torch.float16), + [ + 2, + 2, + ], + [ + 1, + 1, + ], + [ + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.max_pool2d_with_indices, args, kwargs + ) + + @unittest.skip + def test_aten_max_pool2d_with_indices_2(self): + args = ( + torch.randint(0, 10, (3, 2, 10)).to(torch.int32), + [ + 2, + 2, + ], + [ + 1, + 1, + ], + [ + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.max_pool2d_with_indices, args, kwargs + ) + + @unittest.skip + def test_aten_max_pool3d_with_indices_0(self): + args = ( + torch.randn((1, 3, 2, 10)).to(torch.float32), + [ + 2, + 2, + 2, + ], + [ + 1, + 1, + 1, + ], + [ + 1, + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.max_pool3d_with_indices, args, kwargs + ) + + @unittest.skip + def test_aten_max_pool3d_with_indices_1(self): + args = ( + torch.randn((1, 3, 2, 10)).to(torch.float16), + [ + 2, + 2, + 2, + ], + [ + 1, + 1, + 1, + ], + [ + 1, + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.max_pool3d_with_indices, args, kwargs + ) + + @unittest.skip + def test_aten_max_pool3d_with_indices_2(self): + args = ( + torch.randint(0, 10, (1, 3, 2, 10)).to(torch.int32), + [ + 2, + 2, + 2, + ], + [ + 1, + 1, + 1, + ], + [ + 1, + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.max_pool3d_with_indices, args, kwargs + ) + + def test_aten_maximum_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.maximum, args, kwargs) + + def test_aten_maximum_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.maximum, args, kwargs) + + def test_aten_maximum_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.maximum, args, kwargs) + + @unittest.skip + def test_aten_mean_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mean, args, kwargs) + + @unittest.skip + def test_aten_mean_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mean, args, kwargs) + + @unittest.skip + def test_aten_mean_dim_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + None, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mean.dim, args, kwargs) + + @unittest.skip + def test_aten_mean_dim_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + None, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mean.dim, args, kwargs) + + def test_aten_min_dim_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.min.dim, args, kwargs) + + def test_aten_min_dim_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.min.dim, args, kwargs) + + def test_aten_min_dim_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.min.dim, args, kwargs) + + def test_aten_minimum_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.minimum, args, kwargs) + + def test_aten_minimum_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.minimum, args, kwargs) + + def test_aten_minimum_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.minimum, args, kwargs) + + @unittest.skip + def test_aten_mm_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mm, args, kwargs) + + @unittest.skip + def test_aten_mm_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mm, args, kwargs) + + @unittest.skip + def test_aten_mm_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mm, args, kwargs) + + def test_aten_mul_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mul.Scalar, args, kwargs) + + def test_aten_mul_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mul.Scalar, args, kwargs) + + def test_aten_mul_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mul.Scalar, args, kwargs) + + def test_aten_mul_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mul.Tensor, args, kwargs) + + def test_aten_mul_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mul.Tensor, args, kwargs) + + def test_aten_mul_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.mul.Tensor, args, kwargs) + + @unittest.skip + def test_aten__native_batch_norm_legit_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + None, + None, + torch.randn((10,)).to(torch.float32), + torch.randn((10,)).to(torch.float32), + False, + 1.0, + 1.0, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten._native_batch_norm_legit, args, kwargs + ) + + @unittest.skip + def test_aten__native_batch_norm_legit_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + None, + None, + torch.randn((10,)).to(torch.float16), + torch.randn((10,)).to(torch.float16), + False, + 1.0, + 1.0, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten._native_batch_norm_legit, args, kwargs + ) + + @unittest.skip + def test_aten__native_batch_norm_legit_no_stats_0(self): + args = ( + torch.randn((1, 3, 2, 10)).to(torch.float32), + torch.randn((1, 3, 2, 10)).to(torch.float32), + torch.randn((1, 3, 2, 10)).to(torch.float32), + True, + 0.0, + 1.0, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten._native_batch_norm_legit.no_stats, args, kwargs + ) + + @unittest.skip + def test_aten__native_batch_norm_legit_no_stats_1(self): + args = ( + torch.randn((1, 3, 2, 10)).to(torch.float16), + torch.randn((1, 3, 2, 10)).to(torch.float16), + torch.randn((1, 3, 2, 10)).to(torch.float16), + True, + 0.0, + 1.0, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten._native_batch_norm_legit.no_stats, args, kwargs + ) + + def test_aten__native_batch_norm_legit_no_training_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + None, + None, + torch.randn((10,)).to(torch.float32), + torch.randn((10,)).to(torch.float32), + 1.0, + 1.0, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten._native_batch_norm_legit_no_training, args, kwargs + ) + + @unittest.skip + def test_aten__native_batch_norm_legit_no_training_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + None, + None, + torch.randn((10,)).to(torch.float16), + torch.randn((10,)).to(torch.float16), + 1.0, + 1.0, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten._native_batch_norm_legit_no_training, args, kwargs + ) + + @unittest.skip + def test_aten_native_dropout_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1.0, + None, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_dropout, args, kwargs) + + @unittest.skip + def test_aten_native_dropout_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1.0, + None, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_dropout, args, kwargs) + + @unittest.skip + def test_aten_native_group_norm_0(self): + args = ( + torch.randn((1, 3, 2, 10)).to(torch.float32), + None, + None, + 1, + 3, + 20, + 1, + 0.0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs) + + @unittest.skip + def test_aten_native_group_norm_1(self): + args = ( + torch.randn((1, 3, 2, 10)).to(torch.float16), + None, + None, + 1, + 3, + 20, + 1, + 0.0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs) + + @unittest.skip + def test_aten_native_layer_norm_0(self): + args = ( + torch.randn((1, 3, 2, 10)).to(torch.float32), + [ + 1, + 3, + 2, + 10, + ], + None, + None, + 0.0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.native_layer_norm, args, kwargs) + + @unittest.skip + def test_aten_ne_Scalar_0(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ne.Scalar, args, kwargs) + + @unittest.skip + def test_aten_ne_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ne.Scalar, args, kwargs) + + @unittest.skip + def test_aten_ne_Scalar_2(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ne.Scalar, args, kwargs) + + @unittest.skip + def test_aten_ne_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ne.Tensor, args, kwargs) + + @unittest.skip + def test_aten_ne_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ne.Tensor, args, kwargs) + + @unittest.skip + def test_aten_ne_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.ne.Tensor, args, kwargs) + + @unittest.skip + def test_aten_neg_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.neg, args, kwargs) + + @unittest.skip + def test_aten_neg_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.neg, args, kwargs) + + @unittest.skip + def test_aten_neg_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.neg, args, kwargs) + + @unittest.skip + def test_aten_nonzero_0(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.nonzero, args, kwargs) + + @unittest.skip + def test_aten_nonzero_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.nonzero, args, kwargs) + + @unittest.skip + def test_aten_nonzero_2(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.nonzero, args, kwargs) + + @unittest.skip + def test_aten__pdist_forward_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1.0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._pdist_forward, args, kwargs) + + def test_aten_permute_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.permute, args, kwargs) + + def test_aten_permute_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.permute, args, kwargs) + + def test_aten_permute_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.permute, args, kwargs) + + @unittest.skip + def test_aten_permute_copy_0(self): + args = ( + torch.randn((2, 2, 2)).to(torch.float32), + [ + 1, + 2, + 0, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.permute_copy, args, kwargs) + + @unittest.skip + def test_aten_permute_copy_1(self): + args = ( + torch.randn((2, 2, 2)).to(torch.float16), + [ + 1, + 2, + 0, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.permute_copy, args, kwargs) + + @unittest.skip + def test_aten_permute_copy_2(self): + args = ( + torch.randint(0, 10, (2, 2, 2)).to(torch.int32), + [ + 1, + 2, + 0, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.permute_copy, args, kwargs) + + @unittest.skip + def test_aten_pixel_shuffle_0(self): + args = ( + torch.randn((1, 3, 10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.pixel_shuffle, args, kwargs) + + @unittest.skip + def test_aten_pixel_shuffle_1(self): + args = ( + torch.randn((1, 3, 10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.pixel_shuffle, args, kwargs) + + @unittest.skip + def test_aten_pixel_shuffle_2(self): + args = ( + torch.randint(0, 10, (1, 3, 10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.pixel_shuffle, args, kwargs) + + @unittest.skip + def test_aten_pow_Scalar_0(self): + args = ( + 1.123, + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.pow.Scalar, args, kwargs) + + def test_aten_pow_Tensor_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1.2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.pow.Tensor_Scalar, args, kwargs) + + def test_aten_pow_Tensor_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1.2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.pow.Tensor_Scalar, args, kwargs) + + def test_aten_pow_Tensor_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1.2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.pow.Tensor_Scalar, args, kwargs) + + @unittest.skip + def test_aten_pow_Scalar_1(self): + args = (10000, torch.randn(16 * 8)) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.pow.Scalar, args, kwargs) + + @unittest.skip + def test_aten_pow_Tensor_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.pow.Tensor_Tensor, args, kwargs) + + @unittest.skip + def test_aten_pow_Tensor_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.pow.Tensor_Tensor, args, kwargs) + + @unittest.skip + def test_aten_pow_Tensor_Tensor_2(self): + args = ( + torch.randint(0, 5, (10, 10)).to(torch.int32), + torch.randint(0, 5, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.pow.Tensor_Tensor, args, kwargs) + + @unittest.skip + def test_aten_prod_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.prod, args, kwargs) + + @unittest.skip + def test_aten_prod_1(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.prod, args, kwargs) + + @unittest.skip + def test_aten_prod_dim_int_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.prod.dim_int, args, kwargs) + + @unittest.skip + def test_aten_prod_dim_int_1(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.prod.dim_int, args, kwargs) + + @unittest.skip + def test_aten_reciprocal_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs) + + @unittest.skip + def test_aten_reciprocal_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs) + + @unittest.skip + def test_aten_reciprocal_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs) + + def test_aten_reflection_pad1d_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.reflection_pad1d, args, kwargs) + + @unittest.skip + def test_aten_reflection_pad1d_1(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.reflection_pad1d, args, kwargs) + + @unittest.skip + def test_aten_reflection_pad2d_0(self): + args = ( + torch.randn((3, 2, 10)).to(torch.float32), + [ + 1, + 1, + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.reflection_pad2d, args, kwargs) + + @unittest.skip + def test_aten_reflection_pad2d_1(self): + args = ( + torch.randint(0, 10, (3, 2, 10)).to(torch.int32), + [ + 1, + 1, + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.reflection_pad2d, args, kwargs) + + @unittest.skip + def test_aten_reflection_pad3d_0(self): + args = ( + torch.randn((3, 3, 3, 3, 3, 3)).to(torch.float32), + [ + 1, + 2, + 1, + 2, + 1, + 2, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs) + + @unittest.skip + def test_aten_reflection_pad3d_1(self): + args = ( + torch.randn((3, 3, 3, 3, 3, 3)).to(torch.float16), + [ + 1, + 2, + 1, + 2, + 1, + 2, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs) + + @unittest.skip + def test_aten_reflection_pad3d_2(self): + args = ( + torch.randint(0, 10, (3, 3, 3, 3, 3, 3)).to(torch.int32), + [ + 1, + 2, + 1, + 2, + 1, + 2, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs) + + def test_aten_relu_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.relu, args, kwargs) + + def test_aten_relu_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.relu, args, kwargs) + + def test_aten_relu_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.relu, args, kwargs) + + @unittest.skip + def test_aten_remainder_Scalar_0(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.remainder.Scalar, args, kwargs) + + @unittest.skip + def test_aten_remainder_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.remainder.Scalar, args, kwargs) + + @unittest.skip + def test_aten_remainder_Scalar_2(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.remainder.Scalar, args, kwargs) + + @unittest.skip + def test_aten_remainder_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.remainder.Tensor, args, kwargs) + + @unittest.skip + def test_aten_remainder_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.remainder.Tensor, args, kwargs) + + @unittest.skip + def test_aten_replication_pad2d_0(self): + args = ( + torch.randn((3, 2, 10)).to(torch.float32), + [ + 1, + 1, + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.replication_pad2d, args, kwargs) + + @unittest.skip + def test_aten_replication_pad2d_1(self): + args = ( + torch.randint(0, 10, (3, 2, 10)).to(torch.int32), + [ + 1, + 1, + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.replication_pad2d, args, kwargs) + + @unittest.skip + def test_aten_replication_pad3d_0(self): + args = ( + torch.randn((1, 3, 2, 10)).to(torch.float32), + [ + 1, + 1, + 1, + 1, + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.replication_pad3d, args, kwargs) + + @unittest.skip + def test_aten_replication_pad3d_1(self): + args = ( + torch.randint(0, 10, (1, 3, 2, 10)).to(torch.int32), + [ + 1, + 1, + 1, + 1, + 1, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.replication_pad3d, args, kwargs) + + @unittest.skip + def test_aten_resize__0(self): + args = ( + torch.randn((2, 5, 10)).to(torch.float32), + [ + 2, + 5, + 10, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.resize_, args, kwargs) + + @unittest.skip + def test_aten_resize__1(self): + args = ( + torch.randn((2, 5, 10)).to(torch.float16), + [ + 2, + 5, + 10, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.resize_, args, kwargs) + + @unittest.skip + def test_aten_resize__2(self): + args = ( + torch.randint(0, 10, (2, 5, 10)).to(torch.int32), + [ + 2, + 5, + 10, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.resize_, args, kwargs) + + @unittest.skip + def test_aten_roll_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 0, + 1, + ], + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.roll, args, kwargs) + + @unittest.skip + def test_aten_roll_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 0, + 1, + ], + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.roll, args, kwargs) + + @unittest.skip + def test_aten_roll_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 0, + 1, + ], + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.roll, args, kwargs) + + @unittest.skip + def test_aten_round_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.round, args, kwargs) + + @unittest.skip + def test_aten_round_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.round, args, kwargs) + + @unittest.skip + def test_aten_round_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.round, args, kwargs) + + def test_aten_rsqrt_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.rsqrt, args, kwargs) + + @unittest.skip + def test_aten_rsqrt_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.rsqrt, args, kwargs) + + @unittest.skip + def test_aten_rsqrt_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.rsqrt, args, kwargs) + + def test_aten_rsub_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.rsub.Scalar, args, kwargs) + + def test_aten_rsub_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.rsub.Scalar, args, kwargs) + + def test_aten_rsub_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.rsub.Scalar, args, kwargs) + + @unittest.skip + def test_aten_scatter_add_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + torch.randint(0, 10, (2, 2)).to(torch.int64), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.scatter_add, args, kwargs) + + @unittest.skip + def test_aten_scatter_add_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + torch.randint(0, 10, (2, 2)).to(torch.int64), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.scatter_add, args, kwargs) + + @unittest.skip + def test_aten_scatter_add_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + torch.randint(0, 10, (2, 2)).to(torch.int64), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.scatter_add, args, kwargs) + + @unittest.skip + def test_aten_scatter_reduce_two_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + torch.randint(0, 10, (10, 10)).to(torch.int64), + torch.randn((10, 10)).to(torch.float32), + "sum", + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.scatter_reduce.two, args, kwargs + ) + + @unittest.skip + def test_aten_scatter_reduce_two_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + torch.randint(0, 10, (10, 10)).to(torch.int64), + torch.randn((10, 10)).to(torch.float16), + "sum", + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.scatter_reduce.two, args, kwargs + ) + + @unittest.skip + def test_aten_scatter_reduce_two_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + torch.randint(0, 10, (10, 10)).to(torch.int64), + torch.randint(0, 10, (10, 10)).to(torch.int32), + "sum", + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.scatter_reduce.two, args, kwargs + ) + + @unittest.skip + def test_aten_scatter_src_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + torch.randint(0, 10, (10, 10)).to(torch.int64), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.scatter.src, args, kwargs) + + @unittest.skip + def test_aten_scatter_src_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + torch.randint(0, 10, (10, 10)).to(torch.int64), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.scatter.src, args, kwargs) + + @unittest.skip + def test_aten_scatter_src_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + torch.randint(0, 10, (10, 10)).to(torch.int64), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.scatter.src, args, kwargs) + + @unittest.skip + def test_aten_scatter_value_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + torch.randint(0, 10, (10, 10)).to(torch.int64), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.scatter.value, args, kwargs) + + @unittest.skip + def test_aten_scatter_value_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + torch.randint(0, 10, (10, 10)).to(torch.int64), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.scatter.value, args, kwargs) + + @unittest.skip + def test_aten_scatter_value_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + torch.randint(0, 10, (10, 10)).to(torch.int64), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.scatter.value, args, kwargs) + + @unittest.skip + def test_aten_select_copy_int_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + 0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.select_copy.int, args, kwargs) + + @unittest.skip + def test_aten_select_copy_int_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + 0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.select_copy.int, args, kwargs) + + @unittest.skip + def test_aten_select_copy_int_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + 0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.select_copy.int, args, kwargs) + + def test_aten_select_int_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.select.int, args, kwargs) + + def test_aten_select_int_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.select.int, args, kwargs) + + def test_aten_select_int_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.select.int, args, kwargs) + + @unittest.skip + def test_aten_select_scatter_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randint(0, 10, (10,)).to(torch.int64), + 1, + 0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs) + + @unittest.skip + def test_aten_select_scatter_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randint(0, 10, (10,)).to(torch.int64), + 1, + 0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs) + + @unittest.skip + def test_aten_select_scatter_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10,)).to(torch.int64), + 1, + 0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs) + + @unittest.skip + def test_aten_sigmoid_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs) + + @unittest.skip + def test_aten_sigmoid_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs) + + @unittest.skip + def test_aten_sigmoid_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs) + + @unittest.skip + def test_aten_sign_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sign, args, kwargs) + + @unittest.skip + def test_aten_sign_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sign, args, kwargs) + + @unittest.skip + def test_aten_sign_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sign, args, kwargs) + + def test_aten_sin_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sin, args, kwargs) + + def test_aten_sin_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sin, args, kwargs) + + def test_aten_sin_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sin, args, kwargs) + + def test_aten_sinh_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sinh, args, kwargs) + + def test_aten_sinh_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sinh, args, kwargs) + + def test_aten_sinh_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sinh, args, kwargs) + + @unittest.skip + def test_aten_slice_copy_Tensor_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.slice_copy.Tensor, args, kwargs) + + @unittest.skip + def test_aten_slice_copy_Tensor_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.slice_copy.Tensor, args, kwargs) + + @unittest.skip + def test_aten_slice_copy_Tensor_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.slice_copy.Tensor, args, kwargs) + + @unittest.skip + def test_aten_slice_scatter_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.slice_scatter, args, kwargs) + + @unittest.skip + def test_aten_slice_scatter_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.slice_scatter, args, kwargs) + + @unittest.skip + def test_aten_slice_scatter_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.slice_scatter, args, kwargs) + + def test_aten_slice_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.slice.Tensor, args, kwargs) + + def test_aten_slice_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.slice.Tensor, args, kwargs) + + def test_aten_slice_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.slice.Tensor, args, kwargs) + + def test_aten__softmax_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + False, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._softmax, args, kwargs) + + @unittest.skip + def test_aten__softmax_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + False, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten._softmax, args, kwargs) + + def _compare_sorted_result(self, args): + res = torch.ops.aten.sort(*args) + with self.subTest("torch_xla2_eval"): + args2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device, args) + res2 = torch.ops.aten.sort(*args2) + + # The second argument is the sorted index. These might not be + # identical from torch vs. jax; but both can be correct + diff_output(self, res[0], res2[0].torch(), 1e-3, 1e-5) + + def test_aten_sort_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + ) + self._compare_sorted_result(args) + + def test_aten_sort_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 1, + ) + self._compare_sorted_result(args) + + def test_aten_sort_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 1, + ) + self._compare_sorted_result(args) + + @unittest.skip + def test_aten_split_copy_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.split_copy.Tensor, args, kwargs) + + @unittest.skip + def test_aten_split_copy_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.split_copy.Tensor, args, kwargs) + + @unittest.skip + def test_aten_split_copy_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 2, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.split_copy.Tensor, args, kwargs) + + @unittest.skip + def test_aten_split_with_sizes_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 1, + 2, + 3, + 4, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.split_with_sizes, args, kwargs) + + @unittest.skip + def test_aten_split_with_sizes_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 1, + 2, + 3, + 4, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.split_with_sizes, args, kwargs) + + @unittest.skip + def test_aten_split_with_sizes_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 1, + 2, + 3, + 4, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.split_with_sizes, args, kwargs) + + def test_aten_sqrt_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sqrt, args, kwargs) + + def test_aten_sqrt_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sqrt, args, kwargs) + + def test_aten_sqrt_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sqrt, args, kwargs) + + @unittest.skip + def test_aten_squeeze_copy_dim_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.squeeze_copy.dim, args, kwargs) + + @unittest.skip + def test_aten_squeeze_copy_dim_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.squeeze_copy.dim, args, kwargs) + + @unittest.skip + def test_aten_squeeze_copy_dim_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.squeeze_copy.dim, args, kwargs) + + @unittest.skip + def test_aten_squeeze_dims_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.squeeze.dims, args, kwargs) + + @unittest.skip + def test_aten_squeeze_dims_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.squeeze.dims, args, kwargs) + + @unittest.skip + def test_aten_squeeze_dims_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 0, + 1, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.squeeze.dims, args, kwargs) + + def test_aten_stack_0(self): + args = ( + [ + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.stack, args, kwargs) + + def test_aten_stack_1(self): + args = ( + [ + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.stack, args, kwargs) + + def test_aten_stack_2(self): + args = ( + [ + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.stack, args, kwargs) + + def test_aten_sub_Scalar_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sub.Scalar, args, kwargs) + + def test_aten_sub_Scalar_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sub.Scalar, args, kwargs) + + def test_aten_sub_Scalar_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0.123, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sub.Scalar, args, kwargs) + + def test_aten_sub_Tensor_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sub.Tensor, args, kwargs) + + def test_aten_sub_Tensor_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sub.Tensor, args, kwargs) + + def test_aten_sub_Tensor_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sub.Tensor, args, kwargs) + + def test_aten_sum_dim_IntList_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + None, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sum.dim_IntList, args, kwargs) + + def test_aten_sum_dim_IntList_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + None, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sum.dim_IntList, args, kwargs) + + def test_aten_sum_dim_IntList_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + None, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.sum.dim_IntList, args, kwargs) + + @unittest.skip + def test_aten_tan_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.tan, args, kwargs) + + @unittest.skip + def test_aten_tan_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare( + self, + torch.ops.aten.tan, + args, + kwargs, + rtol=0.001, + atol=0.01, + ) + + @unittest.skip + def test_aten_tan_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.tan, args, kwargs) + + def test_aten_tanh_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.tanh, args, kwargs) + + def test_aten_tanh_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.tanh, args, kwargs) + + def test_aten_tanh_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.tanh, args, kwargs) + + def test_aten_topk_0(self): + args = ( + torch.arange(0, 100).reshape(10, 10).to(torch.float32), + 1, + 1, + False, + False, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.topk, args, kwargs) + + def test_aten_topk_1(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 1, + 1, + True, + False, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.topk, args, kwargs) + + def test_aten_topk_2(self): + args = ( + torch.arange(0, 100).reshape(10, 10).to(torch.int32), + 1, + 1, + False, + False, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.topk, args, kwargs) + + def test_aten_topk_3(self): + args = ( + torch.arange(0, 100).reshape(10, 10).to(torch.int32), + 3, + 0, + False, + True, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.topk, args, kwargs) + + def test_aten_topk_4(self): + args = ( + torch.arange(0, 100).reshape(10, 10).to(torch.int32), + 3, + 0, + True, + True, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.topk, args, kwargs) + @unittest.skip + def test_aten_transpose_copy_int_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + 0, + 1, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.transpose_copy.int, args, kwargs + ) + + @unittest.skip + def test_aten_transpose_copy_int_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + 0, + 1, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.transpose_copy.int, args, kwargs + ) + + @unittest.skip + def test_aten_transpose_copy_int_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + 0, + 1, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.transpose_copy.int, args, kwargs + ) + + @unittest.skip + def test_aten_tril_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.tril, args, kwargs) + + @unittest.skip + def test_aten_tril_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.tril, args, kwargs) + + @unittest.skip + def test_aten_tril_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.tril, args, kwargs) + + def test_aten_trunc_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.trunc, args, kwargs) + + def test_aten_trunc_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.trunc, args, kwargs) + + def test_aten_trunc_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.trunc, args, kwargs) + + @unittest.skip + def test_aten_unbind_copy_int_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unbind_copy.int, args, kwargs) + + @unittest.skip + def test_aten_unbind_copy_int_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unbind_copy.int, args, kwargs) + + @unittest.skip + def test_aten_unbind_copy_int_2(self): + args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unbind_copy.int, args, kwargs) + + @unittest.skip + def test_aten_unsqueeze_copy_0(self): + args = ( + torch.randn((2, 0, 2)).to(torch.float32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unsqueeze_copy, args, kwargs) + + @unittest.skip + def test_aten_unsqueeze_copy_1(self): + args = ( + torch.randn((2, 0, 2)).to(torch.float16), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unsqueeze_copy, args, kwargs) + + @unittest.skip + def test_aten_unsqueeze_copy_2(self): + args = ( + torch.randint(0, 10, (2, 0, 2)).to(torch.int32), + 1, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.unsqueeze_copy, args, kwargs) + + def test_aten_upsample_bilinear2d_0(self): + args = ( + torch.randn((1, 3, 2, 10)).to(torch.float32), + [ + 3, + 20, + ], + False, + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.upsample_bilinear2d, args, kwargs + ) + + def test_aten_upsample_nearest2d_0(self): + args = ( + torch.randn((1, 3, 2, 10)).to(torch.float32), + [ + 3, + 20, + ], + ) + kwargs = dict() + run_export_and_compare( + self, torch.ops.aten.upsample_nearest2d, args, kwargs + ) + + @unittest.skip + def test_aten_var_correction_0(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) + + @unittest.skip + def test_aten_var_correction_1(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) + + @unittest.skip + def test_aten_var_correction_2(self): + args = (torch.randn((10, 10)).to(torch.float32),) + kwargs = dict(correction=0) + run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) + + @unittest.skip + def test_aten_var_correction_3(self): + args = (torch.randn((10, 10)).to(torch.float16),) + kwargs = dict(correction=0) + run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) + + def test_aten_view_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 1, + 100, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.view, args, kwargs) + + def test_aten_view_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 1, + 100, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.view, args, kwargs) + + def test_aten_view_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 1, + 100, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.view, args, kwargs) + + @unittest.skip + def test_aten_view_copy_0(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + [ + 2, + 5, + 10, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.view_copy, args, kwargs) + + @unittest.skip + def test_aten_view_copy_1(self): + args = ( + torch.randn((10, 10)).to(torch.float16), + [ + 2, + 5, + 10, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.view_copy, args, kwargs) + + @unittest.skip + def test_aten_view_copy_2(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + 2, + 5, + 10, + ], + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.view_copy, args, kwargs) + + def test_aten_where_self_0(self): + args = ( + torch.randn((10, 10)).to(torch.bool), + torch.randn((10, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.where.self, args, kwargs) + + +if __name__ == "__main__": + test_base.main() diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py new file mode 100644 index 00000000000..94e50b47c95 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -0,0 +1,19 @@ +import jax +import torch +import torch._functorch +from torch_xla2 import tensor + + +def extract_jax(mod: torch.nn.Module): + """Returns a pytree of jax.ndarray and a jax callable.""" + func, weights, buffer = torch._functorch.make_functional_with_buffers(mod) + states = (weights, buffer) + + @jax.jit + def jax_func(states, inputs): + (states, inputs) = tensor.wrap((states, inputs)) + weights, buffer = states + res = func(weights, buffer, *inputs) + return tensor.unwrap(res) + + return states, jax_func diff --git a/experimental/torch_xla2/torch_xla2/export.py b/experimental/torch_xla2/torch_xla2/export.py new file mode 100644 index 00000000000..9d70be87d55 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/export.py @@ -0,0 +1,221 @@ +# pylint: disable +"""Utilities for exporting a torch program to jax/stablehlo.""" +import copy +from typing import Any, Dict, Tuple +import jax +import torch +from torch.fx import _pytree as fx_pytree +from torch_xla2 import ops_registry, tensor +from torch.utils import _pytree as pytree + + +class JaxProgram: + + def _wrap_inputs(self, xs, allow_torch_tensor=False): + def convert(t): + if isinstance(t, tensor.XLATensor2): + return t + if isinstance(t, torch.Tensor): + if allow_torch_tensor: + return tensor.move_to_device(t) + else: + raise ValueError('Regular torch.Tensor is not allowed.') + if isinstance(t, jax.Array): + return tensor.XLATensor2(t) + return t + + return jax.tree_util.tree_map(convert, xs) + + def _unwrap_outputs(self, xs): + def convert(t): + if isinstance(t, tensor.XLATensor2): + return t.jax() + if isinstance(t, torch.Tensor): + raise ValueError('Regular torch.Tensor is not allowed.') + return t + + return jax.tree_util.tree_map(convert, xs) + + def __init__( + self, + exported_program, + param_buffer_values, + ordered_tensor_constants, + ): + + self.param_buffer_values = self._wrap_inputs( + param_buffer_values, allow_torch_tensor=True + ) + self.ordered_tensor_constants = self._wrap_inputs( + ordered_tensor_constants, allow_torch_tensor=True + ) + self.exported_program = exported_program + + def __hash__(self): + return hash(self.exported_program) + + @property + def example_inputs(self): + args, kwargs = self.exported_program.example_inputs + args = pytree.tree_map(tensor.t2j, args) + kwargs = pytree.tree_map(tensor.t2j, kwargs) + return args, kwargs + + def flatten_inputs(self, args, kwargs): + if args is None: + args = tuple() + if kwargs is None: + kwargs = {} + + if (in_spec := self.exported_program.call_spec.in_spec) is not None: + if ( + in_spec.type == tuple + and len(in_spec.children_specs) == 2 + and in_spec.children_specs[0].type == tuple + and in_spec.children_specs[1].type == dict + ): + # NOTE: this is the case where in_spec is for both args and kwargs + return fx_pytree.tree_flatten_spec((args, kwargs), in_spec) + return fx_pytree.tree_flatten_spec(args, in_spec) + return copy.deepcopy(args) + + def unflatten_outputs(self, res): + return pytree.tree_unflatten(res, self.exported_program.call_spec.out_spec) + + def __call__(self, *args, **kwargs): + + inputs = self.flatten_inputs(args, kwargs) + res = self.flatten_callable(*inputs) + res = self.unflatten_outputs(res) + + return res + + @property + def flatten_callable(self): + def func(*inputs: jax.Array): + nonlocal self + inputs = self._wrap_inputs(inputs) + num_mutations = len( + self.exported_program.graph_signature.buffers_to_mutate + ) + res = torch.fx.Interpreter(self.exported_program.graph_module).run( + *self.param_buffer_values, + *inputs, + *self.ordered_tensor_constants, + enable_io_processing=False, + ) + res = res[num_mutations:] + res = self._unwrap_outputs(res) + return res + + return func + + def jit(self, *args, **kwargs): + """Returns `jax.jit(self, *args, **kwargs)`.""" + return jax.jit(self, *args, **kwargs) + + def jit_lower(self, *args, **kwargs): + """Returns `jax.jit(self, *args, **kwargs).lower(...)` with example_inputs used in export.""" + example_args, example_kwargs = self.example_inputs + return self.jit(*args, **kwargs).lower(*example_args, **example_kwargs) + + +def exported_program_to_jax_program(ep): + """exported_program_to_jax_program. + + Args: + ep: torch.export.ExportedProgram + + Returns: + JaxProgram + + """ + if torch.__version__ >= '2.2': + ep = ep.run_decompositions() + + param_buffer_keys = ep.graph_signature.parameters + ep.graph_signature.buffers + param_buffer_values = tuple(ep.state_dict[key] for key in param_buffer_keys) + + if hasattr(ep.graph_signature, 'lifted_tensor_constants'): + ordered_tensor_constants = tuple( + ep.tensor_constants[name] + for name in ep.graph_signature.lifted_tensor_constants + ) + else: + ordered_tensor_constants = tuple() + + return JaxProgram(ep, param_buffer_values, ordered_tensor_constants) + + +class JaxInterpreter(torch.fx.Interpreter): + """Experimental.""" + + def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: + if not isinstance( + target, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ): + return super().call_function(target, args, kwargs) + + print('Running ', target.name(), '--------') + + op = ops_registry.lowerings.lookup(target) + if op is None: + print(target.name(), target.tags) + raise RuntimeError('No lowering found for', target.name()) + return op.func(*args, **kwargs) + + def run_node(self, n) -> Any: + res = super().run_node(n) + if n.op == 'call_function': + if hasattr(res, 'shape'): + print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape) + return res + + +_extra_decomp = {} + + +def exported_program_to_jax(exported_program): + """returns a pytree of jax arrays(state), and + + a callable(func) that is jax function. + + func(state, input) would be how you call it. + """ + if torch.__version__ >= '2.2': + # torch version 2.1 didn't expose this yet + exported_program = exported_program.run_decompositions() + exported_program = exported_program.run_decompositions(_extra_decomp) + param_buffer_keys = ( + exported_program.graph_signature.parameters + + exported_program.graph_signature.buffers + ) + param_buffer_values = tuple( + exported_program.state_dict[key] for key in param_buffer_keys + ) + + if hasattr(exported_program.graph_signature, 'lifted_tensor_constants'): + ordered_tensor_constants = tuple( + exported_program.tensor_constants[name] + for name in exported_program.graph_signature.lifted_tensor_constants + ) + else: + ordered_tensor_constants = tuple() + + num_mutations = len(exported_program.graph_signature.buffers_to_mutate) + + def func(states, inputs): + param_buffer_values, ordered_tensor_constants = states + res = JaxInterpreter(exported_program.graph_module).run( + *param_buffer_values, + *inputs, + *ordered_tensor_constants, + enable_io_processing=False, + ) + res = res[num_mutations:] + return res + + state = pytree.tree_map_only( + torch.Tensor, tensor.t2j, (param_buffer_values, ordered_tensor_constants) + ) + return state, func diff --git a/experimental/torch_xla2/torch_xla2/ops.py b/experimental/torch_xla2/torch_xla2/ops.py new file mode 100644 index 00000000000..a6f153d72e1 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/ops.py @@ -0,0 +1,1241 @@ +# pylint: disable +"""Torch ops implemented using jax.""" +import sys + +import jax +from jax import numpy as jnp +import numpy as np +import torch +from torch_xla2 import ops_registry +from torch_xla2 import tensor + + +class TorchFunctionLowering: + + def __init__(self, func, is_jax_func, should_jit=False): + if is_jax_func and should_jit: + func = jax.jit(func) + self.func = func + self.is_jax_func = is_jax_func + + def __call__(self, *args, **kwargs): + if self.is_jax_func: + (args, kwargs) = tensor.unwrap((args, kwargs)) + res = self.func(*args, **kwargs) + if self.is_jax_func: + res = tensor.wrap(res) + return res + + +def op(aten_op, is_jax_func=True): + """if is_jax_func is true, then the function it will register + + should takes jax array as input and returns jax array. + + Which means we need to wrap it + """ + + def inner(func): + ops_registry.lowerings.register( + aten_op, TorchFunctionLowering(func, is_jax_func) + ) + return func + + return inner + + +@op(torch.ops.aten.view) +@op(torch.ops.aten._unsafe_view) +def _aten_unsafe_view(x, shape): + return jnp.reshape(x, shape) + + +@op(torch.ops.aten.add) +def _aten_add(x, y): + """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): + + assert x.dtype == y.dtype, (x.dtype, y.dtype) + """ + try: + return x + y + except Exception as e: + import pdb + + pdb.set_trace() + + +@op(torch.ops.aten.add_, is_jax_func=False) +def _aten_add_inplace(self, other, *, alpha): + if isinstance(other, XLATensor2): + self._elem += alpha * other._elem + else: + self._elem += alpha * other + return self + + +@op(torch.ops.aten.copy_, is_jax_func=False) +def _aten_copy(x, y, memory_format=None): + if isinstance(x, XLATensor2): + x._elem = y._elem + elif isinstance(x, SliceView): + x.mutate(y) + return x + + +@op(torch.ops.aten.clone) +def _aten_clone(x, memory_format=None): + return jnp.copy(x) + + +@op(torch.ops.aten.full) +def _aten_full(size, value, **kwargs): + return jnp.full(size, value) + + +@op(torch.ops.aten.index_copy) +def _aten_index_copy(x, dim, indexes, source): + # return jax.lax.scatter(x, index, dim) + dims = [] + for i in range(len(x.shape)): + if i == dim: + dims.append(indexes) + else: + dims.append(slice(None, None, None)) + return x.at[dim].set(source) + + +@op(torch.ops.aten.select) +@op(torch.ops.aten.index_select) +def _aten_index_select(x, dim, indexes): + dims = [] + for i in range(len(x.shape)): + if i == dim: + dims.append(indexes) + else: + dims.append(slice(None, None, None)) + return x[tuple(dims)] + + +@op(torch.ops.aten.mean) +def _aten_mean(x, dim, keepdim): + return jnp.mean(x, dim, keepdims=keepdim) + + +def _torch_binary_scalar_type(scalar, tensor): + if "float" in str(tensor.dtype): + return tensor.dtype + + if isinstance(scalar, int): + if "int" in str(tensor.dtype): + return tensor.dtype + + return jnp.float32 + + +@op(torch.ops.aten.sub) +def _aten_sub(x, y): + if isinstance(x, float): + dtype = _torch_binary_scalar_type(x, y) + x = jnp.array(x, dtype=dtype) + if isinstance(y, float): + dtype = _torch_binary_scalar_type(y, x) + y = jnp.array(y, dtype=dtype) + return x - y + + +@op(torch.ops.aten.mm) +def _aten_mm(x, y): + res = x @ y + assert res.dtype == jnp.bfloat16 + return res + + +@op(torch.ops.aten.mul) +def _aten_mul(x, y): + return x * y + + +@op(torch.ops.aten.silu) +def _aten_silu(x): + return jax.nn.silu(x) + + +@op(torch.ops.aten.t) +def _aten_t(x): + return jnp.transpose(x) + + +@op(torch.ops.aten.transpose) +def _aten_transpose(x, dim0, dim1): + shape = list(range(len(x.shape))) + shape[dim0], shape[dim1] = shape[dim1], shape[dim0] + return jnp.transpose(x, shape) + + +@op(torch.ops.aten.triu) +def _aten_triu(m, k): + return jnp.triu(m, k) + + +@op(torch.ops.aten.slice) +def _aten_slice(self, dim=0, start=None, end=None, step=1): + if end == sys.maxsize: + end = self.shape[dim] + sl = slice(start, end, step) + dims = [] + for i in range(len(self.shape)): + if i == dim: + dims.append(sl) + else: + dims.append(slice(None, None, None)) + return self[tuple(dims)] + + +@op(torch.ops.aten.detach) +def _aten_detach(self): + return self + + +@op(torch.ops.aten.view_as_real) +def _aten_view_as_real(x): + real = jnp.real(x) + im = jnp.imag(x) + res = jnp.stack([real, im], -1) + return res + + +@op(torch.ops.aten.stack) +def _aten_stack(tensors, dim=0): + return jnp.stack(tensors, dim) + + +@op(torch.ops.aten._softmax) +def _aten_softmax(x, dim, halftofloat): + return jax.nn.softmax(x, dim) + + +@op(torch.ops.aten.pow) +def _aten_pow(x, y): + if isinstance(y, int): + y = float(y) + if isinstance(y, jnp.ndarray): + y = y.astype(jnp.astype(jnp.bfloat16)) + return jnp.power(x, y) + + +@op(torch.ops.aten.view_as_complex) +def _aten_view_as_complex(input): + if input.dtype == jnp.bfloat16: + input = input.astype(jnp.float32) + x, y = input[..., 0], input[..., 1] + return jax.lax.complex(x, y) + + +@op(torch.ops.aten.div) +def _aten_div(x, y, rounding_mode=""): + if rounding_mode == "trunc": + return jnp.floor_divide(x, y) + return x / y + + +@op(torch.ops.aten.bmm) +def _aten_bmm(x, y): + res = x @ y + return res + # return jnp.einsum('bnm,bmk->bnk', x, y) + + +@op(torch.ops.aten.embedding) +# embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) +def _aten_embedding(a, w, padding_idx=-1): + return jnp.take(a, w, axis=0) + + +@op(torch.ops.aten.rsqrt) +def _aten_rsqrt(x): + return jax.lax.rsqrt(x) + + +@op(torch.ops.aten.expand) +def _aten_expand(x, dims): + def fix_dims(d, xs): + if d == -1: + return xs + return d + dims = [fix_dims(p, s) for p, s in zip(dims, x.shape)] + return jnp.broadcast_to(x, dims) + + +@op(torch.ops.aten.dot) +def _aten_dot(x, y): + return jnp.dot(x, y) + + +@op(torch.ops.aten._to_copy) +def _aten__to_copy(self, **kwargs): + dtype = tensor.t2j_dtype(kwargs["dtype"]) + if dtype != self.dtype: + return self.astype(dtype) + return jnp.copy(self) + + +@op(torch.ops.aten.empty) +def _aten_empty(sizes, **kwargs): + return jnp.zeros(sizes) + + +@op(torch.ops.aten.index_put_) +@op(torch.ops.aten.index_put) +def _aten_index_put(self, indexes, values): + indexes = [slice(None, None, None) if i is None else i for i in indexes] + indexes = tuple(indexes) + return self.at[indexes].set(values) + + +@op(torch.ops.aten.index) +@op(torch.ops.aten._unsafe_index) +@op(torch.ops.aten.index.Tensor) +def _aten_index(self, indexes): + indexes = [slice(None, None, None) if i is None else i for i in indexes] + indexes = tuple(indexes) + return self[indexes] + + +@op(torch.ops.aten.split) +@op(torch.ops.aten.split_with_sizes) +def split_with_sizes(x, sizes, dim): + """Splits an array `x` into sub-arrays based on static sizes `sizes`. + + Args: + x: The input array to split. + sizes: A 1D array of integer sizes for each sub-array. + + Returns: + A list of sub-arrays. + """ + if isinstance(sizes, int): + # split equal size + new_sizes = [sizes] * (x.shape[dim] // sizes) + sizes = new_sizes + rank = x.ndim + splits = np.cumsum(sizes) # Cumulative sum for split points + + def make_range(rank, dim, start, end): + res = [slice(None, None, None)] * rank + res[dim] = slice(start, end) + return tuple(res) + + return [ + x[make_range(rank, dim, start, end)] + for start, end in zip([0] + list(splits[:-1]), splits) + ] + + +@op(torch.ops.aten.permute) +def permute(t, dims): + return jnp.transpose(t, dims) + + +@op(torch.ops.aten.unsqueeze) +@op(torch.ops.aten.unsqueeze.default) +def _aten_unsqueeze(self, dim): + if dim < 0: + dim += self.ndim + 1 + return jnp.expand_dims(self, dim) + + +@op(torch.ops.aten.ne) +def _aten_ne(x, y): + return jnp.not_equal(x, y) + + +@op(torch.ops.aten.cumsum) +def _aten_cumsum(x, y): + try: + return jnp.cumsum(x, y) + except Exception as e: + import pdb + + pdb.set_trace() + + +@op(torch.ops.aten.native_layer_norm) +def _aten_native_layer_norm( + input, normalized_shape, weight=None, bias=None, eps=1e-5 +): + """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. + + Args: + input: The input tensor. + normalized_shape: A list of integer dimensions to be normalized over. + weight: Optional weight tensor for the affine transformation. + bias: Optional bias tensor for the affine transformation. + eps: A small epsilon value for numerical stability. + + Returns: + output: The normalized tensor. + mean: The calculated mean tensor. + std: The calculated standard deviation tensor. + """ + if isinstance(normalized_shape, int): + normalized_shape = [normalized_shape] + axis = [i for i, d in enumerate(input.shape) if d in normalized_shape] + + # Calculate mean and standard deviation + mean = jnp.mean(input, axis=axis, keepdims=True) + var = jnp.var(input, axis=axis, keepdims=True) + rstd = jax.lax.rsqrt(var + eps) + + # Normalize the input + norm_x = (input - mean) * rstd + + # Apply affine transformation (if provided) + if weight is not None: + norm_x *= weight + if bias is not None: + norm_x += bias + return norm_x, mean, rstd + + +# - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +@op(torch.ops.aten.addmm) +def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): + self *= beta + self += alpha * jnp.matmul(mat1, mat2) + return self + + +@op(torch.ops.aten.gelu) +def _aten_gelu(self, *, approximate="none"): + approx = approximate == "tanh" + return jax.nn.gelu(self, approx) + + +@op(torch.ops.aten.squeeze) +def _aten_squeeze_dim(self, dim): + """Squeezes a Jax tensor by removing a single dimension of size 1. + + Args: + self: The input tensor. + dim: The dimension to squeeze. + + Returns: + The squeezed tensor with the specified dimension removed if it is 1, + otherwise the original tensor is returned. + """ + + # Validate input arguments + if not isinstance(self, jnp.ndarray): + raise TypeError(f"Expected a Jax tensor, got {type(self)}.") + if not isinstance(dim, int): + raise TypeError(f"Expected dim to be an int, got {type(dim)}.") + + # Check if the specified dimension has size 1 + if self.shape[dim] != 1: + return self + + # Use slicing to remove the dimension if it is 1 + new_shape = list(self.shape) + new_shape.pop(dim) + return self.reshape(new_shape) + + +@op(torch.ops.aten.convolution) +def _aten_convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, +): + if transposed: + raise NotImplementedError("Transposed convolution is not implemented.") + + def make_padding(padding): + return ((p, p) for p in padding) + + def create_default_conv_dimension_numbers(num_spatial_dims): + # Ref: https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc#L4211 + # (batch dimension, feature dimension, spatial dimensions...) + lhs_spec = [0, 1] + # (out feature dimension, in feature dimension, spatial dimensions...) + rhs_spec = [0, 1] + # (batch dimension, feature dimension, spatial dimensions...) + out_spec = [0, 1] + for i in range(0, num_spatial_dims): + lhs_spec.append(i + 2) + rhs_spec.append(i + 2) + out_spec.append(i + 2) + return jax.lax.ConvDimensionNumbers( + *map(tuple, (lhs_spec, rhs_spec, out_spec)) + ) + + res = jax.lax.conv_general_dilated( + input, + weight, + stride, + make_padding(padding), + lhs_dilation=(1,) * len(stride), + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers(len(stride)), + feature_group_count=groups, + batch_group_count=1, + ) + + if bias is not None: + # TODO(qihqi): this is wrong + bias = bias.reshape(bias.shape + (1,)) + res = res + bias + return res + + +@op(torch.ops.aten._native_batch_norm_legit_no_training) +def _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps +): + if weight is None: + weight = jnp.ones_like(running_mean) + if bias is None: + bias = jnp.zeros_like(running_mean) + + def broadcast(t): + return jax.lax.broadcast_in_dim(t, input.shape, broadcast_dimensions=(1,)) + + a = input - broadcast(running_mean) + b = broadcast(jnp.sqrt(running_var + eps)) + return ( + a / b * broadcast(weight) + broadcast(bias), + jnp.array([]), + jnp.array([]), + ) + + +@op(torch.ops.aten.relu) +def _aten_relu(self): + return jax.nn.relu(self) + + +@op(torch.ops.aten.cat) +def _aten_cat(tensors, dims=0): + return jnp.concatenate(tensors, dims) + + +@op(torch.ops.aten.max_pool2d_with_indices) +def _aten_max_pool2d_with_indices( + self, kernel_size, stride, padding=0, dilation=1, ceil_mode=False +): + stride = stride if stride else [1, 1] + if not isinstance(padding, (list, tuple)): + padding = [padding, padding] + + def build_ceil_mode_padding(): + ceil_mode_padding = [(0, 0), (0, 0)] + for i in range(len(padding)): + left_padding = padding[0] + input_size = self.shape[2 + i] + output_size_rem = ( + input_size + 2 * left_padding - kernel_size[i] + ) % stride[i] + right_padding = left_padding + if ceil_mode and output_size_rem != 0: + extra_padding = stride[i] - output_size_rem + new_output_size = ( + input_size + + left_padding + + right_padding + + extra_padding + - kernel_size[i] + + stride[i] + - 1 + ) // stride[i] + 1 + if (new_output_size - 1) * stride[i] < input_size + left_padding: + right_padding += extra_padding + ceil_mode_padding.append((left_padding, right_padding)) + return ceil_mode_padding + + ceil_mode_padding = build_ceil_mode_padding() + if not all([p == (0, 0) for p in ceil_mode_padding]): + self = jnp.pad( + self, + ceil_mode_padding, + "constant", + constant_values=-jnp.inf, + ) + batch_result = jax.lax.reduce_window( + self, + -jnp.inf, + jax.lax.max, + window_dimensions=[1, 1] + kernel_size, + window_strides=[1, 1] + stride, + padding="VALID", + ) + + # TODO: compute indices from batch_result + # Ref: https://github.com/pytorch/xla/blob/master/torch_xla/csrc/pooling.cpp#L259 + + return batch_result, None + + +# TODO add more ops + +@op(torch.ops.aten.min) +def _aten_min(x, axis=None): + return jnp.min(x, axis=axis), jnp.argmin(x, axis=axis).astype(jnp.int64) + +@op(torch.ops.aten.amin) +def _aten_amin(x, axis=None): + return jnp.min(x, axis=axis) + +@op(torch.ops.aten.argmin) +def _aten_amin(x, axis=None): + return jnp.argmin(x, axis=axis) + +@op(torch.ops.aten.sin) +def _aten_sin(x): + return jnp.sin(x) + +@op(torch.ops.aten.sym_size) +def _aten_sym_size(x, dim): + return x.shape[dim] + +@op(torch.ops.aten.var) +@op(torch.ops.prims.var) +def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None): + return jnp.var(x, axis=dim, ddof=correction, keepdims=keepdim) + +@op(torch.ops.prims.broadcast_in_dim) +def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): + return jax.lax.broadcast_in_dim( + t, shape, broadcast_dimensions=broadcast_dimensions + ) + + +# aten.native_group_norm -- should use decomp table +# func: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) + +@op(torch.ops.aten.native_group_norm) +def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5): + """Group Normalization implementation in JAX. + + Args: + input: Input tensor. Expected shape (batch_size, channels, ... spatial dims + ...) + weight: Optional scaling (gamma) parameter. Shape (channels,) + bias: Optional shifting (beta) parameter. Shape (channels,) + N: Batch size. + C: Number of channels. + HxW: Product of spatial dimensions (number of elements per channel after + flattening). + group: Number of groups for Group Normalization. + eps: Small value added for numerical stability. + + Returns: + A tuple of (normalized_output, mean, rstd) + """ + + input_shape = input.shape + + # Reshape for group-wise normalization + reshaped_input = jnp.reshape(input, (1, N * group, -1)) + + # **Core Group Normalization** + def group_norm_body(x): # Function to apply within each group + mean = jnp.mean(x, axis=-1, keepdims=True) + var = jnp.var(x, axis=-1, keepdims=True) + rstd = jax.lax.rsqrt(var + eps) # Reciprocal of std with epsilon + normalized = (x - mean) * rstd + return normalized, mean, rstd + + normalized, group_mean, group_rstd = jax.lax.map( + group_norm_body, reshaped_input + ) + + # Reshape back to original input shape + output = jnp.reshape(normalized, input_shape) + + # **Affine transformation** + affine_shape = [ + -1 if i == 1 else 1 for i in range(input.ndim) + ] # Shape for broadcasting + if weight is not None and bias is not None: + output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape) + elif weight is not None: + output = output * weight.reshape(affine_shape) + elif bias is not None: + output = output + bias.reshape(affine_shape) + + # Reshape mean and rstd + mean = jnp.reshape(group_mean, (N, group)) + rstd = jnp.reshape(group_rstd, (N, group)) + + return output, mean, rstd + + +@op(torch.ops.aten.linalg_vector_norm) +def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): + """Calculates the vector norm along specified dimensions. + + Args: + self: The input tensor. + ord: The order of the norm. Can be a float or 'inf', '-inf', 'fro'. + Default is 2 (Euclidean norm). + dim: Dimensions along which to calculate the norm. If None, the norm is + calculated over all dimensions. + keepdim: Whether to keep the reduced dimensions. + dtype: Optional data type for the output. + + Returns: + The tensor containing the calculated vector norms. + """ + + if ord not in {2, float("inf"), float("-inf"), "fro"}: + raise ValueError( + f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" + " 'fro'." + ) + + # Special cases (for efficiency and clarity) + if ord == 2: # Euclidean norm + result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim)) + + elif ord == float("inf"): + result = jnp.max(jnp.abs(self), axis=dim, keepdims=keepdim) + + elif ord == float("-inf"): + result = jnp.min(jnp.abs(self), axis=dim, keepdims=keepdim) + + elif ord == "fro": # Frobenius norm + result = jnp.sqrt(jnp.sum(jnp.abs(self) ** 2, axis=dim, keepdims=keepdim)) + + else: # General case (e.g., ord = 1, ord = 3) + result = jnp.sum(jnp.abs(self) ** ord, axis=dim, keepdims=keepdim) ** ( + 1.0 / ord + ) + + # (Optional) dtype conversion + if dtype is not None: + result = result.astype(dtype) + + return result + + +# aten.reflection_pad1d +@op(torch.ops.aten.reflection_pad1d) +def _aten_reflection_pad1d(input, padding): + rank = len(input.shape) + pad_size = [(0, 0)] * rank + pad_size[-1] = padding + return jnp.pad(input, pad_size, mode="reflect") + + +# aten.alias +@op(torch.ops.aten.alias) +def _aten_alias(self, *args): + return self + + +# aten.sinh +@op(torch.ops.aten.sinh) +def _aten_sinh(self): + return jnp.sinh(self) + + +# aten.native_layer_norm_backward +@op(torch.ops.aten.native_layer_norm_backward) +def _aten_native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps=1e-5 +): + """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. + + Args: + grad_out: The gradient of the output tensor. + input: The input tensor. + normalized_shape: A list of integer dimensions to be normalized over. + weight: Optional weight tensor for the affine transformation. + bias: Optional bias tensor for the affine transformation. + eps: A small epsilon value for numerical stability. + + Returns: + A tuple of (grad_input, grad_weight, grad_bias). + """ + return jax.lax.native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps + ) + + +# aten.reflection_pad3d_backward +# aten.reflection_pad2d + + +# aten.atanh +@op(torch.ops.aten.atanh) +def _aten_atanh(self): + return jnp.arctanh(self) + + +# aten.bitwise_not +@op(torch.ops.aten.bitwise_not) +def _aten_bitwise_not(self): + return ~self + + +# aten.embedding_dense_backward + + +# aten.sum +@op(torch.ops.aten.sum) +def _aten_sum(self, dim=None, keepdim=False, dtype=None): + return jnp.sum(self, axis=dim, keepdims=keepdim, dtype=dtype) + + +# aten.sqrt +@op(torch.ops.aten.sqrt) +def _aten_sqrt(self): + return jnp.sqrt(self) + + +# aten.tanh +@op(torch.ops.aten.tanh) +def _aten_tanh(self): + return jnp.tanh(self) + + +# aten.ceil +@op(torch.ops.aten.ceil) +def _aten_ceil(self): + return jnp.ceil(self) + + +# aten.asin +@op(torch.ops.aten.asin) +def _aten_asin(self): + return jnp.arcsin(self) + + +# aten.minimum +@op(torch.ops.aten.minimum) +def _aten_minimum(self, other): + return jnp.minimum(self, other) + + +# aten.max_pool2d_backward + +# aten.scatter_add +# aten.logical_not + +# aten.sign +# aten.sigmoid + + +# implement aten.asinh in jax +@op(torch.ops.aten.asinh) +def _aten_asinh(self): + return jnp.arcsinh(self) + + +# aten.atan +@op(torch.ops.aten.atan) +def _aten_atan(self): + return jnp.arctan(self) + + +# aten.scatter_reduce +# aten.acos +@op(torch.ops.aten.acos) +def _aten_acos(self): + return jnp.arccos(self) + + +# aten.sym_storage_offset +# aten.native_layer_norm_backward +# aten.max_pool3d_with_indices + + +# aten.gt +@op(torch.ops.aten.gt) +def _aten_gt(self, other): + return self > other + + +# aten.pixel_shuffle +# aten.sym_stride +# aten.lt +@op(torch.ops.aten.lt) +def _aten_lt(self, other): + return self < other + + +# aten.avg_pool2d +# aten.sym_numel +# aten.reciprocal +# aten.scatter + + +# aten.acosh +@op(torch.ops.aten.acosh) +def _aten_acosh(self): + return jnp.arccosh(self) + + +# aten.avg_pool2d_backward +# aten.col2im +# aten.avg_pool3d +# aten.round + + +# aten.max +@op(torch.ops.aten.max) +def _aten_max(self, dim=None, keepdim=False): + return jnp.max(self, axis=dim, keepdims=keepdim), jnp.argmax( + self, axis=dim, keepdims=keepdim + ) + + +# aten.maximum +@op(torch.ops.aten.maximum) +def _aten_maximum(self, other): + return jnp.maximum(self, other) + + +# aten.abs +@op(torch.ops.aten.abs) +def _aten_abs(self): + return jnp.abs(self) + + +# generate aten.amax only +@op(torch.ops.aten.amax) +def _aten_amax(self, dim=None, keepdim=False): + return jnp.amax(self, axis=dim, keepdims=keepdim) + + +# aten.any +@op(torch.ops.aten.any) +def _aten_any(self, dim=None, keepdim=False): + return jnp.any(self, axis=dim, keepdims=keepdim) + + +# aten.arange +@op(torch.ops.aten.arange) +def _aten_arange( + start, + end=None, + step=1, + dtype=None, + layout=None, + device=None, + pin_memory=False, +): + return jnp.arange( + start, + end, + step, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + ) + + +# aten.argmax +@op(torch.ops.aten.argmax) +def _aten_argmax(self, dim=None, keepdim=False): + return jnp.argmax(self, axis=dim, keepdims=keepdim) + + +# aten.as_strided + + +# aten.atan2 +@op(torch.ops.aten.atan2) +def _aten_atan2(self, other): + return jnp.arctan2(self, other) + + +# aten.bitwise_and +@op(torch.ops.aten.bitwise_and) +def _aten_bitwise_and(self, other): + return self & other + + +# aten.bitwise_or +@op(torch.ops.aten.bitwise_or) +def _aten_bitwise_or(self, other): + return self | other + + +# aten.bitwise_xor +@op(torch.ops.aten.bitwise_xor) +def _aten_bitwise_xor(self, other): + return self ^ other + + +# aten.clamp +@op(torch.ops.aten.clamp) +def _aten_clamp(self, min=None, max=None): + return jnp.clip(self, min, max) + + +# aten.constant_pad_nd +@op(torch.ops.aten.constant_pad_nd) +def _aten_constant_pad_nd(input, padding, value=0): + # NOTE: Torch padding is flat and reversed: (1, 1, 2, 2) + # means last dim get padded 1 in front and 1 in back; + # and second last dim get padded 2 in front and 2 in back. + # Jax padding tuple of 2-tuple: the same padding is + # [(0, 0), ..., (2,2), (1,1)] + m = len(padding) + rev_padding = [(padding[i - 1], padding[i]) for i in range(m - 1, 0, -2)] + pad_dim = tuple(([(0, 0)] * (len(input.shape) - m // 2)) + rev_padding) + return jnp.pad(input, pad_dim, mode="constant", constant_values=value) + + +# aten.convolution_backward +@op(torch.ops.aten.copy) +def _aten_copy(x): + return jnp.copy(x) + + +# aten.cos +@op(torch.ops.aten.cos) +def _aten_cos(input): + return jnp.cos(input) + + +# aten.cosh +@op(torch.ops.aten.cosh) +def _aten_cosh(input): + return jnp.cosh(input) + + +# aten.diagonal +# aten.empty_strided +# aten.eq +@op(torch.ops.aten.eq) +def _aten_eq(input1, input2): + return input1 == input2 + + +# aten.erf +# aten.exp +@op(torch.ops.aten.exp) +def _aten_exp(input): + return jnp.exp(input) + + +# aten.expm1 +@op(torch.ops.aten.expm1) +def _aten_expm1(input): + return jnp.expm1(input) + + +# aten.fill +# aten.flip +@op(torch.ops.aten.flip) +def _aten_flip(input, dims): + if dims is not None: + return jnp.flip(input, tuple(dims)) + else: + return jnp.flip(input) + + +# aten.floor +@op(torch.ops.aten.floor) +def _aten_floor(input): + return jnp.floor(input) + + +# aten.fmod +# aten.gather +# aten.ge +@op(torch.ops.aten.ge) +def _aten_ge(self, other): + return self >= other + + +# aten.hardtanh +@op(torch.ops.aten.hardtanh) +def _aten_hardtanh(input, min_val=-1., max_val=1., inplace=False): + return jnp.clip(input, min_val, max_val) + +# aten.isinf +@op(torch.ops.aten.isinf) +def _aten_isinf(input): + return jnp.isinf(input) + +# aten.isnan +@op(torch.ops.aten.isnan) +def _aten_isnan(input): + return jnp.isnan(input) + +@op(torch.ops.aten.le) +def _aten_le(self, other): + return self <= other + +# aten.leaky_relu +@op(torch.ops.aten.leaky_relu) +def _aten_leaky_relu(x, negative_slope): + return jax.nn.leaky_relu(x, negative_slope) +# aten.log +@op(torch.ops.aten.log) +def _aten_log(x): + return jnp.log(x) + +# aten.log10 +@op(torch.ops.aten.log10) +def _aten_log10(x): + return jnp.log10(x) + +# aten.log1p +@op(torch.ops.aten.log1p) +def _aten_log1p(x): + return jnp.log1p(x) + +# aten.log2 +@op(torch.ops.aten.log2) +def _aten_log2(x): + return jnp.log2(x) + +# aten.logical_and +@op(torch.ops.aten.logical_and) +def _aten_logical_and(self, other): + return jnp.logical_and(self, other) + + +# aten.logical_or +@op(torch.ops.aten.logical_or) +def _aten_logical_or(self, other): + return jnp.logical_or(self, other) + + +# aten.logical_not +@op(torch.ops.aten.logical_not) +def _aten_logical_not(self): + return jnp.logical_not(self) + + +# aten.log_softmax +@op(torch.ops.aten._log_softmax) +def _aten_log_softmax(self, axis=-1, half_to_float=False): + return jax.nn.log_softmax(self, axis) + + +# aten.max_pool3d_backward +# aten.logical_xor +@op(torch.ops.aten.logical_xor) +def _aten_logical_xor(self, other): + return jnp.logical_xor(self, other) + + +# aten.max_pool2d_with_indices_backward +# aten.native_dropout +# aten.native_group_norm_backward +# aten.neg +# aten.nonzero +# aten.prod + +# aten.rand +# aten.randn +# aten.randperm +# aten.reflection_pad3d +# aten.remainder +# aten.repeat +# aten.replication_pad2d +# aten.replication_pad3d +# aten.roll +# aten.scalar_tensor +# aten.select_scatter +# aten.slice_scatter + + +# aten.sort +# torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) +@op(torch.ops.aten.sort) +def _aten_sort(a, dim=-1, descending=False, stable=False): + return ( + jnp.sort(a, axis=dim, stable=stable, descending=descending), + jnp.argsort(a, axis=dim, stable=stable, descending=descending), + ) + + +# aten.sym_size + + +# aten.topk +@op(torch.ops.aten.topk) +def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): + """JAX top-k implementation using jax.lax.top_k for improved efficiency. + + Args: + input: The input JAX array. + k: The number of top elements to return. + dim: The dimension along which to find the top-k. If None, operates on the + flattened array. + largest: If True, returns the largest k elements. Otherwise, smallest k. + sorted: If True, returns the elements in sorted order. + + Returns: + A tuple (values, indices) containing: + - values: The top k values. + - indices: The indices of the top k values in the original array. + """ + if dim is None: + input = input.flatten() + dim = 0 + + if not largest: + input = -input # Find top-k of negated input if we want the smallest + + transpose_shape = None + if dim != -1 and dim != len(input.shape) - 1: + transpose_shape = list(range(len(input.shape))) + transpose_shape[dim], transpose_shape[-1] = ( + transpose_shape[-1], transpose_shape[dim]) + input = jnp.transpose(input, transpose_shape) + + values, indices = jax.lax.top_k(input, k) + + if sorted: + values = jnp.sort(values, descending=True) + indices = jnp.take_along_axis(indices, + jnp.argsort(values, axis=-1, descending=True), axis=-1) + + if not largest: + values = -values # Negate values back if we found smallest + + if transpose_shape is not None: + values = jnp.transpose(values, transpose_shape) + indices = jnp.transpose(indices, transpose_shape) + + return values, indices + + +# aten.trunc +@op(torch.ops.aten.trunc) +def _aten_trunc(a): + return jnp.trunc(a) + + +# NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d +# despite those being core aten ops, they also have decompositions. +# here we are using torch decompositions. + + +# aten.where +@op(torch.ops.aten.where) +def _aten_where(condition, x, y): + return jnp.where(condition, x, y) + + +# aten.to.dtype +@op(torch.ops.aten.to.dtype) +def _aten_to_dtype(a, dtype): + jaxdtype = tensor.t2j_dtype(dtype) + return a.astype(jaxdtype) + + +# aten.to.device diff --git a/experimental/torch_xla2/torch_xla2/ops_registry.py b/experimental/torch_xla2/torch_xla2/ops_registry.py new file mode 100644 index 00000000000..7915c360783 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/ops_registry.py @@ -0,0 +1,46 @@ +import torch + + +class LoweringRegistry: + + def __init__(self): + self.registered_ops = {} + + def lookup(self, op_or_name): + candidate = self.registered_ops.get(op_or_name) + if candidate is None: + if isinstance(op_or_name, torch._ops.OpOverloadPacket): + candidate = self.registered_ops.get(op_or_name.default) + if isinstance(op_or_name, torch._ops.OpOverload): + candidate = self.registered_ops.get(op_or_name.overloadpacket) + return candidate + + def register(self, op, lowering): + self.registered_ops[op] = lowering + + +lowerings = LoweringRegistry() + + +def _all_core_ops(): + """Yields all core ops.""" + import torch._ops + + for k, v in torch.ops.aten.__dict__.items(): + if k.startswith('__'): + continue + if k.startswith('_'): + continue + if isinstance(v, torch._ops.OpOverloadPacket): + for overload in v.overloads(): + op = getattr(v, overload) + if torch.Tag.core in op.tags: + yield v + break + + +def print_missing_ops(): + core_aten = set(_all_core_ops()) + existing = set(lowerings.registered_ops.keys()) + for v in core_aten - existing: + print(v) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py new file mode 100644 index 00000000000..6b6ac010f63 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -0,0 +1,273 @@ +import functools +import jax +from jax import dlpack as jaxdl +import jax.numpy as jnp +import numpy +import torch +import torch._decomp as decomp +from torch_xla2 import ops_registry +import torch.utils._python_dispatch as torch_dispatch +import torch.utils._pytree as torch_pytree +import torch.utils.dlpack as torchdl + + +class XLADispatchMode(torch_dispatch.TorchDispatchMode): + + def __torch_dispatch__(self, fn, types, args=(), kwargs=None): + if fn in constructors: + args, kwargs = unwrap((args, kwargs)) + res = constructors[fn](*args, **kwargs) + return wrap(res) + return fn(*args, **kwargs) + + +def _aten_arange( + start, + end, + *, + dtype=None, + layout=None, + requires_grad=False, + device=None, + pin_memory=False +): + return jnp.arange(start, end, 1) + + +constructors = { + torch.ops.aten.arange.default: functools.partial(_aten_arange, 0), + torch.ops.aten.arange.start: _aten_arange, +} + + +def wrap(jaxarray): + return torch_pytree.tree_map_only(jnp.ndarray, XLATensor2, jaxarray) + + +def unwrap(torchtensors): + return torch_pytree.tree_map_only(XLATensor2, lambda x: x._elem, torchtensors) + + +def t2j(t): + if isinstance(t, XLATensor2): + return t._elem + if t.dtype == torch.bool: + t = t.to(torch.int8) + + if not t.is_contiguous(): + t = t.contiguous() + + try: + dl = torchdl.to_dlpack(t) + res = jaxdl.from_dlpack(dl) + except Exception: + # https://github.com/google/jax/issues/7657 + # https://github.com/google/jax/issues/17784 + if t.dtype == torch.bfloat16: + nparray = ( + t.detach().to(torch.float32).numpy() + ) # numpy don't support bfloat16 + else: + nparray = t.detach().numpy() + res = jnp.asarray(nparray) + if t.dtype == torch.bfloat16: + res = res.astype(jnp.bfloat16) + + if t.dtype == torch.bool: + res = res.astype(jnp.bool_) + return res + + +def j2t(x): + try: + dl = jaxdl.to_dlpack(x) + res = torchdl.from_dlpack(dl) + except Exception: + res = torch.from_numpy(numpy.asarray(x)) + if x.dtype == jnp.bool_: + res = res.to(torch.bool) + return res + + +def t2j_dtype(dtype): + return { + torch.bfloat16: jnp.bfloat16, + torch.double: jnp.double, + torch.float32: jnp.float32, + torch.half: jnp.float16, + torch.long: jnp.int64, + torch.int32: jnp.int32, + torch.int16: jnp.int16, + torch.bool: jnp.bool_, + }.get(dtype) + + +def j2t_dtype(dtype): + return { + jnp.bfloat16: torch.bfloat16, + jnp.double: torch.double, + jnp.float32: torch.float32, + jnp.float16: torch.half, + jnp.int64: torch.long, + jnp.int32: torch.int32, + jnp.int16: torch.int16, + jnp.bool_: torch.bool, + }.get(dtype) + + +def move_to_device(t): + return XLATensor2(t2j(t)) + + +EXTRA_DECOMP = decomp.get_decompositions([torch.ops.aten.upsample_nearest2d]) +CORE_ATEN_DECOMP = decomp.core_aten_decompositions() +CORE_ATEN_DECOMP.update(EXTRA_DECOMP) + + +class XLATensor2(torch.Tensor): + + @staticmethod + def __new__(cls, elem): + dtype = j2t_dtype(elem.dtype) + shape = list(elem.shape) + for i, s in enumerate(shape): + if not isinstance(s, int): + shape[i] = 1 + if dtype is None: + dtype = torch.float32 + return torch.Tensor._make_subclass( + cls, + torch.empty(shape, dtype=dtype, device="meta"), + require_grad=False, + ) + + def __init__(self, elem: jax.Array): + super().__init__() + self._elem = elem + + def __str__(self): + return "XLATensor2({} {})".format(str(type(self._elem)), str(self._elem)) + + def __jax_array__(self): + return self._elem + + @property + def shape(self): + return self._elem.shape + + @property + def ndim(self): + return len(self._elem.shape) + + def flatten(self, start_dim=0, end_dim=-1): + if end_dim == -1: + end_dim = self.ndim + new_shape = ( + self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim:] + ) + new_elem = jnp.reshape(self._elem, new_shape) + return XLATensor2(new_elem) + # return torch.reshape(self, new_shape) + + def __setitem__(self, key, val): + key = unwrap(key) + self._elem = self._elem.at[key].set(val._elem) + + def type_as(self, other): + self._elem = self._elem.astype(other._elem.dtype) + return self + + __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + print("running...", func.name(), types) + for a in torch_pytree.tree_flatten(args)[0]: + if isinstance(a, XLATensor2): + print(" ", a._elem.shape) + else: + print(" ", a) + lowering = ops_registry.lowerings.lookup(func) + if lowering is None: + if func in CORE_ATEN_DECOMP: + with XLADispatchMode(): + return CORE_ATEN_DECOMP[func](*args, **kwargs) + else: + print(func.name(), func.tags) + raise RuntimeError("No lowering found for", func.name()) + res = lowering(*args, **kwargs) + print("output:") + for a in torch_pytree.tree_flatten(res)[0]: + if isinstance(a, XLATensor2): + print(" ", a._elem.shape) + debug_accuracy(func, args, kwargs, res) + return res + + def detach(self): + return XLATensor2(jax.lax.stop_gradient(self.jax())) + + def numpy(self) -> numpy.ndarray: + import numpy as np + + return np.array(self._elem) + + def jax(self) -> jax.Array: + return self._elem + + def torch(self) -> torch.Tensor: + return j2t(self.jax()) + + +# TODO: slice of slice should also be another slice +class SliceView(XLATensor2): + + def __init__(self, orig_tensor, slice): + self._orig_tensor = orig_tensor + self._slice = slice + + def numpy(self): + return self._orig_tensor.numpy()[self._slice] + + def jax(self): + return self._orig_tensor.jax()[self._slice] + + def torch(self): + return self._orig_tensor.torch()[self.slice] + + def mutate(self, slice, new_content): + self._orig_tensor._elem = self._orig_tensor.at[slice].set(new_content) + + +_DEBUG_ACCURACY = False + + +def debug_accuracy(func, args, kwargs, current_output): + if not _DEBUG_ACCURACY: + return True + + args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only( + torch.Tensor, lambda x: j2t(x._elem), (args, kwargs, current_output) + ) + expected_out = func(*args_torch, **kwargs_torch) + + flattened_current_out, _ = torch_pytree.tree_flatten(out_torch) + flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out) + + for ex, real in zip(flattened_expected_out, flattened_current_out): + if ex.dtype != real.dtype: + ex = ex.to(real.dtype) + try: + if ( + _DEBUG_ACCURACY + and isinstance(ex, torch.Tensor) + and not torch.allclose(ex, real, atol=1e-3, equal_nan=True) + ): + import pdb + + pdb.set_trace() + except: + import pdb + + pdb.set_trace() + + return True diff --git a/experimental/torch_xla2/torch_xla2/tf_integration.py b/experimental/torch_xla2/torch_xla2/tf_integration.py new file mode 100644 index 00000000000..2e3cd265252 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/tf_integration.py @@ -0,0 +1,127 @@ +# pylint: disable +import os +from typing import Any, Tuple + +from jax.experimental import jax2tf +import tensorflow as tf +import torch +from torch_xla2 import export + + +def exported_program_to_tf_function( + ep, enable_xla=True +): + jax_program = export.exported_program_to_jax_program(ep) + + example_inputs = jax_program.flatten_inputs(*jax_program.example_inputs) + input_signature = [ + tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=f"args_{i}") + for i, t in enumerate(example_inputs) + ] + tf_f = tf.function( + jax2tf.convert( + jax_program.flatten_callable, + with_gradient=False, + enable_xla=enable_xla, + ), + autograph=False, + input_signature=input_signature, + ) + return tf_f + + +def exported_program_to_tf_module( + ep: torch.export.ExportedProgram, enable_xla=True +) -> tf.Module: + tfm = tf.Module() + tfm.f = exported_program_to_tf_function(ep, enable_xla) + return tfm + + +def save_exported_program_as_tf_saved_model( + ep: torch.export.ExportedProgram, + saved_model_dir: os.PathLike, + serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + function_alias: str = "", + enable_xla=True, +): + """This function will export and save a pytorch ExportedProgram to tf.saved_model format. + + The resulting tf.saved_model can be used inference using tf.serving model + server + or further convert to tflite flatbuffer for on-device serving. + + Args: + torch_model: torch.nn.Module - model to export and save + args: Tuple[Any] - a set of args to trace the model with, i.e. + torch_model(*args) must run + saved_model_dir: os.PathLike - location to an empty directory to store the + saved_model + serving_key: str - serving key tag, this is used by tf.serving to know + which function to run. + function_alias: str - passed through saved_model.save, used to tag a + function for inference converter or other tools. + """ + tfm = exported_program_to_tf_module(ep, enable_xla=enable_xla) + signatures = { + serving_key: tfm.f.get_concrete_function(*tfm.f.input_signature) + } + save_options = tf.saved_model.SaveOptions( + function_aliases={ + function_alias: tfm.f, + } + ) + tf.saved_model.save( + tfm, + saved_model_dir, + signatures=signatures, + options=save_options, + ) + + +def save_torch_module_as_tf_saved_model( + torch_model: torch.nn.Module, + args: Tuple[Any], + saved_model_dir: os.PathLike, + serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + function_alias: str = "", + enable_xla=True, +): + """This function will export and save a pytorch nn.Module to tf.saved_model format. + + The resulting tf.saved_model can be used inference using tf.serving model + server + or further convert to tflite flatbuffer for on-device serving. + + Args: + torch_model: torch.nn.Module - model to export and save + args: Tuple[Any] - a set of args to trace the model with, i.e. + torch_model(*args) must run + saved_model_dir: os.PathLike - location to an empty directory to store the + saved_model + serving_key: str - serving key tag, this is used by tf.serving to know + which function to run. + function_alias: str - passed through saved_model.save, used to tag a + function for inference converter or other tools. + """ + ep = torch.export.export(torch_model, args) + save_exported_program_as_tf_saved_model( + ep, saved_model_dir, serving_key, function_alias, enable_xla + ) + + +def exported_program_to_tflite_flatbuffer(ep: torch.export.ExportedProgram): + tfm = exported_program_to_tf_module(ep) + tf_concrete_func = tfm.f.get_concrete_function(*tfm.f.input_signature) + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [tf_concrete_func], tfm + ) + tflite_model = converter.convert() + return tflite_model + + +def torch_module_to_tflite_flatbuffer( + torch_model: torch.nn.Module, args: Tuple[Any] +): + ep = torch.export.export(torch_model, args) + return exported_program_to_tflite_flatbuffer(ep)