Skip to content

Commit

Permalink
Merge pull request #8 from jmoo2880/josh-research
Browse files Browse the repository at this point in the history
Add CI Workflow + Entanglement utilities
  • Loading branch information
joshuabmoore authored Dec 23, 2024
2 parents bf142c8 + bc4c96f commit b5b6ae1
Show file tree
Hide file tree
Showing 15 changed files with 311 additions and 18 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Run tests

on:
push:
branches:
- main
tags: '*'
pull_request:

# needed to allow julia-actions/cache to delete old caches that it has created
permissions:
actions: write
contents: read

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: ['1.11.1']
julia-arch: [x64]
os: [ubuntu-latest, windows-latest, macOS-latest]

steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.julia-version }}
arch: ${{ matrix.julia-arch }}
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
# with:
# annotate: true
8 changes: 8 additions & 0 deletions docs/src/docstrings.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,11 @@ MPSTime.symbolic_encoding
MPSTime.model_loss_func
MPSTime.model_bbopt
```

## Analysis
```@docs
MPSTime.von_neumann_entropy
MPSTime.rho_correct
MPSTime.bipartite_spectrum
MPSTime.single_site_spectrum
```
2 changes: 1 addition & 1 deletion docs/src/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ @article{FULCHER2017527
pages = {527-531.e3},
year = {2017},
issn = {2405-4712},
doi = {https://doi.org/10.1016/j.cels.2017.10.001},
doi = {10.1016/j.cels.2017.10.001},
url = {https://www.sciencedirect.com/science/article/pii/S2405471217304386},
author = {Ben D. Fulcher and Nick S. Jones},
}
145 changes: 145 additions & 0 deletions src/Analysis/analyse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
```Julia
von_neumann_entropy(mps::MPS; logfn::Function=log) -> Vector{Float64}
```
Compute the [von Neumann entanglement entropy](https://en.wikipedia.org/wiki/Entropy_of_entanglement) for each site in a Matrix Product State (MPS).
The von Neumann entropy quantifies the entanglement at each bond of the MPS by computing the entropy of the singular value spectrum obtained from a singular value decomposition (SVD). The entropy is computed as:
[ S = -sum_{i} p_i log(p_i) ]
where ( p_i ) are the squared singular values normalized to sum to 1.
# Arguments
- `mps::MPS`: The Matrix Product State (MPS) whose entanglement entropy is to be computed.
- `logfn::Function`: (Optional) The logarithm function to use (`log`, `log2`, or `log10`). Defaults to the natural logarithm (`log`).
# Returns
A vector of `Float64` values where the i-th element represents the von Neumann entropy at site i of the MPS.
"""
function von_neumann_entropy(mps::MPS, logfn::Function=log)
# adapted from http://itensor.org/docs.cgi?page=formulas/entanglement_mps
if !(logfn in (log, log2, log10))
throw(ArgumentError("logfn must be one of: log, log2, or log10"))
end
N = length(mps)
entropy = zeros(Float64, N)
for i in eachindex(entropy)
orthogonalize!(mps, i) # place orthogonality center on site i
S = 0.0
if i == 1 || i == N
_, S, _ = svd(mps[i], (siteind(mps, i))) # make the cut at bond i
else
_, S, _ = svd(mps[i], (linkind(mps, i-1), siteind(mps, i)))
end
SvN = 0.0
for n in 1:ITensors.dim(S, 1)
p = S[n, n]^2
if (p > 1E-12) # avoid log 0
SvN += -p * logfn(p)
end
end
entropy[i] = SvN
end
return entropy
end

"""
```Julia
Compute the bipartite entanglement entropy (BEE) of a trained MPS across each bond.
Given a single unlabeled MPS the BEE is defined as:
∑ α^2 log(α^2)
where α are the eigenvalues obtained from the shmidt decomposition.
```
Compute the bipartite entanglement entropy (BEE) of a trained MPS.
"""
function bipartite_spectrum(mps::TrainedMPS; logfn::Function=log)
if !(logfn in (log, log2, log10))
throw(ArgumentError("logfn must be one of: log, log2, or log10"))
end
mpss, _ = expand_label_index(mps.mps); # expand the label index
bees = Vector{Vector{Float64}}(undef, length(mpss))
for i in eachindex(bees)
bees[i] = von_neumann_entropy(mpss[i], logfn);
end
return bees
end

"""
Check whether the reduced density matrix (rho) is positive semidefinite by
eigendecomposition.
\nIf the eigenvalue decomp of ρ yields negative but small (< tol) eigenvalues,
clamp to them to range [threshold, ∞] and reconstruct ρ.
"""
function rho_correct(rho::Matrix, eigentol::Float64=sqrt(eps()))

eigvals, eigvecs = eigen(rho) # do an eigendecomp on the rdm
neg_eigs = findall(<(0), eigvals) # find negative eigenvalues
if isempty(neg_eigs)
return rho
end
# check eigenvalues within tolerance
oot = findall(x -> x < -eigentol, eigvals) # out of tolerance
if isempty(oot)
# clamp negative eigenvalues to the range [tol, ∞]
eigs_clamped = clamp.(eigvals, eigentol, Inf)
else
throw(DomainError("RDM contains large negative eigenvalues outside of the tolerance $eigentol: λ = $(eigvals[oot]...)"))
end
# reconstruct the rdm with the clamped eigenvalues
rho_corrected = eigvecs * LinearAlgebra.Diagonal(eigs_clamped) * (eigvecs)'
# check trace
if !isapprox(tr(rho_corrected), 1.0)
throw(DomainError("Tr(ρ_corrected) > 1.0!"))
end
return rho_corrected
end

function one_site_rdm(mps::MPS, site::Int)
s = siteinds(mps)
orthogonalize!(mps, site)
psi_dag = dag(mps) # conjugate transpose of MPS
rho = matrix(prime(mps[site], s[site]) * psi_dag[site]) # compute the reduced density matrix
rho_corrected = rho_correct(rho) # clamp negative eigenvalues to pos range
return rho_corrected
end

function single_site_entropy(mps::MPS)
N = length(mps)
entropy = zeros(Float64, N)
for i in 1:N
rho = one_site_rdm(mps, i)
rho_log_rho = rho * log(rho)
entropy[i] = -tr(rho_log_rho)
end
return entropy
end

"""
single_site_spectrum(mps::TrainedMPS) -> Vector{Vector{Float64}}
Compute the single-site entanglement entropy (SEE) spectrum of a trained MPS.
The single-site entanglement entropy (SEE) quantifies the entanglement at each site of the MPS. It is computed as:
[ SEE = -tr(ρ ⋅ log(ρ)) ]
where ρ is the single-site reduced density matrix (RDM).
# Arguments
- `mps::TrainedMPS`: A trained Matrix Product State (MPS) object, which includes the MPS and associated labels.
# Returns
A vector of vectors, where the outer vector corresponds to each label in the expanded MPS, and the inner vectors contain the SEE values for the respective sites.
"""
function single_site_spectrum(mps::TrainedMPS)
# expand the label index
mpss, _ = expand_label_index(mps.mps);
sees = Vector{Vector{Float64}}(undef, length(mpss))
for i in eachindex(sees)
sees[i] = single_site_entropy(mpss[i]);
end
return sees
end
2 changes: 1 addition & 1 deletion src/Encodings/bases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -394,4 +394,4 @@ function project_legendre(xs::AbstractVector{T}, d::Integer; max_series_terms::I
return series_expand(basis, xs_samps, wf, d)
end

project_legendre(Xs::AbstractMatrix{<:Real}, ys::AbstractVector{<:Integer}; opts, kwargs...) = project_legendre(Xs, opts.d; kwargs...)
project_legendre(Xs::AbstractMatrix{<:Real}, ys::AbstractVector{<:Integer}; opts, kwargs...) = project_legendre(Xs, opts.d; kwargs...)
2 changes: 1 addition & 1 deletion src/Encodings/basis_structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,4 +275,4 @@ uniform_split(s::Symbol) = symbolic_encoding(uniform_split(model_encoding(s)))


histogram_split() = histogram_split(uniform())
uniform_split() = uniform_split(uniform())
uniform_split() = uniform_split(uniform())
2 changes: 0 additions & 2 deletions src/Encodings/splitbases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,3 @@ function project_onto_bins(x::Float64, d::Int, ti::Int, all_aux_enc_args::Abstra
return project_onto_bins(x, aux_dim, aux_encoder, bins; norm=norm)

end


2 changes: 1 addition & 1 deletion src/Imputation/MPS_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,4 +433,4 @@ function get_rdms_with_med(
)

return (x_samps, x_wmads, cdfs)
end
end
1 change: 0 additions & 1 deletion src/Imputation/metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,3 @@ function compute_all_forecast_metrics(forecast::Vector{Float64},
return metric_outputs

end

2 changes: 0 additions & 2 deletions src/Imputation/sampling_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,5 +361,3 @@ function compute_entanglement_entropy_profile(class_mps::MPS)
return entropy_vals

end


9 changes: 8 additions & 1 deletion src/MPSTime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ include("utils.jl") # Some utils used by the entire library
# Visualisation utilities
include("Vis/vis_encodings.jl")

# Analysis utilities
include("Analysis/analyse.jl")

include("Training/loss_functions.jl") # Where loss functions and the LossFunction type are defined
include("Training/RealRealHighDimension.jl"); # The training algorithm, fitMPS and co

Expand Down Expand Up @@ -111,6 +114,10 @@ export
# vis
plot_encoding,

# analysis
bipartite_spectrum,
single_site_spectrum,

# Training functions
fitMPS, # gotta fit those MPSs somehow

Expand All @@ -122,4 +129,4 @@ export

# MLJ
MPSClassifier
end
end
1 change: 0 additions & 1 deletion src/Training/RealRealHighDimension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -969,4 +969,3 @@ function fitMPS(W::MPS, training_states_meta::EncodedTimeSeriesSet, testing_stat
return TrainedMPS(W, MPSOptions(opts), opts, training_states_meta), training_information, testing_states_meta

end

6 changes: 1 addition & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,13 @@ function expand_label_index(mps::MPS; lstr="f(x)")
for iv in eachindval(l_ind)
mpsc = deepcopy(mps)
mpsc[pos] = mpsc[pos] * onehot(iv)
normalize!(mpsc)
push!(weights_by_class, mpsc)
end

return Vector{MPS}(weights_by_class), l_ind
end




function saveMPS(mps::MPS, path::String; id::String="W")
"""Saves an MPS as a .h5 file"""
file = path[end-2:end] == ".h5" ? path[1:end-3] : path
Expand All @@ -376,12 +374,10 @@ function saveMPS(mps::MPS, path::String; id::String="W")
println("Succesfully saved mps $id at $file.h5")
end


function get_siteinds(W::MPS)
W1 = deepcopy(W)
pos, label_idx = find_label(W1)
W1[pos] *= onehot(label_idx => 1) # eliminate label index

return siteinds(W1)
end

Loading

0 comments on commit b5b6ae1

Please sign in to comment.