From 174183e105633cbfbbcafccf106b96c873f75d2c Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 29 Feb 2024 17:26:06 +0800 Subject: [PATCH] Change ATEN generator argument type to const std::optional& --- torch_patches/.torch_pin | 1 + torch_xla/csrc/aten_xla_type.cpp | 28 ++++++++++++++-------------- 2 files changed, 15 insertions(+), 14 deletions(-) create mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 000000000000..06eab6aef1a3 --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#120076 diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 36791d0fc056..e9733a204c4d 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -829,7 +829,7 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self, } at::Tensor XLANativeFunctions::bernoulli( - const at::Tensor& self, c10::optional generator) { + const at::Tensor& self, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -841,7 +841,7 @@ at::Tensor XLANativeFunctions::bernoulli( } at::Tensor XLANativeFunctions::bernoulli( - const at::Tensor& self, double p, c10::optional generator) { + const at::Tensor& self, double p, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -853,7 +853,7 @@ at::Tensor XLANativeFunctions::bernoulli( at::Tensor& XLANativeFunctions::bernoulli_( at::Tensor& self, const at::Tensor& p, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -1244,7 +1244,7 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self, } at::Tensor& XLANativeFunctions::exponential_( - at::Tensor& self, double lambd, c10::optional generator) { + at::Tensor& self, double lambd, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -1927,7 +1927,7 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self, at::Tensor XLANativeFunctions::multinomial( const at::Tensor& self, int64_t num_samples, bool replacement, - c10::optional generator) { + const std::optional& generator) { XLA_CHECK(num_samples > 0) << "Multinomial number of samples must be greater than 0"; XLA_CHECK(at::isFloatingType(self.scalar_type())) @@ -2242,7 +2242,7 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, } at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2254,7 +2254,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, } at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2267,7 +2267,7 @@ at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, const at::Tensor& std, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2280,7 +2280,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, at::Tensor& XLANativeFunctions::normal_( at::Tensor& self, double mean, double std, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2417,7 +2417,7 @@ std::tuple XLANativeFunctions::qr( // The value generated should be within (from, to]. at::Tensor& XLANativeFunctions::random_( at::Tensor& self, int64_t from, c10::optional to, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2437,7 +2437,7 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (0, to]. at::Tensor& XLANativeFunctions::random_( - at::Tensor& self, int64_t to, c10::optional generator) { + at::Tensor& self, int64_t to, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2453,7 +2453,7 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (self_type_min, self_type_max). at::Tensor& XLANativeFunctions::random_( - at::Tensor& self, c10::optional generator) { + at::Tensor& self, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2594,7 +2594,7 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self, at::Tensor XLANativeFunctions::rrelu_with_noise( const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower, const at::Scalar& upper, bool training, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { // The fallback path for rrelu_with_noise when training=true is wrong @@ -3133,7 +3133,7 @@ std::vector XLANativeFunctions::unbind_copy(const at::Tensor& self, at::Tensor& XLANativeFunctions::uniform_( at::Tensor& self, double from, double to, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback,