diff --git a/effekt/jvm/src/test/scala/effekt/core/InterpreterTests.scala b/effekt/jvm/src/test/scala/effekt/core/InterpreterTests.scala new file mode 100644 index 000000000..d6fb5b243 --- /dev/null +++ b/effekt/jvm/src/test/scala/effekt/core/InterpreterTests.scala @@ -0,0 +1,321 @@ +package effekt +package core + +import effekt.core.Interpreter.{ InterpreterError, State } +import effekt.source.FeatureFlag +import effekt.symbols.QualifiedName +import effekt.symbols.given +import kiama.util.FileSource + +class InterpreterTests extends munit.FunSuite { + + import effekt.context.{ Context, IOModuleDB } + import effekt.util.AnsiColoredMessaging + import kiama.util.{ Positions, StringSource } + + // object driver extends effekt.Driver + // + // def run(content: String): String = + // var options = Seq( + // "--Koutput", "string", + // "--backend", "js", + // ) + // val configs = driver.createConfig(options) + // configs.verify() + // + // val compiler = new TestFrontend + // compiler.compile(StringSource(content, "input.effekt"))(using context).map { + // case (_, decl) => decl + // } + // configs.stringEmitter.result() + val positions = new Positions + object ansiMessaging extends AnsiColoredMessaging + object context extends Context(positions) with IOModuleDB { + val messaging = ansiMessaging + object testFrontend extends TestFrontend + override lazy val compiler = testFrontend.asInstanceOf + } + + def run(content: String) = + val config = new EffektConfig(Seq("--Koutput", "string")) + config.verify() + context.setup(config) + context.testFrontend.compile(StringSource(content, "input.effekt"))(using context).map { + case (_, decl) => decl + } + + def runFile(path: String) = + val config = new EffektConfig(Seq("--Koutput", "string")) + config.verify() + context.setup(config) + context.testFrontend.compile(FileSource(path))(using context).map { + case (_, decl) => decl + } + + + val recursion = + """def countdown(n: Int): Int = + | if (n == 0) 42 + | else countdown(n - 1) + | + |def fib(n: Int): Int = + | if (n == 0) 1 + | else if (n == 1) 1 + | else fib(n - 2) + fib(n - 1) + | + |def main() = { + | println(fib(10)) + |} + |""".stripMargin + + val dynamicDispatch = """def size[T](l: List[T]): Int = + | l match { + | case Nil() => 0 + | case Cons(hd, tl) => 1 + size(tl) + | } + | + |def map[A, B](l: List[A]) { f: A => B }: List[B] = + | l match { + | case Nil() => Nil() + | case Cons(hd, tl) => Cons(f(hd), map(tl){f}) + | } + | + |def main() = { + | println(size([1, 2, 3].map { x => x + 1 })) + |} + |""".stripMargin + + val eraseUnused = + """def replicate(v: Int, n: Int, a: List[Int]): List[Int] = + | if (n == 0) { + | a + | } else { + | replicate(v, n - 1, Cons(v, a)) + | } + | + |def useless(i: Int, n: Int, _: List[Int]): Int = + | if (i < n) { + | useless(i + 1, n, replicate(0, i, Nil())) + | } else { + | i + | } + | + |def run(n: Int) = + | useless(0, n, Nil()) + | + |def main() = { + | println(run(10)) + |} + |""".stripMargin + + val simpleObject = + """interface Counter { + | def inc(): Unit + | def get(): Int + |} + | + |def main() = { + | def c = new Counter { + | def inc() = println("tick") + | def get() = 0 + | }; + | c.inc() + | c.inc() + | c.inc() + |} + | + |""".stripMargin + + val factorialAccumulator = + """import examples/benchmarks/runner + | + |def factorial(a: Int, i: Int): Int = + | if (i == 0) { + | a + | } else { + | factorial((i * a).mod(1000000007), i - 1) + | } + | + |def run(n: Int): Int = + | factorial(1, n) + | + |def main() = benchmark(5){run} + | + |""".stripMargin + + val sort = + """import list + | + |def main() = { + | // synchronized with doctest in `sortBy` + | println([1, 3, -1, 5].sortBy { (a, b) => a <= b }) + |} + |""".stripMargin + + val mutableState = + """def main() = { + | var x = 42; + | x = x + 1; + | println(x.show) + | [1, 2, 3].map { x => x + 1 }.foreach { x => println(x) }; + | + | region r { + | var x in r = 42; + | x = x + 1 + | println(x) + | } + |} + |""".stripMargin + + val simpleException = + """effect raise(): Unit + | + |def main() = { + | try { + | println("before"); + | do raise() + | println("after") + | } with raise { println("caught") } + |} + | + |""".stripMargin + + val triples = + """import examples/benchmarks/runner + | + |record Triple(a: Int, b: Int, c: Int) + | + |interface Flip { + | def flip(): Bool + |} + | + |interface Fail { + | def fail(): Nothing + |} + | + |def choice(n: Int): Int / {Flip, Fail} = { + | if (n < 1) { + | do fail() match {} + | } else if (do flip()) { + | n + | } else { + | choice(n - 1) + | } + |} + | + |def triple(n: Int, s: Int): Triple / {Flip, Fail} = { + | val i = choice(n) + | val j = choice(i - 1) + | val k = choice(j - 1) + | if (i + j + k == s) { + | Triple(i, j, k) + | } else { + | do fail() match {} + | } + |} + | + |def hash(triple: Triple): Int = triple match { + | case Triple(a, b, c) => mod(((53 * a) + 2809 * b + 148877 * c), 1000000007) + |} + | + |def run(n: Int) = + | try { + | hash(triple(n, n)) + | } with Flip { + | def flip() = mod(resume(true) + resume(false), 1000000007) + | } with Fail { + | def fail() = 0 + | } + | + |def main() = benchmark(10){run} + | + | + |""".stripMargin + + // doesn't work: product_early (since it SO due to run run run) + + def runTest(file: String): Unit = + + val Some(main, mod, decl) = runFile(file): @unchecked + + val gced = Deadcode.remove(main, decl) + + val inlined = Inline.full(Set(main), gced, 40) + + try { + object data extends Counting { + override def step(state: Interpreter.State) = state match { + case State.Done(result) => ??? + case State.Step(stmt, env, stack) => + //println(Interpreter.show(stack)) + } + } + Interpreter(data).run(main, inlined) + + data.report() + + } catch { + case err: InterpreterError => + err match { + case InterpreterError.NotFound(id) => println(s"Not found: ${util.show(id)}") + case InterpreterError.NotAnExternFunction(id) => err.printStackTrace() + case InterpreterError.MissingBuiltin(name) => println(s"Missing ${name}") + case InterpreterError.RuntimeTypeError(msg) => err.printStackTrace() + case InterpreterError.NonExhaustive(missingCase) => err.printStackTrace() + case InterpreterError.Hole() => err.printStackTrace() + case InterpreterError.NoMain() => err.printStackTrace() + } + } + + // TODO allocate arrays and ref on a custom heap that could be inspected and visualized + + runTest("examples/benchmarks/are_we_fast_yet/bounce.effekt") + runTest("examples/benchmarks/are_we_fast_yet/list_tail.effekt") + runTest("examples/benchmarks/are_we_fast_yet/mandelbrot.effekt") + runTest("examples/benchmarks/are_we_fast_yet/nbody.effekt") + + // global is missing + //runTest("examples/benchmarks/are_we_fast_yet/permute.effekt") + //runTest("examples/benchmarks/are_we_fast_yet/storage.effekt") + + runTest("examples/benchmarks/are_we_fast_yet/queens.effekt") + runTest("examples/benchmarks/are_we_fast_yet/sieve.effekt") + runTest("examples/benchmarks/are_we_fast_yet/towers.effekt") +} + +class TestFrontend extends Compiler[(Id, symbols.Module, core.ModuleDecl)] { + + + import effekt.PhaseResult.CoreTransformed + import effekt.context.Context + import kiama.output.PrettyPrinterTypes.Document + import kiama.util.Source + + + // Implementation of the Compiler Interface: + // ----------------------------------------- + def extension = ".class" + + override def supportedFeatureFlags: List[String] = List("jvm") + + override def prettyIR(source: Source, stage: Stage)(using C: Context): Option[Document] = None + + override def treeIR(source: Source, stage: Stage)(using Context): Option[Any] = None + + override def compile(source: Source)(using C: Context): Option[(Map[String, String], (Id, symbols.Module, core.ModuleDecl))] = + Optimized.run(source).map { res => (Map.empty, res) } + + + // The Compilation Pipeline + // ------------------------ + // Source => Core => CPS => JS + lazy val Core = Phase.cached("core") { + Frontend andThen Middleend + } + + lazy val Optimized = allToCore(Core) andThen Aggregate map { + case input @ CoreTransformed(source, tree, mod, core) => + val mainSymbol = Context.checkMain(mod) + (mainSymbol, mod, core) + } +} diff --git a/effekt/shared/src/main/scala/effekt/core/Deadcode.scala b/effekt/shared/src/main/scala/effekt/core/Deadcode.scala index 8765d570a..3a88f8e9c 100644 --- a/effekt/shared/src/main/scala/effekt/core/Deadcode.scala +++ b/effekt/shared/src/main/scala/effekt/core/Deadcode.scala @@ -19,14 +19,15 @@ class Deadcode(entrypoints: Set[Id], definitions: Map[Id, Definition]) extends c }, rewrite(stmt)) } - override def rewrite(m: ModuleDecl): ModuleDecl = m.copy( - // Remove top-level unused definitions - definitions = m.definitions.collect { case d if reachable.isDefinedAt(d.id) => rewrite(d) }, - externs = m.externs.collect { - case e: Extern.Def if reachable.isDefinedAt(e.id) => e - case e: Extern.Include => e - } - ) + override def rewrite(m: ModuleDecl): ModuleDecl = + m.copy( + // Remove top-level unused definitions + definitions = m.definitions.collect { case d if reachable.isDefinedAt(d.id) => rewrite(d) }, + externs = m.externs.collect { + case e: Extern.Def if reachable.isDefinedAt(e.id) => e + case e: Extern.Include => e + } + ) } object Deadcode { diff --git a/effekt/shared/src/main/scala/effekt/core/Inline.scala b/effekt/shared/src/main/scala/effekt/core/Inline.scala index 18f91c607..91b5efa84 100644 --- a/effekt/shared/src/main/scala/effekt/core/Inline.scala +++ b/effekt/shared/src/main/scala/effekt/core/Inline.scala @@ -188,7 +188,7 @@ object Inline { } def rewrite(p: Pure)(using InlineContext): Pure = p match { - case Pure.PureApp(b, targs, vargs) => pureApp(rewrite(b), targs, vargs.map(rewrite)) + case Pure.PureApp(b, targs, vargs) => pureApp(b, targs, vargs.map(rewrite)) case Pure.Make(data, tag, vargs) => make(data, tag, vargs.map(rewrite)) // currently, we don't inline values, but we can dealias them case x @ Pure.ValueVar(id, annotatedType) => dealias(x) @@ -200,7 +200,7 @@ object Inline { } def rewrite(e: Expr)(using InlineContext): Expr = e match { - case DirectApp(b, targs, vargs, bargs) => directApp(rewrite(b), targs, vargs.map(rewrite), bargs.map(rewrite)) + case DirectApp(b, targs, vargs, bargs) => directApp(b, targs, vargs.map(rewrite), bargs.map(rewrite)) // congruences case Run(s) => run(rewrite(s)) diff --git a/effekt/shared/src/main/scala/effekt/core/Interpreter.scala b/effekt/shared/src/main/scala/effekt/core/Interpreter.scala new file mode 100644 index 000000000..3810289c4 --- /dev/null +++ b/effekt/shared/src/main/scala/effekt/core/Interpreter.scala @@ -0,0 +1,773 @@ +package effekt +package core + +import effekt.core.Interpreter.Stack.Segment +import effekt.source.FeatureFlag + +import scala.annotation.tailrec + + +type ~>[-A, +B] = PartialFunction[A, B] + + +trait Instrumentation { + def staticDispatch(id: Id): Unit = () + def dynamicDispatch(id: Id): Unit = () + def patternMatch(comparisons: Int): Unit = () + def branch(): Unit = () + def pushFrame(): Unit = () + def popFrame(): Unit = () + def allocate(v: Interpreter.Value.Data): Unit = () + def closure(): Unit = () + def fieldLookup(id: Id): Unit = () + def step(state: Interpreter.State): Unit = () + def readMutableVariable(id: Id): Unit = () + def writeMutableVariable(id: Id): Unit = () + def allocateVariable(id: Id): Unit = () + def allocateRegion(region: Interpreter.Address): Unit = () + def allocateVariableIntoRegion(id: Id, region: Interpreter.Address): Unit = () + def reset(): Unit = () + def shift(): Unit = () + def resume(): Unit = () +} +object NoInstrumentation extends Instrumentation + +class Counting extends Instrumentation { + var staticDispatches = 0 + var dynamicDispatches = 0 + var patternMatches = 0 + var branches = 0 + var pushedFrames = 0 + var poppedFrames = 0 + var allocations = 0 + var closures = 0 + var fieldLookups = 0 + var variableReads = 0 + var variableWrites = 0 + + var resets = 0 + var shifts = 0 + var resumes = 0 + + override def staticDispatch(id: Id): Unit = staticDispatches += 1 + override def dynamicDispatch(id: Id): Unit = dynamicDispatches += 1 + override def patternMatch(comparisons: Int): Unit = patternMatches += 1 + override def branch(): Unit = branches += 1 + override def pushFrame(): Unit = pushedFrames += 1 + override def popFrame(): Unit = poppedFrames += 1 + override def allocate(v: Interpreter.Value.Data): Unit = allocations += 1 + override def closure(): Unit = closures += 1 + override def fieldLookup(id: Id): Unit = fieldLookups += 1 + override def readMutableVariable(id: Id): Unit = variableReads += 1 + override def writeMutableVariable(id: Id): Unit = variableWrites += 1 + override def reset(): Unit = resets += 1 + override def shift(): Unit = shifts += 1 + override def resume(): Unit = resumes += 1 + + def report() = + println(s"Static dispatches: ${staticDispatches}") + println(s"Dynamic dispatches: ${dynamicDispatches}") + println(s"Branches: ${branches}") + println(s"Pattern matches: ${patternMatches}") + println(s"Frames (pushed: ${pushedFrames}, popped: ${poppedFrames})") + println(s"Allocations: ${allocations}") + println(s"Closures: ${closures}") + println(s"Field lookups: ${fieldLookups}") + println(s"Variable reads: ${variableReads}") + println(s"Variable writes: ${variableWrites}") + println(s"Installed delimiters: ${resets}") + println(s"Captured continuations: ${shifts}") + println(s"Resumed continuations: ${resumes}") +} + +class Interpreter(instrumentation: Instrumentation = NoInstrumentation) { + + import Interpreter.* + + // TODO maybe replace region values by integers instead of Id + + // TODO instrument the interpreter and count: + // - heap allocations + // - function calls + // - virtual dispatch + // - pattern match and field selection, etc. + // - primitive operations + + // things we need to know to run the interpreter: + // - FFI Builtins (like infixAdd) --> prepopulate environemtn + // - Datatype declarations (for generic comparison and field selection) + + @tailrec + private def returnWith(value: Value, env: Env, stack: Stack): State = + @tailrec + def go(frames: List[Frame], prompt: Address, stack: Stack): State = + frames match { + case Frame.Val(x, body, frameEnv) :: rest => + instrumentation.popFrame() + State.Step(body, frameEnv.bind(x, value), Stack.Segment(rest, prompt, stack)) + // free the mutable state + case Frame.Var(x, value) :: rest => go(rest, prompt, stack) + // free the region + case Frame.Region(x, values) :: rest => go(rest, prompt, stack) + case Nil => returnWith(value, env, stack) + } + stack match { + case Stack.Empty => State.Done(value) + case Stack.Segment(frames, prompt, rest) => + go(frames, prompt, rest) + } + + private def push(frame: Frame, stack: Stack): Stack = stack match { + case Stack.Empty => ??? + case Stack.Segment(frames, prompt, rest) => Stack.Segment(frame :: frames, prompt, rest) + } + + @tailrec + private def findFirst[A](stack: Stack)(f: Frame ~> A): Option[A] = + stack match { + case Stack.Empty => None + case Stack.Segment(frames, prompt, rest) => + @tailrec + def go(frames: List[Frame]): Option[A] = + frames match { + case Nil => findFirst(rest)(f) + case frame :: rest if f.isDefinedAt(frame) => Some(f(frame)) + case frame :: rest => go(rest) + } + go(frames) + } + + def updateOnce(stack: Stack)(f: Frame ~> Frame): Stack = + stack match { + case Stack.Empty => ??? + case Stack.Segment(frames, prompt, rest) => + def go(frames: List[Frame], acc: List[Frame]): Stack = + frames match { + case Nil => + Stack.Segment(acc.reverse, prompt, updateOnce(rest)(f)) + case frame :: frames if f.isDefinedAt(frame) => + Stack.Segment(acc.reverse ++ (f(frame) :: frames), prompt, rest) + case frame :: frames => + go(frames, frame :: acc) + } + go(frames, Nil) + } + + @tailrec + private def findFirst[T](env: Env)(f: Env ~> T): T = env match { + case e if f.isDefinedAt(e) => f(e) + case Env.Top(functions, builtins, declarations) => ??? + case Env.Static(id, block, rest) => findFirst(rest)(f) + case Env.Dynamic(id, block, rest) => findFirst(rest)(f) + case Env.Let(id, value, rest) => findFirst(rest)(f) + } + + def step(s: State): State = + instrumentation.step(s) + s match { + case State.Done(result) => s + case State.Step(stmt, env, stack) => stmt match { + + case Stmt.Scope(definitions, body) => + var envSoFar = env + definitions.foreach { + case Definition.Def(id, block: Block.BlockLit) => envSoFar = envSoFar.bind(id, block) + // TODO what is the cost model for aliased block literals? + case Definition.Def(id, block) => envSoFar = envSoFar.bind(id, eval(block, env)) + case Definition.Let(id, tpe, binding) => envSoFar = envSoFar.bind(id, eval(binding, envSoFar)) + } + State.Step(body, envSoFar, stack) + + case Stmt.Return(expr) => + val v = eval(expr, env) + returnWith(v, env, stack) + + case Stmt.Val(id, annotatedTpe, binding, body) => + instrumentation.pushFrame() + State.Step(binding, env, push(Frame.Val(id, body, env), stack)) + + case Stmt.App(Block.BlockVar(id, _, _), targs, vargs, bargs) => + @tailrec + def lookup(env: Env): (BlockLit, Env) = env match { + case Env.Top(functions, builtins, declarations) => + instrumentation.staticDispatch(id) + (functions.getOrElse(id, throw InterpreterError.NotFound(id)), env) + case Env.Static(other, block, rest) if id == other => + instrumentation.staticDispatch(id) + (block, env) + case Env.Static(other, block, rest) => lookup(rest) + case Env.Dynamic(other, block, rest) if id == other => block match { + case Computation.Closure(target, env) => + instrumentation.dynamicDispatch(id) + env.lookupStatic(target) + case _ => + throw InterpreterError.RuntimeTypeError("Can only call functions") + } + case Env.Dynamic(other, block, rest) => lookup(rest) + case Env.Let(other, value, rest) => lookup(rest) + } + + val (Block.BlockLit(_, _, vparams, bparams, body), definitionEnv) = lookup(env) + + State.Step( + body, + definitionEnv + .bindValues((vparams zip vargs).map { case (p, a) => p.id -> eval(a, env) }) + .bindBlocks((bparams zip bargs).map { case (p, a) => p.id -> eval(a, env) }), + stack) + + case Stmt.App(callee, targs, vargs, bargs) => ??? + + case Stmt.Invoke(b, method, methodTpe, targs, vargs, bargs) => + eval(b, env) match { + case Computation.Object(methods, definitionEnv) => + val BlockLit(_, _, vparams, bparams, body) = methods.getOrElse(method, throw InterpreterError.NonExhaustive(method)) + instrumentation.dynamicDispatch(method) + State.Step( + body, + definitionEnv + .bindValues((vparams zip vargs).map { case (p, a) => p.id -> eval(a, env) }) + .bindBlocks((bparams zip bargs).map { case (p, a) => p.id -> eval(a, env) }), + stack) + case _ => throw InterpreterError.RuntimeTypeError("Can only call methods on objects") + } + + case Stmt.If(cond, thn, els) => + instrumentation.branch() + eval(cond, env) match { + case As.Bool(true) => State.Step(thn, env, stack) + case As.Bool(false) => State.Step(els, env, stack) + case v => throw InterpreterError.RuntimeTypeError(s"Expected Bool, but got ${v}") + } + + case Stmt.Match(scrutinee, clauses, default) => eval(scrutinee, env) match { + case Value.Data(data, tag, fields) => + @tailrec + def search(clauses: List[(Id, BlockLit)], comparisons: Int): State = (clauses, default) match { + case (Nil, None) => + throw InterpreterError.NonExhaustive(tag) + case (Nil, Some(stmt)) => + instrumentation.patternMatch(comparisons) + State.Step(stmt, env, stack) + case ((id, BlockLit(tparams, cparams, vparams, bparams, body)) :: clauses, _) if id == tag => + instrumentation.patternMatch(comparisons) + State.Step(body, env.bindValues(vparams.map(p => p.id) zip fields), stack) + case (_ :: clauses, _) => search(clauses, comparisons + 1) + } + search(clauses, 0) + + case other => throw InterpreterError.RuntimeTypeError(s"Expected value of a data type, but got ${other}") + } + + case Stmt.Region(Block.BlockLit(_, _, _, List(region), body)) => + val fresh = freshAddress() + + instrumentation.allocateRegion(fresh) + + State.Step(body, env.bind(region.id, Computation.Region(fresh)), + push(Frame.Region(fresh, Map.empty), stack)) + + // TODO make the type of Region more precise... + case Stmt.Region(_) => ??? + + case Stmt.Alloc(id, init, region, body) => + val value = eval(init, env) + + val address = findFirst(env) { + case Env.Dynamic(id, Computation.Region(r), rest) => r + } + + instrumentation.allocateVariableIntoRegion(id, address) + + val updated = updateOnce(stack) { + case Frame.Region(r, values) if r == address => + Frame.Region(r, values.updated(id, value)) + } + returnWith(Value.Literal(()), env, updated) + + // TODO also use addresses for variables + case Stmt.Var(id, init, capture, body) => + instrumentation.allocateVariable(id) + State.Step(body, env, push(Frame.Var(id, eval(init, env)), stack)) + + case Stmt.Get(id, annotatedCapt, annotatedTpe) => + instrumentation.readMutableVariable(id) + val value = findFirst(stack) { + case Frame.Var(other, value) if other == id => value + case Frame.Region(_, values) if values.isDefinedAt(id) => values(id) + } getOrElse ??? + + returnWith(value, env, stack) + + case Stmt.Put(id, annotatedCapt, value) => + instrumentation.writeMutableVariable(id) + val newValue = eval(value, env) + val updated = updateOnce(stack) { + case Frame.Var(other, value) if other == id => + Frame.Var(other, newValue) + case Frame.Region(r, values) if values.isDefinedAt(id) => + Frame.Region(r, values.updated(id, newValue)) + } + + returnWith(Value.Literal(()), env, updated) + + case Stmt.Reset(BlockLit(_, _, _, List(prompt), body)) => + val freshPrompt = freshAddress() + instrumentation.reset() + State.Step(body, env.bind(prompt.id, Computation.Prompt(freshPrompt)), + Segment(Nil, freshPrompt, stack)) + + case Stmt.Reset(b) => ??? + + case Stmt.Shift(prompt, BlockLit(tparams, cparams, vparams, List(resume), body)) => + instrumentation.shift() + val address = findFirst(env) { + case Env.Dynamic(id, Computation.Prompt(addr), rest) => addr + } + @tailrec + def unwind(stack: Stack, cont: Stack): (Stack, Stack) = stack match { + case Stack.Empty => ??? + case Stack.Segment(frames, prompt, rest) if prompt == address => + (Stack.Segment(frames, prompt, cont), rest) + case Stack.Segment(frames, prompt, rest) => + unwind(rest, Stack.Segment(frames, prompt, cont)) + } + val (cont, rest) = unwind(stack, Stack.Empty) + + State.Step(body, env.bind(resume.id, Computation.Resumption(cont)), rest) + case Stmt.Shift(_, _) => ??? + + + case Stmt.Resume(k, body) => + instrumentation.resume() + val cont = findFirst(env) { + case Env.Dynamic(id, Computation.Resumption(k), rest) => k + } + @tailrec + def rewind(k: Stack, onto: Stack): Stack = k match { + case Stack.Empty => onto + case Stack.Segment(frames, prompt, rest) => + rewind(rest, Stack.Segment(frames, prompt, onto)) + } + State.Step(body, env, rewind(cont, stack)) + + case Stmt.Hole() => throw InterpreterError.Hole() + } + } + + @tailrec + private def run(s: State): Value = s match { + case State.Done(result) => result + case other => run(step(other)) + } + + def eval(b: Block, env: Env): Computation = b match { + case Block.BlockVar(id, annotatedTpe, annotatedCapt) => + @tailrec + def go(env: Env): Computation = env match { + case Env.Top(functions, builtins, declarations) => instrumentation.closure(); Computation.Closure(id, env) + case Env.Static(other, block, rest) if other == id => instrumentation.closure(); Computation.Closure(id, env) + case Env.Static(other, block, rest) => go(rest) + case Env.Dynamic(other, block, rest) if other == id => block + case Env.Dynamic(other, block, rest) => go(rest) + case Env.Let(other, value, rest) => go(rest) + } + go(env) + case b @ Block.BlockLit(tparams, cparams, vparams, bparams, body) => + val tmp = Id("tmp") + instrumentation.closure() + Computation.Closure(tmp, env.bind(tmp, b)) + case Block.Unbox(pure) => eval(pure, env) match { + case Value.Boxed(block) => block + case other => throw InterpreterError.RuntimeTypeError(s"Expected boxed block, but got ${other}") + } + case Block.New(Implementation(interface, operations)) => + instrumentation.closure() + Computation.Object(operations.map { + case Operation(id, tparams, cparams, vparams, bparams, body) => + id -> (BlockLit(tparams, cparams, vparams, bparams, body): BlockLit) + }.toMap, env) + } + + def eval(e: Expr, env: Env): Value = e match { + case DirectApp(b, targs, vargs, Nil) => env.lookupBuiltin(b.id) match { + case impl => + val arguments = vargs.map(a => eval(a, env)) + try { impl(arguments) } catch { case e => sys error s"Cannot call ${b} with arguments ${arguments.map { + case Value.Literal(l) => s"${l}: ${l.getClass.getName}" + case other => other.toString + }.mkString(", ")}" } + } + case DirectApp(b, targs, vargs, bargs) => ??? + case Run(s) => run(State.Step(s, env, Stack.Toplevel)) + case Pure.ValueVar(id, annotatedType) => env.lookupValue(id) + case Pure.Literal(value, annotatedType) => Value.Literal(value) + case Pure.PureApp(x, targs, vargs) => env.lookupBuiltin(x.id) match { + case impl => + val arguments = vargs.map(a => eval(a, env)) + try { impl(arguments) } catch { case e => sys error s"Cannot call ${x} with arguments ${arguments.map { + case Value.Literal(l) => s"${l}: ${l.getClass.getName}" + case other => other.toString + }.mkString(", ")}" } + } + case Pure.Make(data, tag, vargs) => + val result: Value.Data = Value.Data(data, tag, vargs.map(a => eval(a, env))) + instrumentation.allocate(result) + result + case Pure.Select(target, field, annotatedType) => + @tailrec + def declarations(env: Env): List[Declaration] = env match { + case Env.Top(functions, builtins, declarations) => declarations + case Env.Static(id, block, rest) => declarations(rest) + case Env.Dynamic(id, block, rest) => declarations(rest) + case Env.Let(id, value, rest) => declarations(rest) + } + val decls = DeclarationContext(declarations(env), Nil) + + // TODO clean this mess up! + val fieldSymbol = decls.findField(field).getOrElse(???) + val constrSymbol = decls.findConstructor(fieldSymbol).getOrElse(???) + val index = constrSymbol.fields.indexOf(fieldSymbol) + + instrumentation.fieldLookup(field) + + eval(target, env) match { + case Value.Data(data, tag, fields) => fields(index) + case _ => ??? + } + + case Pure.Box(b, annotatedCapture) => Value.Boxed(eval(b, env)) + } + + def run(main: Id, m: ModuleDecl): Unit = { + val mainFun = m.definitions.collectFirst { + case Definition.Def(id, b: BlockLit) if id == main => b + }.getOrElse { throw new InterpreterError.NoMain() } + + val functions = m.definitions.collect { case Definition.Def(id, b: Block.BlockLit) => id -> b }.toMap + + val builtinFunctions = m.externs.collect { + case Extern.Def(id, tparams, cparams, vparams, bparams, ret, annotatedCapture, + ExternBody.StringExternBody(FeatureFlag.NamedFeatureFlag("jvm"), Template(name :: Nil, Nil))) => + id -> builtins.getOrElse(name, throw InterpreterError.MissingBuiltin(name)) + }.toMap + + val env = Env.Top(functions, builtinFunctions, m.declarations) + + val initial = State.Step(mainFun.body, env, Stack.Toplevel) + + run(initial) + } +} + +object Interpreter { + + type Address = Int + private var lastAddress: Address = 0 + def freshAddress(): Address = { lastAddress += 1; lastAddress } + + val GLOBAL_PROMPT = 0 + + class Reference(var value: Value) + + enum Value { + case Literal(value: Any) + // TODO this could also be Pointer(Array | Ref) + case Array(array: scala.Array[Value]) + case Ref(ref: Reference) + case Data(data: ValueType.Data, tag: Id, fields: List[Value]) + case Boxed(block: Computation) + } + object Value { + def Int(v: Long): Value = Value.Literal(v) + def Bool(b: Boolean): Value = Value.Literal(b) + def Unit(): Value = Value.Literal(()) + def Double(d: scala.Double): Value = Value.Literal(d) + def String(s: java.lang.String): Value = Value.Literal(s) + } + + def inspect(v: Value): String = v match { + case Value.Literal(value) => value.toString + case Value.Data(data, tag, fields) => + tag.name.name + "(" + fields.map(inspect).mkString(", ") + ")" + case Value.Boxed(block) => block.toString + case Value.Array(arr) => ??? + case Value.Ref(ref) => ??? + } + + enum Computation { + case Closure(id: Id, env: Env) + case Object(methods: Map[Id, BlockLit], env: Env) + case Region(address: Address) + case Prompt(address: Address) + case Resumption(cont: Stack) + } + + enum Env { + case Top(functions: Map[Id, BlockLit], builtins: Map[Id, Builtin], declarations: List[core.Declaration]) + case Static(id: Id, block: BlockLit, rest: Env) + case Dynamic(id: Id, block: Computation, rest: Env) + case Let(id: Id, value: Value, rest: Env) + + def bind(id: Id, value: Value): Env = Let(id, value, this) + def bind(id: Id, lit: BlockLit): Env = Static(id, lit, this) + def bind(id: Id, block: Computation): Env = Dynamic(id, block, this) + def bindValues(otherValues: List[(Id, Value)]): Env = + otherValues.foldLeft(this) { case (env, (id, value)) => Let(id, value, env) } + def bindBlocks(otherBlocks: List[(Id, Computation)]): Env = + otherBlocks.foldLeft(this) { case (env, (id, block)) => Dynamic(id, block, env) } + + def lookupValue(id: Id): Value = { + @tailrec + def go(rest: Env): Value = rest match { + case Env.Top(functions, builtins, declarations) => throw InterpreterError.NotFound(id) + case Env.Static(id, block, rest) => go(rest) + case Env.Dynamic(id, block, rest) => go(rest) + case Env.Let(otherId, value, rest) => if (id == otherId) value else go(rest) + } + go(this) + } + + def lookupBuiltin(id: Id): Builtin = { + @tailrec + def go(rest: Env): Builtin = rest match { + case Env.Top(functions, builtins, declarations) => builtins.getOrElse(id, throw InterpreterError.NotFound(id)) + case Env.Static(id, block, rest) => go(rest) + case Env.Dynamic(id, block, rest) => go(rest) + case Env.Let(id, value, rest) => go(rest) + } + go(this) + } + + def lookupStatic(id: Id): (BlockLit, Env) = this match { + case Env.Top(functions, builtins, declarations) => (functions.getOrElse(id, throw InterpreterError.NotFound(id)), this) + case Env.Static(other, block, rest) => if (id == other) (block, this) else rest.lookupStatic(id) + case Env.Dynamic(other, block, rest) => rest.lookupStatic(id) + case Env.Let(other, value, rest) => rest.lookupStatic(id) + } + } + + enum InterpreterError extends Throwable { + case NotFound(id: Id) + case NotAnExternFunction(id: Id) + case MissingBuiltin(name: String) + case RuntimeTypeError(msg: String) + case NonExhaustive(missingCase: Id) + case Hole() + case NoMain() + } + + enum Stack { + case Empty + case Segment(frames: List[Frame], prompt: Address, rest: Stack) + } + object Stack { + val Toplevel = Stack.Segment(Nil, GLOBAL_PROMPT, Stack.Empty) + } + def show(stack: Stack): String = stack match { + case Stack.Empty => "Empty" + case Stack.Segment(frames, prompt, rest) => + s"${frames.map(show).mkString(" :: ")} :: p${prompt } :: ${show(rest)}" + } + + def show(frame: Frame): String = frame match { + case Frame.Var(x, value) => s"${util.show(x)}=${show(value)}" + case Frame.Val(x, body, env) => s"val ${util.show(x)}" + case Frame.Region(r, values) => s"region ${r} {${values.map { + case (id, value) => s"${util.show(id)}=${show(value)}}" + }.mkString(", ")}}" + } + + def show(value: Value): String = inspect(value) + + enum Frame { + // mutable state + case Var(x: Id, value: Value) + // sequencing + case Val(x: Id, body: Stmt, env: Env) + // local regions + case Region(r: Address, values: Map[Id, Value]) + } + + enum State { + case Done(result: Value) + case Step(stmt: Stmt, env: Env, stack: Stack) + } + + + type Builtin = List[Value] ~> Value + + def Builtin(impl: List[Value] ~> Value) = impl + + val builtins = Map( + "effekt::println(String)" -> Builtin { + case As.String(msg) :: Nil => + println(msg); + Value.Unit() + }, + "effekt::show(Int)" -> Builtin { + case As.Int(n) :: Nil => Value.String(n.toString) + }, + "effekt::infixAdd(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Int(x + y) + }, + "effekt::infixSub(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Int(x - y) + }, + "effekt::infixMul(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Int(x * y) + }, + "effekt::infixAdd(Double, Double)" -> Builtin { + case As.Double(x) :: As.Double(y) :: Nil => Value.Double(x + y) + }, + "effekt::infixSub(Double, Double)" -> Builtin { + case As.Double(x) :: As.Double(y) :: Nil => Value.Double(x - y) + }, + "effekt::infixMul(Double, Double)" -> Builtin { + case As.Double(x) :: As.Double(y) :: Nil => Value.Double(x * y) + }, + "effekt::infixDiv(Double, Double)" -> Builtin { + case As.Double(x) :: As.Double(y) :: Nil => Value.Double(x / y) + }, + "effekt::toInt(Double)" -> Builtin { + case As.Double(x) :: Nil => Value.Int(x.toLong) + }, + "effekt::toDouble(Int)" -> Builtin { + case As.Int(x) :: Nil => Value.Double(x.toDouble) + }, + "effekt::mod(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Int(x % y) + }, + "effekt::infixEq(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Bool(x == y) + }, + "effekt::infixNeq(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Bool(x != y) + }, + "effekt::infixLt(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Bool(x < y) + }, + "effekt::infixGt(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Bool(x > y) + }, + "effekt::infixLte(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Bool(x <= y) + }, + "effekt::infixGte(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Bool(x >= y) + }, + + "effekt::infixEq(Double, Double)" -> Builtin { + case As.Double(x) :: As.Double(y) :: Nil => Value.Bool(x == y) + }, + "effekt::infixNeq(Double, Double)" -> Builtin { + case As.Double(x) :: As.Double(y) :: Nil => Value.Bool(x != y) + }, + "effekt::infixLt(Double, Double)" -> Builtin { + case As.Double(x) :: As.Double(y) :: Nil => Value.Bool(x < y) + }, + "effekt::infixGt(Double, Double)" -> Builtin { + case As.Double(x) :: As.Double(y) :: Nil => Value.Bool(x > y) + }, + "effekt::infixLte(Double, Double)" -> Builtin { + case As.Double(x) :: As.Double(y) :: Nil => Value.Bool(x <= y) + }, + "effekt::infixGte(Double, Double)" -> Builtin { + case As.Double(x) :: As.Double(y) :: Nil => Value.Bool(x >= y) + }, + "effekt::sqrt(Double)" -> Builtin { + case As.Double(x) :: Nil => Value.Double(Math.sqrt(x)) + }, + + "effekt::bitwiseShl(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Int(x << y) + }, + "effekt::bitwiseShr(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Int(x >> y) + }, + "effekt::bitwiseAnd(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Int(x & y) + }, + "effekt::bitwiseOr(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Int(x | y) + }, + "effekt::bitwiseXor(Int, Int)" -> Builtin { + case As.Int(x) :: As.Int(y) :: Nil => Value.Int(x ^ y) + }, + + "effekt::infixConcat(String, String)" -> Builtin { + case As.String(x) :: As.String(y) :: Nil => Value.String(x + y) + }, + "effekt::not(Bool)" -> Builtin { + case As.Bool(x) :: Nil => Value.Bool(!x) + }, + "effekt::inspect(Any)" -> Builtin { + case any :: Nil => Value.String(inspect(any)) + }, + + // array + // ----- + "array::allocate(Int)" -> Builtin { + case As.Int(x) :: Nil => Value.Array(scala.Array.ofDim(x.toInt)) + }, + "array::size[T](Array[T])" -> Builtin { + case As.Array(arr) :: Nil => Value.Int(arr.length.toLong) + }, + "array::unsafeGet[T](Array[T], Int)" -> Builtin { + case As.Array(arr) :: As.Int(index) :: Nil => arr(index.toInt) + }, + "array::unsafeSet[T](Array[T], Int, T)" -> Builtin { + case As.Array(arr) :: As.Int(index) :: value :: Nil => arr.update(index.toInt, value); Value.Unit() + }, + + // ref + // --- + "ref::ref[T](T)" -> Builtin { + case init :: Nil => Value.Ref(Reference(init)) + }, + "ref::get[T](Ref[T])" -> Builtin { + case As.Reference(ref) :: Nil => ref.value + }, + "ref::set[T](Ref[T], T)" -> Builtin { + case As.Reference(ref) :: value :: Nil => ref.value = value; Value.Unit() + }, + ) + object As { + object String { + def unapply(v: Value): Option[java.lang.String] = v match { + case Value.Literal(value: java.lang.String) => Some(value) + case _ => None + } + } + object Int { + def unapply(v: Value): Option[scala.Long] = v match { + case Value.Literal(value: scala.Long) => Some(value) + case _ => None + } + } + object Bool { + def unapply(v: Value): Option[scala.Boolean] = v match { + case Value.Literal(value: scala.Boolean) => Some(value) + case _ => None + } + } + object Double { + def unapply(v: Value): Option[scala.Double] = v match { + case Value.Literal(value: scala.Double) => Some(value) + case _ => None + } + } + object Array { + def unapply(v: Value): Option[scala.Array[Value]] = v match { + case Value.Array(array) => Some(array) + case _ => None + } + } + object Reference { + def unapply(v: Value): Option[Reference] = v match { + case Value.Ref(ref) => Some(ref) + case _ => None + } + } + } +} diff --git a/effekt/shared/src/main/scala/effekt/core/Parser.scala b/effekt/shared/src/main/scala/effekt/core/Parser.scala index 7d80f3607..a3a9b7779 100644 --- a/effekt/shared/src/main/scala/effekt/core/Parser.scala +++ b/effekt/shared/src/main/scala/effekt/core/Parser.scala @@ -170,7 +170,7 @@ class CoreParsers(positions: Positions, names: Names) extends EffektLexers(posit | id ~ (`:` ~> valueType) ^^ Pure.ValueVar.apply | `box` ~> captures ~ block ^^ { case capt ~ block => Pure.Box(block, capt) } | `make` ~> dataType ~ id ~ valueArgs ^^ Pure.Make.apply - | block ~ maybeTypeArgs ~ valueArgs ^^ Pure.PureApp.apply + | maybeParens(blockVar) ~ maybeTypeArgs ~ valueArgs ^^ Pure.PureApp.apply | failure("Expected a pure expression.") ) @@ -191,14 +191,16 @@ class CoreParsers(positions: Positions, names: Names) extends EffektLexers(posit lazy val expr: P[Expr] = ( pure | `run` ~> stmt ^^ Run.apply - | (`!` ~/> block) ~ maybeTypeArgs ~ valueArgs ~ blockArgs ^^ DirectApp.apply + | (`!` ~/> maybeParens(blockVar)) ~ maybeTypeArgs ~ valueArgs ~ blockArgs ^^ DirectApp.apply ) + def maybeParens[T](p: P[T]): P[T] = (p | `(` ~> p <~ `)`) + // Blocks // ------ lazy val block: P[Block] = - ( id ~ (`:` ~> blockType) ~ (`@` ~> captures) ^^ Block.BlockVar.apply + ( blockVar | `unbox` ~> pure ^^ Block.Unbox.apply | `new` ~> implementation ^^ Block.New.apply | blockLit @@ -206,6 +208,11 @@ class CoreParsers(positions: Positions, names: Names) extends EffektLexers(posit | `(` ~> block <~ `)` ) + lazy val blockVar: P[Block.BlockVar] = + id ~ (`:` ~> blockType) ~ (`@` ~> captures) ^^ { + case x ~ tpe ~ capt => Block.BlockVar(x, tpe, capt) : Block.BlockVar + } + lazy val blockLit: P[Block.BlockLit] = `{` ~> parameters ~ (`=>` ~/> stmts) <~ `}` ^^ { case (tparams, cparams, vparams, bparams) ~ body => diff --git a/effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala b/effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala index c2057425c..c4e7dace2 100644 --- a/effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala +++ b/effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala @@ -386,8 +386,7 @@ object PolymorphismBoxing extends Phase[CoreTransformed, CoreTransformed] { def transform(expr: Expr)(using PContext): Expr = expr match { - case DirectApp(b, targs, vargs, bargs) => - val callee = transform(b) + case DirectApp(callee, targs, vargs, bargs) => val tpe: BlockType.Function = callee.tpe match { case tpe: BlockType.Function => tpe case _ => sys error "Callee does not have function type" @@ -567,17 +566,17 @@ object PolymorphismBoxing extends Phase[CoreTransformed, CoreTransformed] { } trait FunctionCoercer[Ty <: BlockType, Te <: Block] extends Coercer[Ty, Te] { - def callPure(block: Te, vargs: List[Pure])(using PContext): Pure - def callDirect(block: Te, vargs: List[Pure], bargs: List[Block])(using PContext): Expr + def callPure(block: Block.BlockVar, vargs: List[Pure])(using PContext): Pure + def callDirect(block: Block.BlockVar, vargs: List[Pure], bargs: List[Block])(using PContext): Expr def call(block: Te, vargs: List[Pure], bargs: List[Block])(using PContext): Stmt } class FunctionIdentityCoercer[Ty <: BlockType, Te <: Block]( from: Ty, to: Ty, targs: List[ValueType]) extends IdentityCoercer[Ty, Te](from, to) with FunctionCoercer[Ty, Te] { override def call(block: Te, vargs: List[Pure], bargs: List[Block])(using PContext): Stmt = Stmt.App(block, targs map transformArg, vargs, bargs) - override def callPure(block: Te, vargs: List[Pure])(using PContext): Pure = + override def callPure(block: Block.BlockVar, vargs: List[Pure])(using PContext): Pure = Pure.PureApp(block, targs map transformArg, vargs) - override def callDirect(block: Te, vargs: List[Pure], bargs: List[Block])(using PContext): Expr = + override def callDirect(block: Block.BlockVar, vargs: List[Pure], bargs: List[Block])(using PContext): Expr = DirectApp(block, targs map transformArg, vargs, bargs) } def coercer[B >: Block.BlockLit <: Block](fromtpe: BlockType, totpe: BlockType, targs: List[ValueType] = List())(using PContext): FunctionCoercer[BlockType, B] = @@ -617,11 +616,11 @@ object PolymorphismBoxing extends Phase[CoreTransformed, CoreTransformed] { Stmt.Return(rcoercer(Pure.ValueVar(result, rcoercer.from)))))) } - override def callPure(block: B, vargs: List[Pure])(using PContext): Pure = { + override def callPure(block: Block.BlockVar, vargs: List[Pure])(using PContext): Pure = { rcoercer(Pure.PureApp(block, targs map transformArg, (vcoercers zip vargs).map { case (c,v) => c(v) })) } - override def callDirect(block: B, vargs: List[Pure], bargs: List[Block])(using PContext): Expr = { + override def callDirect(block: Block.BlockVar, vargs: List[Pure], bargs: List[Block])(using PContext): Expr = { val result = TmpValue("coe") Run(Let(result, rcoercer.from, DirectApp(block, targs map transformArg, (vcoercers zip vargs).map {case (c,v) => c(v)}, diff --git a/effekt/shared/src/main/scala/effekt/core/Reachable.scala b/effekt/shared/src/main/scala/effekt/core/Reachable.scala index 3d51c94f0..603150830 100644 --- a/effekt/shared/src/main/scala/effekt/core/Reachable.scala +++ b/effekt/shared/src/main/scala/effekt/core/Reachable.scala @@ -3,6 +3,8 @@ package core /** * A simple reachability analysis. + * + * TODO reachability should also process externs since they now contain splices. */ class Reachable( var reachable: Map[Id, Usage], diff --git a/effekt/shared/src/main/scala/effekt/core/StaticArguments.scala b/effekt/shared/src/main/scala/effekt/core/StaticArguments.scala index 4aabbf42d..b3908af93 100644 --- a/effekt/shared/src/main/scala/effekt/core/StaticArguments.scala +++ b/effekt/shared/src/main/scala/effekt/core/StaticArguments.scala @@ -180,7 +180,7 @@ object StaticArguments { } def rewrite(p: Pure)(using StaticArgumentsContext): Pure = p match { - case Pure.PureApp(b, targs, vargs) => pureApp(rewrite(b), targs, vargs.map(rewrite)) + case Pure.PureApp(b, targs, vargs) => pureApp(b, targs, vargs.map(rewrite)) case Pure.Make(data, tag, vargs) => make(data, tag, vargs.map(rewrite)) case x @ Pure.ValueVar(id, annotatedType) => x @@ -191,7 +191,7 @@ object StaticArguments { } def rewrite(e: Expr)(using StaticArgumentsContext): Expr = e match { - case DirectApp(b, targs, vargs, bargs) => directApp(rewrite(b), targs, vargs.map(rewrite), bargs.map(rewrite)) + case DirectApp(b, targs, vargs, bargs) => directApp(b, targs, vargs.map(rewrite), bargs.map(rewrite)) // congruences case Run(s) => run(rewrite(s)) diff --git a/effekt/shared/src/main/scala/effekt/core/Tree.scala b/effekt/shared/src/main/scala/effekt/core/Tree.scala index 24d4cdfe9..bfa6c2cf2 100644 --- a/effekt/shared/src/main/scala/effekt/core/Tree.scala +++ b/effekt/shared/src/main/scala/effekt/core/Tree.scala @@ -184,7 +184,7 @@ sealed trait Expr extends Tree { } // invariant, block b is {io}. -case class DirectApp(b: Block, targs: List[ValueType], vargs: List[Pure], bargs: List[Block]) extends Expr +case class DirectApp(b: Block.BlockVar, targs: List[ValueType], vargs: List[Pure], bargs: List[Block]) extends Expr // only inserted by the transformer if stmt is pure / io case class Run(s: Stmt) extends Expr @@ -214,7 +214,7 @@ enum Pure extends Expr { /** * Pure FFI calls. Invariant, block b is pure. */ - case PureApp(b: Block, targs: List[ValueType], vargs: List[Pure]) + case PureApp(b: Block.BlockVar, targs: List[ValueType], vargs: List[Pure]) /** * Constructor calls @@ -404,18 +404,8 @@ object normal { def make(tpe: ValueType.Data, tag: Id, vargs: List[Pure]): Pure = Pure.Make(tpe, tag, vargs) - def pureApp(callee: Block, targs: List[ValueType], vargs: List[Pure]): Pure = - callee match { - case b : Block.BlockLit => - INTERNAL_ERROR( - """|This should not happen! - |User defined functions always have to be called with App, not PureApp. - |If this error does occur, this means this changed. - |Check `core.Transformer.makeFunctionCall` for details. - |""".stripMargin) - case other => - Pure.PureApp(callee, targs, vargs) - } + def pureApp(callee: Block.BlockVar, targs: List[ValueType], vargs: List[Pure]): Pure = + Pure.PureApp(callee, targs, vargs) // "match" is a keyword in Scala def patternMatch(scrutinee: Pure, clauses: List[(Id, BlockLit)], default: Option[Stmt]): Stmt = @@ -432,12 +422,8 @@ object normal { } } - - def directApp(callee: Block, targs: List[ValueType], vargs: List[Pure], bargs: List[Block]): Expr = - callee match { - case b : Block.BlockLit => run(reduce(b, targs, vargs, Nil)) - case other => DirectApp(callee, targs, vargs, bargs) - } + def directApp(callee: Block.BlockVar, targs: List[ValueType], vargs: List[Pure], bargs: List[Block]): Expr = + DirectApp(callee, targs, vargs, bargs) def reduce(b: BlockLit, targs: List[core.ValueType], vargs: List[Pure], bargs: List[Block]): Stmt = { @@ -772,8 +758,10 @@ object substitutions { def substitute(expression: Expr)(using Substitution): Expr = expression match { - case DirectApp(b, targs, vargs, bargs) => - DirectApp(substitute(b), targs.map(substitute), vargs.map(substitute), bargs.map(substitute)) + case DirectApp(b, targs, vargs, bargs) => substitute(b) match { + case x : Block.BlockVar => DirectApp(x, targs.map(substitute), vargs.map(substitute), bargs.map(substitute)) + case _ => INTERNAL_ERROR("Should never substitute a concrete block for an FFI function.") + } case Run(s) => Run(substitute(s)) @@ -874,8 +862,10 @@ object substitutions { case Make(tpe, tag, vargs) => Make(substitute(tpe).asInstanceOf, tag, vargs.map(substitute)) - case PureApp(b, targs, vargs) => - PureApp(substitute(b), targs.map(substitute), vargs.map(substitute)) + case PureApp(b, targs, vargs) => substitute(b) match { + case x : Block.BlockVar => PureApp(x, targs.map(substitute), vargs.map(substitute)) + case _ => INTERNAL_ERROR("Should never substitute a concrete block for an FFI function.") + } case Select(target, field, annotatedType) => Select(substitute(target), field, substitute(annotatedType)) diff --git a/libraries/common/args.effekt b/libraries/common/args.effekt index eaf1d2336..af53dd664 100644 --- a/libraries/common/args.effekt +++ b/libraries/common/args.effekt @@ -4,6 +4,7 @@ extern io def commandLineArgs(): List[String] = js { js::commandLineArgs() } chez { chez::commandLineArgs() } llvm { llvm::commandLineArgs() } + jvm { Nil() } namespace js { extern type Args // = Array[String] diff --git a/libraries/common/array.effekt b/libraries/common/array.effekt index f778ca598..75bc07ebc 100644 --- a/libraries/common/array.effekt +++ b/libraries/common/array.effekt @@ -16,6 +16,7 @@ extern global def allocate[T](size: Int): Array[T] = %z = call %Pos @c_array_new(%Int ${size}) ret %Pos %z """ + jvm "array::allocate(Int)" /// Creates a new Array of size `size` filled with the value `init` def array[T](size: Int, init: T): Array[T] = { @@ -45,6 +46,7 @@ extern pure def size[T](arr: Array[T]): Int = %z = call %Int @c_array_size(%Pos ${arr}) ret %Int %z """ + jvm "array::size[T](Array[T])" /// Gets the element of the `arr` at given `index` in constant time. /// Unchecked Precondition: `index` is in bounds (0 ≤ index < arr.size) @@ -57,6 +59,7 @@ extern global def unsafeGet[T](arr: Array[T], index: Int): T = %z = call %Pos @c_array_get(%Pos ${arr}, %Int ${index}) ret %Pos %z """ + jvm "array::unsafeGet[T](Array[T], Int)" extern js """ function array$set(arr, index, value) { @@ -76,6 +79,7 @@ extern global def unsafeSet[T](arr: Array[T], index: Int, value: T): Unit = %z = call %Pos @c_array_set(%Pos ${arr}, %Int ${index}, %Pos ${value}) ret %Pos %z """ + jvm "array::unsafeSet[T](Array[T], Int, T)" diff --git a/libraries/common/effekt.effekt b/libraries/common/effekt.effekt index 65505277f..4d4809028 100644 --- a/libraries/common/effekt.effekt +++ b/libraries/common/effekt.effekt @@ -45,6 +45,7 @@ extern def println(value: String): Unit = call void @c_io_println_String(%Pos ${value}) ret %Pos zeroinitializer ; Unit """ + jvm "effekt::println(String)" def println(value: Int): Unit = println(value.show) def println(value: Unit): Unit = println(value.show) @@ -59,6 +60,7 @@ extern pure def show(value: Int): String = %z = call %Pos @c_bytearray_show_Int(%Int ${value}) ret %Pos %z """ + jvm"effekt::show(Int)" def show(value: Unit): String = "()" @@ -97,6 +99,7 @@ extern pure def genericShow[R](value: R): String = extern io def inspect[R](value: R): Unit = js { println(genericShow(value)) } chez { println(genericShow(value)) } + jvm "effekt::inspect(Any)" // Strings @@ -108,6 +111,7 @@ extern pure def infixConcat(s1: String, s2: String): String = %spz = call %Pos @c_bytearray_concatenate(%Pos ${s1}, %Pos ${s2}) ret %Pos %spz """ + jvm"effekt::infixConcat(String, String)" extern pure def length(str: String): Int = js "${str}.length" @@ -204,6 +208,7 @@ extern pure def infixEq(x: Int, y: Int): Bool = %adt_boolean = insertvalue %Pos zeroinitializer, i64 %fat_z, 0 ret %Pos %adt_boolean """ + jvm"effekt::infixEq(Int, Int)" extern pure def infixNeq(x: Int, y: Int): Bool = js "${x} !== ${y}" @@ -214,6 +219,7 @@ extern pure def infixNeq(x: Int, y: Int): Bool = %adt_boolean = insertvalue %Pos zeroinitializer, i64 %fat_z, 0 ret %Pos %adt_boolean """ + jvm"effekt::infixNeq(Int, Int)" extern pure def infixEq(x: Char, y: Char): Bool = js "${x} === ${y}" @@ -276,46 +282,55 @@ extern pure def infixAdd(x: Int, y: Int): Int = js "(${x} + ${y})" chez "(+ ${x} ${y})" llvm "%z = add %Int ${x}, ${y} ret %Int %z" + jvm"effekt::infixAdd(Int, Int)" extern pure def infixMul(x: Int, y: Int): Int = js "(${x} * ${y})" chez "(* ${x} ${y})" llvm "%z = mul %Int ${x}, ${y} ret %Int %z" + jvm"effekt::infixMul(Int, Int)" extern pure def infixDiv(x: Int, y: Int): Int = js "Math.floor(${x} / ${y})" chez "(floor (/ ${x} ${y}))" llvm "%z = sdiv %Int ${x}, ${y} ret %Int %z" + jvm"effekt::infixDiv(Int, Int)" extern pure def infixSub(x: Int, y: Int): Int = js "(${x} - ${y})" chez "(- ${x} ${y})" llvm "%z = sub %Int ${x}, ${y} ret %Int %z" + jvm"effekt::infixSub(Int, Int)" extern pure def mod(x: Int, y: Int): Int = js "(${x} % ${y})" chez "(modulo ${x} ${y})" llvm "%z = srem %Int ${x}, ${y} ret %Int %z" + jvm"effekt::mod(Int, Int)" extern pure def infixAdd(x: Double, y: Double): Double = js "(${x} + ${y})" chez "(+ ${x} ${y})" llvm "%z = fadd %Double ${x}, ${y} ret %Double %z" + jvm"effekt::infixAdd(Double, Double)" extern pure def infixMul(x: Double, y: Double): Double = js "(${x} * ${y})" chez "(* ${x} ${y})" llvm "%z = fmul %Double ${x}, ${y} ret %Double %z" + jvm"effekt::infixMul(Double, Double)" extern pure def infixSub(x: Double, y: Double): Double = js "(${x} - ${y})" chez "(- ${x} ${y})" llvm "%z = fsub %Double ${x}, ${y} ret %Double %z" + jvm"effekt::infixSub(Double, Double)" extern pure def infixDiv(x: Double, y: Double): Double = js "(${x} / ${y})" chez "(/ ${x} ${y})" llvm "%z = fdiv %Double ${x}, ${y} ret %Double %z" + jvm"effekt::infixDiv(Double, Double)" extern pure def cos(x: Double): Double = js "Math.cos(${x})" @@ -337,6 +352,7 @@ extern pure def sqrt(x: Double): Double = js "Math.sqrt(${x})" chez "(sqrt ${x})" llvm "%z = call %Double @llvm.sqrt.f64(double ${x}) ret %Double %z" + jvm "effekt::sqrt(Double)" def square(x: Double): Double = x * x @@ -385,11 +401,14 @@ extern pure def toInt(d: Double): Int = js "Math.trunc(${d})" chez "(flonum->fixnum ${d})" llvm "%z = fptosi double ${d} to %Int ret %Int %z" + jvm"effekt::toInt(Double)" extern pure def toDouble(d: Int): Double = js "${d}" chez "${d}" llvm "%z = sitofp i64 ${d} to double ret double %z" + jvm"effekt::toDouble(Int)" + extern pure def round(d: Double): Int = js "Math.round(${d})" @@ -442,6 +461,7 @@ extern pure def infixLt(x: Int, y: Int): Bool = %adt_boolean = insertvalue %Pos zeroinitializer, i64 %fat_z, 0 ret %Pos %adt_boolean """ + jvm"effekt::infixLt(Int, Int)" extern pure def infixLte(x: Int, y: Int): Bool = js "(${x} <= ${y})" @@ -452,6 +472,7 @@ extern pure def infixLte(x: Int, y: Int): Bool = %adt_boolean = insertvalue %Pos zeroinitializer, i64 %fat_z, 0 ret %Pos %adt_boolean """ + jvm"effekt::infixLte(Int, Int)" extern pure def infixGt(x: Int, y: Int): Bool = js "(${x} > ${y})" @@ -462,6 +483,7 @@ extern pure def infixGt(x: Int, y: Int): Bool = %adt_boolean = insertvalue %Pos zeroinitializer, i64 %fat_z, 0 ret %Pos %adt_boolean """ + jvm"effekt::infixGt(Int, Int)" extern pure def infixGte(x: Int, y: Int): Bool = js "(${x} >= ${y})" @@ -472,14 +494,17 @@ extern pure def infixGte(x: Int, y: Int): Bool = %adt_boolean = insertvalue %Pos zeroinitializer, i64 %fat_z, 0 ret %Pos %adt_boolean """ + jvm"effekt::infixGte(Int, Int)" extern pure def infixEq(x: Double, y: Double): Bool = js "${x} === ${y}" chez "(= ${x} ${y})" + jvm"effekt::infixEq(Double, Double)" extern pure def infixNeq(x: Double, y: Double): Bool = js "${x} !== ${y}" chez "(not (= ${x} ${y}))" + jvm"effekt::infixNeq(Double, Double)" extern pure def infixLt(x: Double, y: Double): Bool = js "(${x} < ${y})" @@ -490,10 +515,12 @@ extern pure def infixLt(x: Double, y: Double): Bool = %adt_boolean = insertvalue %Pos zeroinitializer, i64 %fat_z, 0 ret %Pos %adt_boolean """ + jvm"effekt::infixLt(Double, Double)" extern pure def infixLte(x: Double, y: Double): Bool = js "(${x} <= ${y})" chez "(<= ${x} ${y})" + jvm"effekt::infixLte(Double, Double)" extern pure def infixGt(x: Double, y: Double): Bool = js "(${x} > ${y})" @@ -504,10 +531,12 @@ extern pure def infixGt(x: Double, y: Double): Bool = %adt_boolean = insertvalue %Pos zeroinitializer, i64 %fat_z, 0 ret %Pos %adt_boolean """ + jvm"effekt::infixGt(Double, Double)" extern pure def infixGte(x: Double, y: Double): Bool = js "(${x} >= ${y})" chez "(>= ${x} ${y})" + jvm"effekt::infixGte(Double, Double)" // TODO do we really need those? if yes, move to string.effekt extern pure def infixLt(x: String, y: String): Bool = @@ -538,6 +567,7 @@ extern pure def not(b: Bool): Bool = %adt_q = insertvalue %Pos zeroinitializer, i64 %q, 0 ret %Pos %adt_q """ + jvm"effekt::not(Bool)" def infixOr { first: => Bool } { second: => Bool }: Bool = if (first()) true else second() @@ -554,27 +584,32 @@ extern pure def bitwiseShl(x: Int, y: Int): Int = js "(${x} << ${y})" chez "(ash ${x} ${y})" llvm "%z = shl %Int ${x}, ${y} ret %Int %z" + jvm"effekt::bitwiseShl(Int, Int)" /// Arithmetic right shift extern pure def bitwiseShr(x: Int, y: Int): Int = js "(${x} >> ${y})" chez "(ash ${x} (- ${y}))" llvm "%z = ashr %Int ${x}, ${y} ret %Int %z" + jvm"effekt::bitwiseShr(Int, Int)" extern pure def bitwiseAnd(x: Int, y: Int): Int = js "(${x} & ${y})" chez "(logand ${x} ${y})" llvm "%z = and %Int ${x}, ${y} ret %Int %z" + jvm"effekt::bitwiseAnd(Int, Int)" extern pure def bitwiseOr(x: Int, y: Int): Int = js "(${x} | ${y})" chez "(logior ${x} ${y})" llvm "%z = or %Int ${x}, ${y} ret %Int %z" + jvm"effekt::bitwiseOr(Int, Int)" extern pure def bitwiseXor(x: Int, y: Int): Int = js "(${x} ^ ${y})" chez "(logxor ${x} ${y})" llvm "%z = xor %Int ${x}, ${y} ret %Int %z" + jvm"effekt::bitwiseXor(Int, Int)" // Byte operations diff --git a/libraries/common/ref.effekt b/libraries/common/ref.effekt index bdd7613e0..300368621 100644 --- a/libraries/common/ref.effekt +++ b/libraries/common/ref.effekt @@ -30,6 +30,7 @@ extern global def ref[T](init: T): Ref[T] = %z = call %Pos @c_ref_fresh(%Pos ${init}) ret %Pos %z """ + jvm "ref::ref[T](T)" /// Gets the referenced element of the `ref` in constant time. extern global def get[T](ref: Ref[T]): T = @@ -39,6 +40,7 @@ extern global def get[T](ref: Ref[T]): T = %z = call %Pos @c_ref_get(%Pos ${ref}) ret %Pos %z """ + jvm "ref::get[T](Ref[T])" /// Sets the referenced element of the `ref` to `value` in constant time. extern global def set[T](ref: Ref[T], value: T): Unit = @@ -48,3 +50,4 @@ extern global def set[T](ref: Ref[T], value: T): Unit = %z = call %Pos @c_ref_set(%Pos ${ref}, %Pos ${value}) ret %Pos %z """ + jvm "ref::set[T](Ref[T], T)"