diff --git a/pkg/didcomm/protocol/legacyconnection/models.go b/pkg/didcomm/protocol/legacyconnection/models.go index 49452d94d6..8e3bfb74fb 100644 --- a/pkg/didcomm/protocol/legacyconnection/models.go +++ b/pkg/didcomm/protocol/legacyconnection/models.go @@ -106,6 +106,17 @@ type legacyDoc struct { Proof []interface{} `json:"proof,omitempty"` } +type problemReport struct { + Type string `json:"@type,omitempty"` + ID string `json:"@id,omitempty"` + Thread *decorator.Thread `json:"~thread,omitempty"` + ProblemCode string `json:"problem-code,omitempty"` + Explain string `json:"explain,omitempty"` + Localization struct { + Locale string `json:"locale,omitempty"` + } `json:"~l10n,omitempty"` +} + // JSONBytes converts Connection to json bytes. func (con *Connection) toLegacyJSONBytes() ([]byte, error) { if con.DIDDoc == nil { diff --git a/pkg/didcomm/protocol/legacyconnection/service.go b/pkg/didcomm/protocol/legacyconnection/service.go index 7fb76bd778..2b54d144b0 100644 --- a/pkg/didcomm/protocol/legacyconnection/service.go +++ b/pkg/didcomm/protocol/legacyconnection/service.go @@ -45,7 +45,9 @@ const ( // ResponseMsgType defines the legacy-connection response message type. ResponseMsgType = PIURI + "/response" // AckMsgType defines the legacy-connection ack message type. - AckMsgType = "https://didcomm.org/notification/1.0/ack" + AckMsgType = "https://didcomm.org/notification/1.0/ack" + // ProblemReportMsgType defines the protocol problem-report message type. + ProblemReportMsgType = PIURI + "/problem-report" routerConnsMetadataKey = "routerConnections" ) @@ -295,7 +297,8 @@ func (s *Service) Accept(msgType string) bool { return msgType == InvitationMsgType || msgType == RequestMsgType || msgType == ResponseMsgType || - msgType == AckMsgType + msgType == AckMsgType || + msgType == ProblemReportMsgType } // HandleOutbound handles outbound connection messages. @@ -318,6 +321,10 @@ func (s *Service) nextState(msgType, thID string) (state, error) { logger.Debugf("retrieved current state [%s] using nsThID [%s]", current.Name(), nsThID) + if msgType == ProblemReportMsgType { + return &responded{}, nil + } + next, err := stateFromMsgType(msgType) if err != nil { return nil, err @@ -636,7 +643,7 @@ func (s *Service) connectionRecord(msg service.DIDCommMsg, ctx service.DIDCommCo return s.requestMsgRecord(msg, ctx) case ResponseMsgType: return s.responseMsgRecord(msg) - case AckMsgType: + case AckMsgType, ProblemReportMsgType: return s.fetchConnectionRecord(theirNSPrefix, msg) } diff --git a/pkg/didcomm/protocol/legacyconnection/service_test.go b/pkg/didcomm/protocol/legacyconnection/service_test.go index 9998572cf6..036d543bb9 100644 --- a/pkg/didcomm/protocol/legacyconnection/service_test.go +++ b/pkg/didcomm/protocol/legacyconnection/service_test.go @@ -176,7 +176,7 @@ func TestService_Handle_Inviter(t *testing.T) { require.NoError(t, err) completedFlag := make(chan struct{}) - respondedFlag := make(chan struct{}) + respondedFlag := make(chan string) go msgEventListener(t, statusCh, respondedFlag, completedFlag) @@ -247,7 +247,153 @@ func TestService_Handle_Inviter(t *testing.T) { validateState(t, s, thid, findNamespace(AckMsgType), (&completed{}).Name()) } -func msgEventListener(t *testing.T, statusCh chan service.StateMsg, respondedFlag, completedFlag chan struct{}) { +func TestService_Handle_Inviter_With_ProblemReport(t *testing.T) { + mockStore := &mockstorage.MockStore{Store: make(map[string]mockstorage.DBEntry)} + storeProv := mockstorage.NewCustomMockStoreProvider(mockStore) + k := newKMS(t, storeProv) + prov := &protocol.MockProvider{ + StoreProvider: storeProv, + ServiceMap: map[string]interface{}{ + mediator.Coordination: &mockroute.MockMediatorSvc{}, + }, + CustomKMS: k, + KeyTypeValue: kms.ED25519Type, + KeyAgreementTypeValue: kms.X25519ECDHKWType, + } + + ctx := &context{ + outboundDispatcher: prov.OutboundDispatcher(), + crypto: &tinkcrypto.Crypto{}, + kms: k, + keyType: kms.ED25519Type, + keyAgreementType: kms.X25519ECDHKWType, + } + + _, pubKey, err := ctx.kms.CreateAndExportPubKeyBytes(kms.ED25519Type) + require.NoError(t, err) + + ctx.vdRegistry = &mockvdr.MockVDRegistry{CreateValue: createDIDDocWithKey(pubKey)} + + connRec, err := connection.NewRecorder(prov) + require.NoError(t, err) + require.NotNil(t, connRec) + + ctx.connectionRecorder = connRec + + doc, err := ctx.vdRegistry.Create(testMethod, nil) + require.NoError(t, err) + + s, err := New(prov) + require.NoError(t, err) + + actionCh := make(chan service.DIDCommAction, 10) + err = s.RegisterActionEvent(actionCh) + require.NoError(t, err) + + statusCh := make(chan service.StateMsg, 10) + err = s.RegisterMsgEvent(statusCh) + require.NoError(t, err) + + completedFlag := make(chan struct{}) + respondedFlag := make(chan string) + + go msgEventListener(t, statusCh, respondedFlag, completedFlag) + go func() { service.AutoExecuteActionEvent(actionCh) }() + + invitation := &Invitation{ + Type: InvitationMsgType, + ID: randomString(), + Label: "Bob", + RecipientKeys: []string{base58.Encode(pubKey)}, + ServiceEndpoint: "http://alice.agent.example.com:8081", + } + + err = ctx.connectionRecorder.SaveInvitation(invitation.ID, invitation) + require.NoError(t, err) + + thid := randomString() + + // Invitation was previously sent by Alice to Bob. + // Bob now sends a connection Request + connRequest, err := json.Marshal( + &Request{ + Type: RequestMsgType, + ID: thid, + Label: "Bob", + Thread: &decorator.Thread{ + PID: invitation.ID, + }, + Connection: &Connection{ + DID: doc.DIDDocument.ID, + DIDDoc: doc.DIDDocument, + }, + }) + require.NoError(t, err) + requestMsg, err := service.ParseDIDCommMsgMap(connRequest) + require.NoError(t, err) + _, err = s.HandleInbound(requestMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil)) + require.NoError(t, err) + + var connID string + select { + case connID = <-respondedFlag: + case <-time.After(2 * time.Second): + require.Fail(t, "didn't receive connection ID") + } + + connRecord, err := s.connectionRecorder.GetConnectionRecord(connID) + require.NoError(t, err) + + // Alice automatically sends connection Response to Bob + // Bob replies with Problem Report + prbRpt, err := json.Marshal( + &problemReport{ + ID: randomString(), + Type: ProblemReportMsgType, + Thread: &decorator.Thread{ID: connRecord.ThreadID}, + }) + require.NoError(t, err) + + prbRptMsg, err := service.ParseDIDCommMsgMap(prbRpt) + require.NoError(t, err) + + _, err = s.HandleInbound(prbRptMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil)) + require.NoError(t, err) + + validateState(t, s, thid, findNamespace(RequestMsgType), (&responded{}).Name()) + + _, err = ctx.connectionRecorder.GetConnectionRecord(connID) + require.ErrorContains(t, err, "data not found") + + _, err = s.HandleInbound(requestMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil)) + require.NoError(t, err) + + // Finally Bob replies with an ACK + ack, err := json.Marshal( + &model.Ack{ + Type: AckMsgType, + ID: randomString(), + Status: "OK", + Thread: &decorator.Thread{ID: connRecord.ThreadID}, + }) + require.NoError(t, err) + + ackMsg, err := service.ParseDIDCommMsgMap(ack) + require.NoError(t, err) + + _, err = s.HandleInbound(ackMsg, service.NewDIDCommContext(doc.DIDDocument.ID, "", nil)) + require.NoError(t, err) + + select { + case <-completedFlag: + case <-time.After(4 * time.Second): + require.Fail(t, "didn't receive post event complete") + } + + validateState(t, s, connRecord.ThreadID, findNamespace(AckMsgType), (&completed{}).Name()) +} + +func msgEventListener(t *testing.T, statusCh chan service.StateMsg, respondedFlag chan string, completedFlag chan struct{}) { for e := range statusCh { require.Equal(t, LegacyConnection, e.ProtocolName) @@ -272,11 +418,11 @@ func msgEventListener(t *testing.T, statusCh chan service.StateMsg, respondedFla close(completedFlag) } - if e.StateID == "responded" { + if e.StateID == "responded" && e.Msg.Type() != ProblemReportMsgType { // validate connectionID received during state transition with original connectionID require.NotNil(t, prop.ConnectionID()) require.NotNil(t, prop.InvitationID()) - close(respondedFlag) + respondedFlag <- prop.ConnectionID() } } } diff --git a/pkg/didcomm/protocol/legacyconnection/states.go b/pkg/didcomm/protocol/legacyconnection/states.go index 8156b44627..67c4833ab9 100644 --- a/pkg/didcomm/protocol/legacyconnection/states.go +++ b/pkg/didcomm/protocol/legacyconnection/states.go @@ -204,7 +204,7 @@ func (s *responded) Name() string { } func (s *responded) CanTransitionTo(next state) bool { - return StateIDCompleted == next.Name() + return StateIDCompleted == next.Name() || StateIDRequested == next.Name() } func (s *responded) ExecuteInbound(msg *stateMachineMsg, _ string, ctx *context) (*connectionstore.Record, @@ -226,6 +226,13 @@ func (s *responded) ExecuteInbound(msg *stateMachineMsg, _ string, ctx *context) return connRecord, &noOp{}, action, nil case ResponseMsgType: return msg.connRecord, &completed{}, func() error { return nil }, nil + case ProblemReportMsgType: + err := ctx.connectionRecorder.RemoveConnection(msg.connRecord.ConnectionID) + if err != nil { + return nil, nil, nil, fmt.Errorf("delete connection record is failed: %w", err) + } + + return msg.connRecord, &noOp{}, func() error { return nil }, nil default: return nil, nil, nil, fmt.Errorf("illegal msg type %s for state %s", msg.Type(), s.Name()) } diff --git a/pkg/didcomm/protocol/legacyconnection/states_test.go b/pkg/didcomm/protocol/legacyconnection/states_test.go index ae200849ab..ebd21a91d4 100644 --- a/pkg/didcomm/protocol/legacyconnection/states_test.go +++ b/pkg/didcomm/protocol/legacyconnection/states_test.go @@ -92,7 +92,7 @@ func TestRespondedState(t *testing.T) { require.Equal(t, "responded", res.Name()) require.False(t, res.CanTransitionTo(&null{})) require.False(t, res.CanTransitionTo(&invited{})) - require.False(t, res.CanTransitionTo(&requested{})) + require.True(t, res.CanTransitionTo(&requested{})) require.False(t, res.CanTransitionTo(res)) require.True(t, res.CanTransitionTo(&completed{})) } @@ -388,6 +388,32 @@ func TestRespondedState_Execute(t *testing.T) { require.NotNil(t, connRec) require.Equal(t, (&completed{}).Name(), followup.Name()) }) + t.Run("followup to 'noop' on inbound problem report message", func(t *testing.T) { + connRec := &connection.Record{ + State: (&responded{}).Name(), + ThreadID: request.ID, + ConnectionID: "123", + Namespace: findNamespace(ResponseMsgType), + } + err = ctx.connectionRecorder.SaveConnectionRecordWithMappings(connRec) + require.NoError(t, err) + + problemReportPayload, err := json.Marshal(&problemReport{Type: ProblemReportMsgType}) + require.NoError(t, err) + + connRec, followup, _, e := (&responded{}).ExecuteInbound( + &stateMachineMsg{ + DIDCommMsg: bytesToDIDCommMsg(t, problemReportPayload), + connRecord: connRec, + }, "", ctx) + require.NoError(t, e) + require.NotNil(t, connRec) + + _, e = ctx.connectionRecorder.GetConnectionRecord(connRec.ConnectionID) + require.Error(t, e) + require.ErrorContains(t, e, "data not found") + require.Equal(t, (&noOp{}).Name(), followup.Name()) + }) t.Run("handle inbound request unmarshalling error", func(t *testing.T) { _, followup, _, err := (&responded{}).ExecuteInbound(&stateMachineMsg{