Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds polymorphism #2

Open
wants to merge 1 commit into
base: progagate-cancellation-outward
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
addSbtPlugin("io.github.davidgregory084" % "sbt-tpolecat" % "0.1.17")
87 changes: 50 additions & 37 deletions src/main/scala/cats/effect/std/IOAsync.scala
Original file line number Diff line number Diff line change
@@ -1,42 +1,51 @@
package cats.effect.std

import java.util.Objects
import scala.annotation.compileTimeOnly
import scala.reflect.macros.whitebox
// import language.experimental.macros

import scala.util.{Failure, Success, Try}
import cats.effect.std.Dispatcher
import cats.effect.IO
import cats.effect.kernel.Outcome
import cats.effect.kernel.Outcome.Canceled
import cats.effect.kernel.Outcome.Errored
import cats.effect.kernel.Outcome.Succeeded
import java.util.concurrent.CancellationException
import cats.effect.kernel.Sync
import cats.effect.kernel.Async
import cats.effect.syntax.all._
import cats.effect.IO

import language.experimental.macros
import scala.annotation.compileTimeOnly
import scala.reflect.macros.blackbox
import scala.annotation.compileTimeOnly
import scala.reflect.macros.whitebox
import cats.syntax
object IOAsync extends AsyncAwaitDsl[IO]

object IOAsync {
class AsyncAwaitDsl[F[_]](implicit F: Async[F]) {

type Callback = Either[Throwable, AnyRef] => Unit
/** Type member used by the macro expansion to recover what `F` is without typetags
*/
type _AsyncContext[A] = F[A]

/** Run the block of code `body` asynchronously. `body` may contain calls to `await` when the results of
* a `Future` are needed; this is translated into non-blocking code.
/** Value member used by the macro expansion to recover the Async instance associated to the block.
*/
def async[T](body: => T): IO[T] = macro asyncImpl[T]
implicit val _AsyncInstance: Async[F] = F

/** Non-blocking await the on result of `awaitable`. This may only be used directly within an enclosing `async` block.
*
* Internally, this will register the remainder of the code in enclosing `async` block as a callback
* in the `onComplete` handler of `awaitable`, and will *not* block a thread.
*/
@compileTimeOnly("[async] `await` must be enclosed in an `async` block")
def await[T](awaitable: IO[T]): T =
def await[T](awaitable: F[T]): T =
??? // No implementation here, as calls to this are translated to `onComplete` by the macro.

def asyncImpl[T: c.WeakTypeTag](c: whitebox.Context)(body: c.Tree): c.Tree = {
/** Run the block of code `body` asynchronously. `body` may contain calls to `await` when the results of
* a `Future` are needed; this is translated into non-blocking code.
*/
def async[T](body: => T): F[T] = macro AsyncAwaitDsl.asyncImpl[F, T]

}

object AsyncAwaitDsl {

type Callback = Either[Throwable, AnyRef] => Unit

def asyncImpl[F[_], T](
c: whitebox.Context
)(body: c.Tree): c.Tree = {
import c.universe._
if (!c.compilerSettings.contains("-Xasync")) {
c.abort(
Expand All @@ -45,7 +54,7 @@ object IOAsync {
)
} else
try {
val awaitSym = typeOf[IOAsync.type].decl(TermName("await"))
val awaitSym = typeOf[AsyncAwaitDsl[Any]].decl(TermName("await"))
def mark(t: DefDef): Tree = {
import language.reflectiveCalls
c.internal
Expand All @@ -68,14 +77,16 @@ object IOAsync {
val name = TypeName("stateMachine$async")
// format: off
q"""
final class $name(dispatcher: _root_.cats.effect.std.Dispatcher[IO], callback: _root_.cats.effect.std.IOAsync.Callback) extends _root_.cats.effect.std.IOStateMachine(dispatcher, callback) {
${mark(q"""override def apply(tr$$async: _root_.cats.effect.kernel.Outcome[_root_.cats.effect.IO, _root_.scala.Throwable, _root_.scala.AnyRef]): _root_.scala.Unit = ${body}""")}
final class $name(dispatcher: _root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext], callback: _root_.cats.effect.std.AsyncAwaitDsl.Callback) extends _root_.cats.effect.std.AsyncAwaitStateMachine(dispatcher, callback) {
${mark(q"""override def apply(tr$$async: _root_.cats.effect.kernel.Outcome[${c.prefix}._AsyncContext, _root_.scala.Throwable, _root_.scala.AnyRef]): _root_.scala.Unit = ${body}""")}
}
_root_.cats.effect.std.Dispatcher[IO].use { dispatcher =>
_root_.cats.effect.IO.async_[_root_.scala.AnyRef](cb => new $name(dispatcher, cb).start())
}.handleErrorWith {
case _root_.cats.effect.std.IOAsync.CancelBridge => _root_.cats.effect.IO.canceled
case _root_.scala.util.control.NonFatal(other) => _root_.cats.effect.IO.raiseError(other)
${c.prefix}._AsyncInstance.recoverWith {
_root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext].use { dispatcher =>
${c.prefix}._AsyncInstance.async_[_root_.scala.AnyRef](cb => new $name(dispatcher, cb).start())
}
}{
case _root_.cats.effect.std.AsyncAwaitDsl.CancelBridge =>
${c.prefix}._AsyncInstance.map(${c.prefix}._AsyncInstance.canceled)(_ => null.asInstanceOf[AnyRef])
}.asInstanceOf[${c.macroApplication.tpe}]
"""
} catch {
Expand All @@ -91,10 +102,10 @@ object IOAsync {
object CancelBridge extends Throwable with scala.util.control.NoStackTrace
}

abstract class IOStateMachine(
dispatcher: Dispatcher[IO],
callback: IOAsync.Callback
) extends Function1[Outcome[IO, Throwable, AnyRef], Unit] {
abstract class AsyncAwaitStateMachine[F[_]](
dispatcher: Dispatcher[F],
callback: AsyncAwaitDsl.Callback
)(implicit F: Sync[F]) extends Function1[Outcome[F, Throwable, AnyRef], Unit] {

// FSM translated method
//def apply(v1: Outcome[IO, Throwable, AnyRef]): Unit = ???
Expand All @@ -114,27 +125,29 @@ abstract class IOStateMachine(
callback(Right(value))
}

protected def onComplete(f: IO[AnyRef]): Unit = {
dispatcher.unsafeRunAndForget(f.guaranteeCase(outcome => IO(this(outcome))))
protected def onComplete(f: F[AnyRef]): Unit = {
dispatcher.unsafeRunAndForget(f.guaranteeCase(outcome => F.delay(this(outcome))))
}

protected def getCompleted(f: IO[AnyRef]): Outcome[IO, Throwable, AnyRef] = {
protected def getCompleted(f: F[AnyRef]): Outcome[F, Throwable, AnyRef] = {
val _ = f
null
}

protected def tryGet(tr: Outcome[IO, Throwable, AnyRef]): AnyRef =
protected def tryGet(tr: Outcome[F, Throwable, AnyRef]): AnyRef =
tr match {
case Outcome.Succeeded(value) =>
dispatcher.unsafeRunSync(value)
case Outcome.Errored(e) =>
callback(Left(e))
this // sentinel value to indicate the dispatch loop should exit.
case Outcome.Canceled() =>
callback(Left(IOAsync.CancelBridge))
callback(Left(AsyncAwaitDsl.CancelBridge))
this
}

def start(): Unit = {
// Required to kickstart the async state machine.
// `def apply` does not consult its argument when `state == 0`.
apply(null)
}
Expand Down
11 changes: 10 additions & 1 deletion src/test/scala/com/example/HelloWorldSuite.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.example

import cats.effect.{IO, SyncIO}
import cats.effect.IO
import munit.CatsEffectSuite
import cats.effect.std.IOAsync._
import scala.concurrent.duration._
Expand Down Expand Up @@ -51,4 +51,13 @@ class HelloWorldSuite extends CatsEffectSuite {
assertEquals(result, 0)
}
}

test("side effects in the async block are suspended") {
var x = 0

val io = async { x += 1; await(IO(x)) }

IO(assertEquals(x, 0)) *> io *> IO(assertEquals(x, 1))
}

}
22 changes: 22 additions & 0 deletions src/test/scala/com/example/OptionTSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.example

import cats.effect.IO
import munit.CatsEffectSuite
import cats.effect.std.AsyncAwaitDsl
import cats.data.OptionT

class OptionTSuite extends CatsEffectSuite {

type Foo[A] = OptionT[IO, A]
object dsl extends AsyncAwaitDsl[Foo]
import dsl._

test("optionT") {

val io: Foo[Int] = OptionT.none[IO, Int]

val foo = async(await(io))

foo.value.map(it => assertEquals(it, None))
}
}