From 2bf2f79e3cbd51c0000dd3a4e21e178f89fd697f Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 29 Feb 2024 17:26:06 +0800 Subject: [PATCH 1/2] Change ATEN generator argument type to const std::optional& --- torch_patches/.torch_pin | 1 + torch_xla/csrc/aten_xla_type.cpp | 28 ++++++++++++++-------------- 2 files changed, 15 insertions(+), 14 deletions(-) create mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 00000000000..06eab6aef1a --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#120076 diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 996a1b72eb3..282d797022a 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -893,7 +893,7 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self, } at::Tensor XLANativeFunctions::bernoulli( - const at::Tensor& self, c10::optional generator) { + const at::Tensor& self, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -905,7 +905,7 @@ at::Tensor XLANativeFunctions::bernoulli( } at::Tensor XLANativeFunctions::bernoulli( - const at::Tensor& self, double p, c10::optional generator) { + const at::Tensor& self, double p, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -917,7 +917,7 @@ at::Tensor XLANativeFunctions::bernoulli( at::Tensor& XLANativeFunctions::bernoulli_( at::Tensor& self, const at::Tensor& p, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -1308,7 +1308,7 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self, } at::Tensor& XLANativeFunctions::exponential_( - at::Tensor& self, double lambd, c10::optional generator) { + at::Tensor& self, double lambd, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -1991,7 +1991,7 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self, at::Tensor XLANativeFunctions::multinomial( const at::Tensor& self, int64_t num_samples, bool replacement, - c10::optional generator) { + const std::optional& generator) { XLA_CHECK(num_samples > 0) << "Multinomial number of samples must be greater than 0"; XLA_CHECK(at::isFloatingType(self.scalar_type())) @@ -2306,7 +2306,7 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, } at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2318,7 +2318,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, } at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2331,7 +2331,7 @@ at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, const at::Tensor& std, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2344,7 +2344,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, at::Tensor& XLANativeFunctions::normal_( at::Tensor& self, double mean, double std, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2481,7 +2481,7 @@ std::tuple XLANativeFunctions::qr( // The value generated should be within (from, to]. at::Tensor& XLANativeFunctions::random_( at::Tensor& self, int64_t from, c10::optional to, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2501,7 +2501,7 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (0, to]. at::Tensor& XLANativeFunctions::random_( - at::Tensor& self, int64_t to, c10::optional generator) { + at::Tensor& self, int64_t to, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2517,7 +2517,7 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (self_type_min, self_type_max). at::Tensor& XLANativeFunctions::random_( - at::Tensor& self, c10::optional generator) { + at::Tensor& self, const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2690,7 +2690,7 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self, at::Tensor XLANativeFunctions::rrelu_with_noise( const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower, const at::Scalar& upper, bool training, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { // The fallback path for rrelu_with_noise when training=true is wrong @@ -3229,7 +3229,7 @@ std::vector XLANativeFunctions::unbind_copy(const at::Tensor& self, at::Tensor& XLANativeFunctions::uniform_( at::Tensor& self, double from, double to, - c10::optional generator) { + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, From f78d082393730ac98cee5e2d8cea2717803c5b4c Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 1 Mar 2024 07:59:17 +0800 Subject: [PATCH 2/2] Format code --- torch_xla/csrc/aten_xla_type.cpp | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 282d797022a..bc6a64c8d47 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -905,7 +905,8 @@ at::Tensor XLANativeFunctions::bernoulli( } at::Tensor XLANativeFunctions::bernoulli( - const at::Tensor& self, double p, const std::optional& generator) { + const at::Tensor& self, double p, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -1308,7 +1309,8 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self, } at::Tensor& XLANativeFunctions::exponential_( - at::Tensor& self, double lambd, const std::optional& generator) { + at::Tensor& self, double lambd, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2305,8 +2307,9 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, bridge::GetXlaTensor(self), p, c10::nullopt, dim, keepdim)); } -at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, - const std::optional& generator) { +at::Tensor XLANativeFunctions::normal( + const at::Tensor& mean, double std, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2317,8 +2320,9 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, tensor_methods::normal(bridge::GetXlaTensor(mean), std)); } -at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, - const std::optional& generator) { +at::Tensor XLANativeFunctions::normal( + double mean, const at::Tensor& std, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2329,9 +2333,9 @@ at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, tensor_methods::normal(mean, bridge::GetXlaTensor(std))); } -at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, - const at::Tensor& std, - const std::optional& generator) { +at::Tensor XLANativeFunctions::normal( + const at::Tensor& mean, const at::Tensor& std, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2501,7 +2505,8 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (0, to]. at::Tensor& XLANativeFunctions::random_( - at::Tensor& self, int64_t to, const std::optional& generator) { + at::Tensor& self, int64_t to, + const std::optional& generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback,