Skip to content

Commit

Permalink
feat: add optional tag to ignore fields when nil when rlp encoding/de…
Browse files Browse the repository at this point in the history
…coding
  • Loading branch information
paologalligit committed Dec 11, 2024
1 parent c74017e commit f31c793
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 21 deletions.
28 changes: 21 additions & 7 deletions rlp/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ type Decoder interface {
// rules for the field such that input values of size zero decode as a nil
// pointer. This tag can be useful when decoding recursive types.
//
// type StructWithEmptyOK struct {
// Foo *[20]byte `rlp:"nil"`
// }
// type StructWithEmptyOK struct {
// Foo *[20]byte `rlp:"nil"`
// }
//
// To decode into a slice, the input must be a list and the resulting
// slice will contain the input elements in order. For byte slices,
Expand All @@ -113,8 +113,8 @@ type Decoder interface {
// To decode into an interface value, Decode stores one of these
// in the value:
//
// []interface{}, for RLP lists
// []byte, for RLP strings
// []interface{}, for RLP lists
// []byte, for RLP strings
//
// Non-empty interface types are not supported, nor are booleans,
// signed integers, floating point numbers, maps, channels and
Expand All @@ -124,7 +124,7 @@ type Decoder interface {
// and may be vulnerable to panics cause by huge value sizes. If
// you need an input limit, use
//
// NewStream(r, limit).Decode(val)
// NewStream(r, limit).Decode(val)
func Decode(r io.Reader, val interface{}) error {
// TODO: this could use a Stream from a pool.
return NewStream(r, 0).Decode(val)
Expand Down Expand Up @@ -438,9 +438,16 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
if _, err := s.List(); err != nil {
return wrapStreamError(err, typ)
}
for _, f := range fields {
for i, f := range fields {
err := f.info.decoder(s, val.Field(f.index))
if err == EOL {
if f.optional {
// The field is optional, so reaching the end of the list before
// reaching the last field is acceptable. All remaining undecoded
// fields are zeroed.
zeroFields(val, fields[i:])
break
}
return &decodeError{msg: "too few elements", typ: typ}
} else if err != nil {
return addErrorContext(err, "."+typ.Field(f.index).Name)
Expand All @@ -451,6 +458,13 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
return dec, nil
}

func zeroFields(structval reflect.Value, fields []field) {
for _, f := range fields {
fv := structval.Field(f.index)
fv.Set(reflect.Zero(fv.Type()))
}
}

// makePtrDecoder creates a decoder that decodes into
// the pointer's element type.
func makePtrDecoder(typ reflect.Type) (decoder, error) {
Expand Down
164 changes: 161 additions & 3 deletions rlp/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,29 @@ var (
)
)

type hasIgnoredField struct {
type optionalFields struct {
A uint
B uint `rlp:"optional"`
C uint `rlp:"optional"`
}
type optionalAndTailField struct {
A uint
B uint `rlp:"optional"`
Tail []uint `rlp:"tail"`
}
type optionalBigIntField struct {
A uint
B *big.Int `rlp:"optional"`
}
type optionalPtrField struct {
A uint
B *[3]byte `rlp:"optional"`
}
type optionalPtrFieldNil struct {
A uint
B *[3]byte `rlp:"optional,nil"`
}
type ignoredField struct {
A uint
B uint `rlp:"-"`
C uint
Expand Down Expand Up @@ -514,8 +536,112 @@ var decodeTests = []decodeTest{
// struct tag "-"
{
input: "C20102",
ptr: new(hasIgnoredField),
value: hasIgnoredField{A: 1, C: 2},
ptr: new(ignoredField),
value: ignoredField{A: 1, C: 2},
},

// struct tag "optional"
{
input: "C101",
ptr: new(optionalFields),
value: optionalFields{1, 0, 0},
},
{
input: "C20102",
ptr: new(optionalFields),
value: optionalFields{1, 2, 0},
},
{
input: "C3010203",
ptr: new(optionalFields),
value: optionalFields{1, 2, 3},
},
{
input: "C401020304",
ptr: new(optionalFields),
error: "rlp: input list has too many elements for rlp.optionalFields",
},
{
input: "C101",
ptr: new(optionalAndTailField),
value: optionalAndTailField{A: 1},
},
{
input: "C20102",
ptr: new(optionalAndTailField),
value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}},
},
{
input: "C401020304",
ptr: new(optionalAndTailField),
value: optionalAndTailField{A: 1, B: 2, Tail: []uint{3, 4}},
},
{
input: "C101",
ptr: new(optionalBigIntField),
value: optionalBigIntField{A: 1, B: nil},
},
{
input: "C20102",
ptr: new(optionalBigIntField),
value: optionalBigIntField{A: 1, B: big.NewInt(2)},
},
{
input: "C101",
ptr: new(optionalPtrField),
value: optionalPtrField{A: 1},
},
{
input: "C20180", // not accepted because "optional" doesn't enable "nil"
ptr: new(optionalPtrField),
error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B",
},
{
input: "C20102",
ptr: new(optionalPtrField),
error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B",
},
{
input: "C50183010203",
ptr: new(optionalPtrField),
value: optionalPtrField{A: 1, B: &[3]byte{1, 2, 3}},
},
{
input: "C101",
ptr: new(optionalPtrFieldNil),
value: optionalPtrFieldNil{A: 1},
},
{
input: "C20180", // accepted because "nil" tag allows empty input
ptr: new(optionalPtrFieldNil),
value: optionalPtrFieldNil{A: 1},
},
{
input: "C20102",
ptr: new(optionalPtrFieldNil),
error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrFieldNil).B",
},

// struct tag "optional" field clearing
{
input: "C101",
ptr: &optionalFields{A: 9, B: 8, C: 7},
value: optionalFields{A: 1, B: 0, C: 0},
},
{
input: "C20102",
ptr: &optionalFields{A: 9, B: 8, C: 7},
value: optionalFields{A: 1, B: 2, C: 0},
},
{
input: "C20102",
ptr: &optionalAndTailField{A: 9, B: 8, Tail: []uint{7, 6, 5}},
value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}},
},
{
input: "C101",
ptr: &optionalPtrField{A: 9, B: &[3]byte{8, 7, 6}},
value: optionalPtrField{A: 1},
},

// RawValue
Expand Down Expand Up @@ -817,3 +943,35 @@ func unhex(str string) []byte {
}
return b
}

// This tests the validity checks for fields with struct tag "optional".
func TestInvalidOptionalField(t *testing.T) {
type (
invalid1 struct {
A uint `rlp:"optional"`
B uint
}
invalid2 struct {
T []uint `rlp:"tail,optional"`
}
invalid3 struct {
T []uint `rlp:"optional,tail"`
}
)
tests := []struct {
v interface{}
err string
}{
{v: new(invalid1), err: `rlp: struct field rlp.invalid1.B needs "optional" tag`},
{v: new(invalid2), err: `rlp: invalid struct tag "optional" for rlp.invalid2.T (cannot be used with "tail")`},
{v: new(invalid3), err: `rlp: invalid struct tag "tail" for rlp.invalid3.T (cannot be used with "optional")`},
}
for _, test := range tests {
err := DecodeBytes(unhex("C20102"), test.v)
if err == nil {
t.Errorf("no error for %T", test.v)
} else if err.Error() != test.err {
t.Errorf("wrong error for %T: %v", test.v, err.Error())
}
}
}
38 changes: 31 additions & 7 deletions rlp/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,15 +529,39 @@ func makeStructWriter(typ reflect.Type) (writer, error) {
if err != nil {
return nil, err
}
writer := func(val reflect.Value, w *encbuf) error {
lh := w.list()
for _, f := range fields {
if err := f.info.writer(val.Field(f.index), w); err != nil {
return err
var writer writer
firstOptionalField := firstOptionalField(fields)
if firstOptionalField == len(fields) {
// This is the writer function for structs without any optional fields.
writer = func(val reflect.Value, w *encbuf) error {
lh := w.list()
for _, f := range fields {
if err := f.info.writer(val.Field(f.index), w); err != nil {
return err
}
}
w.listEnd(lh)
return nil
}
} else {
// If there are any "optional" fields, the writer needs to perform additional
// checks to determine the output list length.
writer = func(val reflect.Value, w *encbuf) error {
lastField := len(fields) - 1
for ; lastField >= firstOptionalField; lastField-- {
if !val.Field(fields[lastField].index).IsZero() {
break
}
}
lh := w.list()
for i := 0; i <= lastField; i++ {
if err := fields[i].info.writer(val.Field(fields[i].index), w); err != nil {
return err
}
}
w.listEnd(lh)
return nil
}
w.listEnd(lh)
return nil
}
return writer, nil
}
Expand Down
16 changes: 15 additions & 1 deletion rlp/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ var encTests = []encTest{
{val: &tailRaw{A: 1, Tail: []RawValue{unhex("02")}}, output: "C20102"},
{val: &tailRaw{A: 1, Tail: []RawValue{}}, output: "C101"},
{val: &tailRaw{A: 1, Tail: nil}, output: "C101"},
{val: &hasIgnoredField{A: 1, B: 2, C: 3}, output: "C20103"},
{val: &ignoredField{A: 1, B: 2, C: 3}, output: "C20103"},

// nil
{val: (*uint)(nil), output: "80"},
Expand All @@ -232,6 +232,20 @@ var encTests = []encTest{
{val: (*[]struct{ uint })(nil), output: "C0"},
{val: (*interface{})(nil), output: "C0"},

// struct tag "optional"
{val: &optionalFields{}, output: "C180"},
{val: &optionalFields{A: 1}, output: "C101"},
{val: &optionalFields{A: 1, B: 2}, output: "C20102"},
{val: &optionalFields{A: 1, B: 2, C: 3}, output: "C3010203"},
{val: &optionalFields{A: 1, B: 0, C: 3}, output: "C3018003"},
{val: &optionalAndTailField{A: 1}, output: "C101"},
{val: &optionalAndTailField{A: 1, B: 2}, output: "C20102"},
{val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"},
{val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"},
{val: &optionalBigIntField{A: 1}, output: "C101"},
{val: &optionalPtrField{A: 1}, output: "C101"},
{val: &optionalPtrFieldNil{A: 1}, output: "C101"},

// interfaces
{val: []io.Reader{reader}, output: "C3C20102"}, // the contained value is a struct

Expand Down
Loading

0 comments on commit f31c793

Please sign in to comment.