Skip to content

Commit

Permalink
remove debug pritn
Browse files Browse the repository at this point in the history
  • Loading branch information
Siyuan Liu committed Feb 29, 2024
1 parent ef773d5 commit 9d0caed
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 18 deletions.
3 changes: 0 additions & 3 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,11 +896,8 @@ xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast(
// Create Concatenate op
auto concat_op = xla::ConcatInDim(input.builder(), get_dim_ops, {0});

std::cout << "check output dimensions, output dynamic: " << output_dimensions
<< std::endl;
xla::Shape final_shape = xla::ShapeUtil::MakeShape(
input_shape.element_type(), output_dimensions, output_dynamic);
std::cout << "check final shape: " << final_shape << std::endl;
return DynamicBroadcastInDim(input, final_shape, concat_op);
}

Expand Down
15 changes: 0 additions & 15 deletions torch_xla/csrc/softmax_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ static xla::XlaOp BuildBroadcastForReducedLogits(xla::XlaOp reduced_logits,
std::iota(op_broadcast_dims.begin(), op_broadcast_dims.begin() + dim, 0);
std::iota(op_broadcast_dims.begin() + dim, op_broadcast_dims.end(), dim + 1);

std::cout << "in BuildBroadcastForReducedLogits " << std::endl;
std::cout << "check final_shape " << logits_shape << std::endl;
std::cout << "check op_broadcast_dims " << op_broadcast_dims << std::endl;
return xla::CustomCall(
reduced_logits.builder(), "mhlo.dynamic_broadcast_in_dim",
/*operands=*/{reduced_logits, final_broadcast_dimensions},
Expand All @@ -77,13 +74,8 @@ static xla::XlaOp BuildBroadcastForReducedLogits(xla::XlaOp reduced_logits,
SoftMaxPartials LogSoftmaxPartials(xla::XlaOp logits, int64_t dim) {
const xla::Shape& logits_shape = ShapeHelper::ShapeOfXlaOp(logits);
bool is_unbounded_dynamic = logits_shape.is_unbounded_dynamic();
std::cout << "check dim: " << dim << std::endl;
std::cout << "check logits_shape shape: " << ShapeHelper::ShapeOfXlaOp(logits)
<< std::endl;
std::vector<int64_t> broadcast_dimensions =
BroadcastDimensions(logits_shape.rank(), dim);
std::cout << "check broadcast_dimensions: " << broadcast_dimensions
<< std::endl;
xla::XlaComputation max_func =
XlaHelpers::CreateMaxComputation(logits_shape.element_type());
xla::Literal min_value =
Expand All @@ -93,23 +85,16 @@ SoftMaxPartials LogSoftmaxPartials(xla::XlaOp logits, int64_t dim) {
logits, xla::ConstantLiteral(builder, min_value), max_func, {dim});
if (is_unbounded_dynamic) {
xla::Shape logits_max_shape = ShapeHelper::ShapeOfXlaOp(logits_max);
std::cout << "check logits_max shape: " << logits_max_shape << std::endl;
logits_max = BuildBroadcastForReducedLogits(logits_max, logits_shape, dim);
}
xla::XlaOp shifted_logits =
is_unbounded_dynamic ? xla::Sub(logits, logits_max)
: xla::Sub(logits, logits_max, broadcast_dimensions);
xla::XlaOp exp_shifted = xla::Exp(shifted_logits);
std::cout << "check exp_shifted shape: "
<< ShapeHelper::ShapeOfXlaOp(exp_shifted) << std::endl;
xla::XlaOp init_value = xla::Zero(builder, logits_shape.element_type());
std::cout << "check init_value shape: "
<< ShapeHelper::ShapeOfXlaOp(init_value) << std::endl;
xla::XlaOp reduce = xla::Reduce(
exp_shifted, init_value,
XlaHelpers::CreateAddComputation(logits_shape.element_type()), {dim});
std::cout << "check reduce shape: " << ShapeHelper::ShapeOfXlaOp(reduce)
<< std::endl;
return {std::move(broadcast_dimensions), shifted_logits, exp_shifted, reduce};
}

Expand Down

0 comments on commit 9d0caed

Please sign in to comment.