Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 10, 2024
1 parent 233d3f0 commit 9c1665b
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 33 deletions.
11 changes: 8 additions & 3 deletions src/abstractarrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,23 @@ end
return Broadcast.DefaultArrayStyle{ndims(type)}()

Check warning on line 29 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L28-L29

Added lines #L28 - L29 were not covered by tests
end

@interface ::AbstractArrayInterface function Base.similar(
@interface interface::AbstractArrayInterface function Base.similar(

Check warning on line 32 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L32

Added line #L32 was not covered by tests
a::AbstractArray, T::Type, size::Tuple{Vararg{Int}}
)
# TODO: Maybe define as `Array{T}(undef, size...)` or
# `invoke(Base.similar, Tuple{AbstractArray,Type,Vararg{Int}}, a, T, size)`.
# TODO: Use `MethodError`?
return error("Not implemented.")
return similar(arraytype(interface, T), size)

Check warning on line 38 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L38

Added line #L38 was not covered by tests
end

@interface ::AbstractArrayInterface function Base.copy(a::AbstractArray)
a_dest = similar(a)
return a_dest .= a

Check warning on line 43 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L41-L43

Added lines #L41 - L43 were not covered by tests
end

# TODO: Make this more general, handle mixtures of integers and ranges (`Union{Integer,Base.OneTo}`).
@interface interface::AbstractArrayInterface function Base.similar(

Check warning on line 47 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L47

Added line #L47 was not covered by tests
a::AbstractArray, T::Type, axes::Tuple{Vararg{Base.OneTo}}
a::AbstractArray, T::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
)
# TODO: Use `Base.to_shape(axes)` or
# `Base.invoke(similar, Tuple{AbstractArray,Type,Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}}}, a, T, axes)`.
Expand Down
7 changes: 5 additions & 2 deletions src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ using LinearAlgebra: LinearAlgebra
=#
function derive(::Val{:AbstractArrayOps}, type)
return quote
Base.getindex(::$type, ::Any...)

Check warning on line 14 in src/traits.jl

View check run for this annotation

Codecov / codecov/patch

src/traits.jl#L14

Added line #L14 was not covered by tests
Base.getindex(::$type, ::Int...)
Base.setindex!(::$type, ::Any, ::Int...)
Base.similar(::$type, ::Type, ::Tuple{Vararg{Int}})
Base.similar(::$type, ::Type, ::Tuple{Vararg{Base.OneTo}})
Base.similar(::$type, ::Type, ::Tuple{Base.OneTo,Vararg{Base.OneTo}})
Base.copy(::$type)

Check warning on line 19 in src/traits.jl

View check run for this annotation

Codecov / codecov/patch

src/traits.jl#L18-L19

Added lines #L18 - L19 were not covered by tests
Base.map(::Any, ::$type...)
Base.map!(::Any, ::Any, ::$type...)
Base.map!(::Any, ::AbstractArray, ::$type...)
Base.permutedims!(::Any, ::$type, ::Any)

Check warning on line 22 in src/traits.jl

View check run for this annotation

Codecov / codecov/patch

src/traits.jl#L21-L22

Added lines #L21 - L22 were not covered by tests
Broadcast.BroadcastStyle(::Type{<:$type})
ArrayLayouts.MemoryLayout(::Type{<:$type})
Expand All @@ -27,5 +29,6 @@ end
function derive(::Val{:AbstractArrayStyleOps}, type)
return quote
Base.similar(::Broadcast.Broadcasted{<:$type}, ::Type, ::Tuple)
Base.copyto!(::AbstractArray, ::Broadcast.Broadcasted{<:$type})

Check warning on line 32 in src/traits.jl

View check run for this annotation

Codecov / codecov/patch

src/traits.jl#L29-L32

Added lines #L29 - L32 were not covered by tests
end
end
130 changes: 104 additions & 26 deletions test/basics/SparseArrayDOKs.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module SparseArrayDOKs

using ArrayLayouts: ArrayLayouts
using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout
using Derive: Derive, @array_aliases, @derive, @interface, AbstractArrayInterface, interface
using LinearAlgebra: LinearAlgebra

Expand All @@ -26,6 +26,109 @@ end
return a
end

struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()

Derive.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface()

@derive SparseArrayStyle AbstractArrayStyleOps

Derive.arraytype(::SparseArrayInterface, T::Type) = SparseArrayDOK{T}

# Interface functions.
@interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type)
return SparseArrayStyle{ndims(type)}()
end

struct SparseLayout <: MemoryLayout end

@interface ::SparseArrayInterface function ArrayLayouts.MemoryLayout(type::Type)
return SparseLayout()
end

@interface ::SparseArrayInterface function Base.map!(
f, a_dest::AbstractArray, as::AbstractArray...
)
for I in union(eachstoredindex.(as)...)
a_dest[I] = map(f, map(a -> a[I], as)...)
end
return a_dest
end

# ArrayLayouts functionality.

function ArrayLayouts.sub_materialize(::SparseLayout, a::AbstractArray, axes::Tuple)
a_dest = similar(a)
a_dest .= a
return a_dest
end

function ArrayLayouts.materialize!(
m::MatMulMatAdd{<:SparseLayout,<:SparseLayout,<:SparseLayout}
)
a_dest, a1, a2, α, β = m.C, m.A, m.B, m.α, m.β
for I1 in eachstoredindex(a1)
for I2 in eachstoredindex(a2)
if I1[2] == I2[1]
I_dest = CartesianIndex(I1[1], I2[2])
a_dest[I_dest] = a1[I1] * a2[I2] * α + a_dest[I_dest] * β
end
end
end
return a_dest
end

# Sparse array minimal interface
using LinearAlgebra: Adjoint
function isstored(a::Adjoint, i::Int, j::Int)
return isstored(parent(a), j, i)
end
function getstoredindex(a::Adjoint, i::Int, j::Int)
return getstoredindex(parent(a), j, i)'
end
function getunstoredindex(a::Adjoint, i::Int, j::Int)
return getunstoredindex(parent(a), j, i)'
end
function eachstoredindex(a::Adjoint)
return map(CartesianIndex reverse Tuple, collect(eachstoredindex(parent(a))))
end

function isstored(a::PermutedDimsArray, I::Int...)
return isstored(parent(a), reverse(I)...)
end
function getstoredindex(a::PermutedDimsArray, I::Int...)
return getstoredindex(parent(a), reverse(I)...)
end
function getunstoredindex(a::PermutedDimsArray, I::Int...)
return getunstoredindex(parent(a), reverse(I)...)
end
function eachstoredindex(a::PermutedDimsArray)
return map(CartesianIndex reverse Tuple, collect(eachstoredindex(parent(a))))
end

function isstored(a::SubArray, I::Int...)
return isstored(parent(a), Base.reindex(parentindices(a), I)...)
end
function getstoredindex(a::SubArray, I::Int...)
return getstoredindex(parent(a), Base.reindex(parentindices(a), I)...)
end
function getunstoredindex(a::SubArray, I::Int...)
return getunstoredindex(parent(a), Base.reindex(parentindices(a), I)...)
end
function eachstoredindex(a::SubArray)
nonscalardims = filter(ntuple(identity, ndims(parent(a)))) do d
return !(parentindices(a)[d] isa Real)
end
nonscalar_parentindices = map(d -> parentindices(a)[d], nonscalardims)
subindices = filter(eachstoredindex(parent(a))) do I
return all(d -> I[d] parentindices(a)[d], 1:ndims(parent(a)))
end
return map(collect(subindices)) do I
I_nonscalar = CartesianIndex(map(d -> I[d], nonscalardims))
return CartesianIndex(Base.reindex(nonscalar_parentindices, Tuple(I_nonscalar)))
end
end

# Define a type that will derive the interface.
struct SparseArrayDOK{T,N} <: AbstractArray{T,N}
storage::Dict{CartesianIndex{N},T}
Expand Down Expand Up @@ -61,37 +164,12 @@ end
eachstoredindex(a::SparseArrayDOK) = keys(storage(a))
storedlength(a::SparseArrayDOK) = length(eachstoredindex(a))

using LinearAlgebra: Adjoint
function isstored(a::Adjoint{<:Any,<:SparseArrayDOK}, i::Int, j::Int)
return isstored(parent(a), j, i)
end
function getstoredindex(a::Adjoint{<:Any,<:SparseArrayDOK}, i::Int, j::Int)
return getstoredindex(parent(a), j, i)'
end
function getunstoredindex(a::Adjoint{<:Any,<:SparseArrayDOK}, i::Int, j::Int)
return getunstoredindex(parent(a), j, i)'
end

# Specify the interface the type adheres to.
Derive.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface()

# Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc.
@array_aliases SparseArrayDOK

struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()

Derive.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface()

@derive SparseArrayStyle AbstractArrayStyleOps

Derive.arraytype(::SparseArrayInterface, T::Type) = SparseArrayDOK{T}

# Interface functions.
@interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type)
return SparseArrayStyle{ndims(type)}()
end

# Derive the interface for the type.
@derive AnySparseArrayDOK AbstractArrayOps

Expand Down
60 changes: 58 additions & 2 deletions test/basics/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,72 @@
using Test: @testset
using Test: @test, @testset
include("SparseArrayDOKs.jl")
using .SparseArrayDOKs: SparseArrayDOK, storedlength

elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "Derive" for elt in elts
a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
@test a isa SparseArrayDOK{elt,2}
@test size(a) == (2, 2)
@test a[1, 2] == 12
@test storedlength(a) == 1

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
for b in (similar(a, Float32, (3, 3)), similar(a, Float32, Base.OneTo.((3, 3))))
@test b isa SparseArrayDOK{Float32,2}
@test b == zeros(Float32, 3, 3)
@test size(b) == (3, 3)
end

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = similar(a)
bc = Broadcast.Broadcasted(x -> 2x, (a,))
copyto!(b, bc)
@test b isa SparseArrayDOK{elt,2}
@test b == [0 24; 0 0]
@test storedlength(b) == 1

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = permutedims(a, (2, 1))
@test b isa SparseArrayDOK{elt,2}
@test b == [0 0; 12 0]
@test storedlength(b) == 1

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = copy(a')
@test b isa SparseArrayDOK{elt,2}
@test b == [0 0; 12 0]
@test storedlength(b) == 1

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = map(x -> 2x, a)
@test b isa SparseArrayDOK{elt,2}
@test b == [0 24; 0 0]
@test storedlength(b) == 1

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = a * a'
@test b isa SparseArrayDOK{elt,2}
@test b == [144 0; 0 0]
@test storedlength(b) == 1

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = a .+ 2 .* a'
@test b isa SparseArrayDOK{elt}
@test b isa SparseArrayDOK{elt,2}
@test b == [0 12; 24 0]
@test storedlength(b) == 2

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = a[1:2, 2]
@test b isa SparseArrayDOK{elt,1}
@test b == [12, 0]
@test storedlength(b) == 1
end

0 comments on commit 9c1665b

Please sign in to comment.