Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding NamedArrayPartition type #293

Merged
merged 12 commits into from
Jan 4, 2024
3 changes: 2 additions & 1 deletion src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ include("utils.jl")
include("vector_of_array.jl")
include("tabletraits.jl")
include("array_partition.jl")
include("named_array_partition.jl")

function Base.show(io::IO, x::Union{ArrayPartition, AbstractVectorOfArray})
invoke(show, Tuple{typeof(io), Any}, io, x)
Expand Down Expand Up @@ -52,6 +53,6 @@ export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_pus
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
recursive_unitless_bottom_eltype, recursive_unitless_eltype

export ArrayPartition
export ArrayPartition, NamedArrayPartition

end # module
114 changes: 114 additions & 0 deletions src/named_array_partition.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
NamedArrayPartition(; kwargs...)
NamedArrayPartition(x::NamedTuple)

Similar to an `ArrayPartition` but the individual arrays can be accessed via the
constructor-specified names. However, unlike `ArrayPartition`, each individual array
must have the same element type.
"""
struct NamedArrayPartition{T, A<:ArrayPartition{T}, NT<:NamedTuple} <: AbstractVector{T}
array_partition::A
names_to_indices::NT
end
NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs))
function NamedArrayPartition(x::NamedTuple)
names_to_indices = NamedTuple(Pair(symbol, index) for (index, symbol) in enumerate(keys(x)))

Check warning on line 15 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L13-L15

Added lines #L13 - L15 were not covered by tests

# enforce homogeneity of eltypes
@assert all(eltype.(values(x)) .== eltype(first(x)))
T = eltype(first(x))
S = typeof(values(x))
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices)

Check warning on line 21 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L18-L21

Added lines #L18 - L21 were not covered by tests
end

# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
# fields except through `getfield` and accessor functions.
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)

Check warning on line 26 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L26

Added line #L26 was not covered by tests

Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))

Check warning on line 28 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L28

Added line #L28 was not covered by tests

Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} =

Check warning on line 30 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L30

Added line #L30 was not covered by tests
NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices))
Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors

Check warning on line 32 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L32

Added line #L32 was not covered by tests


Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices))
Base.getproperty(x::NamedArrayPartition, s::Symbol) =

Check warning on line 36 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L35-L36

Added lines #L35 - L36 were not covered by tests
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s))

# this enables x.s = some_array.
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v)
index = getproperty(getfield(x, :names_to_indices), s)
ArrayPartition(x).x[index] .= v

Check warning on line 42 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L40-L42

Added lines #L40 - L42 were not covered by tests
end

# print out NamedArrayPartition as a NamedTuple
Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:")
Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition) =

Check warning on line 47 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L46-L47

Added lines #L46 - L47 were not covered by tests
show(io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x)))

Base.size(x::NamedArrayPartition) = size(ArrayPartition(x))
Base.length(x::NamedArrayPartition) = length(ArrayPartition(x))
Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...)

Check warning on line 52 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L50-L52

Added lines #L50 - L52 were not covered by tests

Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...)
Base.map(f, x::NamedArrayPartition) = NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x))

Check warning on line 56 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L54-L56

Added lines #L54 - L56 were not covered by tests
# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x))

Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT} =

Check warning on line 59 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L59

Added line #L59 was not covered by tests
NamedArrayPartition{T, S, NT}(similar(ArrayPartition(x)), getfield(x, :names_to_indices))

# broadcasting
Base.BroadcastStyle(::Type{<:NamedArrayPartition}) = Broadcast.ArrayStyle{NamedArrayPartition}()
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}},

Check warning on line 64 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L63-L64

Added lines #L63 - L64 were not covered by tests
::Type{ElType}) where {ElType}
x = find_NamedArrayPartition(bc)
return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices))

Check warning on line 67 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L66-L67

Added lines #L66 - L67 were not covered by tests
end

# when broadcasting with ArrayPartition + another array type, the output is the other array tupe
Base.BroadcastStyle(::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) =

Check warning on line 71 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L71

Added line #L71 was not covered by tests
Broadcast.DefaultArrayStyle{1}()

# hook into ArrayPartition broadcasting routines
@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x))
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) =

Check warning on line 76 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L75-L76

Added lines #L75 - L76 were not covered by tests
Broadcast.Broadcasted(bc.f, RecursiveArrayTools.unpack_args(i, bc.args))
@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i)

Check warning on line 78 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L78

Added line #L78 was not covered by tests

Base.copy(A::NamedArrayPartition{T,S,NT}) where {T,S,NT} =

Check warning on line 80 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L80

Added line #L80 was not covered by tests
NamedArrayPartition{T,S,NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices))

@inline NamedArrayPartition(f::F, N, names_to_indices) where F<:Function =

Check warning on line 83 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L83

Added line #L83 was not covered by tests
NamedArrayPartition(ArrayPartition(ntuple(f, Val(N))), names_to_indices)

@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
N = npartitions(bc)
@inline function f(i)
copy(unpack(bc, i))

Check warning on line 89 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L86-L89

Added lines #L86 - L89 were not covered by tests
end
x = find_NamedArrayPartition(bc)
NamedArrayPartition(f, N, getfield(x, :names_to_indices))

Check warning on line 92 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L91-L92

Added lines #L91 - L92 were not covered by tests
end

@inline function Base.copyto!(dest::NamedArrayPartition,

Check warning on line 95 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L95

Added line #L95 was not covered by tests
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
N = npartitions(dest, bc)
@inline function f(i)
copyto!(ArrayPartition(dest).x[i], unpack(bc, i))

Check warning on line 99 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L97-L99

Added lines #L97 - L99 were not covered by tests
end
ntuple(f, Val(N))
return dest

Check warning on line 102 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L101-L102

Added lines #L101 - L102 were not covered by tests
end

# `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments.
find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args)
find_NamedArrayPartition(args::Tuple) =

Check warning on line 107 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L106-L107

Added lines #L106 - L107 were not covered by tests
find_NamedArrayPartition(find_NamedArrayPartition(args[1]), Base.tail(args))
find_NamedArrayPartition(x) = x
find_NamedArrayPartition(::Tuple{}) = nothing
find_NamedArrayPartition(x::NamedArrayPartition, rest) = x
find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest)

Check warning on line 112 in src/named_array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/named_array_partition.jl#L109-L112

Added lines #L109 - L112 were not covered by tests


30 changes: 30 additions & 0 deletions test/named_array_partition_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
@testset "NamedArrayPartition tests" begin
x = NamedArrayPartition(a = ones(10), b = rand(20))
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
@test typeof(@. sin(x * x^2 / x - 1)) <: NamedArrayPartition
@test typeof(x.^2) <: NamedArrayPartition
@test x.a ≈ ones(10)
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence
@test all(x .== x[1:end])
y = copy(x)
@test zero(x, (10, 20)) == zero(x) # test that ignoring dims works
@test typeof(zero(x)) <: NamedArrayPartition
@test (y .*= 2).a[1] ≈ 2 # test in-place bcast

@test length(Array(x))==30
@test typeof(Array(x)) <: Array
@test propertynames(x) == (:a, :b)

x = NamedArrayPartition(a = ones(1), b = 2*ones(1))
@test Base.summary(x) == string(typeof(x), " with arrays:")
@test (@capture_out Base.show(stdout, MIME"text/plain"(), x)) == "(a = [1.0], b = [2.0])"
jlchan marked this conversation as resolved.
Show resolved Hide resolved
jlchan marked this conversation as resolved.
Show resolved Hide resolved

using StructArrays
using StaticArrays: SVector
x = NamedArrayPartition(a = StructArray{SVector{2, Float64}}((ones(5), 2*ones(5))),
b = StructArray{SVector{2, Float64}}((3 * ones(2,2), 4*ones(2,2))))
@test typeof(x.a) <: StructVector{<:SVector{2}}
@test typeof(x.b) <: StructArray{<:SVector{2}, 2}
@test typeof((x->x[1]).(x)) <: NamedArrayPartition
@test typeof(map(x->x[1], x)) <: NamedArrayPartition
end

5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ end
@time @safetestset "Utils Tests" begin
include("utils_test.jl")
end
@time @safetestset "NamedArrayPartition Tests" begin
include("named_array_partition_tests.jl")
end
@time @safetestset "Partitions Tests" begin
include("partitions_test.jl")
end
end
@time @safetestset "VecOfArr Indexing Tests" begin
include("basic_indexing.jl")
end
Expand Down
Loading