From d1396305a3661220a404a4ab14fb7b3fb9a6f028 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 22 Mar 2024 12:31:23 +0800 Subject: [PATCH] Change ATEN generator argument type to const std::optional& --- torch_patches/.torch_pin | 1 + torch_xla/csrc/aten_xla_type.cpp | 41 ++++++++++++++++++-------------- 2 files changed, 24 insertions(+), 18 deletions(-) create mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 00000000000..06eab6aef1a --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#120076 diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 116635c688b..27e148b3c4d 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -920,7 +920,7 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self, } at::Tensor XLANativeFunctions::bernoulli( - const at::Tensor& self, c10::optional generator) { + const at::Tensor& self, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -932,7 +932,8 @@ at::Tensor XLANativeFunctions::bernoulli( } at::Tensor XLANativeFunctions::bernoulli( - const at::Tensor& self, double p, c10::optional generator) { + const at::Tensor& self, double p, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -944,7 +945,7 @@ at::Tensor XLANativeFunctions::bernoulli( at::Tensor& XLANativeFunctions::bernoulli_( at::Tensor& self, const at::Tensor& p, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -1347,7 +1348,8 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self, } at::Tensor& XLANativeFunctions::exponential_( - at::Tensor& self, double lambd, c10::optional generator) { + at::Tensor& self, double lambd, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2030,7 +2032,7 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self, at::Tensor XLANativeFunctions::multinomial( const at::Tensor& self, int64_t num_samples, bool replacement, - c10::optional generator) { + const std::optional& generator) { XLA_CHECK(num_samples > 0) << "Multinomial number of samples must be greater than 0"; XLA_CHECK(at::isFloatingType(self.scalar_type())) @@ -2344,8 +2346,9 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, bridge::GetXlaTensor(self), p, c10::nullopt, dim, keepdim)); } -at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, - c10::optional generator) { +at::Tensor XLANativeFunctions::normal( + const at::Tensor& mean, double std, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2356,8 +2359,9 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, tensor_methods::normal(bridge::GetXlaTensor(mean), std)); } -at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, - c10::optional generator) { +at::Tensor XLANativeFunctions::normal( + double mean, const at::Tensor& std, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2368,9 +2372,9 @@ at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, tensor_methods::normal(mean, bridge::GetXlaTensor(std))); } -at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, - const at::Tensor& std, - c10::optional generator) { +at::Tensor XLANativeFunctions::normal( + const at::Tensor& mean, const at::Tensor& std, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2383,7 +2387,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, at::Tensor& XLANativeFunctions::normal_( at::Tensor& self, double mean, double std, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2527,7 +2531,7 @@ std::tuple XLANativeFunctions::qr( // The value generated should be within (from, to]. at::Tensor& XLANativeFunctions::random_( at::Tensor& self, int64_t from, c10::optional to, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2547,7 +2551,8 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (0, to]. at::Tensor& XLANativeFunctions::random_( - at::Tensor& self, int64_t to, c10::optional generator) { + at::Tensor& self, int64_t to, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2563,7 +2568,7 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (self_type_min, self_type_max). at::Tensor& XLANativeFunctions::random_( - at::Tensor& self, c10::optional generator) { + at::Tensor& self, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2736,7 +2741,7 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self, at::Tensor XLANativeFunctions::rrelu_with_noise( const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower, const at::Scalar& upper, bool training, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { // The fallback path for rrelu_with_noise when training=true is wrong @@ -3275,7 +3280,7 @@ std::vector XLANativeFunctions::unbind_copy(const at::Tensor& self, at::Tensor& XLANativeFunctions::uniform_( at::Tensor& self, double from, double to, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback,