diff --git a/Project.toml b/Project.toml index 1f01116..0004d6d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Derive" uuid = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a" authors = ["ITensor developers and contributors"] -version = "0.3.1" +version = "0.3.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/abstractarrayinterface.jl b/src/abstractarrayinterface.jl index 3586269..c6fb83a 100644 --- a/src/abstractarrayinterface.jl +++ b/src/abstractarrayinterface.jl @@ -92,6 +92,35 @@ end return error("Not implemented.") end +@interface ::AbstractArrayInterface function Base.mapreduce( + f, op, as::AbstractArray...; kwargs... +) + return error("Not implemented.") +end + +# TODO: Generalize to multiple inputs. +@interface interface::AbstractInterface function Base.reduce(f, a::AbstractArray; kwargs...) + return @interface interface mapreduce(identity, f, a; kwargs...) +end + +@interface interface::AbstractArrayInterface function Base.all(a::AbstractArray) + return @interface interface reduce(&, a; init=true) +end + +@interface interface::AbstractArrayInterface function Base.all( + f::Function, a::AbstractArray +) + return @interface interface mapreduce(f, &, a; init=true) +end + +@interface interface::AbstractArrayInterface function Base.iszero(a::AbstractArray) + return @interface interface all(iszero, a) +end + +@interface interface::AbstractArrayInterface function Base.isreal(a::AbstractArray) + return @interface interface all(isreal, a) +end + @interface ::AbstractArrayInterface function Base.permutedims!( a_dest::AbstractArray, a_src::AbstractArray, perm ) diff --git a/src/interface_function.jl b/src/interface_function.jl index ae0d0df..d34f3e5 100644 --- a/src/interface_function.jl +++ b/src/interface_function.jl @@ -7,11 +7,11 @@ This errors for debugging, but probably should be defined as: call(interface, f, args...) = f(args...) ``` =# -call(interface, f, args...) = error("Not implemented") +call(interface, f, args...; kwargs...) = error("Not implemented") # Change the behavior of a function to use a certain interface. struct InterfaceFunction{Interface,F} <: Function interface::Interface f::F end -(f::InterfaceFunction)(args...) = call(f.interface, f.f, args...) +(f::InterfaceFunction)(args...; kwargs...) = call(f.interface, f.f, args...; kwargs...) diff --git a/src/traits.jl b/src/traits.jl index 1ee5f1b..8fffa37 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -19,6 +19,12 @@ function derive(::Val{:AbstractArrayOps}, type) Base.copy(::$type) Base.map(::Any, ::$type...) Base.map!(::Any, ::AbstractArray, ::$type...) + Base.mapreduce(::Any, ::Any, ::$type...; kwargs...) + Base.reduce(::Any, ::$type...; kwargs...) + Base.all(::Function, ::$type) + Base.all(::$type) + Base.iszero(::$type) + Base.real(::$type) Base.permutedims!(::Any, ::$type, ::Any) Broadcast.BroadcastStyle(::Type{<:$type}) ArrayLayouts.MemoryLayout(::Type{<:$type}) diff --git a/test/basics/SparseArrayDOKs.jl b/test/basics/SparseArrayDOKs.jl index 563f26a..44a3e32 100644 --- a/test/basics/SparseArrayDOKs.jl +++ b/test/basics/SparseArrayDOKs.jl @@ -10,6 +10,25 @@ function setunstoredindex!(a::AbstractArray, value, I::CartesianIndex) return setunstoredindex!(a, value, Tuple(I)...) end +# A view of the stored values of an array. +# Similar to: `@view a[collect(eachstoredindex(a))]`, but the issue +# with that is it returns a `SubArray` wrapping a sparse array, which +# is then interpreted as a sparse array. Also, that involves extra +# logic for determining if the indices are stored or not, but we know +# the indices are stored. +struct StoredValues{T,A<:AbstractArray{T},I} <: AbstractVector{T} + array::A + storedindices::I +end +StoredValues(a::AbstractArray) = StoredValues(a, collect(eachstoredindex(a))) +Base.size(a::StoredValues) = size(a.storedindices) +Base.getindex(a::StoredValues, I::Int) = getstoredindex(a.array, a.storedindices[I]) +function Base.setindex!(a::StoredValues, value, I::Int) + return setstoredindex!(a.array, value, a.storedindices[I]) +end + +storedvalues(a::AbstractArray) = StoredValues(a) + using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout using Derive: Derive, @array_aliases, @derive, @interface, AbstractArrayInterface, interface using LinearAlgebra: LinearAlgebra @@ -29,8 +48,8 @@ end a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N} ) where {N} checkbounds(a, I...) - iszero(value) && return a if !isstored(a, I...) + iszero(value) && return a setunstoredindex!(a, value, I...) return a end @@ -67,6 +86,13 @@ end return a_dest end +@interface ::SparseArrayInterface function Base.mapreduce( + f, op, a::AbstractArray; kwargs... +) + # TODO: Need to select a better `init`. + return mapreduce(f, op, storedvalues(a); kwargs...) +end + # ArrayLayouts functionality. function ArrayLayouts.sub_materialize(::SparseLayout, a::AbstractArray, axes::Tuple) diff --git a/test/basics/test_basics.jl b/test/basics/test_basics.jl index 2b4876d..581c3ef 100644 --- a/test/basics/test_basics.jl +++ b/test/basics/test_basics.jl @@ -72,4 +72,13 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test b isa SparseArrayDOK{elt,1} @test b == [12, 0] @test storedlength(b) == 1 + + a = SparseArrayDOK{elt}(2, 2) + @test iszero(a) + a[2, 1] = 21 + a[1, 2] = 12 + @test !iszero(a) + @test isreal(a) + @test sum(a) == 33 + @test mapreduce(x -> 2x, +, a) == 66 end