Skip to content

Commit

Permalink
Revert "[Reland]Change ATEN generator argument type to const std::opt…
Browse files Browse the repository at this point in the history
…ional<Ge…"

This reverts commit 84e7feb.
  • Loading branch information
alanwaketan authored Mar 26, 2024
1 parent 22fe1dc commit 7ac3d24
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self,
}

at::Tensor XLANativeFunctions::bernoulli(
const at::Tensor& self, const std::optional<at::Generator>& generator) {
const at::Tensor& self, c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand All @@ -932,8 +932,7 @@ at::Tensor XLANativeFunctions::bernoulli(
}

at::Tensor XLANativeFunctions::bernoulli(
const at::Tensor& self, double p,
const std::optional<at::Generator>& generator) {
const at::Tensor& self, double p, c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -945,7 +944,7 @@ at::Tensor XLANativeFunctions::bernoulli(

at::Tensor& XLANativeFunctions::bernoulli_(
at::Tensor& self, const at::Tensor& p,
const std::optional<at::Generator>& generator) {
c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand Down Expand Up @@ -1348,8 +1347,7 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self,
}

at::Tensor& XLANativeFunctions::exponential_(
at::Tensor& self, double lambd,
const std::optional<at::Generator>& generator) {
at::Tensor& self, double lambd, c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down Expand Up @@ -2032,7 +2030,7 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self,

at::Tensor XLANativeFunctions::multinomial(
const at::Tensor& self, int64_t num_samples, bool replacement,
const std::optional<at::Generator>& generator) {
c10::optional<at::Generator> generator) {
XLA_CHECK(num_samples > 0)
<< "Multinomial number of samples must be greater than 0";
XLA_CHECK(at::isFloatingType(self.scalar_type()))
Expand Down Expand Up @@ -2346,9 +2344,8 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self,
bridge::GetXlaTensor(self), p, c10::nullopt, dim, keepdim));
}

at::Tensor XLANativeFunctions::normal(
const at::Tensor& mean, double std,
const std::optional<at::Generator>& generator) {
at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std,
c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2359,9 +2356,8 @@ at::Tensor XLANativeFunctions::normal(
tensor_methods::normal(bridge::GetXlaTensor(mean), std));
}

at::Tensor XLANativeFunctions::normal(
double mean, const at::Tensor& std,
const std::optional<at::Generator>& generator) {
at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std,
c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2372,9 +2368,9 @@ at::Tensor XLANativeFunctions::normal(
tensor_methods::normal(mean, bridge::GetXlaTensor(std)));
}

at::Tensor XLANativeFunctions::normal(
const at::Tensor& mean, const at::Tensor& std,
const std::optional<at::Generator>& generator) {
at::Tensor XLANativeFunctions::normal(const at::Tensor& mean,
const at::Tensor& std,
c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2387,7 +2383,7 @@ at::Tensor XLANativeFunctions::normal(

at::Tensor& XLANativeFunctions::normal_(
at::Tensor& self, double mean, double std,
const std::optional<at::Generator>& generator) {
c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down Expand Up @@ -2531,7 +2527,7 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::qr(
// The value generated should be within (from, to].
at::Tensor& XLANativeFunctions::random_(
at::Tensor& self, int64_t from, c10::optional<int64_t> to,
const std::optional<at::Generator>& generator) {
c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2551,8 +2547,7 @@ at::Tensor& XLANativeFunctions::random_(

// The value generated should be in (0, to].
at::Tensor& XLANativeFunctions::random_(
at::Tensor& self, int64_t to,
const std::optional<at::Generator>& generator) {
at::Tensor& self, int64_t to, c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand All @@ -2568,7 +2563,7 @@ at::Tensor& XLANativeFunctions::random_(

// The value generated should be in (self_type_min, self_type_max).
at::Tensor& XLANativeFunctions::random_(
at::Tensor& self, const std::optional<at::Generator>& generator) {
at::Tensor& self, c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down Expand Up @@ -2741,7 +2736,7 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self,
at::Tensor XLANativeFunctions::rrelu_with_noise(
const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower,
const at::Scalar& upper, bool training,
const std::optional<at::Generator>& generator) {
c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
// The fallback path for rrelu_with_noise when training=true is wrong
Expand Down Expand Up @@ -3280,7 +3275,7 @@ std::vector<at::Tensor> XLANativeFunctions::unbind_copy(const at::Tensor& self,

at::Tensor& XLANativeFunctions::uniform_(
at::Tensor& self, double from, double to,
const std::optional<at::Generator>& generator) {
c10::optional<at::Generator> generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down

0 comments on commit 7ac3d24

Please sign in to comment.