Skip to content

Commit

Permalink
Decouple size dictionary from EinExprs (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing authored Dec 28, 2023
1 parent 39f52af commit 316069d
Show file tree
Hide file tree
Showing 22 changed files with 380 additions and 292 deletions.
23 changes: 10 additions & 13 deletions benchmark/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,15 @@ suite["greedy"] = BenchmarkGroup([])
suite["kahypar"] = BenchmarkGroup([])

# BENCHMARK 1
expr = EinExpr(
Symbol[],
[
EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])),
EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])),
EinExpr([:j], Dict(i => 2 for i in [:j])),
EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])),
EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])),
EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])),
EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])),
],
)
expr = sum([
EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])),
EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])),
EinExpr([:j], Dict(i => 2 for i in [:j])),
EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])),
EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])),
EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])),
EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])),
])

suite["naive"][1] = @benchmarkable einexpr(EinExprs.Naive(), $expr)
suite["exhaustive"][1] = @benchmarkable einexpr(Exhaustive(), $expr)
Expand All @@ -41,7 +38,7 @@ D = EinExpr([:c, :h, :d, :i], Dict(:c => 2, :h => 2, :d => 2, :i => 2))
E = EinExpr([:f, :i, :g, :j], Dict(:f => 2, :i => 2, :g => 2, :j => 2))
F = EinExpr([:B, :h, :k, :l], Dict(:B => 2, :h => 2, :k => 2, :l => 2))
G = EinExpr([:j, :k, :l, :D], Dict(:j => 2, :k => 2, :l => 2, :D => 2))
expr = EinExpr([:A, :B, :C, :D], [A, B, C, D, E, F, G])
expr = sum([A, B, C, D, E, F, G], skip = [:A, :B, :C, :D])

suite["naive"][2] = @benchmarkable einexpr(EinExprs.Naive(), $expr)
suite["exhaustive"][2] = @benchmarkable einexpr(Exhaustive(), $expr)
Expand Down
10 changes: 5 additions & 5 deletions ext/EinExprsMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ const MAX_EDGE_WIDTH = 10.0
const MAX_ARROW_SIZE = 35.0
const MAX_NODE_SIZE = 40.0

function Makie.plot(path::EinExpr; kwargs...)
function Makie.plot(path::SizedEinExpr; kwargs...)
f = Figure()
ax, p = plot!(f[1, 1], path; kwargs...)
return Makie.FigureAxisPlot(f, ax, p)
end

function Makie.plot!(f::Union{Figure,GridPosition}, path::EinExpr; kwargs...)
function Makie.plot!(f::Union{Figure,GridPosition}, path::SizedEinExpr; kwargs...)
ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3
Axis3(f[1, 1])
else
Expand Down Expand Up @@ -65,13 +65,13 @@ end
# TODO replace `to_colormap(:viridis)[begin:end-10]` with a custom colormap
function Makie.plot!(
ax::Union{Axis,Axis3},
path::EinExpr;
path::SizedEinExpr;
colormap = to_colormap(:viridis)[begin:end-10],
inds = false,
kwargs...,
)
handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path)))
graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path) for from in to.args])
handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path.path)))
graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path.path) for from in to.args])

lin_size = length.(PostOrderDFS(path))[1:end-1]
lin_flops = map(max, Iterators.repeated(1), Iterators.map(flops, PostOrderDFS(path)))
Expand Down
26 changes: 22 additions & 4 deletions src/Counters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,41 @@
Count the number of mathematical operations will be performed by the contraction of the root of the `path` tree.
"""
flops(expr::EinExpr) =
if length(expr.args) == 0 || length(expr.args) == 1 && isempty(suminds(expr))
flops(sexpr::SizedEinExpr) =
if nargs(sexpr) == 0 || nargs(sexpr) == 1 && isempty(suminds(sexpr))
0
else
mapreduce(Base.Fix1(size, expr), *, Iterators.flatten((head(expr), suminds(expr))), init = one(BigInt))
mapreduce(
Base.Fix1(getindex, sexpr.size),
*,
Iterators.flatten((head(sexpr), suminds(sexpr))),
init = one(BigInt),
)
end

flops(expr::EinExpr, size) = flops(SizedEinExpr(expr, size))

"""
removedsize(path::EinExpr)
Count the amount of memory that will be freed after performing the contraction of the root of the `path` tree.
"""
removedsize(expr::EinExpr) = mapreduce(prod size, +, expr.args) - prod(size(expr))
removedsize(sexpr::SizedEinExpr) = -length(sexpr) + mapreduce(+, sexpr.args) do arg
length(SizedEinExpr(arg, sexpr.size))
end

removedsize(expr::EinExpr, size) = removedsize(SizedEinExpr(expr, size))

"""
removedrank(path::EinExpr)
Count the rank reduction after performing the contraction of the root of the `path` tree.
"""
removedrank(expr::EinExpr) = mapreduce(ndims, max, expr.args) - ndims(expr)
removedrank(expr::EinExpr, _) = removedrank(expr)
removedrank(sexpr::SizedEinExpr, _) = removedrank(sexpr.path)

for f in [:flops, :removedsize]
@eval $f(sizedict::Dict{Symbol}) = Base.Fix2($f, sizedict)
end
removedrank(::Dict) = removedrank
62 changes: 34 additions & 28 deletions src/EinExpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,18 @@ using Base: AbstractVecOrTuple
using DataStructures: DefaultDict
using AbstractTrees

struct EinExpr
Base.@kwdef struct EinExpr
head::Vector{Symbol}
args::Vector{EinExpr}
size::Dict{Symbol,Int}

# TODO checks: same dim for index, valid indices
EinExpr(head, args) = new(head, args, Dict{Symbol,EinExpr}())

function EinExpr(head::AbstractVector{Symbol}, size::AbstractDict{Symbol,Int})
head keys(size) || throw(ArgumentError("Missing sizes for indices $(setdiff(head, keys(size)))"))
new(head, EinExpr[], size)
end
args::Vector{EinExpr} = EinExpr[]
end

EinExpr(head) = EinExpr(head, EinExpr[])
EinExpr(head, args::AbstractVecOrTuple{<:AbstractVecOrTuple{Symbol}}) = EinExpr(head, map(EinExpr, args))

EinExpr(head::NTuple, args) = EinExpr(collect(head), args)
EinExpr(head, args::NTuple) = EinExpr(head, collect(args))
EinExpr(head::NTuple, args::NTuple) = EinExpr(collect(head), collect(args))

function EinExpr(head, args::AbstractVecOrTuple{<:AbstractVecOrTuple{Symbol}}, sizes)
args = map(args) do arg
sizedict = filter((arg) first, sizes)
EinExpr(arg, sizedict)
end
EinExpr(head, args)
end

"""
head(path::EinExpr)
Expand All @@ -46,6 +32,8 @@ See also: [`head`](@ref).
"""
args(path::EinExpr) = path.args

nargs(path::EinExpr) = length(path.args)

"""
inds(path)
Expand Down Expand Up @@ -100,11 +88,8 @@ Base.ndims(path::EinExpr) = length(head(path))
Return the size of the resulting tensor from contracting `path`. If `index` is specified, return the size of such index.
"""
Base.size(path::EinExpr) = (size(path, i) for i in head(path)) |> splat(tuple)
Base.size(path::EinExpr, i::Symbol) =
Iterators.filter((i) head, Leaves(path)) |> first |> Base.Fix2(getproperty, :size) |> Base.Fix2(getindex, i)

Base.length(path::EinExpr) = (prod size)(path)
Base.size(path::EinExpr, sizedict) = (sizedict[i] for i in head(path)) |> splat(tuple)
Base.length(path::EinExpr, sizedict) = (prod size)(path, sizedict)

"""
collapse!(path::EinExpr)
Expand Down Expand Up @@ -241,24 +226,45 @@ Create an `EinExpr` from other `EinExpr`s.
function Base.sum(args::Vector{EinExpr}; skip = Symbol[])
_head = Symbol[]
_counts = Int[]

for arg in args
for index in head(arg)
i = findfirst(Base.Fix1(===, index), _head)
if isnothing(i)
push!(_head, index)
push!(_counts, 1)
else
_counts[i] += 1
@inbounds _counts[i] += 1
end
end
end

_head = map(first, Iterators.filter(zip(_head, _counts)) do (index, count)
count == 1 || index skip
end)
# NOTE `map` with `Iterators.filter` induces many heap grows; allocating once and deleting is faster
for i in Iterators.reverse(eachindex(_head, _counts))
(_counts[i] == 1 || _head[i] skip) && continue
deleteat!(_head, i)
end

EinExpr(_head, args)
end

function Base.sum(a::EinExpr, b::EinExpr; skip = Symbol[])
_head = copy(head(a))

for index in head(b)
i = findfirst(Base.Fix1(===, index), _head)
if isnothing(i)
push!(_head, index)
elseif index skip
continue
else
deleteat!(_head, i)
end
end

EinExpr(_head, [a, b])
end

function Base.string(path::EinExpr; recursive::Bool = false)
!recursive && return "$(join(map(x -> string.(head(x)) |> join, args(path)), ","))->$(string.(head(path)) |> join)"
map(string, Branches(path))
Expand Down
3 changes: 3 additions & 0 deletions src/EinExprs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ export EinExpr
export head, args, inds, hyperinds, suminds, parsuminds, collapse!, contractorder, select, neighbours
export Branches, branches, leaves

include("SizedEinExpr.jl")
export SizedEinExpr

include("Counters.jl")
export flops, removedsize

Expand Down
34 changes: 20 additions & 14 deletions src/Optimizers/Exhaustive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,41 @@ The algorithm has a ``\mathcal{O}(n!)`` time complexity if `outer = true` and ``
end

function einexpr(config::Exhaustive, path; cost = BigInt(0))
leader = Ref{NamedTuple{(:path, :cost),Tuple{EinExpr,BigInt}}}((;
# metric = Base.Fix2(config.metric, path.size)
leader = Ref((;
path = einexpr(Naive(), path),
cost = mapreduce(config.metric, +, Branches(einexpr(Naive(), path), inverse = true), init = BigInt(0))::BigInt,
))
cache = Dict{Vector{Symbol},BigInt}()
__einexpr_exhaustive_it(path, cost, config.metric, config.outer, leader, cache)
__einexpr_exhaustive_it(path, cost, Val(config.metric), config.outer, leader)
return leader[].path
end

function __einexpr_exhaustive_it(path, cost, metric, outer, leader, cache)
if length(path.args) == 1
# remove identity einsum (i.e. "i...->i...")
path = path.args[1]

leader[] = (; path, cost = mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))::BigInt)
function __einexpr_exhaustive_it(
path,
cost,
@specialize(metric::Val{Metric}),
outer,
leader;
cache = Dict{Vector{Symbol},BigInt}(),
hashyperinds = !isempty(hyperinds(path)),
) where {Metric}
if nargs(path) <= 2
#= mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))) =#
leader[] = (; path = path, cost = cost)
return
end

for (i, j) in combinations(args(path), 2)
for (i, j) in combinations(path.args, 2)
!outer && isdisjoint(head(i), head(j)) && continue
candidate = sum([i, j], skip = path.head hyperinds(path))
candidate = sum(i, j; skip = hashyperinds ? path.head hyperinds(path) : path.head)

# prune paths based on metric
new_cost = cost + get!(cache, head(candidate)) do
metric(candidate)
Metric(SizedEinExpr(candidate, path.size))
end
new_cost >= leader[].cost && continue

new_path = EinExpr(head(path), [candidate, filter(([i, j]), args(path))...])
__einexpr_exhaustive_it(new_path, new_cost, metric, outer, leader, cache)
new_path = SizedEinExpr(EinExpr(head(path), [candidate, filter(([i, j]), path.args)...]), path.size) # sum([candidate, filter(∉([i, j]), args(path))...], skip = path.head)
__einexpr_exhaustive_it(new_path, new_cost, metric, outer, leader; cache, hashyperinds)
end
end
12 changes: 8 additions & 4 deletions src/Optimizers/Greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ The implementation uses a binary heaptree to sort candidate pairwise tensor cont
outer::Bool = false
end

function einexpr(config::Greedy, path)
function einexpr(config::Greedy, path, sizedict)
metric = config.metric(sizedict)

# generate initial candidate contractions
queue = MutableBinaryHeap{Tuple{Float64,EinExpr}}(
Base.By(first, Base.Reverse),
Expand All @@ -36,12 +38,12 @@ function einexpr(config::Greedy, path)
) do (a, b)
# TODO don't consider outer products
candidate = sum([a, b], skip = path.head hyperinds(path))
weight = config.metric(candidate)
weight = metric(candidate)
(weight, candidate)
end,
)

while length(path.args) > 2 && length(queue) > 1
while nargs(path) > 2 && length(queue) > 1
# choose winner
_, winner = config.choose(queue)

Expand All @@ -55,7 +57,7 @@ function einexpr(config::Greedy, path)
for other in Iterators.filter(other -> config.outer || !isdisjoint(winner.head, other.head), path.args)
# TODO don't consider outer products
candidate = sum([winner, other], skip = path.head hyperinds(path))
weight = config.metric(candidate)
weight = metric(candidate)
push!(queue, (weight, candidate))
end

Expand All @@ -65,3 +67,5 @@ function einexpr(config::Greedy, path)

return path
end

einexpr(config::Greedy, path::SizedEinExpr) = SizedEinExpr(einexpr(config, path.path, path.size), path.size)
12 changes: 6 additions & 6 deletions src/Optimizers/KaHyPar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Suppressor
@kwdef struct HyPar <: Optimizer
parts::Int = 2
imbalance::Float32 = 0.03
stop::Function = <=(2) length Base.Fix2(getfield, :args)
stop::Function = <=(2) length Base.Fix2(getproperty, :args)
configuration::Union{Nothing,Symbol,String} = nothing
edge_scaler::Function = Base.Fix1(*, 1000) Int round log2
vertex_scaler::Function = Base.Fix1(*, 1000) Int round log2
Expand All @@ -26,7 +26,7 @@ function EinExprs.einexpr(config::HyPar, path)

# NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order
edge_weights = map(config.edge_scaler Base.Fix1(size, path), inds)
vertex_weights = map(config.vertex_scaler length, path.args)
vertex_weights = map(config.vertex_scaler length, args(path))

hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights)
KaHyPar.kahypar_set_seed(hypergraph.context, config.seed)
Expand All @@ -38,13 +38,13 @@ function EinExprs.einexpr(config::HyPar, path)
configuration = config.configuration,
)

args = map(unique(partitions)) do partition
_args = map(unique(partitions)) do partition
selection = partitions .== partition
count(selection) == 1 && return only(path.args[selection])
count(selection) == 1 && return only(args(path)[selection])

expr = sum(path.args[selection], skip = path.head)
expr = sum(args(path)[selection], skip = path.head)
einexpr(config, expr)
end

return EinExpr(path.head, args)
return sum(_args, skip = path.head)
end
12 changes: 9 additions & 3 deletions src/Optimizers/Naive.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
using AbstractTrees

struct Naive <: Optimizer end

einexpr(::Naive, path, _) = einexpr(Naive(), path)

function einexpr(::Naive, path)
hist = Dict(i => count((i) head, path.args) for i in hyperinds(path))
hist = Dict(i => count((i) head, args(path)) for i in hyperinds(path))

foldl(path.args) do a, b
expr = sum([a, b], skip = path.head collect(keys(hist)))
foldl(args(path)) do a, b
expr = sum([a, b], skip = head(path) collect(keys(hist)))

for i in Iterators.filter((keys(hist)), (head(a), head(b)))
hist[i] -= 1
Expand All @@ -14,3 +18,5 @@ function einexpr(::Naive, path)
return expr
end
end

einexpr(::Naive, path::SizedEinExpr) = SizedEinExpr(einexpr(Naive(), path.path), path.size)
Loading

0 comments on commit 316069d

Please sign in to comment.