Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can not infer term type in egraph #125

Open
vitrun opened this issue May 26, 2022 · 1 comment
Open

Can not infer term type in egraph #125

vitrun opened this issue May 26, 2022 · 1 comment

Comments

@vitrun
Copy link

vitrun commented May 26, 2022

Something is wrong when matching against pure literal rules, such as pi + 3 --> cos(4), im+(pi+3) --> sin(4). The following demo can reproduce the issue. I've tried different versions including master and v1.3.3.

using Metatheory
using Metatheory.EGraphs
using TermInterface

struct Term{T}
  f::Any
  args::Vector{Any}
end

function Term(f, args)
  T = if length(args) == 0
    Any
  elseif length(args) == 1
    promote_type(symtype(args[1]))
  else
    promote_type(symtype(args[1]), symtype(args[2]))
  end
  Term{T}(f, args)
end

Base.promote_type(::Type{Irrational{:π}}, ::Type{Int64}) = Real


TermInterface.exprhead(e::Term) = :call
TermInterface.operation(e::Term) = e.f
TermInterface.arguments(e::Term) = e.args
TermInterface.istree(e::Term) = true
TermInterface.symtype(::Term{T}) where {T} = T
TermInterface.symtype(::T) where {T} = T

function TermInterface.similarterm(x::Term, head, args; metadata = nothing, exprhead = :call)
  Term(head, args)
end

function EGraphs.egraph_reconstruct_expression(
  T::Type{Term{S}},
  op,
  args;
  metadata = nothing,
  exprhead = nothing,
) where {S}
  Term(op, args)
end

pt = @theory a b c  begin
  im + (pi + 3) --> sin(4)
  # pi + 3 --> cos(4)
end


# let's create an egraph 
ex = Term(+, [im, Term(+, [pi, 3])])
g = EGraph(ex)

settermtype!(g, Term{symtype(ex)})
# settermtype!(g, :+, 2, Term{Real})

saturate!(g, pt)
r = extract!(g, astsize)
println(r)

I digged into the code and found following function in ematch.jl, which I believe is to blame.

function lookup_pat(g::EGraph, p::PatTerm)
  @assert isground(p)

  eh = exprhead(p)
  op = operation(p)
  args = arguments(p)
  ar = arity(p)

  T = gettermtype(g, op, ar)

  ids = [lookup_pat(g, pp) for pp in args]
  if all(i -> i isa EClassId, ids)
    n = ENodeTerm{T}(eh, op isa Symbol ? eval(op) : op, ids)
    ec = lookup(g, n)
    mn = ENodeTerm{T}(eh, +, [1, 2])
    ec2 = lookup(g, mn)
    println("T: $T, op: $(typeof(op)), n: $n, ec: $ec, mn:$mn, ec2: $ec2")
    return ec
  else
    return nothing
  end
end

In the demo above, overloaded promote_type is used to decide the type parameter of Term{T}, and egraph has no idea of it. Meanwhile, is T = gettermtype(g, op, ar) sufficient to decide the type of an enode term? I doubt that. Consider pi + 3 and im +3, they have same op + and ar 2, but the resulting termtypes are totally different.

@vitrun
Copy link
Author

vitrun commented May 30, 2022

Hi @0x0f0f0f , any idea about this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant