diff --git a/include/pisa/linear_quantizer.hpp b/include/pisa/linear_quantizer.hpp index 5be852ee..9d2faf50 100644 --- a/include/pisa/linear_quantizer.hpp +++ b/include/pisa/linear_quantizer.hpp @@ -1,19 +1,23 @@ #pragma once -#include "spdlog/spdlog.h" + #include +#include + +#include #include namespace pisa { struct LinearQuantizer { - explicit LinearQuantizer(float max, uint8_t bits) - : m_max(max), m_scale(static_cast(1U << (bits)) / max) { + explicit LinearQuantizer(float max, std::uint8_t bits) + : m_max(max), m_scale(static_cast((1U << bits) - 1U) / max) { if (bits > 32 or bits == 0) { throw std::runtime_error(fmt::format( "Linear quantizer must take a number of bits between 1 and 32 but {} passed", bits )); } } + [[nodiscard]] auto operator()(float value) const -> std::uint32_t { Expects(value <= m_max); return std::ceil(value * m_scale); @@ -24,4 +28,4 @@ struct LinearQuantizer { float m_scale; }; -} // namespace pisa \ No newline at end of file +} // namespace pisa diff --git a/test/test_linear_quantizer.cpp b/test/test_linear_quantizer.cpp new file mode 100644 index 00000000..2787e92f --- /dev/null +++ b/test/test_linear_quantizer.cpp @@ -0,0 +1,24 @@ +#define CATCH_CONFIG_MAIN +#include "catch2/catch.hpp" + +#include + +#include "linear_quantizer.hpp" + +TEST_CASE("LinearQuantizer", "[scoring][unit]") { + SECTION("construct") { + WHEN("number of bits is 0 or 33") { + std::uint8_t bits = GENERATE(0, 33); + THEN("constructor fails") { + REQUIRE_THROWS(pisa::LinearQuantizer(10.0, bits)); + } + } + } + SECTION("scores") { + std::uint8_t bits = GENERATE(3, 8, 12, 16, 19, 32); + float max = GENERATE(1.0, 100.0, std::numeric_limits::max()); + pisa::LinearQuantizer quantizer(max, bits); + REQUIRE(quantizer(0) == 0); + REQUIRE(quantizer(max) == (1 << bits) - 1); + } +}