diff --git a/Project.toml b/Project.toml index c7a8ba06..71c98244 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveArrayTools" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" authors = ["Chris Rackauckas "] -version = "3.52.0" +version = "3.53.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/RecursiveArrayToolsFastBroadcastExt.jl b/ext/RecursiveArrayToolsFastBroadcastExt.jl index b5467bc7..0bc7dc4a 100644 --- a/ext/RecursiveArrayToolsFastBroadcastExt.jl +++ b/ext/RecursiveArrayToolsFastBroadcastExt.jl @@ -2,7 +2,7 @@ module RecursiveArrayToolsFastBroadcastExt using RecursiveArrayTools using FastBroadcast -using FastBroadcast: Serial +using FastBroadcast: Serial, Threaded using StaticArraysCore const AbstractVectorOfSArray = AbstractVectorOfArray{ @@ -27,4 +27,32 @@ const AbstractVectorOfSArray = AbstractVectorOfArray{ return dst end +@inline function FastBroadcast.fast_materialize!( + ::Threaded, dst::AbstractVectorOfSArray, + bc::Broadcast.Broadcasted{S} + ) where {S} + if FastBroadcast.use_fast_broadcast(S) + Threads.@threads for i in 1:length(dst.u) + unpacked = RecursiveArrayTools.unpack_voa(bc, i) + dst.u[i] = StaticArraysCore.similar_type(dst.u[i])( + unpacked[j] + for j in eachindex(unpacked) + ) + end + else + Broadcast.materialize!(dst, bc) + end + return dst +end + +# Fallback for non-SArray VectorOfArray: the generic threaded path splits +# along the last axis via views, which does not correctly partition work for +# VectorOfArray. Fall back to serial broadcasting. +@inline function FastBroadcast.fast_materialize!( + ::Threaded, dst::AbstractVectorOfArray, + bc::Broadcast.Broadcasted + ) + return FastBroadcast.fast_materialize!(Serial(), dst, bc) +end + end # module diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 77b75b4e..f5b10118 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -284,6 +284,18 @@ f3!(z, zz) @test z == VA[fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})] @test (@allocated f3!(z, zz)) == 0 +# Test threaded FastBroadcast with VectorOfArray of StaticArrays (issue #564) +@testset "Threaded @.. with VectorOfArray{SArray}" begin + u_t = VectorOfArray(fill(SVector(1.0, 1.0), 2, 2)) + v_t = copy(u_t) + @.. thread = true v_t = v_t + u_t + @test all(x -> x == SVector(2.0, 2.0), v_t.u) + + # Test that repeated threaded application accumulates correctly + @.. thread = true v_t = v_t + u_t + @test all(x -> x == SVector(3.0, 3.0), v_t.u) +end + struct ImmutableVectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A} u::A # A <: AbstractArray{<: AbstractArray{T, N - 1}} end