From e112a0138f3685b36b36709a758b54ff3f6e0fcd Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Fri, 20 Dec 2024 18:31:34 -0500 Subject: [PATCH] Add support for linear indexing (#10) --- Project.toml | 4 +- src/abstractsparsearrayinterface.jl | 83 ++++++++++++++++++++++++++--- src/sparsearrayinterface.jl | 15 ++++++ src/wrappers.jl | 44 ++++++++++----- 4 files changed, 124 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index 8136d00..d22082c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseArraysBase" uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208" authors = ["ITensor developers and contributors"] -version = "0.2.0" +version = "0.2.1" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -14,7 +14,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Aqua = "0.8.9" ArrayLayouts = "1.11.0" BroadcastMapConversion = "0.1.0" -Derive = "0.3.0" +Derive = "0.3.6" Dictionaries = "0.4.3" LinearAlgebra = "1.10" SafeTestsets = "0.1" diff --git a/src/abstractsparsearrayinterface.jl b/src/abstractsparsearrayinterface.jl index d3d10d0..1dc18c6 100644 --- a/src/abstractsparsearrayinterface.jl +++ b/src/abstractsparsearrayinterface.jl @@ -104,6 +104,27 @@ end # type instead so fallback functions can use abstract types. abstract type AbstractSparseArrayInterface <: AbstractArrayInterface end +function Derive.combine_interface_rule( + interface1::AbstractSparseArrayInterface, interface2::AbstractSparseArrayInterface +) + return error("Rule not defined.") +end +function Derive.combine_interface_rule( + interface1::Interface, interface2::Interface +) where {Interface<:AbstractSparseArrayInterface} + return interface1 +end +function Derive.combine_interface_rule( + interface1::AbstractSparseArrayInterface, interface2::AbstractArrayInterface +) + return interface1 +end +function Derive.combine_interface_rule( + interface1::AbstractArrayInterface, interface2::AbstractSparseArrayInterface +) + return interface2 +end + to_vec(x) = vec(collect(x)) to_vec(x::AbstractArray) = vec(x) @@ -178,7 +199,46 @@ end return SparseArrayDOK{T}(size...) end -@interface ::AbstractSparseArrayInterface function Base.map!( +# Only map the stored values of the inputs. +function map_stored! end + +@interface interface::AbstractArrayInterface function map_stored!( + f, a_dest::AbstractArray, as::AbstractArray... +) + for I in eachstoredindex(as...) + a_dest[I] = f(map(a -> a[I], as)...) + end + return a_dest +end + +# Only map all values, not just the stored ones. +function map_all! end + +@interface interface::AbstractArrayInterface function map_all!( + f, a_dest::AbstractArray, as::AbstractArray... +) + for I in eachindex(as...) + a_dest[I] = map(f, map(a -> a[I], as)...) + end + return a_dest +end + +using ArrayLayouts: ArrayLayouts, zero! + +# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts` +# and is useful for sparse array logic, since it can be used to empty +# the sparse array storage. +# We use a single function definition to minimize method ambiguities. +@interface interface::AbstractSparseArrayInterface function ArrayLayouts.zero!( + a::AbstractArray +) + # More generally, this codepath could be taking if `zero(eltype(a))` + # is defined and the elements are immutable. + f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero! + return @interface interface map_stored!(f, a, a) +end + +@interface interface::AbstractSparseArrayInterface function Base.map!( f, a_dest::AbstractArray, as::AbstractArray... ) # TODO: Define a function `preserves_unstored(a_dest, f, as...)` @@ -194,15 +254,22 @@ end preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...)) if !preserves_unstored # Doesn't preserve unstored values, loop over all elements. - for I in eachindex(as...) - a_dest[I] = map(f, map(a -> a[I], as)...) - end + @interface interface map_all!(f, a_dest, as...) return a_dest end - # Define `eachstoredindex` promotion. - for I in eachstoredindex(as...) - a_dest[I] = f(map(a -> a[I], as)...) - end + # First zero out the destination. + # TODO: Make this more nuanced, skip when possible, for + # example if the sparsity of the destination is a subset of + # the sparsity of the sources, i.e.: + # ```julia + # if eachstoredindex(as...) ∉ eachstoredindex(a_dest) + # zero!(a_dest) + # end + # ``` + # This is the safest thing to do in general, for example + # if the destination is dense but the sources are sparse. + @interface interface zero!(a_dest) + @interface interface map_stored!(f, a_dest, as...) return a_dest end diff --git a/src/sparsearrayinterface.jl b/src/sparsearrayinterface.jl index 577ff33..505ff0e 100644 --- a/src/sparsearrayinterface.jl +++ b/src/sparsearrayinterface.jl @@ -2,6 +2,21 @@ using Derive: Derive struct SparseArrayInterface <: AbstractSparseArrayInterface end +# Fix ambiguity error. +function Derive.combine_interface_rule(::SparseArrayInterface, ::SparseArrayInterface) + return SparseArrayInterface() +end +function Derive.combine_interface_rule( + interface1::SparseArrayInterface, interface2::AbstractSparseArrayInterface +) + return interface1 +end +function Derive.combine_interface_rule( + interface1::AbstractSparseArrayInterface, interface2::SparseArrayInterface +) + return interface2 +end + # Convenient shorthand to refer to the sparse interface. # Can turn a function into a sparse function with the syntax `sparse(f)`, # i.e. `sparse(map)(x -> 2x, randn(2, 2))` while use the sparse diff --git a/src/wrappers.jl b/src/wrappers.jl index a7123a3..13e6042 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -2,14 +2,28 @@ parentvalue_to_value(a::AbstractArray, value) = value value_to_parentvalue(a::AbstractArray, value) = value eachstoredparentindex(a::AbstractArray) = eachstoredindex(parent(a)) storedparentvalues(a::AbstractArray) = storedvalues(parent(a)) -parentindex_to_index(a::AbstractArray, I::CartesianIndex) = error() -function parentindex_to_index(a::AbstractArray, I::Int...) + +function parentindex_to_index(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N} + return throw(MethodError(parentindex_to_index, Tuple{typeof(a),typeof(I)})) +end +function parentindex_to_index(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N} return Tuple(parentindex_to_index(a, CartesianIndex(I))) end -index_to_parentindex(a::AbstractArray, I::CartesianIndex) = error() -function index_to_parentindex(a::AbstractArray, I::Int...) +# Handle linear indexing. +function parentindex_to_index(a::AbstractArray, I::Int) + return parentindex_to_index(a, CartesianIndices(parent(a))[I]) +end + +function index_to_parentindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N} + return throw(MethodError(index_to_parentindex, Tuple{typeof(a),typeof(I)})) +end +function index_to_parentindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N} return Tuple(index_to_parentindex(a, CartesianIndex(I))) end +# Handle linear indexing. +function index_to_parentindex(a::AbstractArray, I::Int) + return index_to_parentindex(a, CartesianIndices(a)[I]) +end function cartesianindex_reverse(I::CartesianIndex) return CartesianIndex(reverse(Tuple(I))) @@ -21,10 +35,10 @@ tuple_oneto(n) = ntuple(identity, n) genperm(v, perm) = map(j -> v[j], perm) using LinearAlgebra: Adjoint -function parentindex_to_index(a::Adjoint, I::CartesianIndex) +function parentindex_to_index(a::Adjoint, I::CartesianIndex{2}) return cartesianindex_reverse(I) end -function index_to_parentindex(a::Adjoint, I::CartesianIndex) +function index_to_parentindex(a::Adjoint, I::CartesianIndex{2}) return cartesianindex_reverse(I) end function parentvalue_to_value(a::Adjoint, value) @@ -36,18 +50,18 @@ end perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip -function index_to_parentindex(a::PermutedDimsArray, I::CartesianIndex) +function index_to_parentindex(a::PermutedDimsArray{<:Any,N}, I::CartesianIndex{N}) where {N} return CartesianIndex(genperm(I, iperm(a))) end -function parentindex_to_index(a::PermutedDimsArray, I::CartesianIndex) +function parentindex_to_index(a::PermutedDimsArray{<:Any,N}, I::CartesianIndex{N}) where {N} return CartesianIndex(genperm(I, perm(a))) end using Base: ReshapedArray -function parentindex_to_index(a::ReshapedArray, I::CartesianIndex) +function parentindex_to_index(a::ReshapedArray{<:Any,N}, I::CartesianIndex{N}) where {N} return CartesianIndices(size(a))[LinearIndices(parent(a))[I]] end -function index_to_parentindex(a::ReshapedArray, I::CartesianIndex) +function index_to_parentindex(a::ReshapedArray{<:Any,N}, I::CartesianIndex{N}) where {N} return CartesianIndices(parent(a))[LinearIndices(size(a))[I]] end @@ -56,9 +70,15 @@ function eachstoredparentindex(a::SubArray) return all(d -> I[d] ∈ parentindices(a)[d], 1:ndims(parent(a))) end end +# Don't constrain the number of dimensions of the array +# and index since the parent array can have a different +# number of dimensions than the `SubArray`. function index_to_parentindex(a::SubArray, I::CartesianIndex) return CartesianIndex(Base.reindex(parentindices(a), Tuple(I))) end +# Don't constrain the number of dimensions of the array +# and index since the parent array can have a different +# number of dimensions than the `SubArray`. function parentindex_to_index(a::SubArray, I::CartesianIndex) nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d return !(parentindices(a)[d] isa Real) @@ -81,10 +101,10 @@ function storedparentvalues(a::SubArray) end using LinearAlgebra: Transpose -function parentindex_to_index(a::Transpose, I::CartesianIndex) +function parentindex_to_index(a::Transpose, I::CartesianIndex{2}) return cartesianindex_reverse(I) end -function index_to_parentindex(a::Transpose, I::CartesianIndex) +function index_to_parentindex(a::Transpose, I::CartesianIndex{2}) return cartesianindex_reverse(I) end function parentvalue_to_value(a::Transpose, value)