diff --git a/internal/pointcut/pointcut.go b/internal/pointcut/pointcut.go new file mode 100644 index 00000000000..778e3a9c140 --- /dev/null +++ b/internal/pointcut/pointcut.go @@ -0,0 +1,43 @@ +package pointcut + +import ( + "sync" +) + +type Hook[T any] interface { + Invoke(args ...any) T + Set(func(args ...any) T) + Reset() +} + +type hook[T any] struct { + mu sync.RWMutex + fn func(args ...any) T +} + +func NewHook[T any]() Hook[T] { + return &hook[T]{} +} + +func (h *hook[T]) Invoke(args ...any) T { + h.mu.RLock() + defer h.mu.RUnlock() + + var result T + if h.fn != nil { + result = h.fn() + } + return result +} + +// TODO: allow to set maximum +func (h *hook[T]) Set(fn func(args ...any) T) { + h.mu.Lock() + defer h.mu.Unlock() + + h.fn = fn +} + +func (h *hook[T]) Reset() { + h.Set(nil) +} diff --git a/service/history/api/multioperation/api.go b/service/history/api/multioperation/api.go index 2a3a7b3be28..ce0d3de3888 100644 --- a/service/history/api/multioperation/api.go +++ b/service/history/api/multioperation/api.go @@ -38,6 +38,7 @@ import ( "go.temporal.io/server/common/definition" "go.temporal.io/server/common/locks" "go.temporal.io/server/common/persistence/visibility/manager" + "go.temporal.io/server/internal/pointcut" "go.temporal.io/server/service/history/api" "go.temporal.io/server/service/history/api/startworkflow" "go.temporal.io/server/service/history/api/updateworkflow" @@ -46,7 +47,8 @@ import ( ) var ( - multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.") + InBetweenLockAndStart = pointcut.NewHook[any]() + multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.") ) type ( @@ -206,6 +208,9 @@ func Invoke( } // workflow hasn't been started yet: start and then apply update + InBetweenLockAndStart.Invoke(req.WorkflowId) + + // Workflow hasn't been started yet: start and then apply update. resp, err := startAndUpdateWorkflow(ctx, starter, updater) var noStartErr *noStartError if errors.As(err, &noStartErr) { diff --git a/tests/update_workflow_test.go b/tests/update_workflow_test.go index fd4c4e5af46..99a6635619a 100644 --- a/tests/update_workflow_test.go +++ b/tests/update_workflow_test.go @@ -52,6 +52,7 @@ import ( "go.temporal.io/server/common/testing/protoutils" "go.temporal.io/server/common/testing/taskpoller" "go.temporal.io/server/common/testing/testvars" + "go.temporal.io/server/service/history/api/multioperation" "go.temporal.io/server/tests/testcore" "google.golang.org/protobuf/types/known/durationpb" ) @@ -5319,6 +5320,48 @@ func (s *UpdateWorkflowSuite) TestUpdateWithStart() { }) }) + s.Run("workflow start conflict", func() { + + s.Run("workflow id conflict policy fail: use-existing", func() { + tv := testvars.New(s.T()) + + startReq := startWorkflowReq(tv) + startReq.WorkflowIdConflictPolicy = enumspb.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING + updateReq := s.updateWorkflowRequest(tv, + &updatepb.WaitPolicy{LifecycleStage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED}, "1") + + // simulate a race condition + defer multioperation.InBetweenLockAndStart.Reset() + multioperation.InBetweenLockAndStart.Set(func(args ...any) any { + if args[0] == tv.WorkflowID() { + _, err := s.FrontendClient().StartWorkflowExecution(testcore.NewContext(), startReq) + s.NoError(err) + } + return nil + }) + + uwsCh := sendUpdateWithStart(testcore.NewContext(), startReq, updateReq) + + _, err := s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + fmt.Println("PROCESS TASK #1") + return &workflowservice.RespondWorkflowTaskCompletedRequest{}, nil + }) + s.NoError(err) + + _, err = s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + fmt.Println("PROCESS TASK #2") + return &workflowservice.RespondWorkflowTaskCompletedRequest{ + Messages: s.UpdateAcceptCompleteMessages(tv, task.Messages[0], "1"), + }, nil + }) + s.NoError(err) + + <-uwsCh + }) + }) + s.Run("return update rate limit error", func() { // lower maximum total number of updates for testing purposes maxTotalUpdates := 0