Skip to content

Commit

Permalink
optimize Gen.zipWith code generator
Browse files Browse the repository at this point in the history
  • Loading branch information
satorg committed Aug 9, 2024
1 parent 13f7163 commit 8df466c
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions project/codegen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,6 @@ object codegen {
s"$g.flatMap(($t: $T) => $acc)"
}

def flatMappedGeneratorsWithFun(i: Int, f: String, s: Seq[(String, String)]): String =
s.init.foldRight(s"${s.last._2}.map { ${s.last._1} => $f(${vals(i)}) }") {
case ((t, g), acc) =>
val T = t.toUpperCase
s"$g.flatMap(($t: $T) => $acc)"
}

def vals(i: Int) = csv(idents("t", i))

def coImplicits(i: Int) = (1 to i).map(n => s"co$n: Cogen[T$n]").mkString(",")
Expand Down Expand Up @@ -114,14 +107,30 @@ object codegen {
}

def zipWith(i: Int) = {
val gens = flatMappedGeneratorsWithFun(i, "f", idents("t", i) zip idents("g", i))
val f = "f"
val tR = "R"
val ts = idents("t", i) // Seq(t1, ... ti)
val tTs = idents("T", i) // Seq(T1, ..., Ti)
val gs = idents("g", i) // Seq(g1, ..., gi)
val tTsCsv = csv(tTs) // "T1, ..., Ti"
val tsCsv = csv(ts) // "t1, ..., ti"
val tTts = tTs.zip(ts) // Seq((T1, t1), ..., (Ti, ti))
val tTtgs = tTts.zip(gs) // Seq(((T1, t1), g1), ..., ((Ti, ti), gi))

val ((_, ti), gi) = tTtgs.last
val gens =
tTtgs.init.foldRight(s"$gi.map { $ti => $f($tsCsv) }") {
case (((tT, t), g), acc) =>
s"$g.flatMap(($t: $tT) => $acc)"
}

s"""
| /** Combines the given generators into a new generator of the given result type
| * with help of the given mapping function. */
| def zipWith[${types(i)},R](
| def zipWith[$tTsCsv, $tR](
| ${wrappedArgs("Gen", i)}
| )(
| f: (${types(i)}) => R
| $f: ($tTsCsv) => $tR
| ): Gen[R] =
| $gens
|""".stripMargin
Expand Down

0 comments on commit 8df466c

Please sign in to comment.