Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ namespace npuw {
namespace llm {
enum class PrefillHint { DYNAMIC, STATIC };
enum class GenerateHint { FAST_COMPILE, BEST_PERF };
enum class AttentionHint { DYNAMIC, STATIC, PYRAMID };
enum class AttentionHint { DYNAMIC, STATIC, PYRAMID, FLASH };
} // namespace llm
} // namespace npuw

Expand Down Expand Up @@ -228,6 +228,8 @@ struct ATTN_HINT_BASE : OptionBase<ATTN_HINT_BASE, ::intel_npu::npuw::llm::Atten
return ::intel_npu::npuw::llm::AttentionHint::STATIC;
} else if (val == "PYRAMID") {
return ::intel_npu::npuw::llm::AttentionHint::PYRAMID;
} else if (val == "FLASH") {
return ::intel_npu::npuw::llm::AttentionHint::FLASH;
}
OPENVINO_THROW("Unsupported attention hint provided: ", val);
return {};
Expand All @@ -241,6 +243,8 @@ struct ATTN_HINT_BASE : OptionBase<ATTN_HINT_BASE, ::intel_npu::npuw::llm::Atten
return "STATIC";
case ::intel_npu::npuw::llm::AttentionHint::PYRAMID:
return "PYRAMID";
case ::intel_npu::npuw::llm::AttentionHint::FLASH:
return "FLASH";
default:
OPENVINO_THROW("Can't convert provided attention hint : ", int(val), " to string.");
}
Expand Down
134 changes: 133 additions & 1 deletion src/plugins/intel_npu/src/plugin/npuw/base_sync_infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ void ov::npuw::IBaseInferRequest::bind_global_params(std::size_t idx, RqPtr requ
const bool is_spatial = proto_comp_model_desc.spatial.has_value();
const bool is_attention = proto_comp_model_desc.attention.has_value();
const bool is_pyramid_attention = proto_comp_model_desc.pyramid_attention.has_value();
const bool is_flash_attention = proto_comp_model_desc.flash_attention.has_value();

// a list of ports to copy tensors, if needed: FROM -> TO
std::vector<std::pair<ov::SoPtr<ov::ITensor>, ov::Output<const ov::Node>>> copy_list;
Expand Down Expand Up @@ -534,6 +535,18 @@ void ov::npuw::IBaseInferRequest::bind_global_params(std::size_t idx, RqPtr requ
});
};

auto is_flash_attn_param = [&](std::size_t sub_in_idx) -> bool {
if (!is_flash_attention) {
return false; // Early return
}

// auto& attn = proto_comp_model_desc.flash_attention.value();
// return std::any_of(attn.params.begin(), attn.params.end(), [&](const auto& p) -> bool {
// return p.idx == sub_in_idx;
// });
return true;
};

for (auto&& it : iodesc.global_params) {
std::size_t param_idx{}, sub_in_idx{};
std::tie(param_idx, sub_in_idx) = it;
Expand All @@ -554,8 +567,9 @@ void ov::npuw::IBaseInferRequest::bind_global_params(std::size_t idx, RqPtr requ
// function pipelining
NPUW_ASSERT(false && "Global parameter can't be spatial");
m_spatial_io[real_idx].inputs.at(sub_in_idx) = g_tnsr;
} else if (is_attn_param(sub_in_idx) || is_pyramid_attn_param(sub_in_idx)) {
} else if (is_attn_param(sub_in_idx) || is_pyramid_attn_param(sub_in_idx) || is_flash_attn_param(sub_in_idx)) {
// Register for future use
LOG_DEBUG("Register for future use as attention_io["<< idx <<"]");
m_attention_io[idx].inputs.at(sub_in_idx) = g_tnsr;
} else {
// Input parameter is non-spatial, do normal handling
Expand Down Expand Up @@ -602,6 +616,11 @@ void ov::npuw::IBaseInferRequest::bind_global_params(std::size_t idx, RqPtr requ
bind_pyramid_attention_inputs(idx, request);
});

// Handle pyramid attention inputs, if required
m_profile["attn(io)"].record([&]() {
bind_flash_attention_inputs(idx, request);
});

LOG_DEBUG("Done");
}

Expand Down Expand Up @@ -755,6 +774,119 @@ void ov::npuw::IBaseInferRequest::bind_attention_inputs(std::size_t idx, RqPtr r
LOG_DEBUG("Done");
}

void ov::npuw::IBaseInferRequest::bind_flash_attention_inputs(std::size_t idx, RqPtr request) {
auto& comp_model_desc = m_npuw_model->m_compiled_submodels[real(idx)];
if (!comp_model_desc.flash_attention) {
return;
}

LOG_DEBUG("Binding Flash Attention inputs...");
LOG_BLOCK();

const auto tile_id = m_flash_selector->tile_id();
const auto& flash_attention = comp_model_desc.flash_attention.value();
const auto& attention_params = flash_attention.params;
const auto& flash_models = flash_attention._compiled_models;

// TODO: es why bind global not enough???
/// const auto& attention_model = pyramid_attention._compiled_models[pyramid_id];

using PA = npuw::function::FlashAttention;
// using concat model: bind it's inputs to a global
// const auto & concat_model = flash_models[PA::eConcat];
// for (auto&& param : attention_params[PA::eConcat]) {
// const auto& iport = concat_model->inputs()[param.idx];
// const auto& input = m_attention_io[idx].inputs.at(param.idx);
// request->set_tensor(iport, input);
// }

//TODO: when to recreate infer-requests per each flash attention - actually we need 3 right now

// Pyramid dynamic range identified
// const auto past_len = m_pyramid_selector->past_length();
// const auto infer_case = m_pyramid_selector->this_case();

// using namespace ov::npuw::runtime;

// // Process each KV parameter based on inference case
// if (infer_case == pyramid_attention::Selector::Case::PREFILL) {
// // PREFILL: Set or copy past KV to destination tensors
// for (auto&& param : attention_info.params) {
// const auto& iport = pyramid_model->inputs()[param.idx];
// const auto& input = m_attention_io[idx].inputs.at(param.idx);
// const auto& input_shape = input->get_shape();

// LOG_DEBUG(iport);
// LOG_BLOCK();

// // Optimization for the last chunk: Direct tensor reuse when shapes match
// if (static_cast<int64_t>(input_shape[param.dim]) == past_len) {
// request->set_tensor(iport, input);
// continue;
// }

// // Create view of past KV data
// const auto& view = ov::npuw::util::view(input, param.dim, 0, past_len);
// const auto& shape = view->get_shape();

// // Handle empty shape case (first chunk)
// if (ov::shape_size(shape) == 0) {
// request->get_tensor(iport)->set_shape(shape);
// continue;
// }

// // Copy past KV to full destination tensor
// LOG_DEBUG("Do copy: " << shape << "...");
// const auto& dst = request->get_tensor(iport);
// ov::npuw::util::copy_tensor_by_dim(view,
// dst,
// static_cast<uint32_t>(param.dim),
// static_cast<uint32_t>(param.dim));
// }
// } else if (infer_case == pyramid_attention::Selector::Case::GENERATE) {
// // GENERATE: Set or copy past KV, preserving existing data
// for (auto&& param : attention_info.params) {
// const auto& iport = pyramid_model->inputs()[param.idx];
// const auto& input = m_attention_io[idx].inputs.at(param.idx);
// const auto& input_shape = input->get_shape();

// LOG_DEBUG(iport);
// LOG_BLOCK();

// // Validation: ensure space for new tokens
// if (static_cast<int64_t>(input_shape[param.dim]) == past_len) {
// NPUW_ASSERT(false && "Past KV is full, no space for generation");
// }

// const auto& dst = request->get_tensor(iport);
// const auto& dst_shape = dst->get_shape();

// // Optimization: Direct tensor reuse when destination matches input
// if (dst_shape == input_shape) {
// request->set_tensor(iport, input);
// continue;
// }

// // FIXME: No need to copy whole past KV, just the new part

// // Create view of past KV data
// const auto& view = ov::npuw::util::view(input, param.dim, 0, past_len);

// // Copy past KV to sliced destination (preserve space for new tokens)
// LOG_DEBUG("Do copy: " << view->get_shape() << "...");
// const auto& dst_slice = ov::npuw::util::view(dst, param.dim, 0, past_len);
// ov::npuw::util::copy_tensor_by_dim(view,
// dst_slice,
// static_cast<uint32_t>(param.dim),
// static_cast<uint32_t>(param.dim));
// }
// } else {
// NPUW_ASSERT(false && "Unsupported pyramid attention case");
// }

LOG_DEBUG("Done");
}

void ov::npuw::IBaseInferRequest::bind_pyramid_attention_inputs(std::size_t idx, RqPtr request) {
auto& comp_model_desc = m_npuw_model->m_compiled_submodels[real(idx)];
if (!comp_model_desc.pyramid_attention) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "openvino/runtime/so_ptr.hpp"
#include "perf.hpp"
#include "pyramid_attention.hpp"
#include "flash_attention.hpp"
#include "spatial.hpp"
#include "util.hpp"

Expand Down Expand Up @@ -150,6 +151,9 @@ class IBaseInferRequest : public ov::ISyncInferRequest {
// Separate selector for pyramid attention
runtime::pyramid_attention::Selector::Ptr m_pyramid_selector;

// Separate selector for pyramid attention
runtime::flash_attention::Selector::Ptr m_flash_selector;

// This structure tracks how every individual subrequest
// access the model's top-level (global, public, etc) parameters
// and results. Again, is managed by subclasses
Expand Down Expand Up @@ -183,6 +187,7 @@ class IBaseInferRequest : public ov::ISyncInferRequest {

void bind_attention_inputs(std::size_t idx, RqPtr request);
void bind_pyramid_attention_inputs(std::size_t idx, RqPtr request);
void bind_flash_attention_inputs(std::size_t idx, RqPtr request);

void dump_input_tensors(std::size_t idx);
void dump_output_tensors(std::size_t idx);
Expand Down
89 changes: 88 additions & 1 deletion src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,13 @@ ov::npuw::CompiledModel::CompiledModel(const std::shared_ptr<ov::Model>& model,
LOG_INFO("Creating compiled::PyramidAttention for Subgraph[" << id << "] (function "
<< subgraph._funcall << ")");
m_compiled_submodels[id].pyramid_attention =
compiled::PyramidAttention(fcn_template._pyramid_attention.value());
compiled::PyramidAttention{fcn_template._pyramid_attention.value()};
}
if (fcn_template._flash_attention) {
LOG_INFO("Creating compiled::FlashAttention for Subgraph[" << id << "] (function "
<< subgraph._funcall << ")");
m_compiled_submodels[id].flash_attention =
compiled::FlashAttention{fcn_template._flash_attention.value()};
}
LOG_INFO("Subgraph[" << id << "] is a function body for " << subgraph._funcall);
} else {
Expand Down Expand Up @@ -455,6 +461,28 @@ ov::npuw::CompiledModel::CompiledModel(const std::shared_ptr<ov::Model>& model,
LOG_INFO("Wrote " << pyramid_attention_model_dump_path);
}
}

// Dump flash-attention subgraphs
if (m_compiled_submodels[id].flash_attention) {
//TODO: make a loop
auto tiled_model = m_compiled_submodels[id].flash_attention.value()._models_to_compile[ov::npuw::function::FlashAttention::eTile];
std::string flash_attention_model_dump_path =
m_name + (subgraph._funcall.empty() ? "" : "_" + subgraph._funcall) + "_flash_" + "tile.xml";
ov::save_model(tiled_model, flash_attention_model_dump_path);
LOG_INFO("Wrote " << flash_attention_model_dump_path);

auto kv_cache_concat_model = m_compiled_submodels[id].flash_attention.value()._models_to_compile[ov::npuw::function::FlashAttention::eConcat];
flash_attention_model_dump_path =
m_name + (subgraph._funcall.empty() ? "" : "_" + subgraph._funcall) + "_flash_" + "kv_cache_concat.xml";
ov::save_model(kv_cache_concat_model, flash_attention_model_dump_path);
LOG_INFO("Wrote " << flash_attention_model_dump_path);

auto divide_model = m_compiled_submodels[id].flash_attention.value()._models_to_compile[ov::npuw::function::FlashAttention::eDivide];
flash_attention_model_dump_path =
m_name + (subgraph._funcall.empty() ? "" : "_" + subgraph._funcall) + "_flash_" + "divide.xml";
ov::save_model(divide_model, flash_attention_model_dump_path);
LOG_INFO("Wrote " << flash_attention_model_dump_path);
}
} // if(dump)
} // for(orderedSubGraphs)

Expand Down Expand Up @@ -698,6 +726,19 @@ void ov::npuw::CompiledModel::CompiledModelDesc::serialize(std::ostream& stream,
write(stream, ss.str());
}
}
// TODO: add flash-attention serialize
// write(stream, flash_attention);
// if (flash_attention.has_value()) {
// size_t num_models = flash_attention.value()._compiled_models.size();
// write(stream, num_models);

// for (size_t i = 0; i < num_models; ++i) {
// std::stringstream ss;
// auto compiled_model = flash_attention.value()._compiled_models[i];
// compiled_model->export_model(ss);
// write(stream, ss.str());
// }
// }

auto& closure_desc = closure.get();

Expand Down Expand Up @@ -800,6 +841,7 @@ void ov::npuw::CompiledModel::CompiledModelDesc::deserialize(std::istream& strea
}
}
}
// TODO: repeat deserialize for flash attention

auto& closure_desc = closure.get();

Expand Down Expand Up @@ -1451,6 +1493,7 @@ void ov::npuw::CompiledModel::detach_memory() {
// No need to clear pyramid attention data - it's self-contained!
// The _models_to_compile is already cleared in set_compiled_models()
// and compiled::PyramidAttention only stores _compiled_models (not original models)
// same applied for FlashAttention
}
LOG_INFO("Done");
}
Expand Down Expand Up @@ -1622,6 +1665,9 @@ bool ov::npuw::CompiledModel::compile_for_device(std::size_t id, const std::stri

// Compile pyramid attention models if present
compile_pyramid_attention_models(id, device_to_try);

// Compile flash attention models if present
compile_flash_attention_model(id, device_to_try);
} catch (const std::exception& ex) {
LOG_ERROR("Subgraph [" << id << "] Failed to compile: " << std::endl << ex.what());
dump_on_fail(id, device_to_try, ex.what());
Expand All @@ -1636,6 +1682,47 @@ bool ov::npuw::CompiledModel::compile_for_device(std::size_t id, const std::stri
return true;
}

void ov::npuw::CompiledModel::compile_flash_attention_model(std::size_t id, const std::string & device) {
// Check if we have flash attention to compile
if (!m_compiled_submodels[id].flash_attention.has_value()) {
return;
}

LOG_INFO("Compiling flash attention submodel for Subgraph[" << id << "]...");
LOG_BLOCK();

auto& flash_attn = m_compiled_submodels[id].flash_attention.value();

auto compile_tile = [&]() {
try {
std::vector<ov::SoPtr<ov::ICompiledModel>> compiled;
for (size_t sub_idx = 0; sub_idx != flash_attn._models_to_compile.size(); ++sub_idx) {
const auto& model = flash_attn._models_to_compile[sub_idx];

LOG_DEBUG("Compiling flash attention submodel [ "<< sub_idx << "]: "<< model->get_friendly_name());

auto compiled_submodel = compile_submodel(model, device);
OPENVINO_ASSERT(compiled_submodel, "Failed to compile flash attention submodel");

compiled.push_back(std::move(compiled_submodel));
}

// Set compiled model - this also clears _tiles_model internally
LOG_INFO("Setting compiled model into compiled::FlashAttention...");
flash_attn.set_compiled_models(std::move(compiled));

// TODO: specify tile size maybe in logs
LOG_INFO("Flash attention compilation complete for Subgraph[" << id << "]");
} catch (const std::exception& ex) {
OPENVINO_THROW("Flash attention compilation failed: ", ex.what());
} catch (...) {
OPENVINO_THROW("Flash attention compilation failed with unknown error");
}
};

compile_tile();
}

void ov::npuw::CompiledModel::compile_pyramid_attention_models(std::size_t id, const std::string& device) {
// Check if we have pyramid attention to compile
if (!m_compiled_submodels[id].pyramid_attention.has_value()) {
Expand Down
6 changes: 6 additions & 0 deletions src/plugins/intel_npu/src/plugin/npuw/compiled_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "partitioning/partitioning.hpp"
#include "perf.hpp"
#include "pyramid_attention.hpp"
#include "flash_attention.hpp"
#include "serialization.hpp"
#include "spatial.hpp"
#include "weights_bank.hpp"
Expand Down Expand Up @@ -80,6 +81,7 @@ class CompiledModel : public ov::npuw::ICompiledModel {
ov::SoPtr<ov::ICompiledModel> compile_submodel(const std::shared_ptr<ov::Model>& submodel,
const std::string& device);
void compile_pyramid_attention_models(std::size_t id, const std::string& device);
void compile_flash_attention_model(std::size_t id, const std::string& device);

void dump_on_fail(std::size_t id, const std::string& device_to_stry, const char* extra);

Expand Down Expand Up @@ -180,6 +182,10 @@ class CompiledModel : public ov::npuw::ICompiledModel {
std::optional<ov::npuw::compiled::Spatial> spatial;
std::optional<ov::npuw::compiled::Attention> attention;
std::optional<ov::npuw::compiled::PyramidAttention> pyramid_attention;
std::optional<ov::npuw::compiled::FlashAttention> flash_attention;

// TODO: reuse maybe infer-requests between flash_attention and pyramid
std::vector<ov::SoPtr<ov::IAsyncInferRequest>> flash_infer_requests;

// Infer requests for pyramid attention models (if pyramid_attention is present)
std::vector<ov::SoPtr<ov::IAsyncInferRequest>> pyramid_infer_requests;
Expand Down
Loading