Skip to content

Commit

Permalink
manual override for half equal
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Dec 2, 2024
1 parent b41bff6 commit 1ea41bf
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
12 changes: 4 additions & 8 deletions cpp/include/cuvs/neighbors/quantization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,15 @@ class ScalarQuantizer {
raft::resources const& res, raft::host_matrix_view<const QuantI, int64_t> dataset);

// returns whether the instance can be used for transform
RAFT_INLINE_FUNCTION bool is_trained() const { return is_trained_; };
_RAFT_HOST_DEVICE bool is_trained() const;

RAFT_INLINE_FUNCTION bool operator==(const ScalarQuantizer<T, QuantI>& other) const
{
return (!is_trained() && !other.is_trained()) ||
(is_trained() == other.is_trained() && min() == other.min() && max() == other.max());
};
_RAFT_HOST_DEVICE bool operator==(const ScalarQuantizer<T, QuantI>& other) const;

// the minimum value covered by the quantized datatype
RAFT_INLINE_FUNCTION T min() const { return min_; };
_RAFT_HOST_DEVICE T min() const;

// the maximum value covered by the quantized datatype
RAFT_INLINE_FUNCTION T max() const { return max_; };
_RAFT_HOST_DEVICE T max() const;

private:
bool is_trained_ = false;
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/neighbors/detail/quantization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@

namespace cuvs::neighbors::detail {

template <class T>
_RAFT_HOST_DEVICE bool fp_equals(const T& a, const T& b)
{
return a == b;
}

template <>
_RAFT_HOST_DEVICE bool fp_equals(const half& a, const half& b)
{
return static_cast<float>(a) == static_cast<float>(b);
}

template <typename T, typename QuantI, typename TempT = double>
struct quantize_op {
const T min_;
Expand Down
27 changes: 27 additions & 0 deletions cpp/src/neighbors/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,33 @@ raft::host_matrix<T, int64_t> ScalarQuantizer<T, QuantI>::inverse_transform(
return detail::inverse_scalar_transform<T, QuantI>(res, dataset, min_, max_);
}

template <typename T, typename QuantI>
_RAFT_HOST_DEVICE bool ScalarQuantizer<T, QuantI>::is_trained() const
{
return is_trained_;
}

template <typename T, typename QuantI>
_RAFT_HOST_DEVICE bool ScalarQuantizer<T, QuantI>::operator==(
const ScalarQuantizer<T, QuantI>& other) const
{
return (!is_trained() && !other.is_trained()) ||
(is_trained() == other.is_trained() && detail::fp_equals(min(), other.min()) &&
detail::fp_equals(max(), other.max()));
}

template <typename T, typename QuantI>
_RAFT_HOST_DEVICE T ScalarQuantizer<T, QuantI>::min() const
{
return min_;
}

template <typename T, typename QuantI>
_RAFT_HOST_DEVICE T ScalarQuantizer<T, QuantI>::max() const
{
return max_;
}

#define CUVS_INST_QUANTIZATION(T, QuantI) \
template struct cuvs::neighbors::quantization::ScalarQuantizer<T, QuantI>;

Expand Down

0 comments on commit 1ea41bf

Please sign in to comment.