Skip to content

[Performance] Solve high memory usage issue during model compilation using OpenVINO backend on Keras 3 #31482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,8 @@ class TRANSFORMATIONS_API EinsumDecomposition;
class ov::pass::EinsumDecomposition : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("EinsumDecomposition");
EinsumDecomposition();
EinsumDecomposition(bool check_const = false);

private:
bool m_check_const; // store the flag
};
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
#include "transformations/op_conversions/convert_scatter_elements_to_scatter.hpp"
#include "transformations/op_conversions/convert_subtract.hpp"
#include "transformations/op_conversions/convert_ti_to_sequences.hpp"
#include "transformations/op_conversions/einsum_decomposition.hpp"
#include "transformations/resolve_names_collisions.hpp"
#include "transformations/smart_reshape/lstm_states_broadcast.hpp"
#include "transformations/smart_reshape/matmul_sr.hpp"
Expand Down Expand Up @@ -164,6 +165,13 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
REGISTER_PASS(manager, ConstantFolding)
REGISTER_PASS(manager, Validate)

// EinsumDecomposition should be called after ConstantFolding
// for better performance and memory usage.
// ConstantFolding creates constant inputs to Einsum operations,
// which EinsumDecomposition can then decompose more efficiently with
// reduced memory consumption.
REGISTER_PASS(manager, EinsumDecomposition, true)

// FusedFilteringBoxesBySize transformation has the complex pattern
// which can be affected by further transformations. So we have to
// execute it at the beginning of the pipeline. Also, this pass resolves
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ void fix_inputs_with_0d_ellipsis(ov::OutputVector& input_nodes,
/// 8. Transpose dimensions to match the layout required by the output subscript.
/// 9. Replace the original Einsum node with the last node from the decomposed sub-graph,
/// preserving the original node's name and runtime information.
ov::pass::EinsumDecomposition::EinsumDecomposition() {
ov::pass::EinsumDecomposition::EinsumDecomposition(bool check_const) : m_check_const(check_const) {
MATCHER_SCOPE(EinsumDecomposition);
auto einsum = ov::pass::pattern::wrap_type<ov::op::v7::Einsum>();
matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
Expand All @@ -1300,6 +1300,28 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() {
return false;
}

if (m_check_const) {
// This optimization targets Einsum operations in transformer models
// where at least one input is constant. After ConstantFolding,
// weight matrices become constants enabling efficient decomposition.
// Optimized patterns:
// 1. Weight projections: einsum("abc,cd->abd", input, weight_matrix) - OPTIMIZED (constant weight)
// 2. Attention scores: einsum("aecd,abcd->acbe", key, query) - NOT OPTIMIZED (both variable)
// 3. Attention-value: einsum("acbe,aecd->abcd", attention_scores, value) - NOT OPTIMIZED (both variable)
// See: https://gist.github.com/Mohamed-Ashraf273/59eddcd120918cb0761ffa5020800d5d
bool has_const = false;
for (auto& input : einsum_node->input_values()) {
auto node_ptr = input.get_node_shared_ptr();
auto constant_ptr = ov::as_type_ptr<ov::op::v0::Constant>(node_ptr);
if (constant_ptr) {
has_const = true;
break;
}
}
if (!has_const)
return false;
}

// Parse the Einsum equation to get input and output subscripts
auto equation = einsum_node->get_equation();
std::vector<std::string> input_subscripts;
Expand Down
Loading