Skip to content

Commit

Permalink
fix constructor param fallback value
Browse files Browse the repository at this point in the history
  • Loading branch information
martijnhoekstra committed Jun 23, 2024
1 parent a1a0c80 commit d3a872c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 32 deletions.
52 changes: 20 additions & 32 deletions macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,33 +133,28 @@ In SBT settings:
""")
}

val repr = TypeRepr.of[A](using tpe)
val repr = TypeRepr.of[A]
val tpeSym = repr.typeSymbol

val valueRepr = TypeRepr.of[ValueType]

val ctorParams = tpeSym.primaryConstructor.paramSymss.flatten

val (valueField, valueParamIndex): (Symbol, Int) = ctorParams.zipWithIndex.find{ case (p, _) => p.name == "value"}.getOrElse {
report.errorAndAbort(s"Could not find 'value' field in ${tpeSym.name}")
}
val valueParamIndex = tpeSym.primaryConstructor.paramSymss
.filterNot(_.exists(_.isType))
.flatten
.zipWithIndex
.collectFirst {
case (p, i) if p.name == "value" => i
}

type IsValue[T <: ValueType] = T

object ConstVal {
@annotation.tailrec
def unapply(tree: Tree): Option[Constant] = tree match {
case NamedArg(nme, v) if (nme == valueField.name) =>
unapply(v)

case ValDef(nme, _, Some(v)) if (nme == valueField.name) =>
unapply(v)

case lit @ Literal(const) if (lit.tpe <:< valueRepr) =>
Some(const)

case _ =>
None
case NamedArg("value", v) => unapply(v)
case ValDef("value", _, Some(v)) => unapply(v)
case lit @ Literal(const) if (lit.tpe <:< valueRepr) => Some(const)
case _ => None
}
}

Expand All @@ -175,24 +170,18 @@ In SBT settings:
(for {
vof <- Expr.summon[ValueOf[h]]
constValue <- htpr.typeSymbol.tree match {
case ClassDef(_, _, spr, _, rhs) => {
val fromCtor = spr
case ClassDef(_, _, parents, _, statements) => {
val fromCtor = parents
.collectFirst {
case Apply(Select(New(id), _), args) if id.tpe <:< repr => args
case Apply(TypeApply(Select(New(id), _), _), args) if id.tpe <:< repr => args
}
.flatMap(_.lift(valueParamIndex).collect { case ConstVal(const) =>
.flatMap(_.lift(valueParamIndex.getOrElse(-1)).collect { case ConstVal(const) =>
const
})

fromCtor
.orElse(rhs.collectFirst { case ConstVal(v) => v })
.flatMap { const =>
cls.unapply(const.value)
}

def fromBody = statements.collectFirst { case ConstVal(v) => v }
fromCtor.orElse(fromBody).flatMap { const => cls.unapply(const.value) }
}

case _ =>
Option.empty[ValueType]
}
Expand All @@ -202,7 +191,7 @@ In SBT settings:

case None =>
report.errorAndAbort(
s"Fails to check value entry ${htpr.show} for enum ${repr.show}"
s"Failed to check value entry ${htpr.show} for enum ${repr.show}"
)
}
}
Expand All @@ -212,8 +201,7 @@ In SBT settings:
case Some(sum) =>
sum.asTerm.tpe.asType match {
case '[SumOf[a, t]] => collect[Tuple.Concat[t, tail]](instances, values)

case _ => Left(s"Invalid `Mirror.SumOf[${TypeRepr.of[h].show}]")
case _ => Left(s"Invalid `Mirror.SumOf[${TypeRepr.of[h].show}]")
}

case None =>
Expand All @@ -230,7 +218,7 @@ In SBT settings:
}
.mkString(", ")

Left(s"Values for ${valueField.name} are not discriminated subtypes: ${details}")
Left(s"Values for the `value` field field are not discriminated subtypes: ${details}")
} else {
Right(Expr ofList instances.reverse)
}
Expand Down
6 changes: 6 additions & 0 deletions macros/src/test/scala/enumeratum/CompilationSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,9 @@ object F {
case object F4 extends F(value = 4, "mike")

}

sealed abstract class G(val name: String, val value: Int)
object G {
val values = FindValEnums[G]
case object G1 extends G("gerald", 1)
}

0 comments on commit d3a872c

Please sign in to comment.