Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -32,6 +33,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[extensions]
RecursiveArrayToolsCUDAExt = "CUDA"
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
RecursiveArrayToolsFastBroadcastPolyesterExt = ["FastBroadcast", "Polyester"]
RecursiveArrayToolsForwardDiffExt = "ForwardDiff"
RecursiveArrayToolsKernelAbstractionsExt = "KernelAbstractions"
RecursiveArrayToolsMeasurementsExt = "Measurements"
Expand Down Expand Up @@ -59,6 +61,7 @@ Measurements = "2.11"
MonteCarloMeasurements = "1.2"
NLsolve = "4.5"
Pkg = "1"
Polyester = "0.7.16"
PrecompileTools = "1.2.1"
Random = "1"
RecipesBase = "1.3.4"
Expand Down Expand Up @@ -86,6 +89,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -98,4 +102,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Polyester", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
34 changes: 14 additions & 20 deletions ext/RecursiveArrayToolsFastBroadcastExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,25 @@ const AbstractVectorOfSArray = AbstractVectorOfArray{
return dst
end

@inline function FastBroadcast.fast_materialize!(
::Threaded, dst::AbstractVectorOfSArray,
bc::Broadcast.Broadcasted{S}
) where {S}
if FastBroadcast.use_fast_broadcast(S)
Threads.@threads for i in 1:length(dst.u)
unpacked = RecursiveArrayTools.unpack_voa(bc, i)
dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(
unpacked[j]
for j in eachindex(unpacked)
)
end
else
Broadcast.materialize!(dst, bc)
end
return dst
end

# Fallback for non-SArray VectorOfArray: the generic threaded path splits
# along the last axis via views, which does not correctly partition work for
# Fallback for non-SArray VectorOfArray: the generic threaded path splits along
# the last axis via views, which does not correctly partition work for
# VectorOfArray. Fall back to serial broadcasting.
# For SArray VectorOfArray, throw an informative error telling the user to
# load Polyester.jl for threaded broadcasting.
@inline function FastBroadcast.fast_materialize!(
::Threaded, dst::AbstractVectorOfArray,
bc::Broadcast.Broadcasted
)
# When Polyester is loaded, RecursiveArrayToolsFastBroadcastPolyesterExt
# defines more-specific methods for AbstractVectorOfSArray, so reaching
# this method with an SArray VoA means Polyester is not loaded.
if dst isa AbstractVectorOfSArray
error(
"Threaded FastBroadcast on VectorOfArray{SArray} requires Polyester.jl. " *
"Add `using Polyester` to enable threaded broadcasting, or use " *
"`@.. thread=false` for serial broadcasting."
)
end
return FastBroadcast.fast_materialize!(Serial(), dst, bc)
end

Expand Down
48 changes: 48 additions & 0 deletions ext/RecursiveArrayToolsFastBroadcastPolyesterExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module RecursiveArrayToolsFastBroadcastPolyesterExt

using RecursiveArrayTools
using FastBroadcast
using FastBroadcast: Serial, Threaded
using Polyester
using StaticArraysCore

const AbstractVectorOfSArray = AbstractVectorOfArray{
T, N, <:AbstractVector{<:StaticArraysCore.SArray},
} where {T, N}

@inline function _polyester_fast_materialize!(
dst::AbstractVectorOfSArray,
bc::Broadcast.Broadcasted{S}
) where {S}
if FastBroadcast.use_fast_broadcast(S)
@batch for i in 1:length(dst.u)
unpacked = RecursiveArrayTools.unpack_voa(bc, i)
dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(
unpacked[j]
for j in eachindex(unpacked)
)
end
else
Broadcast.materialize!(dst, bc)
end
return dst
end

@inline function FastBroadcast.fast_materialize!(
::Threaded, dst::AbstractVectorOfSArray,
bc::Broadcast.Broadcasted{S}
) where {S}
return _polyester_fast_materialize!(dst, bc)
end

# Disambiguation: this method is more specific than both the base ext's
# (::Threaded, ::AbstractVectorOfArray, ::Broadcasted) fallback and
# the above (::Threaded, ::AbstractVectorOfSArray, ::Broadcasted{S}).
@inline function FastBroadcast.fast_materialize!(
::Threaded, dst::AbstractVectorOfSArray,
bc::Broadcast.Broadcasted
)
return _polyester_fast_materialize!(dst, bc)
end

end # module
22 changes: 22 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using RecursiveArrayTools, StaticArrays, Test
using RecursiveArrayToolsShorthandConstructors
using FastBroadcast
using Polyester
using SymbolicIndexingInterface: SymbolCache

t = 1:3
Expand Down Expand Up @@ -302,6 +303,27 @@ f3!(z, zz)
@test all(x -> x == SVector(3.0, 3.0), v_t.u)
end

# Test Polyester-based threaded FastBroadcast extension (issue #564)
@testset "Polyester-threaded @.. with VectorOfArray{SArray}" begin
# Verify the Polyester extension is loaded
@test Base.get_extension(
Base.PkgId(RecursiveArrayTools),
:RecursiveArrayToolsFastBroadcastPolyesterExt
) !== nothing

# Test basic threaded broadcast with Polyester (Vector-of-SVector storage)
u_p = VectorOfArray([SVector(2.0, 3.0) for _ in 1:9])
v_p = copy(u_p)
@.. thread = true v_p = v_p + u_p
@test all(x -> x == SVector(4.0, 6.0), v_p.u)

# Test with larger array to exercise Polyester batching
u_large = VectorOfArray([SVector(1.0, 1.0, 1.0) for _ in 1:100])
v_large = VectorOfArray([SVector(0.0, 0.0, 0.0) for _ in 1:100])
@.. thread = true v_large = u_large * 2.0
@test all(x -> x == SVector(2.0, 2.0, 2.0), v_large.u)
end

struct ImmutableVectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
u::A # A <: AbstractArray{<: AbstractArray{T, N - 1}}
end
Expand Down
Loading