diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py new file mode 100644 index 00000000000..ed14e636e5c --- /dev/null +++ b/experimental/torch_xla2/test/test_ops.py @@ -0,0 +1,680 @@ +import unittest + +import torch +from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, ops) +from torch.utils import _pytree as pytree +from torch_xla2 import tensor + +skiplist = { + "__getitem__", + "__rmatmul__", + "__rpow__", + "_native_batch_norm_legit", + "_segment_reduce", + "_upsample_bilinear2d_aa", + "addbmm", + "addmm", + "addmv", + "addr", + "all", + "allclose", + "amax", + "amin", + "aminmax", + "angle", + "any", + "argmax", + "argmin", + "argsort", + "as_strided", + "as_strided_scatter", + "baddbmm", + "bernoulli", + "bincount", + "bitwise_left_shift", + "bitwise_right_shift", + "block_diag", + "broadcast_tensors", + "broadcast_to", + "bucketize", + "byte", + "cat", + "cauchy", + "cdist", + "cholesky", + "cholesky_inverse", + "cholesky_solve", + "combinations", + "complex", + "constant_pad_nd", + "copysign", + "corrcoef", + "count_nonzero", + "cov", + "cross", + "cummax", + "cummin", + "cumprod", + "cumsum", + "diag", + "diag_embed", + "diagflat", + "diagonal_copy", + "diagonal_scatter", + "diff", + "digamma", + "dist", + "div", + "empty", + "empty_like", + "empty_permuted", + "empty_strided", + "equal", + "erfc", + "erfinv", + "exp2", + "expand", + "exponential", + "fft.fft2", + "fft.fft", + "fft.fftn", + "fft.hfft2", + "fft.hfft", + "fft.hfftn", + "fft.ifft2", + "fft.ifft", + "fft.ifftn", + "fft.ihfft2", + "fft.ihfft", + "fft.ihfftn", + "fft.irfft2", + "fft.irfft", + "fft.irfftn", + "fft.rfft2", + "fft.rfft", + "fft.rfftn", + "floor_divide", + "fmax", + "fmin", + "frexp", + "full_like", + "gather", + "gcd", + "geometric", + "geqrf", + "grid_sampler_2d", + "heaviside", + "histc", + "histogram", + "histogramdd", + "hypot", + "i0", + "igamma", + "igammac", + "index_copy", + "index_fill", + "index_put", + "index_reduce", + "index_select", + "isclose", + "isin", + "item", + "kthvalue", + "lcm", + "lerp", + "lgamma", + "linalg.cholesky", + "linalg.cholesky_ex", + "linalg.cond", + "linalg.cross", + "linalg.det", + "linalg.eig", + "linalg.eigh", + "linalg.eigvals", + "linalg.eigvalsh", + "linalg.householder_product", + "linalg.inv", + "linalg.inv_ex", + "linalg.ldl_factor", + "linalg.ldl_factor_ex", + "linalg.ldl_solve", + "linalg.lstsq", + "linalg.lu", + "linalg.lu_factor", + "linalg.lu_factor_ex", + "linalg.lu_solve", + "linalg.matrix_norm", + "linalg.matrix_power", + "linalg.matrix_rank", + "linalg.multi_dot", + "linalg.norm", + "linalg.pinv", + "linalg.qr", + "linalg.slogdet", + "linalg.solve", + "linalg.solve_ex", + "linalg.solve_triangular", + "linalg.svd", + "linalg.svdvals", + "linalg.tensorinv", + "linalg.tensorsolve", + "linalg.vander", + "linalg.vector_norm", + "linspace", + "log_normal", + "log_softmax", + "logaddexp2", + "logaddexp", + "logcumsumexp", + "logdet", + "logspace", + "logsumexp", + "lu", + "lu_solve", + "lu_unpack", + "masked.amax", + "masked.amin", + "masked.argmax", + "masked.argmin", + "masked.cumprod", + "masked.cumsum", + "masked.log_softmax", + "masked.logaddexp", + "masked.logsumexp", + "masked.mean", + "masked.median", + "masked.norm", + "masked.normalize", + "masked.prod", + "masked_scatter", + "masked_select", + "masked.softmax", + "masked.softmin", + "masked.std", + "masked.sum", + "masked.var", + "matrix_exp", + "matmul", + "max_pool2d_with_indices_backward", + "max", + "median", + "min", + "mode", + "multinomial", + "mvlgamma", + "nanmedian", + "nanquantile", + "nansum", + "narrow_copy", + "narrow", + "native_batch_norm", + "native_layer_norm", + "new_empty", + "new_empty_strided", + "nextafter", + "nn.functional.adaptive_avg_pool1d", + "nn.functional.adaptive_avg_pool2d", + "nn.functional.adaptive_avg_pool3d", + "nn.functional.adaptive_max_pool1d", + "nn.functional.adaptive_max_pool2d", + "nn.functional.adaptive_max_pool3d", + "nn.functional.alpha_dropout", + "nn.functional.avg_pool1d", + "nn.functional.avg_pool2d", + "nn.functional.avg_pool3d", + "nn.functional.batch_norm", + "nn.functional.bilinear", + "nn.functional.binary_cross_entropy", + "nn.functional.conv2d", + "nn.functional.conv3d", + "nn.functional.conv_transpose1d", + "nn.functional.conv_transpose2d", + "nn.functional.conv_transpose3d", + "nn.functional.cosine_embedding_loss", + "nn.functional.cosine_similarity", + "nn.functional.cross_entropy", + "nn.functional.ctc_loss", + "nn.functional.dropout2d", + "nn.functional.dropout3d", + "nn.functional.dropout", + "nn.functional.embedding_bag", + "nn.functional.embedding", + "nn.functional.feature_alpha_dropout", + "nn.functional.fractional_max_pool2d", + "nn.functional.fractional_max_pool3d", + "nn.functional.gaussian_nll_loss", + "nn.functional.grid_sample", + "nn.functional.group_norm", + "nn.functional.hinge_embedding_loss", + "nn.functional.instance_norm", + "nn.functional.interpolate", + "nn.functional.layer_norm", + "nn.functional.leaky_relu", + "nn.functional.linear", + "nn.functional.logsigmoid", + "nn.functional.margin_ranking_loss", + "nn.functional.max_pool1d", + "nn.functional.max_pool2d", + "nn.functional.max_pool3d", + "nn.functional.max_unpool1d", + "nn.functional.max_unpool2d", + "nn.functional.max_unpool3d", + "nn.functional.multi_head_attention_forward", + "nn.functional.multi_margin_loss", + "nn.functional.multilabel_margin_loss", + "nn.functional.multilabel_soft_margin_loss", + "nn.functional.nll_loss", + "nn.functional.normalize", + "nn.functional.one_hot", + "nn.functional.pad", + "nn.functional.pairwise_distance", + "nn.functional.pdist", + "nn.functional.pixel_shuffle", + "nn.functional.pixel_unshuffle", + "nn.functional.poisson_nll_loss", + "nn.functional.rrelu", + "nn.functional.scaled_dot_product_attention", + "nn.functional.softmin", + "nn.functional.unfold", + "nn.functional.upsample_nearest", + "nonzero", + "nonzero_static", + "norm", + "normal", + "ones_like", + "ormqr", + "pca_lowrank", + "pinverse", + "polar", + "polygamma", + "prod", + "put", + "qr", + "quantile", + "rand_like", + "randint_like", + "randn_like", + "renorm", + "repeat_interleave", + "resize_", + "resize_as_", + "rot90", + "rsub", + "scatter_add", + "scatter", + "scatter_reduce", + "searchsorted", + "select", + "select_scatter", + "signbit", + "softmax", + "sort", + "special.airy_ai", + "special.bessel_j0", + "special.bessel_j1", + "special.bessel_y0", + "special.bessel_y1", + "special.chebyshev_polynomial_t", + "special.chebyshev_polynomial_u", + "special.erfcx", + "special.hermite_polynomial_h", + "special.hermite_polynomial_he", + "special.i0e", + "special.i1", + "special.i1e", + "special.laguerre_polynomial_l", + "special.log_ndtr", + "special.modified_bessel_i0", + "special.modified_bessel_i1", + "special.modified_bessel_k0", + "special.modified_bessel_k1", + "special.ndtri", + "special.polygamma", + "special.scaled_modified_bessel_k0", + "special.scaled_modified_bessel_k1", + "special.spherical_bessel_j0", + "special.zeta", + "squeeze", + "stft", + "sub", + "sum", + "svd", + "svd_lowrank", + "take_along_dim", + "take", + "tensor_split", + "to_sparse", + "topk", + "trace", + "triangular_solve", + "triu", + "unbind", + "unfold_copy", + "unfold", + "uniform", + "unique_consecutive", + "unique", + "unravel_index", + "var_mean", + "zero_", + "zeros_like", + "argwhere", + "cumulative_trapezoid", + "expand_as", + "mean", + "nanmean", + "trapezoid", + "trapz", + "H", + "T", + "__radd__", + "__rand__", + "__rdiv__", + "__rmod__", + "__rmul__", + "__ror__", + "__rsub__", + "__rxor__", + "_softmax_backward_data", + "abs", + "acos", + "acosh", + "addcdiv", + "addcmul", + "arange", + "asin", + "asinh", + "atan2", + "atan", + "atanh", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "bfloat16", + "bitwise_and", + "bitwise_not", + "bitwise_or", + "bitwise_xor", + "bmm", + "bool", + "broadcast_shapes", + "cartesian_prod", + "cdouble", + "ceil", + "cfloat", + "chalf", + "char", + "chunk", + "clamp", + "clamp_max", + "clamp_min", + "clone", + "column_stack", + "conj", + "conj_physical", + "contiguous", + "cos", + "cosh", + "deg2rad", + "diagonal", + "dot", + "double", + "dsplit", + "dstack", + "einsum", + "eq", + "erf", + "exp", + "expm1", + "eye", + "fft.fftshift", + "fft.ifftshift", + "fill", + "flatten", + "flip", + "fliplr", + "flipud", + "float", + "float_power", + "floor", + "fmod", + "frac", + "full", + "ge", + "gradient", + "gt", + "half", + "hsplit", + "hstack", + "index_add", + "inner", + "int", + "isfinite", + "isinf", + "isnan", + "isneginf", + "isposinf", + "isreal", + "kron", + "ldexp", + "le", + "linalg.diagonal", + "linalg.vecdot", + "log10", + "log1p", + "log2", + "log", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logit", + "long", + "lt", + "mH", + "mT", + "masked_fill", + "maximum", + "meshgrid", + "minimum", + "movedim", + "msort", + "mul", + "mv", + "nan_to_num", + "native_dropout_backward", + "ne", + "neg", + "new_full", + "new_ones", + "new_zeros", + "nn.functional.binary_cross_entropy_with_logits", + "nn.functional.celu", + "nn.functional.conv1d", + "nn.functional.elu", + "nn.functional.gelu", + "nn.functional.glu", + "nn.functional.hardshrink", + "nn.functional.hardsigmoid", + "nn.functional.hardswish", + "nn.functional.hardtanh", + "nn.functional.huber_loss", + "nn.functional.kl_div", + "nn.functional.l1_loss", + "nn.functional.local_response_norm", + "nn.functional.mish", + "nn.functional.mse_loss", + "nn.functional.prelu", + "nn.functional.relu6", + "nn.functional.relu", + "nn.functional.selu", + "nn.functional.silu", + "nn.functional.smooth_l1_loss", + "nn.functional.soft_margin_loss", + "nn.functional.softplus", + "nn.functional.softshrink", + "nn.functional.softsign", + "nn.functional.tanhshrink", + "nn.functional.threshold", + "nn.functional.triplet_margin_loss", + "nn.functional.triplet_margin_with_distance_loss", + "nn.functional.upsample_bilinear", + "ones", + "outer", + "permute", + "positive", + "pow", + "rad2deg", + "randint", + "randn", + "ravel", + "real", + "reciprocal", + "remainder", + "repeat", + "reshape_as", + "reshape", + "resolve_conj", + "resolve_neg", + "roll", + "round", + "rsqrt", + "scalar_tensor", + "sgn", + "short", + "sigmoid", + "sign", + "signal.windows.bartlett", + "signal.windows.blackman", + "signal.windows.cosine", + "signal.windows.exponential", + "signal.windows.gaussian", + "signal.windows.general_cosine", + "signal.windows.general_hamming", + "signal.windows.hamming", + "signal.windows.hann", + "signal.windows.kaiser", + "signal.windows.nuttall", + "sin", + "sinc", + "sinh", + "slice", + "slice_scatter", + "sparse.mm", + "sparse.sampled_addmm", + "special.entr", + "special.ndtr", + "special.xlog1py", + "split", + "split_with_sizes", + "sqrt", + "square", + "stack", + "std", + "std_mean", + "sum_to_size", + "t", + "tan", + "tanh", + "tensordot", + "tile", + "to", + "transpose", + "tril", + "tril_indices", + "triu_indices", + "true_divide", + "trunc", + "unflatten", + "unsafe_chunk", + "unsafe_split", + "unsqueeze", + "var", + "vdot", + "view_as_complex", + "view_as", + "view_copy", + "view", + "vsplit", + "vstack", + "where", + "xlogy", + "zeros", +} + + +def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True): + if isinstance(output1, torch.Tensor): + testcase.assertIsInstance(output2, torch.Tensor) + output2_cpu = output2.detach().cpu() + if output2_cpu.dtype != output1.dtype: + output2_cpu = output2_cpu.to(output1.dtype) + testcase.assertEqual(output1.shape, output2_cpu.shape) + testcase.assertEqual(output1.dtype, output2_cpu.dtype) + testcase.assertTrue( + torch.allclose( + output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan)) + elif isinstance(output1, (tuple, list)): + testcase.assertIsInstance(output2, (tuple, list)) + testcase.assertEqual(len(output1), len(output2)) + for o1, o2 in zip(output1, output2): + diff_output(testcase, o1, o2, rtol, atol) + else: + testcase.assertEqual(output1, output2) + + +def run_export_and_compare(testcase, + func, + sample_input, + atol=1e-3, + rtol=1e-5, + equal_nan=True, + ignore_indices=False): + with testcase.subTest("torch_eval"): + res = func(sample_input.input, *sample_input.args, **sample_input.kwargs) + with testcase.subTest("torch_xla2_eval"): + input2, args2, kwargs2 = pytree.tree_map_only( + torch.Tensor, tensor.move_to_device, + (sample_input.input, sample_input.args, sample_input.kwargs)) + res2 = func(input2, *args2, **kwargs2) + res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) + with testcase.subTest("torch_xla2_diff:" + str(atol)): + if ignore_indices and isinstance(res, tuple) and len(res) == 2: + diff_output( + testcase, + res[0], + res2[0], + atol=atol, + rtol=rtol, + equal_nan=equal_nan) + else: + diff_output( + testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan) + + +ops_to_test = list(filter(lambda x: x.name not in skiplist, op_db)) + + +class TestOpInfo(TestCase): + + @classmethod + def setUpClass(cls): + print('op_db size: ', len(op_db), 'testing: ', len(ops_to_test)) + + @ops(ops_to_test, allowed_dtypes=(torch.float32, torch.long)) + def test_reference_eager(self, device, dtype, op): + sample_inputs = op.sample_inputs(device, dtype) + for sample_input in sample_inputs: + t = sample_input.input + if isinstance(t, torch.Tensor) and t.is_sparse: + continue + run_export_and_compare(self, op, sample_input) + + +instantiate_device_type_tests(TestOpInfo, globals()) + +if __name__ == '__main__': + unittest.main() diff --git a/experimental/torch_xla2/test_requirements.txt b/experimental/torch_xla2/test_requirements.txt index aab35f51dab..c8596327236 100644 --- a/experimental/torch_xla2/test_requirements.txt +++ b/experimental/torch_xla2/test_requirements.txt @@ -1,4 +1,5 @@ pytest immutabledict sentencepiece -pytest-xdist \ No newline at end of file +pytest-xdist +expecttest \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/_ops.py b/experimental/torch_xla2/torch_xla2/_ops.py index 334c63bb423..e4aece8f6a8 100644 --- a/experimental/torch_xla2/torch_xla2/_ops.py +++ b/experimental/torch_xla2/torch_xla2/_ops.py @@ -52,17 +52,12 @@ def _aten_unsafe_view(x, shape): @op(torch.ops.aten.add) -def _aten_add(x, y): +def _aten_add(x, y, *, alpha=1): """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): assert x.dtype == y.dtype, (x.dtype, y.dtype) """ - try: - return x + y - except Exception as e: - import pdb - - pdb.set_trace() + return x + y * alpha @op(torch.ops.aten.add_, is_jax_func=False) @@ -239,6 +234,7 @@ def _aten_div(x, y, rounding_mode=""): res = jnp.trunc(res) return res + @op(torch.ops.aten.div_, is_jax_func=False) def _aten_div_(x, y, rounding_mode=""): x._elem = _aten_div(x._elem, y._elem, rounding_mode) @@ -376,14 +372,9 @@ def _aten_ne(x, y): @op(torch.ops.aten.cumsum) def _aten_cumsum(x, y, dtype=None): - try: - dtype = tensor.t2j_dtype(dtype) - res = jnp.cumsum(x, y, dtype) - return res - except Exception as e: - import pdb - - pdb.set_trace() + dtype = tensor.t2j_dtype(dtype) + res = jnp.cumsum(x, y, dtype) + return res @op(torch.ops.aten.native_layer_norm) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 78e0849537a..6adacedbbc0 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -7,10 +7,3 @@ # Keys are OpOverload, value is a callable that takes # XLATensor2 all_ops = {} - - -all_ops[torch.ops.aten.add.Tensor] = op_base.BinaryOpWithPromotion(jnp.add) -all_ops[torch.ops.aten.sub.Tensor] = op_base.BinaryOpWithPromotion(jnp.subtract) -all_ops[torch.ops.aten.sub.Scalar] = op_base.BinaryOpWithPromotion(jnp.subtract) -all_ops[torch.ops.aten.mul.Tensor] = op_base.BinaryOpWithPromotion(jnp.multiply) -all_ops[torch.ops.aten.div.Tensor] = op_base.BinaryOpWithPromotion(jnp.divide)