Skip to content

Commit

Permalink
[Snippets] Applied comments
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 10, 2025
1 parent a67c947 commit 2c9cfcb
Show file tree
Hide file tree
Showing 25 changed files with 268 additions and 209 deletions.
33 changes: 13 additions & 20 deletions src/common/snippets/include/snippets/lowered/loop_port.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ class LoopPort {
enum {UNDEFINED_DIM_IDX = std::numeric_limits<size_t>::max()};
enum class Type {
Incremented, // Loop port which data ptr should be incremented after each Loop iteration
NotIncremented, // Loop port which data ptr should not be to avoid double increment
NotIncremented, // Loop port which data ptr should not be incremented (for example, to avoid double increment)
NotProcessed, // LoopPort which doesn't process the dim by `dim_idx` (UNDEFINED_DIM_IDX) and is used only for Loop bound definition
};

LoopPort() = default;

template<LoopPort::Type T = Type::Incremented,
template<LoopPort::Type T,
typename std::enable_if<T == Type::Incremented || T == Type::NotIncremented, bool>::type = true>
static LoopPort create(const ExpressionPort& port, size_t dim_idx = 0) {
return LoopPort(port, dim_idx, T);
Expand All @@ -46,12 +46,20 @@ class LoopPort {

const std::shared_ptr<ExpressionPort>& get_expr_port() const { return m_expr_port; }
Type get_type() const { return m_type; }
size_t get_dim_idx() const { return m_dim_idx; }
size_t get_dim_idx() const;

void set_expr_port(std::shared_ptr<ExpressionPort> p);
void set_type(Type type);
void set_dim_idx(size_t idx);

template<LoopPort::Type T = Type::Incremented,
typename std::enable_if<T == Type::Incremented || T == Type::NotIncremented, bool>::type = true>
void convert_to_type() {
m_type = T;
}

bool is_processed() const;
bool is_incremented() const;

private:
LoopPort(const ExpressionPort& port, size_t dim_idx, Type type);

Expand All @@ -60,22 +68,7 @@ class LoopPort {
Type m_type = Type::Incremented;
};

inline std::ostream& operator<<(std::ostream& out, const LoopPort::Type& type) {
switch (type) {
case LoopPort::Type::Incremented:
out << "Incremented";
break;
case LoopPort::Type::NotIncremented:
out << "NotIncremented";
break;
case LoopPort::Type::NotProcessed:
out << "NotProcessed";
break;
default:
OPENVINO_THROW("Unknown LoopPort Type");
}
return out;
}
std::ostream& operator<<(std::ostream& out, const LoopPort::Type& type);

} // namespace lowered
} // namespace snippets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class InsertBuffers : public RangedPass {
const LinearIR::constExprIt& begin_it,
const LinearIR::constExprIt& end_it,
const LoopManagerPtr& loop_manager,
const std::vector<LoopPort>& loop_entries,
const std::vector<LoopPort>& loop_exits) const;
const std::vector<ExpressionPort>& loop_entries,
const std::vector<ExpressionPort>& loop_exits) const;

static LinearIR::constExprIt insertion_position(const LinearIR& linear_ir,
const LoopManagerPtr& loop_manager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ void BufferExpression::init_allocation_size(const std::shared_ptr<LoopManager>&
OPENVINO_ASSERT(it != output_ports.end(), "compute_allocation_shape: output port of parent loop can not be found");
}
const auto& loop_port = *it;
if (!loop_port.is_processed())
continue;
const auto& dim_idx = loop_port.get_dim_idx();
if (loop_port.get_type() != LoopPort::Type::NotProcessed && dim_idx < rank) {
if (dim_idx < rank) {
if (const auto& unified_loop_info = ov::as_type_ptr<UnifiedLoopInfo>(loop_info))
m_allocation_size = utils::dynamic_safe_mul(m_allocation_size, unified_loop_info->get_work_amount());
else if (const auto& expanded_loop_info = ov::as_type_ptr<ExpandedLoopInfo>(loop_info))
Expand Down
27 changes: 16 additions & 11 deletions src/common/snippets/src/lowered/loop_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ LoopInfo::LoopInfo(size_t work_amount, size_t increment, const std::vector<Expre
m_input_ports.reserve(entries.size());
m_output_ports.reserve(exits.size());
for (const auto& port : entries)
m_input_ports.push_back(LoopPort::create(port));
m_input_ports.push_back(LoopPort::create<LoopPort::Type::Incremented>(port));
for (const auto& port : exits)
m_output_ports.push_back(LoopPort::create(port));
m_output_ports.push_back(LoopPort::create<LoopPort::Type::Incremented>(port));
}

bool LoopInfo::is_dynamic() const {
Expand All @@ -30,12 +30,20 @@ bool LoopInfo::is_dynamic() const {

size_t LoopInfo::get_dim_idx() const {
OPENVINO_ASSERT(!m_input_ports.empty(), "Loop info must have at least one input port");
auto equal_dim_idxes = [&](const LoopPort& p) {
return p.get_type() == LoopPort::Type::NotProcessed || p.get_dim_idx() == m_input_ports[0].get_dim_idx();
};

auto is_processed = [](const LoopPort& p) { return p.is_processed(); };
auto is_processed_it = std::find_if(m_input_ports.begin(), m_input_ports.end(), is_processed);
if (is_processed_it == m_input_ports.end()) {
is_processed_it = std::find_if(m_output_ports.begin(), m_output_ports.end(), is_processed);
if (is_processed_it == m_output_ports.end())
return LoopPort::UNDEFINED_DIM_IDX;
}
const auto dim_idx = is_processed_it->get_dim_idx();

auto equal_dim_idxes = [&](const LoopPort& p) { return !p.is_processed() || p.get_dim_idx() == dim_idx; };
if (std::all_of(m_input_ports.begin(), m_input_ports.end(), equal_dim_idxes) &&
std::all_of(m_output_ports.begin(), m_output_ports.end(), equal_dim_idxes)) {
return m_input_ports[0].get_dim_idx();
return dim_idx;
} else {
return LoopPort::UNDEFINED_DIM_IDX;
}
Expand All @@ -60,7 +68,7 @@ size_t LoopInfo::get_increment() const {
std::vector<bool> LoopInfo::get_is_incremented() const {
std::vector<bool> values;
values.reserve(get_input_count() + get_output_count());
iterate_through_ports([&values](const LoopPort& port) { values.push_back(port.get_type() == LoopPort::Type::Incremented); });
iterate_through_ports([&values](const LoopPort& port) { values.push_back(port.is_incremented()); });
return values;
}

Expand All @@ -81,10 +89,7 @@ void LoopInfo::set_increment(size_t increment) {
}

void LoopInfo::set_dim_idx(size_t dim_idx) {
auto setter = [dim_idx](LoopPort& port) {
if (port.get_type() != LoopPort::Type::NotProcessed)
port.set_dim_idx(dim_idx);
};
auto setter = [dim_idx](LoopPort& port) { if (port.is_processed()) port.set_dim_idx(dim_idx); };
std::for_each(m_input_ports.begin(), m_input_ports.end(), setter);
std::for_each(m_output_ports.begin(), m_output_ports.end(), setter);
}
Expand Down
47 changes: 41 additions & 6 deletions src/common/snippets/src/lowered/loop_port.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,35 @@ std::shared_ptr<LoopPort> LoopPort::clone_with_new_expr(const ExpressionPtr& new
return new_loop_port;
}

bool LoopPort::is_processed() const {
switch (m_type) {
case Type::Incremented:
case Type::NotIncremented:
return true;
case Type::NotProcessed:
return false;
default:
OPENVINO_THROW("Unknown LoopPort type");
}
}

bool LoopPort::is_incremented() const {
return m_type == Type::Incremented;
}

size_t LoopPort::get_dim_idx() const {
OPENVINO_ASSERT(is_processed(), "NotProcessed LoopPort cannot call `get_dim_idx()`");
return m_dim_idx;
}

void LoopPort::set_expr_port(std::shared_ptr<ExpressionPort> p) {
OPENVINO_ASSERT(p, "Expression port is missed");
m_expr_port = std::move(p);
}

void LoopPort::set_type(Type type) {
m_type = type;
}

void LoopPort::set_dim_idx(size_t idx) {
if (get_type() == LoopPort::Type::NotProcessed) {
OPENVINO_ASSERT(idx == UNDEFINED_DIM_IDX, "NotProcessed LoopPort cah have only UNDEFINED_DIM_IDX");
if (!is_processed()) {
OPENVINO_ASSERT(idx == UNDEFINED_DIM_IDX, "NotProcessed LoopPort can have only UNDEFINED_DIM_IDX");
} else {
OPENVINO_ASSERT(idx < m_expr_port->get_descriptor_ptr()->get_shape().size(),
"LoopPort dim_idx (",
Expand Down Expand Up @@ -63,6 +80,24 @@ bool operator<(const LoopPort& lhs, const LoopPort& rhs) {
(lhs.m_type == rhs.m_type && lhs.m_dim_idx < rhs.m_dim_idx)));
}

std::ostream& operator<<(std::ostream& out, const LoopPort::Type& type) {
switch (type) {
case LoopPort::Type::Incremented:
out << "Incremented";
break;
case LoopPort::Type::NotIncremented:
out << "NotIncremented";
break;
case LoopPort::Type::NotProcessed:
out << "NotProcessed";
break;
default:
OPENVINO_THROW("Unknown LoopPort Type");
}
return out;
}


} // namespace lowered
} // namespace snippets
} // namespace ov
12 changes: 6 additions & 6 deletions src/common/snippets/src/lowered/pass/brgemm_blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,21 @@ bool BrgemmBlockingBase::mark_blocking_loops(snippets::lowered::LinearIR& linear

const auto& loop_manager = linear_ir.get_loop_manager();
if (!ov::snippets::utils::is_full_dim_value(k_block)) {
const std::vector<LoopPort> entries{LoopPort::create(brgemm_expr->get_input_port(0), 0),
LoopPort::create(brgemm_expr->get_input_port(1), 1)};
const std::vector<LoopPort> entries{LoopPort::create<LoopPort::Type::Incremented>(brgemm_expr->get_input_port(0), 0),
LoopPort::create<LoopPort::Type::Incremented>(brgemm_expr->get_input_port(1), 1)};
const std::vector<LoopPort> exits{LoopPort::create<LoopPort::Type::NotProcessed>(brgemm_expr->get_output_port(0))};
mark_k_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, k_block);
}
if (!ov::snippets::utils::is_full_dim_value(n_block)) {
const std::vector<LoopPort> entries{LoopPort::create<LoopPort::Type::NotProcessed>(brgemm_expr->get_input_port(0)),
LoopPort::create(brgemm_expr->get_input_port(1))};
const std::vector<LoopPort> exits{LoopPort::create(brgemm_expr->get_output_port(0))};
LoopPort::create<LoopPort::Type::Incremented>(brgemm_expr->get_input_port(1))};
const std::vector<LoopPort> exits{LoopPort::create<LoopPort::Type::Incremented>(brgemm_expr->get_output_port(0))};
mark_n_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, n_block);
}
if (!ov::snippets::utils::is_full_dim_value(m_block)) {
const std::vector<LoopPort> entries{LoopPort::create(brgemm_expr->get_input_port(0), 1),
const std::vector<LoopPort> entries{LoopPort::create<LoopPort::Type::Incremented>(brgemm_expr->get_input_port(0), 1),
LoopPort::create<LoopPort::Type::NotProcessed>(brgemm_expr->get_input_port(1))};
const std::vector<LoopPort> exits{LoopPort::create(brgemm_expr->get_output_port(0), 1)};
const std::vector<LoopPort> exits{LoopPort::create<LoopPort::Type::Incremented>(brgemm_expr->get_output_port(0), 1)};
mark_m_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, m_block);
}
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ bool CleanRepeatedDataPointerShifts::reuse_increments(const LoopManagerPtr& loop
const auto loop_info = loop_manager->get_loop_info<UnifiedLoopInfo>(loop_end->get_id());
size_t loop_port_idx = 0;
loop_info->iterate_through_infos([&resetting_data_indexes, &loop_port_idx](LoopPort& loop_port, UnifiedLoopInfo::LoopPortDesc& shifts) {
if (resetting_data_indexes.count(loop_port_idx) && loop_port.get_type() != LoopPort::Type::NotProcessed) {
if (resetting_data_indexes.count(loop_port_idx) && loop_port.is_processed()) {
shifts.ptr_increment = 0;
shifts.finalization_offset = 0;
loop_port.set_type(LoopPort::Type::NotIncremented);
loop_port.convert_to_type<LoopPort::Type::NotIncremented>();
}
++loop_port_idx;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ void DefineBufferClusters::parse_loop(const LoopManagerPtr& loop_manager, const
const auto out_path = MarkInvariantShapePath::getInvariantPortShapePath(*output_buffer_port_info.port.get_expr_port());
// - Memory can be reused if there are the same loop pointer increments (data size, final offsets, ptr increments).
// For that, loop ports with buffers should be on the same shape-path and have the same value of `is_incremented`.
const auto in_is_incremented = input_buffer_port_info.port.get_type() == LoopPort::Type::Incremented;
const auto out_is_incremented = output_buffer_port_info.port.get_type() == LoopPort::Type::Incremented;
const auto in_is_incremented = input_buffer_port_info.port.is_incremented();
const auto out_is_incremented = output_buffer_port_info.port.is_incremented();
if (in_path != out_path || in_is_incremented != out_is_incremented)
continue;

Expand Down Expand Up @@ -176,8 +176,7 @@ void DefineBufferClusters::parse_nested_loops(const LoopManagerPtr& loop_manager
auto can_be_data_ptr_proportionally_shifted = [](const LoopPortInfo& outer_port_info, const LoopPortInfo& inner_port_info) {
// Outer Buffer ptr should be shifted to emulate "window" sliding
const auto& outer_desc = outer_port_info.desc;
if (outer_port_info.port.get_type() != LoopPort::Type::Incremented ||
(!utils::is_dynamic_value(outer_desc.ptr_increment) && outer_desc.ptr_increment == 0))
if (!outer_port_info.port.is_incremented() || (!utils::is_dynamic_value(outer_desc.ptr_increment) && outer_desc.ptr_increment == 0))
return false;

OPENVINO_ASSERT(inner_port_info.port.get_expr_port() && outer_port_info.port.get_expr_port(), "Expression ports are nullptr!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ bool is_extraction_applicable(const ExpressionPtr& expr, const UnifiedLoopInfoPt
if (is_loop_port) {
// stride is not 1 after move to outside, then should not extract.
const auto& loop_port = inner_loop_info->get_loop_port(expr_input_ports[i]);
if (loop_port.get_type() == LoopPort::Type::NotProcessed || get_stride_after_move_outer(loop_port) != 1) {
if (!loop_port.is_processed() || get_stride_after_move_outer(loop_port) != 1) {
return false;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/pass/fuse_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ bool FuseLoops::loop_ports_are_compatible(const LoopInfoPtr& loop_upper,
const auto upper_exit_port_it = found_port(upper_exit_ports, src_port);
if (upper_exit_port_it != upper_exit_ports.cend()) {
const auto& upper_exit_port = *upper_exit_port_it;
if (!utils::everyone_is(LoopPort::Type::Incremented, lower_entry_port.get_type(), upper_exit_port.get_type()))
if (!lower_entry_port.is_incremented() || !upper_exit_port.is_incremented())
return false;
if (lower_entry_port.get_dim_idx() != upper_exit_port.get_dim_idx())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/pass/init_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace {
inline void init_is_incremented(LoopPort& port) {
const auto& expr = port.get_expr_port()->get_expr();
if (!std::dynamic_pointer_cast<modifier::MemoryAccess>(expr->get_node())) {
port.set_type(LoopPort::Type::NotIncremented);
port.convert_to_type<LoopPort::Type::NotIncremented>();
}
}

Expand Down
Loading

0 comments on commit 2c9cfcb

Please sign in to comment.