Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix map task cache misses (#363)
Browse files Browse the repository at this point in the history
* Add input hash to workItemID

Signed-off-by: Bernhard Stadlbauer <[email protected]>

* Add test to ensure different IDs

Signed-off-by: Bernhard Stadlbauer <[email protected]>

* Cleanup `make lint`

Signed-off-by: Bernhard Stadlbauer <[email protected]>

* Move `emptyLiteralMap` to `hashing.go`

Signed-off-by: Bernhard Stadlbauer <[email protected]>

* Fix import ordering

Signed-off-by: Bernhard Stadlbauer <[email protected]>

---------

Signed-off-by: Bernhard Stadlbauer <[email protected]>
  • Loading branch information
bstadlbauer authored Jun 27, 2023
1 parent 1d0f3e9 commit 5a5ad82
Show file tree
Hide file tree
Showing 6 changed files with 788 additions and 14 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.0.0 // indirect
github.com/aws/smithy-go v1.1.0 // indirect
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash v1.1.0 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.0.0/go.mod h1:5f+cELGATgill5Pu3/vK3E
github.com/aws/smithy-go v1.0.0/go.mod h1:EzMw8dbp/YJL4A5/sbhGddag+NPT7q084agLbB9LgIw=
github.com/aws/smithy-go v1.1.0 h1:D6CSsM3gdxaGaqXnPgOBCeL6Mophqzu7KJOu7zW78sU=
github.com/aws/smithy-go v1.1.0/go.mod h1:EzMw8dbp/YJL4A5/sbhGddag+NPT7q084agLbB9LgIw=
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 h1:VRtJdDi2lqc3MFwmouppm2jlm6icF+7H3WYKpLENMTo=
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1/go.mod h1:jvdWlw8vowVGnZqSDC7yhPd7AifQeQbRDkZcQXV2nRg=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
Expand Down
28 changes: 21 additions & 7 deletions go/tasks/pluginmachinery/catalog/async_client_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ import (
"hash/fnv"
"reflect"

"github.com/flyteorg/flytestdlib/promutils"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue"
"github.com/flyteorg/flytestdlib/bitarray"

"github.com/flyteorg/flytestdlib/errors"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue"
"github.com/flyteorg/flytestdlib/promutils"
)

const specialEncoderKey = "abcdefghijklmnopqrstuvwxyz123456"
Expand Down Expand Up @@ -41,6 +39,18 @@ func consistentHash(str string) (string, error) {
return base32Encoder.EncodeToString(b), nil
}

func hashInputs(ctx context.Context, key Key) (string, error) {
inputs := &core.LiteralMap{}
if key.TypedInterface.Inputs != nil {
retInputs, err := key.InputReader.Get(ctx)
if err != nil {
return "", err
}
inputs = retInputs
}
return HashLiteralMap(ctx, inputs)
}

func (c AsyncClientImpl) Download(ctx context.Context, requests ...DownloadRequest) (outputFuture DownloadFuture, err error) {
status := ResponseStatusReady
cachedResults := bitarray.NewBitSet(uint(len(requests)))
Expand Down Expand Up @@ -95,8 +105,12 @@ func (c AsyncClientImpl) Upload(ctx context.Context, requests ...UploadRequest)
status := ResponseStatusReady
var respErr error
for idx, request := range requests {
workItemID := formatWorkItemID(request.Key, idx, "")
err := c.Writer.Queue(ctx, workItemID, NewWriterWorkItem(
inputHash, err := hashInputs(ctx, request.Key)
if err != nil {
return nil, errors.Wrapf(ErrSystemError, err, "Failed to hash inputs for item: %v", request.Key)
}
workItemID := formatWorkItemID(request.Key, idx, inputHash)
err = c.Writer.Queue(ctx, workItemID, NewWriterWorkItem(
request.Key,
request.ArtifactData,
request.ArtifactMetadata))
Expand Down
98 changes: 91 additions & 7 deletions go/tasks/pluginmachinery/catalog/async_client_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,62 @@ import (
"reflect"
"testing"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue/mocks"
"github.com/flyteorg/flytestdlib/bitarray"
"github.com/stretchr/testify/mock"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue"
)

var exampleInterface = &core.TypedInterface{
Inputs: &core.VariableMap{
Variables: map[string]*core.Variable{
"a": {
Type: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_INTEGER,
},
},
},
},
},
}
var input1 = &core.LiteralMap{
Literals: map[string]*core.Literal{
"a": {
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{
Integer: 1,
},
},
},
},
},
},
},
}
var input2 = &core.LiteralMap{
Literals: map[string]*core.Literal{
"a": {
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{
Integer: 2,
},
},
},
},
},
},
},
}

func TestAsyncClientImpl_Download(t *testing.T) {
ctx := context.Background()

Expand Down Expand Up @@ -61,24 +109,50 @@ func TestAsyncClientImpl_Download(t *testing.T) {
func TestAsyncClientImpl_Upload(t *testing.T) {
ctx := context.Background()

inputHash1 := "{UNSPECIFIED {} [] 0}:-0-DNhkpTTPC5YDtRGb4yT-PFxgMSgHzHrKAQKgQGEfGRY"
inputHash2 := "{UNSPECIFIED {} [] 0}:-1-26M4dwarvBVJqJSUC4JC1GtRYgVBIAmQfsFSdLVMlAc"

q := &mocks.IndexedWorkQueue{}
info := &mocks.WorkItemInfo{}
info.OnItem().Return(NewReaderWorkItem(Key{}, &mocks2.OutputWriter{}))
info.OnStatus().Return(workqueue.WorkStatusSucceeded)
q.OnGet("{UNSPECIFIED {} [] 0}:-0-").Return(info, true, nil)
q.OnGet(inputHash1).Return(info, true, nil)
q.OnGet(inputHash2).Return(info, true, nil)
q.OnQueueMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)

inputReader1 := &mocks2.InputReader{}
inputReader1.OnGetMatch(mock.Anything).Return(input1, nil)
inputReader2 := &mocks2.InputReader{}
inputReader2.OnGetMatch(mock.Anything).Return(input2, nil)

tests := []struct {
name string
requests []UploadRequest
wantPutFuture UploadFuture
wantErr bool
}{
{"UploadSucceeded", []UploadRequest{
{
Key: Key{},
{
"UploadSucceeded",
// The second request has the same Key.Identifier and Key.Cache version but a different
// Key.InputReader. This should lead to a different WorkItemID in the queue.
// See https://github.com/flyteorg/flyte/issues/3787 for more details
[]UploadRequest{
{
Key: Key{
TypedInterface: *exampleInterface,
InputReader: inputReader1,
},
},
{
Key: Key{
TypedInterface: *exampleInterface,
InputReader: inputReader2,
},
},
},
}, newUploadFuture(ResponseStatusReady, nil), false},
newUploadFuture(ResponseStatusReady, nil),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -93,6 +167,16 @@ func TestAsyncClientImpl_Upload(t *testing.T) {
if !reflect.DeepEqual(gotPutFuture, tt.wantPutFuture) {
t.Errorf("AsyncClientImpl.Sidecar() = %v, want %v", gotPutFuture, tt.wantPutFuture)
}
expectedWorkItemIDs := []string{inputHash1, inputHash2}
gottenWorkItemIDs := make([]string, 0)
for _, mockCall := range q.Calls {
if mockCall.Method == "Get" {
gottenWorkItemIDs = append(gottenWorkItemIDs, mockCall.Arguments[0].(string))
}
}
if !reflect.DeepEqual(gottenWorkItemIDs, expectedWorkItemIDs) {
t.Errorf("Retrieved workitem IDs = %v, want %v", gottenWorkItemIDs, expectedWorkItemIDs)
}
})
}
}
Expand Down
78 changes: 78 additions & 0 deletions go/tasks/pluginmachinery/catalog/hashing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package catalog

import (
"context"
"encoding/base64"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytestdlib/pbhash"
)

var emptyLiteralMap = core.LiteralMap{Literals: map[string]*core.Literal{}}

// Hashify a literal, in other words, produce a new literal where the corresponding value is removed in case
// the literal hash is set.
func hashify(literal *core.Literal) *core.Literal {
// Two recursive cases:
// 1. A collection of literals or
// 2. A map of literals

if literal.GetCollection() != nil {
literals := literal.GetCollection().Literals
literalsHash := make([]*core.Literal, 0)
for _, lit := range literals {
literalsHash = append(literalsHash, hashify(lit))
}
return &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: literalsHash,
},
},
}
}
if literal.GetMap() != nil {
literalsMap := make(map[string]*core.Literal)
for key, lit := range literal.GetMap().Literals {
literalsMap[key] = hashify(lit)
}
return &core.Literal{
Value: &core.Literal_Map{
Map: &core.LiteralMap{
Literals: literalsMap,
},
},
}
}

// And a base case that consists of a scalar, where the hash might be set
if literal.GetHash() != "" {
return &core.Literal{
Hash: literal.GetHash(),
}
}
return literal
}

func HashLiteralMap(ctx context.Context, literalMap *core.LiteralMap) (string, error) {
if literalMap == nil || len(literalMap.Literals) == 0 {
literalMap = &emptyLiteralMap
}

// Hashify, i.e. generate a copy of the literal map where each literal value is removed
// in case the corresponding hash is set.
hashifiedLiteralMap := make(map[string]*core.Literal, len(literalMap.Literals))
for name, literal := range literalMap.Literals {
hashifiedLiteralMap[name] = hashify(literal)
}
hashifiedInputs := &core.LiteralMap{
Literals: hashifiedLiteralMap,
}

inputsHash, err := pbhash.ComputeHash(ctx, hashifiedInputs)
if err != nil {
return "", err
}

return base64.RawURLEncoding.EncodeToString(inputsHash), nil
}
Loading

0 comments on commit 5a5ad82

Please sign in to comment.