From 3e98a4db2ea6ec7a2264aa4281af770ecbe049a7 Mon Sep 17 00:00:00 2001 From: Dmitry Matveev Date: Thu, 25 Jul 2024 23:41:34 +0100 Subject: [PATCH] Turn i4 to canonical u4+zp to enable folding --- src/cpp/src/llm_pipeline_static.cpp | 83 +++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index 351e10b523..942f8f2076 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -5,6 +5,16 @@ #include "openvino/opsets/opset13.hpp" + +#include "openvino/pass/validate.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + #include "text_callback_streamer.hpp" #include "utils.hpp" @@ -117,6 +127,72 @@ ov::AnyMap extract_config_or_default(const ov::AnyMap& config, const std::string return stage_cfg; } +inline int8_t hi4(int8_t x) { + return ((x & (1 << 7)) >> 4) | ((x & (1 << 6)) >> 4) | ((x & (1 << 5)) >> 4) | ((x & (1 << 4)) >> 4); +} + +inline int8_t lo4(int8_t x) { + return (x & (1 << 3)) | (x & (1 << 2)) | (x & (1 << 1)) | (x & (1 << 0)); +} + +inline int8_t upc(int8_t h) { + return h | (-((h & (1 << 3)) >> 3) & (-8)); +} + +void cvt(const ov::Tensor &src, ov::Tensor &dst) { + + int8_t const* pSrc = static_cast(src.data()); + int8_t *pDst = static_cast(dst.data()); + for (int i = 0; i < src.get_size() / 2; i++) { + uint8_t a0 = upc(lo4(*pSrc)) + 8; + uint8_t a1 = upc(hi4(*pSrc)) + 8; + *pDst = a1 << 4 | a0; + pSrc++; + pDst++; + } +} + +struct DQMM1: public ov::pass::MatcherPass { + DQMM1() { + namespace opp = ov::pass::pattern; + + auto w = opp::wrap_type(); + auto s = opp::wrap_type(); + auto cvtw = opp::wrap_type({w}); + auto mply = opp::wrap_type({cvtw, s}); + + auto cb = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + auto mw_const = std::static_pointer_cast(node_to_output.at(w).get_node_shared_ptr()); + auto mupscale = node_to_output.at(mply).get_node_shared_ptr(); + if (ov::element::i4 == mw_const->get_element_type()) { + + ov::Tensor src(mw_const->get_element_type(), mw_const->get_shape(), const_cast(mw_const->get_data_ptr())); + ov::Tensor dst(ov::element::u4, mw_const->get_shape()); + cvt(src, dst); + + auto new_w = std::make_shared(dst); + + ov::Tensor zp(ov::element::u4, ov::Shape{1}); + *static_cast(zp.data()) = 8; + + auto new_z = std::make_shared(zp); + + auto mply_type = mupscale->input(1).get_element_type(); + + auto new_wcvt = std::make_shared(new_w, mply_type); + auto new_zcvt = std::make_shared(new_z, mply_type); + auto new_sub = std::make_shared(new_wcvt, new_zcvt); + + mupscale->input(0).replace_source_output(new_sub); + } + + return false; + }; + register_matcher(std::make_shared(mply, "DQMM1"), cb); + } +}; + } // anonymous namespace namespace ov { @@ -145,6 +221,13 @@ StaticLLMPipeline::StaticLLMPipeline( ov::Core core; // (1) Read the template model - this will be kvcache model m_kvcache_model = core.read_model(path / "openvino_model.xml"); + + // (1.5): Some rewrites + ov::pass::GraphRewrite rewr; + rewr.add_matcher(); + rewr.run_on_model(m_kvcache_model); + ov::pass::Validate().run_on_model(m_kvcache_model); + // (2) Expose KV-cache input and output layers from kvcache model ov::pass::StatefulToStateless().run_on_model(m_kvcache_model); // (3) Clone the model - this will be prefill