Skip to content

Commit

Permalink
Pointcut
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanos committed Dec 5, 2024
1 parent c051713 commit d51df51
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 1 deletion.
43 changes: 43 additions & 0 deletions internal/pointcut/pointcut.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package pointcut

import (
"sync"
)

type Hook[T any] interface {
Invoke(args ...any) T
Set(func(args ...any) T)
Reset()
}

type hook[T any] struct {
mu sync.RWMutex
fn func(args ...any) T
}

func NewHook[T any]() Hook[T] {
return &hook[T]{}
}

func (h *hook[T]) Invoke(args ...any) T {
h.mu.RLock()
defer h.mu.RUnlock()

var result T
if h.fn != nil {
result = h.fn()
}
return result
}

// TODO: allow to set maximum invocations
func (h *hook[T]) Set(fn func(args ...any) T) {
h.mu.Lock()
defer h.mu.Unlock()

h.fn = fn
}

func (h *hook[T]) Reset() {
h.Set(nil)
}
6 changes: 5 additions & 1 deletion service/history/api/multioperation/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"go.temporal.io/server/common/definition"
"go.temporal.io/server/common/locks"
"go.temporal.io/server/common/persistence/visibility/manager"
"go.temporal.io/server/internal/pointcut"
"go.temporal.io/server/service/history/api"
"go.temporal.io/server/service/history/api/startworkflow"
"go.temporal.io/server/service/history/api/updateworkflow"
Expand All @@ -46,7 +47,8 @@ import (
)

var (
multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.")
InBetweenLockAndStart = pointcut.NewHook[any]()
multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.")
)

type (
Expand Down Expand Up @@ -205,6 +207,8 @@ func Invoke(
}
}

InBetweenLockAndStart.Invoke(req.WorkflowId)

// workflow hasn't been started yet: start and then apply update
resp, err := startAndUpdateWorkflow(ctx, starter, updater)
var noStartErr *noStartError
Expand Down
43 changes: 43 additions & 0 deletions tests/update_workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import (
"go.temporal.io/server/common/testing/protoutils"
"go.temporal.io/server/common/testing/taskpoller"
"go.temporal.io/server/common/testing/testvars"
"go.temporal.io/server/service/history/api/multioperation"
"go.temporal.io/server/tests/testcore"
"google.golang.org/protobuf/types/known/durationpb"
)
Expand Down Expand Up @@ -5319,6 +5320,48 @@ func (s *UpdateWorkflowSuite) TestUpdateWithStart() {
})
})

s.Run("workflow start conflict", func() {

s.Run("workflow id conflict policy fail: use-existing", func() {
tv := testvars.New(s.T())

startReq := startWorkflowReq(tv)
startReq.WorkflowIdConflictPolicy = enumspb.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING
updateReq := s.updateWorkflowRequest(tv,
&updatepb.WaitPolicy{LifecycleStage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED}, "1")

// simulate a race condition
defer multioperation.InBetweenLockAndStart.Reset()
multioperation.InBetweenLockAndStart.Set(func(args ...any) any {
if args[0] == tv.WorkflowID() {
_, err := s.FrontendClient().StartWorkflowExecution(testcore.NewContext(), startReq)
s.NoError(err)
}
return nil
})

uwsCh := sendUpdateWithStart(testcore.NewContext(), startReq, updateReq)

_, err := s.TaskPoller.PollAndHandleWorkflowTask(tv,
func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) {
fmt.Println("PROCESS TASK #1")
return &workflowservice.RespondWorkflowTaskCompletedRequest{}, nil
})
s.NoError(err)

_, err = s.TaskPoller.PollAndHandleWorkflowTask(tv,
func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) {
fmt.Println("PROCESS TASK #2")
return &workflowservice.RespondWorkflowTaskCompletedRequest{
Messages: s.UpdateAcceptCompleteMessages(tv, task.Messages[0], "1"),
}, nil
})
s.NoError(err)

<-uwsCh
})
})

s.Run("return update rate limit error", func() {
// lower maximum total number of updates for testing purposes
maxTotalUpdates := 0
Expand Down

0 comments on commit d51df51

Please sign in to comment.