Skip to content

Commit

Permalink
Disable GGNN for FP16 input type
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Feb 5, 2024
1 parent 77a296b commit c3beee7
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions cpp/bench/ann/src/ggnn/ggnn_benchmark.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -92,11 +92,10 @@ std::unique_ptr<raft::bench::ann::ANN<T>> create_algo(const std::string& algo,
raft::bench::ann::Metric metric = parse_metric(distance);
std::unique_ptr<raft::bench::ann::ANN<T>> ann;

if constexpr (std::is_same_v<T, float>) {}

if constexpr (std::is_same_v<T, uint8_t>) {}

if (algo == "ggnn") { ann = make_algo<T, raft::bench::ann::Ggnn>(metric, dim, conf); }
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, uint8_t> ||
std::is_same_v<T, int8_t>) {
if (algo == "ggnn") { ann = make_algo<T, raft::bench::ann::Ggnn>(metric, dim, conf); }
}
if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); }

return ann;
Expand All @@ -106,10 +105,13 @@ template <typename T>
std::unique_ptr<typename raft::bench::ann::ANN<T>::AnnSearchParam> create_search_param(
const std::string& algo, const nlohmann::json& conf)
{
if (algo == "ggnn") {
auto param = std::make_unique<typename raft::bench::ann::Ggnn<T>::SearchParam>();
parse_search_param<T>(conf, *param);
return param;
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, uint8_t> ||
std::is_same_v<T, int8_t>) {
if (algo == "ggnn") {
auto param = std::make_unique<typename raft::bench::ann::Ggnn<T>::SearchParam>();
parse_search_param<T>(conf, *param);
return param;
}
}
// else
throw std::runtime_error("invalid algo: '" + algo + "'");
Expand Down

0 comments on commit c3beee7

Please sign in to comment.