From 37a4a22d9400696f4545314a832ccd418b8a71a7 Mon Sep 17 00:00:00 2001 From: Olivier Melois Date: Thu, 22 Apr 2021 15:18:34 +0200 Subject: [PATCH] Adds polymorphism --- project/plugins.sbt | 1 + src/main/scala/cats/effect/std/IOAsync.scala | 87 +++++++++++-------- .../scala/com/example/HelloWorldSuite.scala | 11 ++- src/test/scala/com/example/OptionTSuite.scala | 22 +++++ 4 files changed, 83 insertions(+), 38 deletions(-) create mode 100644 project/plugins.sbt create mode 100644 src/test/scala/com/example/OptionTSuite.scala diff --git a/project/plugins.sbt b/project/plugins.sbt new file mode 100644 index 0000000..9886dc5 --- /dev/null +++ b/project/plugins.sbt @@ -0,0 +1 @@ +addSbtPlugin("io.github.davidgregory084" % "sbt-tpolecat" % "0.1.17") diff --git a/src/main/scala/cats/effect/std/IOAsync.scala b/src/main/scala/cats/effect/std/IOAsync.scala index dd4a3be..f718c50 100644 --- a/src/main/scala/cats/effect/std/IOAsync.scala +++ b/src/main/scala/cats/effect/std/IOAsync.scala @@ -1,31 +1,27 @@ package cats.effect.std -import java.util.Objects +import scala.annotation.compileTimeOnly +import scala.reflect.macros.whitebox +// import language.experimental.macros -import scala.util.{Failure, Success, Try} import cats.effect.std.Dispatcher -import cats.effect.IO import cats.effect.kernel.Outcome -import cats.effect.kernel.Outcome.Canceled -import cats.effect.kernel.Outcome.Errored -import cats.effect.kernel.Outcome.Succeeded -import java.util.concurrent.CancellationException +import cats.effect.kernel.Sync +import cats.effect.kernel.Async +import cats.effect.syntax.all._ +import cats.effect.IO -import language.experimental.macros -import scala.annotation.compileTimeOnly -import scala.reflect.macros.blackbox -import scala.annotation.compileTimeOnly -import scala.reflect.macros.whitebox -import cats.syntax +object IOAsync extends AsyncAwaitDsl[IO] -object IOAsync { +class AsyncAwaitDsl[F[_]](implicit F: Async[F]) { - type Callback = Either[Throwable, AnyRef] => Unit + /** Type member used by the macro expansion to recover what `F` is without typetags + */ + type _AsyncContext[A] = F[A] - /** Run the block of code `body` asynchronously. `body` may contain calls to `await` when the results of - * a `Future` are needed; this is translated into non-blocking code. + /** Value member used by the macro expansion to recover the Async instance associated to the block. */ - def async[T](body: => T): IO[T] = macro asyncImpl[T] + implicit val _AsyncInstance: Async[F] = F /** Non-blocking await the on result of `awaitable`. This may only be used directly within an enclosing `async` block. * @@ -33,10 +29,23 @@ object IOAsync { * in the `onComplete` handler of `awaitable`, and will *not* block a thread. */ @compileTimeOnly("[async] `await` must be enclosed in an `async` block") - def await[T](awaitable: IO[T]): T = + def await[T](awaitable: F[T]): T = ??? // No implementation here, as calls to this are translated to `onComplete` by the macro. - def asyncImpl[T: c.WeakTypeTag](c: whitebox.Context)(body: c.Tree): c.Tree = { + /** Run the block of code `body` asynchronously. `body` may contain calls to `await` when the results of + * a `Future` are needed; this is translated into non-blocking code. + */ + def async[T](body: => T): F[T] = macro AsyncAwaitDsl.asyncImpl[F, T] + +} + +object AsyncAwaitDsl { + + type Callback = Either[Throwable, AnyRef] => Unit + + def asyncImpl[F[_], T]( + c: whitebox.Context + )(body: c.Tree): c.Tree = { import c.universe._ if (!c.compilerSettings.contains("-Xasync")) { c.abort( @@ -45,7 +54,7 @@ object IOAsync { ) } else try { - val awaitSym = typeOf[IOAsync.type].decl(TermName("await")) + val awaitSym = typeOf[AsyncAwaitDsl[Any]].decl(TermName("await")) def mark(t: DefDef): Tree = { import language.reflectiveCalls c.internal @@ -68,14 +77,16 @@ object IOAsync { val name = TypeName("stateMachine$async") // format: off q""" - final class $name(dispatcher: _root_.cats.effect.std.Dispatcher[IO], callback: _root_.cats.effect.std.IOAsync.Callback) extends _root_.cats.effect.std.IOStateMachine(dispatcher, callback) { - ${mark(q"""override def apply(tr$$async: _root_.cats.effect.kernel.Outcome[_root_.cats.effect.IO, _root_.scala.Throwable, _root_.scala.AnyRef]): _root_.scala.Unit = ${body}""")} + final class $name(dispatcher: _root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext], callback: _root_.cats.effect.std.AsyncAwaitDsl.Callback) extends _root_.cats.effect.std.AsyncAwaitStateMachine(dispatcher, callback) { + ${mark(q"""override def apply(tr$$async: _root_.cats.effect.kernel.Outcome[${c.prefix}._AsyncContext, _root_.scala.Throwable, _root_.scala.AnyRef]): _root_.scala.Unit = ${body}""")} } - _root_.cats.effect.std.Dispatcher[IO].use { dispatcher => - _root_.cats.effect.IO.async_[_root_.scala.AnyRef](cb => new $name(dispatcher, cb).start()) - }.handleErrorWith { - case _root_.cats.effect.std.IOAsync.CancelBridge => _root_.cats.effect.IO.canceled - case _root_.scala.util.control.NonFatal(other) => _root_.cats.effect.IO.raiseError(other) + ${c.prefix}._AsyncInstance.recoverWith { + _root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext].use { dispatcher => + ${c.prefix}._AsyncInstance.async_[_root_.scala.AnyRef](cb => new $name(dispatcher, cb).start()) + } + }{ + case _root_.cats.effect.std.AsyncAwaitDsl.CancelBridge => + ${c.prefix}._AsyncInstance.map(${c.prefix}._AsyncInstance.canceled)(_ => null.asInstanceOf[AnyRef]) }.asInstanceOf[${c.macroApplication.tpe}] """ } catch { @@ -91,10 +102,10 @@ object IOAsync { object CancelBridge extends Throwable with scala.util.control.NoStackTrace } -abstract class IOStateMachine( - dispatcher: Dispatcher[IO], - callback: IOAsync.Callback -) extends Function1[Outcome[IO, Throwable, AnyRef], Unit] { +abstract class AsyncAwaitStateMachine[F[_]]( + dispatcher: Dispatcher[F], + callback: AsyncAwaitDsl.Callback +)(implicit F: Sync[F]) extends Function1[Outcome[F, Throwable, AnyRef], Unit] { // FSM translated method //def apply(v1: Outcome[IO, Throwable, AnyRef]): Unit = ??? @@ -114,15 +125,16 @@ abstract class IOStateMachine( callback(Right(value)) } - protected def onComplete(f: IO[AnyRef]): Unit = { - dispatcher.unsafeRunAndForget(f.guaranteeCase(outcome => IO(this(outcome)))) + protected def onComplete(f: F[AnyRef]): Unit = { + dispatcher.unsafeRunAndForget(f.guaranteeCase(outcome => F.delay(this(outcome)))) } - protected def getCompleted(f: IO[AnyRef]): Outcome[IO, Throwable, AnyRef] = { + protected def getCompleted(f: F[AnyRef]): Outcome[F, Throwable, AnyRef] = { + val _ = f null } - protected def tryGet(tr: Outcome[IO, Throwable, AnyRef]): AnyRef = + protected def tryGet(tr: Outcome[F, Throwable, AnyRef]): AnyRef = tr match { case Outcome.Succeeded(value) => dispatcher.unsafeRunSync(value) @@ -130,11 +142,12 @@ abstract class IOStateMachine( callback(Left(e)) this // sentinel value to indicate the dispatch loop should exit. case Outcome.Canceled() => - callback(Left(IOAsync.CancelBridge)) + callback(Left(AsyncAwaitDsl.CancelBridge)) this } def start(): Unit = { + // Required to kickstart the async state machine. // `def apply` does not consult its argument when `state == 0`. apply(null) } diff --git a/src/test/scala/com/example/HelloWorldSuite.scala b/src/test/scala/com/example/HelloWorldSuite.scala index 707095e..a8f3e87 100644 --- a/src/test/scala/com/example/HelloWorldSuite.scala +++ b/src/test/scala/com/example/HelloWorldSuite.scala @@ -1,6 +1,6 @@ package com.example -import cats.effect.{IO, SyncIO} +import cats.effect.IO import munit.CatsEffectSuite import cats.effect.std.IOAsync._ import scala.concurrent.duration._ @@ -51,4 +51,13 @@ class HelloWorldSuite extends CatsEffectSuite { assertEquals(result, 0) } } + + test("side effects in the async block are suspended") { + var x = 0 + + val io = async { x += 1; await(IO(x)) } + + IO(assertEquals(x, 0)) *> io *> IO(assertEquals(x, 1)) + } + } diff --git a/src/test/scala/com/example/OptionTSuite.scala b/src/test/scala/com/example/OptionTSuite.scala new file mode 100644 index 0000000..bdeb411 --- /dev/null +++ b/src/test/scala/com/example/OptionTSuite.scala @@ -0,0 +1,22 @@ +package com.example + +import cats.effect.IO +import munit.CatsEffectSuite +import cats.effect.std.AsyncAwaitDsl +import cats.data.OptionT + +class OptionTSuite extends CatsEffectSuite { + + type Foo[A] = OptionT[IO, A] + object dsl extends AsyncAwaitDsl[Foo] + import dsl._ + + test("optionT") { + + val io: Foo[Int] = OptionT.none[IO, Int] + + val foo = async(await(io)) + + foo.value.map(it => assertEquals(it, None)) + } +}