Skip to content

Commit

Permalink
Change ATEN generator argument type to const std::optional<Generator>&
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyever committed Mar 23, 2024
1 parent 42ecc29 commit d139630
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
1 change: 1 addition & 0 deletions torch_patches/.torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#120076
41 changes: 23 additions & 18 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self,
}

at::Tensor XLANativeFunctions::bernoulli(
const at::Tensor& self, c10::optional<at::Generator> generator) {
const at::Tensor& self, const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand All @@ -932,7 +932,8 @@ at::Tensor XLANativeFunctions::bernoulli(
}

at::Tensor XLANativeFunctions::bernoulli(
const at::Tensor& self, double p, c10::optional<at::Generator> generator) {
const at::Tensor& self, double p,
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -944,7 +945,7 @@ at::Tensor XLANativeFunctions::bernoulli(

at::Tensor& XLANativeFunctions::bernoulli_(
at::Tensor& self, const at::Tensor& p,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand Down Expand Up @@ -1347,7 +1348,8 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self,
}

at::Tensor& XLANativeFunctions::exponential_(
at::Tensor& self, double lambd, c10::optional<at::Generator> generator) {
at::Tensor& self, double lambd,
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down Expand Up @@ -2030,7 +2032,7 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self,

at::Tensor XLANativeFunctions::multinomial(
const at::Tensor& self, int64_t num_samples, bool replacement,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
XLA_CHECK(num_samples > 0)
<< "Multinomial number of samples must be greater than 0";
XLA_CHECK(at::isFloatingType(self.scalar_type()))
Expand Down Expand Up @@ -2344,8 +2346,9 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self,
bridge::GetXlaTensor(self), p, c10::nullopt, dim, keepdim));
}

at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std,
c10::optional<at::Generator> generator) {
at::Tensor XLANativeFunctions::normal(
const at::Tensor& mean, double std,
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2356,8 +2359,9 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std,
tensor_methods::normal(bridge::GetXlaTensor(mean), std));
}

at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std,
c10::optional<at::Generator> generator) {
at::Tensor XLANativeFunctions::normal(
double mean, const at::Tensor& std,
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2368,9 +2372,9 @@ at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std,
tensor_methods::normal(mean, bridge::GetXlaTensor(std)));
}

at::Tensor XLANativeFunctions::normal(const at::Tensor& mean,
const at::Tensor& std,
c10::optional<at::Generator> generator) {
at::Tensor XLANativeFunctions::normal(
const at::Tensor& mean, const at::Tensor& std,
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2383,7 +2387,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean,

at::Tensor& XLANativeFunctions::normal_(
at::Tensor& self, double mean, double std,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down Expand Up @@ -2527,7 +2531,7 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::qr(
// The value generated should be within (from, to].
at::Tensor& XLANativeFunctions::random_(
at::Tensor& self, int64_t from, c10::optional<int64_t> to,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<
Expand All @@ -2547,7 +2551,8 @@ at::Tensor& XLANativeFunctions::random_(

// The value generated should be in (0, to].
at::Tensor& XLANativeFunctions::random_(
at::Tensor& self, int64_t to, c10::optional<at::Generator> generator) {
at::Tensor& self, int64_t to,
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand All @@ -2563,7 +2568,7 @@ at::Tensor& XLANativeFunctions::random_(

// The value generated should be in (self_type_min, self_type_max).
at::Tensor& XLANativeFunctions::random_(
at::Tensor& self, c10::optional<at::Generator> generator) {
at::Tensor& self, const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down Expand Up @@ -2736,7 +2741,7 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self,
at::Tensor XLANativeFunctions::rrelu_with_noise(
const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower,
const at::Scalar& upper, bool training,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
// The fallback path for rrelu_with_noise when training=true is wrong
Expand Down Expand Up @@ -3275,7 +3280,7 @@ std::vector<at::Tensor> XLANativeFunctions::unbind_copy(const at::Tensor& self,

at::Tensor& XLANativeFunctions::uniform_(
at::Tensor& self, double from, double to,
c10::optional<at::Generator> generator) {
const std::optional<at::Generator>& generator) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (generator.has_value() && generator->defined()) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
Expand Down

0 comments on commit d139630

Please sign in to comment.