From df708af06073ae7767a93fdce7109d25883c63ad Mon Sep 17 00:00:00 2001 From: Martijn Hoekstra Date: Sun, 7 Jul 2024 14:37:31 +0200 Subject: [PATCH] allow value constructor param to be any constructor parameter, rather than only the first (#398) * don't check for the correct tree shape for empty tree * fix constructor param fallback value * keep error messages per PR comments * use Option.flatMap rather than sentinel * fmt --- .../enumeratum/values/ValueEnumSpec.scala | 11 +++ .../scala-3/enumeratum/ValueEnumMacros.scala | 78 ++++++------------- .../scala/enumeratum/CompilationSpec.scala | 6 ++ 3 files changed, 41 insertions(+), 54 deletions(-) diff --git a/enumeratum-core/src/test/scala/enumeratum/values/ValueEnumSpec.scala b/enumeratum-core/src/test/scala/enumeratum/values/ValueEnumSpec.scala index ecfabbb9..9d14142b 100644 --- a/enumeratum-core/src/test/scala/enumeratum/values/ValueEnumSpec.scala +++ b/enumeratum-core/src/test/scala/enumeratum/values/ValueEnumSpec.scala @@ -118,6 +118,17 @@ class ValueEnumSpec extends AnyFunSpec with Matchers with ValueEnumHelpers { """ shouldNot compile } + it("should compile when the value constructor parameter is not first") { + """ + sealed abstract class MyStatus(final val idx: Int, final val value: String) extends StringEnumEntry + + object MyStatus extends StringEnum[MyStatus] { + case object PENDING extends MyStatus(1, "PENDING") + val values = findValues + } + """ should compile + } + it("should compile even when values are repeated if AllowAlias is extended") { """ sealed abstract class ContentTypeRepeated(val value: Long, name: String) extends LongEnumEntry with AllowAlias diff --git a/macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala b/macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala index 4ab3b219..d3f856dc 100644 --- a/macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala +++ b/macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala @@ -133,51 +133,28 @@ In SBT settings: """) } - val repr = TypeRepr.of[A](using tpe) + val repr = TypeRepr.of[A] val tpeSym = repr.typeSymbol val valueRepr = TypeRepr.of[ValueType] - val ctorParams = tpeSym.primaryConstructor.paramSymss.flatten - - val enumFields = repr.typeSymbol.fieldMembers.flatMap { field => - ctorParams.zipWithIndex.find { case (p, i) => - p.name == field.name && (p.tree match { - case term: Term => - term.tpe <:< valueRepr - - case _ => - false - }) + val valueParamIndex = tpeSym.primaryConstructor.paramSymss + .filterNot(_.exists(_.isType)) + .flatten + .zipWithIndex + .collectFirst { + case (p, i) if p.name == "value" => i } - }.toSeq - - val (valueField, valueParamIndex): (Symbol, Int) = { - if (enumFields.size == 1) { - enumFields.headOption - } else { - enumFields.find(_._1.name == "value") - } - }.getOrElse { - Symbol.newVal(tpeSym, "value", valueRepr, Flags.Abstract, Symbol.noSymbol) -> 0 - } type IsValue[T <: ValueType] = T object ConstVal { @annotation.tailrec def unapply(tree: Tree): Option[Constant] = tree match { - case NamedArg(nme, v) if (nme == valueField.name) => - unapply(v) - - case ValDef(nme, _, Some(v)) if (nme == valueField.name) => - unapply(v) - - case lit @ Literal(const) if (lit.tpe <:< valueRepr) => - Some(const) - - case _ => - None + case NamedArg("value", v) => unapply(v) + case ValDef("value", _, Some(v)) => unapply(v) + case lit @ Literal(const) if (lit.tpe <:< valueRepr) => Some(const) + case _ => None } } @@ -193,24 +170,18 @@ In SBT settings: (for { vof <- Expr.summon[ValueOf[h]] constValue <- htpr.typeSymbol.tree match { - case ClassDef(_, _, spr, _, rhs) => { - val fromCtor = spr - .collectFirst { - case Apply(Select(New(id), _), args) if id.tpe <:< repr => args - case Apply(TypeApply(Select(New(id), _), _), args) if id.tpe <:< repr => args - } - .flatMap(_.lift(valueParamIndex).collect { case ConstVal(const) => - const - }) - - fromCtor - .orElse(rhs.collectFirst { case ConstVal(v) => v }) - .flatMap { const => - cls.unapply(const.value) - } - + case ClassDef(_, _, parents, _, statements) => { + val fromCtor = valueParamIndex.flatMap { (ix: Int) => + parents + .collectFirst { + case Apply(Select(New(id), _), args) if id.tpe <:< repr => args + case Apply(TypeApply(Select(New(id), _), _), args) if id.tpe <:< repr => args + } + .flatMap(_.lift(ix).collect { case ConstVal(const) => const }) + } + def fromBody = statements.collectFirst { case ConstVal(v) => v } + fromCtor.orElse(fromBody).flatMap { const => cls.unapply(const.value) } } - case _ => Option.empty[ValueType] } @@ -230,8 +201,7 @@ In SBT settings: case Some(sum) => sum.asTerm.tpe.asType match { case '[SumOf[a, t]] => collect[Tuple.Concat[t, tail]](instances, values) - - case _ => Left(s"Invalid `Mirror.SumOf[${TypeRepr.of[h].show}]") + case _ => Left(s"Invalid `Mirror.SumOf[${TypeRepr.of[h].show}]") } case None => @@ -248,7 +218,7 @@ In SBT settings: } .mkString(", ") - Left(s"Values for ${valueField.name} are not discriminated subtypes: ${details}") + Left(s"Values value are not discriminated subtypes: ${details}") } else { Right(Expr ofList instances.reverse) } diff --git a/macros/src/test/scala/enumeratum/CompilationSpec.scala b/macros/src/test/scala/enumeratum/CompilationSpec.scala index 8fa943b0..b94eba68 100644 --- a/macros/src/test/scala/enumeratum/CompilationSpec.scala +++ b/macros/src/test/scala/enumeratum/CompilationSpec.scala @@ -136,3 +136,9 @@ object F { case object F4 extends F(value = 4, "mike") } + +sealed abstract class G(val name: String, val value: Int) +object G { + val values = FindValEnums[G] + case object G1 extends G("gerald", 1) +}