Skip to content

Commit

Permalink
allow value constructor param to be any constructor parameter, rather…
Browse files Browse the repository at this point in the history
… than only the first (#398)

* don't check for the correct tree shape for empty tree

* fix constructor param fallback value

* keep error messages per PR comments

* use Option.flatMap rather than sentinel

* fmt
  • Loading branch information
martijnhoekstra authored Jul 7, 2024
1 parent 1fbbd04 commit df708af
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ class ValueEnumSpec extends AnyFunSpec with Matchers with ValueEnumHelpers {
""" shouldNot compile
}

it("should compile when the value constructor parameter is not first") {
"""
sealed abstract class MyStatus(final val idx: Int, final val value: String) extends StringEnumEntry
object MyStatus extends StringEnum[MyStatus] {
case object PENDING extends MyStatus(1, "PENDING")
val values = findValues
}
""" should compile
}

it("should compile even when values are repeated if AllowAlias is extended") {
"""
sealed abstract class ContentTypeRepeated(val value: Long, name: String) extends LongEnumEntry with AllowAlias
Expand Down
78 changes: 24 additions & 54 deletions macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,51 +133,28 @@ In SBT settings:
""")
}

val repr = TypeRepr.of[A](using tpe)
val repr = TypeRepr.of[A]
val tpeSym = repr.typeSymbol

val valueRepr = TypeRepr.of[ValueType]

val ctorParams = tpeSym.primaryConstructor.paramSymss.flatten

val enumFields = repr.typeSymbol.fieldMembers.flatMap { field =>
ctorParams.zipWithIndex.find { case (p, i) =>
p.name == field.name && (p.tree match {
case term: Term =>
term.tpe <:< valueRepr

case _ =>
false
})
val valueParamIndex = tpeSym.primaryConstructor.paramSymss
.filterNot(_.exists(_.isType))
.flatten
.zipWithIndex
.collectFirst {
case (p, i) if p.name == "value" => i
}
}.toSeq

val (valueField, valueParamIndex): (Symbol, Int) = {
if (enumFields.size == 1) {
enumFields.headOption
} else {
enumFields.find(_._1.name == "value")
}
}.getOrElse {
Symbol.newVal(tpeSym, "value", valueRepr, Flags.Abstract, Symbol.noSymbol) -> 0
}

type IsValue[T <: ValueType] = T

object ConstVal {
@annotation.tailrec
def unapply(tree: Tree): Option[Constant] = tree match {
case NamedArg(nme, v) if (nme == valueField.name) =>
unapply(v)

case ValDef(nme, _, Some(v)) if (nme == valueField.name) =>
unapply(v)

case lit @ Literal(const) if (lit.tpe <:< valueRepr) =>
Some(const)

case _ =>
None
case NamedArg("value", v) => unapply(v)
case ValDef("value", _, Some(v)) => unapply(v)
case lit @ Literal(const) if (lit.tpe <:< valueRepr) => Some(const)
case _ => None
}
}

Expand All @@ -193,24 +170,18 @@ In SBT settings:
(for {
vof <- Expr.summon[ValueOf[h]]
constValue <- htpr.typeSymbol.tree match {
case ClassDef(_, _, spr, _, rhs) => {
val fromCtor = spr
.collectFirst {
case Apply(Select(New(id), _), args) if id.tpe <:< repr => args
case Apply(TypeApply(Select(New(id), _), _), args) if id.tpe <:< repr => args
}
.flatMap(_.lift(valueParamIndex).collect { case ConstVal(const) =>
const
})

fromCtor
.orElse(rhs.collectFirst { case ConstVal(v) => v })
.flatMap { const =>
cls.unapply(const.value)
}

case ClassDef(_, _, parents, _, statements) => {
val fromCtor = valueParamIndex.flatMap { (ix: Int) =>
parents
.collectFirst {
case Apply(Select(New(id), _), args) if id.tpe <:< repr => args
case Apply(TypeApply(Select(New(id), _), _), args) if id.tpe <:< repr => args
}
.flatMap(_.lift(ix).collect { case ConstVal(const) => const })
}
def fromBody = statements.collectFirst { case ConstVal(v) => v }
fromCtor.orElse(fromBody).flatMap { const => cls.unapply(const.value) }
}

case _ =>
Option.empty[ValueType]
}
Expand All @@ -230,8 +201,7 @@ In SBT settings:
case Some(sum) =>
sum.asTerm.tpe.asType match {
case '[SumOf[a, t]] => collect[Tuple.Concat[t, tail]](instances, values)

case _ => Left(s"Invalid `Mirror.SumOf[${TypeRepr.of[h].show}]")
case _ => Left(s"Invalid `Mirror.SumOf[${TypeRepr.of[h].show}]")
}

case None =>
Expand All @@ -248,7 +218,7 @@ In SBT settings:
}
.mkString(", ")

Left(s"Values for ${valueField.name} are not discriminated subtypes: ${details}")
Left(s"Values value are not discriminated subtypes: ${details}")
} else {
Right(Expr ofList instances.reverse)
}
Expand Down
6 changes: 6 additions & 0 deletions macros/src/test/scala/enumeratum/CompilationSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,9 @@ object F {
case object F4 extends F(value = 4, "mike")

}

sealed abstract class G(val name: String, val value: Int)
object G {
val values = FindValEnums[G]
case object G1 extends G("gerald", 1)
}

0 comments on commit df708af

Please sign in to comment.