Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Cogen for union types #51

Merged
merged 1 commit into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/main/derived/extras/UnionCogens.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io.github.martinhh.derived.extras

import org.scalacheck.Cogen
import org.scalacheck.rng.Seed

import scala.compiletime.summonInline
import scala.reflect.TypeTest

// combines a Cogen with a TypeTest that allows matching on the type of the Cogen
private case class TypedCogen[A](typeTest: TypeTest[Any, A], cogen: Cogen[A]):
def tryPerturb(seed: Seed, a: Any): Option[Seed] = typeTest.unapply(a).map(cogen.perturb(seed, _))

private object TypedCogen:
inline given derived[A]: TypedCogen[A] =
TypedCogen(summonInline[TypeTest[Any, A]], summonInline[Cogen[A]])

// type for accumulating the TypedCogen-instances of a union
private sealed trait TypedCogens[A]:
def instances: List[TypedCogen[?]]

private object TypedCogens:
inline given derived[A]: SingleTypedCogens[A] =
SingleTypedCogens(summonInline[TypedCogen[A]])

private case class SingleTypedCogens[A](instance: TypedCogen[A]) extends TypedCogens[A]:
override def instances: List[TypedCogen[?]] = List(instance)

private case class UnionTypedCogens[A](instances: List[TypedCogen[?]]) extends TypedCogens[A]:
def toCogen: Cogen[A] =
Cogen { (seed: Seed, a: A) =>
object TheUnapply:
def unapply(typedCogen: TypedCogen[?]): Option[Seed] = typedCogen.tryPerturb(seed, a)
val seedOpt = instances.zipWithIndex.collectFirst { case (TheUnapply(seed), i) =>
Cogen.perturb(seed, i)
}
assert(seedOpt.isDefined, "This case should be unreachable")
seedOpt.get
}

private trait UnionCogens:
transparent inline given unionTypedCogensMacro[X]: UnionTypedCogens[X] =
io.github.martinhh.derived.extras.unionTypedCogensMacro

transparent inline given cogenUnion[X](using inline bg: UnionTypedCogens[X]): Cogen[X] =
bg.toCogen
4 changes: 2 additions & 2 deletions src/main/derived/extras/api.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.github.martinhh.derived.extras

object union extends UnionArbitraries
object union extends UnionArbitraries with UnionCogens

object literal extends LiteralArbitraries

object all extends LiteralArbitraries with UnionArbitraries
object all extends LiteralArbitraries with UnionArbitraries with UnionCogens
21 changes: 20 additions & 1 deletion src/main/derived/extras/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.github.martinhh.derived.extras

import scala.quoted.*

// macro based on this StackOverflow answer by Dmytro Mitin: https://stackoverflow.com/a/78567397/6152669
// macros for union based on this StackOverflow answer by Dmytro Mitin: https://stackoverflow.com/a/78567397/6152669
private def unionGens[X: Type](using Quotes): Expr[UnionGens[X]] =
import quotes.reflect.*
TypeRepr.of[X] match
Expand All @@ -21,3 +21,22 @@ private def unionGens[X: Type](using Quotes): Expr[UnionGens[X]] =

private transparent inline given unionGensMacro[X]: UnionGens[X] =
${ unionGens[X] }

private def unionTypedCogens[X: Type](using Quotes): Expr[UnionTypedCogens[X]] =
import quotes.reflect.*
TypeRepr.of[X] match
case OrType(l, r) =>
(l.asType, r.asType) match
case ('[a], '[b]) =>
(Expr.summon[TypedCogens[a]], Expr.summon[TypedCogens[b]]) match
case (Some(aInst), Some(bInst)) =>
'{
val x = $aInst
val y = $bInst
UnionTypedCogens[X](x.instances ++ y.instances)
}.asExprOf[UnionTypedCogens[X]]
case (_, _) =>
report.errorAndAbort(s"Could not summon UnionTypedCogens")

private transparent inline given unionTypedCogensMacro[X]: UnionTypedCogens[X] =
${ unionTypedCogens[X] }
2 changes: 1 addition & 1 deletion src/test/ArbitrarySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.scalacheck.Gen
import org.scalacheck.Gen.Parameters
import org.scalacheck.rng.Seed

class ArbitrarySuite extends munit.FunSuite:
trait ArbitrarySuite extends munit.BaseFunSuite:

protected def equalValues[T](
expectedGen: Gen[T],
Expand Down
16 changes: 1 addition & 15 deletions src/test/CogenDerivingSuite.scala
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
package io.github.martinhh

import org.scalacheck.Arbitrary
import org.scalacheck.Arbitrary.arbitrary
import org.scalacheck.Cogen
import org.scalacheck.Gen
import org.scalacheck.Prop
import org.scalacheck.rng.Seed

class CogenDerivingSuite extends munit.ScalaCheckSuite:

private def equalValues[T](
expectedCogen: Cogen[T]
)(using arbSeed: Arbitrary[Seed], arbT: Arbitrary[T], derivedCogen: Cogen[T]): Prop =
Prop.forAll { (s: Seed, t: T) =>
assertEquals(derivedCogen.perturb(s, t), expectedCogen.perturb(s, t))
}

import CogenDerivingSuite.arbSeed
class CogenDerivingSuite extends CogenSuite:

test("deriveCogen allows to derive a given without loop of given definition") {
given cogen: Cogen[SimpleCaseClass] = derived.scalacheck.deriveCogen
Expand Down Expand Up @@ -110,6 +99,3 @@ class CogenDerivingSuite extends munit.ScalaCheckSuite:
test("supports enums with up to 24 members (if -Xmax-inlines=32)") {
summon[Cogen[MaxEnum]]
}

object CogenDerivingSuite:
private given arbSeed: Arbitrary[Seed] = Arbitrary(arbitrary[Long].map(Seed.apply))
18 changes: 18 additions & 0 deletions src/test/CogenSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package io.github.martinhh

import org.scalacheck.Arbitrary
import org.scalacheck.Arbitrary.arbitrary
import org.scalacheck.Cogen
import org.scalacheck.Prop
import org.scalacheck.rng.Seed

trait CogenSuite extends munit.ScalaCheckSuite:

protected given arbSeed: Arbitrary[Seed] = Arbitrary(arbitrary[Long].map(Seed.apply))

protected def equalValues[T](
expectedCogen: Cogen[T]
)(using arbSeed: Arbitrary[Seed], arbT: Arbitrary[T], derivedCogen: Cogen[T]): Prop =
Prop.forAll { (s: Seed, t: T) =>
assertEquals(derivedCogen.perturb(s, t), expectedCogen.perturb(s, t))
}
38 changes: 36 additions & 2 deletions src/test/ExtrasSuite.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package io.github.martinhh

import io.github.martinhh.derived.extras.TypedCogen
import io.github.martinhh.derived.extras.all.given

import org.scalacheck.Arbitrary
import org.scalacheck.Arbitrary.arbitrary
import org.scalacheck.Cogen
import org.scalacheck.Gen

class ExtrasSuite extends ArbitrarySuite:
class ExtrasSuite extends ArbitrarySuite with CogenSuite:

test("Arbitrary for union of two types") {
type TheUnion = String | Int
Expand Down Expand Up @@ -50,3 +51,36 @@ class ExtrasSuite extends ArbitrarySuite:
)
equalValues[TheUnion](expectedGen)
}

test("Cogen for union of three types") {
type TheUnion = String | Int | Boolean
val expectedCogen: Cogen[TheUnion] =
Cogen { (seed, value) =>
value match
case s: String =>
Cogen.perturb(
Cogen.perturb(
seed,
s
),
0
)
case i: Int =>
Cogen.perturb(
Cogen.perturb(
seed,
i
),
1
)
case b: Boolean =>
Cogen.perturb(
Cogen.perturb(
seed,
b
),
2
)
}
equalValues(expectedCogen)
}
Loading