diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala index b30094384d906..1027468553828 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala @@ -64,6 +64,8 @@ private[transaction] sealed trait TransactionState { * Get the name of this state. This is exposed through the `DescribeTransactions` API. */ def name: String + + def validPreviousStates: Set[TransactionState] } /** @@ -75,6 +77,7 @@ private[transaction] sealed trait TransactionState { private[transaction] case object Empty extends TransactionState { val id: Byte = 0 val name: String = "Empty" + val validPreviousStates: Set[TransactionState] = Set(Empty, CompleteCommit, CompleteAbort) } /** @@ -88,6 +91,7 @@ private[transaction] case object Empty extends TransactionState { private[transaction] case object Ongoing extends TransactionState { val id: Byte = 1 val name: String = "Ongoing" + val validPreviousStates: Set[TransactionState] = Set(Ongoing, Empty, CompleteCommit, CompleteAbort) } /** @@ -98,6 +102,7 @@ private[transaction] case object Ongoing extends TransactionState { private[transaction] case object PrepareCommit extends TransactionState { val id: Byte = 2 val name: String = "PrepareCommit" + val validPreviousStates: Set[TransactionState] = Set(Ongoing) } /** @@ -108,6 +113,7 @@ private[transaction] case object PrepareCommit extends TransactionState { private[transaction] case object PrepareAbort extends TransactionState { val id: Byte = 3 val name: String = "PrepareAbort" + val validPreviousStates: Set[TransactionState] = Set(Ongoing, PrepareEpochFence) } /** @@ -118,6 +124,7 @@ private[transaction] case object PrepareAbort extends TransactionState { private[transaction] case object CompleteCommit extends TransactionState { val id: Byte = 4 val name: String = "CompleteCommit" + val validPreviousStates: Set[TransactionState] = Set(PrepareCommit) } /** @@ -128,6 +135,7 @@ private[transaction] case object CompleteCommit extends TransactionState { private[transaction] case object CompleteAbort extends TransactionState { val id: Byte = 5 val name: String = "CompleteAbort" + val validPreviousStates: Set[TransactionState] = Set(PrepareAbort) } /** @@ -136,6 +144,7 @@ private[transaction] case object CompleteAbort extends TransactionState { private[transaction] case object Dead extends TransactionState { val id: Byte = 6 val name: String = "Dead" + val validPreviousStates: Set[TransactionState] = Set(Empty, CompleteAbort, CompleteCommit) } /** @@ -145,6 +154,7 @@ private[transaction] case object Dead extends TransactionState { private[transaction] case object PrepareEpochFence extends TransactionState { val id: Byte = 7 val name: String = "PrepareEpochFence" + val validPreviousStates: Set[TransactionState] = Set(Ongoing) } private[transaction] object TransactionMetadata { @@ -162,20 +172,6 @@ private[transaction] object TransactionMetadata { new TransactionMetadata(transactionalId, producerId, lastProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp) - def isValidTransition(oldState: TransactionState, newState: TransactionState): Boolean = - TransactionMetadata.validPreviousStates(newState).contains(oldState) - - private val validPreviousStates: Map[TransactionState, Set[TransactionState]] = - Map(Empty -> Set(Empty, CompleteCommit, CompleteAbort), - Ongoing -> Set(Ongoing, Empty, CompleteCommit, CompleteAbort), - PrepareCommit -> Set(Ongoing), - PrepareAbort -> Set(Ongoing, PrepareEpochFence), - CompleteCommit -> Set(PrepareCommit), - CompleteAbort -> Set(PrepareAbort), - Dead -> Set(Empty, CompleteAbort, CompleteCommit), - PrepareEpochFence -> Set(Ongoing) - ) - def isEpochExhausted(producerEpoch: Short): Boolean = producerEpoch >= Short.MaxValue - 1 } @@ -385,7 +381,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, throw new IllegalArgumentException(s"Illegal new producer epoch $newEpoch") // check that the new state transition is valid and update the pending state if necessary - if (TransactionMetadata.validPreviousStates(newState).contains(state)) { + if (newState.validPreviousStates.contains(state)) { val transitMetadata = TxnTransitMetadata(newProducerId, producerId, newEpoch, newLastEpoch, newTxnTimeoutMs, newState, newTopicPartitions, newTxnStartTimestamp, updateTimestamp) debug(s"TransactionalId $transactionalId prepare transition from $state to $transitMetadata")