diff --git a/macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala b/macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala index 89abe2af..64347caf 100644 --- a/macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala +++ b/macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala @@ -133,33 +133,28 @@ In SBT settings: """) } - val repr = TypeRepr.of[A](using tpe) + val repr = TypeRepr.of[A] val tpeSym = repr.typeSymbol val valueRepr = TypeRepr.of[ValueType] - val ctorParams = tpeSym.primaryConstructor.paramSymss.flatten - - val (valueField, valueParamIndex): (Symbol, Int) = ctorParams.zipWithIndex.find{ case (p, _) => p.name == "value"}.getOrElse { - report.errorAndAbort(s"Could not find 'value' field in ${tpeSym.name}") - } + val valueParamIndex = tpeSym.primaryConstructor.paramSymss + .filterNot(_.exists(_.isType)) + .flatten + .zipWithIndex + .collectFirst { + case (p, i) if p.name == "value" => i + } type IsValue[T <: ValueType] = T object ConstVal { @annotation.tailrec def unapply(tree: Tree): Option[Constant] = tree match { - case NamedArg(nme, v) if (nme == valueField.name) => - unapply(v) - - case ValDef(nme, _, Some(v)) if (nme == valueField.name) => - unapply(v) - - case lit @ Literal(const) if (lit.tpe <:< valueRepr) => - Some(const) - - case _ => - None + case NamedArg("value", v) => unapply(v) + case ValDef("value", _, Some(v)) => unapply(v) + case lit @ Literal(const) if (lit.tpe <:< valueRepr) => Some(const) + case _ => None } } @@ -175,24 +170,18 @@ In SBT settings: (for { vof <- Expr.summon[ValueOf[h]] constValue <- htpr.typeSymbol.tree match { - case ClassDef(_, _, spr, _, rhs) => { - val fromCtor = spr + case ClassDef(_, _, parents, _, statements) => { + val fromCtor = parents .collectFirst { case Apply(Select(New(id), _), args) if id.tpe <:< repr => args case Apply(TypeApply(Select(New(id), _), _), args) if id.tpe <:< repr => args } - .flatMap(_.lift(valueParamIndex).collect { case ConstVal(const) => + .flatMap(_.lift(valueParamIndex.getOrElse(-1)).collect { case ConstVal(const) => const }) - - fromCtor - .orElse(rhs.collectFirst { case ConstVal(v) => v }) - .flatMap { const => - cls.unapply(const.value) - } - + def fromBody = statements.collectFirst { case ConstVal(v) => v } + fromCtor.orElse(fromBody).flatMap { const => cls.unapply(const.value) } } - case _ => Option.empty[ValueType] } @@ -202,7 +191,7 @@ In SBT settings: case None => report.errorAndAbort( - s"Fails to check value entry ${htpr.show} for enum ${repr.show}" + s"Failed to check value entry ${htpr.show} for enum ${repr.show}" ) } } @@ -212,8 +201,7 @@ In SBT settings: case Some(sum) => sum.asTerm.tpe.asType match { case '[SumOf[a, t]] => collect[Tuple.Concat[t, tail]](instances, values) - - case _ => Left(s"Invalid `Mirror.SumOf[${TypeRepr.of[h].show}]") + case _ => Left(s"Invalid `Mirror.SumOf[${TypeRepr.of[h].show}]") } case None => @@ -230,7 +218,7 @@ In SBT settings: } .mkString(", ") - Left(s"Values for ${valueField.name} are not discriminated subtypes: ${details}") + Left(s"Values for the `value` field field are not discriminated subtypes: ${details}") } else { Right(Expr ofList instances.reverse) } diff --git a/macros/src/test/scala/enumeratum/CompilationSpec.scala b/macros/src/test/scala/enumeratum/CompilationSpec.scala index 8fa943b0..b94eba68 100644 --- a/macros/src/test/scala/enumeratum/CompilationSpec.scala +++ b/macros/src/test/scala/enumeratum/CompilationSpec.scala @@ -136,3 +136,9 @@ object F { case object F4 extends F(value = 4, "mike") } + +sealed abstract class G(val name: String, val value: Int) +object G { + val values = FindValEnums[G] + case object G1 extends G("gerald", 1) +}