From e6bd3b2f11ace4ac4e9494f44070f4b6b8f7bcea Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Thu, 5 Dec 2024 07:58:46 -0800 Subject: [PATCH] Pointcut --- internal/pointcut/pointcut.go | 43 +++++++++ service/history/api/multioperation/api.go | 7 +- tests/update_workflow_test.go | 109 ++++++++++++++++++++++ 3 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 internal/pointcut/pointcut.go diff --git a/internal/pointcut/pointcut.go b/internal/pointcut/pointcut.go new file mode 100644 index 00000000000..65229771bea --- /dev/null +++ b/internal/pointcut/pointcut.go @@ -0,0 +1,43 @@ +package pointcut + +import ( + "sync" +) + +type Hook[T any] interface { + Invoke(args ...any) T + Set(func() T) + Reset() +} + +type hook[T any] struct { + mu sync.RWMutex + fn func() T +} + +func NewHook[T any]() Hook[T] { + return &hook[T]{} +} + +func (h *hook[T]) Invoke(args ...any) T { + h.mu.RLock() + defer h.mu.RUnlock() + + var result T + if h.fn != nil { + result = h.fn() + } + return result +} + +// TODO: allow to set maximum +func (h *hook[T]) Set(fn func() T) { + h.mu.Lock() + defer h.mu.Unlock() + + h.fn = fn +} + +func (h *hook[T]) Reset() { + h.Set(nil) +} diff --git a/service/history/api/multioperation/api.go b/service/history/api/multioperation/api.go index 2a3a7b3be28..992ca8b92c0 100644 --- a/service/history/api/multioperation/api.go +++ b/service/history/api/multioperation/api.go @@ -38,6 +38,7 @@ import ( "go.temporal.io/server/common/definition" "go.temporal.io/server/common/locks" "go.temporal.io/server/common/persistence/visibility/manager" + "go.temporal.io/server/internal/pointcut" "go.temporal.io/server/service/history/api" "go.temporal.io/server/service/history/api/startworkflow" "go.temporal.io/server/service/history/api/updateworkflow" @@ -46,7 +47,8 @@ import ( ) var ( - multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.") + InBetweenLockAndStart = pointcut.NewHook[any]() + multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.") ) type ( @@ -206,6 +208,9 @@ func Invoke( } // workflow hasn't been started yet: start and then apply update + InBetweenLockAndStart.Invoke() + + // Workflow hasn't been started yet: start and then apply update. resp, err := startAndUpdateWorkflow(ctx, starter, updater) var noStartErr *noStartError if errors.As(err, &noStartErr) { diff --git a/tests/update_workflow_test.go b/tests/update_workflow_test.go index fd4c4e5af46..5bc56769d42 100644 --- a/tests/update_workflow_test.go +++ b/tests/update_workflow_test.go @@ -52,6 +52,7 @@ import ( "go.temporal.io/server/common/testing/protoutils" "go.temporal.io/server/common/testing/taskpoller" "go.temporal.io/server/common/testing/testvars" + "go.temporal.io/server/service/history/api/multioperation" "go.temporal.io/server/tests/testcore" "google.golang.org/protobuf/types/known/durationpb" ) @@ -5319,6 +5320,114 @@ func (s *UpdateWorkflowSuite) TestUpdateWithStart() { }) }) + s.Run("workflow start conflict", func() { + + s.Run("workflow id conflict policy fail: fail", func() { + tv := testvars.New(s.T()) + + startReq := startWorkflowReq(tv) + startReq.WorkflowIdConflictPolicy = enumspb.WORKFLOW_ID_CONFLICT_POLICY_FAIL + updateReq := s.updateWorkflowRequest(tv, + &updatepb.WaitPolicy{LifecycleStage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED}, "1") + + // simulate a race condition + defer multioperation.InBetweenLockAndStart.Reset() + multioperation.InBetweenLockAndStart.Set(func() any { + _, err := s.FrontendClient().StartWorkflowExecution(testcore.NewContext(), startReq) + s.NoError(err) + return nil + }) + + uwsCh := sendUpdateWithStart(testcore.NewContext(), startReq, updateReq) + + _, err := s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + reply := &workflowservice.RespondWorkflowTaskCompletedRequest{} + if len(task.Messages) > 0 { + reply.Messages = s.UpdateAcceptCompleteMessages(tv, task.Messages[0], "1") + } + return reply, nil + }) + s.NoError(err) + + uwsRes := <-uwsCh + s.Error(uwsRes.err) + }) + + s.Run("workflow id conflict policy fail: use-existing", func() { + tv := testvars.New(s.T()) + + startReq := startWorkflowReq(tv) + startReq.WorkflowIdConflictPolicy = enumspb.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING + updateReq := s.updateWorkflowRequest(tv, + &updatepb.WaitPolicy{LifecycleStage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED}, "1") + + // simulate a race condition + defer multioperation.InBetweenLockAndStart.Reset() + multioperation.InBetweenLockAndStart.Set(func() any { + fmt.Println("INJECT START") + _, err := s.FrontendClient().StartWorkflowExecution(testcore.NewContext(), startReq) + s.NoError(err) + return nil + }) + + uwsCh := sendUpdateWithStart(testcore.NewContext(), startReq, updateReq) + + _, err := s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + fmt.Println("PROCESS TASK #1") + return &workflowservice.RespondWorkflowTaskCompletedRequest{}, nil + }) + s.NoError(err) + + _, err = s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + fmt.Println("PROCESS TASK #2") + return &workflowservice.RespondWorkflowTaskCompletedRequest{ + Messages: s.UpdateAcceptCompleteMessages(tv, task.Messages[0], "1"), + }, nil + }) + s.NoError(err) + + <-uwsCh + }) + + s.Run("dedup request", func() { + tv := testvars.New(s.T()) + + startReq := startWorkflowReq(tv) + startReq.RequestId = "request_id" // dedup key + updateReq := s.updateWorkflowRequest(tv, + &updatepb.WaitPolicy{LifecycleStage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED}, "1") + + // simulate a race condition + defer multioperation.InBetweenLockAndStart.Reset() + multioperation.InBetweenLockAndStart.Set(func() any { + _, err := s.FrontendClient().StartWorkflowExecution(testcore.NewContext(), startReq) + s.NoError(err) + return nil + }) + + uwsCh := sendUpdateWithStart(testcore.NewContext(), startReq, updateReq) + + _, err := s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + return &workflowservice.RespondWorkflowTaskCompletedRequest{}, nil + }) + s.NoError(err) + + _, err = s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + return &workflowservice.RespondWorkflowTaskCompletedRequest{ + Messages: s.UpdateAcceptCompleteMessages(tv, task.Messages[0], "1"), + }, nil + }) + s.NoError(err) + + <-uwsCh + }) + }) + s.Run("return update rate limit error", func() { // lower maximum total number of updates for testing purposes maxTotalUpdates := 0