diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 8ec94c67cc0a4..42e8e9c5e3cbe 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -105,9 +105,9 @@ class OpKernel { return Status::OK(); } - // Note: New implementations should override OpKernel::UseSharedPrePackedBuffers_V2 instead. // Override this function to use provided pre-packed weight. // Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + // gsl::span prepacked_buffer_sizes, // int input_idx, // /*out*/ bool& used_shared_buffers) { // used_shared_buffers = true; @@ -121,37 +121,18 @@ class OpKernel { // and must use the same order for retrieval in UseSharedPrePackedBuffers(). Though each element // of this vector is a BufferUniquePtr, the deleter of the BufferUniquePtr is NULL. So actually they // are raw pointers. + // @param prepacked_buffer_sizes: The sizes (in bytes) of each buffer in prepacked_buffers. // @param input_idx: The input index of the tensor in this kernel // @param used_shared_buffers: Boolean flag set by the kernel implementation indicating // that the provided weight has been used by the kernel. virtual Status UseSharedPrePackedBuffers(std::vector& /*prepacked_buffers*/, + gsl::span /*prepacked_buffer_sizes*/, int /*input_idx*/, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; return Status::OK(); } - /// - /// Version 2 of OpKernel::UseSharedPrePackedBuffers() that additionally accepts the buffer sizes as a parameter. - /// The default implementation of this function just calls directly to OpKernel::UseSharedPrePackedBuffers() - /// to avoid the need to update all existing kernel-based provider-bridge EPs. - /// - /// TODO: Consolidate UseSharedPrePackedBuffers and UseSharedPrePackedBuffers_V2 into a single function, - /// which will require updating kernel-based provider-bridge EPs (cpu, cuda, webgpu). - /// - /// - /// - /// - /// - /// - /// - virtual Status UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, - gsl::span /*prepacked_buffer_sizes*/, - int input_idx, - /*out*/ bool& used_shared_buffers) { - return UseSharedPrePackedBuffers(prepacked_buffers, input_idx, used_shared_buffers); - } - const OrtDevice GetDevice(OrtMemType mem_type) const; const OpKernelInfo& Info() const { return *op_kernel_info_; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index a938688fcfd5a..e457a2a57065e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -3613,12 +3613,6 @@ struct KernelRegistry : detail::Base { }; namespace detail { -/** \brief Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`. - * - * Holds a single type constraint from an operator schema, providing access to - * the constraint's name, allowed data types, and associated input/output indices. - * This is a non-owning view — the lifetime is tied to the parent OrtOpSchema. - */ template struct OpSchemaTypeConstraintImpl : Base { using B = Base; @@ -3639,15 +3633,11 @@ struct OpSchemaTypeConstraintImpl : Base { } // namespace detail /// Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`. +/// Holds a single type constraint from an operator schema, providing access to +/// the constraint's name, allowed data types, and associated input/output indices. using ConstOpSchemaTypeConstraint = detail::OpSchemaTypeConstraintImpl>; namespace detail { -/** \brief Owning wrapper around an `OrtOpSchema*`. - * - * Provides access to operator schema metadata such as version, input/output names, - * and type constraints. The underlying OrtOpSchema is owned by this wrapper and - * released automatically on destruction. - */ template struct OpSchemaImpl : Base { using B = Base; @@ -3685,6 +3675,9 @@ struct OpSchemaImpl : Base { } // namespace detail /// Owning wrapper around an `OrtOpSchema*`. +/// Provides access to operator schema metadata such as version, input/output names, +/// and type constraints. The underlying OrtOpSchema is owned by this wrapper and +/// released automatically on destruction. using OpSchema = detail::OpSchemaImpl; /// \brief Get an operator schema from the global schema registry. diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 7268b32623b95..e1981fb5c2442 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -34,6 +34,7 @@ class Attention : public OpKernel, public AttentionCPUBase { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; @@ -176,6 +177,7 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr template Status Attention::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { if (1 != input_idx) { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index ca2c3ab001da6..a674d05b6daae 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -578,10 +578,10 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all } template -Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, - gsl::span /*prepacked_buffer_sizes*/, - int input_idx, - /*out*/ bool& used_shared_buffers) { +Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, + /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; if (expert_weight_bits_ != 4) { @@ -1577,11 +1577,11 @@ template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); -template Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); +template Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); -template Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); +template Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); // Kernel Registration ONNX_OPERATOR_TYPED_KERNEL_EX( diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h index f678a27190c90..c5e6904ae48c2 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -32,10 +32,10 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; - Status UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, - gsl::span prepacked_buffer_sizes, - int input_idx, - /*out*/ bool& used_shared_buffers) override; + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) override; void ApplyActivationVectorized(float* data, int64_t size) const; diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index b30fa1e5e618a..931677582d469 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -28,6 +28,7 @@ class QAttention : public OpKernel, public AttentionCPUBase { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; @@ -117,6 +118,7 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr template Status QAttention::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { if (1 != input_idx) { diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc index f55e66f9c5d81..2094af78f40b7 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc @@ -17,6 +17,7 @@ class DynamicQuantizeLSTM : public OpKernel, public LSTMBase { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; @@ -117,6 +118,7 @@ Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, Allocat } Status DynamicQuantizeLSTM::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index d2996b122c5f7..3da0ee19d4cde 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -135,7 +135,9 @@ class MatMulNBits final : public OpKernel { /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; - Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, /*out*/ bool& used_shared_buffers) override; private: @@ -557,7 +559,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou #endif // end !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_ARM64 template -Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, +Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 5c33a621cf514..84521af2d8532 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -436,8 +436,8 @@ static Status KernelUseSharedPrePackedBuffers(OpKernel& kernel, int input_idx, } bool used_shared_buffers = false; - ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers_V2(shared_prepacked_buffers, shared_prepacked_buffer_sizes, - input_idx, used_shared_buffers)); + ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers(shared_prepacked_buffers, shared_prepacked_buffer_sizes, + input_idx, used_shared_buffers)); // BUG CHECK: Ensure that the kernel used the provided shared buffers // Mostly a debug check to ensure that the kernel has an overridden implementation of the diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index 9fd71b3b00cd0..7fe7c914fa796 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -2,15 +2,673 @@ // Licensed under the MIT License. #include "core/graph/graph_utils.h" +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/attention_fusion.h" #include "core/optimizer/utils.h" #include "core/optimizer/attention_fusion_helper.h" -#include "core/graph/graph_utils.h" #include +#include namespace onnxruntime { +static bool ValidateMatMulInitializer(const Graph& graph, const Node& matmul, int64_t hidden_size); + +namespace { + +static bool ValidateAddBiasInitializerEitherInput(const Graph& graph, const Node& add, int64_t hidden_size) { + if (add.InputDefs().size() < 2) { + return false; + } + + const NodeArg& input_0 = *(add.InputDefs()[0]); + const NodeArg& input_1 = *(add.InputDefs()[1]); + const bool input_0_is_bias = graph_utils::IsInitializer(graph, input_0.Name(), true) && + optimizer_utils::ValidateShape(input_0, {hidden_size}); + const bool input_1_is_bias = graph_utils::IsInitializer(graph, input_1.Name(), true) && + optimizer_utils::ValidateShape(input_1, {hidden_size}); + return input_0_is_bias || input_1_is_bias; +} + +static bool ValidateProjectionGemmInitializer(const Graph& graph, const Node& gemm, int64_t hidden_size) { + if (gemm.InputDefs().size() < 3) { + return false; + } + + if (const auto* alpha_attr = graph_utils::GetNodeAttribute(gemm, "alpha"); + alpha_attr && std::abs(alpha_attr->f() - 1.0f) > 1e-6f) { + return false; + } + + if (const auto* beta_attr = graph_utils::GetNodeAttribute(gemm, "beta"); + beta_attr && std::abs(beta_attr->f() - 1.0f) > 1e-6f) { + return false; + } + + if (const auto* trans_a_attr = graph_utils::GetNodeAttribute(gemm, "transA"); + trans_a_attr && trans_a_attr->i() != 0) { + return false; + } + + if (const auto* trans_b_attr = graph_utils::GetNodeAttribute(gemm, "transB"); + trans_b_attr && trans_b_attr->i() != 0) { + return false; + } + + const NodeArg& input_b = *(gemm.InputDefs()[1]); + const NodeArg& input_c = *(gemm.InputDefs()[2]); + if (!graph_utils::IsInitializer(graph, input_b.Name(), true) || + !graph_utils::IsInitializer(graph, input_c.Name(), true)) { + return false; + } + + return optimizer_utils::ValidateShape(input_b, {hidden_size, hidden_size}) && + optimizer_utils::ValidateShape(input_c, {hidden_size}); +} + +// Most attention fusions require all matched nodes to already be assigned to an execution provider +// that supports the fused op. MobileClipMHA is also matched before partitioning in graph-transform +// tests, so nodes may still be unassigned here. Accept nodes that are either unassigned or already +// assigned to a compatible provider, and preserve the original provider string on the fused nodes +// once the pattern is rewritten. +static bool IsSupportedOrUnassignedNode(const Node& node, + const InlinedHashSet& compatible_execution_providers) { + return node.GetExecutionProviderType().empty() || + graph_utils::IsSupportedProvider(node, compatible_execution_providers); +} + +static bool IsSupportedOrUnassignedNode(const Node& node, + std::string_view required_execution_provider) { + const auto& execution_provider = node.GetExecutionProviderType(); + return execution_provider.empty() || + execution_provider == required_execution_provider; +} + +static bool AreSupportedOrUnassignedNodes( + const Node& anchor_node, + const std::initializer_list& nodes, + const InlinedHashSet& compatible_execution_providers) { + if (!IsSupportedOrUnassignedNode(anchor_node, compatible_execution_providers)) { + return false; + } + + const auto& required_execution_provider = anchor_node.GetExecutionProviderType(); + for (const Node* node : nodes) { + if (node == nullptr) { + continue; + } + + if (!IsSupportedOrUnassignedNode(*node, required_execution_provider)) { + return false; + } + } + + return true; +} + +static bool HasExpectedPerm(const Node& node, const std::initializer_list& expected_perm) { + return optimizer_utils::IsAttributeWithExpectedValues(node, "perm", std::vector(expected_perm)); +} + +static bool HasExpectedAxesInput(const Graph& graph, const Node& node, const std::initializer_list& expected_axes) { + if (node.InputDefs().size() < 2) { + return false; + } + + InlinedVector axes; + if (!optimizer_utils::AppendTensorFromInitializer(graph, *node.InputDefs()[1], axes, true)) { + return false; + } + + return axes == InlinedVector(expected_axes.begin(), expected_axes.end()); +} + +static bool TryGetMobileClipQkvReshapeInfo(const Graph& graph, const Node& qkv_reshape, + int64_t& num_heads, int64_t& head_size, int64_t& hidden_size) { + if (qkv_reshape.InputDefs().size() < 2) { + return false; + } + + InlinedVector reshape_dims; + if (!optimizer_utils::AppendTensorFromInitializer(graph, *qkv_reshape.InputDefs()[1], reshape_dims, true)) { + return false; + } + + if (reshape_dims.size() != 5 || reshape_dims[2] != 3 || reshape_dims[3] <= 0 || reshape_dims[4] <= 0) { + return false; + } + + num_heads = reshape_dims[3]; + head_size = reshape_dims[4]; + + try { + hidden_size = SafeInt(num_heads) * head_size; + } catch (const OnnxRuntimeException&) { + return false; + } + + return hidden_size > 0; +} + +static std::optional TryCreateMobileClipMhaOutputType(const NodeArg& qkv_output, + int64_t hidden_size) { + const auto* qkv_output_type = qkv_output.TypeAsProto(); + if (qkv_output_type == nullptr || !qkv_output_type->has_tensor_type()) { + return std::nullopt; + } + + ONNX_NAMESPACE::TypeProto mha_output_type(*qkv_output_type); + auto* shape = mha_output_type.mutable_tensor_type()->mutable_shape(); + if (shape->dim_size() > 0) { + auto* last_dim = shape->mutable_dim(shape->dim_size() - 1); + last_dim->clear_dim_param(); + last_dim->set_dim_value(hidden_size); + } + + return mha_output_type; +} + +static Node* GetOnlyChildByOutputIndex(Graph& graph, const Node& node, size_t output_index, const char* child_op_type) { + const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, output_index); + if (output_edges.size() != 1) { + return nullptr; + } + + Node* child = graph.GetNode(output_edges[0].dst_node); + if (child == nullptr || child->OpType() != child_op_type) { + return nullptr; + } + + return child; +} + +static bool TryCreateNormalizedProjectionGemm(Graph& graph, + NodeArg& projection_input, + const NodeArg& original_projection_input, + const NodeArg& proj_weight, + const NodeArg& proj_bias, + NodeArg& projection_output, + const std::string& base_name, + const std::string& provider_type) { + const auto* proj_input_shape = original_projection_input.Shape(); + const auto* proj_weight_shape = proj_weight.Shape(); + if (proj_input_shape == nullptr || proj_weight_shape == nullptr || proj_weight_shape->dim_size() != 2) { + return false; + } + + auto input_shape = utils::GetTensorShapeFromTensorShapeProto(*proj_input_shape); + if (input_shape.Size() == -1 || input_shape.NumDimensions() < 2) { + return false; + } + + const auto& dim_k = proj_weight_shape->dim(0); + const auto& dim_n = proj_weight_shape->dim(1); + if (!utils::HasDimValue(dim_k) || !utils::HasDimValue(dim_n)) { + return false; + } + + const int64_t m = input_shape.SizeToDimension(input_shape.NumDimensions() - 1); + if (m <= 0) { + return false; + } + + const int64_t k = dim_k.dim_value(); + const int64_t n = dim_n.dim_value(); + if (input_shape[input_shape.NumDimensions() - 1] != k) { + return false; + } + + const auto* bias_shape = proj_bias.Shape(); + if (bias_shape == nullptr || bias_shape->dim_size() != 1 || !utils::HasDimValue(bias_shape->dim(0)) || + bias_shape->dim(0).dim_value() != n) { + return false; + } + + const auto* input_type = original_projection_input.TypeAsProto(); + if (input_type == nullptr || !input_type->has_tensor_type()) { + return false; + } + + const auto element_type = static_cast(input_type->tensor_type().elem_type()); + + auto add_shape_initializer = [&](const std::string& name, const InlinedVector& shape) -> NodeArg& { + ONNX_NAMESPACE::TensorProto shape_initializer_proto; + shape_initializer_proto.set_name(graph.GenerateNodeArgName(name)); + shape_initializer_proto.add_dims(static_cast(shape.size())); + shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + const size_t shape_bytes = SafeInt(shape.size()) * sizeof(int64_t); + utils::SetRawDataInTensorProto(shape_initializer_proto, shape.data(), shape_bytes); + return graph_utils::AddInitializerWithOrtValue(graph, shape_initializer_proto); + }; + + auto make_tensor_arg = [&](const std::string& name, const InlinedVector& shape) -> NodeArg* { + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(element_type); + for (int64_t dim_value : shape) { + type_proto.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim_value); + } + + return &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(name), &type_proto); + }; + + InlinedVector gemm_input_shape{m, k}; + InlinedVector gemm_output_shape{m, n}; + InlinedVector output_shape_values = input_shape.AsShapeVector(); + output_shape_values.back() = n; + + NodeArg* gemm_input_arg = make_tensor_arg("mobileclip_proj_gemm_input", gemm_input_shape); + NodeArg* gemm_output_arg = make_tensor_arg("mobileclip_proj_gemm_output", gemm_output_shape); + NodeArg& gemm_input_shape_arg = add_shape_initializer("mobileclip_proj_gemm_input_shape", gemm_input_shape); + NodeArg& gemm_output_shape_arg = add_shape_initializer("mobileclip_proj_gemm_output_shape", output_shape_values); + + Node& input_reshape = graph.AddNode( + graph.GenerateNodeName("MobileClipProjGemmInputReshape"), + "Reshape", + "Reshape MobileCLIP projection input for Gemm", + {&projection_input, &gemm_input_shape_arg}, + {gemm_input_arg}); + input_reshape.SetExecutionProviderType(provider_type); + + Node& gemm_node = graph.AddNode( + graph.GenerateNodeName(base_name + "/MobileClipProjectionGemm"), + "Gemm", + "Normalized MobileCLIP projection Gemm", + {gemm_input_arg, const_cast(&proj_weight), const_cast(&proj_bias)}, + {gemm_output_arg}); + gemm_node.SetExecutionProviderType(provider_type); + + Node& output_reshape = graph.AddNode( + graph.GenerateNodeName("MobileClipProjGemmOutputReshape"), + "Reshape", + "Restore MobileCLIP projection output shape after Gemm", + {gemm_output_arg, &gemm_output_shape_arg}, + {&projection_output}); + output_reshape.SetExecutionProviderType(provider_type); + + return true; +} + +static bool TryRewriteProjectionMatMulAddToGemm(Graph& graph, + NodeArg& projection_input, + Node& proj_matmul, + Node& proj_add) { + if (proj_matmul.InputDefs().size() < 2 || proj_add.InputDefs().size() < 2) { + return false; + } + + const int bias_idx = proj_matmul.OutputDefs()[0]->Name() == proj_add.InputDefs()[0]->Name() ? 1 : 0; + return TryCreateNormalizedProjectionGemm(graph, + projection_input, + *proj_matmul.InputDefs()[0], + *proj_matmul.InputDefs()[1], + *proj_add.InputDefs()[bias_idx], + *proj_add.MutableOutputDefs()[0], + proj_matmul.Name(), + proj_matmul.GetExecutionProviderType()); +} + +static bool TryRewriteProjectionGemm(Graph& graph, + NodeArg& projection_input, + Node& proj_gemm) { + if (proj_gemm.InputDefs().size() < 3 || proj_gemm.OutputDefs().empty()) { + return false; + } + + return TryCreateNormalizedProjectionGemm(graph, + projection_input, + *proj_gemm.InputDefs()[0], + *proj_gemm.InputDefs()[1], + *proj_gemm.InputDefs()[2], + *proj_gemm.MutableOutputDefs()[0], + proj_gemm.Name(), + proj_gemm.GetExecutionProviderType()); +} + +static bool TryFuseMobileClipMHA(Node& qkv_matmul, + Graph& graph, + const InlinedHashSet& compatible_execution_providers, + const logging::Logger& logger) { + const auto fail = [&](const char* message) { + LOGS(logger, VERBOSE) << "MobileClipMHA[" << qkv_matmul.Name() << "]: fusion skipped: " << message; + return false; + }; + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(qkv_matmul, "MatMul", {1, 9, 13}, kOnnxDomain)) { + return false; + } + + if (!IsSupportedOrUnassignedNode(qkv_matmul, compatible_execution_providers)) { + return false; + } + + if (!optimizer_utils::CheckOutputEdges(graph, qkv_matmul, 1) || qkv_matmul.InputDefs().size() < 2 || + !graph_utils::IsInitializer(graph, qkv_matmul.InputDefs()[1]->Name(), true)) { + return fail("qkv MatMul output count or weight initializer check failed"); + } + + const Node* sequence_transpose = graph_utils::GetInputNode(qkv_matmul, 0); + if (sequence_transpose == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*sequence_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !HasExpectedPerm(*sequence_transpose, {0, 2, 1}) || + !optimizer_utils::CheckOutputEdges(graph, *sequence_transpose, 1)) { + return false; + } + + const Node* input_reshape = graph_utils::GetInputNode(*sequence_transpose, 0); + if (input_reshape == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*input_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *input_reshape, 1)) { + return fail("missing input Reshape before sequence transpose"); + } + + Node* qkv_reshape = GetOnlyChildByOutputIndex(graph, qkv_matmul, 0, "Reshape"); + if (qkv_reshape == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*qkv_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *qkv_reshape, 1)) { + return fail("qkv Reshape after MatMul not matched"); + } + + Node* split = GetOnlyChildByOutputIndex(graph, *qkv_reshape, 0, "Split"); + if (split == nullptr || !graph_utils::IsSupportedOptypeVersionAndDomain(*split, "Split", {13, 18}, kOnnxDomain) || + split->OutputDefs().size() != 3 || !optimizer_utils::IsAttributeWithExpectedValue(*split, "axis", static_cast(2))) { + return fail("qkv Split(axis=2, outputs=3) not matched"); + } + + Node* q_transpose = GetOnlyChildByOutputIndex(graph, *split, 0, "Transpose"); + Node* k_squeeze = GetOnlyChildByOutputIndex(graph, *split, 1, "Squeeze"); + Node* v_transpose = GetOnlyChildByOutputIndex(graph, *split, 2, "Transpose"); + if (q_transpose == nullptr || k_squeeze == nullptr || v_transpose == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*q_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*k_squeeze, "Squeeze", {13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*v_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !HasExpectedPerm(*q_transpose, {2, 0, 3, 1, 4}) || + !HasExpectedPerm(*v_transpose, {2, 0, 3, 1, 4}) || + !HasExpectedAxesInput(graph, *k_squeeze, {2})) { + return fail("q/k/v branch entry pattern after Split not matched"); + } + + Node* q_squeeze = GetOnlyChildByOutputIndex(graph, *q_transpose, 0, "Squeeze"); + Node* v_squeeze = GetOnlyChildByOutputIndex(graph, *v_transpose, 0, "Squeeze"); + if (q_squeeze == nullptr || v_squeeze == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*q_squeeze, "Squeeze", {13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*v_squeeze, "Squeeze", {13}, kOnnxDomain) || + !HasExpectedAxesInput(graph, *q_squeeze, {0}) || + !HasExpectedAxesInput(graph, *v_squeeze, {0})) { + return fail("q/v squeeze pattern not matched"); + } + + Node* q_scale_mul = GetOnlyChildByOutputIndex(graph, *q_squeeze, 0, "Mul"); + Node* k_transpose = GetOnlyChildByOutputIndex(graph, *k_squeeze, 0, "Transpose"); + if (q_scale_mul == nullptr || k_transpose == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*q_scale_mul, "Mul", {7, 13, 14}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*k_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !HasExpectedPerm(*k_transpose, {0, 2, 3, 1})) { + return fail("q scale Mul or k Transpose(0,2,3,1) not matched"); + } + + float scale = 0.0f; + if (q_scale_mul->InputDefs().size() < 2) { + return fail("q scale constant not found"); + } + + const NodeArg* q_squeeze_output = q_squeeze->OutputDefs()[0]; + const NodeArg* mul_input_0 = q_scale_mul->InputDefs()[0]; + const NodeArg* mul_input_1 = q_scale_mul->InputDefs()[1]; + const bool input_0_is_q_squeeze = mul_input_0 != nullptr && q_squeeze_output != nullptr && + mul_input_0->Name() == q_squeeze_output->Name(); + const bool input_1_is_q_squeeze = mul_input_1 != nullptr && q_squeeze_output != nullptr && + mul_input_1->Name() == q_squeeze_output->Name(); + + const NodeArg* scale_input = nullptr; + if (input_0_is_q_squeeze && !input_1_is_q_squeeze) { + scale_input = mul_input_1; + } else if (input_1_is_q_squeeze && !input_0_is_q_squeeze) { + scale_input = mul_input_0; + } + + if (scale_input == nullptr || + !optimizer_utils::GetScalarInitializerValue(graph, *scale_input, scale, true)) { + return fail("q scale constant not found"); + } + + Node* qk_matmul = GetOnlyChildByOutputIndex(graph, *q_scale_mul, 0, "MatMul"); + if (qk_matmul == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*qk_matmul, "MatMul", {1, 9, 13}, kOnnxDomain) || + graph_utils::GetInputNode(*qk_matmul, 1) == nullptr || + graph_utils::GetInputNode(*qk_matmul, 1)->Index() != k_transpose->Index() || + !optimizer_utils::CheckOutputEdges(graph, *qk_matmul, 1)) { + return fail("qk MatMul not matched"); + } + + Node* softmax = GetOnlyChildByOutputIndex(graph, *qk_matmul, 0, "Softmax"); + if (softmax == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*softmax, "Softmax", {1, 11, 13}, kOnnxDomain) || + !optimizer_utils::IsAttributeWithExpectedValue(*softmax, "axis", static_cast(-1)) || + !optimizer_utils::CheckOutputEdges(graph, *softmax, 1)) { + return fail("Softmax(axis=-1) not matched"); + } + + Node* qkv_matmul_1 = GetOnlyChildByOutputIndex(graph, *softmax, 0, "MatMul"); + if (qkv_matmul_1 == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*qkv_matmul_1, "MatMul", {1, 9, 13}, kOnnxDomain) || + graph_utils::GetInputNode(*qkv_matmul_1, 1) == nullptr || + graph_utils::GetInputNode(*qkv_matmul_1, 1)->Index() != v_squeeze->Index() || + !optimizer_utils::CheckOutputEdges(graph, *qkv_matmul_1, 1)) { + return fail("attention-value MatMul not matched"); + } + + Node* transpose_3 = GetOnlyChildByOutputIndex(graph, *qkv_matmul_1, 0, "Transpose"); + if (transpose_3 == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*transpose_3, "Transpose", {1, 13}, kOnnxDomain) || + !HasExpectedPerm(*transpose_3, {0, 2, 1, 3}) || + !optimizer_utils::CheckOutputEdges(graph, *transpose_3, 1)) { + return fail("output Transpose(0,2,1,3) not matched"); + } + + Node* reshape_2 = GetOnlyChildByOutputIndex(graph, *transpose_3, 0, "Reshape"); + if (reshape_2 == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_2, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *reshape_2, 1)) { + return fail("output Reshape not matched"); + } + + Node* proj_matmul = GetOnlyChildByOutputIndex(graph, *reshape_2, 0, "MatMul"); + Node* proj_gemm = proj_matmul == nullptr ? GetOnlyChildByOutputIndex(graph, *reshape_2, 0, "Gemm") : nullptr; + Node* proj_gemm_input_reshape = nullptr; + Node* proj_gemm_output_reshape = nullptr; + Node* proj_add = nullptr; + + if (proj_matmul != nullptr) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(*proj_matmul, "MatMul", {1, 9, 13}, kOnnxDomain) || + proj_matmul->InputDefs().size() < 2 || + !graph_utils::IsInitializer(graph, proj_matmul->InputDefs()[1]->Name(), true) || + !optimizer_utils::CheckOutputEdges(graph, *proj_matmul, 1)) { + return fail("projection MatMul not matched"); + } + + proj_add = GetOnlyChildByOutputIndex(graph, *proj_matmul, 0, "Add"); + if (proj_add == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_add, "Add", {7, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_add, 1)) { + return fail("projection Add not matched"); + } + } else { + if (proj_gemm == nullptr) { + proj_gemm_input_reshape = GetOnlyChildByOutputIndex(graph, *reshape_2, 0, "Reshape"); + if (proj_gemm_input_reshape == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_input_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_gemm_input_reshape, 1)) { + return fail("projection MatMul/Gemm not matched"); + } + + proj_gemm = GetOnlyChildByOutputIndex(graph, *proj_gemm_input_reshape, 0, "Gemm"); + if (proj_gemm == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm, "Gemm", {7, 9, 11, 13}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_gemm, 1)) { + return fail("projection MatMul/Gemm not matched"); + } + + proj_gemm_output_reshape = GetOnlyChildByOutputIndex(graph, *proj_gemm, 0, "Reshape"); + if (proj_gemm_output_reshape == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_output_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_gemm_output_reshape, 1)) { + return fail("normalized projection Gemm output Reshape not matched"); + } + } else if (!graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm, "Gemm", {7, 9, 11, 13}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_gemm, 1)) { + return fail("projection MatMul/Gemm not matched"); + } + } + + int64_t num_heads = 0; + int64_t head_size = 0; + int64_t hidden_size = 0; + if (!TryGetMobileClipQkvReshapeInfo(graph, *qkv_reshape, num_heads, head_size, hidden_size)) { + return fail("unable to derive num_heads/head_size from qkv reshape initializer"); + } + + if (proj_matmul != nullptr) { + if (!ValidateMatMulInitializer(graph, *proj_matmul, hidden_size) || + !ValidateAddBiasInitializerEitherInput(graph, *proj_add, hidden_size)) { + return fail("projection weight/bias shape validation failed"); + } + } else { + if (!ValidateProjectionGemmInitializer(graph, *proj_gemm, hidden_size)) { + return fail("projection Gemm weight/bias shape validation failed"); + } + } + + const NodeArg& qkv_weight = *qkv_matmul.InputDefs()[1]; + if (!optimizer_utils::ValidateShape(qkv_weight, {hidden_size, 3 * hidden_size})) { + return fail("qkv weight shape is not [hidden, 3*hidden]"); + } + + if (!AreSupportedOrUnassignedNodes( + qkv_matmul, + {sequence_transpose, + input_reshape, + qkv_reshape, + split, + q_transpose, + k_squeeze, + v_transpose, + q_squeeze, + v_squeeze, + q_scale_mul, + k_transpose, + qk_matmul, + softmax, + qkv_matmul_1, + transpose_3, + reshape_2, + proj_matmul, + proj_add, + proj_gemm_input_reshape, + proj_gemm, + proj_gemm_output_reshape}, + compatible_execution_providers)) { + return fail("matched nodes are assigned to incompatible execution providers"); + } + + auto mha_output_type = TryCreateMobileClipMhaOutputType(*qkv_matmul.OutputDefs()[0], hidden_size); + auto* mha_output = &graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("mobileclip_mha_output"), + mha_output_type ? &*mha_output_type : nullptr); + + if (proj_matmul != nullptr) { + if (!TryRewriteProjectionMatMulAddToGemm(graph, *mha_output, *proj_matmul, *proj_add)) { + return fail("projection MatMul/Add could not be rewritten to Gemm"); + } + } else if (proj_gemm_input_reshape == nullptr) { + if (!TryRewriteProjectionGemm(graph, *mha_output, *proj_gemm)) { + return fail("projection Gemm could not be normalized"); + } + } + + ONNX_NAMESPACE::TensorProto split_sizes_tensor; + split_sizes_tensor.set_name(graph.GenerateNodeArgName("mobileclip_mha_split_sizes")); + split_sizes_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + split_sizes_tensor.add_dims(3); + const std::array split_sizes{hidden_size, hidden_size, hidden_size}; + utils::SetRawDataInTensorProto(split_sizes_tensor, split_sizes.data(), split_sizes.size() * sizeof(int64_t)); + NodeArg& split_sizes_arg = graph_utils::AddInitializerWithOrtValue(graph, split_sizes_tensor); + + auto* mha_q = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("mobileclip_mha_q"), nullptr); + auto* mha_k = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("mobileclip_mha_k"), nullptr); + auto* mha_v = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("mobileclip_mha_v"), nullptr); + + Node& split_for_mha = graph.AddNode( + graph.GenerateNodeName("MobileClipSplitForMHA"), + "Split", + "Split packed MobileCLIP QKV for MultiHeadAttention", + {qkv_matmul.MutableOutputDefs()[0], &split_sizes_arg}, + {mha_q, mha_k, mha_v}, + nullptr, + kOnnxDomain); + split_for_mha.AddAttribute("axis", static_cast(2)); + + Node& mha_node = graph.AddNode( + graph.GenerateNodeName("MobileClipMultiHeadAttention"), + "MultiHeadAttention", + "Fused MobileCLIP attention subgraph", + {mha_q, mha_k, mha_v}, + {mha_output}, + nullptr, + kMSDomain); + mha_node.AddAttribute("num_heads", num_heads); + mha_node.AddAttribute("scale", scale); + + const auto& provider = qkv_matmul.GetExecutionProviderType(); + split_for_mha.SetExecutionProviderType(provider); + mha_node.SetExecutionProviderType(provider); + + if (proj_gemm_input_reshape != nullptr) { + graph_utils::ReplaceDownstreamNodeInput(graph, *reshape_2, 0, mha_node, 0); + } + + std::vector nodes_to_remove{ + qkv_reshape->Index(), + split->Index(), + q_transpose->Index(), + q_squeeze->Index(), + q_scale_mul->Index(), + k_squeeze->Index(), + k_transpose->Index(), + qk_matmul->Index(), + softmax->Index(), + v_transpose->Index(), + v_squeeze->Index(), + qkv_matmul_1->Index(), + transpose_3->Index(), + reshape_2->Index(), + }; + + if (proj_matmul != nullptr) { + nodes_to_remove.push_back(proj_matmul->Index()); + nodes_to_remove.push_back(proj_add->Index()); + } else if (proj_gemm_input_reshape == nullptr) { + nodes_to_remove.push_back(proj_gemm->Index()); + } + + for (const auto& node_index : nodes_to_remove) { + Node* node = graph.GetNode(node_index); + if (node == nullptr) { + continue; + } + + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(node_index); + } + + LOGS(logger, VERBOSE) << "MobileClipMHA[" << qkv_matmul.Name() + << "]: fused MobileCLIP attention subgraph to MultiHeadAttention"; + + return true; +} + +} // namespace + static bool ValidateMatMulInitializer(const Graph& graph, const Node& matmul, int64_t hidden_size) { const NodeArg& input_b = *(matmul.InputDefs()[1]); if (!graph_utils::IsInitializer(graph, input_b.Name(), true)) { @@ -179,6 +837,12 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& node = *p_node; ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + if (TryFuseMobileClipMHA(node, graph, GetCompatibleExecutionProviders(), logger)) { + fused_count++; + modified = true; + continue; + } + // Add node.GetOutputEdgesCount() == 5/6 for distilbert if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1, 17}, kOnnxDomain) && diff --git a/onnxruntime/core/providers/acl/math/matmul.cc b/onnxruntime/core/providers/acl/math/matmul.cc index 468b394471c13..029a9ebe2768a 100644 --- a/onnxruntime/core/providers/acl/math/matmul.cc +++ b/onnxruntime/core/providers/acl/math/matmul.cc @@ -269,6 +269,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, } Status MatMul::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; if (input_idx != 1) { diff --git a/onnxruntime/core/providers/acl/math/matmul.h b/onnxruntime/core/providers/acl/math/matmul.h index b137e33833de9..783e15585ebf5 100644 --- a/onnxruntime/core/providers/acl/math/matmul.h +++ b/onnxruntime/core/providers/acl/math/matmul.h @@ -34,6 +34,7 @@ class MatMul : public OpKernel { bool& is_packed, PrePackedWeights*) override; Status UseSharedPrePackedBuffers(std::vector&, + gsl::span, int, bool&) override; Status Compute(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/acl/nn/conv.cc b/onnxruntime/core/providers/acl/nn/conv.cc index a62158f1c26ee..5cc10f7cfd2a8 100644 --- a/onnxruntime/core/providers/acl/nn/conv.cc +++ b/onnxruntime/core/providers/acl/nn/conv.cc @@ -370,6 +370,7 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, } Status Conv::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; if (isQuantized ? (input_idx != 3) : (input_idx != 1)) { diff --git a/onnxruntime/core/providers/acl/nn/conv.h b/onnxruntime/core/providers/acl/nn/conv.h index b05ba5363542f..7af086a410857 100644 --- a/onnxruntime/core/providers/acl/nn/conv.h +++ b/onnxruntime/core/providers/acl/nn/conv.h @@ -36,6 +36,7 @@ class Conv : public onnxruntime::OpKernel { bool& is_packed, PrePackedWeights*) override; Status UseSharedPrePackedBuffers(std::vector&, + gsl::span, int, bool&) override; Status Compute(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index 790b1543bbd74..08dbc46213f65 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -54,6 +54,7 @@ class FusedConvFp16 final : public OpKernel { /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; @@ -211,6 +212,7 @@ Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr } Status FusedConvFp16::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { if (input_idx != 1) { diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index ac931c76ee3ae..c0da9aec1e1b1 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -296,6 +296,7 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, template Status Gemm::UseSharedPrePackedBuffers(std::vector& /*prepacked_buffers*/, + gsl::span /*prepacked_buffer_sizes*/, int /*input_idx*/, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; @@ -304,6 +305,7 @@ Status Gemm::UseSharedPrePackedBuffers(std::vector& /*prepac template <> Status Gemm::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index c65f3eb96f62e..d9e66df4bee7c 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -37,6 +37,7 @@ class Gemm : protected GemmBase, public OpKernel { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 8a7795a81027d..8dea41e3488e2 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -220,6 +220,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc } Status MatMul::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index 7f2d2ee400b63..9e6ef1a486235 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -47,7 +47,9 @@ class MatMul final : public OpKernel { /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; - Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, /*out*/ bool& used_shared_buffers) override; Status Compute(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc index 6ebd12a525371..bbb530d037cec 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc @@ -102,6 +102,7 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, Alloca template Status ConvTranspose::UseSharedPrePackedBuffers(std::vector& /*prepacked_buffers*/, + gsl::span /*prepacked_buffer_sizes*/, int /*input_idx*/, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; @@ -110,6 +111,7 @@ Status ConvTranspose::UseSharedPrePackedBuffers(std::vector& template <> Status ConvTranspose::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.h b/onnxruntime/core/providers/cpu/nn/conv_transpose.h index fd6021e65670e..96e3ecf912f32 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.h @@ -35,6 +35,7 @@ class ConvTranspose : public OpKernel { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h index fb86e9731035c..9916c426a54fe 100644 --- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h @@ -80,6 +80,7 @@ class MatMulIntegerBase : public OpKernel { } Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override { used_shared_buffers = false; diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index 24c8b0d57294e..a5e3d4b04a1e3 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -30,6 +30,7 @@ class QLinearConv : public OpKernel { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; @@ -495,6 +496,7 @@ Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, Alloca template Status QLinearConv::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { if (input_idx != 3) { diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index d1ddd04a953ef..d5be6bd29592e 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -322,6 +322,7 @@ Status DeepCpuGruOp::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr a } Status DeepCpuGruOp::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h index 881adf9efb376..fa233cc6f9cde 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h @@ -69,6 +69,7 @@ class DeepCpuGruOp final : public OpKernel { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index 4b3ea672c0812..d2520804bb64c 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -260,6 +260,7 @@ Status DeepCpuLstmOp::PrePack(const Tensor& tensor, int input_idx, } Status DeepCpuLstmOp::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h index c949b62ce7186..487e2a3fb8129 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h @@ -24,6 +24,7 @@ class DeepCpuLstmOp final : public OpKernel, public LSTMBase { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 697428e1ce140..c2a8896b84a7e 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/webgpu/nn/conv.h" #include "core/providers/webgpu/nn/conv2d_mm.h" +#include "core/providers/webgpu/nn/conv3d_naive.h" #include "core/providers/webgpu/nn/im2col_matmul.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -80,8 +81,42 @@ Status Conv::ComputeInternal(ComputeContext& context std::transform(local_dilations.begin(), local_dilations.end(), std::back_inserter(dilations), transform_dim); auto rank = input_shape.NumDimensions(); const InlinedVector perm = {2, 3, 1, 0}; - if (rank > 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only Conv1d and Conv2d are supported."); + if (rank > 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only Conv1d, Conv2d, and Conv3d are supported."); + } else if (rank == 5) { + // Conv3D - use naive per-element shader (matching JS implementation) + if (conv_attrs_.group != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Conv3D does not support grouped convolution (group=", conv_attrs_.group, ")."); + } + const auto output_size = static_cast(output_shape.Size()); + const auto kernel_depth = static_cast(kernel_shape[2]); + const auto kernel_height = static_cast(kernel_shape[3]); + const auto kernel_width = static_cast(kernel_shape[4]); + // pads: head padding values for each spatial dim (front, top, left) + std::vector pads_3d{pads[0], pads[1], pads[2]}; + // Extract spatial dims and channels for explicit uniforms + const auto x_depth = static_cast(input_shape[is_channels_last ? 1 : 2]); + const auto x_height = static_cast(input_shape[is_channels_last ? 2 : 3]); + const auto x_width = static_cast(input_shape[is_channels_last ? 3 : 4]); + const auto x_channels = static_cast(input_shape[is_channels_last ? 4 : 1]); + Conv3DNaiveProgram program(activation_, has_bias, is_channels_last); + program.CacheHint(activation_.ToString(), std::to_string(is_channels_last)) + .AddInput({input, ProgramTensorMetadataDependency::TypeAndRank, input_shape, 1}) + .AddInput({kernel, ProgramTensorMetadataDependency::TypeAndRank, kernel_shape, 1}) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape, 1}) + .AddUniformVariables({{output_size}, + {std::vector{kernel_depth, kernel_height, kernel_width}}, + {pads_3d}, + {strides}, + {dilations}, + {std::vector{x_depth, x_height, x_width}}, + {x_channels}}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, bias->Shape(), 1}); + } + return context.RunProgram(program); } else if (rank == 4) { // Conv2D } else if (rank == 3) { diff --git a/onnxruntime/core/providers/webgpu/nn/conv3d_naive.cc b/onnxruntime/core/providers/webgpu/nn/conv3d_naive.cc new file mode 100644 index 0000000000000..76895e684eeab --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv3d_naive.cc @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webgpu/nn/conv3d_naive.h" +#include "core/providers/webgpu/nn/fuse_utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/shader_variable.h" + +namespace onnxruntime { +namespace webgpu { + +Status Conv3DNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | + ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + const auto& w = shader.AddInput("w", ShaderUsage::UseUniform | + ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | + ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + + std::string apply_activation = GetActivationSnippet(activation_, "x_value_t", "x_element_t"); + + // Helper functions to access x and w by 5D indices + shader.AdditionalImplementation() + << "fn getX(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> x_value_t {\n" + << " let aIndices = x_indices_t(d0, d1, d2, d3, d4);\n" + << " return " << x.GetByIndices("aIndices") << ";\n" + << "}\n" + << "fn getW(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> x_value_t {\n" + << " let aIndices = w_indices_t(d0, d1, d2, d3, d4);\n" + << " return " << w.GetByIndices("aIndices") << ";\n" + << "}\n"; + + // Spatial dimensions and channels are passed as explicit uniforms + // to avoid rank-5 shape packing issues (array,2> vs vec4). + shader.MainFunctionBody() + << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << "let batch = output_indices[0];\n" + << "let d2 = " << output.IndicesGet("output_indices", is_channels_last_ ? "4" : "1") << ";\n" + << "let xFRCCorner = vec3(" << output.IndicesGet("output_indices", is_channels_last_ ? "1" : "2") << ", " + << output.IndicesGet("output_indices", is_channels_last_ ? "2" : "3") << ", " + << output.IndicesGet("output_indices", is_channels_last_ ? "3" : "4") << ") * uniforms.strides - uniforms.pads;\n" + << "let xFCorner = xFRCCorner.x;\n" + << "let xRCorner = xFRCCorner.y;\n" + << "let xCCorner = xFRCCorner.z;\n" + << "let xDepth = uniforms.x_spatial[0];\n" + << "let xHeight = uniforms.x_spatial[1];\n" + << "let xWidth = uniforms.x_spatial[2];\n" + << "let xChannels = uniforms.x_channels;\n" + << "let inputChannelsNearestVec4 = (xChannels / 4u) * 4u;\n" + << "let inputChannelsVec4Remainder = xChannels % 4u;\n" + << "\n" + << "var value = x_value_t(0);\n" + << "for (var wF = 0u; wF < uniforms.filter_dims[0]; wF++) {\n" + << " let xF = xFCorner + wF * uniforms.dilations[0];\n" + << " if (xF >= xDepth) {\n" + << " continue;\n" + << " }\n" + << " for (var wR = 0u; wR < uniforms.filter_dims[1]; wR++) {\n" + << " let xR = xRCorner + wR * uniforms.dilations[1];\n" + << " if (xR >= xHeight) {\n" + << " continue;\n" + << " }\n" + << " for (var wC = 0u; wC < uniforms.filter_dims[2]; wC++) {\n" + << " let xC = xCCorner + wC * uniforms.dilations[2];\n" + << " if (xC >= xWidth) {\n" + << " continue;\n" + << " }\n" + << " for (var d1 = 0u; d1 < inputChannelsNearestVec4; d1 += 4u) {\n"; + + // vec4 dot product accumulation over input channels + if (is_channels_last_) { + shader.MainFunctionBody() + << " let xValues = vec4(\n" + << " getX(batch, xF, xR, xC, d1),\n" + << " getX(batch, xF, xR, xC, d1 + 1u),\n" + << " getX(batch, xF, xR, xC, d1 + 2u),\n" + << " getX(batch, xF, xR, xC, d1 + 3u));\n"; + } else { + shader.MainFunctionBody() + << " let xValues = vec4(\n" + << " getX(batch, d1, xF, xR, xC),\n" + << " getX(batch, d1 + 1u, xF, xR, xC),\n" + << " getX(batch, d1 + 2u, xF, xR, xC),\n" + << " getX(batch, d1 + 3u, xF, xR, xC));\n"; + } + shader.MainFunctionBody() + << " let wValues = vec4(\n" + << " getW(d2, d1, wF, wR, wC),\n" + << " getW(d2, d1 + 1u, wF, wR, wC),\n" + << " getW(d2, d1 + 2u, wF, wR, wC),\n" + << " getW(d2, d1 + 3u, wF, wR, wC));\n" + << " value += x_value_t(dot(xValues, wValues));\n" + << " }\n"; + + // Handle remainder channels (1, 2, or 3) + shader.MainFunctionBody() + << " if (inputChannelsVec4Remainder == 1u) {\n"; + if (is_channels_last_) { + shader.MainFunctionBody() + << " value += getX(batch, xF, xR, xC, inputChannelsNearestVec4)\n" + << " * getW(d2, inputChannelsNearestVec4, wF, wR, wC);\n"; + } else { + shader.MainFunctionBody() + << " value += getX(batch, inputChannelsNearestVec4, xF, xR, xC)\n" + << " * getW(d2, inputChannelsNearestVec4, wF, wR, wC);\n"; + } + shader.MainFunctionBody() + << " } else if (inputChannelsVec4Remainder == 2u) {\n"; + if (is_channels_last_) { + shader.MainFunctionBody() + << " let xValues = vec2(\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4),\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4 + 1u));\n"; + } else { + shader.MainFunctionBody() + << " let xValues = vec2(\n" + << " getX(batch, inputChannelsNearestVec4, xF, xR, xC),\n" + << " getX(batch, inputChannelsNearestVec4 + 1u, xF, xR, xC));\n"; + } + shader.MainFunctionBody() + << " let wValues = vec2(\n" + << " getW(d2, inputChannelsNearestVec4, wF, wR, wC),\n" + << " getW(d2, inputChannelsNearestVec4 + 1u, wF, wR, wC));\n" + << " value += x_value_t(dot(xValues, wValues));\n" + << " } else if (inputChannelsVec4Remainder == 3u) {\n"; + if (is_channels_last_) { + shader.MainFunctionBody() + << " let xValues = vec3(\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4),\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4 + 1u),\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4 + 2u));\n"; + } else { + shader.MainFunctionBody() + << " let xValues = vec3(\n" + << " getX(batch, inputChannelsNearestVec4, xF, xR, xC),\n" + << " getX(batch, inputChannelsNearestVec4 + 1u, xF, xR, xC),\n" + << " getX(batch, inputChannelsNearestVec4 + 2u, xF, xR, xC));\n"; + } + shader.MainFunctionBody() + << " let wValues = vec3(\n" + << " getW(d2, inputChannelsNearestVec4, wF, wR, wC),\n" + << " getW(d2, inputChannelsNearestVec4 + 1u, wF, wR, wC),\n" + << " getW(d2, inputChannelsNearestVec4 + 2u, wF, wR, wC));\n" + << " value += x_value_t(dot(xValues, wValues));\n" + << " }\n" + << " }\n" + << " }\n" + << "}\n"; + + // Apply bias + if (has_bias_) { + const auto& b = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.MainFunctionBody() << "value = value + " << b.GetByIndices("d2") << ";\n"; + } + + // Apply activation + shader.MainFunctionBody() << apply_activation << "\n"; + + // Write output + shader.MainFunctionBody() << output.SetByOffset("global_idx", "value"); + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv3d_naive.h b/onnxruntime/core/providers/webgpu/nn/conv3d_naive.h new file mode 100644 index 0000000000000..25ae449a7d02c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv3d_naive.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/nn/fuse_utils.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class Conv3DNaiveProgram final : public Program { + public: + Conv3DNaiveProgram(const Activation& activation, bool has_bias, bool is_channels_last) + : Program("Conv3DNaive"), activation_(activation), has_bias_(has_bias), is_channels_last_(is_channels_last) { + } + Status GenerateShaderCode(ShaderHelper& shader) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"output_size", ProgramUniformVariableDataType::Uint32}, + {"filter_dims", ProgramUniformVariableDataType::Uint32}, + {"pads", ProgramUniformVariableDataType::Uint32}, + {"strides", ProgramUniformVariableDataType::Uint32}, + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"x_spatial", ProgramUniformVariableDataType::Uint32}, + {"x_channels", ProgramUniformVariableDataType::Uint32}); + + private: + const Activation& activation_; + bool has_bias_; + bool is_channels_last_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc index 625645e71cfec..6f29361502a73 100644 --- a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc +++ b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc @@ -126,9 +126,9 @@ class PluginEpOpKernel final : public controlflow::IControlFlowKernel { return Status::OK(); } - Status UseSharedPrePackedBuffers_V2(std::vector& buffer_unique_ptrs, - gsl::span buffer_sizes, - int input_idx, /*out*/ bool& used_shared_buffers) override { + Status UseSharedPrePackedBuffers(std::vector& buffer_unique_ptrs, + gsl::span buffer_sizes, + int input_idx, /*out*/ bool& used_shared_buffers) override { assert(kernel_impl_ != nullptr); // Should be ensured by PluginEpOpKernel::Create(). if (kernel_impl_->ort_version_supported < 24 || kernel_impl_->SetSharedPrePackedWeight == nullptr) { diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index da958ba6fc970..9640d94aebe58 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -876,11 +876,8 @@ TEST(OpSchemaTypeConstraintTest, Add_SingleConstraint) { // T should allow tensor(float) and tensor(double) among others auto allowed_types = tc.GetAllowedTypes(); EXPECT_GT(allowed_types.size(), 1u); - auto has_type = [&](const char* t) { - return std::find(allowed_types.begin(), allowed_types.end(), t) != allowed_types.end(); - }; - EXPECT_TRUE(has_type("tensor(float)")) << "Expected T to allow tensor(float)"; - EXPECT_TRUE(has_type("tensor(double)")) << "Expected T to allow tensor(double)"; + EXPECT_THAT(allowed_types, ::testing::Contains("tensor(float)")) << "Expected T to allow tensor(float)"; + EXPECT_THAT(allowed_types, ::testing::Contains("tensor(double)")) << "Expected T to allow tensor(double)"; // Both inputs use T auto input_indices = tc.GetInputIndices(); @@ -921,22 +918,18 @@ TEST(OpSchemaTypeConstraintTest, LSTM_MultipleConstraints) { ASSERT_NE(t_ptr, nullptr) << "Expected to find type constraint 'T'"; ASSERT_NE(t1_ptr, nullptr) << "Expected to find type constraint 'T1'"; - auto has_type = [](gsl::span types, const char* t) { - return std::find(types.begin(), types.end(), t) != types.end(); - }; - // T should include tensor(float) and tensor(double) auto t_types = t_tc.GetAllowedTypes(); EXPECT_GT(t_types.size(), 0u); - EXPECT_TRUE(has_type(t_types, "tensor(float)")) << "Expected T to allow tensor(float)"; - EXPECT_TRUE(has_type(t_types, "tensor(double)")) << "Expected T to allow tensor(double)"; + EXPECT_THAT(t_types, ::testing::Contains("tensor(float)")) << "Expected T to allow tensor(float)"; + EXPECT_THAT(t_types, ::testing::Contains("tensor(double)")) << "Expected T to allow tensor(double)"; // T1 should include tensor(int32) (sequence_lens is int32) auto t1_types = t1_tc.GetAllowedTypes(); EXPECT_GT(t1_types.size(), 0u); // T1 is for sequence_lens which is int32 - EXPECT_TRUE(has_type(t1_types, "tensor(int32)")) << "Expected T1 to allow tensor(int32)"; + EXPECT_THAT(t1_types, ::testing::Contains("tensor(int32)")) << "Expected T1 to allow tensor(int32)"; // T should map to inputs X (0), W (1), R (2), B (3), initial_h (5), initial_c (6), P (7) auto t_inputs = t_tc.GetInputIndices(); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 656b0ef86289d..418bb2a809259 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -662,6 +662,7 @@ class PrePackingTestOpKernel : public OpKernel { } Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override { ORT_UNUSED_PARAMETER(input_idx); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 18933e45b8922..75ba3b802f9ae 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -5826,6 +5826,363 @@ TEST_F(GraphTransformationTests, AttentionFusionDistilBertTest) { EXPECT_EQ(op_to_count["Shape"], 0); } +enum class MobileClipProjectionType { + MatMulAdd, + GemmWithReshapes, +}; + +struct MobileClipAttentionShapeConfig { + int64_t input_channels = 512; + int64_t hidden_size = 512; + int64_t num_heads = 16; + int64_t head_size = 32; + int64_t qkv_weight_input_dim = 512; +}; + +static void BuildMobileClipAttentionTestCase(ModelTestBuilder& builder, + MobileClipProjectionType projection_type, + const MobileClipAttentionShapeConfig& shape_config = {}, + bool use_non_default_projection_gemm_attributes = false, + bool use_runtime_projection_shape_input = false) { + const int64_t input_channels = shape_config.input_channels; + const int64_t hidden_size = shape_config.hidden_size; + const int64_t num_heads = shape_config.num_heads; + const int64_t head_size = shape_config.head_size; + const int64_t qkv_weight_input_dim = shape_config.qkv_weight_input_dim; + const int64_t qkv_hidden_size = num_heads * head_size; + const int64_t qkv_output_size = 3 * qkv_hidden_size; + + auto* input_x = builder.MakeInput({1, input_channels, 8, 8}, -1.0f, 1.0f); + auto* input_skip = builder.MakeInput({1, hidden_size, 8, 8}, -1.0f, 1.0f); + + auto* reshape0_shape = builder.Make1DInitializer({1, input_channels, 64}); + auto* qkv_weight = builder.MakeInitializer({qkv_weight_input_dim, qkv_output_size}, -0.05f, 0.05f); + auto* qkv_reshape_shape = builder.Make1DInitializer({1, 64, 3, num_heads, head_size}); + auto* split_sizes = builder.Make1DInitializer({1, 1, 1}); + auto* squeeze_axis_0 = builder.Make1DInitializer({0}); + auto* squeeze_axis_2 = builder.Make1DInitializer({2}); + auto* scale = builder.MakeScalarInitializer(1.0f / std::sqrt(static_cast(head_size))); + auto* reshape2_shape = use_runtime_projection_shape_input + ? builder.MakeInput({3}, {1, 64, hidden_size}) + : builder.Make1DInitializer({1, 64, hidden_size}); + auto* proj_gemm_input_shape = builder.Make1DInitializer({64, hidden_size}); + auto* proj_weight = builder.MakeInitializer({hidden_size, hidden_size}, -0.05f, 0.05f); + auto* proj_bias = builder.MakeInitializer({hidden_size}, -0.02f, 0.02f); + auto* proj_gemm_output_shape = builder.Make1DInitializer({1, 64, hidden_size}); + auto* reshape3_shape = builder.Make1DInitializer({1, hidden_size, 8, 8}); + auto* layer_scale = builder.MakeInitializer({hidden_size, 1, 1}, 0.9f, 1.1f); + + auto* reshape0_out = builder.MakeIntermediate(std::vector{1, input_channels, 64}); + auto* transpose0_out = builder.MakeIntermediate(std::vector{1, 64, input_channels}); + auto* qkv_out = builder.MakeIntermediate(std::vector{1, 64, qkv_output_size}); + auto* qkv_reshape_out = builder.MakeIntermediate(std::vector{1, 64, 3, num_heads, head_size}); + auto* split_q = builder.MakeIntermediate(std::vector{1, 64, 1, num_heads, head_size}); + auto* split_k = builder.MakeIntermediate(std::vector{1, 64, 1, num_heads, head_size}); + auto* split_v = builder.MakeIntermediate(std::vector{1, 64, 1, num_heads, head_size}); + auto* q_transpose_out = builder.MakeIntermediate(std::vector{1, 1, num_heads, 64, head_size}); + auto* q_squeeze_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, head_size}); + auto* k_squeeze_out = builder.MakeIntermediate(std::vector{1, 64, num_heads, head_size}); + auto* k_transpose_out = builder.MakeIntermediate(std::vector{1, num_heads, head_size, 64}); + auto* q_scaled_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, head_size}); + auto* qk_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, 64}); + auto* softmax_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, 64}); + auto* v_transpose_out = builder.MakeIntermediate(std::vector{1, 1, num_heads, 64, head_size}); + auto* v_squeeze_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, head_size}); + auto* attn_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, head_size}); + auto* transpose3_out = builder.MakeIntermediate(std::vector{1, 64, num_heads, head_size}); + auto* reshape2_out = use_runtime_projection_shape_input + ? builder.MakeIntermediate(std::nullopt) + : builder.MakeIntermediate(std::vector{1, 64, hidden_size}); + auto* proj_gemm_input_out = builder.MakeIntermediate(std::vector{64, hidden_size}); + auto* proj_gemm_out = builder.MakeIntermediate(std::vector{64, hidden_size}); + auto* proj_linear_out = builder.MakeIntermediate(std::vector{1, 64, hidden_size}); + auto* transpose4_out = builder.MakeIntermediate(std::vector{1, hidden_size, 64}); + auto* reshape3_out = builder.MakeIntermediate(std::vector{1, hidden_size, 8, 8}); + auto* layer_scale_out = builder.MakeIntermediate(std::vector{1, hidden_size, 8, 8}); + auto* output = builder.MakeOutput(std::vector{1, hidden_size, 8, 8}); + + auto& reshape0 = builder.AddNode("Reshape", std::vector{input_x, reshape0_shape}, std::vector{reshape0_out}); + reshape0.AddAttribute("allowzero", static_cast(0)); + + auto& transpose0 = builder.AddNode("Transpose", std::vector{reshape0_out}, std::vector{transpose0_out}); + transpose0.AddAttribute("perm", std::vector{0, 2, 1}); + + builder.AddNode("MatMul", std::vector{transpose0_out, qkv_weight}, std::vector{qkv_out}); + + auto& qkv_reshape = builder.AddNode("Reshape", std::vector{qkv_out, qkv_reshape_shape}, std::vector{qkv_reshape_out}); + qkv_reshape.AddAttribute("allowzero", static_cast(0)); + + auto& split = builder.AddNode("Split", std::vector{qkv_reshape_out, split_sizes}, std::vector{split_q, split_k, split_v}); + split.AddAttribute("axis", static_cast(2)); + + auto& q_transpose = builder.AddNode("Transpose", std::vector{split_q}, std::vector{q_transpose_out}); + q_transpose.AddAttribute("perm", std::vector{2, 0, 3, 1, 4}); + + builder.AddNode("Squeeze", std::vector{q_transpose_out, squeeze_axis_0}, std::vector{q_squeeze_out}); + builder.AddNode("Squeeze", std::vector{split_k, squeeze_axis_2}, std::vector{k_squeeze_out}); + + auto& k_transpose = builder.AddNode("Transpose", std::vector{k_squeeze_out}, std::vector{k_transpose_out}); + k_transpose.AddAttribute("perm", std::vector{0, 2, 3, 1}); + + builder.AddNode("Mul", std::vector{q_squeeze_out, scale}, std::vector{q_scaled_out}); + builder.AddNode("MatMul", std::vector{q_scaled_out, k_transpose_out}, std::vector{qk_out}); + + auto& softmax = builder.AddNode("Softmax", std::vector{qk_out}, std::vector{softmax_out}); + softmax.AddAttribute("axis", static_cast(-1)); + + auto& v_transpose = builder.AddNode("Transpose", std::vector{split_v}, std::vector{v_transpose_out}); + v_transpose.AddAttribute("perm", std::vector{2, 0, 3, 1, 4}); + + builder.AddNode("Squeeze", std::vector{v_transpose_out, squeeze_axis_0}, std::vector{v_squeeze_out}); + builder.AddNode("MatMul", std::vector{softmax_out, v_squeeze_out}, std::vector{attn_out}); + + auto& transpose3 = builder.AddNode("Transpose", std::vector{attn_out}, std::vector{transpose3_out}); + transpose3.AddAttribute("perm", std::vector{0, 2, 1, 3}); + + auto& reshape2 = builder.AddNode("Reshape", std::vector{transpose3_out, reshape2_shape}, std::vector{reshape2_out}); + reshape2.AddAttribute("allowzero", static_cast(0)); + + if (projection_type == MobileClipProjectionType::GemmWithReshapes) { + auto& proj_gemm_input = builder.AddNode("Reshape", std::vector{reshape2_out, proj_gemm_input_shape}, + std::vector{proj_gemm_input_out}); + proj_gemm_input.AddAttribute("allowzero", static_cast(0)); + + auto& proj_gemm = builder.AddNode("Gemm", std::vector{proj_gemm_input_out, proj_weight, proj_bias}, + std::vector{proj_gemm_out}); + if (use_non_default_projection_gemm_attributes) { + proj_gemm.AddAttribute("transB", static_cast(1)); + } + + auto& proj_gemm_output = builder.AddNode("Reshape", std::vector{proj_gemm_out, proj_gemm_output_shape}, + std::vector{proj_linear_out}); + proj_gemm_output.AddAttribute("allowzero", static_cast(0)); + } else { + auto* proj_matmul_out = builder.MakeIntermediate(std::vector{1, 64, hidden_size}); + builder.AddNode("MatMul", std::vector{reshape2_out, proj_weight}, std::vector{proj_matmul_out}); + builder.AddNode("Add", std::vector{proj_bias, proj_matmul_out}, std::vector{proj_linear_out}); + } + + auto& transpose4 = builder.AddNode("Transpose", std::vector{proj_linear_out}, std::vector{transpose4_out}); + transpose4.AddAttribute("perm", std::vector{0, 2, 1}); + + auto& reshape3 = builder.AddNode("Reshape", std::vector{transpose4_out, reshape3_shape}, std::vector{reshape3_out}); + reshape3.AddAttribute("allowzero", static_cast(0)); + + builder.AddNode("Mul", std::vector{layer_scale, reshape3_out}, std::vector{layer_scale_out}); + builder.AddNode("Add", std::vector{input_skip, layer_scale_out}, std::vector{output}); +} + +static Status CheckMobileClipAttentionFusedGraph(Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Softmax"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Split"] == 1); + TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 4); + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 1); + + int mha_nodes = 0; + int gemm_nodes = 0; + int split_nodes = 0; + for (Node& node : graph.Nodes()) { + if (node.OpType() == "MultiHeadAttention" && node.Domain() == kMSDomain) { + ++mha_nodes; + TEST_RETURN_IF_NOT(node.GetAttributes().at("num_heads").i() == 16); + TEST_RETURN_IF_NOT(std::abs(node.GetAttributes().at("scale").f() - (1.0f / std::sqrt(32.0f))) < 1e-6f); + TEST_RETURN_IF_NOT(node.OutputDefs().size() == 1); + TEST_RETURN_IF_NOT(node.OutputDefs()[0]->Shape() != nullptr); + TEST_RETURN_IF_NOT(node.OutputDefs()[0]->Shape()->dim_size() == 3); + } else if (node.OpType() == "Split") { + ++split_nodes; + } else if (node.OpType() == "Gemm") { + ++gemm_nodes; + TEST_RETURN_IF_NOT(node.InputDefs().size() == 3); + TEST_RETURN_IF_NOT(node.OutputDefs().size() == 1); + TEST_RETURN_IF_NOT(node.InputDefs()[0]->Shape() != nullptr); + TEST_RETURN_IF_NOT(node.InputDefs()[0]->Shape()->dim_size() == 2); + TEST_RETURN_IF_NOT(node.OutputDefs()[0]->Shape() != nullptr); + TEST_RETURN_IF_NOT(node.OutputDefs()[0]->Shape()->dim_size() == 2); + + const Node* gemm_input_node = graph_utils::GetInputNode(node, 0); + TEST_RETURN_IF_NOT(gemm_input_node != nullptr); + TEST_RETURN_IF_NOT(gemm_input_node->OpType() == "Reshape"); + + bool has_output_reshape = false; + for (const Node& consumer : graph.Nodes()) { + for (const NodeArg* input_def : consumer.InputDefs()) { + if (input_def != nullptr && input_def->Name() == node.OutputDefs()[0]->Name()) { + has_output_reshape = consumer.OpType() == "Reshape"; + break; + } + } + + if (has_output_reshape) { + break; + } + } + + TEST_RETURN_IF_NOT(has_output_reshape); + } + } + + TEST_RETURN_IF_NOT(mha_nodes == 1); + TEST_RETURN_IF_NOT(gemm_nodes == 1); + TEST_RETURN_IF_NOT(split_nodes == 1); + return Status::OK(); +} + +static Status CheckMobileClipAttentionFusedGraphOnProvider(Graph& graph, const char* provider) { + ORT_RETURN_IF_ERROR(CheckMobileClipAttentionFusedGraph(graph)); + + for (Node& node : graph.Nodes()) { + TEST_RETURN_IF_NOT(node.GetExecutionProviderType() == provider); + } + + return Status::OK(); +} + +static Status CheckMobileClipAttentionUnfusedProjectionGemmGraph(Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 3); + TEST_RETURN_IF_NOT(op_to_count["Split"] == 1); + TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 3); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 6); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 6); + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 1); + + int gemm_nodes = 0; + for (Node& node : graph.Nodes()) { + if (node.OpType() != "Gemm") { + continue; + } + + ++gemm_nodes; + const auto& attrs = node.GetAttributes(); + auto trans_b_attr = attrs.find("transB"); + TEST_RETURN_IF_NOT(trans_b_attr != attrs.end()); + TEST_RETURN_IF_NOT(trans_b_attr->second.i() == 1); + } + + TEST_RETURN_IF_NOT(gemm_nodes == 1); + return Status::OK(); +} + +static Status CheckMobileClipAttentionUnfusedMatMulGraph(Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 3); + TEST_RETURN_IF_NOT(op_to_count["Split"] == 1); + TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 4); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 6); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 4); + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 2); + return Status::OK(); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionFusedGraph)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionFusedGraph)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaCudaEpTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd); + }; + + auto pre_graph_checker = [](Graph& graph) { + for (Node& node : graph.Nodes()) { + node.SetExecutionProviderType(kCudaExecutionProvider); + } + + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + return CheckMobileClipAttentionFusedGraphOnProvider(graph, kCudaExecutionProvider); + }; + + ASSERT_STATUS_OK(TestGraphTransformer( + build_test_case, 14, *logger_, std::make_unique(InlinedHashSet{kCudaExecutionProvider}), + TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmCudaEpTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes); + }; + + auto pre_graph_checker = [](Graph& graph) { + for (Node& node : graph.Nodes()) { + node.SetExecutionProviderType(kCudaExecutionProvider); + } + + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + return CheckMobileClipAttentionFusedGraphOnProvider(graph, kCudaExecutionProvider); + }; + + ASSERT_STATUS_OK(TestGraphTransformer( + build_test_case, 14, *logger_, std::make_unique(InlinedHashSet{kCudaExecutionProvider}), + TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaInvalidQkvWeightShapeTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, + MobileClipProjectionType::MatMulAdd, + MobileClipAttentionShapeConfig{512, 510, 15, 34, 512}); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionUnfusedMatMulGraph)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmNonDefaultAttributesTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes, {}, true); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, + CheckMobileClipAttentionUnfusedProjectionGemmGraph)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionRewriteFailureLeavesGraphUnfusedTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd, {}, false, true); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, + CheckMobileClipAttentionUnfusedMatMulGraph)); +} + TEST_F(GraphTransformationTests, GeluFusionTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 6d6fedb3c9812..843d925ed6638 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -686,6 +686,8 @@ TEST(ConvFp16Test, Conv2D_AutoPad2) { TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +// TODO: Enable Conv3D fp16 tests for WebGPU when the test infrastructure supports +// conditionally skipping based on device capabilities (e.g., wgpu::FeatureName::ShaderF16). TEST(ConvFp16Test, Conv3D_1) { ConvOpAndTestAttributes attrs = { "", // auto_pad diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 060b61c61532a..f8e93c19dc8d3 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -812,7 +812,7 @@ TEST(ConvTest, Conv3D_1) { vector{1, 1, 1}, // kernel_shape vector{0, 0, 0, 0, 0, 0}, // pads vector{1, 1, 1}, // strides - {kWebGpuExecutionProvider} // excluded EPs + {} // excluded EPs }; vector X = {-0.43337246775627136f, -0.48385289311408997f, -0.30954962968826294f, @@ -849,7 +849,7 @@ TEST(ConvTest, Conv3D_2) { vector{1, 1, 1}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {kWebGpuExecutionProvider} // excluded EPs + {} // excluded EPs }; vector X = {0.010772407054901123f, -0.43806642293930054f, 0.455391526222229f, -0.28657248616218567f, @@ -892,7 +892,7 @@ TEST(ConvTest, Conv3D_Bias) { vector{2, 2, 2}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {kWebGpuExecutionProvider} // excluded EPs + {} // excluded EPs }; vector X = {0.46796226501464844f, -0.4613912105560303f, 0.33512794971466064f, -0.4010460674762726f, diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc index 4afe9dc51b9e5..d151955c7549c 100644 --- a/onnxruntime/test/unittest_main/test_main.cc +++ b/onnxruntime/test/unittest_main/test_main.cc @@ -3,12 +3,14 @@ #include #include +#include +#include +#include #include #include #include #include #ifdef _WIN32 -#include #include #endif @@ -51,6 +53,9 @@ constexpr const char* kLogLevel = "ORT_UNIT_TEST_MAIN_LOG_LEVEL"; // Specify dynamic plugin EP configuration JSON. // Refer to `onnxruntime::test::dynamic_plugin_ep_infra::ParseInitializationConfig()` for more information. constexpr const char* kDynamicPluginEpConfigJson = "ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON"; +// Specify a file path from which to read dynamic plugin EP configuration JSON. +// Mutually exclusive with kDynamicPluginEpConfigJson. +constexpr const char* kDynamicPluginEpConfigJsonFile = "ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON_FILE"; #endif // defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) } // namespace env_var_names @@ -79,9 +84,27 @@ extern "C" void ortenv_setup() { #if defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) { namespace dynamic_plugin_ep_infra = onnxruntime::test::dynamic_plugin_ep_infra; - if (auto dynamic_plugin_ep_config_json = onnxruntime::ParseEnvironmentVariable( - env_var_names::kDynamicPluginEpConfigJson); - dynamic_plugin_ep_config_json.has_value()) { + + auto dynamic_plugin_ep_config_json = onnxruntime::ParseEnvironmentVariable( + env_var_names::kDynamicPluginEpConfigJson); + auto dynamic_plugin_ep_config_json_file = onnxruntime::ParseEnvironmentVariable( + env_var_names::kDynamicPluginEpConfigJsonFile); + + ORT_ENFORCE(!dynamic_plugin_ep_config_json.has_value() || !dynamic_plugin_ep_config_json_file.has_value(), + "Only one of ", env_var_names::kDynamicPluginEpConfigJson, + " and ", env_var_names::kDynamicPluginEpConfigJsonFile, + " should be set, not both."); + + if (dynamic_plugin_ep_config_json_file.has_value()) { + const auto& config_file_path = *dynamic_plugin_ep_config_json_file; + std::cout << "Reading dynamic plugin EP configuration from file: " << config_file_path << "\n"; + std::ifstream config_file{config_file_path}; + ORT_ENFORCE(config_file, "Failed to open dynamic plugin EP configuration file: ", config_file_path); + dynamic_plugin_ep_config_json.emplace( + std::istreambuf_iterator{config_file}, std::istreambuf_iterator{}); + } + + if (dynamic_plugin_ep_config_json.has_value()) { std::cout << "Initializing dynamic plugin EP infrastructure with configuration:\n" << *dynamic_plugin_ep_config_json << "\n"; dynamic_plugin_ep_infra::InitializationConfig config{};