From 662993ffb67162957842f04598542007b79a62bd Mon Sep 17 00:00:00 2001 From: Akira Kyle Date: Thu, 5 Dec 2024 11:07:41 -0700 Subject: [PATCH] Add test_bases from QuantumOpticsBase and reorganize a bit --- src/QuantumInterface.jl | 36 +++++- src/abstract_types.jl | 32 +++-- src/bases.jl | 274 +--------------------------------------- src/embed_permute.jl | 26 ++-- src/identityoperator.jl | 4 +- src/julia_base.jl | 39 +++--- src/julia_linalg.jl | 6 +- src/linalg.jl | 198 ++++++++++++++++++++++++++++- src/show.jl | 69 ++++++++++ src/sparse.jl | 2 +- test/runtests.jl | 1 + test/test_bases.jl | 55 ++++++++ 12 files changed, 415 insertions(+), 327 deletions(-) create mode 100644 src/show.jl create mode 100644 test/test_bases.jl diff --git a/src/QuantumInterface.jl b/src/QuantumInterface.jl index 34efa04..89d3c9a 100644 --- a/src/QuantumInterface.jl +++ b/src/QuantumInterface.jl @@ -1,14 +1,39 @@ module QuantumInterface -import Base: ==, +, -, *, /, ^, length, one, exp, conj, conj!, transpose, copy -import LinearAlgebra: tr, ishermitian, norm, normalize, normalize! -import Base: show, summary -import SparseArrays: sparse, spzeros, AbstractSparseMatrix # TODO move to an extension +## +# Basis specific +## + +""" + basis(a) + +Return the basis of an object. + +If it's ambiguous, e.g. if an operator has a different left and right basis, +an [`IncompatibleBases`](@ref) error is thrown. +""" +function basis end + +""" +Exception that should be raised for an illegal algebraic operation. +""" +mutable struct IncompatibleBases <: Exception end + + +## +# Standard methods +## function apply! end function dagger end +""" + directsum(x, y, z...) + +Direct sum of the given objects. Alternatively, the unicode +symbol ⊕ (\\oplus) can be used. +""" function directsum end const ⊕ = directsum directsum() = GenericBasis(0) @@ -86,8 +111,9 @@ function squeeze end function wigner end -include("bases.jl") include("abstract_types.jl") +include("bases.jl") +include("show.jl") include("linalg.jl") include("tensor.jl") diff --git a/src/abstract_types.jl b/src/abstract_types.jl index f8667c9..0650290 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -1,3 +1,18 @@ +""" +Abstract base class for all specialized bases. + +The Basis class is meant to specify a basis of the Hilbert space of the +studied system. Besides basis specific information all subclasses must +implement a shape variable which indicates the dimension of the used +Hilbert space. For a spin-1/2 Hilbert space this would be the +vector `[2]`. A system composed of two spins would then have a +shape vector `[2 2]`. + +Composite systems can be defined with help of the [`CompositeBasis`](@ref) +class. +""" +abstract type Basis end + """ Abstract base class for `Bra` and `Ket` states. @@ -38,20 +53,3 @@ A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)} ``` """ abstract type AbstractSuperOperator{B1,B2} end - -function summary(stream::IO, x::AbstractOperator) - print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n") - if samebases(x) - print(stream, " basis: ") - show(stream, basis(x)) - else - print(stream, " basis left: ") - show(stream, x.basis_l) - print(stream, "\n basis right: ") - show(stream, x.basis_r) - end -end - -show(stream::IO, x::AbstractOperator) = summary(stream, x) - -traceout!(s::StateVector, i) = ptrace(s,i) diff --git a/src/bases.jl b/src/bases.jl index 6e4b077..2c059e0 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -1,17 +1,6 @@ -""" -Abstract base class for all specialized bases. - -The Basis class is meant to specify a basis of the Hilbert space of the -studied system. Besides basis specific information all subclasses must -implement a shape variable which indicates the dimension of the used -Hilbert space. For a spin-1/2 Hilbert space this would be the -vector `[2]`. A system composed of two spins would then have a -shape vector `[2 2]`. - -Composite systems can be defined with help of the [`CompositeBasis`](@ref) -class. -""" -abstract type Basis end +## +# GenericBasis, CompositeBasis +## """ length(b::Basis) @@ -20,17 +9,6 @@ Total dimension of the Hilbert space. """ Base.length(b::Basis) = prod(b.shape) -""" - basis(a) - -Return the basis of an object. - -If it's ambiguous, e.g. if an operator has a different left and right basis, -an [`IncompatibleBases`](@ref) error is thrown. -""" -function basis end - - """ GenericBasis(N) @@ -67,39 +45,6 @@ CompositeBasis(bases::Vector) = CompositeBasis((bases...,)) Base.:(==)(b1::T, b2::T) where T<:CompositeBasis = equal_shape(b1.shape, b2.shape) -tensor(b::Basis) = b - -""" - tensor(x::Basis, y::Basis, z::Basis...) - -Create a [`CompositeBasis`](@ref) from the given bases. - -Any given CompositeBasis is expanded so that the resulting CompositeBasis never -contains another CompositeBasis. -""" -tensor(b1::Basis, b2::Basis) = CompositeBasis([length(b1); length(b2)], (b1, b2)) -tensor(b1::CompositeBasis, b2::CompositeBasis) = CompositeBasis([b1.shape; b2.shape], (b1.bases..., b2.bases...)) -function tensor(b1::CompositeBasis, b2::Basis) - N = length(b1.bases) - shape = vcat(b1.shape, length(b2)) - bases = (b1.bases..., b2) - CompositeBasis(shape, bases) -end -function tensor(b1::Basis, b2::CompositeBasis) - N = length(b2.bases) - shape = vcat(length(b1), b2.shape) - bases = (b1, b2.bases...) - CompositeBasis(shape, bases) -end -tensor(bases::Basis...) = reduce(tensor, bases) - -function Base.:^(b::Basis, N::Integer) - if N < 1 - throw(ArgumentError("Power of a basis is only defined for positive integers.")) - end - tensor([b for i=1:N]...) -end - """ equal_shape(a, b) @@ -137,130 +82,6 @@ function equal_bases(a, b) return true end -""" -Exception that should be raised for an illegal algebraic operation. -""" -mutable struct IncompatibleBases <: Exception end - -const BASES_CHECK = Ref(true) - -""" - @samebases - -Macro to skip checks for same bases. Useful for `*`, `expect` and similar -functions. -""" -macro samebases(ex) - return quote - BASES_CHECK.x = false - local val = $(esc(ex)) - BASES_CHECK.x = true - val - end -end - -""" - samebases(a, b) - -Test if two objects have the same bases. -""" -samebases(b1::Basis, b2::Basis) = b1==b2 -samebases(b1::Tuple{Basis, Basis}, b2::Tuple{Basis, Basis}) = b1==b2 # for checking superoperators - -""" - check_samebases(a, b) - -Throw an [`IncompatibleBases`](@ref) error if the objects don't have -the same bases. -""" -function check_samebases(b1, b2) - if BASES_CHECK[] && !samebases(b1, b2) - throw(IncompatibleBases()) - end -end - - -""" - multiplicable(a, b) - -Check if two objects are multiplicable. -""" -multiplicable(b1::Basis, b2::Basis) = b1==b2 - -function multiplicable(b1::CompositeBasis, b2::CompositeBasis) - if !equal_shape(b1.shape,b2.shape) - return false - end - for i=1:length(b1.shape) - if !multiplicable(b1.bases[i], b2.bases[i]) - return false - end - end - return true -end - -""" - check_multiplicable(a, b) - -Throw an [`IncompatibleBases`](@ref) error if the objects are -not multiplicable. -""" -function check_multiplicable(b1, b2) - if BASES_CHECK[] && !multiplicable(b1, b2) - throw(IncompatibleBases()) - end -end - -""" - reduced(a, indices) - -Reduced basis, state or operator on the specified subsystems. - -The `indices` argument, which can be a single integer or a vector of integers, -specifies which subsystems are kept. At least one index must be specified. -""" -function reduced(b::CompositeBasis, indices) - if length(indices)==0 - throw(ArgumentError("At least one subsystem must be specified in reduced.")) - elseif length(indices)==1 - return b.bases[indices[1]] - else - return CompositeBasis(b.shape[indices], b.bases[indices]) - end -end - -""" - ptrace(a, indices) - -Partial trace of the given basis, state or operator. - -The `indices` argument, which can be a single integer or a vector of integers, -specifies which subsystems are traced out. The number of indices has to be -smaller than the number of subsystems, i.e. it is not allowed to perform a -full trace. -""" -function ptrace(b::CompositeBasis, indices) - J = [i for i in 1:length(b.bases) if i ∉ indices] - length(J) > 0 || throw(ArgumentError("Tracing over all indices is not allowed in ptrace.")) - reduced(b, J) -end - - -""" - permutesystems(a, perm) - -Change the ordering of the subsystems of the given object. - -For a permutation vector `[2,1,3]` and a given object with basis `[b1, b2, b3]` -this function results in `[b2, b1, b3]`. -""" -function permutesystems(b::CompositeBasis, perm) - @assert length(b.bases) == length(perm) - @assert isperm(perm) - CompositeBasis(b.shape[perm], b.bases[perm]) -end - - ## # Common bases ## @@ -366,89 +187,6 @@ SumBasis(shape, bases::Vector) = (tmp = (bases...,); SumBasis(shape, tmp)) SumBasis(bases::Vector) = SumBasis((bases...,)) SumBasis(bases::Basis...) = SumBasis((bases...,)) -==(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape) -==(b1::SumBasis, b2::SumBasis) = false -length(b::SumBasis) = sum(b.shape) - -""" - directsum(b1::Basis, b2::Basis) - -Construct the [`SumBasis`](@ref) out of two sub-bases. -""" -directsum(b1::Basis, b2::Basis) = SumBasis(Int[length(b1); length(b2)], Basis[b1, b2]) -directsum(b::Basis) = b -directsum(b::Basis...) = reduce(directsum, b) -function directsum(b1::SumBasis, b2::Basis) - shape = [b1.shape;length(b2)] - bases = [b1.bases...;b2] - return SumBasis(shape, (bases...,)) -end -function directsum(b1::Basis, b2::SumBasis) - shape = [length(b1);b2.shape] - bases = [b1;b2.bases...] - return SumBasis(shape, (bases...,)) -end -function directsum(b1::SumBasis, b2::SumBasis) - shape = [b1.shape;b2.shape] - bases = [b1.bases...;b2.bases...] - return SumBasis(shape, (bases...,)) -end - -embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops) - -## -# show methods -## - -function show(stream::IO, x::GenericBasis) - if length(x.shape) == 1 - write(stream, "Basis(dim=$(x.shape[1]))") - else - s = replace(string(x.shape), " " => "") - write(stream, "Basis(shape=$s)") - end -end - -function show(stream::IO, x::CompositeBasis) - write(stream, "[") - for i in 1:length(x.bases) - show(stream, x.bases[i]) - if i != length(x.bases) - write(stream, " ⊗ ") - end - end - write(stream, "]") -end - -function show(stream::IO, x::SpinBasis) - d = denominator(x.spinnumber) - n = numerator(x.spinnumber) - if d == 1 - write(stream, "Spin($n)") - else - write(stream, "Spin($n/$d)") - end -end - -function show(stream::IO, x::FockBasis) - if iszero(x.offset) - write(stream, "Fock(cutoff=$(x.N))") - else - write(stream, "Fock(cutoff=$(x.N), offset=$(x.offset))") - end -end - -function show(stream::IO, x::NLevelBasis) - write(stream, "NLevel(N=$(x.N))") -end - -function show(stream::IO, x::SumBasis) - write(stream, "[") - for i in 1:length(x.bases) - show(stream, x.bases[i]) - if i != length(x.bases) - write(stream, " ⊕ ") - end - end - write(stream, "]") -end +Base.:(==)(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape) +Base.:(==)(b1::SumBasis, b2::SumBasis) = false +Base.length(b::SumBasis) = sum(b.shape) diff --git a/src/embed_permute.jl b/src/embed_permute.jl index c2cc4ca..297ab58 100644 --- a/src/embed_permute.jl +++ b/src/embed_permute.jl @@ -67,8 +67,8 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis, ops_sb = [x[2] for x in idxop_sb] for (idxsb, opsb) in zip(indices_sb, ops_sb) - (opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases()) - (opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases()) + (opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases()) # FIXME issue #12 + (opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases()) # FIXME issue #12 end S = length(operators) > 0 ? mapreduce(eltype, promote_type, operators) : Any @@ -83,10 +83,20 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis, return embed_op end -permutesystems(a::AbstractOperator, perm) = arithmetic_unary_error("Permutations of subsystems", a) +embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops) + +""" + permutesystems(a, perm) + +Change the ordering of the subsystems of the given object. -nsubsystems(s::AbstractKet) = nsubsystems(basis(s)) -nsubsystems(s::AbstractOperator) = nsubsystems(basis(s)) -nsubsystems(b::CompositeBasis) = length(b.bases) -nsubsystems(b::Basis) = 1 -nsubsystems(::Nothing) = 1 # TODO Exists because of QuantumSavory; Consider removing this and reworking the functions that depend on it. E.g., a reason to have it when performing a project_traceout measurement on a state that contains only one subsystem +For a permutation vector `[2,1,3]` and a given object with basis `[b1, b2, b3]` +this function results in `[b2, b1, b3]`. +""" +function permutesystems(b::CompositeBasis, perm) + @assert length(b.bases) == length(perm) + @assert isperm(perm) + CompositeBasis(b.shape[perm], b.bases[perm]) +end + +permutesystems(a::AbstractOperator, perm) = arithmetic_unary_error("Permutations of subsystems", a) diff --git a/src/identityoperator.jl b/src/identityoperator.jl index 5959882..aa031ff 100644 --- a/src/identityoperator.jl +++ b/src/identityoperator.jl @@ -1,4 +1,4 @@ -one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x) +Base.one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x) """ identityoperator(a::Basis[, b::Basis]) @@ -22,4 +22,4 @@ identityoperator(::Type{T}, ::Type{Any}, b1::Basis, b2::Basis) where T<:Abstract identityoperator(b1::Basis, b2::Basis) = identityoperator(ComplexF64, b1, b2) """Prepare the identity superoperator over a given space.""" -function identitysuperoperator end \ No newline at end of file +function identitysuperoperator end diff --git a/src/julia_base.jl b/src/julia_base.jl index 9a0532d..2d8e085 100644 --- a/src/julia_base.jl +++ b/src/julia_base.jl @@ -1,3 +1,5 @@ +import Base: +, -, *, /, ^, length, exp, conj, conj!, adjoint, transpose, copy + # Common error messages arithmetic_unary_error(funcname, x::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this type of operator: $(typeof(x)).\nTry to convert to another operator type first with e.g. dense() or sparse().")) arithmetic_binary_error(funcname, a::AbstractOperator, b::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this combination of types of operators: $(typeof(a)), $(typeof(b)).\nTry to convert to a common operator type first with e.g. dense() or sparse().")) @@ -8,33 +10,31 @@ addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an op # States ## --(a::T) where {T<:StateVector} = T(a.basis, -a.data) +-(a::T) where {T<:StateVector} = T(a.basis, -a.data) # FIXME issue #12 *(a::StateVector, b::Number) = b*a -copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) -length(a::StateVector) = length(a.basis)::Int -basis(a::StateVector) = a.basis -directsum(x::StateVector...) = reduce(directsum, x) +copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) # FIXME issue #12 +length(a::StateVector) = length(a.basis)::Int # FIXME issue #12 +basis(a::StateVector) = a.basis # FIXME issue #12 +adjoint(a::StateVector) = dagger(a) + # Array-like functions -Base.size(x::StateVector) = size(x.data) -@inline Base.axes(x::StateVector) = axes(x.data) +Base.size(x::StateVector) = size(x.data) # FIXME issue #12 +@inline Base.axes(x::StateVector) = axes(x.data) # FIXME issue #12 Base.ndims(x::StateVector) = 1 Base.ndims(::Type{<:StateVector}) = 1 -Base.eltype(x::StateVector) = eltype(x.data) +Base.eltype(x::StateVector) = eltype(x.data) # FIXME issue #12 # Broadcasting Base.broadcastable(x::StateVector) = x -Base.adjoint(a::StateVector) = dagger(a) - - ## # Operators ## -length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int -basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) -basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) +length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int # FIXME issue #12 +basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) # FIXME issue #12 +basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) # FIXME issue #12 # Ensure scalar broadcasting Base.broadcastable(x::AbstractOperator) = Ref(x) @@ -60,14 +60,19 @@ Operator exponential. """ exp(op::AbstractOperator) = throw(ArgumentError("exp() is not defined for this type of operator: $(typeof(op)).\nTry to convert to dense operator first with dense().")) -Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r)) +Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r)) # FIXME issue #12 function Base.size(op::AbstractOperator, i::Int) i < 1 && throw(ErrorException("dimension index is < 1")) i > 2 && return 1 - i==1 ? length(op.basis_l) : length(op.basis_r) + i==1 ? length(op.basis_l) : length(op.basis_r) # FIXME issue #12 end -Base.adjoint(a::AbstractOperator) = dagger(a) +dagger(a::AbstractOperator) = arithmetic_unary_error("Hermitian conjugate", a) + +adjoint(a::AbstractOperator) = dagger(a) + +transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a) + conj(a::AbstractOperator) = arithmetic_unary_error("Complex conjugate", a) conj!(a::AbstractOperator) = conj(a::AbstractOperator) diff --git a/src/julia_linalg.jl b/src/julia_linalg.jl index d2f4d3d..3087d0a 100644 --- a/src/julia_linalg.jl +++ b/src/julia_linalg.jl @@ -1,3 +1,5 @@ +import LinearAlgebra: tr, ishermitian, norm, normalize, normalize! + """ ishermitian(op::AbstractOperator) @@ -17,7 +19,7 @@ tr(x::AbstractOperator) = arithmetic_unary_error("Trace", x) Norm of the given bra or ket state. """ -norm(x::StateVector) = norm(x.data) +norm(x::StateVector) = norm(x.data) # FIXME issue #12 """ normalize(x::StateVector) @@ -31,7 +33,7 @@ normalize(x::StateVector) = x/norm(x) In-place normalization of the given bra or ket so that `norm(x)` is one. """ -normalize!(x::StateVector) = (normalize!(x.data); x) +normalize!(x::StateVector) = (normalize!(x.data); x) # FIXME issue #12 """ normalize(op) diff --git a/src/linalg.jl b/src/linalg.jl index 8bb47cd..076ec32 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -1,10 +1,194 @@ -samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool -samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool -check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r) -multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) -dagger(a::AbstractOperator) = arithmetic_unary_error("Hermitian conjugate", a) -transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a) -directsum(a::AbstractOperator...) = reduce(directsum, a) +## +# Basis checks +## + +const BASES_CHECK = Ref(true) + +""" + @samebases + +Macro to skip checks for same bases. Useful for `*`, `expect` and similar +functions. +""" +macro samebases(ex) + return quote + BASES_CHECK.x = false + local val = $(esc(ex)) + BASES_CHECK.x = true + val + end +end + +""" + samebases(a, b) + +Test if two objects have the same bases. +""" +samebases(b1::Basis, b2::Basis) = b1==b2 +samebases(b1::Tuple{Basis, Basis}, b2::Tuple{Basis, Basis}) = b1==b2 # for checking superoperators + +""" + check_samebases(a, b) + +Throw an [`IncompatibleBases`](@ref) error if the objects don't have +the same bases. +""" +function check_samebases(b1, b2) + if BASES_CHECK[] && !samebases(b1, b2) + throw(IncompatibleBases()) + end +end + + +""" + multiplicable(a, b) + +Check if two objects are multiplicable. +""" +multiplicable(b1::Basis, b2::Basis) = b1==b2 + +function multiplicable(b1::CompositeBasis, b2::CompositeBasis) + if !equal_shape(b1.shape,b2.shape) + return false + end + for i=1:length(b1.shape) + if !multiplicable(b1.bases[i], b2.bases[i]) + return false + end + end + return true +end + +""" + check_multiplicable(a, b) + +Throw an [`IncompatibleBases`](@ref) error if the objects are +not multiplicable. +""" +function check_multiplicable(b1, b2) + if BASES_CHECK[] && !multiplicable(b1, b2) + throw(IncompatibleBases()) + end +end + +samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool # FIXME issue #12 +samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool # FIXME issue #12 +check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r) # FIXME issue #12 +multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) # FIXME issue #12 + +## +# tensor, reduce, ptrace +## + +tensor(b::Basis) = b + +""" + tensor(x::Basis, y::Basis, z::Basis...) + +Create a [`CompositeBasis`](@ref) from the given bases. + +Any given CompositeBasis is expanded so that the resulting CompositeBasis never +contains another CompositeBasis. +""" +tensor(b1::Basis, b2::Basis) = CompositeBasis([length(b1); length(b2)], (b1, b2)) +tensor(b1::CompositeBasis, b2::CompositeBasis) = CompositeBasis([b1.shape; b2.shape], (b1.bases..., b2.bases...)) +function tensor(b1::CompositeBasis, b2::Basis) + N = length(b1.bases) + shape = vcat(b1.shape, length(b2)) + bases = (b1.bases..., b2) + CompositeBasis(shape, bases) +end +function tensor(b1::Basis, b2::CompositeBasis) + N = length(b2.bases) + shape = vcat(length(b1), b2.shape) + bases = (b1, b2.bases...) + CompositeBasis(shape, bases) +end +tensor(bases::Basis...) = reduce(tensor, bases) + +function Base.:^(b::Basis, N::Integer) + if N < 1 + throw(ArgumentError("Power of a basis is only defined for positive integers.")) + end + tensor([b for i=1:N]...) +end + +""" + reduced(a, indices) + +Reduced basis, state or operator on the specified subsystems. + +The `indices` argument, which can be a single integer or a vector of integers, +specifies which subsystems are kept. At least one index must be specified. +""" +function reduced(b::CompositeBasis, indices) + if length(indices)==0 + throw(ArgumentError("At least one subsystem must be specified in reduced.")) + elseif length(indices)==1 + return b.bases[indices[1]] + else + return CompositeBasis(b.shape[indices], b.bases[indices]) + end +end + +""" + ptrace(a, indices) + +Partial trace of the given basis, state or operator. + +The `indices` argument, which can be a single integer or a vector of integers, +specifies which subsystems are traced out. The number of indices has to be +smaller than the number of subsystems, i.e. it is not allowed to perform a +full trace. +""" +function ptrace(b::CompositeBasis, indices) + J = [i for i in 1:length(b.bases) if i ∉ indices] + length(J) > 0 || throw(ArgumentError("Tracing over all indices is not allowed in ptrace.")) + reduced(b, J) +end + ptrace(a::AbstractOperator, index) = arithmetic_unary_error("Partial trace", a) _index_complement(b::CompositeBasis, indices) = complement(length(b.bases), indices) reduced(a, indices) = ptrace(a, _index_complement(basis(a), indices)) +traceout!(s::StateVector, i) = ptrace(s,i) + +## +# nsubsystems +## + +nsubsystems(s::AbstractKet) = nsubsystems(basis(s)) +nsubsystems(s::AbstractOperator) = nsubsystems(basis(s)) +nsubsystems(b::CompositeBasis) = length(b.bases) +nsubsystems(b::Basis) = 1 +nsubsystems(::Nothing) = 1 # TODO Exists because of QuantumSavory; Consider removing this and reworking the functions that depend on it. E.g., a reason to have it when performing a project_traceout measurement on a state that contains only one subsystem + +## +# directsum +## + +""" + directsum(b1::Basis, b2::Basis) + +Construct the [`SumBasis`](@ref) out of two sub-bases. +""" +directsum(b1::Basis, b2::Basis) = SumBasis(Int[length(b1); length(b2)], Basis[b1, b2]) +directsum(b::Basis) = b +directsum(b::Basis...) = reduce(directsum, b) +function directsum(b1::SumBasis, b2::Basis) + shape = [b1.shape;length(b2)] + bases = [b1.bases...;b2] + return SumBasis(shape, (bases...,)) +end +function directsum(b1::Basis, b2::SumBasis) + shape = [length(b1);b2.shape] + bases = [b1;b2.bases...] + return SumBasis(shape, (bases...,)) +end +function directsum(b1::SumBasis, b2::SumBasis) + shape = [b1.shape;b2.shape] + bases = [b1.bases...;b2.bases...] + return SumBasis(shape, (bases...,)) +end + +directsum(x::StateVector...) = reduce(directsum, x) +directsum(a::AbstractOperator...) = reduce(directsum, a) diff --git a/src/show.jl b/src/show.jl new file mode 100644 index 0000000..38607b0 --- /dev/null +++ b/src/show.jl @@ -0,0 +1,69 @@ +import Base: show, summary + +function summary(stream::IO, x::AbstractOperator) + print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n") + if samebases(x) + print(stream, " basis: ") + show(stream, basis(x)) + else + print(stream, " basis left: ") + show(stream, x.basis_l) + print(stream, "\n basis right: ") + show(stream, x.basis_r) + end +end + +show(stream::IO, x::AbstractOperator) = summary(stream, x) + +function show(stream::IO, x::GenericBasis) + if length(x.shape) == 1 + write(stream, "Basis(dim=$(x.shape[1]))") + else + s = replace(string(x.shape), " " => "") + write(stream, "Basis(shape=$s)") + end +end + +function show(stream::IO, x::CompositeBasis) + write(stream, "[") + for i in 1:length(x.bases) + show(stream, x.bases[i]) + if i != length(x.bases) + write(stream, " ⊗ ") + end + end + write(stream, "]") +end + +function show(stream::IO, x::SpinBasis) + d = denominator(x.spinnumber) + n = numerator(x.spinnumber) + if d == 1 + write(stream, "Spin($n)") + else + write(stream, "Spin($n/$d)") + end +end + +function show(stream::IO, x::FockBasis) + if iszero(x.offset) + write(stream, "Fock(cutoff=$(x.N))") + else + write(stream, "Fock(cutoff=$(x.N), offset=$(x.offset))") + end +end + +function show(stream::IO, x::NLevelBasis) + write(stream, "NLevel(N=$(x.N))") +end + +function show(stream::IO, x::SumBasis) + write(stream, "[") + for i in 1:length(x.bases) + show(stream, x.bases[i]) + if i != length(x.bases) + write(stream, " ⊕ ") + end + end + write(stream, "]") +end diff --git a/src/sparse.jl b/src/sparse.jl index 2ba8f5f..d6b301c 100644 --- a/src/sparse.jl +++ b/src/sparse.jl @@ -1,4 +1,4 @@ -# TODO make an extension? +import SparseArrays: sparse, spzeros, AbstractSparseMatrix # TODO move to an extension # dense(a::AbstractOperator) = arithmetic_unary_error("Conversion to dense", a) diff --git a/test/runtests.jl b/test/runtests.jl index 0bccf25..826fe33 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,7 @@ end println("Starting tests with $(Threads.nthreads()) threads out of `Sys.CPU_THREADS = $(Sys.CPU_THREADS)`...") @doset "sortedindices" +@doset "bases" #VERSION >= v"1.9" && @doset "doctests" get(ENV,"JET_TEST","")=="true" && @doset "jet" VERSION >= v"1.9" && @doset "aqua" diff --git a/test/test_bases.jl b/test/test_bases.jl new file mode 100644 index 0000000..1d91673 --- /dev/null +++ b/test/test_bases.jl @@ -0,0 +1,55 @@ +using Test +using QuantumInterface: tensor, ⊗, ptrace, reduced, permutesystems, equal_bases, multiplicable +using QuantumInterface: GenericBasis, CompositeBasis, NLevelBasis, FockBasis + +@testset "basis" begin + +shape1 = [5] +shape2 = [2, 3] +shape3 = [6] + +b1 = GenericBasis(shape1) +b2 = GenericBasis(shape2) +b3 = GenericBasis(shape3) + +@test b1.shape == shape1 +@test b2.shape == shape2 +@test b1 != b2 +@test b1 != FockBasis(2) +@test b1 == b1 + +@test tensor(b1) == b1 +comp_b1 = tensor(b1, b2) +comp_uni = b1 ⊗ b2 +comp_b2 = tensor(b1, b1, b2) +@test comp_b1.shape == [prod(shape1), prod(shape2)] +@test comp_uni.shape == [prod(shape1), prod(shape2)] +@test comp_b2.shape == [prod(shape1), prod(shape1), prod(shape2)] + +@test b1^3 == CompositeBasis(b1, b1, b1) +@test (b1⊗b2)^2 == CompositeBasis(b1, b2, b1, b2) +@test_throws ArgumentError b1^(0) + +comp_b1_b2 = tensor(comp_b1, comp_b2) +@test comp_b1_b2.shape == [prod(shape1), prod(shape2), prod(shape1), prod(shape1), prod(shape2)] +@test comp_b1_b2 == CompositeBasis(b1, b2, b1, b1, b2) + +@test_throws ArgumentError tensor() +@test comp_b2.shape == tensor(b1, comp_b1).shape +@test comp_b2 == tensor(b1, comp_b1) + +@test_throws ArgumentError ptrace(comp_b1, [1, 2]) +@test ptrace(comp_b2, [1]) == ptrace(comp_b2, [2]) == comp_b1 == ptrace(comp_b2, 1) +@test ptrace(comp_b2, [1, 2]) == ptrace(comp_b1, [1]) +@test ptrace(comp_b2, [2, 3]) == ptrace(comp_b1, [2]) +@test ptrace(comp_b2, [2, 3]) == reduced(comp_b2, [1]) +@test_throws ArgumentError reduced(comp_b1, []) + +comp1 = tensor(b1, b2, b3) +comp2 = tensor(b2, b1, b3) +@test permutesystems(comp1, [2,1,3]) == comp2 + +@test !equal_bases([b1, b2], [b1, b3]) +@test !multiplicable(comp1, b1 ⊗ b2 ⊗ NLevelBasis(prod(b3.shape))) + +end # testset