From 7ac3d24e99089d7da80dafa11168639f2e042f1b Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Mon, 25 Mar 2024 17:51:56 -0700 Subject: [PATCH] =?UTF-8?q?Revert=20"[Reland]Change=20ATEN=20generator=20a?= =?UTF-8?q?rgument=20type=20to=20const=20std::optional& generator) { + const at::Tensor& self, c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -932,8 +932,7 @@ at::Tensor XLANativeFunctions::bernoulli( } at::Tensor XLANativeFunctions::bernoulli( - const at::Tensor& self, double p, - const std::optional& generator) { + const at::Tensor& self, double p, c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -945,7 +944,7 @@ at::Tensor XLANativeFunctions::bernoulli( at::Tensor& XLANativeFunctions::bernoulli_( at::Tensor& self, const at::Tensor& p, - const std::optional& generator) { + c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -1348,8 +1347,7 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self, } at::Tensor& XLANativeFunctions::exponential_( - at::Tensor& self, double lambd, - const std::optional& generator) { + at::Tensor& self, double lambd, c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2032,7 +2030,7 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self, at::Tensor XLANativeFunctions::multinomial( const at::Tensor& self, int64_t num_samples, bool replacement, - const std::optional& generator) { + c10::optional generator) { XLA_CHECK(num_samples > 0) << "Multinomial number of samples must be greater than 0"; XLA_CHECK(at::isFloatingType(self.scalar_type())) @@ -2346,9 +2344,8 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, bridge::GetXlaTensor(self), p, c10::nullopt, dim, keepdim)); } -at::Tensor XLANativeFunctions::normal( - const at::Tensor& mean, double std, - const std::optional& generator) { +at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, + c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2359,9 +2356,8 @@ at::Tensor XLANativeFunctions::normal( tensor_methods::normal(bridge::GetXlaTensor(mean), std)); } -at::Tensor XLANativeFunctions::normal( - double mean, const at::Tensor& std, - const std::optional& generator) { +at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, + c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2372,9 +2368,9 @@ at::Tensor XLANativeFunctions::normal( tensor_methods::normal(mean, bridge::GetXlaTensor(std))); } -at::Tensor XLANativeFunctions::normal( - const at::Tensor& mean, const at::Tensor& std, - const std::optional& generator) { +at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, + const at::Tensor& std, + c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2387,7 +2383,7 @@ at::Tensor XLANativeFunctions::normal( at::Tensor& XLANativeFunctions::normal_( at::Tensor& self, double mean, double std, - const std::optional& generator) { + c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2531,7 +2527,7 @@ std::tuple XLANativeFunctions::qr( // The value generated should be within (from, to]. at::Tensor& XLANativeFunctions::random_( at::Tensor& self, int64_t from, c10::optional to, - const std::optional& generator) { + c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< @@ -2551,8 +2547,7 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (0, to]. at::Tensor& XLANativeFunctions::random_( - at::Tensor& self, int64_t to, - const std::optional& generator) { + at::Tensor& self, int64_t to, c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2568,7 +2563,7 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (self_type_min, self_type_max). at::Tensor& XLANativeFunctions::random_( - at::Tensor& self, const std::optional& generator) { + at::Tensor& self, c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2741,7 +2736,7 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self, at::Tensor XLANativeFunctions::rrelu_with_noise( const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower, const at::Scalar& upper, bool training, - const std::optional& generator) { + c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { // The fallback path for rrelu_with_noise when training=true is wrong @@ -3280,7 +3275,7 @@ std::vector XLANativeFunctions::unbind_copy(const at::Tensor& self, at::Tensor& XLANativeFunctions::uniform_( at::Tensor& self, double from, double to, - const std::optional& generator) { + c10::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback,