diff --git a/test/test_operations.py b/test/test_operations.py index 4f3cc1bed22..187fb62e8ff 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -67,22 +67,6 @@ def _is_on_eager_debug_mode(): 'skip on eager debug mode') -def _skipIfFunctionalization(value=True, reason=""): - verb = "is" if value else "is not" - reason = f" Reason: {reason}" if reason else "" - return unittest.skipIf( - XLA_DISABLE_FUNCTIONALIZATION is value, - f'Works only when functionalization {verb} disabled.{reason}.') - - -def skipIfFunctionalizationEnabled(reason): - return _skipIfFunctionalization(value=False, reason=reason) - - -def skipIfFunctionalizationDisabled(reason): - return _skipIfFunctionalization(value=True, reason=reason) - - def _gen_tensor(*args, **kwargs): return torch.randn(*args, **kwargs) @@ -994,8 +978,8 @@ def func(a, b): # TODO - upstream behavior has changed and results in expected DestroyXlaTensor # counter as of 11/13/2023. Re-enable after reviewing the change. - # @skipIfFunctionalizationDisabled("metrics differ") - @unittest.skip + @unittest.skipIf(True or XLA_DISABLE_FUNCTIONALIZATION, + 'Metrics differ when functionalization is disabled.') def test_set(self): met.clear_all() @@ -1013,7 +997,8 @@ def test_set(self): # shouldn't crash self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10))) - @skipIfFunctionalizationDisabled("metrics differ") + @unittest.skipIf(XLA_DISABLE_FUNCTIONALIZATION, + 'Metrics differ when functionalization is disabled.') def test_replace_xla_tensor(self): met.clear_all() @@ -1356,7 +1341,8 @@ def test_fn(t, c): ), dtype=torch.int64) self.runAtenTest([token_type_ids, cat_ids], test_fn) - @skipIfFunctionalizationEnabled("views do not exist") + @unittest.skipIf(not XLA_DISABLE_FUNCTIONALIZATION, + 'When functionalization is enabled, views do not exist.') def test_save_view_alias_check(self): class Nested(object): @@ -1512,63 +1498,6 @@ def test_fn(r): self.runAtenTest([torch.arange(144, dtype=torch.int32)], test_fn) - @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") - def test_as_strided_with_gap(self): - - def test_fn(r): - return torch.as_strided(r, (4, 4), (8, 1)) - - self.runAtenTest([torch.arange(28, dtype=torch.int32)], test_fn) - - @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") - def test_as_strided_with_gap_no_unit_stride(self): - - def test_fn(r): - return torch.as_strided(r, (4, 4), (8, 2)) - - self.runAtenTest([torch.arange(31, dtype=torch.int32)], test_fn) - - @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") - def test_as_strided_with_overlap(self): - - def test_fn(r): - return torch.as_strided(r, (4, 4), (2, 1)) - - self.runAtenTest([torch.arange(10, dtype=torch.int32)], test_fn) - - @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") - def test_as_strided_with_overlap_and_gap(self): - - def test_fn(r): - return torch.as_strided(r, (4, 4), (4, 2)) - - self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn) - - @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") - def test_as_strided_with_overlap_zero_stride(self): - - def test_fn(r): - return torch.as_strided(r, (4, 4), (0, 1)) - - self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn) - - @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") - def test_as_strided_with_gap_no_unit_stride(self): - - def test_fn(r): - x = r.view(8, 4) - return torch.as_strided(r, (4, 4), (6, 2)) - - self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn) - - @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") - def test_as_strided_with_empty_args(self): - - def test_fn(r): - return torch.as_strided(r, tuple(), tuple()) - - self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn) - def test_basic_bfloat16(self): def test_fn(s): diff --git a/test/test_ops.py b/test/test_ops.py index 12b874593bd..a3db0a91cd1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -29,7 +29,6 @@ def __new__(cls, name, variant_test_name=""): { AllowedOpInfoEntry('abs'), AllowedOpInfoEntry('add'), - AllowedOpInfoEntry('as_strided'), AllowedOpInfoEntry('mul'), AllowedOpInfoEntry('sub'), AllowedOpInfoEntry('addmm'), @@ -350,12 +349,6 @@ def __new__(cls, name, variant_test_name=""): # AllowedOpInfoEntry('var_mean'), # AllowedOpInfoEntry('pow'), # for int64 don't work, likely rounding issue # AllowedOpInfoEntry('__rpow__'), - - # In theory, this should work. - # However, the problem is the way we prepare the reference (CPU) inputs: - # we clone them. If they were a view, they are not anymore. - # - # AllowedOpInfoEntry('as_strided', 'partial_views'), })) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 63232fbac70..25b7e1a7b50 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -697,88 +697,24 @@ at::Tensor XLANativeFunctions::as_strided_copy( // This function actually operates on the tensor's storage. Since XLA does not // expose the actual storage, we use the originally allocated tensor. const at::Tensor& base = bridge::GetXlaTensor(self)->Base(); - at::Tensor tensor = base.defined() ? base : self; - - auto dim = size.size(); - auto itemsize = tensor.dtype().itemsize(); - int64_t storage_size = - at::detail::computeStorageNbytes(size, stride, itemsize); - - XLA_CHECK(tensor.numel() * itemsize >= storage_size) - << "as_strided: storage not big enough for size " << size << ": " - << storage_size << " (needed) vs " << tensor.numel() << " (actual)."; - - if (dim == 0 && tensor.numel() > 0) { - // If there's no specified dimension, return the first element of the - // storage. This behavior is consistent with eager. - return select_copy(view_copy_symint(tensor, {tensor.numel()}), 0, 0); - } - - if (storage_size == 0) { - // Return an empty tensor, if no storage is actually needed. - return empty_symint(c10::fromIntArrayRefSlow(size), tensor.scalar_type(), - /* layout= */ c10::nullopt, tensor.device(), - /* pin_memory= */ c10::nullopt, - /* memory_format= */ c10::nullopt); - } - - // At this point, the following is true: - XLA_CHECK(storage_size > 0); - XLA_CHECK(tensor.numel() > 0); - XLA_CHECK(dim > 0); - - // Index tensor for gathering the needed elements into contiguous data. - // - // PyTorch/XLA, by default, assumes dense and contiguous data. However, when - // specifying strides, that might not be the case. - // - // Therefore, we gather the elements selected by following the size, stride, - // and storage offset, materializing it into contiguous elements. - // - // In order to accomplish that, we create an index tensor. Specifically, we - // create an n-dimensional tensor (n is the number of dimensions of the - // output) of indices. Each element represent the at which position of the - // flattened tensor the desired element is in. - - // Example: arange(13).as_strided((2, 2, 2), (3, 4, 5)) - // - // Start with a 1-element n-dimensional tensor, initialized with 0: - // - // [[[0]]] - // - std::vector view_shape(dim, 1); - auto index_tensor = - at::tensor({storage_offset.value_or(self.storage_offset())}, - at::TensorOptions().dtype(at::kLong)) - .view(view_shape); - - // Then, add to the index_tensor the offset value introduced for each possible - // index of that corresponding dimension. - // - // - Iteration i=0: - // [[[0]]] + [[[0 * 3]], [[1 * 3]]] - // = [[[0 * 3]], [[1 * 3]]] - // = [[[0]], [[3]]] - // - // - Iteration i=1: - // [[[0]], [[3]]] + [[[0 * 4], [1 * 4]]] - // = [[[0 + 0 * 4], [0 + 1 * 4]], [[3 + 0 * 4], [3 + 1 * 4]]] - // = [[[0], [4]], [[3], [7]]] - // - // - Iteration i=2: - // [[[0], [4]], [[3], [7]]] + [[[0 * 5, 1 * 5]]] - // =[[[0 + 0 * 5, 0 + 1 * 5], [4 + 0 * 5, 4 + 1 * 5]], - // [[3 + 0 * 5, 3 + 1 * 5], [7 + 0 * 5, 7 + 1 * 5]]] - // =[[[0, 5], [4, 9]], [[3, 8], [7, 12]]] - for (int i = 0; i < dim; i++) { - auto vshape = view_shape; - vshape[i] = size[i]; - index_tensor = - index_tensor.add((at::arange(size[i]) * stride[i]).view(vshape)); - } - - // Finally, index the tensor with the computed indices. - return take(tensor, index_tensor.to(tensor.device())); + const at::Tensor& tensor = base.defined() ? base : self; + XLATensorPtr self_tensor = bridge::GetXlaTensor(tensor); + auto xsize = XlaHelpers::I64List(size); + auto xstride = XlaHelpers::I64List(stride); + if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride, + storage_offset.value_or(0))) { + return at::native::call_fallback_fn< + &xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride, + storage_offset); + } + // Sets the base tensor as tensor. + // Even though this function copies (without aliasing) tensor, it's still + // treated as a view function in the functionalization layer. + return bridge::AtenFromXlaTensor(bridge::SetBaseTensor( + tensor_methods::as_strided(self_tensor, std::move(xsize), + std::move(xstride), + XlaHelpers::I64Optional(storage_offset)), + tensor)); } at::Tensor XLANativeFunctions::as_strided_scatter(