Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow value constructor param to be any constructor parameter, rather than only the first #398

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ class ValueEnumSpec extends AnyFunSpec with Matchers with ValueEnumHelpers {
""" shouldNot compile
}

it("should compile when the value constructor parameter is not first") {
"""
sealed abstract class MyStatus(final val idx: Int, final val value: String) extends StringEnumEntry

object MyStatus extends StringEnum[MyStatus] {
case object PENDING extends MyStatus(1, "PENDING")
val values = findValues
}
""" should compile
}

it("should compile even when values are repeated if AllowAlias is extended") {
"""
sealed abstract class ContentTypeRepeated(val value: Long, name: String) extends LongEnumEntry with AllowAlias
Expand Down
68 changes: 19 additions & 49 deletions macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,51 +133,28 @@ In SBT settings:
""")
}

val repr = TypeRepr.of[A](using tpe)
val repr = TypeRepr.of[A]
val tpeSym = repr.typeSymbol

val valueRepr = TypeRepr.of[ValueType]

val ctorParams = tpeSym.primaryConstructor.paramSymss.flatten

val enumFields = repr.typeSymbol.fieldMembers.flatMap { field =>
ctorParams.zipWithIndex.find { case (p, i) =>
p.name == field.name && (p.tree match {
case term: Term =>
term.tpe <:< valueRepr

case _ =>
false
})
}
}.toSeq

val (valueField, valueParamIndex): (Symbol, Int) = {
if (enumFields.size == 1) {
enumFields.headOption
} else {
enumFields.find(_._1.name == "value")
val valueParamIndex = tpeSym.primaryConstructor.paramSymss
.filterNot(_.exists(_.isType))
.flatten
.zipWithIndex
.collectFirst {
case (p, i) if p.name == "value" => i
}
}.getOrElse {
Symbol.newVal(tpeSym, "value", valueRepr, Flags.Abstract, Symbol.noSymbol) -> 0
}

type IsValue[T <: ValueType] = T

object ConstVal {
@annotation.tailrec
def unapply(tree: Tree): Option[Constant] = tree match {
case NamedArg(nme, v) if (nme == valueField.name) =>
unapply(v)

case ValDef(nme, _, Some(v)) if (nme == valueField.name) =>
unapply(v)

case lit @ Literal(const) if (lit.tpe <:< valueRepr) =>
Some(const)

case _ =>
None
case NamedArg("value", v) => unapply(v)
case ValDef("value", _, Some(v)) => unapply(v)
case lit @ Literal(const) if (lit.tpe <:< valueRepr) => Some(const)
case _ => None
}
}

Expand All @@ -193,24 +170,18 @@ In SBT settings:
(for {
vof <- Expr.summon[ValueOf[h]]
constValue <- htpr.typeSymbol.tree match {
case ClassDef(_, _, spr, _, rhs) => {
val fromCtor = spr
case ClassDef(_, _, parents, _, statements) => {
val fromCtor = parents
.collectFirst {
case Apply(Select(New(id), _), args) if id.tpe <:< repr => args
case Apply(TypeApply(Select(New(id), _), _), args) if id.tpe <:< repr => args
}
.flatMap(_.lift(valueParamIndex).collect { case ConstVal(const) =>
.flatMap(_.lift(valueParamIndex.getOrElse(-1)).collect { case ConstVal(const) =>
lloydmeta marked this conversation as resolved.
Show resolved Hide resolved
const
})

fromCtor
.orElse(rhs.collectFirst { case ConstVal(v) => v })
.flatMap { const =>
cls.unapply(const.value)
}

def fromBody = statements.collectFirst { case ConstVal(v) => v }
fromCtor.orElse(fromBody).flatMap { const => cls.unapply(const.value) }
}

case _ =>
Option.empty[ValueType]
}
Expand All @@ -220,7 +191,7 @@ In SBT settings:

case None =>
report.errorAndAbort(
s"Fails to check value entry ${htpr.show} for enum ${repr.show}"
s"Failed to check value entry ${htpr.show} for enum ${repr.show}"
martijnhoekstra marked this conversation as resolved.
Show resolved Hide resolved
)
}
}
Expand All @@ -230,8 +201,7 @@ In SBT settings:
case Some(sum) =>
sum.asTerm.tpe.asType match {
case '[SumOf[a, t]] => collect[Tuple.Concat[t, tail]](instances, values)

case _ => Left(s"Invalid `Mirror.SumOf[${TypeRepr.of[h].show}]")
case _ => Left(s"Invalid `Mirror.SumOf[${TypeRepr.of[h].show}]")
}

case None =>
Expand All @@ -248,7 +218,7 @@ In SBT settings:
}
.mkString(", ")

Left(s"Values for ${valueField.name} are not discriminated subtypes: ${details}")
Left(s"Values for the `value` field field are not discriminated subtypes: ${details}")
martijnhoekstra marked this conversation as resolved.
Show resolved Hide resolved
} else {
Right(Expr ofList instances.reverse)
}
Expand Down
6 changes: 6 additions & 0 deletions macros/src/test/scala/enumeratum/CompilationSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,9 @@ object F {
case object F4 extends F(value = 4, "mike")

}

sealed abstract class G(val name: String, val value: Int)
object G {
val values = FindValEnums[G]
case object G1 extends G("gerald", 1)
}
Loading