Skip to content

Commit

Permalink
Update code to SizedEinExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jan 18, 2024
1 parent 6306af0 commit 2ee795e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
10 changes: 5 additions & 5 deletions ext/EinExprsGraphMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ const MAX_EDGE_WIDTH = 10.0
const MAX_ARROW_SIZE = 35.0
const MAX_NODE_SIZE = 40.0

function Makie.plot(path::EinExpr; kwargs...)
function Makie.plot(path::SizedEinExpr; kwargs...)
f = Figure()
ax, p = plot!(f[1, 1], path; kwargs...)
return Makie.FigureAxisPlot(f, ax, p)
end

function Makie.plot!(f::Union{Figure,GridPosition}, path::EinExpr; kwargs...)
function Makie.plot!(f::Union{Figure,GridPosition}, path::SizedEinExpr; kwargs...)
ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3
Axis3(f[1, 1])
else
Expand Down Expand Up @@ -65,13 +65,13 @@ end
# TODO replace `to_colormap(:viridis)[begin:end-10]` with a custom colormap
function Makie.plot!(
ax::Union{Axis,Axis3},
path::EinExpr;
path::SizedEinExpr;
colormap = to_colormap(:viridis)[begin:end-10],
inds = false,
kwargs...,
)
handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path)))
graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path) for from in to.args])
handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path.path)))
graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path.path) for from in to.args])

lin_size = length.(PostOrderDFS(path))[1:end-1]
lin_flops = map(max, Iterators.repeated(1), Iterators.map(flops, PostOrderDFS(path)))
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 comments on commit 2ee795e

Please sign in to comment.