Skip to content

Commit

Permalink
Pointcut
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanos committed Dec 5, 2024
1 parent c051713 commit e6bd3b2
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 1 deletion.
43 changes: 43 additions & 0 deletions internal/pointcut/pointcut.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package pointcut

import (
"sync"
)

type Hook[T any] interface {
Invoke(args ...any) T
Set(func() T)
Reset()
}

type hook[T any] struct {
mu sync.RWMutex
fn func() T
}

func NewHook[T any]() Hook[T] {
return &hook[T]{}
}

func (h *hook[T]) Invoke(args ...any) T {
h.mu.RLock()
defer h.mu.RUnlock()

var result T
if h.fn != nil {
result = h.fn()
}
return result
}

// TODO: allow to set maximum
func (h *hook[T]) Set(fn func() T) {
h.mu.Lock()
defer h.mu.Unlock()

h.fn = fn
}

func (h *hook[T]) Reset() {
h.Set(nil)
}
7 changes: 6 additions & 1 deletion service/history/api/multioperation/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"go.temporal.io/server/common/definition"
"go.temporal.io/server/common/locks"
"go.temporal.io/server/common/persistence/visibility/manager"
"go.temporal.io/server/internal/pointcut"
"go.temporal.io/server/service/history/api"
"go.temporal.io/server/service/history/api/startworkflow"
"go.temporal.io/server/service/history/api/updateworkflow"
Expand All @@ -46,7 +47,8 @@ import (
)

var (
multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.")
InBetweenLockAndStart = pointcut.NewHook[any]()
multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.")
)

type (
Expand Down Expand Up @@ -206,6 +208,9 @@ func Invoke(
}

// workflow hasn't been started yet: start and then apply update
InBetweenLockAndStart.Invoke()

// Workflow hasn't been started yet: start and then apply update.
resp, err := startAndUpdateWorkflow(ctx, starter, updater)
var noStartErr *noStartError
if errors.As(err, &noStartErr) {
Expand Down
109 changes: 109 additions & 0 deletions tests/update_workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import (
"go.temporal.io/server/common/testing/protoutils"
"go.temporal.io/server/common/testing/taskpoller"
"go.temporal.io/server/common/testing/testvars"
"go.temporal.io/server/service/history/api/multioperation"
"go.temporal.io/server/tests/testcore"
"google.golang.org/protobuf/types/known/durationpb"
)
Expand Down Expand Up @@ -5319,6 +5320,114 @@ func (s *UpdateWorkflowSuite) TestUpdateWithStart() {
})
})

s.Run("workflow start conflict", func() {

s.Run("workflow id conflict policy fail: fail", func() {
tv := testvars.New(s.T())

startReq := startWorkflowReq(tv)
startReq.WorkflowIdConflictPolicy = enumspb.WORKFLOW_ID_CONFLICT_POLICY_FAIL
updateReq := s.updateWorkflowRequest(tv,
&updatepb.WaitPolicy{LifecycleStage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED}, "1")

// simulate a race condition
defer multioperation.InBetweenLockAndStart.Reset()
multioperation.InBetweenLockAndStart.Set(func() any {
_, err := s.FrontendClient().StartWorkflowExecution(testcore.NewContext(), startReq)
s.NoError(err)
return nil
})

uwsCh := sendUpdateWithStart(testcore.NewContext(), startReq, updateReq)

_, err := s.TaskPoller.PollAndHandleWorkflowTask(tv,
func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) {
reply := &workflowservice.RespondWorkflowTaskCompletedRequest{}
if len(task.Messages) > 0 {
reply.Messages = s.UpdateAcceptCompleteMessages(tv, task.Messages[0], "1")
}
return reply, nil
})
s.NoError(err)

uwsRes := <-uwsCh
s.Error(uwsRes.err)
})

s.Run("workflow id conflict policy fail: use-existing", func() {
tv := testvars.New(s.T())

startReq := startWorkflowReq(tv)
startReq.WorkflowIdConflictPolicy = enumspb.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING
updateReq := s.updateWorkflowRequest(tv,
&updatepb.WaitPolicy{LifecycleStage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED}, "1")

// simulate a race condition
defer multioperation.InBetweenLockAndStart.Reset()
multioperation.InBetweenLockAndStart.Set(func() any {
fmt.Println("INJECT START")
_, err := s.FrontendClient().StartWorkflowExecution(testcore.NewContext(), startReq)
s.NoError(err)
return nil
})

uwsCh := sendUpdateWithStart(testcore.NewContext(), startReq, updateReq)

_, err := s.TaskPoller.PollAndHandleWorkflowTask(tv,
func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) {
fmt.Println("PROCESS TASK #1")
return &workflowservice.RespondWorkflowTaskCompletedRequest{}, nil
})
s.NoError(err)

_, err = s.TaskPoller.PollAndHandleWorkflowTask(tv,
func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) {
fmt.Println("PROCESS TASK #2")
return &workflowservice.RespondWorkflowTaskCompletedRequest{
Messages: s.UpdateAcceptCompleteMessages(tv, task.Messages[0], "1"),
}, nil
})
s.NoError(err)

<-uwsCh
})

s.Run("dedup request", func() {
tv := testvars.New(s.T())

startReq := startWorkflowReq(tv)
startReq.RequestId = "request_id" // dedup key
updateReq := s.updateWorkflowRequest(tv,
&updatepb.WaitPolicy{LifecycleStage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED}, "1")

// simulate a race condition
defer multioperation.InBetweenLockAndStart.Reset()
multioperation.InBetweenLockAndStart.Set(func() any {
_, err := s.FrontendClient().StartWorkflowExecution(testcore.NewContext(), startReq)
s.NoError(err)
return nil
})

uwsCh := sendUpdateWithStart(testcore.NewContext(), startReq, updateReq)

_, err := s.TaskPoller.PollAndHandleWorkflowTask(tv,
func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) {
return &workflowservice.RespondWorkflowTaskCompletedRequest{}, nil
})
s.NoError(err)

_, err = s.TaskPoller.PollAndHandleWorkflowTask(tv,
func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) {
return &workflowservice.RespondWorkflowTaskCompletedRequest{
Messages: s.UpdateAcceptCompleteMessages(tv, task.Messages[0], "1"),
}, nil
})
s.NoError(err)

<-uwsCh
})
})

s.Run("return update rate limit error", func() {
// lower maximum total number of updates for testing purposes
maxTotalUpdates := 0
Expand Down

0 comments on commit e6bd3b2

Please sign in to comment.