Skip to content

Commit

Permalink
Don't produce layout info for TPU Embedding Ops in the old bridge to …
Browse files Browse the repository at this point in the history
…match behavior of the MLIR Bridge. XLA produces the information.

PiperOrigin-RevId: 585681525
  • Loading branch information
changm authored and tensorflower-gardener committed Nov 27, 2023
1 parent 62525ff commit 907196b
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tensorflow/core/tpu/kernels/tpu_embedding_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "xla/client/xla_builder.h"
#include "xla/layout_util.h"
#include "xla/literal_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/tpu/c_api_conversions.h"
#include "xla/stream_executor/tpu/c_api_decl.h"
Expand Down Expand Up @@ -252,7 +254,11 @@ class SendTPUEmbeddingGradientsOp : public XlaOpKernel {
auto builder = ctx->builder();
gradient_shapes.reserve(gradients.size());
for (xla::XlaOp op : gradients) {
gradient_shapes.push_back(builder->GetShape(op).value());
// Gradient layout information is added by XLA, so we can just create
// default layout information.
xla::Shape gradient_shape = builder->GetShape(op).value();
xla::LayoutUtil::SetToDefaultLayout(&gradient_shape);
gradient_shapes.push_back(gradient_shape);
}

std::vector<xla::XlaOp> learning_rates;
Expand Down

0 comments on commit 907196b

Please sign in to comment.