diff --git a/middleware/mocks/cache.go b/middleware/mocks/cache.go index 181ecef..86d65f2 100644 --- a/middleware/mocks/cache.go +++ b/middleware/mocks/cache.go @@ -9,19 +9,19 @@ import ( // CacheSvc is a mock implementation of cache.Service. // -// func TestSomethingThatUsesService(t *testing.T) { +// func TestSomethingThatUsesService(t *testing.T) { // -// // make and configure a mocked cache.Service -// mockedService := &CacheSvc{ -// GetFunc: func(key string, fn func() (interface{}, error)) (interface{}, error) { -// panic("mock out the Get method") -// }, -// } +// // make and configure a mocked cache.Service +// mockedService := &CacheSvc{ +// GetFunc: func(key string, fn func() (interface{}, error)) (interface{}, error) { +// panic("mock out the Get method") +// }, +// } // -// // use mockedService in code that requires cache.Service -// // and then make assertions. +// // use mockedService in code that requires cache.Service +// // and then make assertions. // -// } +// } type CacheSvc struct { // GetFunc mocks the Get method. GetFunc func(key string, fn func() (interface{}, error)) (interface{}, error) diff --git a/middleware/mocks/circuit_breaker.go b/middleware/mocks/circuit_breaker.go index f8b3483..f9f6c15 100644 --- a/middleware/mocks/circuit_breaker.go +++ b/middleware/mocks/circuit_breaker.go @@ -9,19 +9,19 @@ import ( // CircuitBreakerSvcMock is a mock implementation of middleware.CircuitBreakerSvc. // -// func TestSomethingThatUsesCircuitBreakerSvc(t *testing.T) { +// func TestSomethingThatUsesCircuitBreakerSvc(t *testing.T) { // -// // make and configure a mocked middleware.CircuitBreakerSvc -// mockedCircuitBreakerSvc := &CircuitBreakerSvcMock{ -// ExecuteFunc: func(req func() (interface{}, error)) (interface{}, error) { -// panic("mock out the Execute method") -// }, -// } +// // make and configure a mocked middleware.CircuitBreakerSvc +// mockedCircuitBreakerSvc := &CircuitBreakerSvcMock{ +// ExecuteFunc: func(req func() (interface{}, error)) (interface{}, error) { +// panic("mock out the Execute method") +// }, +// } // -// // use mockedCircuitBreakerSvc in code that requires middleware.CircuitBreakerSvc -// // and then make assertions. +// // use mockedCircuitBreakerSvc in code that requires middleware.CircuitBreakerSvc +// // and then make assertions. // -// } +// } type CircuitBreakerSvcMock struct { // ExecuteFunc mocks the Execute method. ExecuteFunc func(req func() (interface{}, error)) (interface{}, error) diff --git a/middleware/mocks/logger.go b/middleware/mocks/logger.go index 357704a..153fe31 100644 --- a/middleware/mocks/logger.go +++ b/middleware/mocks/logger.go @@ -9,19 +9,19 @@ import ( // LoggerSvc is a mock implementation of logger.Service. // -// func TestSomethingThatUsesService(t *testing.T) { +// func TestSomethingThatUsesService(t *testing.T) { // -// // make and configure a mocked logger.Service -// mockedService := &LoggerSvc{ -// LogfFunc: func(format string, args ...interface{}) { -// panic("mock out the Logf method") -// }, -// } +// // make and configure a mocked logger.Service +// mockedService := &LoggerSvc{ +// LogfFunc: func(format string, args ...interface{}) { +// panic("mock out the Logf method") +// }, +// } // -// // use mockedService in code that requires logger.Service -// // and then make assertions. +// // use mockedService in code that requires logger.Service +// // and then make assertions. // -// } +// } type LoggerSvc struct { // LogfFunc mocks the Logf method. LogfFunc func(format string, args ...interface{}) diff --git a/middleware/mocks/repeater.go b/middleware/mocks/repeater.go index 7765d8a..b430bb5 100644 --- a/middleware/mocks/repeater.go +++ b/middleware/mocks/repeater.go @@ -4,32 +4,35 @@ package mocks import ( + "context" "sync" ) // RepeaterSvcMock is a mock implementation of middleware.RepeaterSvc. // -// func TestSomethingThatUsesRepeaterSvc(t *testing.T) { +// func TestSomethingThatUsesRepeaterSvc(t *testing.T) { // -// // make and configure a mocked middleware.RepeaterSvc -// mockedRepeaterSvc := &RepeaterSvcMock{ -// DoFunc: func(fun func() error, errs ...error) error { -// panic("mock out the Do method") -// }, -// } +// // make and configure a mocked middleware.RepeaterSvc +// mockedRepeaterSvc := &RepeaterSvcMock{ +// DoFunc: func(ctx context.Context, fun func() error, errs ...error) error { +// panic("mock out the Do method") +// }, +// } // -// // use mockedRepeaterSvc in code that requires middleware.RepeaterSvc -// // and then make assertions. +// // use mockedRepeaterSvc in code that requires middleware.RepeaterSvc +// // and then make assertions. // -// } +// } type RepeaterSvcMock struct { // DoFunc mocks the Do method. - DoFunc func(fun func() error, errs ...error) error + DoFunc func(ctx context.Context, fun func() error, errs ...error) error // calls tracks calls to the methods. calls struct { // Do holds details about calls to the Do method. Do []struct { + // Ctx is the ctx argument value. + Ctx context.Context // Fun is the fun argument value. Fun func() error // Errs is the errs argument value. @@ -40,31 +43,35 @@ type RepeaterSvcMock struct { } // Do calls DoFunc. -func (mock *RepeaterSvcMock) Do(fun func() error, errs ...error) error { +func (mock *RepeaterSvcMock) Do(ctx context.Context, fun func() error, errs ...error) error { if mock.DoFunc == nil { panic("RepeaterSvcMock.DoFunc: method is nil but RepeaterSvc.Do was just called") } callInfo := struct { + Ctx context.Context Fun func() error Errs []error }{ + Ctx: ctx, Fun: fun, Errs: errs, } mock.lockDo.Lock() mock.calls.Do = append(mock.calls.Do, callInfo) mock.lockDo.Unlock() - return mock.DoFunc(fun, errs...) + return mock.DoFunc(ctx, fun, errs...) } // DoCalls gets all the calls that were made to Do. // Check the length with: // len(mockedRepeaterSvc.DoCalls()) func (mock *RepeaterSvcMock) DoCalls() []struct { + Ctx context.Context Fun func() error Errs []error } { var calls []struct { + Ctx context.Context Fun func() error Errs []error } diff --git a/middleware/repeater.go b/middleware/repeater.go index 2ba40bf..e2f91b0 100644 --- a/middleware/repeater.go +++ b/middleware/repeater.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "errors" "fmt" "net/http" @@ -8,7 +9,7 @@ import ( // RepeaterSvc defines repeater interface type RepeaterSvc interface { - Do(fun func() error, errs ...error) (err error) + Do(ctx context.Context, fun func() error, errs ...error) (err error) } // Repeater sets middleware with provided RepeaterSvc to retry failed requests @@ -23,7 +24,7 @@ func Repeater(repeater RepeaterSvc, failOnCodes ...int) RoundTripperHandler { var resp *http.Response var err error - e := repeater.Do(func() error { + e := repeater.Do(req.Context(), func() error { resp, err = next.RoundTrip(req) if err != nil { return err diff --git a/middleware/repeater_test.go b/middleware/repeater_test.go index 3c2a3f2..87a4037 100644 --- a/middleware/repeater_test.go +++ b/middleware/repeater_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "errors" "net/http" "sync/atomic" @@ -23,7 +24,7 @@ func TestRepeater_Passed(t *testing.T) { return resp, errors.New("blah") }} - repeater := &mocks.RepeaterSvcMock{DoFunc: func(fun func() error, errs ...error) (err error) { + repeater := &mocks.RepeaterSvcMock{DoFunc: func(ctx context.Context, fun func() error, errs ...error) (err error) { for i := 0; i < 5; i++ { if err = fun(); err == nil { return nil @@ -51,7 +52,7 @@ func TestRepeater_Failed(t *testing.T) { return resp, errors.New("http error") }} - repeater := &mocks.RepeaterSvcMock{DoFunc: func(fun func() error, errs ...error) (err error) { + repeater := &mocks.RepeaterSvcMock{DoFunc: func(ctx context.Context, fun func() error, errs ...error) (err error) { for i := 0; i < 5; i++ { if err = fun(); err == nil { return nil @@ -78,7 +79,7 @@ func TestRepeater_FailedStatus(t *testing.T) { return resp, nil }} - repeater := &mocks.RepeaterSvcMock{DoFunc: func(fun func() error, errs ...error) (err error) { + repeater := &mocks.RepeaterSvcMock{DoFunc: func(ctx context.Context, fun func() error, errs ...error) (err error) { for i := 0; i < 5; i++ { if err = fun(); err == nil { return nil