-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 1xtfloat capability to pairwise_matrix distance computations #1493
base: branch-23.06
Are you sure you want to change the base?
Add 1xtfloat capability to pairwise_matrix distance computations #1493
Conversation
The instance with rbf_fin_op caused some headache: Because the handling of L2 expanded and unexpanded is unified in ``distance_impl_l2_with_options``, an instance of the CUTLASS distance kernel for rbf_fin_op was instantiated. For some reason, CUTLASS did not accept this as a valid argument and threw a very very big error message. I could not get the rbf_fin_op in acceptable state for cutlass: I included a default constructor, put const on every method, but to no avail. The current solution is to avoid CUTLASS when another final op is used than the raft::identity_op.
Peak T ops/s = 74 T/s (1x tfloat) Peak T ops/s = 22 T/s (3x tfloat) This roughly corresponds to: (assuming 2 flops / core op) Peak T ops/s = 144 Tflop/s (1x tfloat) Peak T ops/s = 33 Tflop/s (3x tfloat)
@cjnolet : I have implemented the 1xtfloat distance, but it is not yet exposed in the public API. The distance API is getting a bit unwieldy. I see the following options to expose the 1xtfloat in the API:
Do you have any thought on this? What has your preference? |
@benfred : Related to #852, I have drafted a type to describe the L2 distance options. It describes:
The docstrings in the code explain how each option should work. Please let me know:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Allard for this proposal! Indeed it is important to discuss how do we enable control for distance computation precision.
We should consider how the new arguments could be propagated to IVF methods.
It seems that ann::index_params
already has a metric_arg, expanding that to a struct that has the additional options would work. The parameters would need to be passed to kmeans clustering called from ivf_flat / ivf_pq build(). The kmeans_base_params would need to be extended with the metric_arg.
Having global parameter that we set in resource handle also has its appeal: e.g. if we want to enable 1xtf32 for cuML SVM. Currently (at lest for some of the kernels) NVIDIA_TF32_OVERRIDE would work without passing extra args through the call chain.
// double this number to get the flop/s. For l2 expanded, core_ops/s should | ||
// equal flop/s (modulo the sqrt and subtracting from the norm). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this FMA/s (i.e. you would still need to multiply core_ops/s by 2 to get flops)?
// double this number to get the flop/s. For l2 expanded, core_ops/s should | |
// equal flop/s (modulo the sqrt and subtracting from the norm). | |
// double this number to get the flop/s. For l2 expanded, 2*core_ops/s should | |
// equal flop/s (ignoring the sqrt and subtracting from the norm). |
// Use if constexpr to prevent instantiation of CUTLASS templates with final | ||
// operations like rbf_fin_op, which are somehow not compatible with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mdoijade, are you aware of restrictions that we have for epilogue functions for the cutlass kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes AFAIK cutlass doesn't work with lambda function here in this rb_fin_op
must have been a lambda.
bool isRowMajor> | ||
bool isRowMajor, | ||
/// Whether to use 3xtfloat or 1xtfloat: | ||
bool use_1xtfloat> | ||
struct PairwiseDistanceGemm { | ||
// This struct is specialized for fp32/3xTF32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the same tile sizes work reasonable well for 1xTF32?
using Operator = | ||
std::conditional_t<use_1xtfloat, | ||
cutlass::arch::OpMultiplyAdd, // This implies tensorfloat | ||
cutlass::arch::OpMultiplyAddFastF32>; // This implies 3xtfloat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we decide what precision to use. IIUC, the rest of the PR is responsible to
- propagate the user input parameter until this point,
- update dispatch mechanism accordingly,
- add benchmarks.
/** | ||
* @brief Describes how precise and fast distance should be computed. | ||
*/ | ||
enum class Compute_options { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should consider naming / wording in a way that would clearly describes what happens even if we enable half precision input.
@ahendriksen @tfeher are we still planning to make progress on this feature? I'm doing a little housekeeping on the PRs and just want to make sure the PRs we are keeping open are still valid. |
The main question here what is the best mechanism for the user to opt in/out of 1xTF32 computation. @vinaydes is working on the same question related to #1892. Let's wait until that is fixed, and afterwards we shall return to this PR. Since @ahendriksen is busy with other tasks, we need someone else to continue this. Assigning this to myself for now, we will revisit availability once #1892 is solved. |
This PR adds the possibility to use 1xtfloat in the pairwise matrix computations of
raft::distance
.When 1xtfloat is enabled, the throughput more than triples compared to using 3xtfloat.
Benchmarks below were taken on H100 (unlocked clocks, SXM). The distance computed was the square L2 expanded distance. Therefore, one core_op corresponds to one fused multiply add.