Skip to content

Commit

Permalink
Make CountingAll exact (#91)
Browse files Browse the repository at this point in the history
* make counting all computational exact

* fix CountingAll for GPU
  • Loading branch information
GiggleLiu authored Dec 15, 2024
1 parent fe56997 commit 4759afb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
16 changes: 10 additions & 6 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ max_k(::SizeMin{K}) where K = K
CountingAll <: AbstractProperty
CountingAll()
Counting the total number of sets. e.g. for the [`IndependentSet`](@ref) problem, it counts the independent sets.
Counting the total number of sets exactly without overflow. e.g. for the [`IndependentSet`](@ref) problem, it counts the independent sets.
Note that `PartitionFunction(0.0)` also does the counting. It is more efficient, but uses floating point numbers, which does not have arbitrary precision.
* The corresponding tensor element type is `Base.Real`.
* The corresponding tensor element type is `BigInt`.
* The weights on graph does not have effect.
* BLAS (GPU and CPU) and GPU are supported,
"""
Expand Down Expand Up @@ -261,7 +262,10 @@ function solve(gp::GenericTensorNetwork, property::AbstractProperty; T=Float64,
res = contractx(gp, _x(ExtendedTropical{max_k(property), Tropical{T}}; invert=true); usecuda=usecuda)
return asarray(post_invert_exponent.(res), res)
elseif property isa CountingAll
return contractx(gp, one(T); usecuda=usecuda)
return big_integer_solve(Int32, 100) do T
# NOTE: download to CPU after computation for post-processing (CRT)
Array(contractx(gp, one(T); usecuda=usecuda))
end
elseif property isa PartitionFunction
return contractx(gp, exp(property.beta); usecuda=usecuda)
elseif property isa CountingMax{Single}
Expand Down Expand Up @@ -433,15 +437,15 @@ function _estimate_memory(::Type{ET}, problem::GenericTensorNetwork) where ET
end

for (PROP, ET) in [
(:(PartitionFunction{T}), :(T)),
(:(SizeMax{Single}), :(Tropical{T})), (:(SizeMin{Single}), :(Tropical{T})),
(:(CountingAll), :T), (:(CountingMax{Single}), :(CountingTropical{T,T})), (:(CountingMin{Single}), :(CountingTropical{T,T})),
(:(CountingAll), :Int32), (:(CountingMax{Single}), :(CountingTropical{T,T})), (:(CountingMin{Single}), :(CountingTropical{T,T})),
(:(GraphPolynomial{:polynomial}), :(Polynomial{T, :x})), (:(GraphPolynomial{:fitting}), :T),
(:(GraphPolynomial{:laurent}), :(LaurentPolynomial{T, :x})), (:(GraphPolynomial{:fft}), :(Complex{T})),
(:(GraphPolynomial{:finitefield}), :(Mod{N,Int32} where N))
]
@eval tensor_element_type(::Type{T}, n::Int, num_flavors::Int, ::$PROP) where {T} = $ET
end
tensor_element_type(::Type{T}, n::Int, num_flavors::Int, ::PartitionFunction{T2}) where {T, T2} = T2
for (PROP, ET) in [
(:(SizeMax{K}), :(ExtendedTropical{K,Tropical{T}})), (:(SizeMin{K}), :(ExtendedTropical{K,Tropical{T}})),
(:(CountingMax{K}), :(TruncatedPoly{K,T,T})), (:(CountingMin{K}), :(TruncatedPoly{K,T,T})),
Expand Down Expand Up @@ -497,4 +501,4 @@ end
function Base.findmax(problem::AbstractProblem, solver::GTNSolver)
res = collect(solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), ConfigsMax(; tree_storage=true); usecuda=solver.usecuda, T=solver.T)[].c)
return map(x -> ProblemReductions.id_to_config(problem, Int.(x) .+ 1), res)
end
end
3 changes: 2 additions & 1 deletion test/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,11 @@ end
ConfigsMax(;bounded=true), ConfigsMin(;bounded=true), ConfigsMax(2;bounded=true), ConfigsMin(2;bounded=true),
ConfigsMax(;bounded=false), ConfigsMin(;bounded=false), ConfigsMax(2;bounded=false), ConfigsMin(2;bounded=false), SingleConfigMax(;bounded=false), SingleConfigMin(;bounded=false),
CountingAll(), ConfigsAll(), SingleConfigMax(2), SingleConfigMin(2), SingleConfigMax(2; bounded=true), SingleConfigMin(2,bounded=true),
PartitionFunction(0.0)
]
@show property
ET = GenericTensorNetworks.tensor_element_type(Float32, 10, 2, property)
@test eltype(solve(gp, property, T=Float32)) <: ET
@test eltype(solve(gp, property, T=Float32)) <: (property isa CountingAll ? BigInt : ET)
@test estimate_memory(gp, property) isa Integer
end
@test GenericTensorNetworks.tensor_element_type(Float32, 10, 2, GraphPolynomial(method=:polynomial)) == Polynomial{Float32, :x}
Expand Down

0 comments on commit 4759afb

Please sign in to comment.