Skip to content

Commit

Permalink
Broadcasting Kets and Operators (#172)
Browse files Browse the repository at this point in the history
Co-authored-by: Stefan Krastanov <[email protected]>
  • Loading branch information
apkille and Krastanov authored Aug 10, 2024
1 parent 06c2845 commit 73bac88
Show file tree
Hide file tree
Showing 15 changed files with 153 additions and 71 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
include:
- os: ubuntu-latest
arch: x64
version: '1.6'
version: '1.10'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
QuantumInterface = "5717a53b-5d69-4fa3-b976-0bf2f97ca1e5"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomMatrices = "2576dda1-a324-5b11-aa66-c48ed7e3c618"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6"
Expand All @@ -28,7 +29,8 @@ LinearAlgebra = "1"
QuantumInterface = "0.3.3"
Random = "1"
RandomMatrices = "0.5"
RecursiveArrayTools = "3"
SparseArrays = "1"
Strided = "1, 2"
UnsafeArrays = "1"
julia = "1.6"
julia = "1.10"
1 change: 1 addition & 0 deletions src/QuantumOpticsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module QuantumOpticsBase

using SparseArrays, LinearAlgebra, LRUCache, Strided, UnsafeArrays, FillArrays
import LinearAlgebra: mul!, rmul!
import RecursiveArrayTools

import QuantumInterface: dagger, directsum, , dm, embed, nsubsystems, expect, identityoperator, identitysuperoperator,
permutesystems, projector, ptrace, reduced, tensor, , variance, apply!, basis, AbstractSuperOperator
Expand Down
49 changes: 31 additions & 18 deletions src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -423,37 +423,50 @@ struct OperatorStyle{BL,BR} <: DataOperatorStyle{BL,BR} end
Broadcast.BroadcastStyle(::Type{<:Operator{BL,BR}}) where {BL,BR} = OperatorStyle{BL,BR}()
Broadcast.BroadcastStyle(::OperatorStyle{B1,B2}, ::OperatorStyle{B3,B4}) where {B1,B2,B3,B4} = throw(IncompatibleBases())

# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`)
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {Bl<:Basis, Br<:Basis, T<:OperatorStyle{Bl,Br}} = T()

# Out-of-place broadcasting
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL,BR,Style<:OperatorStyle{BL,BR},Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
bl,br = find_basis(bcf.args)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
return Operator{BL,BR}(bl, br, copy(bc_))
T = find_dType(bcf)
data = zeros(T, length(bl), length(br))
@inbounds @simd for I in eachindex(bcf)
data[I] = bcf[I]
end
return Operator{BL,BR}(bl, br, data)
end
find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:DataOperator}}, axes)
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
end
find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r)
find_dType(a::DataOperator, rest) = eltype(a)
@inline Base.getindex(a::DataOperator, idx) = getindex(a.data, idx)
Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::DataOperator, i) = x.data[i]
Base.iterate(a::DataOperator) = iterate(a.data)
Base.iterate(a::DataOperator, idx) = iterate(a.data, idx)

# In-place broadcasting
@inline function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL,BR,Style<:DataOperatorStyle{BL,BR},Axes,F,Args}
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && isa(bc.args, Tuple{<:DataOperator{BL,BR}}) # only a single input argument to broadcast!
A = bc.args[1]
if axes(dest) == axes(A)
return copyto!(dest, A)
end
bc′ = Base.Broadcast.preprocess(dest, bc)
dest′ = dest.data
@inbounds @simd for I in eachindex(bc′)
dest′[I] = bc′[I]
end
# Get the underlying data fields of operators and broadcast them as arrays
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
copyto!(dest.data, bc_)
return dest
end
@inline Base.copyto!(A::DataOperator{BL,BR},B::DataOperator{BL,BR}) where {BL,BR} = (copyto!(A.data,B.data); A)
@inline Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL,BR,Style<:DataOperatorStyle,Axes,F,Args} =
throw(IncompatibleBases())

# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl
Base.eltype(::Type{Operator{Bl,Br,A}}) where {Bl,Br,N,A<:AbstractMatrix{N}} = N # ODE init
Base.any(f::Function, x::Operator; kwargs...) = any(f, x.data; kwargs...) # ODE nan checks
Base.all(f::Function, x::Operator; kwargs...) = all(f, x.data; kwargs...)
Base.fill!(x::Operator, a) = typeof(x)(x.basis_l, x.basis_r, fill!(x.data, a))
Base.ndims(x::Type{Operator{Bl,Br,A}}) where {Bl,Br,N,A<:AbstractMatrix{N}} = ndims(A)
Base.similar(x::Operator, t) = typeof(x)(x.basis_l, x.basis_r, copy(x.data))
RecursiveArrayTools.recursivecopy!(dest::Operator{Bl,Br,A},src::Operator{Bl,Br,A}) where {Bl,Br,A} = copyto!(dest,src) # ODE in-place equations
RecursiveArrayTools.recursivecopy(x::Operator) = copy(x)
RecursiveArrayTools.recursivecopy(x::AbstractArray{T}) where {T<:Operator} = copy(x)
RecursiveArrayTools.recursivefill!(x::Operator, a) = fill!(x, a)
90 changes: 48 additions & 42 deletions src/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,52 +180,51 @@ Broadcast.BroadcastStyle(::Type{<:Bra{B}}) where {B} = BraStyle{B}()
Broadcast.BroadcastStyle(::KetStyle{B1}, ::KetStyle{B2}) where {B1,B2} = throw(IncompatibleBases())
Broadcast.BroadcastStyle(::BraStyle{B1}, ::BraStyle{B2}) where {B1,B2} = throw(IncompatibleBases())

# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`)
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:KetStyle{B}} = T()
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:BraStyle{B}} = T()

# Out-of-place broadcasting
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:KetStyle{B},Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
b = find_basis(bcf)
return Ket{B}(b, copy(bc_))
T = find_dType(bcf)
data = zeros(T, length(b))
@inbounds @simd for I in eachindex(bcf)
data[I] = bcf[I]
end
return Ket{B}(b, data)
end
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:BraStyle{B},Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
b = find_basis(bcf)
return Bra{B}(b, copy(bc_))
end
find_basis(bc::Broadcast.Broadcasted) = find_basis(bc.args)
find_basis(args::Tuple) = find_basis(find_basis(args[1]), Base.tail(args))
find_basis(x) = x
find_basis(a::StateVector, rest) = a.basis
find_basis(::Any, rest) = find_basis(rest)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::NTuple{N,<:T}, axes) where {T<:StateVector,N}
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
end
function Broadcasted_restrict_f(f, args::Tuple, axes)
error("Cannot broadcast function `$f` on $(typeof(args))")
T = find_dType(bcf)
data = zeros(T, length(b))
@inbounds @simd for I in eachindex(bcf)
data[I] = bcf[I]
end
return Bra{B}(b, data)
end
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{}, axes) # Defined to avoid method ambiguities
error("Cannot broadcast function `$f` on an empty set of arguments")
for f [:find_basis,:find_dType]
@eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args)
@eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args))
@eval ($f)(x) = x
@eval ($f)(::Any, rest) = ($f)(rest)
end

find_basis(x::T, rest) where {T<:Union{Ket, Bra}} = x.basis
find_dType(x::T, rest) where {T<:Union{Ket, Bra}} = eltype(x)
@inline Base.getindex(x::T, idx) where {T<:Union{Ket, Bra}} = getindex(x.data, idx)
Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::T, i) where {T<:Union{Ket, Bra}} = x.data[i]

# In-place broadcasting for Kets
@inline function Base.copyto!(dest::Ket{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:KetStyle{B},Axes,F,Args}
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && isa(bc.args, Tuple{<:Ket{B}}) # only a single input argument to broadcast!
A = bc.args[1]
if axes(dest) == axes(A)
return copyto!(dest, A)
end
bc′ = Base.Broadcast.preprocess(dest, bc)
dest′ = dest.data
@inbounds @simd for I in eachindex(bc′)
dest′[I] = bc′[I]
end
# Get the underlying data fields of kets and broadcast them as arrays
bcf = Broadcast.flatten(bc)
args_ = Tuple(a.data for a=bcf.args)
bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf))
copyto!(dest.data, bc_)
return dest
end
@inline Base.copyto!(dest::Ket{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1,B2,Style<:KetStyle{B2},Axes,F,Args} =
Expand All @@ -234,20 +233,27 @@ end
# In-place broadcasting for Bras
@inline function Base.copyto!(dest::Bra{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:BraStyle{B},Axes,F,Args}
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && isa(bc.args, Tuple{<:Bra{B}}) # only a single input argument to broadcast!
A = bc.args[1]
if axes(dest) == axes(A)
return copyto!(dest, A)
end
bc′ = Base.Broadcast.preprocess(dest, bc)
dest′ = dest.data
@inbounds @simd for I in eachindex(bc′)
dest′[I] = bc′[I]
end
# Get the underlying data fields of bras and broadcast them as arrays
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
copyto!(dest.data, bc_)
return dest
end
@inline Base.copyto!(dest::Bra{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1,B2,Style<:BraStyle{B2},Axes,F,Args} =
throw(IncompatibleBases())

@inline Base.copyto!(A::T,B::T) where T<:Union{Ket, Bra} = (copyto!(A.data,B.data); A) # Can not use T<:QuantumInterface.StateVector, because StateVector does not imply the existence of a data property
@inline Base.copyto!(dest::T,src::T) where {T<:Union{Ket, Bra}} = (copyto!(dest.data,src.data); dest) # Can not use T<:QuantumInterface.StateVector, because StateVector does not imply the existence of a data property

# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl
Base.eltype(::Type{Ket{B,A}}) where {B,N,A<:AbstractVector{N}} = N # ODE init
Base.eltype(::Type{Bra{B,A}}) where {B,N,A<:AbstractVector{N}} = N
Base.any(f::Function, x::T; kwargs...) where {T<:Union{Ket, Bra}} = any(f, x.data; kwargs...) # ODE nan checks
Base.all(f::Function, x::T; kwargs...) where {T<:Union{Ket, Bra}} = all(f, x.data; kwargs...)
Base.fill!(x::T, a) where {T<:Union{Ket, Bra}} = typeof(x)(x.basis, fill!(x.data, a))
Base.similar(x::T, t) where {T<:Union{Ket, Bra}} = typeof(x)(x.basis, similar(x.data))
RecursiveArrayTools.recursivecopy!(dest::Ket{B,A},src::Ket{B,A}) where {B,A} = copyto!(dest, src) # ODE in-place equations
RecursiveArrayTools.recursivecopy!(dest::Bra{B,A},src::Bra{B,A}) where {B,A} = copyto!(dest, src)
RecursiveArrayTools.recursivecopy(x::T) where {T<:Union{Ket, Bra}} = copy(x)
RecursiveArrayTools.recursivecopy(x::AbstractArray{T}) where {T<:Union{Ket, Bra}} = copy(x)
RecursiveArrayTools.recursivefill!(x::T, a) where {T<:Union{Ket, Bra}} = fill!(x, a)
2 changes: 1 addition & 1 deletion src/superoperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ end
# end
find_basis(a::SuperOperator, rest) = (a.basis_l, a.basis_r)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:SuperOperator}}, axes)
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
QuantumInterface = "5717a53b-5d69-4fa3-b976-0bf2f97ca1e5"
QuantumOptics = "6e0679c1-51ea-5a7c-ac74-d61b76210b0c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomMatrices = "2576dda1-a324-5b11-aa66-c48ed7e3c618"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ names = [
"test_subspace.jl",
"test_state_definitions.jl",

"test_sciml_broadcast_interfaces.jl",

"test_transformations.jl",

"test_metrics.jl",
Expand Down
1 change: 0 additions & 1 deletion test/test_abstractdata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ op1 .= op1_ .+ 3 * op1_
bf = FockBasis(3)
op3 = randtestoperator(bf)
@test_throws QuantumOpticsBase.IncompatibleBases op1 .+ op3
@test_throws ErrorException cos.(op1)

####################
# Test lazy tensor #
Expand Down
2 changes: 1 addition & 1 deletion test/test_jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ using LinearAlgebra, LRUCache, Strided, StridedViews, Dates, SparseArrays, Rando
AnyFrameModule(RandomMatrices))
)
@show rep
@test length(JET.get_reports(rep)) <= 24
@test length(JET.get_reports(rep)) <= 28
@test_broken length(JET.get_reports(rep)) == 0
end
end # testset
1 change: 0 additions & 1 deletion test/test_operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,6 @@ op1 .= op1_ .+ 3 * op1_
bf = FockBasis(3)
op3 = randoperator(bf)
@test_throws QuantumOpticsBase.IncompatibleBases op1 .+ op3
@test_throws ErrorException cos.(op1)

# Dimension mismatches
b1, b2, b3 = NLevelBasis.((2,3,4)) # N is not a type parameter
Expand Down
1 change: 0 additions & 1 deletion test/test_operators_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,6 @@ op3 = sprandop(FockBasis(1),FockBasis(2))
op_ = copy(op1)
op_ .+= op1
@test op_ == 2*op1
@test_throws ErrorException cos.(op_)

# Dimension mismatches
b1, b2, b3 = NLevelBasis.((2,3,4)) # N is not a type parameter
Expand Down
63 changes: 63 additions & 0 deletions test/test_sciml_broadcast_interfaces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using Test
using QuantumOptics
using OrdinaryDiffEq

@testset "sciml interface" begin

# ket ODE problem
= SpinBasis(1//2)
= spindown(ℋ)
t₀, t₁ = (0.0, pi)
σx = sigmax(ℋ)
iσx = im*σx
schrod!(dψ, ψ, p, t) = QuantumOptics.mul!(dψ, iσx, ψ)

ix = iσx.data
schrod_data!(dψ,ψ,p,t) = QuantumOptics.mul!(dψ, ix, ψ)
u0 = ().data

prob! = ODEProblem(schrod!, , (t₀, t₁))
prob_data! = ODEProblem(schrod_data!, u0, (t₀, t₁))
sol = solve(prob!, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false)
sol_data = solve(prob_data!, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false)

@test sol[end].data sol_data[end]

# dense operator ODE problem
σ₋ = sigmam(ℋ)
σ₊ = σ₋'
mhalfσ₊σ₋ = -σ₊*σ₋/2
ρ0 = dm()
tmp = zero(ρ0)
function lind!(dρ,ρ,p,t)
QuantumOptics.mul!(tmp, ρ, σ₊)
QuantumOptics.mul!(dρ, σ₋, ρ)
QuantumOptics.mul!(dρ, ρ, mhalfσ₊σ₋, true, true)
QuantumOptics.mul!(dρ, mhalfσ₊σ₋, ρ, true, true)
QuantumOptics.mul!(dρ, iσx, ρ, -ComplexF64(1), ComplexF64(1))
QuantumOptics.mul!(dρ, ρ, iσx, true, true)
return
end
m0 = ρ0.data
σ₋d = σ₋.data
σ₊d = σ₊.data
mhalfσ₊σ₋d = mhalfσ₊σ₋.data
tmpd = zero(m0)
function lind_data!(dρ,ρ,p,t)
QuantumOptics.mul!(tmpd, ρ, σ₊d)
QuantumOptics.mul!(dρ, σ₋d, ρ)
QuantumOptics.mul!(dρ, ρ, mhalfσ₊σ₋d, true, true)
QuantumOptics.mul!(dρ, mhalfσ₊σ₋d, ρ, true, true)
QuantumOptics.mul!(dρ, ix, ρ, -ComplexF64(1), ComplexF64(1))
QuantumOptics.mul!(dρ, ρ, ix, true, true)
return
end

prob! = ODEProblem(lind!, ρ0, (t₀, t₁))
prob_data! = ODEProblem(lind_data!, m0, (t₀, t₁))
sol = solve(prob!, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false)
sol_data = solve(prob_data!, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false)

@test sol[end].data sol_data[end]

end
2 changes: 0 additions & 2 deletions test/test_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,6 @@ psi_ .+= psi123
bra_ = copy(bra123)
bra_ .= 3*bra123
@test bra_ == 3*dagger(psi123)
@test_throws ErrorException cos.(psi_)
@test_throws ErrorException cos.(bra_)

end # testset

Expand Down
2 changes: 0 additions & 2 deletions test/test_superoperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,6 @@ Ldense .+= Ldense
Ldense .+= L
@test isa(Ldense, DenseSuperOpType)
@test isapprox(Ldense.data, 5*Ldense_.data)
@test_throws ErrorException cos.(Ldense)
@test_throws ErrorException cos.(L)

b = FockBasis(20)
L = liouvillian(identityoperator(b), [destroy(b)])
Expand Down

0 comments on commit 73bac88

Please sign in to comment.