Skip to content

Commit

Permalink
CUDA is same as GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Oct 10, 2023
1 parent 96ec8bb commit 6d3a97a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
10 changes: 10 additions & 0 deletions torch_xla/csrc/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ std::string XlaDeviceTypeToString(XlaDeviceType hw_type) {
return "CPU";
case XlaDeviceType::GPU:
return "GPU";
case XlaDeviceType::CUDA:
return "CUDA";
case XlaDeviceType::ROCM:
return "ROCM";
case XlaDeviceType::TPU:
return "TPU";
case XlaDeviceType::XPU:
Expand Down Expand Up @@ -59,6 +63,12 @@ torch::lazy::BackendDevice ParseDeviceString(const std::string& device_spec) {
} else if (device_spec_parts[0] == "CPU") {
device_type->type =
static_cast<std::underlying_type_t<XlaDeviceType>>(XlaDeviceType::CPU);
} else if (device_spec_parts[0] == "ROCM") {
device_type->type =
static_cast<std::underlying_type_t<XlaDeviceType>>(XlaDeviceType::ROCM);
} else if (device_spec_parts[0] == "CUDA") {
device_type->type =
static_cast<std::underlying_type_t<XlaDeviceType>>(XlaDeviceType::CUDA);
} else if (device_spec_parts[0] == "GPU") {
device_type->type =
static_cast<std::underlying_type_t<XlaDeviceType>>(XlaDeviceType::GPU);
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace torch_xla {
// TODO(yeounoh) `SPMD` is a virtual device that defers data `TransferToServer`
// until after the paritioning pass. This avoids transfering the full input
// tensor to the device.
enum class XlaDeviceType { CPU, GPU, TPU, XPU, NEURON, SPMD };
enum class XlaDeviceType { CPU, CUDA, ROCM, GPU, TPU, XPU, NEURON, SPMD };

struct DeviceType : public torch::lazy::BackendDeviceType {
DeviceType() { type = static_cast<int>(XlaDeviceType::CPU); }
Expand Down

0 comments on commit 6d3a97a

Please sign in to comment.