Skip to content

Commit

Permalink
Add test_bases from QuantumOpticsBase and reorganize a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
akirakyle committed Dec 7, 2024
1 parent f22adb7 commit 662993f
Show file tree
Hide file tree
Showing 12 changed files with 415 additions and 327 deletions.
36 changes: 31 additions & 5 deletions src/QuantumInterface.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,39 @@
module QuantumInterface

import Base: ==, +, -, *, /, ^, length, one, exp, conj, conj!, transpose, copy
import LinearAlgebra: tr, ishermitian, norm, normalize, normalize!
import Base: show, summary
import SparseArrays: sparse, spzeros, AbstractSparseMatrix # TODO move to an extension
##
# Basis specific
##

"""
basis(a)
Return the basis of an object.
If it's ambiguous, e.g. if an operator has a different left and right basis,
an [`IncompatibleBases`](@ref) error is thrown.
"""
function basis end

"""
Exception that should be raised for an illegal algebraic operation.
"""
mutable struct IncompatibleBases <: Exception end


##
# Standard methods
##

function apply! end

function dagger end

"""
directsum(x, y, z...)
Direct sum of the given objects. Alternatively, the unicode
symbol ⊕ (\\oplus) can be used.
"""
function directsum end
const = directsum
directsum() = GenericBasis(0)
Expand Down Expand Up @@ -86,8 +111,9 @@ function squeeze end
function wigner end


include("bases.jl")
include("abstract_types.jl")
include("bases.jl")
include("show.jl")

include("linalg.jl")
include("tensor.jl")
Expand Down
32 changes: 15 additions & 17 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
"""
Abstract base class for all specialized bases.
The Basis class is meant to specify a basis of the Hilbert space of the
studied system. Besides basis specific information all subclasses must
implement a shape variable which indicates the dimension of the used
Hilbert space. For a spin-1/2 Hilbert space this would be the
vector `[2]`. A system composed of two spins would then have a
shape vector `[2 2]`.
Composite systems can be defined with help of the [`CompositeBasis`](@ref)
class.
"""
abstract type Basis end

"""
Abstract base class for `Bra` and `Ket` states.
Expand Down Expand Up @@ -38,20 +53,3 @@ A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)}
```
"""
abstract type AbstractSuperOperator{B1,B2} end

function summary(stream::IO, x::AbstractOperator)
print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n")
if samebases(x)
print(stream, " basis: ")
show(stream, basis(x))
else
print(stream, " basis left: ")
show(stream, x.basis_l)
print(stream, "\n basis right: ")
show(stream, x.basis_r)
end
end

show(stream::IO, x::AbstractOperator) = summary(stream, x)

traceout!(s::StateVector, i) = ptrace(s,i)
274 changes: 6 additions & 268 deletions src/bases.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
"""
Abstract base class for all specialized bases.
The Basis class is meant to specify a basis of the Hilbert space of the
studied system. Besides basis specific information all subclasses must
implement a shape variable which indicates the dimension of the used
Hilbert space. For a spin-1/2 Hilbert space this would be the
vector `[2]`. A system composed of two spins would then have a
shape vector `[2 2]`.
Composite systems can be defined with help of the [`CompositeBasis`](@ref)
class.
"""
abstract type Basis end
##
# GenericBasis, CompositeBasis
##

"""
length(b::Basis)
Expand All @@ -20,17 +9,6 @@ Total dimension of the Hilbert space.
"""
Base.length(b::Basis) = prod(b.shape)

"""
basis(a)
Return the basis of an object.
If it's ambiguous, e.g. if an operator has a different left and right basis,
an [`IncompatibleBases`](@ref) error is thrown.
"""
function basis end


"""
GenericBasis(N)
Expand Down Expand Up @@ -67,39 +45,6 @@ CompositeBasis(bases::Vector) = CompositeBasis((bases...,))

Base.:(==)(b1::T, b2::T) where T<:CompositeBasis = equal_shape(b1.shape, b2.shape)

tensor(b::Basis) = b

"""
tensor(x::Basis, y::Basis, z::Basis...)
Create a [`CompositeBasis`](@ref) from the given bases.
Any given CompositeBasis is expanded so that the resulting CompositeBasis never
contains another CompositeBasis.
"""
tensor(b1::Basis, b2::Basis) = CompositeBasis([length(b1); length(b2)], (b1, b2))
tensor(b1::CompositeBasis, b2::CompositeBasis) = CompositeBasis([b1.shape; b2.shape], (b1.bases..., b2.bases...))
function tensor(b1::CompositeBasis, b2::Basis)
N = length(b1.bases)
shape = vcat(b1.shape, length(b2))
bases = (b1.bases..., b2)
CompositeBasis(shape, bases)
end
function tensor(b1::Basis, b2::CompositeBasis)
N = length(b2.bases)
shape = vcat(length(b1), b2.shape)
bases = (b1, b2.bases...)
CompositeBasis(shape, bases)
end
tensor(bases::Basis...) = reduce(tensor, bases)

function Base.:^(b::Basis, N::Integer)
if N < 1
throw(ArgumentError("Power of a basis is only defined for positive integers."))
end
tensor([b for i=1:N]...)
end

"""
equal_shape(a, b)
Expand Down Expand Up @@ -137,130 +82,6 @@ function equal_bases(a, b)
return true
end

"""
Exception that should be raised for an illegal algebraic operation.
"""
mutable struct IncompatibleBases <: Exception end

const BASES_CHECK = Ref(true)

"""
@samebases
Macro to skip checks for same bases. Useful for `*`, `expect` and similar
functions.
"""
macro samebases(ex)
return quote
BASES_CHECK.x = false
local val = $(esc(ex))
BASES_CHECK.x = true
val
end
end

"""
samebases(a, b)
Test if two objects have the same bases.
"""
samebases(b1::Basis, b2::Basis) = b1==b2
samebases(b1::Tuple{Basis, Basis}, b2::Tuple{Basis, Basis}) = b1==b2 # for checking superoperators

"""
check_samebases(a, b)
Throw an [`IncompatibleBases`](@ref) error if the objects don't have
the same bases.
"""
function check_samebases(b1, b2)
if BASES_CHECK[] && !samebases(b1, b2)
throw(IncompatibleBases())
end
end


"""
multiplicable(a, b)
Check if two objects are multiplicable.
"""
multiplicable(b1::Basis, b2::Basis) = b1==b2

function multiplicable(b1::CompositeBasis, b2::CompositeBasis)
if !equal_shape(b1.shape,b2.shape)
return false
end
for i=1:length(b1.shape)
if !multiplicable(b1.bases[i], b2.bases[i])
return false
end
end
return true
end

"""
check_multiplicable(a, b)
Throw an [`IncompatibleBases`](@ref) error if the objects are
not multiplicable.
"""
function check_multiplicable(b1, b2)
if BASES_CHECK[] && !multiplicable(b1, b2)
throw(IncompatibleBases())
end
end

"""
reduced(a, indices)
Reduced basis, state or operator on the specified subsystems.
The `indices` argument, which can be a single integer or a vector of integers,
specifies which subsystems are kept. At least one index must be specified.
"""
function reduced(b::CompositeBasis, indices)
if length(indices)==0
throw(ArgumentError("At least one subsystem must be specified in reduced."))
elseif length(indices)==1
return b.bases[indices[1]]
else
return CompositeBasis(b.shape[indices], b.bases[indices])
end
end

"""
ptrace(a, indices)
Partial trace of the given basis, state or operator.
The `indices` argument, which can be a single integer or a vector of integers,
specifies which subsystems are traced out. The number of indices has to be
smaller than the number of subsystems, i.e. it is not allowed to perform a
full trace.
"""
function ptrace(b::CompositeBasis, indices)
J = [i for i in 1:length(b.bases) if i indices]
length(J) > 0 || throw(ArgumentError("Tracing over all indices is not allowed in ptrace."))
reduced(b, J)
end


"""
permutesystems(a, perm)
Change the ordering of the subsystems of the given object.
For a permutation vector `[2,1,3]` and a given object with basis `[b1, b2, b3]`
this function results in `[b2, b1, b3]`.
"""
function permutesystems(b::CompositeBasis, perm)
@assert length(b.bases) == length(perm)
@assert isperm(perm)
CompositeBasis(b.shape[perm], b.bases[perm])
end


##
# Common bases
##
Expand Down Expand Up @@ -366,89 +187,6 @@ SumBasis(shape, bases::Vector) = (tmp = (bases...,); SumBasis(shape, tmp))
SumBasis(bases::Vector) = SumBasis((bases...,))
SumBasis(bases::Basis...) = SumBasis((bases...,))

==(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape)
==(b1::SumBasis, b2::SumBasis) = false
length(b::SumBasis) = sum(b.shape)

"""
directsum(b1::Basis, b2::Basis)
Construct the [`SumBasis`](@ref) out of two sub-bases.
"""
directsum(b1::Basis, b2::Basis) = SumBasis(Int[length(b1); length(b2)], Basis[b1, b2])
directsum(b::Basis) = b
directsum(b::Basis...) = reduce(directsum, b)
function directsum(b1::SumBasis, b2::Basis)
shape = [b1.shape;length(b2)]
bases = [b1.bases...;b2]
return SumBasis(shape, (bases...,))
end
function directsum(b1::Basis, b2::SumBasis)
shape = [length(b1);b2.shape]
bases = [b1;b2.bases...]
return SumBasis(shape, (bases...,))
end
function directsum(b1::SumBasis, b2::SumBasis)
shape = [b1.shape;b2.shape]
bases = [b1.bases...;b2.bases...]
return SumBasis(shape, (bases...,))
end

embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops)

##
# show methods
##

function show(stream::IO, x::GenericBasis)
if length(x.shape) == 1
write(stream, "Basis(dim=$(x.shape[1]))")
else
s = replace(string(x.shape), " " => "")
write(stream, "Basis(shape=$s)")
end
end

function show(stream::IO, x::CompositeBasis)
write(stream, "[")
for i in 1:length(x.bases)
show(stream, x.bases[i])
if i != length(x.bases)
write(stream, "")
end
end
write(stream, "]")
end

function show(stream::IO, x::SpinBasis)
d = denominator(x.spinnumber)
n = numerator(x.spinnumber)
if d == 1
write(stream, "Spin($n)")
else
write(stream, "Spin($n/$d)")
end
end

function show(stream::IO, x::FockBasis)
if iszero(x.offset)
write(stream, "Fock(cutoff=$(x.N))")
else
write(stream, "Fock(cutoff=$(x.N), offset=$(x.offset))")
end
end

function show(stream::IO, x::NLevelBasis)
write(stream, "NLevel(N=$(x.N))")
end

function show(stream::IO, x::SumBasis)
write(stream, "[")
for i in 1:length(x.bases)
show(stream, x.bases[i])
if i != length(x.bases)
write(stream, "")
end
end
write(stream, "]")
end
Base.:(==)(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape)
Base.:(==)(b1::SumBasis, b2::SumBasis) = false
Base.length(b::SumBasis) = sum(b.shape)

Check warning on line 192 in src/bases.jl

View check run for this annotation

Codecov / codecov/patch

src/bases.jl#L190-L192

Added lines #L190 - L192 were not covered by tests
Loading

0 comments on commit 662993f

Please sign in to comment.