Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions .github/workflows/IntegrationTest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: "IntegrationTest"

on:
push:
branches:
- 'main'
tags: '*'
paths:
- 'Project.toml'
pull_request:
paths:
- 'Project.toml'

jobs:
integration-test:
name: "IntegrationTest"
strategy:
matrix:
pkg:
- 'BlockSparseArrays'
- 'DiagonalArrays'
- 'ITensorBase'
- 'ITensorNetworksNext'
- 'KroneckerArrays'
- 'NamedDimsArrays'
- 'SparseArraysBase'
uses: "ITensor/ITensorActions/.github/workflows/IntegrationTest.yml@main"
with:
localregistry: "https://github.com/ITensor/ITensorRegistry.git"
pkg: "${{ matrix.pkg }}"
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FunctionImplementations"
uuid = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.2.1"
version = "0.3.0"

[weakdeps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ FunctionImplementations = {path = ".."}
[compat]
Documenter = "1"
Literate = "2"
FunctionImplementations = "0.2"
FunctionImplementations = "0.3"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
FunctionImplementations = {path = ".."}

[compat]
FunctionImplementations = "0.2"
FunctionImplementations = "0.3"
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module FunctionImplementationsLinearAlgebraExt
import FunctionImplementations as FI
import LinearAlgebra as LA

struct DiagonalStyle <: FI.AbstractMatrixStyle end
struct DiagonalStyle <: FI.AbstractArrayStyle end
FI.Style(::Type{<:LA.Diagonal}) = DiagonalStyle()
const permuteddims_diag = DiagonalStyle()(FI.permuteddims)
function permuteddims_diag(a::AbstractArray, perm)
Expand Down
66 changes: 19 additions & 47 deletions src/style.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,57 +28,32 @@ Style `s`.
(s::Style)(f) = Implementation(f, s)

"""
`FunctionImplementations.AbstractArrayStyle{N} <: Style` is the abstract supertype for any style
`FunctionImplementations.AbstractArrayStyle <: Style` is the abstract supertype for any style
associated with an `AbstractArray` type.
The `N` parameter is the dimensionality, which can be handy for AbstractArray types
that only support specific dimensionalities:

struct SparseMatrixStyle <: FunctionImplementations.AbstractArrayStyle{2} end
FunctionImplementations.Style(::Type{<:SparseMatrixCSC}) = SparseMatrixStyle()

For `AbstractArray` types that support arbitrary dimensionality, `N` can be set to `Any`:

struct MyArrayStyle <: FunctionImplementations.AbstractArrayStyle{Any} end
FunctionImplementations.Style(::Type{<:MyArray}) = MyArrayStyle()

In cases where you want to be able to mix multiple `AbstractArrayStyle`s and keep track
of dimensionality, your style needs to support a `Val` constructor:

struct MyArrayStyleDim{N} <: FunctionImplementations.AbstractArrayStyle{N} end
(::Type{<:MyArrayStyleDim})(::Val{N}) where N = MyArrayStyleDim{N}()

Note that if two or more `AbstractArrayStyle` subtypes conflict, the resulting
style will fall back to that of `Array`s. If this is undesirable, you may need to
define binary [`Style`](@ref) rules to control the output type.

See also [`FunctionImplementations.DefaultArrayStyle`](@ref).
"""
abstract type AbstractArrayStyle{N} <: Style end
abstract type AbstractVectorStyle <: AbstractArrayStyle{1} end
abstract type AbstractMatrixStyle <: AbstractArrayStyle{2} end
abstract type AbstractArrayStyle <: Style end

"""
`FunctionImplementations.DefaultArrayStyle{N}()` is a [`FunctionImplementations.Style`](@ref) indicating that an object
behaves as an `N`-dimensional array. Specifically, `DefaultArrayStyle` is
used for any
`AbstractArray` type that hasn't defined a specialized style, and in the absence of
overrides from other arguments the resulting output type is `Array`.
`FunctionImplementations.DefaultArrayStyle()` is a [`FunctionImplementations.Style`](@ref)
indicating that an object behaves as an array. Specifically, `DefaultArrayStyle` is
used for any `AbstractArray` type that hasn't defined a specialized style, and in the
absence of overrides from other arguments the resulting output type is `Array`.
"""
struct DefaultArrayStyle{N} <: AbstractArrayStyle{N} end
DefaultArrayStyle() = DefaultArrayStyle{Any}()
DefaultArrayStyle(::Val{N}) where {N} = DefaultArrayStyle{N}()
DefaultArrayStyle{M}(::Val{N}) where {N, M} = DefaultArrayStyle{N}()
const DefaultVectorStyle = DefaultArrayStyle{1}
const DefaultMatrixStyle = DefaultArrayStyle{2}
Style(::Type{<:AbstractArray{T, N}}) where {T, N} = DefaultArrayStyle{N}()
struct DefaultArrayStyle <: AbstractArrayStyle end
Style(::Type{<:AbstractArray}) = DefaultArrayStyle()

# `ArrayConflict` is an internal type signaling that two or more different `AbstractArrayStyle`
# objects were supplied as arguments, and that no rule was defined for resolving the
# conflict. The resulting output is `Array`. While this is the same output type
# produced by `DefaultArrayStyle`, `ArrayConflict` "poisons" the Style so that
# 3 or more arguments still return an `ArrayConflict`.
struct ArrayConflict <: AbstractArrayStyle{Any} end
ArrayConflict(::Val) = ArrayConflict()
struct ArrayConflict <: AbstractArrayStyle end

### Binary Style rules
"""
Expand All @@ -100,17 +75,14 @@ Style(::UnknownStyle, ::UnknownStyle) = UnknownStyle()
Style(::S, ::UnknownStyle) where {S <: Style} = S()
# Precedence rules
Style(::A, ::A) where {A <: AbstractArrayStyle} = A()
function Style(a::A, b::B) where {A <: AbstractArrayStyle{M}, B <: AbstractArrayStyle{N}} where {M, N}
if Base.typename(A) === Base.typename(B)
return A(Val(Any))
function Style(a::A, b::B) where {A <: AbstractArrayStyle, B <: AbstractArrayStyle}
if Base.typename(A) Base.typename(B)
return A()
end
return UnknownStyle()
end
# Any specific array type beats DefaultArrayStyle
Style(a::AbstractArrayStyle{Any}, ::DefaultArrayStyle) = a
Style(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where {N} = a
Style(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M, N} =
typeof(a)(Val(Any))
Style(a::AbstractArrayStyle, ::DefaultArrayStyle) = a

## logic for deciding the Style

Expand All @@ -124,12 +96,12 @@ Uses [`Style`](@ref) to get the style for each argument, and uses
# Examples
```jldoctest
julia> FunctionImplementations.style([1], [1 2; 3 4])
FunctionImplementations.DefaultArrayStyle{Any}()
FunctionImplementations.DefaultArrayStyle()
```
"""
function style end

style() = DefaultArrayStyle{0}()
style() = DefaultArrayStyle()
style(c) = result_style(Style(typeof(c)))
style(c1, c2) = result_style(style(c1), style(c2))
@inline style(c1, c2, cs...) = result_style(style(c1), style(c2, cs...))
Expand All @@ -143,11 +115,11 @@ determine a common `Style`.
# Examples

```jldoctest
julia> FunctionImplementations.result_style(FunctionImplementations.DefaultArrayStyle{0}(), FunctionImplementations.DefaultArrayStyle{3}())
FunctionImplementations.DefaultArrayStyle{Any}()
julia> FunctionImplementations.result_style(FunctionImplementations.DefaultArrayStyle(), FunctionImplementations.DefaultArrayStyle())
FunctionImplementations.DefaultArrayStyle()

julia> FunctionImplementations.result_style(FunctionImplementations.UnknownStyle(), FunctionImplementations.DefaultArrayStyle{1}())
FunctionImplementations.DefaultArrayStyle{1}()
julia> FunctionImplementations.result_style(FunctionImplementations.UnknownStyle(), FunctionImplementations.DefaultArrayStyle())
FunctionImplementations.DefaultArrayStyle()
```
"""
function result_style end
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ FunctionImplementations = {path = ".."}

[compat]
Aqua = "0.8"
FunctionImplementations = "0.2"
FunctionImplementations = "0.3"
LinearAlgebra = "1.10"
SafeTestsets = "0.1"
Suppressor = "0.2"
Expand Down
75 changes: 20 additions & 55 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ using Test: @test, @testset
# Test the shorthand for creating an Implementation by calling a Style with a
# function.
@test FI.style([1, 2, 3])(getindex) ≡
FI.Implementation(getindex, FI.DefaultArrayStyle{1}())
FI.Implementation(getindex, FI.DefaultArrayStyle())
end
@testset "Style" begin
# Test basic Style trait for different array types
@test FI.Style(typeof([1, 2, 3])) isa FI.DefaultArrayStyle{1}
@test FI.style([1, 2, 3]) isa FI.DefaultArrayStyle{1}
@test FI.Style(typeof([1 2; 3 4])) isa FI.DefaultArrayStyle{2}
@test FI.Style(typeof(rand(2, 3, 4))) isa FI.DefaultArrayStyle{3}
@test FI.Style(typeof([1, 2, 3])) FI.DefaultArrayStyle()
@test FI.style([1, 2, 3]) FI.DefaultArrayStyle()
@test FI.Style(typeof([1 2; 3 4])) FI.DefaultArrayStyle()
@test FI.Style(typeof(rand(2, 3, 4))) FI.DefaultArrayStyle()

# Test custom Style definition
struct CustomStyle <: FI.Style end
Expand All @@ -35,91 +35,56 @@ using Test: @test, @testset
struct MyArray{T, N} <: AbstractArray{T, N}
data::Array{T, N}
end
struct MyArrayStyle <: FI.AbstractArrayStyle{Any} end
struct MyArrayStyle <: FI.AbstractArrayStyle end
FI.Style(::Type{<:MyArray}) = MyArrayStyle()
@test FI.Style(MyArray) isa MyArrayStyle

# Test style homogeneity rule (same type returns preserved)
s1 = FI.DefaultArrayStyle{1}()
s2 = FI.DefaultArrayStyle{1}()
@test FI.Style(s1, s2) ≡ s1
s1 = FI.DefaultArrayStyle()
s2 = FI.DefaultArrayStyle()
@test FI.Style(s1, s2) ≡ FI.DefaultArrayStyle()

# Test UnknownStyle precedence
unknown = FI.UnknownStyle()
known = FI.DefaultArrayStyle{1}()
known = FI.DefaultArrayStyle()
@test FI.Style(known, unknown) ≡ known
@test FI.Style(unknown, unknown) ≡ unknown

# Test AbstractArrayStyle with different dimensions uses max
@test FI.Style(
FI.DefaultArrayStyle{1}(),
FI.DefaultArrayStyle{2}()
) isa FI.DefaultArrayStyle{Any}

# Test DefaultArrayStyle Val constructor preserves type when dimension matches
default_style = FI.DefaultArrayStyle{1}(Val(1))
@test FI.DefaultArrayStyle{1}(Val(1)) isa FI.DefaultArrayStyle{1}

# Test DefaultArrayStyle Val constructor changes dimension
@test FI.DefaultArrayStyle{1}(Val(2)) isa FI.DefaultArrayStyle{2}

# Test DefaultArrayStyle constructor defaults to Any dimension
@test FI.DefaultArrayStyle() isa FI.DefaultArrayStyle{Any}

# Test const aliases
@test FI.DefaultVectorStyle ≡ FI.DefaultArrayStyle{1}
@test FI.DefaultMatrixStyle ≡ FI.DefaultArrayStyle{2}

# Test ArrayConflict
conflict = FI.ArrayConflict()
@test conflict isa FI.ArrayConflict
@test conflict isa FI.AbstractArrayStyle{Any}

# Test ArrayConflict Val constructor
conflict_val = FI.ArrayConflict(Val(3))
@test conflict_val isa FI.ArrayConflict
@test conflict isa FI.AbstractArrayStyle

# Test style with no arguments
@test FI.style() isa FI.DefaultArrayStyle{0}
@test FI.style() FI.DefaultArrayStyle()

# Test style with single argument
@test FI.style([1, 2]) isa FI.DefaultArrayStyle{1}
@test FI.style([1 2; 3 4]) isa FI.DefaultArrayStyle{2}
@test FI.style([1, 2]) FI.DefaultArrayStyle()
@test FI.style([1 2; 3 4]) FI.DefaultArrayStyle()

# Test style with two arguments
result = FI.style([1, 2], [1 2; 3 4])
@test result isa FI.DefaultArrayStyle{Any}
@test result FI.DefaultArrayStyle()

# Test style with same dimensions
result = FI.style([1], [2])
@test result isa FI.DefaultArrayStyle{1}
@test result FI.DefaultArrayStyle()

# Test style with multiple arguments
result = FI.style([1], [1 2], rand(2, 3, 4))
@test result isa FI.DefaultArrayStyle{Any}
@test result FI.DefaultArrayStyle()

# Test result_style with single argument
@test FI.result_style(FI.DefaultArrayStyle{1}()) isa FI.DefaultArrayStyle{1}
@test FI.result_style(FI.DefaultArrayStyle()) isa FI.DefaultArrayStyle

# Test result_style with two identical styles
s = FI.DefaultArrayStyle{2}()
s = FI.DefaultArrayStyle()
@test FI.result_style(s, s) ≡ s

# Test result_style with UnknownStyle
known = FI.DefaultArrayStyle{1}()
known = FI.DefaultArrayStyle()
unknown = FI.UnknownStyle()
@test FI.result_style(known, unknown) ≡ known
@test FI.result_style(unknown, known) ≡ known

# Test result_style with different dimension DefaultArrayStyle uses max
result = FI.result_style(
FI.DefaultArrayStyle{1}(),
FI.DefaultArrayStyle{2}()
)
@test result isa FI.DefaultArrayStyle{Any}

# Test result_style with same shape behaves consistently
same_style = FI.DefaultArrayStyle{2}()
@test FI.result_style(same_style, same_style) ≡ same_style
end
end
Loading