diff --git a/Project.toml b/Project.toml index eadfae50..0488c683 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" +Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -32,6 +33,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] RecursiveArrayToolsCUDAExt = "CUDA" RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" +RecursiveArrayToolsFastBroadcastPolyesterExt = ["FastBroadcast", "Polyester"] RecursiveArrayToolsForwardDiffExt = "ForwardDiff" RecursiveArrayToolsKernelAbstractionsExt = "KernelAbstractions" RecursiveArrayToolsMeasurementsExt = "Measurements" @@ -59,6 +61,7 @@ Measurements = "2.11" MonteCarloMeasurements = "1.2" NLsolve = "4.5" Pkg = "1" +Polyester = "0.7.16" PrecompileTools = "1.2.1" Random = "1" RecipesBase = "1.3.4" @@ -86,6 +89,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -98,4 +102,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"] +test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Polyester", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"] diff --git a/ext/RecursiveArrayToolsFastBroadcastExt.jl b/ext/RecursiveArrayToolsFastBroadcastExt.jl index 0bc7dc4a..f8d3ce3d 100644 --- a/ext/RecursiveArrayToolsFastBroadcastExt.jl +++ b/ext/RecursiveArrayToolsFastBroadcastExt.jl @@ -27,31 +27,25 @@ const AbstractVectorOfSArray = AbstractVectorOfArray{ return dst end -@inline function FastBroadcast.fast_materialize!( - ::Threaded, dst::AbstractVectorOfSArray, - bc::Broadcast.Broadcasted{S} - ) where {S} - if FastBroadcast.use_fast_broadcast(S) - Threads.@threads for i in 1:length(dst.u) - unpacked = RecursiveArrayTools.unpack_voa(bc, i) - dst.u[i] = StaticArraysCore.similar_type(dst.u[i])( - unpacked[j] - for j in eachindex(unpacked) - ) - end - else - Broadcast.materialize!(dst, bc) - end - return dst -end - -# Fallback for non-SArray VectorOfArray: the generic threaded path splits -# along the last axis via views, which does not correctly partition work for +# Fallback for non-SArray VectorOfArray: the generic threaded path splits along +# the last axis via views, which does not correctly partition work for # VectorOfArray. Fall back to serial broadcasting. +# For SArray VectorOfArray, throw an informative error telling the user to +# load Polyester.jl for threaded broadcasting. @inline function FastBroadcast.fast_materialize!( ::Threaded, dst::AbstractVectorOfArray, bc::Broadcast.Broadcasted ) + # When Polyester is loaded, RecursiveArrayToolsFastBroadcastPolyesterExt + # defines more-specific methods for AbstractVectorOfSArray, so reaching + # this method with an SArray VoA means Polyester is not loaded. + if dst isa AbstractVectorOfSArray + error( + "Threaded FastBroadcast on VectorOfArray{SArray} requires Polyester.jl. " * + "Add `using Polyester` to enable threaded broadcasting, or use " * + "`@.. thread=false` for serial broadcasting." + ) + end return FastBroadcast.fast_materialize!(Serial(), dst, bc) end diff --git a/ext/RecursiveArrayToolsFastBroadcastPolyesterExt.jl b/ext/RecursiveArrayToolsFastBroadcastPolyesterExt.jl new file mode 100644 index 00000000..2e6c2fc2 --- /dev/null +++ b/ext/RecursiveArrayToolsFastBroadcastPolyesterExt.jl @@ -0,0 +1,48 @@ +module RecursiveArrayToolsFastBroadcastPolyesterExt + +using RecursiveArrayTools +using FastBroadcast +using FastBroadcast: Serial, Threaded +using Polyester +using StaticArraysCore + +const AbstractVectorOfSArray = AbstractVectorOfArray{ + T, N, <:AbstractVector{<:StaticArraysCore.SArray}, +} where {T, N} + +@inline function _polyester_fast_materialize!( + dst::AbstractVectorOfSArray, + bc::Broadcast.Broadcasted{S} + ) where {S} + if FastBroadcast.use_fast_broadcast(S) + @batch for i in 1:length(dst.u) + unpacked = RecursiveArrayTools.unpack_voa(bc, i) + dst.u[i] = StaticArraysCore.similar_type(dst.u[i])( + unpacked[j] + for j in eachindex(unpacked) + ) + end + else + Broadcast.materialize!(dst, bc) + end + return dst +end + +@inline function FastBroadcast.fast_materialize!( + ::Threaded, dst::AbstractVectorOfSArray, + bc::Broadcast.Broadcasted{S} + ) where {S} + return _polyester_fast_materialize!(dst, bc) +end + +# Disambiguation: this method is more specific than both the base ext's +# (::Threaded, ::AbstractVectorOfArray, ::Broadcasted) fallback and +# the above (::Threaded, ::AbstractVectorOfSArray, ::Broadcasted{S}). +@inline function FastBroadcast.fast_materialize!( + ::Threaded, dst::AbstractVectorOfSArray, + bc::Broadcast.Broadcasted + ) + return _polyester_fast_materialize!(dst, bc) +end + +end # module diff --git a/test/interface_tests.jl b/test/interface_tests.jl index cc51974a..91a68a76 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -1,6 +1,7 @@ using RecursiveArrayTools, StaticArrays, Test using RecursiveArrayToolsShorthandConstructors using FastBroadcast +using Polyester using SymbolicIndexingInterface: SymbolCache t = 1:3 @@ -302,6 +303,27 @@ f3!(z, zz) @test all(x -> x == SVector(3.0, 3.0), v_t.u) end +# Test Polyester-based threaded FastBroadcast extension (issue #564) +@testset "Polyester-threaded @.. with VectorOfArray{SArray}" begin + # Verify the Polyester extension is loaded + @test Base.get_extension( + Base.PkgId(RecursiveArrayTools), + :RecursiveArrayToolsFastBroadcastPolyesterExt + ) !== nothing + + # Test basic threaded broadcast with Polyester (Vector-of-SVector storage) + u_p = VectorOfArray([SVector(2.0, 3.0) for _ in 1:9]) + v_p = copy(u_p) + @.. thread = true v_p = v_p + u_p + @test all(x -> x == SVector(4.0, 6.0), v_p.u) + + # Test with larger array to exercise Polyester batching + u_large = VectorOfArray([SVector(1.0, 1.0, 1.0) for _ in 1:100]) + v_large = VectorOfArray([SVector(0.0, 0.0, 0.0) for _ in 1:100]) + @.. thread = true v_large = u_large * 2.0 + @test all(x -> x == SVector(2.0, 2.0, 2.0), v_large.u) +end + struct ImmutableVectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A} u::A # A <: AbstractArray{<: AbstractArray{T, N - 1}} end