Skip to content

Commit

Permalink
better names
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 8, 2023
1 parent 40360de commit cdca4a8
Show file tree
Hide file tree
Showing 11 changed files with 31 additions and 31 deletions.
2 changes: 1 addition & 1 deletion test/cpp/cpp_test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ std::vector<at::Tensor> Fetch(
std::vector<at::Tensor> tensors;
for (auto& literal : literals) {
tensors.push_back(MakeTensorFromXlaLiteral(
literal, GetHostScalarType(literal.shape().element_type())));
literal, MaybeUpcastForHost(literal.shape().element_type())));
}
return tensors;
}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/convert_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from) {
if (from == xla::PrimitiveType::PRED) {
torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice();
op = ConvertTo(op, from,
GetDevicePrimitiveType(xla::PrimitiveType::U8, xla_device),
MaybeDowncastForDevice(xla::PrimitiveType::U8, xla_device),
&xla_device);
}
return op;
Expand Down
17 changes: 8 additions & 9 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) {
}


xla::PrimitiveType GetDevicePrimitiveType(
xla::PrimitiveType MaybeDowncastForDevice(
xla::PrimitiveType type, const torch::lazy::BackendDevice& device) {
XlaDeviceType hw_type = static_cast<XlaDeviceType>(device.type());
switch (type) {
Expand Down Expand Up @@ -197,7 +197,13 @@ xla::PrimitiveType GetDevicePrimitiveType(
}
}

at::ScalarType GetHostScalarType(xla::PrimitiveType xla_type) {
xla::PrimitiveType MaybeDowncastForDevice(
at::ScalarType scalar_type, const torch::lazy::BackendDevice& device) {
xla::PrimitiveType xla_type = XlaTypeFromTorchType(scalar_type);
return MaybeDowncastForDevice(xla_type, device);
}

at::ScalarType MaybeUpcastForHost(xla::PrimitiveType xla_type) {
at::ScalarType scalar_type = TorchTypeFromXlaType(xla_type);
switch (scalar_type) {
case at::ScalarType::BFloat16:
Expand All @@ -214,11 +220,4 @@ at::ScalarType GetHostScalarType(xla::PrimitiveType xla_type) {
}
}

xla::PrimitiveType GetXlaTypeFromTensorType(
at::ScalarType scalar_type, const torch::lazy::BackendDevice& device) {
// XlaDeviceType hw_type = static_cast<XlaDeviceType>(device.type());
xla::PrimitiveType xla_type = XlaTypeFromTorchType(scalar_type);
return GetDevicePrimitiveType(xla_type, device);
}

}
11 changes: 6 additions & 5 deletions torch_xla/csrc/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type);

xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type);

// TODO better names
xla::PrimitiveType GetDevicePrimitiveType(
// Downcast type to be compatible with device if necessary.
xla::PrimitiveType MaybeDowncastForDevice(
xla::PrimitiveType type, const torch::lazy::BackendDevice& device);

at::ScalarType GetHostScalarType(xla::PrimitiveType xla_type);

xla::PrimitiveType GetXlaTypeFromTensorType(
xla::PrimitiveType MaybeDowncastForDevice(
at::ScalarType scalar_type, const torch::lazy::BackendDevice& device);

// Upcast type to original PyTorch type.
at::ScalarType MaybeUpcastForHost(xla::PrimitiveType xla_type);

}

#endif // XLA_TORCH_XLA_CSRC_DTYPE_H_
2 changes: 1 addition & 1 deletion torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ xla::StatusOr<xla::XlaComputation> XlaHelpers::WrapXlaComputation(
}

torch::lazy::Shape XlaHelpers::ConvertXlaShapeToLazy(const xla::Shape& shape) {
at::ScalarType scalar_type = GetHostScalarType(shape.element_type());
at::ScalarType scalar_type = MaybeUpcastForHost(shape.element_type());
c10::optional<std::vector<bool>> is_symbolic = c10::nullopt;
if (shape.is_dynamic()) {
std::vector<bool> xla_dynamic_dimensions =
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ class PyLoweringContext {
xla::Literal& literal = literals[i];
xla::XlaOp op = lowering_ctx.GetParameter(device_data[i]);
at::ScalarType dtype =
GetHostScalarType(literal.shape().element_type());
MaybeUpcastForHost(literal.shape().element_type());
at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype);
results[param_ids[i]] = input;
}
Expand Down Expand Up @@ -1760,7 +1760,7 @@ void InitXlaModuleBindings(py::module m) {
shards.push_back(
XlaDataToTensors(
{shard_handle},
GetHostScalarType(shard_handle->shape().element_type()))
MaybeUpcastForHost(shard_handle->shape().element_type()))
.front());
str_devices.push_back(shard_handle->device());
}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ int64_t XLATensor::size(int64_t dim) const {
at::ScalarType XLATensor::dtype() const {
return data()->logical_element_type
? *data()->logical_element_type
: GetHostScalarType(shape().get().element_type());
: MaybeUpcastForHost(shape().get().element_type());
}

c10::optional<at::ScalarType> XLATensor::dtype_optional() const {
Expand Down
10 changes: 5 additions & 5 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,9 +1027,9 @@ XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other,
bool input_is_float = xla::primitive_util::IsFloatingPointType(input_type);
bool other_is_float = xla::primitive_util::IsFloatingPointType(other_type);
if (input_is_float && !other_is_float) {
scalar_type = GetHostScalarType(input_type);
scalar_type = MaybeUpcastForHost(input_type);
} else if (!input_is_float && other_is_float) {
scalar_type = GetHostScalarType(other_type);
scalar_type = MaybeUpcastForHost(other_type);
}
// We need to cast both input and other to float to perform true divide, floor
// divide and trunc divide.
Expand Down Expand Up @@ -1074,7 +1074,7 @@ XLATensorPtr div(const XLATensorPtr& input, const at::Scalar& other) {
xla::PrimitiveType input_type = input->shape().get().element_type();
bool input_is_float = xla::primitive_util::IsFloatingPointType(input_type);
if (input_is_float) {
scalar_type = GetHostScalarType(input_type);
scalar_type = MaybeUpcastForHost(input_type);
}
torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type);
torch::lazy::Value other_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
Expand Down Expand Up @@ -1182,7 +1182,7 @@ XLATensorPtr eye(int64_t lines, int64_t cols,
void eye_out(XLATensorPtr& out, int64_t lines, int64_t cols) {
out->SetIrValue(
Identity(lines, cols >= 0 ? cols : lines,
GetDevicePrimitiveType(out->shape().get().element_type(),
MaybeDowncastForDevice(out->shape().get().element_type(),
out->GetDevice())));
}

Expand Down Expand Up @@ -2056,7 +2056,7 @@ XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent) {
torch::lazy::NodePtr pow_node = Pow(input_node, exponent->GetIrValue());
at::ScalarType input_dtype = GetScalarType(input);
at::ScalarType exp_dtype = exponent->dtype();
at::ScalarType promoted_dtype = GetHostScalarType(XlaHelpers::PromoteType(
at::ScalarType promoted_dtype = MaybeUpcastForHost(XlaHelpers::PromoteType(
XlaTypeFromTorchType(input_dtype), XlaTypeFromTorchType(exp_dtype)));
return exponent->CreateFrom(pow_node, promoted_dtype);
}
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ void TensorToBuffer(const at::Tensor& tensor, const xla::Shape& dest_shape,
at::Tensor contiguous_tensor = tensor.contiguous();
xla::Shape src_shape = MakeTorchTensorLayout(
XlaHelpers::I64List(contiguous_tensor.sizes()), /*dynamic_dimensions=*/{},
GetXlaTypeFromTensorType(contiguous_tensor.type().scalarType(), device));
MaybeDowncastForDevice(contiguous_tensor.type().scalarType(), device));
CopyTensors<SType, DType>(contiguous_tensor.data_ptr<SType>(), src_shape,
dest_buffer, dest_buffer_size, dest_shape);
}
Expand Down Expand Up @@ -761,7 +761,7 @@ xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape,
auto dimensions = XlaHelpers::I64List(tensor.sizes());
computed_shape = MakeTorchTensorLayout(
dimensions, /*dynamic_dimensions=*/{},
GetXlaTypeFromTensorType(tensor.type().scalarType(), xla_device));
MaybeDowncastForDevice(tensor.type().scalarType(), xla_device));
shape = &computed_shape;
}
xla::Literal literal(*shape);
Expand Down Expand Up @@ -859,13 +859,13 @@ xla::Shape CreateComputationShapeFromTensor(
xla::PrimitiveType GetXlaPrimitiveTypeForCurrentDevice(
xla::PrimitiveType xla_type) {
torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice();
return GetDevicePrimitiveType(xla_type, xla_device);
return MaybeDowncastForDevice(xla_type, xla_device);
}

xla::PrimitiveType MakeXlaPrimitiveType(
at::ScalarType scalar_type, const torch::lazy::BackendDevice* device) {
torch::lazy::BackendDevice xla_device = bridge::GetDeviceOrCurrent(device);
return GetXlaTypeFromTensorType(scalar_type, xla_device);
return MaybeDowncastForDevice(scalar_type, xla_device);
}

xla::Shape MakeXlaShapeFromLazyShape(torch::lazy::Shape shape,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ torch::lazy::Value XLAGraphExecutor::GetDeviceDataIrValue(
const at::Scalar& value, xla::PrimitiveType type,
const torch::lazy::BackendDevice& device) {
torch::lazy::BackendDataPtr data =
GetDeviceData(value, GetHostScalarType(type), device);
GetDeviceData(value, MaybeUpcastForHost(type), device);
data->SetInfo(
std::make_shared<torch::lazy::LazyGraphExecutor::DeviceDataInfo>(
/*tensor_id=*/-1, /*read_only=*/true));
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ std::vector<runtime::ComputationClient::DataPtr> ShardingUtil::OutputHandler(
// Reshards replicated output if `sharding` is present.
std::vector<at::Tensor> tensors = XlaDataToTensors(
{sharded_results[0][i]},
GetHostScalarType(sharding->shape.element_type()));
MaybeUpcastForHost(sharding->shape.element_type()));
outputs.push_back(
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
CreateTensorsData(
Expand Down

0 comments on commit cdca4a8

Please sign in to comment.