Skip to content

Commit

Permalink
Add embedding cast in third_party/tvm/python/tvm/relay/frontend/pytor…
Browse files Browse the repository at this point in the history
…ch.py to cast weight to bfloat16 if they are float32.
  • Loading branch information
dgolubovicTT committed Dec 25, 2024
1 parent 05cd0bc commit 63fcdb6
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2687,6 +2687,7 @@ def rsub(self, inputs, input_types):
return get_relay_op("subtract")(data1, alpha * data0)

def embedding(self, inputs, input_types):
# print("********** EMBEDDING")
weight = inputs[0]
indices = inputs[1]
# Check the type of indices
Expand All @@ -2700,7 +2701,11 @@ def embedding(self, inputs, input_types):
# exposes a few bugs in tt-mlir https://github.com/tenstorrent/tt-mlir/issues/1215
logger.warning("Casting input indices of embedding op from {} to int32", indicies_dtype)
indices = tvm.relay.cast(indices, "int32")
return _op.embedding(weight, indices, axis=0)
# cast the weight to bfloat16 if it is float32
if weight.type_annotation.dtype == "float32":
weight = tvm.relay.cast(weight, "bfloat16")
return tvm.relay.cast(_op.embedding(weight, indices, axis=0), "float32")
# return _op.embedding(weight, indices, axis=0)

def embedding_bag(self, inputs, input_types):

Expand Down

0 comments on commit 63fcdb6

Please sign in to comment.