-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding bf16 data type workaround for max_pool2d op and fixing embedding workaround tests #1657
Conversation
599bfb9
to
ce775d3
Compare
ce775d3
to
ee54049
Compare
ee54049
to
f3ec40c
Compare
Created the following issue to track the runtime workaround: |
f3ec40c
to
f88de78
Compare
private: | ||
constexpr Env(bool maxpool2dPreshard, bool swapBinaryOperands, | ||
bool readUpdateIndexFromDeviceForKVCache) | ||
bool readUpdateIndexFromDeviceForKVCache, bool typecastOnHost) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe name this as toDtypeOnHost or something so that it's explicit that we're using to_dtype when typcasting on host.
Also please add a ttrt command line option to toggle this in runtime/tools/python/ttrt/common/run.py
, the procedure will be the same as the other workaround flags.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed typecastOnHost to toDtypeOnHost. I added an option in run.py; please double-check if I implemented it correctly.
f88de78
to
e320d04
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Runtime changes look good, thanks Stefan!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks stefi!
e320d04
to
5ba32ac
Compare
Applying bf16 data type workaround on max_pool2d op.
With this change, I rewrote the workaround test, silicon, and compiler test.
Closes #1389