Skip to content

Commit

Permalink
Define map_blocklabels (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe authored Dec 16, 2024
1 parent 4f628c4 commit 625665d
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 13 deletions.
7 changes: 1 addition & 6 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,4 @@ nondual_type(x) = nondual_type(typeof(x))
nondual_type(T::Type) = T

dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))
label_dual(x) = label_dual(LabelledStyle(x), x)
label_dual(::NotLabelled, x) = x
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))

flip(a::AbstractUnitRange) = dual(label_dual(a))
flip(g::AbstractGradedUnitRange) = dual(gradedrange(label_dual.(blocklengths(g))))
flip(a::AbstractUnitRange) = dual(map_blocklabels(dual, a))
6 changes: 6 additions & 0 deletions src/gradedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,9 @@ function blockedunitrange_getindices(
# if `a isa `GradedUnitRange`, for example.
return mortar(blks, labelled_length.(blks))
end

map_blocklabels(::Any, a::AbstractUnitRange) = a
function map_blocklabels(f, g::AbstractGradedUnitRange)
# use labelled_blocks to preserve GradedUnitRange
return labelled_blocks(unlabel_blocks(g), f.(blocklabels(g)))
end
16 changes: 10 additions & 6 deletions src/gradedunitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ end
## TODO: Define this to instantiate a dual unit range.
## materialize_dual(a::GradedUnitRangeDual) = materialize_dual(nondual(a))

Base.first(a::GradedUnitRangeDual) = label_dual(first(nondual(a)))
Base.last(a::GradedUnitRangeDual) = label_dual(last(nondual(a)))
Base.step(a::GradedUnitRangeDual) = label_dual(step(nondual(a)))
Base.first(a::GradedUnitRangeDual) = dual(first(nondual(a)))
Base.last(a::GradedUnitRangeDual) = dual(last(nondual(a)))
Base.step(a::GradedUnitRangeDual) = dual(step(nondual(a)))

Base.view(a::GradedUnitRangeDual, index::Block{1}) = a[index]

Expand All @@ -40,7 +40,7 @@ function blockedunitrange_getindices(
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer)
return label_dual(getindex(nondual(a), indices))
return dual(getindex(nondual(a), indices))
end

function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1})
Expand Down Expand Up @@ -123,8 +123,8 @@ function Base.iterate(a::GradedUnitRangeDual, i)
end

BlockArrays.blockaxes(a::GradedUnitRangeDual) = blockaxes(nondual(a))
BlockArrays.blockfirsts(a::GradedUnitRangeDual) = label_dual.(blockfirsts(nondual(a)))
BlockArrays.blocklasts(a::GradedUnitRangeDual) = label_dual.(blocklasts(nondual(a)))
BlockArrays.blockfirsts(a::GradedUnitRangeDual) = dual.(blockfirsts(nondual(a)))
BlockArrays.blocklasts(a::GradedUnitRangeDual) = dual.(blocklasts(nondual(a)))
function BlockArrays.findblock(a::GradedUnitRangeDual, index::Integer)
return findblock(nondual(a), index)
end
Expand All @@ -138,3 +138,7 @@ end
function unlabel_blocks(a::GradedUnitRangeDual)
return unlabel_blocks(nondual(a))
end

function map_blocklabels(f, g::GradedUnitRangeDual)
return dual(map_blocklabels(f, dual(g)))
end
4 changes: 3 additions & 1 deletion src/labelledunitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ end
dual(a::LabelledUnitRange) = LabelledUnitRangeDual(a)
nondual(a::LabelledUnitRangeDual) = a.nondual_unitrange
dual(a::LabelledUnitRangeDual) = nondual(a)
label_dual(::IsLabelled, a::LabelledUnitRangeDual) = dual(label_dual(nondual(a)))
isdual(::LabelledUnitRangeDual) = true
blocklabels(la::LabelledUnitRangeDual) = [label(la)]

map_blocklabels(f, la::LabelledUnitRange) = labelled(unlabel(la), f(label(la)))
map_blocklabels(f, lad::LabelledUnitRangeDual) = dual(map_blocklabels(f, nondual(lad)))

function nondual_type(
::Type{<:LabelledUnitRangeDual{<:Any,NondualUnitRange}}
) where {NondualUnitRange}
Expand Down
7 changes: 7 additions & 0 deletions test/test_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,24 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n

a = 1:3
ad = dual(a)
af = flip(a)
@test !isdual(a)
@test !isdual(ad)
@test !isdual(af)
@test ad isa UnitRange
@test af isa UnitRange
@test space_isequal(ad, a)
@test space_isequal(af, a)

a = blockedrange([2, 3])
ad = dual(a)
af = flip(a)
@test !isdual(a)
@test !isdual(ad)
@test ad isa BlockedOneTo
@test af isa BlockedOneTo
@test blockisequal(ad, a)
@test blockisequal(af, a)
end

@testset "LabelledUnitRangeDual" begin
Expand Down

0 comments on commit 625665d

Please sign in to comment.