Skip to content

Commit

Permalink
Change ATEN generator argument type to const std::optional<Generator>&
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyever committed Feb 29, 2024
1 parent 5405c4d commit 174183e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
1 change: 1 addition & 0 deletions torch_patches/.torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#120076
28 changes: 14 additions & 14 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self,
}

at::Tensor XLANativeFunctions::bernoulli(
const at::Tensor& self, c10::optional<at::Generator> generator) {
const at::Tensor& self, const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand All @@ -841,7 +841,7 @@ at::Tensor XLANativeFunctions::bernoulli(
}

at::Tensor XLANativeFunctions::bernoulli(
const at::Tensor& self, double p, c10::optional<at::Generator> generator) {
const at::Tensor& self, double p, const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -853,7 +853,7 @@ at::Tensor XLANativeFunctions::bernoulli(

at::Tensor& XLANativeFunctions::bernoulli_(
at::Tensor& self, const at::Tensor& p,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand Down Expand Up @@ -1244,7 +1244,7 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self,
}

at::Tensor& XLANativeFunctions::exponential_(
at::Tensor& self, double lambd, c10::optional<at::Generator> generator) {
at::Tensor& self, double lambd, const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down Expand Up @@ -1927,7 +1927,7 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self,

at::Tensor XLANativeFunctions::multinomial(
const at::Tensor& self, int64_t num_samples, bool replacement,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
XLA_CHECK(num_samples > 0)
<< "Multinomial number of samples must be greater than 0";
XLA_CHECK(at::isFloatingType(self.scalar_type()))
Expand Down Expand Up @@ -2242,7 +2242,7 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self,
}

at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2254,7 +2254,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std,
}

at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2267,7 +2267,7 @@ at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std,

at::Tensor XLANativeFunctions::normal(const at::Tensor& mean,
const at::Tensor& std,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2280,7 +2280,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean,

at::Tensor& XLANativeFunctions::normal_(
at::Tensor& self, double mean, double std,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down Expand Up @@ -2417,7 +2417,7 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::qr(
// The value generated should be within (from, to].
at::Tensor& XLANativeFunctions::random_(
at::Tensor& self, int64_t from, c10::optional<int64_t> to,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2437,7 +2437,7 @@ at::Tensor& XLANativeFunctions::random_(

// The value generated should be in (0, to].
at::Tensor& XLANativeFunctions::random_(
at::Tensor& self, int64_t to, c10::optional<at::Generator> generator) {
at::Tensor& self, int64_t to, const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand All @@ -2453,7 +2453,7 @@ at::Tensor& XLANativeFunctions::random_(

// The value generated should be in (self_type_min, self_type_max).
at::Tensor& XLANativeFunctions::random_(
at::Tensor& self, c10::optional<at::Generator> generator) {
at::Tensor& self, const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down Expand Up @@ -2594,7 +2594,7 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self,
at::Tensor XLANativeFunctions::rrelu_with_noise(
const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower,
const at::Scalar& upper, bool training,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
// The fallback path for rrelu_with_noise when training=true is wrong
Expand Down Expand Up @@ -3133,7 +3133,7 @@ std::vector<at::Tensor> XLANativeFunctions::unbind_copy(const at::Tensor& self,

at::Tensor& XLANativeFunctions::uniform_(
at::Tensor& self, double from, double to,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down

0 comments on commit 174183e

Please sign in to comment.