Skip to content

Commit

Permalink
Re-run linter due to wrong version
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Oct 25, 2023
1 parent fd75240 commit 2b5f14a
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 97 deletions.
7 changes: 4 additions & 3 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ void CheckSubOperandTypes(at::ScalarType type1, at::ScalarType type2) {

c10::optional<at::ScalarType> PromoteIntegralType(
at::ScalarType src_dtype, const c10::optional<at::ScalarType>& opt_dtype) {
return opt_dtype.has_value() ? opt_dtype.value()
: at::isIntegralType(src_dtype, /*includeBool=*/true) ? at::kLong
: opt_dtype;
return opt_dtype.has_value()
? opt_dtype.value()
: at::isIntegralType(src_dtype, /*includeBool=*/true) ? at::kLong
: opt_dtype;
}

bool IsTypeWithLargerRangeThanLong(torch::ScalarType dtype) {
Expand Down
176 changes: 83 additions & 93 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -969,28 +969,27 @@ void InitXlaModuleBindings(py::module m) {
ShardingUtil::ShardingType(sharding_type)),
global_shape, minibatch);
}));
m.def(
"_xla_tensors_from_aten",
[](const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices,
const std::optional<std::vector<XLATensor::ShardingSpecPtr>>&
shardings) {
std::vector<at::Tensor> result;
{
NoGilSection nogil;
std::vector<at::Tensor> xla_tensors =
GetXlaTensorsFromAten(tensors, devices, shardings);
result.reserve(xla_tensors.size());
for (size_t i = 0; i < xla_tensors.size(); ++i) {
result.push_back(torch::autograd::make_variable(
xla_tensors[i],
/*requires_grad=*/tensors.at(i).requires_grad()));
m.def("_xla_tensors_from_aten",
[](const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices,
const std::optional<std::vector<XLATensor::ShardingSpecPtr>>&
shardings) {
std::vector<at::Tensor> result;
{
NoGilSection nogil;
std::vector<at::Tensor> xla_tensors =
GetXlaTensorsFromAten(tensors, devices, shardings);
result.reserve(xla_tensors.size());
for (size_t i = 0; i < xla_tensors.size(); ++i) {
result.push_back(torch::autograd::make_variable(
xla_tensors[i],
/*requires_grad=*/tensors.at(i).requires_grad()));
}
}
}
return result;
},
py::arg("tensors"), py::arg("devices"),
py::arg("shardings") = py::none());
return result;
},
py::arg("tensors"), py::arg("devices"),
py::arg("shardings") = py::none());
m.def("_xla_get_cpu_tensors", [](const std::vector<at::Tensor>& tensors) {
std::vector<at::Tensor> result;
{
Expand Down Expand Up @@ -1290,51 +1289,45 @@ void InitXlaModuleBindings(py::module m) {
}
return list;
});
m.def(
"_xla_set_rng_seed",
[](uint64_t seed, const std::string& device) {
SetRngSeed(seed, device);
},
py::arg("seed") = 101, py::arg("device") = "");
m.def(
"_xla_get_rng_seed",
[](const std::string& device) { return GetRngSeed(device); },
py::arg("device") = "");
m.def(
"_xla_sync_multi",
[](const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices, bool wait,
bool sync_xla_data) {
NoGilSection nogil;
SyncTensors(tensors, devices, wait, sync_xla_data);
},
py::arg("tensors"), py::arg("devices"), py::arg("wait") = true,
py::arg("sync_xla_data") = true);
m.def(
"_xla_warm_up_cache",
[](const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices) {
NoGilSection nogil;
SyncTensors(tensors, devices, /*wait=*/false, /*sync_xla_data=*/false,
/*warm_up_cache_only=*/true);
},
py::arg("tensors"), py::arg("devices"));
m.def(
"_xla_sync_live_tensors",
[](const std::string& device, const std::vector<std::string>& devices,
bool wait) {
NoGilSection nogil;
SyncLiveTensors(device, devices, wait);
},
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
m.def(
"_xla_step_marker",
[](const std::string& device, const std::vector<std::string>& devices,
bool wait) {
NoGilSection nogil;
StepMarker(device, devices, wait);
},
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
m.def("_xla_set_rng_seed",
[](uint64_t seed, const std::string& device) {
SetRngSeed(seed, device);
},
py::arg("seed") = 101, py::arg("device") = "");
m.def("_xla_get_rng_seed",
[](const std::string& device) { return GetRngSeed(device); },
py::arg("device") = "");
m.def("_xla_sync_multi",
[](const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices, bool wait,
bool sync_xla_data) {
NoGilSection nogil;
SyncTensors(tensors, devices, wait, sync_xla_data);
},
py::arg("tensors"), py::arg("devices"), py::arg("wait") = true,
py::arg("sync_xla_data") = true);
m.def("_xla_warm_up_cache",
[](const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices) {
NoGilSection nogil;
SyncTensors(tensors, devices, /*wait=*/false, /*sync_xla_data=*/false,
/*warm_up_cache_only=*/true);
},
py::arg("tensors"), py::arg("devices"));
m.def("_xla_sync_live_tensors",
[](const std::string& device, const std::vector<std::string>& devices,
bool wait) {
NoGilSection nogil;
SyncLiveTensors(device, devices, wait);
},
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
m.def("_xla_step_marker",
[](const std::string& device, const std::vector<std::string>& devices,
bool wait) {
NoGilSection nogil;
StepMarker(device, devices, wait);
},
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
m.def("_get_stablehlo",
[](const std::vector<at::Tensor>& tensors, const std::string& device,
const std::vector<std::string>& devices,
Expand Down Expand Up @@ -1371,19 +1364,18 @@ void InitXlaModuleBindings(py::module m) {
}
return retlist;
});
m.def(
"_xla_wait_device_ops",
[](const std::vector<std::string>& devices) {
NoGilSection nogil;
XLAGraphExecutor::Get()->WaitDeviceOps(devices);
if (UseVirtualDevice()) {
std::vector<std::string> spmd_device = {"SPMD:0"};
runtime::GetComputationClient()->WaitDeviceOps(spmd_device);
} else {
runtime::GetComputationClient()->WaitDeviceOps(devices);
}
},
py::arg("devices"));
m.def("_xla_wait_device_ops",
[](const std::vector<std::string>& devices) {
NoGilSection nogil;
XLAGraphExecutor::Get()->WaitDeviceOps(devices);
if (UseVirtualDevice()) {
std::vector<std::string> spmd_device = {"SPMD:0"};
runtime::GetComputationClient()->WaitDeviceOps(spmd_device);
} else {
runtime::GetComputationClient()->WaitDeviceOps(devices);
}
},
py::arg("devices"));
m.def("_xla_counter_names", []() {
auto counter_names = torch::lazy::GetCounterNames();
auto xla_counter_names = runtime::metrics::GetCounterNames();
Expand Down Expand Up @@ -1448,23 +1440,21 @@ void InitXlaModuleBindings(py::module m) {
torch::lazy::MetricsArena::Get()->ResetMetrics();
runtime::metrics::ClearMetrics();
});
m.def(
"_xla_tensors_report",
[](size_t nodes_threshold, const std::string& device) {
return GetLiveTensorsReport(nodes_threshold, device);
},
py::arg("nodes_threshold") = 100, py::arg("device") = "");
m.def("_xla_tensors_report",
[](size_t nodes_threshold, const std::string& device) {
return GetLiveTensorsReport(nodes_threshold, device);
},
py::arg("nodes_threshold") = 100, py::arg("device") = "");
m.def("_xla_memory_info", [](const std::string& device) -> py::object {
return GetMemoryInfo(device);
});
m.def(
"_xla_set_use_full_mat_mul_precision",
[](bool use_full_mat_mul_precision) {
XlaHelpers::set_mat_mul_precision(use_full_mat_mul_precision
? xla::PrecisionConfig::HIGHEST
: xla::PrecisionConfig::DEFAULT);
},
py::arg("use_full_mat_mul_precision") = true);
m.def("_xla_set_use_full_mat_mul_precision",
[](bool use_full_mat_mul_precision) {
XlaHelpers::set_mat_mul_precision(
use_full_mat_mul_precision ? xla::PrecisionConfig::HIGHEST
: xla::PrecisionConfig::DEFAULT);
},
py::arg("use_full_mat_mul_precision") = true);

py::class_<xla::XlaBuilder, op_builder::BuilderPtr>(m, "XlaBuilder");
py::class_<op_builder::Op, op_builder::OpPtr>(m, "XlaOp");
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/ops/custom_mark_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace torch_xla {
class CustomMarkSharding : public XlaNode {
public:
// Make a custom call to Sharding.
CustomMarkSharding(const torch::lazy::Value& input, const torch::lazy::Value& sharding);
CustomMarkSharding(const torch::lazy::Value& input,
const torch::lazy::Value& sharding);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

Expand Down

0 comments on commit 2b5f14a

Please sign in to comment.