From f31c7938e0661254e69f9edb8d9c976337c69188 Mon Sep 17 00:00:00 2001 From: paologalligit Date: Wed, 11 Dec 2024 16:13:25 +0100 Subject: [PATCH] feat: add optional tag to ignore fields when nil when rlp encoding/decoding --- rlp/decode.go | 28 ++++++-- rlp/decode_test.go | 164 ++++++++++++++++++++++++++++++++++++++++++++- rlp/encode.go | 38 +++++++++-- rlp/encode_test.go | 16 ++++- rlp/typecache.go | 36 +++++++++- 5 files changed, 261 insertions(+), 21 deletions(-) diff --git a/rlp/decode.go b/rlp/decode.go index dbbe599597a8..7191cc400e28 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -91,9 +91,9 @@ type Decoder interface { // rules for the field such that input values of size zero decode as a nil // pointer. This tag can be useful when decoding recursive types. // -// type StructWithEmptyOK struct { -// Foo *[20]byte `rlp:"nil"` -// } +// type StructWithEmptyOK struct { +// Foo *[20]byte `rlp:"nil"` +// } // // To decode into a slice, the input must be a list and the resulting // slice will contain the input elements in order. For byte slices, @@ -113,8 +113,8 @@ type Decoder interface { // To decode into an interface value, Decode stores one of these // in the value: // -// []interface{}, for RLP lists -// []byte, for RLP strings +// []interface{}, for RLP lists +// []byte, for RLP strings // // Non-empty interface types are not supported, nor are booleans, // signed integers, floating point numbers, maps, channels and @@ -124,7 +124,7 @@ type Decoder interface { // and may be vulnerable to panics cause by huge value sizes. If // you need an input limit, use // -// NewStream(r, limit).Decode(val) +// NewStream(r, limit).Decode(val) func Decode(r io.Reader, val interface{}) error { // TODO: this could use a Stream from a pool. return NewStream(r, 0).Decode(val) @@ -438,9 +438,16 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { if _, err := s.List(); err != nil { return wrapStreamError(err, typ) } - for _, f := range fields { + for i, f := range fields { err := f.info.decoder(s, val.Field(f.index)) if err == EOL { + if f.optional { + // The field is optional, so reaching the end of the list before + // reaching the last field is acceptable. All remaining undecoded + // fields are zeroed. + zeroFields(val, fields[i:]) + break + } return &decodeError{msg: "too few elements", typ: typ} } else if err != nil { return addErrorContext(err, "."+typ.Field(f.index).Name) @@ -451,6 +458,13 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { return dec, nil } +func zeroFields(structval reflect.Value, fields []field) { + for _, f := range fields { + fv := structval.Field(f.index) + fv.Set(reflect.Zero(fv.Type())) + } +} + // makePtrDecoder creates a decoder that decodes into // the pointer's element type. func makePtrDecoder(typ reflect.Type) (decoder, error) { diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 4d8abd001281..679bb17d7b97 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -354,7 +354,29 @@ var ( ) ) -type hasIgnoredField struct { +type optionalFields struct { + A uint + B uint `rlp:"optional"` + C uint `rlp:"optional"` +} +type optionalAndTailField struct { + A uint + B uint `rlp:"optional"` + Tail []uint `rlp:"tail"` +} +type optionalBigIntField struct { + A uint + B *big.Int `rlp:"optional"` +} +type optionalPtrField struct { + A uint + B *[3]byte `rlp:"optional"` +} +type optionalPtrFieldNil struct { + A uint + B *[3]byte `rlp:"optional,nil"` +} +type ignoredField struct { A uint B uint `rlp:"-"` C uint @@ -514,8 +536,112 @@ var decodeTests = []decodeTest{ // struct tag "-" { input: "C20102", - ptr: new(hasIgnoredField), - value: hasIgnoredField{A: 1, C: 2}, + ptr: new(ignoredField), + value: ignoredField{A: 1, C: 2}, + }, + + // struct tag "optional" + { + input: "C101", + ptr: new(optionalFields), + value: optionalFields{1, 0, 0}, + }, + { + input: "C20102", + ptr: new(optionalFields), + value: optionalFields{1, 2, 0}, + }, + { + input: "C3010203", + ptr: new(optionalFields), + value: optionalFields{1, 2, 3}, + }, + { + input: "C401020304", + ptr: new(optionalFields), + error: "rlp: input list has too many elements for rlp.optionalFields", + }, + { + input: "C101", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1}, + }, + { + input: "C20102", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}}, + }, + { + input: "C401020304", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{3, 4}}, + }, + { + input: "C101", + ptr: new(optionalBigIntField), + value: optionalBigIntField{A: 1, B: nil}, + }, + { + input: "C20102", + ptr: new(optionalBigIntField), + value: optionalBigIntField{A: 1, B: big.NewInt(2)}, + }, + { + input: "C101", + ptr: new(optionalPtrField), + value: optionalPtrField{A: 1}, + }, + { + input: "C20180", // not accepted because "optional" doesn't enable "nil" + ptr: new(optionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B", + }, + { + input: "C20102", + ptr: new(optionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B", + }, + { + input: "C50183010203", + ptr: new(optionalPtrField), + value: optionalPtrField{A: 1, B: &[3]byte{1, 2, 3}}, + }, + { + input: "C101", + ptr: new(optionalPtrFieldNil), + value: optionalPtrFieldNil{A: 1}, + }, + { + input: "C20180", // accepted because "nil" tag allows empty input + ptr: new(optionalPtrFieldNil), + value: optionalPtrFieldNil{A: 1}, + }, + { + input: "C20102", + ptr: new(optionalPtrFieldNil), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrFieldNil).B", + }, + + // struct tag "optional" field clearing + { + input: "C101", + ptr: &optionalFields{A: 9, B: 8, C: 7}, + value: optionalFields{A: 1, B: 0, C: 0}, + }, + { + input: "C20102", + ptr: &optionalFields{A: 9, B: 8, C: 7}, + value: optionalFields{A: 1, B: 2, C: 0}, + }, + { + input: "C20102", + ptr: &optionalAndTailField{A: 9, B: 8, Tail: []uint{7, 6, 5}}, + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}}, + }, + { + input: "C101", + ptr: &optionalPtrField{A: 9, B: &[3]byte{8, 7, 6}}, + value: optionalPtrField{A: 1}, }, // RawValue @@ -817,3 +943,35 @@ func unhex(str string) []byte { } return b } + +// This tests the validity checks for fields with struct tag "optional". +func TestInvalidOptionalField(t *testing.T) { + type ( + invalid1 struct { + A uint `rlp:"optional"` + B uint + } + invalid2 struct { + T []uint `rlp:"tail,optional"` + } + invalid3 struct { + T []uint `rlp:"optional,tail"` + } + ) + tests := []struct { + v interface{} + err string + }{ + {v: new(invalid1), err: `rlp: struct field rlp.invalid1.B needs "optional" tag`}, + {v: new(invalid2), err: `rlp: invalid struct tag "optional" for rlp.invalid2.T (cannot be used with "tail")`}, + {v: new(invalid3), err: `rlp: invalid struct tag "tail" for rlp.invalid3.T (cannot be used with "optional")`}, + } + for _, test := range tests { + err := DecodeBytes(unhex("C20102"), test.v) + if err == nil { + t.Errorf("no error for %T", test.v) + } else if err.Error() != test.err { + t.Errorf("wrong error for %T: %v", test.v, err.Error()) + } + } +} diff --git a/rlp/encode.go b/rlp/encode.go index 445b4b5b2104..6e0cc0716b9f 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -529,15 +529,39 @@ func makeStructWriter(typ reflect.Type) (writer, error) { if err != nil { return nil, err } - writer := func(val reflect.Value, w *encbuf) error { - lh := w.list() - for _, f := range fields { - if err := f.info.writer(val.Field(f.index), w); err != nil { - return err + var writer writer + firstOptionalField := firstOptionalField(fields) + if firstOptionalField == len(fields) { + // This is the writer function for structs without any optional fields. + writer = func(val reflect.Value, w *encbuf) error { + lh := w.list() + for _, f := range fields { + if err := f.info.writer(val.Field(f.index), w); err != nil { + return err + } } + w.listEnd(lh) + return nil + } + } else { + // If there are any "optional" fields, the writer needs to perform additional + // checks to determine the output list length. + writer = func(val reflect.Value, w *encbuf) error { + lastField := len(fields) - 1 + for ; lastField >= firstOptionalField; lastField-- { + if !val.Field(fields[lastField].index).IsZero() { + break + } + } + lh := w.list() + for i := 0; i <= lastField; i++ { + if err := fields[i].info.writer(val.Field(fields[i].index), w); err != nil { + return err + } + } + w.listEnd(lh) + return nil } - w.listEnd(lh) - return nil } return writer, nil } diff --git a/rlp/encode_test.go b/rlp/encode_test.go index 827960f7c15a..f9b909f2f35b 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -218,7 +218,7 @@ var encTests = []encTest{ {val: &tailRaw{A: 1, Tail: []RawValue{unhex("02")}}, output: "C20102"}, {val: &tailRaw{A: 1, Tail: []RawValue{}}, output: "C101"}, {val: &tailRaw{A: 1, Tail: nil}, output: "C101"}, - {val: &hasIgnoredField{A: 1, B: 2, C: 3}, output: "C20103"}, + {val: &ignoredField{A: 1, B: 2, C: 3}, output: "C20103"}, // nil {val: (*uint)(nil), output: "80"}, @@ -232,6 +232,20 @@ var encTests = []encTest{ {val: (*[]struct{ uint })(nil), output: "C0"}, {val: (*interface{})(nil), output: "C0"}, + // struct tag "optional" + {val: &optionalFields{}, output: "C180"}, + {val: &optionalFields{A: 1}, output: "C101"}, + {val: &optionalFields{A: 1, B: 2}, output: "C20102"}, + {val: &optionalFields{A: 1, B: 2, C: 3}, output: "C3010203"}, + {val: &optionalFields{A: 1, B: 0, C: 3}, output: "C3018003"}, + {val: &optionalAndTailField{A: 1}, output: "C101"}, + {val: &optionalAndTailField{A: 1, B: 2}, output: "C20102"}, + {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"}, + {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"}, + {val: &optionalBigIntField{A: 1}, output: "C101"}, + {val: &optionalPtrField{A: 1}, output: "C101"}, + {val: &optionalPtrFieldNil{A: 1}, output: "C101"}, + // interfaces {val: []io.Reader{reader}, output: "C3C20102"}, // the contained value is a struct diff --git a/rlp/typecache.go b/rlp/typecache.go index 3df799e1ecd5..c14178e1e30b 100644 --- a/rlp/typecache.go +++ b/rlp/typecache.go @@ -37,6 +37,9 @@ type typeinfo struct { type tags struct { // rlp:"nil" controls whether empty input results in a nil pointer. nilOK bool + // rlp:"optional" allows for a field to be missing in the input list. + // If this is set, all subsequent fields must also be optional. + optional bool // rlp:"tail" controls whether this field swallows additional list // elements. It can only be set for the last field, which must be // of slice type. @@ -91,11 +94,14 @@ func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) { } type field struct { - index int - info *typeinfo + index int + info *typeinfo + optional bool } func structFields(typ reflect.Type) (fields []field, err error) { + anyOptional := false + for i := 0; i < typ.NumField(); i++ { if f := typ.Field(i); f.PkgPath == "" { // exported tags, err := parseStructTag(typ, i) @@ -105,16 +111,32 @@ func structFields(typ reflect.Type) (fields []field, err error) { if tags.ignored { continue } + // If any field has the "optional" tag, subsequent fields must also have it. + if tags.optional || tags.tail { + anyOptional = true + } else if anyOptional { + return nil, fmt.Errorf(`rlp: struct field %v.%s needs "optional" tag`, typ, f.Name) + } info, err := cachedTypeInfo1(f.Type, tags) if err != nil { return nil, err } - fields = append(fields, field{i, info}) + fields = append(fields, field{i, info, tags.optional}) } } return fields, nil } +// anyOptionalFields returns the index of the first field with "optional" tag. +func firstOptionalField(fields []field) int { + for i, f := range fields { + if f.optional { + return i + } + } + return len(fields) +} + func parseStructTag(typ reflect.Type, fi int) (tags, error) { f := typ.Field(fi) var ts tags @@ -125,11 +147,19 @@ func parseStructTag(typ reflect.Type, fi int) (tags, error) { ts.ignored = true case "nil": ts.nilOK = true + case "optional": + ts.optional = true + if ts.tail { + return ts, fmt.Errorf(`rlp: invalid struct tag "optional" for %v.%s (cannot be used with "tail")`, typ, f.Name) + } case "tail": ts.tail = true if fi != typ.NumField()-1 { return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (must be on last field)`, typ, f.Name) } + if ts.optional { + return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%v (cannot be used with "optional")`, typ, f.Name) + } if f.Type.Kind() != reflect.Slice { return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (field type is not slice)`, typ, f.Name) }