Skip to content

Commit

Permalink
asserts,confdb: define confdb-control assertion (#14705)
Browse files Browse the repository at this point in the history
* asserts: define confdb-control assertion

* asserts: convert []string to []AuthenticationMethod in check

* asserts,registry: fixes after review

* registry,asserts: better errors

* asserts,registry: rename Group to ControlGroup

* asserts,registry: fixes after code review

* registry: convert ControlGroup.Views to a struct

* registry: add doc comment to ViewRef

* testtime: fix export_test
  • Loading branch information
st3v3nmw authored Dec 9, 2024
1 parent c2a092d commit 3cefe04
Show file tree
Hide file tree
Showing 13 changed files with 589 additions and 6 deletions.
2 changes: 2 additions & 0 deletions asserts/asserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ var (
DeviceSessionRequestType = &AssertionType{"device-session-request", []string{"brand-id", "model", "serial"}, nil, assembleDeviceSessionRequest, noAuthority}
SerialRequestType = &AssertionType{"serial-request", nil, nil, assembleSerialRequest, noAuthority}
AccountKeyRequestType = &AssertionType{"account-key-request", []string{"public-key-sha3-384"}, nil, assembleAccountKeyRequest, noAuthority}
ConfdbControlType = &AssertionType{"confdb-control", []string{"brand-id", "model", "serial"}, nil, assembleConfdbControl, noAuthority}
)

var typeRegistry = map[string]*AssertionType{
Expand All @@ -178,6 +179,7 @@ var typeRegistry = map[string]*AssertionType{
DeviceSessionRequestType.Name: DeviceSessionRequestType,
SerialRequestType.Name: SerialRequestType,
AccountKeyRequestType.Name: AccountKeyRequestType,
ConfdbControlType.Name: ConfdbControlType,
}

// Type returns the AssertionType with name or nil
Expand Down
4 changes: 3 additions & 1 deletion asserts/asserts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func (as *assertsSuite) TestTypeNames(c *C) {
"account-key-request",
"base-declaration",
"confdb",
"confdb-control",
"device-session-request",
"model",
"preseed",
Expand Down Expand Up @@ -1207,7 +1208,8 @@ func (as *assertsSuite) TestWithAuthority(c *C) {
"validation-set",
"repair",
}
c.Check(withAuthority, HasLen, asserts.NumAssertionType-3) // excluding device-session-request, serial-request, account-key-request
// excluding device-session-request, serial-request, account-key-request, confdb-control
c.Check(withAuthority, HasLen, asserts.NumAssertionType-4)
for _, name := range withAuthority {
typ := asserts.Type(name)
_, err := asserts.AssembleAndSignInTest(typ, nil, []byte("{}"), testPrivKey1)
Expand Down
108 changes: 108 additions & 0 deletions asserts/confdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package asserts

import (
"encoding/json"
"errors"
"fmt"
"time"

Expand Down Expand Up @@ -108,3 +109,110 @@ func assembleConfdb(assert assertionBase) (Assertion, error) {
timestamp: timestamp,
}, nil
}

// ConfdbControl holds a confdb-control assertion, which holds lists of
// views delegated by the device to operators.
type ConfdbControl struct {
assertionBase

// the key is the operator ID
operators map[string]*confdb.Operator
}

// BrandID returns the brand identifier of the device.
func (cc *ConfdbControl) BrandID() string {
return cc.HeaderString("brand-id")
}

// Model returns the model name identifier of the device.
func (cc *ConfdbControl) Model() string {
return cc.HeaderString("model")
}

// Serial returns the serial identifier of the device.
// Together with brand-id and model, they form the device's unique identifier.
func (cc *ConfdbControl) Serial() string {
return cc.HeaderString("serial")
}

// assembleConfdbControl creates a new confdb-control assertion after validating
// all required fields and constraints.
func assembleConfdbControl(assert assertionBase) (Assertion, error) {
_, err := checkStringMatches(assert.headers, "brand-id", validAccountID)
if err != nil {
return nil, err
}

if _, err := checkModel(assert.headers); err != nil {
return nil, err
}

groups, err := checkList(assert.headers, "groups")
if err != nil {
return nil, err
}
if groups == nil {
return nil, errors.New(`"groups" stanza is mandatory`)
}

operators, err := parseConfdbControlGroups(groups)
if err != nil {
return nil, err
}

cc := &ConfdbControl{
assertionBase: assert,
operators: operators,
}
return cc, nil
}

func parseConfdbControlGroups(rawGroups []interface{}) (map[string]*confdb.Operator, error) {
operators := map[string]*confdb.Operator{}
for i, rawGroup := range rawGroups {
errPrefix := fmt.Sprintf("cannot parse group at position %d", i+1)

group, ok := rawGroup.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("%s: must be a map", errPrefix)
}

operatorID, err := checkNotEmptyStringWhat(group, "operator-id", "field")
if err != nil {
return nil, fmt.Errorf("%s: %w", errPrefix, err)
}

// Currently, operatorIDs must be snap store account IDs
if !IsValidAccountID(operatorID) {
return nil, fmt.Errorf(`%s: invalid "operator-id" %s`, errPrefix, operatorID)
}

operator, ok := operators[operatorID]
if !ok {
operator = &confdb.Operator{ID: operatorID}
operators[operatorID] = operator
}

auth, err := checkStringListInMap(group, "authentication", "field", nil)
if err != nil {
return nil, fmt.Errorf(`%s: "authentication" %w`, errPrefix, err)
}
if auth == nil {
return nil, fmt.Errorf(`%s: "authentication" must be provided`, errPrefix)
}

views, err := checkStringListInMap(group, "views", "field", nil)
if err != nil {
return nil, fmt.Errorf(`%s: "views" %w`, errPrefix, err)
}
if views == nil {
return nil, fmt.Errorf(`%s: "views" must be provided`, errPrefix)
}

if err := operator.AddControlGroup(views, auth); err != nil {
return nil, fmt.Errorf(`%s: %w`, errPrefix, err)
}
}

return operators, nil
}
151 changes: 151 additions & 0 deletions asserts/confdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
. "gopkg.in/check.v1"

"github.com/snapcore/snapd/asserts"
"github.com/snapcore/snapd/confdb"
)

type confdbSuite struct {
Expand Down Expand Up @@ -199,3 +200,153 @@ func (s *confdbSuite) TestAssembleAndSignChecksSchemaFormatFail(c *C) {
_, err := asserts.AssembleAndSignInTest(asserts.ConfdbType, headers, []byte(schema), testPrivKey0)
c.Assert(err, ErrorMatches, `assertion confdb: JSON in body must be indented with 2 spaces and sort object entries by key`)
}

type confdbCtrlSuite struct{}

var _ = Suite(&confdbCtrlSuite{})

const (
confdbControlExample = `type: confdb-control
brand-id: generic
model: generic-classic
serial: 03961d5d-26e5-443f-838d-6db046126bea
groups:
-
operator-id: john
authentication:
- operator-key
views:
- canonical/network/control-device
- canonical/network/observe-device
-
operator-id: john
authentication:
- store
views:
- canonical/network/control-interfaces
-
operator-id: jane
authentication:
- store
- operator-key
views:
- canonical/network/observe-interfaces
sign-key-sha3-384: t9yuKGLyiezBq_PXMJZsGdkTukmL7MgrgqXAlxxiZF4TYryOjZcy48nnjDmEHQDp
AXNpZw==`
)

func (s *confdbCtrlSuite) TestDecodeOK(c *C) {
encoded := confdbControlExample

a, err := asserts.Decode([]byte(encoded))
c.Assert(err, IsNil)
c.Assert(a, NotNil)
c.Assert(a.Type(), Equals, asserts.ConfdbControlType)

cc := a.(*asserts.ConfdbControl)
c.Assert(cc.BrandID(), Equals, "generic")
c.Assert(cc.Model(), Equals, "generic-classic")
c.Assert(cc.Serial(), Equals, "03961d5d-26e5-443f-838d-6db046126bea")
c.Assert(cc.AuthorityID(), Equals, "")

operators := cc.Operators()

john, ok := operators["john"]
c.Assert(ok, Equals, true)
c.Assert(john.ID, Equals, "john")
c.Assert(len(john.Groups), Equals, 2)

g := john.Groups[0]
c.Assert(g.Authentication, DeepEquals, []confdb.AuthenticationMethod{"operator-key"})
expectedViews := []*confdb.ViewRef{
{Account: "canonical", Confdb: "network", View: "control-device"},
{Account: "canonical", Confdb: "network", View: "observe-device"},
}
c.Assert(g.Views, DeepEquals, expectedViews)

g = john.Groups[1]
c.Assert(g.Authentication, DeepEquals, []confdb.AuthenticationMethod{"store"})
expectedViews = []*confdb.ViewRef{
{Account: "canonical", Confdb: "network", View: "control-interfaces"},
}
c.Assert(g.Views, DeepEquals, expectedViews)

jane, ok := operators["jane"]
c.Assert(ok, Equals, true)
c.Assert(jane.ID, Equals, "jane")
c.Assert(len(jane.Groups), Equals, 1)

g = jane.Groups[0]
c.Assert(g.Authentication, DeepEquals, []confdb.AuthenticationMethod{"operator-key", "store"})
expectedViews = []*confdb.ViewRef{
{Account: "canonical", Confdb: "network", View: "observe-interfaces"},
}
c.Assert(g.Views, DeepEquals, expectedViews)
}

func (s *confdbCtrlSuite) TestDecodeInvalid(c *C) {
encoded := confdbControlExample
const validationSetErrPrefix = "assertion confdb-control: "

invalidTests := []struct{ original, invalid, expectedErr string }{
{"brand-id: generic\n", "", `"brand-id" header is mandatory`},
{"brand-id: generic\n", "brand-id: \n", `"brand-id" header should not be empty`},
{"brand-id: generic\n", "brand-id: 456#\n", `"brand-id" header contains invalid characters: "456#"`},
{"model: generic-classic\n", "", `"model" header is mandatory`},
{"model: generic-classic\n", "model: \n", `"model" header should not be empty`},
{"model: generic-classic\n", "model: #\n", `"model" header contains invalid characters: "#"`},
{"serial: 03961d5d-26e5-443f-838d-6db046126bea\n", "", `"serial" header is mandatory`},
{"serial: 03961d5d-26e5-443f-838d-6db046126bea\n", "serial: \n", `"serial" header should not be empty`},
{"groups:", "groups: foo\nviews:", `"groups" header must be a list`},
{"groups:", "views:", `"groups" stanza is mandatory`},
{"groups:", "groups:\n - bar", `cannot parse group at position 1: must be a map`},
{" operator-id: jane\n", "", `cannot parse group at position 3: "operator-id" field is mandatory`},
{
"operator-id: jane\n",
"operator-id: \n",
`cannot parse group at position 3: "operator-id" field should not be empty`,
},
{
"operator-id: jane\n",
"operator-id: @op\n",
`cannot parse group at position 3: invalid "operator-id" @op`,
},
{
" authentication:\n - store",
" authentication: abcd",
`cannot parse group at position 2: "authentication" field must be a list of strings`,
},
{
" authentication:\n - store",
" foo: bar",
`cannot parse group at position 2: "authentication" must be provided`,
},
{
" views:\n - canonical/network/control-interfaces",
" views: abcd",
`cannot parse group at position 2: "views" field must be a list of strings`,
},
{
" views:\n - canonical/network/control-interfaces",
" foo: bar",
`cannot parse group at position 2: "views" must be provided`,
},
{
" - operator-key",
" - foo-bar",
"cannot parse group at position 1: cannot add group: invalid authentication method: foo-bar",
},
{
"canonical/network/control-interfaces",
"canonical",
`cannot parse group at position 2: view "canonical" must be in the format account/confdb/view`,
},
}

for i, test := range invalidTests {
invalid := strings.Replace(encoded, test.original, test.invalid, 1)
_, err := asserts.Decode([]byte(invalid))
c.Assert(err, ErrorMatches, validationSetErrPrefix+test.expectedErr, Commentf("test %d/%d failed", i+1, len(invalidTests)))
}
}
6 changes: 6 additions & 0 deletions asserts/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"time"

"github.com/snapcore/snapd/asserts/internal"
"github.com/snapcore/snapd/confdb"
"github.com/snapcore/snapd/testutil"
)

Expand Down Expand Up @@ -368,3 +369,8 @@ func MockAssertionPrereqs(f func(a Assertion) []*Ref) func() {
assertionPrereqs = f
return r
}

// ConfdbControl.operators exposed for tests
func (cc *ConfdbControl) Operators() map[string]*confdb.Operator {
return cc.operators
}
17 changes: 17 additions & 0 deletions asserts/header_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,20 @@ func checkMapWhat(m map[string]interface{}, name, what string) (map[string]inter
}
return mv, nil
}

func checkList(headers map[string]interface{}, name string) ([]interface{}, error) {
return checkListWhat(headers, name, "header")
}

func checkListWhat(m map[string]interface{}, name, what string) ([]interface{}, error) {
value, ok := m[name]
if !ok {
return nil, nil
}

list, ok := value.([]interface{})
if !ok {
return nil, fmt.Errorf("%q %s must be a list", name, what)
}
return list, nil
}
2 changes: 1 addition & 1 deletion confdb/confdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ func getPlaceholders(viewStr string) map[string]bool {
return placeholders
}

// View returns an view from the confdb.
// View returns a view from the confdb.
func (db *Confdb) View(view string) *View {
return db.views[view]
}
Expand Down
Loading

0 comments on commit 3cefe04

Please sign in to comment.