Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow value constructor param to be any constructor parameter, rather than only the first #398

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ class ValueEnumSpec extends AnyFunSpec with Matchers with ValueEnumHelpers {
""" shouldNot compile
}

it("should compile when the value constructor parameter is not first") {
"""
sealed abstract class MyStatus(final val idx: Int, final val value: String) extends StringEnumEntry

object MyStatus extends StringEnum[MyStatus] {
case object PENDING extends MyStatus(1, "PENDING")
val values = findValues
}
""" should compile
}

it("should compile even when values are repeated if AllowAlias is extended") {
"""
sealed abstract class ContentTypeRepeated(val value: Long, name: String) extends LongEnumEntry with AllowAlias
Expand Down
78 changes: 24 additions & 54 deletions macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,51 +133,28 @@ In SBT settings:
""")
}

val repr = TypeRepr.of[A](using tpe)
val repr = TypeRepr.of[A]
val tpeSym = repr.typeSymbol

val valueRepr = TypeRepr.of[ValueType]

val ctorParams = tpeSym.primaryConstructor.paramSymss.flatten

val enumFields = repr.typeSymbol.fieldMembers.flatMap { field =>
ctorParams.zipWithIndex.find { case (p, i) =>
p.name == field.name && (p.tree match {
case term: Term =>
term.tpe <:< valueRepr

case _ =>
false
})
val valueParamIndex = tpeSym.primaryConstructor.paramSymss
.filterNot(_.exists(_.isType))
.flatten
.zipWithIndex
.collectFirst {
case (p, i) if p.name == "value" => i
}
}.toSeq

val (valueField, valueParamIndex): (Symbol, Int) = {
if (enumFields.size == 1) {
enumFields.headOption
} else {
enumFields.find(_._1.name == "value")
}
}.getOrElse {
Symbol.newVal(tpeSym, "value", valueRepr, Flags.Abstract, Symbol.noSymbol) -> 0
}

type IsValue[T <: ValueType] = T

object ConstVal {
@annotation.tailrec
def unapply(tree: Tree): Option[Constant] = tree match {
case NamedArg(nme, v) if (nme == valueField.name) =>
unapply(v)

case ValDef(nme, _, Some(v)) if (nme == valueField.name) =>
unapply(v)

case lit @ Literal(const) if (lit.tpe <:< valueRepr) =>
Some(const)

case _ =>
None
case NamedArg("value", v) => unapply(v)
case ValDef("value", _, Some(v)) => unapply(v)
case lit @ Literal(const) if (lit.tpe <:< valueRepr) => Some(const)
case _ => None
}
}

Expand All @@ -193,24 +170,18 @@ In SBT settings:
(for {
vof <- Expr.summon[ValueOf[h]]
constValue <- htpr.typeSymbol.tree match {
case ClassDef(_, _, spr, _, rhs) => {
val fromCtor = spr
.collectFirst {
case Apply(Select(New(id), _), args) if id.tpe <:< repr => args
case Apply(TypeApply(Select(New(id), _), _), args) if id.tpe <:< repr => args
}
.flatMap(_.lift(valueParamIndex).collect { case ConstVal(const) =>
const
})

fromCtor
.orElse(rhs.collectFirst { case ConstVal(v) => v })
.flatMap { const =>
cls.unapply(const.value)
}

case ClassDef(_, _, parents, _, statements) => {
val fromCtor = valueParamIndex.flatMap { (ix: Int) =>
parents
.collectFirst {
case Apply(Select(New(id), _), args) if id.tpe <:< repr => args
case Apply(TypeApply(Select(New(id), _), _), args) if id.tpe <:< repr => args
}
.flatMap(_.lift(ix).collect { case ConstVal(const) => const })
}
def fromBody = statements.collectFirst { case ConstVal(v) => v }
fromCtor.orElse(fromBody).flatMap { const => cls.unapply(const.value) }
}

case _ =>
Option.empty[ValueType]
}
Expand All @@ -230,8 +201,7 @@ In SBT settings:
case Some(sum) =>
sum.asTerm.tpe.asType match {
case '[SumOf[a, t]] => collect[Tuple.Concat[t, tail]](instances, values)

case _ => Left(s"Invalid `Mirror.SumOf[${TypeRepr.of[h].show}]")
case _ => Left(s"Invalid `Mirror.SumOf[${TypeRepr.of[h].show}]")
}

case None =>
Expand All @@ -248,7 +218,7 @@ In SBT settings:
}
.mkString(", ")

Left(s"Values for ${valueField.name} are not discriminated subtypes: ${details}")
Left(s"Values value are not discriminated subtypes: ${details}")
} else {
Right(Expr ofList instances.reverse)
}
Expand Down
6 changes: 6 additions & 0 deletions macros/src/test/scala/enumeratum/CompilationSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,9 @@ object F {
case object F4 extends F(value = 4, "mike")

}

sealed abstract class G(val name: String, val value: Int)
object G {
val values = FindValEnums[G]
case object G1 extends G("gerald", 1)
}
Loading