diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 28da183a..a60bb857 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -61,6 +61,7 @@ include("solvers/linsolve.jl") include("solvers/sweep_plans/sweep_plans.jl") include("apply.jl") include("inner.jl") +include("normalize.jl") include("expect.jl") include("environment.jl") include("exports.jl") diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index d5886ec7..787438df 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -311,3 +311,29 @@ end function scalar_factors_quotient(bp_cache::BeliefPropagationCache) return vertex_scalars(bp_cache), edge_scalars(bp_cache) end + +function normalize_messages(bp_cache::BeliefPropagationCache, pes::Vector{<:PartitionEdge}) + bp_cache = copy(bp_cache) + mts = messages(bp_cache) + for pe in pes + me, mer = only(mts[pe]), only(mts[reverse(pe)]) + me, mer = normalize(me), normalize(mer) + n = dot(me, mer) + if isreal(n) && n < 0 + set!(mts, pe, ITensor[(sgn(n) / sqrt(abs(n))) * me]) + set!(mts, reverse(pe), ITensor[(1 / sqrt(abs(n))) * mer]) + else + set!(mts, pe, ITensor[(1 / sqrt(n)) * me]) + set!(mts, reverse(pe), ITensor[(1 / sqrt(n)) * mer]) + end + end + return bp_cache +end + +function normalize_message(bp_cache::BeliefPropagationCache, pe::PartitionEdge) + return normalize_messages(bp_cache, PartitionEdge[pe]) +end + +function normalize_messages(bp_cache::BeliefPropagationCache) + return normalize_messages(bp_cache, partitionedges(partitioned_tensornetwork(bp_cache))) +end diff --git a/src/normalize.jl b/src/normalize.jl new file mode 100644 index 00000000..52eeed5c --- /dev/null +++ b/src/normalize.jl @@ -0,0 +1,91 @@ +using LinearAlgebra + +function rescale(tn::AbstractITensorNetwork; alg="exact", kwargs...) + return rescale(Algorithm(alg), tn; kwargs...) +end + +function rescale( + alg::Algorithm"exact", tn::AbstractITensorNetwork, vs=collect(vertices(tn)); kwargs... +) + logn = logscalar(alg, tn; kwargs...) + c = 1.0 / (exp(logn / length(vs))) + tn = copy(tn) + for v in vs + tn[v] *= c + end + return tn +end + +function rescale( + alg::Algorithm, + tn::AbstractITensorNetwork, + vs=collect(vertices(tn)); + (cache!)=nothing, + cache_construction_kwargs=default_cache_construction_kwargs(alg, tn), + update_cache=isnothing(cache!), + cache_update_kwargs=default_cache_update_kwargs(cache!), +) + if isnothing(cache!) + cache! = Ref(cache(alg, tn; cache_construction_kwargs...)) + end + + if update_cache + cache![] = update(cache![]; cache_update_kwargs...) + end + + tn = copy(tn) + cache![] = normalize_messages(cache![]) + vertices_states = Dictionary() + for pv in partitionvertices(cache![]) + pv_vs = filter(v -> v ∈ vs, vertices(cache![], pv)) + + isempty(pv_vs) && continue + + vn = region_scalar(cache![], pv) + if isreal(vn) && vn < 0 + tn[first(pv_vs)] *= -1 + vn = abs(vn) + end + + vn = vn^(1 / length(pv_vs)) + for v in pv_vs + tn[v] /= vn + set!(vertices_states, v, tn[v]) + end + end + + cache![] = update_factors(cache![], vertices_states) + return tn +end + +function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...) + return normalize(Algorithm(alg), tn; kwargs...) +end + +function LinearAlgebra.normalize( + alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs... +) + norm_tn = QuadraticFormNetwork(tn) + vs = filter(v -> v ∉ operator_vertices(norm_tn), collect(vertices(norm_tn))) + return ket_network(rescale(alg, norm_tn, vs; kwargs...)) +end + +function LinearAlgebra.normalize( + alg::Algorithm, + tn::AbstractITensorNetwork; + (cache!)=nothing, + cache_construction_function=tn -> + cache(alg, tn; default_cache_construction_kwargs(alg, tn)...), + update_cache=isnothing(cache!), + cache_update_kwargs=default_cache_update_kwargs(cache!), +) + norm_tn = QuadraticFormNetwork(tn) + if isnothing(cache!) + cache! = Ref(cache_construction_function(norm_tn)) + end + + vs = filter(v -> v ∉ operator_vertices(norm_tn), collect(vertices(norm_tn))) + norm_tn = rescale(alg, norm_tn, vs; cache!, update_cache, cache_update_kwargs) + + return ket_network(norm_tn) +end diff --git a/test/test_normalize.jl b/test/test_normalize.jl new file mode 100644 index 00000000..cd95d635 --- /dev/null +++ b/test/test_normalize.jl @@ -0,0 +1,52 @@ +@eval module $(gensym()) +using ITensorNetworks: + BeliefPropagationCache, + QuadraticFormNetwork, + edge_scalars, + norm_sqr_network, + random_tensornetwork, + vertex_scalars, + rescale +using ITensors: dag, inner, siteinds, scalar +using Graphs: SimpleGraph, uniform_tree +using LinearAlgebra: normalize +using NamedGraphs: NamedGraph +using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree +using StableRNGs: StableRNG +using Test: @test, @testset +@testset "Normalize" begin + + #First lets do a flat tree + nx, ny = 2, 3 + χ = 2 + rng = StableRNG(1234) + + g = named_comb_tree((nx, ny)) + tn = random_tensornetwork(rng, g; link_space=χ) + + tn_r = rescale(tn; alg="exact") + @test scalar(tn_r; alg="exact") ≈ 1.0 + + tn_r = rescale(tn; alg="bp") + @test scalar(tn_r; alg="exact") ≈ 1.0 + + #Now a state on a loopy graph + Lx, Ly = 3, 2 + χ = 2 + rng = StableRNG(1234) + + g = named_grid((Lx, Ly)) + s = siteinds("S=1/2", g) + x = random_tensornetwork(rng, s; link_space=χ) + + ψ = normalize(x; alg="exact") + @test scalar(norm_sqr_network(ψ); alg="exact") ≈ 1.0 + + ψIψ_bpc = Ref(BeliefPropagationCache(QuadraticFormNetwork(x))) + ψ = normalize(x; alg="bp", (cache!)=ψIψ_bpc, update_cache=true) + ψIψ_bpc = ψIψ_bpc[] + @test all(x -> x ≈ 1.0, edge_scalars(ψIψ_bpc)) + @test all(x -> x ≈ 1.0, vertex_scalars(ψIψ_bpc)) + @test scalar(QuadraticFormNetwork(ψ); alg="bp") ≈ 1.0 +end +end