Skip to content

Commit

Permalink
Turn i4 to canonical u4+zp to enable folding
Browse files Browse the repository at this point in the history
  • Loading branch information
dmatveev authored and TolyaTalamanov committed Jul 26, 2024
1 parent cd3d2c2 commit 3e98a4d
Showing 1 changed file with 83 additions and 0 deletions.
83 changes: 83 additions & 0 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@

#include "openvino/opsets/opset13.hpp"


#include "openvino/pass/validate.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

#include "text_callback_streamer.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -117,6 +127,72 @@ ov::AnyMap extract_config_or_default(const ov::AnyMap& config, const std::string
return stage_cfg;
}

inline int8_t hi4(int8_t x) {
return ((x & (1 << 7)) >> 4) | ((x & (1 << 6)) >> 4) | ((x & (1 << 5)) >> 4) | ((x & (1 << 4)) >> 4);
}

inline int8_t lo4(int8_t x) {
return (x & (1 << 3)) | (x & (1 << 2)) | (x & (1 << 1)) | (x & (1 << 0));
}

inline int8_t upc(int8_t h) {
return h | (-((h & (1 << 3)) >> 3) & (-8));
}

void cvt(const ov::Tensor &src, ov::Tensor &dst) {

int8_t const* pSrc = static_cast<const int8_t*>(src.data());
int8_t *pDst = static_cast<int8_t*>(dst.data());
for (int i = 0; i < src.get_size() / 2; i++) {
uint8_t a0 = upc(lo4(*pSrc)) + 8;
uint8_t a1 = upc(hi4(*pSrc)) + 8;
*pDst = a1 << 4 | a0;
pSrc++;
pDst++;
}
}

struct DQMM1: public ov::pass::MatcherPass {
DQMM1() {
namespace opp = ov::pass::pattern;

auto w = opp::wrap_type<ov::op::v0::Constant>();
auto s = opp::wrap_type<ov::op::v0::Constant>();
auto cvtw = opp::wrap_type<ov::op::v0::Convert>({w});
auto mply = opp::wrap_type<ov::op::v1::Multiply>({cvtw, s});

auto cb = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();
auto mw_const = std::static_pointer_cast<ov::op::v0::Constant>(node_to_output.at(w).get_node_shared_ptr());
auto mupscale = node_to_output.at(mply).get_node_shared_ptr();
if (ov::element::i4 == mw_const->get_element_type()) {

ov::Tensor src(mw_const->get_element_type(), mw_const->get_shape(), const_cast<void*>(mw_const->get_data_ptr()));
ov::Tensor dst(ov::element::u4, mw_const->get_shape());
cvt(src, dst);

auto new_w = std::make_shared<ov::op::v0::Constant>(dst);

ov::Tensor zp(ov::element::u4, ov::Shape{1});
*static_cast<uint8_t*>(zp.data()) = 8;

auto new_z = std::make_shared<ov::op::v0::Constant>(zp);

auto mply_type = mupscale->input(1).get_element_type();

auto new_wcvt = std::make_shared<ov::op::v0::Convert>(new_w, mply_type);
auto new_zcvt = std::make_shared<ov::op::v0::Convert>(new_z, mply_type);
auto new_sub = std::make_shared<ov::op::v1::Subtract>(new_wcvt, new_zcvt);

mupscale->input(0).replace_source_output(new_sub);
}

return false;
};
register_matcher(std::make_shared<opp::Matcher>(mply, "DQMM1"), cb);
}
};

} // anonymous namespace

namespace ov {
Expand Down Expand Up @@ -145,6 +221,13 @@ StaticLLMPipeline::StaticLLMPipeline(
ov::Core core;
// (1) Read the template model - this will be kvcache model
m_kvcache_model = core.read_model(path / "openvino_model.xml");

// (1.5): Some rewrites
ov::pass::GraphRewrite rewr;
rewr.add_matcher<DQMM1>();
rewr.run_on_model(m_kvcache_model);
ov::pass::Validate().run_on_model(m_kvcache_model);

// (2) Expose KV-cache input and output layers from kvcache model
ov::pass::StatefulToStateless().run_on_model(m_kvcache_model);
// (3) Clone the model - this will be prefill
Expand Down

0 comments on commit 3e98a4d

Please sign in to comment.