Skip to content

Commit

Permalink
Reductions (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Dec 11, 2024
1 parent f08a672 commit b2fea1d
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Derive"
uuid = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.1"
version = "0.3.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
29 changes: 29 additions & 0 deletions src/abstractarrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,35 @@ end
return error("Not implemented.")
end

@interface ::AbstractArrayInterface function Base.mapreduce(
f, op, as::AbstractArray...; kwargs...
)
return error("Not implemented.")
end

# TODO: Generalize to multiple inputs.
@interface interface::AbstractInterface function Base.reduce(f, a::AbstractArray; kwargs...)
return @interface interface mapreduce(identity, f, a; kwargs...)
end

@interface interface::AbstractArrayInterface function Base.all(a::AbstractArray)
return @interface interface reduce(&, a; init=true)
end

@interface interface::AbstractArrayInterface function Base.all(
f::Function, a::AbstractArray
)
return @interface interface mapreduce(f, &, a; init=true)
end

@interface interface::AbstractArrayInterface function Base.iszero(a::AbstractArray)
return @interface interface all(iszero, a)
end

@interface interface::AbstractArrayInterface function Base.isreal(a::AbstractArray)
return @interface interface all(isreal, a)
end

@interface ::AbstractArrayInterface function Base.permutedims!(
a_dest::AbstractArray, a_src::AbstractArray, perm
)
Expand Down
4 changes: 2 additions & 2 deletions src/interface_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ This errors for debugging, but probably should be defined as:
call(interface, f, args...) = f(args...)
```
=#
call(interface, f, args...) = error("Not implemented")
call(interface, f, args...; kwargs...) = error("Not implemented")

# Change the behavior of a function to use a certain interface.
struct InterfaceFunction{Interface,F} <: Function
interface::Interface
f::F
end
(f::InterfaceFunction)(args...) = call(f.interface, f.f, args...)
(f::InterfaceFunction)(args...; kwargs...) = call(f.interface, f.f, args...; kwargs...)
6 changes: 6 additions & 0 deletions src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ function derive(::Val{:AbstractArrayOps}, type)
Base.copy(::$type)
Base.map(::Any, ::$type...)
Base.map!(::Any, ::AbstractArray, ::$type...)
Base.mapreduce(::Any, ::Any, ::$type...; kwargs...)
Base.reduce(::Any, ::$type...; kwargs...)
Base.all(::Function, ::$type)
Base.all(::$type)
Base.iszero(::$type)
Base.real(::$type)
Base.permutedims!(::Any, ::$type, ::Any)
Broadcast.BroadcastStyle(::Type{<:$type})
ArrayLayouts.MemoryLayout(::Type{<:$type})
Expand Down
28 changes: 27 additions & 1 deletion test/basics/SparseArrayDOKs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,25 @@ function setunstoredindex!(a::AbstractArray, value, I::CartesianIndex)
return setunstoredindex!(a, value, Tuple(I)...)
end

# A view of the stored values of an array.
# Similar to: `@view a[collect(eachstoredindex(a))]`, but the issue
# with that is it returns a `SubArray` wrapping a sparse array, which
# is then interpreted as a sparse array. Also, that involves extra
# logic for determining if the indices are stored or not, but we know
# the indices are stored.
struct StoredValues{T,A<:AbstractArray{T},I} <: AbstractVector{T}
array::A
storedindices::I
end
StoredValues(a::AbstractArray) = StoredValues(a, collect(eachstoredindex(a)))
Base.size(a::StoredValues) = size(a.storedindices)
Base.getindex(a::StoredValues, I::Int) = getstoredindex(a.array, a.storedindices[I])
function Base.setindex!(a::StoredValues, value, I::Int)
return setstoredindex!(a.array, value, a.storedindices[I])
end

storedvalues(a::AbstractArray) = StoredValues(a)

using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout
using Derive: Derive, @array_aliases, @derive, @interface, AbstractArrayInterface, interface
using LinearAlgebra: LinearAlgebra
Expand All @@ -29,8 +48,8 @@ end
a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N}
) where {N}
checkbounds(a, I...)
iszero(value) && return a
if !isstored(a, I...)
iszero(value) && return a
setunstoredindex!(a, value, I...)
return a
end
Expand Down Expand Up @@ -67,6 +86,13 @@ end
return a_dest
end

@interface ::SparseArrayInterface function Base.mapreduce(
f, op, a::AbstractArray; kwargs...
)
# TODO: Need to select a better `init`.
return mapreduce(f, op, storedvalues(a); kwargs...)
end

# ArrayLayouts functionality.

function ArrayLayouts.sub_materialize(::SparseLayout, a::AbstractArray, axes::Tuple)
Expand Down
9 changes: 9 additions & 0 deletions test/basics/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,13 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test b isa SparseArrayDOK{elt,1}
@test b == [12, 0]
@test storedlength(b) == 1

a = SparseArrayDOK{elt}(2, 2)
@test iszero(a)
a[2, 1] = 21
a[1, 2] = 12
@test !iszero(a)
@test isreal(a)
@test sum(a) == 33
@test mapreduce(x -> 2x, +, a) == 66
end

0 comments on commit b2fea1d

Please sign in to comment.