Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast input of embedding op only if its dtype is different than int32. #45

Merged
merged 1 commit into from
Nov 25, 2024

Conversation

dgolubovicTT
Copy link
Contributor

@dgolubovicTT dgolubovicTT commented Nov 15, 2024

Cast input of embedding op only if its dtype is different than int32, instead of always.

@nvukobratTT
Copy link
Contributor

This astype cast of indices causes tvm to add cast op on input of embedding no matter the type of indices.

We can handle type check here in tvm and cast if data format of the input is different from int32. However, I am wondering if there is a better place for that inside compiler. @nvukobratTT what do you think?

I don't specifically mind if we do this at the other part of the code. However, from the current perspective, it seems to me that the op definition is the best place to keep this logic. Why? The first place to check for uncertenties regarding op support, or some special logic binds to certain ops is the definition itself. Therefore, doing specific casting in certain cases seems like a good spot to be done in this place.

Let me know your thoughts, and if you have a better suggestion where this logic should lay

@dgolubovicTT
Copy link
Contributor Author

Yeah, having checks in op definition seems logical!

@dgolubovicTT dgolubovicTT force-pushed the dgolubovic/erase-default-cast-on-emb branch from 1dd990a to 4b8b6be Compare November 18, 2024 11:28
@dgolubovicTT dgolubovicTT changed the title Erase astype cast for input to embedding op. Cast input of embedding op only if its dtype is different than int32. Nov 18, 2024
@dgolubovicTT
Copy link
Contributor Author

Added check to cast input only if its dtype is different than int32. However, this won't save us from casting int32 to in32 because if we have int64 as input to embedding we will add cast to int32. However, since forge doesn't support int64, pytorch_dtype_to_forge_dataformat will cast it to int32. Therefore we will end up again with input in int32 and cast to int32. This shouldn't be the problem when mlir addresses this issue. Currently, casting int32 to int32 exposes few bugs in mlir that are described in this issue...

@nvukobratTT
Copy link
Contributor

Added check to cast input only if its dtype is different than int32. However, this won't save us from casting int32 to in32 because if we have int64 as input to embedding we will add cast to int32. However, since forge doesn't support int64, pytorch_dtype_to_forge_dataformat will cast it to int32. Therefore we will end up again with input in int32 and cast to int32. This shouldn't be the problem when mlir addresses this issue. Currently, casting int32 to int32 exposes few bugs in mlir that are described in this issue...

I see. Can we skip cast for int64 as well until MLIR solves problem on their end? Also, if we do so, let's open tech dept issues + add good forst issue label :))

This is pretty localised change, so will be okay :))

@dgolubovicTT dgolubovicTT force-pushed the dgolubovic/erase-default-cast-on-emb branch from 4b8b6be to d0ed894 Compare November 18, 2024 13:22
@dgolubovicTT dgolubovicTT marked this pull request as ready for review November 18, 2024 13:23
@dgolubovicTT
Copy link
Contributor Author

Added check to cast input only if its dtype is different than int32. However, this won't save us from casting int32 to in32 because if we have int64 as input to embedding we will add cast to int32. However, since forge doesn't support int64, pytorch_dtype_to_forge_dataformat will cast it to int32. Therefore we will end up again with input in int32 and cast to int32. This shouldn't be the problem when mlir addresses this issue. Currently, casting int32 to int32 exposes few bugs in mlir that are described in this issue...

I see. Can we skip cast for int64 as well until MLIR solves problem on their end? Also, if we do so, let's open tech dept issues + add good forst issue label :))

This is pretty localised change, so will be okay :))

Added int64 check and an explanation why we are doing that. After I merge I will open tech debt with reference to this part of code...

python/tvm/relay/frontend/pytorch.py Outdated Show resolved Hide resolved
@dgolubovicTT dgolubovicTT force-pushed the dgolubovic/erase-default-cast-on-emb branch from d0ed894 to 17e017a Compare November 18, 2024 14:53
@dgolubovicTT dgolubovicTT merged commit 932d1d4 into main Nov 25, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants