Skip to content

Commit

Permalink
Add xm.xla_device_kind() to return XLA device kind string. (#8493)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcuiaws authored Dec 17, 2024
1 parent 0154850 commit 0121444
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 0 deletions.
15 changes: 15 additions & 0 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,21 @@ def xla_device_hw(device: Union[str, torch.device]) -> str:
return real_device.split(':')[0]


def xla_device_kind(device: Optional[Union[str, torch.device]] = None) -> str:
"""Returns vendor-dependent string that uniquely identifies the kind of
device.
Args:
device (string or torch.device): The xla device
Returns:
A vendor-dependent device kind string.
"""
if device is None:
device = torch_xla._XLAC._xla_get_default_device()
return torch_xla._XLAC._xla_device_kind(str(device))


def xla_replication_devices(
local_devices: Optional[List[torch.device]] = None) -> List[str]:
real_devices = xla_real_devices(local_devices)
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,13 @@ void InitXlaModuleBindings(py::module m) {
return xla_devices;
},
py::arg("devices") = std::nullopt);
m.def(
"_xla_device_kind",
[](const std::string& device) {
auto xla_device = bridge::AtenDeviceToXlaDevice(device).toString();
return runtime::GetComputationClient()->GetDeviceKind(xla_device);
},
py::arg("device") = "");
m.def("_xla_set_replication_devices",
[](const std::vector<std::string>& devices) {
auto replication_devices =
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ class ComputationClient {

virtual torch_xla::DeviceType GetDeviceType() const = 0;

virtual std::string GetDeviceKind(const std::string& device) = 0;

virtual xla::PjRtPlatformId GetPlatformID() const = 0;

virtual absl::StatusOr<xla::PjRtDevice*> LookupAddressableDevice(
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,10 @@ int IfrtComputationClient::GetNumProcesses() const {
return max_process_index + 1;
};

std::string IfrtComputationClient::GetDeviceKind(const std::string& device) {
return std::string(StringToIfrtDevice(device)->Kind());
}

const absl::flat_hash_map<
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>
IfrtComputationClient::GetDeviceAttributes(const std::string& device) {
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class IfrtComputationClient : public ComputationClient {
absl::AsciiStrToUpper(client_->platform_name()));
};

std::string GetDeviceKind(const std::string& device) override;

xla::PjRtPlatformId GetPlatformID() const override {
return client_->platform_id();
}
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,10 @@ int PjRtComputationClient::GetNumProcesses() const {
return max_process_index + 1;
};

std::string PjRtComputationClient::GetDeviceKind(const std::string& device) {
return std::string(StringToPjRtDevice(device)->device_kind());
}

const absl::flat_hash_map<
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>
PjRtComputationClient::GetDeviceAttributes(const std::string& device) {
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class PjRtComputationClient : public ComputationClient {
absl::AsciiStrToUpper(client_->platform_name()));
};

std::string GetDeviceKind(const std::string& device) override;

xla::PjRtPlatformId GetPlatformID() const override {
return client_->platform_id();
}
Expand Down

0 comments on commit 0121444

Please sign in to comment.