diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d355d6c378f..6e98726063f 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3667,7 +3667,7 @@ at::Tensor XLANativeFunctions::upsample_nearest2d_backward( // our XLA lowering. XlaDeviceType hw_type = static_cast(grad_output_tensor->GetDevice().type()); - if (!CheckTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON) { + if (!CheckTpuDevice(hw_type) && !CheckNeuronDevice(hw_type)) { return at::native::call_fallback_fn< &xla_fallback, ATEN_OP(upsample_nearest2d_backward)>::call(grad_output, output_size, diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index 8e60c235a4b..016f125c332 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -32,7 +32,7 @@ bool IsSparseGather(const xla::Shape& input_shape, // to avoid gather on a single float on TPU. XlaDeviceType hw_type = static_cast(bridge::GetCurrentDevice().type()); - if (CheckTpuDevice(hw_type) || hw_type == XlaDeviceType::NEURON) { + if (CheckTpuDevice(hw_type) || CheckNeuronDevice(hw_type)) { // XLA_DENSE_GATHER_FACTOR can be used to finely control the // sparsity check. static int dense_gather_factor = diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index 71c2a63e686..a6490778477 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -116,4 +116,16 @@ bool CheckTpuDevice(XlaDeviceType hw_type) { return false; } +bool CheckNeuronDevice(XlaDeviceType hw_type) { + if (hw_type == XlaDeviceType::NEURON) { + return true; + } + + std::string pjrt_device = runtime::sys_util::GetEnvString("PJRT_DEVICE", ""); + if (hw_type == XlaDeviceType::SPMD) { + return pjrt_device == "NEURON"; + } + return false; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index 6006796a42f..385eef905a5 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -57,6 +57,9 @@ bool GetLockSpmdConfig(); // TODO(yeounoh) - see if we need to check for AOT compilation device type. bool CheckTpuDevice(XlaDeviceType hw_type); +// Return true if the physical device type is NEURON. +bool CheckNeuronDevice(XlaDeviceType hw_type); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_DEVICE_H_ diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index f56af984197..923f1152c9d 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -129,7 +129,7 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( if (UseBF16()) { return xla::PrimitiveType::BF16; } - if (DowncastBF16() || hw_type == XlaDeviceType::NEURON) { + if (DowncastBF16() || CheckNeuronDevice(hw_type)) { return xla::PrimitiveType::F32; } return xla::PrimitiveType::F64; @@ -137,11 +137,11 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16 : xla::PrimitiveType::F32; case xla::PrimitiveType::U16: - return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16 - : xla::PrimitiveType::U32; + return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U32 + : xla::PrimitiveType::U16; case xla::PrimitiveType::S16: - return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16 - : xla::PrimitiveType::S32; + return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32 + : xla::PrimitiveType::S16; case xla::PrimitiveType::S64: return xla::PrimitiveType::S64; case xla::PrimitiveType::U64: diff --git a/torch_xla/csrc/resize_ops.cpp b/torch_xla/csrc/resize_ops.cpp index 97fa335d9d6..0f5417ed8d1 100644 --- a/torch_xla/csrc/resize_ops.cpp +++ b/torch_xla/csrc/resize_ops.cpp @@ -271,7 +271,7 @@ xla::XlaOp LowerForward2d(const std::string& target, xla::XlaOp input, XlaDeviceType hw_type = static_cast(bridge::GetCurrentDevice().type()); - if (CheckTpuDevice(hw_type) || hw_type == XlaDeviceType::NEURON) { + if (CheckTpuDevice(hw_type) || CheckNeuronDevice(hw_type)) { // TPU uses custom call implementation resized = xla::CustomCall(input.builder(), target, {tinput}, resized_shape,