Skip to content

Commit

Permalink
Implement GetMemoryInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed May 20, 2024
1 parent 961c22a commit d9d0eaa
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 13 deletions.
12 changes: 12 additions & 0 deletions test/pjrt/test_runtime_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from absl.testing import absltest, parameterized
import torch_xla
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
Expand Down Expand Up @@ -252,6 +253,17 @@ def test_execute_time_metric(self):
f"Expected exectue time of {i} to take more than "
f"{expected_time_seconds} seconds, got {v / 1e9} seconds")

@staticmethod
def _memory_usage():
return torch_xla._XLAC._xla_memory_info(str(torch_xla.device()))

# TODO: Create a public API and test that instead
def test_memory_usage(self):
results = pjrt.run_multiprocess(self._memory_usage)
for usage in results.values():
self.assertIn('bytes_used', usage)
self.assertIn('bytes_limit', usage)


if __name__ == '__main__':
absltest.main()
10 changes: 5 additions & 5 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,8 @@ py::dict GetMemoryInfo(const std::string& device_str) {
runtime::GetComputationClient()->GetMemoryInfo(device.toString());
}
auto py_dict = py::dict();
py_dict["kb_free"] = mem_info.kb_free;
py_dict["kb_total"] = mem_info.kb_total;
py_dict["bytes_used"] = mem_info.bytes_used;
py_dict["bytes_limit"] = mem_info.bytes_limit;
return py_dict;
}

Expand Down Expand Up @@ -1825,9 +1825,9 @@ void InitXlaModuleBindings(py::module m) {
return GetLiveTensorsReport(nodes_threshold, device);
},
py::arg("nodes_threshold") = 100, py::arg("device") = "");
m.def("_xla_memory_info", [](const std::string& device) -> py::object {
return GetMemoryInfo(device);
});
py::class_<runtime::ComputationClient::MemoryInfo>(m, "MemoryInfo");
m.def("_xla_memory_info",
[](const std::string& device) { return GetMemoryInfo(device); });
m.def(
"_xla_set_use_full_mat_mul_precision",
[](bool use_full_mat_mul_precision) {
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ class ComputationClient {
struct ExecuteReplicatedOptions : public ClientExecuteOptions {};

struct MemoryInfo {
int64_t kb_free = 0;
int64_t kb_total = 0;
int64_t bytes_used = 0;
int64_t bytes_limit = 0;
};

virtual ~ComputationClient() {}
Expand Down
15 changes: 14 additions & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "absl/strings/ascii.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/span.h"
#include "pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_hash.h"
Expand Down Expand Up @@ -926,5 +925,19 @@ std::map<std::string, Metric> PjRtComputationClient::GetMetrics() const {
return {};
}

ComputationClient::MemoryInfo PjRtComputationClient::GetMemoryInfo(
const std::string& device) {
XLA_CHECK_NE(device, spmd_device_str)
<< "MemoryInfo not supported for SPMD virtual device.";
xla::PjRtDevice* pjrt_device =
PjRtComputationClient::StringToPjRtDevice(device);
tsl::AllocatorStats stats = pjrt_device->GetAllocatorStats().value();

return {
stats.bytes_in_use,
*stats.bytes_limit,
};
}

} // namespace runtime
} // namespace torch_xla
6 changes: 1 addition & 5 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,7 @@ class PjRtComputationClient : public ComputationClient {

bool CoordinatorInitialized() const override;

// NOT IMPLEMENTED

MemoryInfo GetMemoryInfo(const std::string& device) override {
XLA_ERROR() << __FUNCTION__ << " not implemented";
};
MemoryInfo GetMemoryInfo(const std::string& device) override;

private:
std::unique_ptr<xla::PjRtClient> client_;
Expand Down

0 comments on commit d9d0eaa

Please sign in to comment.