Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Support for half-float mixed precise in brute-force #225

Merged
merged 25 commits into from
Aug 20, 2024

Conversation

rhdong
Copy link
Member

@rhdong rhdong commented Jul 17, 2024

  • distance supports half-float mixed precision
  • prefiltered_brute_force supports half
  • migrate the ann brute force test cases and support half

rhdong added 2 commits July 16, 2024 22:24
- distance adaptation
- prefiltered_brute_force supports half
- migrate the ann brute force test cases and support half
@rhdong rhdong requested review from benfred and cjnolet July 17, 2024 05:25
@rhdong rhdong requested review from a team as code owners July 17, 2024 05:25
@rhdong rhdong added non-breaking Introduces a non-breaking change feature request New feature or request labels Jul 17, 2024
@rhdong
Copy link
Member Author

rhdong commented Jul 25, 2024

Hey @cjnolet @benfred , here is the performance result(code link), one line of float following one for half (the performance is improved significantly for IO workload reducing):

A100 with 80GB PCIE x 1 @computlab
Type           Queries   Vectors   Dim       K         Metric              Layout              Build Time (ms)    Search Time (ms)     Total Time (ms)    Throughput (q/s)
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
float          10        1000000   32        128       InnerProduct        row                           0.004               1.684               1.688            5924.023
half           10        1000000   32        128       InnerProduct        row                           0.003               1.563               1.566            6385.688
float          10        1000000   32        128       InnerProduct        col                           0.392               1.673               2.065            4843.640
half           10        1000000   32        128       InnerProduct        col                           1.098               2.362               3.459            2890.637
float          10        1000000   32        128       L2SqrtExpanded      row                           1.021               2.555               3.576            2796.354
half           10        1000000   32        128       L2SqrtExpanded      row                           0.771               1.613               2.384            4194.696
float          10        1000000   32        128       L2SqrtExpanded      col                           0.696               1.953               2.649            3775.518
half           10        1000000   32        128       L2SqrtExpanded      col                           2.082               3.600               5.682            1759.904
float          10        1000000   32        1024      InnerProduct        row                           0.003               3.867               3.870            2584.163
half           10        1000000   32        1024      InnerProduct        row                           0.003               3.772               3.774            2649.537
float          10        1000000   32        1024      InnerProduct        col                           0.380               3.903               4.283            2334.875
half           10        1000000   32        1024      InnerProduct        col                           1.096               4.575               5.671            1763.367
float          10        1000000   32        1024      L2SqrtExpanded      row                           0.222               5.674               5.895            1696.235
half           10        1000000   32        1024      L2SqrtExpanded      row                           0.192               5.585               5.777            1731.000
float          10        1000000   32        1024      L2SqrtExpanded      col                           0.579               5.914               6.493            1540.118
half           10        1000000   32        1024      L2SqrtExpanded      col                           0.469               5.789               6.258            1597.850
float          10        1000000   256       128       InnerProduct        row                           0.003               2.064               2.067            4838.767
half           10        1000000   256       128       InnerProduct        row                           0.003               2.208               2.211            4522.343
float          10        1000000   256       128       InnerProduct        col                           2.480               2.561               5.041            1983.605
half           10        1000000   256       128       InnerProduct        col                           1.708               1.573               3.280            3048.330
float          10        1000000   256       128       L2SqrtExpanded      row                           1.578               2.404               3.982            2511.399
half           10        1000000   256       128       L2SqrtExpanded      row                           1.253               2.022               3.275            3053.356
float          10        1000000   256       128       L2SqrtExpanded      col                           3.470               4.402               7.872            1270.285
half           10        1000000   256       128       L2SqrtExpanded      col                           2.715               3.080               5.796            1725.468
float          10        1000000   256       1024      InnerProduct        row                           0.003               6.421               6.424            1556.748
half           10        1000000   256       1024      InnerProduct        row                           0.004               4.355               4.358            2294.375
float          10        1000000   256       1024      InnerProduct        col                           2.483               6.152               8.635            1158.097
half           10        1000000   256       1024      InnerProduct        col                           1.706               5.625               7.331            1364.046
float          10        1000000   256       1024      L2SqrtExpanded      row                           1.582               9.255              10.838             922.716
half           10        1000000   256       1024      L2SqrtExpanded      row                           1.253               6.718               7.971            1254.565
float          10        1000000   256       1024      L2SqrtExpanded      col                           3.467              10.577              14.044             712.051
half           10        1000000   256       1024      L2SqrtExpanded      col                           2.731              10.204              12.935             773.093
float          10        1000000   1024      128       InnerProduct        row                           0.003               5.929               5.932            1685.793
half           10        1000000   1024      128       InnerProduct        row                           0.003               3.225               3.228            3097.965
float          10        1000000   1024      128       InnerProduct        col                           6.648               5.696              12.344             810.101
half           10        1000000   1024      128       InnerProduct        col                           4.121               3.334               7.455            1341.401
float          10        1000000   1024      128       L2SqrtExpanded      row                           3.539               5.533               9.072            1102.320
half           10        1000000   1024      128       L2SqrtExpanded      row                           2.331               3.022               5.354            1867.913
float          10        1000000   1024      128       L2SqrtExpanded      col                           9.609               9.480              19.088             523.883
half           10        1000000   1024      128       L2SqrtExpanded      col                           8.514               5.665              14.179             705.248
float          10        1000000   1024      1024      InnerProduct        row                           0.003               7.793               7.796            1282.688
half           10        1000000   1024      1024      InnerProduct        row                           0.003               6.463               6.466            1546.549
float          10        1000000   1024      1024      InnerProduct        col                           6.635               7.572              14.207             703.864
half           10        1000000   1024      1024      InnerProduct        col                           4.138               6.572              10.710             933.723
float          10        1000000   1024      1024      L2SqrtExpanded      row                           3.546              10.093              13.639             733.206
half           10        1000000   1024      1024      L2SqrtExpanded      row                           2.325               9.668              11.993             833.810
float          10        1000000   1024      1024      L2SqrtExpanded      col                           9.578              14.069              23.647             422.885
half           10        1000000   1024      1024      L2SqrtExpanded      col                           8.411              12.422              20.833             479.998
float          100       1000000   32        128       InnerProduct        row                           0.003               9.865               9.869           10133.094
half           100       1000000   32        128       InnerProduct        row                           0.003               9.882               9.885           10116.119
float          100       1000000   32        128       InnerProduct        col                           1.262               9.685              10.947            9134.877
half           100       1000000   32        128       InnerProduct        col                           1.095               9.843              10.938            9142.326
float          100       1000000   32        128       L2SqrtExpanded      row                           1.013               9.841              10.854            9213.248
half           100       1000000   32        128       L2SqrtExpanded      row                           0.985               9.875              10.860            9207.921
float          100       1000000   32        128       L2SqrtExpanded      col                           2.233              10.793              13.026            7677.038
half           100       1000000   32        128       L2SqrtExpanded      col                           2.075              10.968              13.044            7666.443
float          100       1000000   32        1024      InnerProduct        row                           0.003              31.016              31.019            3223.793
half           100       1000000   32        1024      InnerProduct        row                           0.003              30.987              30.990            3226.864
float          100       1000000   32        1024      InnerProduct        col                           1.257              30.811              32.068            3118.342
half           100       1000000   32        1024      InnerProduct        col                           1.094              30.995              32.089            3116.356
float          100       1000000   32        1024      L2SqrtExpanded      row                           1.013              53.023              54.035            1850.639
half           100       1000000   32        1024      L2SqrtExpanded      row                           0.985              53.037              54.022            1851.109
float          100       1000000   32        1024      L2SqrtExpanded      col                           2.251              48.709              50.961            1962.302
half           100       1000000   32        1024      L2SqrtExpanded      col                           2.086              48.884              50.971            1961.911
float          100       1000000   256       128       InnerProduct        row                           0.003               6.092               6.094           16408.592
half           100       1000000   256       128       InnerProduct        row                           0.003               3.126               3.129           31960.277
float          100       1000000   256       128       InnerProduct        col                           2.507               6.781               9.289           10765.918
half           100       1000000   256       128       InnerProduct        col                           1.710               2.488               4.198           23822.428
float          100       1000000   256       128       L2SqrtExpanded      row                           1.578               5.967               7.544           13255.126
half           100       1000000   256       128       L2SqrtExpanded      row                           1.250               3.342               4.592           21778.493
float          100       1000000   256       128       L2SqrtExpanded      col                           3.474               7.944              11.419            8757.537
half           100       1000000   256       128       L2SqrtExpanded      col                           2.743               4.379               7.122           14041.061
float          100       1000000   256       1024      InnerProduct        row                           0.003              18.316              18.319            5458.889
half           100       1000000   256       1024      InnerProduct        row                           0.003              19.368              19.371            5162.445
float          100       1000000   256       1024      InnerProduct        col                           2.492              17.946              20.437            4892.974
half           100       1000000   256       1024      InnerProduct        col                           1.710              17.683              19.393            5156.620
float          100       1000000   256       1024      L2SqrtExpanded      row                           1.585              46.057              47.642            2098.978
half           100       1000000   256       1024      L2SqrtExpanded      row                           1.259              45.341              46.600            2145.905
float          100       1000000   256       1024      L2SqrtExpanded      col                           3.480              47.391              50.872            1965.728
half           100       1000000   256       1024      L2SqrtExpanded      col                           2.736              48.139              50.875            1965.621
float          100       1000000   1024      128       InnerProduct        row                           0.003              16.497              16.500            6060.686
half           100       1000000   1024      128       InnerProduct        row                           0.003               4.195               4.198           23822.428
float          100       1000000   1024      128       InnerProduct        col                           6.655              17.215              23.870            4189.381
half           100       1000000   1024      128       InnerProduct        col                           4.165               4.262               8.427           11866.377
float          100       1000000   1024      128       L2SqrtExpanded      row                           3.542              16.509              20.051            4987.345
half           100       1000000   1024      128       L2SqrtExpanded      row                           2.321               4.400               6.721           14878.718
float          100       1000000   1024      128       L2SqrtExpanded      col                           9.558              20.924              30.482            3280.646
half           100       1000000   1024      128       L2SqrtExpanded      col                           8.448               7.159              15.608            6407.093
float          100       1000000   1024      1024      InnerProduct        row                           0.003              24.618              24.622            4061.449
half           100       1000000   1024      1024      InnerProduct        row                           0.003              18.342              18.345            5451.117
float          100       1000000   1024      1024      InnerProduct        col                           6.650              28.540              35.191            2841.661
half           100       1000000   1024      1024      InnerProduct        col                           4.134              19.459              23.593            4238.568
float          100       1000000   1024      1024      L2SqrtExpanded      row                           3.538              52.542              56.080            1783.158
half           100       1000000   1024      1024      L2SqrtExpanded      row                           2.325              48.467              50.791            1968.835
float          100       1000000   1024      1024      L2SqrtExpanded      col                           9.607              59.168              68.775            1454.023
half           100       1000000   1024      1024      L2SqrtExpanded      col                         151.098              49.229             200.327             499.184
float          1024      1000000   32        128       InnerProduct        row                           0.003              20.491              20.494           49966.165
half           1024      1000000   32        128       InnerProduct        row                           0.003              20.509              20.513           49919.884
float          1024      1000000   32        128       InnerProduct        col                           1.258              20.298              21.556           47503.756
half           1024      1000000   32        128       InnerProduct        col                           1.094              20.477              21.571           47471.076
float          1024      1000000   32        128       L2SqrtExpanded      row                           1.010              23.636              24.646           41548.334
half           1024      1000000   32        128       L2SqrtExpanded      row                           0.982              20.531              21.513           47598.883
float          1024      1000000   32        128       L2SqrtExpanded      col                           2.241              23.541              25.782           39718.195
half           1024      1000000   32        128       L2SqrtExpanded      col                           2.081              21.588              23.669           43263.647
float          1024      1000000   32        1024      InnerProduct        row                           0.003              79.198              79.201           12929.171
half           1024      1000000   32        1024      InnerProduct        row                           0.004              78.151              78.155           13102.164
float          1024      1000000   32        1024      InnerProduct        col                           1.263              77.934              79.198           12929.661
half           1024      1000000   32        1024      InnerProduct        col                           1.094              80.218              81.312           12593.487
float          1024      1000000   32        1024      L2SqrtExpanded      row                           1.013             243.951             244.965            4180.195
half           1024      1000000   32        1024      L2SqrtExpanded      row                           0.974             242.955             243.929            4197.946
float          1024      1000000   32        1024      L2SqrtExpanded      col                           2.240             247.013             249.253            4108.275
half           1024      1000000   32        1024      L2SqrtExpanded      col                           2.081             226.590             228.670            4478.060
float          1024      1000000   256       128       InnerProduct        row                           0.003              38.875              38.879           26338.442
half           1024      1000000   256       128       InnerProduct        row                           0.003              15.260              15.263           67090.257
float          1024      1000000   256       128       InnerProduct        col                           2.481              39.226              41.707           24552.388
half           1024      1000000   256       128       InnerProduct        col                           1.717              19.817              21.534           47553.262
float          1024      1000000   256       128       L2SqrtExpanded      row                           1.479              45.291              46.771           21893.991
half           1024      1000000   256       128       L2SqrtExpanded      row                           1.248              20.523              21.771           47035.155
float          1024      1000000   256       128       L2SqrtExpanded      col                           3.468              45.591              49.059           20872.764
half           1024      1000000   256       128       L2SqrtExpanded      col                           2.763              22.984              25.747           39772.126
float          1024      1000000   256       1024      InnerProduct        row                           0.003              73.301              73.304           13969.220
half           1024      1000000   256       1024      InnerProduct        row                           0.003              53.471              53.474           19149.481
float          1024      1000000   256       1024      InnerProduct        col                           2.463              72.974              75.437           13574.186
half           1024      1000000   256       1024      InnerProduct        col                           1.701              51.873              53.574           19113.680
float          1024      1000000   256       1024      L2SqrtExpanded      row                           1.575             317.503             319.078            3209.245
half           1024      1000000   256       1024      L2SqrtExpanded      row                           1.253             297.144             298.398            3431.664
float          1024      1000000   256       1024      L2SqrtExpanded      col                           3.449             320.774             324.223            3158.318
half           1024      1000000   256       1024      L2SqrtExpanded      col                           2.728             297.009             299.737            3416.330
float          1024      1000000   1024      128       InnerProduct        row                           0.003             125.358             125.361            8168.406
half           1024      1000000   1024      128       InnerProduct        row                           0.003              20.999              21.003           48755.522
float          1024      1000000   1024      128       InnerProduct        col                           6.607             126.106             132.713            7715.899
half           1024      1000000   1024      128       InnerProduct        col                           4.128              22.579              26.707           38342.069
float          1024      1000000   1024      128       L2SqrtExpanded      row                           3.537             131.411             134.947            7588.147
half           1024      1000000   1024      128       L2SqrtExpanded      row                           2.315              26.711              29.026           35278.668
float          1024      1000000   1024      128       L2SqrtExpanded      col                           9.612             133.920             143.532            7134.308
half           1024      1000000   1024      128       L2SqrtExpanded      col                           8.525              28.427              36.952           27711.348
float          1024      1000000   1024      1024      InnerProduct        row                           0.004             171.560             171.563            5968.639
half           1024      1000000   1024      1024      InnerProduct        row                           0.004              76.682              76.685           13353.282
float          1024      1000000   1024      1024      InnerProduct        col                           6.641             157.351             163.992            6244.192
half           1024      1000000   1024      1024      InnerProduct        col                           4.160              82.248              86.408           11850.815
float          1024      1000000   1024      1024      L2SqrtExpanded      row                           3.541             415.505             419.046            2443.645
half           1024      1000000   1024      1024      L2SqrtExpanded      row                           2.324             316.662             318.986            3210.173
float          1024      1000000   1024      1024      L2SqrtExpanded      col                           9.578             417.324             426.902            2398.680
half           1024      1000000   1024      1024      L2SqrtExpanded      col                           8.509             317.856             326.365            3137.591

@rhdong rhdong changed the base branch from branch-24.08 to branch-24.10 July 26, 2024 17:06
@cjnolet
Copy link
Member

cjnolet commented Jul 30, 2024

Linking #110

@cjnolet
Copy link
Member

cjnolet commented Aug 8, 2024

Thanks for providing the benchmarks above @rhdong. For smaller number of queries (e.g. 10) it looks like the half precision is significantly slower than the single-precision. It's very common for these algos to be used in online scenarios where 1 query at a time is used. Any idea why we are seeing this perf degradation and how to fix it?

@rhdong rhdong closed this Aug 8, 2024
@rhdong rhdong reopened this Aug 8, 2024
@rhdong
Copy link
Member Author

rhdong commented Aug 8, 2024

Thanks for providing the benchmarks above @rhdong. For smaller number of queries (e.g. 10) it looks like the half precision is significantly slower than the single-precision. It's very common for these algos to be used in online scenarios where 1 query at a time is used. Any idea why we are seeing this perf degradation and how to fix it?

It looks like it only happens on Col_Major; let me take a look.

@@ -61,8 +61,8 @@ endfunction()
# To use a different RAFT locally, set the CMake variable
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${RAFT_VERSION}.00
FORK ${RAFT_FORK}
PINNED_TAG ${RAFT_PINNED_TAG}
FORK rhdong
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question- why is this dependent upon RAFT? what do we still have in raft that is needed for this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a PR for raft which is required by the feature: rapidsai/raft#2382

@@ -0,0 +1,51 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All new files should have 2024 as the year. Please make this change in all new files introduced in this PR.

template <typename Type,
typename layout = layout_c_contiguous,
typename IdxT = int,
typename DistT = Type>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really really not in love at all with introducing new template types. We should be going the opposite direction.

raft::resources const& handle, \
raft::device_matrix_view<const DataT, IdxT, layout> const x, \
raft::device_matrix_view<const DataT, IdxT, layout> const y, \
raft::device_matrix_view<OutT, IdxT, layout> dist, \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a new template param for this? Is this absolutely necessary or is this just a convenience?

I understand you were able to drop the binary size down in this PR by fixing the pairwise distance instnatiations, but that fix should be decoupled from your additions here. In other words- if we could drop the binary size by 100mb or more just by fixing exiting instantiations, I'd rather not add new intantiations in unecessarily that add to the binary size. Can these be at all avoided using the existing template types? We have a lot of them already...

Copy link
Member Author

@rhdong rhdong Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm afraid it's necessary if we want to support a distance type different from the input type(like half -> float) unless we make large-scale refactoring on this part of code because all of the APIs assume inputs and output share the same type before the PR
.

@@ -15,7 +15,7 @@
*/
#pragma once

#ifndef CUVS_EXPLICIT_INSTANTIATE_ONLY
#ifndef _CUVS_EXPLICIT_INSTANTIATE_ONLY
Copy link
Member Author

@rhdong rhdong Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cjnolet , I'm not so sure if our Cagra module needs to open this macro, so add a _ for the placeholder and wait for further instruction. If you feel it is necessary, could you help tag suitable reviewers/owners to take a look? (I have tried to open it, got redefined errors, I guess they caused by search_single_cta_inst.cuh and search_multi_cta_inst.cuh). Thanks!

@@ -23,7 +23,7 @@
namespace cuvs::neighbors::cagra::detail {
namespace single_cta_search {

#ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY
#ifdef _CUVS_EXPLICIT_INSTANTIATE_ONLY
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cjnolet , I'm not so sure if our Cagra module needs to open this macro, so add a _ for the placeholder. If you feel it is necessary, could you help tag a suitable reviewer to take a look? Thanks!

@cjnolet
Copy link
Member

cjnolet commented Aug 20, 2024

/merge

@rapids-bot rapids-bot bot merged commit 934645c into rapidsai:branch-24.10 Aug 20, 2024
46 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CMake cpp feature request New feature or request non-breaking Introduces a non-breaking change
Projects
Development

Successfully merging this pull request may close these issues.

3 participants