diff --git a/internal/dremel/write_optional.go b/internal/dremel/write_optional.go index 53a6650..edc6dfd 100644 --- a/internal/dremel/write_optional.go +++ b/internal/dremel/write_optional.go @@ -114,31 +114,24 @@ type ifElseCase struct { // ifelses returns an if else block for the given definition and repetition level func ifelses(def, rep int, orig fields.Field) ifElses { opts := optionals(def, orig) - var seen []fields.RepetitionType - - di := int(orig.DefIndex(def)) - - for _, rt := range orig.RepetitionTypes[:di+1] { - if rt == fields.Required { - seen = append(seen, fields.Repeated) - } else { - break - } - } - var cases ifElseCases for _, o := range opts { f := orig.Copy() - if len(orig.Seen) <= len(seen) { - f.Seen = append(seen[:0:0], seen...) - } + f.Seen = seens(o) cases = append(cases, ifElseCase{f: f, p: f.Parent(o + 1)}) - seen = append(seen, fields.Repeated) } return cases.ifElses(def, rep, int(orig.MaxDef())) } +func seens(i int) fields.RepetitionTypes { + out := make([]fields.RepetitionType, i) + for i := range out { + out[i] = fields.Repeated + } + return fields.RepetitionTypes(out) +} + type ifElseCases []ifElseCase func (i ifElseCases) ifElses(def, rep, md int) ifElses { @@ -150,16 +143,13 @@ func (i ifElseCases) ifElses(def, rep, md int) ifElses { } var leftovers []ifElseCase - if def == md { + if len(i) > 1 { out.Else = &ifElse{ Val: i[len(i)-1].f.Init(def, rep), } - if len(i) > 1 { + if len(i) > 2 { leftovers = i[1 : len(i)-1] } - - } else if len(i) > 1 { - leftovers = i[1:] } for _, iec := range leftovers { @@ -177,16 +167,19 @@ func (i ifElseCases) ifElses(def, rep, md int) ifElses { func optionals(def int, f fields.Field) []int { var out []int di := f.DefIndex(def) + seen := append(f.Seen[:0:0], f.Seen...) + + if len(seen) > di+1 { + seen = seen[:di+1] + } + for i, rt := range f.RepetitionTypes[:di+1] { - if rt == fields.Optional { + if rt >= fields.Optional { out = append(out, i) } - } - - if def == int(f.MaxDef()) && len(f.RepetitionTypes) > di+1 && f.RepetitionTypes[di+1] == fields.Required { - out = append(out, out[len(out)-1]) - } else if def == int(f.MaxDef()) && len(out) == 1 { - out = append(out, out[len(out)-1]) + if i > len(seen)-1 && rt >= fields.Optional { + break + } } return out diff --git a/internal/dremel/write_repeated.go b/internal/dremel/write_repeated.go index 1e49749..6ee8a1f 100644 --- a/internal/dremel/write_repeated.go +++ b/internal/dremel/write_repeated.go @@ -134,7 +134,7 @@ func initRepeated(def, rep int, seen fields.RepetitionTypes, f fields.Field) str rep = def } - if useIfElse(def, rep, seen, f) { + if useIfElse(def, rep, append(seen[:0:0], seen...), f) { ie := ifelses(def, rep, f) var buf bytes.Buffer if err := ifTpl.Execute(&buf, ie); err != nil { @@ -153,12 +153,16 @@ func useIfElse(def, rep int, seen fields.RepetitionTypes, f fields.Field) bool { } i := f.DefIndex(def) - p := f.Parent(i + 1) - if !p.Optional() || seen.Repeated() { + + if i+1 > len(seen) && f.RepetitionTypes[:len(seen)].Required() { return false } - if def == f.MaxDef() && rep > 0 { + if len(seen) > i+1 { + seen = seen[:i+1] + } + + if seen.Repeated() || (def == f.MaxDef() && rep > 0) { return false } diff --git a/internal/dremel/write_test.go b/internal/dremel/write_test.go index ec956de..009e35e 100644 --- a/internal/dremel/write_test.go +++ b/internal/dremel/write_test.go @@ -150,15 +150,9 @@ func TestWrite(t *testing.T) { def := defs[0] switch def { case 1: - if x.Friend.Hobby == nil { - x.Friend.Hobby = &Item{} - } + x.Friend.Hobby = &Item{} case 2: - if x.Friend.Hobby == nil { - x.Friend.Hobby = &Item{Name: pstring(vals[0])} - } else { - x.Friend.Hobby.Name = pstring(vals[0]) - } + x.Friend.Hobby = &Item{Name: pstring(vals[0])} return 1, 1 } @@ -206,7 +200,7 @@ func TestWrite(t *testing.T) { { name: "nested 3 deep all optional and seen by optional field", fields: []fields.Field{ - {FieldNames: []string{"Friend", "Rank"}, FieldTypes: []string{"Entity", "int"}, RepetitionTypes: []fields.RepetitionType{fields.Required, fields.Optional}}, + {FieldNames: []string{"Friend", "Rank"}, FieldTypes: []string{"Entity", "int"}, RepetitionTypes: []fields.RepetitionType{fields.Optional, fields.Optional}}, {Type: "Person", TypeName: "*string", FieldNames: []string{"Friend", "Hobby", "Name"}, FieldTypes: []string{"Entity", "Item", "string"}, RepetitionTypes: []fields.RepetitionType{fields.Optional, fields.Optional, fields.Optional}}, }, result: `func writeFriendHobbyName(x *Person, vals []string, defs, reps []uint8) (int, int) { @@ -219,16 +213,14 @@ func TestWrite(t *testing.T) { case 2: if x.Friend == nil { x.Friend = &Entity{Hobby: &Item{}} - } else if x.Friend.Hobby == nil { + } else { x.Friend.Hobby = &Item{} } case 3: if x.Friend == nil { x.Friend = &Entity{Hobby: &Item{Name: pstring(vals[0])}} - } else if x.Friend.Hobby == nil { - x.Friend.Hobby = &Item{Name: pstring(vals[0])} } else { - x.Friend.Hobby.Name = pstring(vals[0]) + x.Friend.Hobby = &Item{Name: pstring(vals[0])} } return 1, 1 } @@ -272,26 +264,20 @@ func TestWrite(t *testing.T) { case 2: if x.Friend == nil { x.Friend = &Entity{Hobby: &Item{}} - } else if x.Friend.Hobby == nil { + } else { x.Friend.Hobby = &Item{} } case 3: if x.Friend == nil { x.Friend = &Entity{Hobby: &Item{Name: &Name{}}} - } else if x.Friend.Hobby == nil { + } else { x.Friend.Hobby = &Item{Name: &Name{}} - } else if x.Friend.Hobby.Name == nil { - x.Friend.Hobby.Name = &Name{} } case 4: if x.Friend == nil { x.Friend = &Entity{Hobby: &Item{Name: &Name{First: pstring(vals[0])}}} - } else if x.Friend.Hobby == nil { - x.Friend.Hobby = &Item{Name: &Name{First: pstring(vals[0])}} - } else if x.Friend.Hobby.Name == nil { - x.Friend.Hobby.Name = &Name{First: pstring(vals[0])} } else { - x.Friend.Hobby.Name.First = pstring(vals[0]) + x.Friend.Hobby = &Item{Name: &Name{First: pstring(vals[0])}} } return 1, 1 } @@ -318,7 +304,7 @@ func TestWrite(t *testing.T) { }`, }, { - name: "four deep mixed and seen by optional field", + name: "four deep mixed and seen by a required sub-field", fields: []fields.Field{ {FieldNames: []string{"Friend", "Rank"}, FieldTypes: []string{"Entity", "int"}, RepetitionTypes: []fields.RepetitionType{fields.Required, fields.Optional}}, {Type: "Person", TypeName: "*string", FieldNames: []string{"Friend", "Hobby", "Name", "First"}, FieldTypes: []string{"Entity", "Item", "Name", "string"}, RepetitionTypes: []fields.RepetitionType{fields.Required, fields.Optional, fields.Optional, fields.Optional}}, @@ -327,23 +313,11 @@ func TestWrite(t *testing.T) { def := defs[0] switch def { case 1: - if x.Friend.Hobby == nil { - x.Friend.Hobby = &Item{} - } + x.Friend.Hobby = &Item{} case 2: - if x.Friend.Hobby == nil { - x.Friend.Hobby = &Item{Name: &Name{}} - } else if x.Friend.Hobby.Name == nil { - x.Friend.Hobby.Name = &Name{} - } + x.Friend.Hobby = &Item{Name: &Name{}} case 3: - if x.Friend.Hobby == nil { - x.Friend.Hobby = &Item{Name: &Name{First: pstring(vals[0])}} - } else if x.Friend.Hobby.Name == nil { - x.Friend.Hobby.Name = &Name{First: pstring(vals[0])} - } else { - x.Friend.Hobby.Name.First = pstring(vals[0]) - } + x.Friend.Hobby = &Item{Name: &Name{First: pstring(vals[0])}} return 1, 1 } @@ -369,7 +343,7 @@ func TestWrite(t *testing.T) { }`, }, { - name: "four deep mixed v2 and seen by optional fields", + name: "four deep mixed v2 and seen by an optional field", fields: []fields.Field{ {FieldNames: []string{"Friend", "Rank"}, FieldTypes: []string{"Entity", "int"}, RepetitionTypes: []fields.RepetitionType{fields.Optional}}, {Type: "Person", TypeName: "*string", FieldNames: []string{"Friend", "Hobby", "Name", "First"}, FieldTypes: []string{"Entity", "Item", "Name", "string"}, RepetitionTypes: []fields.RepetitionType{fields.Optional, fields.Optional, fields.Optional, fields.Required}}, @@ -384,18 +358,14 @@ func TestWrite(t *testing.T) { case 2: if x.Friend == nil { x.Friend = &Entity{Hobby: &Item{}} - } else if x.Friend.Hobby == nil { + } else { x.Friend.Hobby = &Item{} } case 3: if x.Friend == nil { x.Friend = &Entity{Hobby: &Item{Name: &Name{First: vals[0]}}} - } else if x.Friend.Hobby == nil { - x.Friend.Hobby = &Item{Name: &Name{First: vals[0]}} - } else if x.Friend.Hobby.Name == nil { - x.Friend.Hobby.Name = &Name{First: vals[0]} } else { - x.Friend.Hobby.Name.First = vals[0] + x.Friend.Hobby = &Item{Name: &Name{First: vals[0]}} } return 1, 1 } @@ -660,7 +630,6 @@ func TestWrite(t *testing.T) { t.Run(fmt.Sprintf("%02d %s", i, tc.name), func(t *testing.T) { s := dremel.Write(len(tc.fields)-1, tc.fields) gocode, err := format.Source([]byte(s)) - //fmt.Println(string(gocode)) assert.NoError(t, err) assert.Equal(t, tc.result, string(gocode)) }) diff --git a/internal/fields/fields.go b/internal/fields/fields.go index c44bcb5..50fb733 100644 --- a/internal/fields/fields.go +++ b/internal/fields/fields.go @@ -268,6 +268,15 @@ func (f Field) start(def, rep int) int { seen = seen[:di+1] } + if len(f.RepetitionTypes)-1 > di { + for _, rt := range f.RepetitionTypes[di+1:] { + if rt >= Optional { + break + } + di++ + } + } + if rep == 0 { rep = int(seen.MaxRep()) + 1 } @@ -283,21 +292,12 @@ func (f Field) start(def, rep int) int { reps++ } - if rt == Optional && (!seen.Repeated() || len(seen) <= i) { - break - } - if reps == rep { break } - } - if len(seen) == def && f.RepetitionTypes[di] == Optional && def == int(f.MaxDef()) && i < len(f.RepetitionTypes)-1 { - for _, rt := range f.RepetitionTypes[i+1:] { - if rt == Optional || rt == Repeated { - break - } - i++ + if rt >= Optional && i >= len(seen) { + break } } diff --git a/internal/fields/fields_test.go b/internal/fields/fields_test.go index df40e76..04d13c2 100644 --- a/internal/fields/fields_test.go +++ b/internal/fields/fields_test.go @@ -28,13 +28,6 @@ func TestFields(t *testing.T) { rep: 0, expected: "x.Link = &Link{Backward: []int64{vals[nVals]}}", }, - { - field: fields.Field{TypeName: "int64", FieldNames: []string{"Link", "Backward"}, FieldTypes: []string{"Link", "int64"}, RepetitionTypes: []fields.RepetitionType{fields.Optional, fields.Repeated}}, - def: 2, - rep: 0, - seen: []fields.RepetitionType{fields.Optional}, - expected: "x.Link = &Link{Backward: []int64{vals[nVals]}}", - }, { field: fields.Field{TypeName: "int64", FieldNames: []string{"Link", "Backward"}, FieldTypes: []string{"Link", "int64"}, RepetitionTypes: []fields.RepetitionType{fields.Optional, fields.Repeated}}, def: 2, @@ -239,12 +232,6 @@ func TestFields(t *testing.T) { def: 1, expected: "x.Friend = &Entity{}", }, - { - field: fields.Field{FieldNames: []string{"Friend", "Hobby", "Name", "First"}, FieldTypes: []string{"Entity", "Item", "Name", "string"}, RepetitionTypes: []fields.RepetitionType{fields.Optional, fields.Optional, fields.Optional, fields.Required}}, - def: 2, - seen: []fields.RepetitionType{fields.Optional}, - expected: "x.Friend = &Entity{Hobby: &Item{}}", - }, { field: fields.Field{FieldNames: []string{"Friend", "Hobby", "Name", "First"}, FieldTypes: []string{"Entity", "Item", "Name", "string"}, RepetitionTypes: []fields.RepetitionType{fields.Optional, fields.Optional, fields.Optional, fields.Required}}, def: 2, @@ -288,8 +275,8 @@ func TestFields(t *testing.T) { { field: fields.Field{FieldNames: []string{"Link", "Forward"}, FieldTypes: []string{"Link", "int64"}, RepetitionTypes: []fields.RepetitionType{fields.Optional, fields.Repeated}}, def: 2, - seen: []fields.RepetitionType{fields.Optional}, - expected: "x.Link = &Link{Forward: []int64{vals[nVals]}}", + seen: []fields.RepetitionType{fields.Repeated}, + expected: "x.Link.Forward = append(x.Link.Forward, vals[nVals])", }, { field: fields.Field{FieldNames: []string{"LuckyNumbers"}, FieldTypes: []string{"int64"}, RepetitionTypes: []fields.RepetitionType{fields.Repeated}}, @@ -303,6 +290,35 @@ func TestFields(t *testing.T) { rep: 1, expected: "x.LuckyNumbers = append(x.LuckyNumbers, vals[nVals])", }, + { + field: fields.Field{FieldNames: []string{"A", "B", "C", "D", "E", "F"}, FieldTypes: []string{"A", "B", "C", "D", "E", "string"}, RepetitionTypes: []fields.RepetitionType{fields.Required, fields.Optional, fields.Required, fields.Repeated, fields.Required, fields.Optional}}, + def: 3, + expected: "x.A.B = &B{C: C{D: []D{{E: E{F: pstring(vals[nVals])}}}}}", + }, + { + field: fields.Field{FieldNames: []string{"A", "B", "C", "D", "E", "F"}, FieldTypes: []string{"A", "B", "C", "D", "E", "string"}, RepetitionTypes: []fields.RepetitionType{fields.Required, fields.Optional, fields.Required, fields.Repeated, fields.Required, fields.Optional}}, + def: 3, + seen: []fields.RepetitionType{fields.Repeated}, + expected: "x.A.B = &B{C: C{D: []D{{E: E{F: pstring(vals[nVals])}}}}}", + }, + { + field: fields.Field{FieldNames: []string{"A", "B", "C", "D", "E", "F"}, FieldTypes: []string{"A", "B", "C", "D", "E", "string"}, RepetitionTypes: []fields.RepetitionType{fields.Required, fields.Optional, fields.Required, fields.Repeated, fields.Required, fields.Optional}}, + def: 3, + seen: []fields.RepetitionType{fields.Repeated, fields.Repeated}, + expected: "x.A.B.C.D = []D{{E: E{F: pstring(vals[nVals])}}}", + }, + { + field: fields.Field{FieldNames: []string{"A", "B", "C", "D", "E", "F"}, FieldTypes: []string{"A", "B", "C", "D", "E", "string"}, RepetitionTypes: []fields.RepetitionType{fields.Required, fields.Optional, fields.Required, fields.Repeated, fields.Required, fields.Optional}}, + def: 3, + seen: []fields.RepetitionType{fields.Repeated, fields.Repeated, fields.Repeated}, + expected: "x.A.B.C.D = []D{{E: E{F: pstring(vals[nVals])}}}", + }, + { + field: fields.Field{FieldNames: []string{"A", "B", "C", "D", "E", "F"}, FieldTypes: []string{"A", "B", "C", "D", "E", "string"}, RepetitionTypes: []fields.RepetitionType{fields.Required, fields.Optional, fields.Required, fields.Repeated, fields.Required, fields.Optional}}, + def: 3, + seen: []fields.RepetitionType{fields.Repeated, fields.Repeated, fields.Repeated, fields.Repeated}, + expected: "x.A.B.C.D[ind[0]].E.F = pstring(vals[nVals])", + }, } for i, tc := range testCases {