diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 6e1936c258a..fb6b6bac634 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -204,6 +204,21 @@ def xla_device_hw(device: Union[str, torch.device]) -> str: return real_device.split(':')[0] +def xla_device_kind(device: Optional[Union[str, torch.device]] = None) -> str: + """Returns vendor-dependent string that uniquely identifies the kind of + device. + + Args: + device (string or torch.device): The xla device + + Returns: + A vendor-dependent device kind string. + """ + if device is None: + device = torch_xla._XLAC._xla_get_default_device() + return torch_xla._XLAC._xla_device_kind(str(device)) + + def xla_replication_devices( local_devices: Optional[List[torch.device]] = None) -> List[str]: real_devices = xla_real_devices(local_devices) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b348c898974..7f14d684e0e 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1497,6 +1497,13 @@ void InitXlaModuleBindings(py::module m) { return xla_devices; }, py::arg("devices") = std::nullopt); + m.def( + "_xla_device_kind", + [](const std::string& device) { + auto xla_device = bridge::AtenDeviceToXlaDevice(device).toString(); + return runtime::GetComputationClient()->GetDeviceKind(xla_device); + }, + py::arg("device") = ""); m.def("_xla_set_replication_devices", [](const std::vector& devices) { auto replication_devices = diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 22940ee6595..fcd1adcf51e 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -358,6 +358,8 @@ class ComputationClient { virtual torch_xla::DeviceType GetDeviceType() const = 0; + virtual std::string GetDeviceKind(const std::string& device) = 0; + virtual xla::PjRtPlatformId GetPlatformID() const = 0; virtual absl::StatusOr LookupAddressableDevice( diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index fd9e81bcb2d..4a2e528e26d 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -638,6 +638,10 @@ int IfrtComputationClient::GetNumProcesses() const { return max_process_index + 1; }; +std::string IfrtComputationClient::GetDeviceKind(const std::string& device) { + return std::string(StringToIfrtDevice(device)->Kind()); +} + const absl::flat_hash_map< std::string, torch_xla::runtime::ComputationClient::DeviceAttribute> IfrtComputationClient::GetDeviceAttributes(const std::string& device) { diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index fd34021393d..c83a705abbb 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -88,6 +88,8 @@ class IfrtComputationClient : public ComputationClient { absl::AsciiStrToUpper(client_->platform_name())); }; + std::string GetDeviceKind(const std::string& device) override; + xla::PjRtPlatformId GetPlatformID() const override { return client_->platform_id(); } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index eefd1251667..30e648919b3 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -941,6 +941,10 @@ int PjRtComputationClient::GetNumProcesses() const { return max_process_index + 1; }; +std::string PjRtComputationClient::GetDeviceKind(const std::string& device) { + return std::string(StringToPjRtDevice(device)->device_kind()); +} + const absl::flat_hash_map< std::string, torch_xla::runtime::ComputationClient::DeviceAttribute> PjRtComputationClient::GetDeviceAttributes(const std::string& device) { diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index ca2257f8295..6530ce768b4 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -95,6 +95,8 @@ class PjRtComputationClient : public ComputationClient { absl::AsciiStrToUpper(client_->platform_name())); }; + std::string GetDeviceKind(const std::string& device) override; + xla::PjRtPlatformId GetPlatformID() const override { return client_->platform_id(); }