From 63fcdb6099deb58fa7e8f7a35a799741865db4b0 Mon Sep 17 00:00:00 2001 From: dgolubovicTT Date: Tue, 10 Dec 2024 09:50:44 +0000 Subject: [PATCH] Add embedding cast in third_party/tvm/python/tvm/relay/frontend/pytorch.py to cast weight to bfloat16 if they are float32. --- python/tvm/relay/frontend/pytorch.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index d8f657ad5..f214e957d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2687,6 +2687,7 @@ def rsub(self, inputs, input_types): return get_relay_op("subtract")(data1, alpha * data0) def embedding(self, inputs, input_types): + # print("********** EMBEDDING") weight = inputs[0] indices = inputs[1] # Check the type of indices @@ -2700,7 +2701,11 @@ def embedding(self, inputs, input_types): # exposes a few bugs in tt-mlir https://github.com/tenstorrent/tt-mlir/issues/1215 logger.warning("Casting input indices of embedding op from {} to int32", indicies_dtype) indices = tvm.relay.cast(indices, "int32") - return _op.embedding(weight, indices, axis=0) + # cast the weight to bfloat16 if it is float32 + if weight.type_annotation.dtype == "float32": + weight = tvm.relay.cast(weight, "bfloat16") + return tvm.relay.cast(_op.embedding(weight, indices, axis=0), "float32") + # return _op.embedding(weight, indices, axis=0) def embedding_bag(self, inputs, input_types):