Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting Unitful LinearMaps #196

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Manifest.toml
16 changes: 7 additions & 9 deletions src/LinearMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ julia> A(x)
"""
function Base.:(*)(A::LinearMap, x::AbstractVector)
check_dim_mul(A, x)
T = promote_type(eltype(A), eltype(x))
T = typeof(oneunit(eltype(A)) * oneunit(eltype(x)))
y = similar(x, T, axes(A)[1])
return @inbounds mul!(y, A, x)
end
Expand Down Expand Up @@ -308,25 +308,23 @@ function _generic_map_mul!(Y, A, X::AbstractMatrix, α, β)
end
return Y
end
function _generic_map_mul!(Y, A, s::Number)
T = promote_type(eltype(A), typeof(s))
function _generic_map_mul!(Y, A, s::S) where {S <: Number}
ax2 = axes(A)[2]
xi = zeros(T, ax2)
xi = zeros(S, ax2)
@inbounds for (i, Yi) in zip(ax2, eachcol(Y))
xi[i] = s
mul!(Yi, A, xi)
xi[i] = zero(T)
xi[i] = zero(S)
end
return Y
end
function _generic_map_mul!(Y, A, s::Number, α, β)
T = promote_type(eltype(A), typeof(s))
function _generic_map_mul!(Y, A, s::S, α, β) where {S <: Number}
ax2 = axes(A)[2]
xi = zeros(T, ax2)
xi = zeros(S, ax2)
@inbounds for (i, Yi) in zip(ax2, eachcol(Y))
xi[i] = s
mul!(Yi, A, xi, α, β)
xi[i] = zero(T)
xi[i] = zero(S)
end
return Y
end
Expand Down
52 changes: 34 additions & 18 deletions src/composition.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
# appropriate type for product of two (possibly Unitful) quantities:
_multype(a, b) = typeof(oneunit(eltype(a)) * oneunit(eltype(b)))
# this variant is needed because reinterpret changes length of complex vectors
#_multype(a, b, iscomplex::Bool) =
# typeof((1+0im) * oneunit(eltype(a)) * oneunit(eltype(b)))

struct CompositeMap{T, As<:LinearMapTupleOrVector} <: LinearMap{T}
maps::As # stored in order of application to vector
function CompositeMap{T, As}(maps::As) where {T, As}
N = length(maps)
for n in 2:N
check_dim_mul(maps[n], maps[n-1])
end
for TA in Base.Iterators.map(eltype, maps)
promote_type(T, TA) == T ||
error("eltype $TA cannot be promoted to $T in CompositeMap constructor")
end
Tprod = typeof(*(map(oneunit∘eltype, maps)...)) # handles units
promote_type(T, Tprod) == T ||
error("eltype $Tprod and $T incompatible in CompositeMap constructor")
new{T, As}(maps)
end
end

CompositeMap{T}(maps::As) where {T, As<:LinearMapTupleOrVector} = CompositeMap{T, As}(maps)

Base.mapreduce(::typeof(identity), ::typeof(Base.mul_prod), maps::LinearMapTupleOrVector) =
CompositeMap{promote_type(map(eltype, maps)...)}(reverse(maps))
function Base.mapreduce(::typeof(identity), ::typeof(Base.mul_prod), maps::LinearMapTupleOrVector)
Tprod = typeof(*(map(oneunit∘eltype, maps)...)) # handles units
return CompositeMap{Tprod}(reverse(maps))
end
Base.mapreduce(::typeof(identity), ::typeof(Base.mul_prod), maps::AbstractVector{<:LinearMap{T}}) where {T} =
CompositeMap{T}(reverse(maps))

Expand Down Expand Up @@ -80,11 +88,11 @@ end

# scalar multiplication and division (non-commutative case)
function Base.:(*)(α::Number, A::LinearMap)
T = promote_type(typeof(α), eltype(A))
T = _multype(α, A)
return CompositeMap{T}(_combine(A, UniformScalingMap(α, size(A, 1))))
end
function Base.:(*)(α::Number, A::CompositeMap)
T = promote_type(typeof(α), eltype(A))
T = _multype(α, A)
Alast = last(A.maps)
if Alast isa UniformScalingMap
return CompositeMap{T}(_combine(_front(A.maps), α * Alast))
Expand All @@ -94,15 +102,15 @@ function Base.:(*)(α::Number, A::CompositeMap)
end
# needed for disambiguation
function Base.:(*)(α::RealOrComplex, A::CompositeMap{<:RealOrComplex})
T = Base.promote_op(*, typeof(α), eltype(A))
T = _multype(α, A)
return ScaledMap{T}(α, A)
end
function Base.:(*)(A::LinearMap, α::Number)
T = promote_type(typeof(α), eltype(A))
T = _multype(A, α)
return CompositeMap{T}(_combine(UniformScalingMap(α, size(A, 2)), A))
end
function Base.:(*)(A::CompositeMap, α::Number)
T = promote_type(typeof(α), eltype(A))
T = _multype(A, α)
Afirst = first(A.maps)
if Afirst isa UniformScalingMap
return CompositeMap{T}(_combine(Afirst * α, _tail(A.maps)))
Expand All @@ -112,7 +120,7 @@ function Base.:(*)(A::CompositeMap, α::Number)
end
# needed for disambiguation
function Base.:(*)(A::CompositeMap{<:RealOrComplex}, α::RealOrComplex)
T = Base.promote_op(*, typeof(α), eltype(A))
T = _multype(A, α)
return ScaledMap{T}(α, A)
end

Expand All @@ -137,19 +145,19 @@ julia> LinearMap(ones(Int, 3, 3)) * CS * I * rand(3, 3);
```
"""
function Base.:(*)(A₁::LinearMap, A₂::LinearMap)
T = promote_type(eltype(A₁), eltype(A₂))
T = _multype(A₁, A₂)
return CompositeMap{T}(_combine(A₂, A₁))
end
function Base.:(*)(A₁::LinearMap, A₂::CompositeMap)
T = promote_type(eltype(A₁), eltype(A₂))
T = _multype(A₁, A₂)
return CompositeMap{T}(_combine(A₂.maps, A₁))
end
function Base.:(*)(A₁::CompositeMap, A₂::LinearMap)
T = promote_type(eltype(A₁), eltype(A₂))
T = _multype(A₁, A₂)
return CompositeMap{T}(_combine(A₂, A₁.maps))
end
function Base.:(*)(A₁::CompositeMap, A₂::CompositeMap)
T = promote_type(eltype(A₁), eltype(A₂))
T = _multype(A₁, A₂)
return CompositeMap{T}(_combine(A₂.maps, A₁.maps))
end
# needed for disambiguation
Expand Down Expand Up @@ -217,6 +225,11 @@ function _compositemulN!(y, A::CompositeMap, x,
src = nothing,
dst = nothing)
N = length(A.maps) # ≥ 3
# caution: be careful if y is complex but intermediate products are not
# source = reinterpret(_multype(A.maps[1], x, !isreal(y)), source) # for units
# resize!(source, size(A.maps[1],1)) # trick due to complex case
# todo: build reinterpret into _unsafe_mul! instead?
# only necessary if either source or map has Number type instead of Real|Complex
n = n0 = firstindex(A.maps)
source = isnothing(src) ?
convert(AbstractArray, A.maps[n] * x) :
Expand All @@ -227,9 +240,12 @@ function _compositemulN!(y, A::CompositeMap, x,
_unsafe_mul!(dst, A.maps[n], source)
dest, source = source, dest # alternate dest and source
for n in (n0+2):N-1
dest = _resize(dest, (size(A.maps[n], 1), size(x)[2:end]...))
# dest = _resize(dest, (size(A.maps[n], 1), size(x)[2:end]...))
# dest = reinterpret(_multype(A.maps[n], source), dest)
dest = similar([], _multype(A.maps[n], source), (size(A.maps[n], 1), size(source)[2:end]...))
_unsafe_mul!(dest, A.maps[n], source)
dest, source = source, dest # alternate dest and source
# dest, source = source, dest # alternate dest and source
source = dest
end
_unsafe_mul!(y, last(A.maps), source)
return y
Expand Down
6 changes: 4 additions & 2 deletions src/scaledmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ struct ScaledMap{T, S<:RealOrComplex, L<:LinearMap} <: LinearMap{T}
λ::S
lmap::L
function ScaledMap{T}(λ::S, A::L) where {T, S <: RealOrComplex, L <: LinearMap{<:RealOrComplex}}
@assert Base.promote_op(*, S, eltype(A)) == T "target type $T cannot hold products of $S and $(eltype(A)) objects"
Tprod = typeof(oneunit(S) * oneunit(eltype(A)))
promote_type(T, Tprod) == T ||
error("target type $T vs product of $S and $(eltype(A))")
new{T,S,L}(λ, A)
end
end

# constructor
ScaledMap(λ::RealOrComplex, lmap::LinearMap{<:RealOrComplex}) =
ScaledMap{Base.promote_op(*, typeof(λ), eltype(lmap))}(λ, lmap)
ScaledMap{typeof(oneunit(λ) * oneunit(eltype(lmap)))}(λ, lmap)

# basic methods
Base.size(A::ScaledMap) = size(A.lmap)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
Aqua = "0.5, 0.6"
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ include("inversemap.jl")

include("rrules.jl")

include("units.jl")

include("khatrirao.jl")

include("trace.jl")
2 changes: 1 addition & 1 deletion test/scaledmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ using Test, LinearMaps, LinearAlgebra
# complex case
β = π + 2π * im
C = @inferred β * A
@test_throws AssertionError LinearMaps.ScaledMap{Float64}(β, A)
@test_throws ErrorException LinearMaps.ScaledMap{Float64}(β, A)
@inferred conj(β) * A' # needed in left-mul
T = ComplexF64
xc = rand(T, N)
Expand Down
26 changes: 26 additions & 0 deletions test/units.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# test/units

using Test: @test, @testset, @inferred, @test_throws
using LinearMaps: LinearMap
using Unitful: m, s, g

@testset "units" begin
A = rand(4,3) * 1m
B = rand(3,2) * 1s
C = A * B
D = 1f0g * C

Ma = @inferred LinearMap(A)
Mb = @inferred LinearMap(B)
Mc = Ma * Mb
Md = 1f0g * Mc

@test Matrix(Ma) == A
@test Matrix(Mc) == C
@test Matrix(Md) == D

x = randn(2)
@test B * x == Mb * x
@test C * x ≈ Mc * x
@test D * x ≈ Md * x
end