Skip to content

Commit

Permalink
Improved Compatibility Around LAST_INSERT_ID (#17408)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
Signed-off-by: Harshit Gangal <[email protected]>
Co-authored-by: Harshit Gangal <[email protected]>
  • Loading branch information
systay and harshit-gangal authored Dec 20, 2024
1 parent 9714713 commit e750d22
Show file tree
Hide file tree
Showing 59 changed files with 2,472 additions and 1,584 deletions.
2 changes: 1 addition & 1 deletion go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ func (c *Conn) readHeaderFrom(r io.Reader) (int, error) {
return 0, vterrors.Wrapf(err, "io.ReadFull(header size) failed")
}

sequence := uint8(c.header[3])
sequence := c.header[3]
if sequence != c.sequence {
return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence, expected %v got %v", c.sequence, sequence)
}
Expand Down
1 change: 1 addition & 0 deletions go/mysql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result,
return &sqltypes.Result{
RowsAffected: packetOk.affectedRows,
InsertID: packetOk.lastInsertID,
InsertIDChanged: packetOk.lastInsertID > 0,
SessionStateChanges: packetOk.sessionStateData,
StatusFlags: packetOk.statusFlags,
Info: packetOk.info,
Expand Down
16 changes: 7 additions & 9 deletions go/mysql/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,13 @@ import (
"sync"
"testing"

"google.golang.org/protobuf/proto"

"vitess.io/vitess/go/mysql/sqlerror"

"vitess.io/vitess/go/mysql/collations"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/mysql/sqlerror"
"vitess.io/vitess/go/sqltypes"

querypb "vitess.io/vitess/go/vt/proto/query"
)

Expand Down Expand Up @@ -393,8 +389,9 @@ func TestQueries(t *testing.T) {

// Typical Insert result
checkQuery(t, "insert", sConn, cConn, &sqltypes.Result{
RowsAffected: 0x8010203040506070,
InsertID: 0x0102030405060708,
RowsAffected: 0x8010203040506070,
InsertID: 0x0102030405060708,
InsertIDChanged: true,
})

// Typical Select with TYPE_AND_NAME.
Expand Down Expand Up @@ -702,6 +699,7 @@ func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result *
got = &sqltypes.Result{}
got.RowsAffected = result.RowsAffected
got.InsertID = result.InsertID
got.InsertIDChanged = result.InsertIDUpdated()
got.Fields, err = cConn.Fields()
if err != nil {
fatalError = fmt.Sprintf("Fields(%v) failed: %v", query, err)
Expand Down
3 changes: 3 additions & 0 deletions go/sqltypes/proto3.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ func ResultToProto3(qr *Result) *querypb.QueryResult {
Fields: qr.Fields,
RowsAffected: qr.RowsAffected,
InsertId: qr.InsertID,
InsertIdChanged: qr.InsertIDChanged,
Rows: RowsToProto3(qr.Rows),
Info: qr.Info,
SessionStateChanges: qr.SessionStateChanges,
Expand All @@ -119,6 +120,7 @@ func Proto3ToResult(qr *querypb.QueryResult) *Result {
Fields: qr.Fields,
RowsAffected: qr.RowsAffected,
InsertID: qr.InsertId,
InsertIDChanged: qr.InsertIdChanged,
Rows: proto3ToRows(qr.Fields, qr.Rows),
Info: qr.Info,
SessionStateChanges: qr.SessionStateChanges,
Expand All @@ -136,6 +138,7 @@ func CustomProto3ToResult(fields []*querypb.Field, qr *querypb.QueryResult) *Res
Fields: qr.Fields,
RowsAffected: qr.RowsAffected,
InsertID: qr.InsertId,
InsertIDChanged: qr.InsertIdChanged,
Rows: proto3ToRows(fields, qr.Rows),
Info: qr.Info,
SessionStateChanges: qr.SessionStateChanges,
Expand Down
70 changes: 40 additions & 30 deletions go/sqltypes/proto3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ func TestResult(t *testing.T) {
Type: Float64,
}}
sqlResult := &Result{
Fields: fields,
InsertID: 1,
RowsAffected: 2,
Fields: fields,
InsertID: 1,
InsertIDChanged: true,
RowsAffected: 2,
Rows: [][]Value{{
TestValue(VarChar, "aa"),
TestValue(Int64, "1"),
Expand All @@ -53,9 +54,10 @@ func TestResult(t *testing.T) {
}},
}
p3Result := &querypb.QueryResult{
Fields: fields,
InsertId: 1,
RowsAffected: 2,
Fields: fields,
InsertId: 1,
InsertIdChanged: true,
RowsAffected: 2,
Rows: []*querypb.Row{{
Lengths: []int64{2, 1, 1},
Values: []byte("aa12"),
Expand Down Expand Up @@ -105,36 +107,40 @@ func TestResults(t *testing.T) {
Type: Float64,
}}
sqlResults := []Result{{
Fields: fields1,
InsertID: 1,
RowsAffected: 2,
Fields: fields1,
InsertID: 1,
InsertIDChanged: true,
RowsAffected: 2,
Rows: [][]Value{{
TestValue(VarChar, "aa"),
TestValue(Int64, "1"),
TestValue(Float64, "2"),
}},
}, {
Fields: fields2,
InsertID: 3,
RowsAffected: 4,
Fields: fields2,
InsertID: 3,
InsertIDChanged: true,
RowsAffected: 4,
Rows: [][]Value{{
TestValue(VarChar, "bb"),
TestValue(Int64, "3"),
TestValue(Float64, "4"),
}},
}}
p3Results := []*querypb.QueryResult{{
Fields: fields1,
InsertId: 1,
RowsAffected: 2,
Fields: fields1,
InsertId: 1,
InsertIdChanged: true,
RowsAffected: 2,
Rows: []*querypb.Row{{
Lengths: []int64{2, 1, 1},
Values: []byte("aa12"),
}},
}, {
Fields: fields2,
InsertId: 3,
RowsAffected: 4,
Fields: fields2,
InsertId: 3,
InsertIdChanged: true,
RowsAffected: 4,
Rows: []*querypb.Row{{
Lengths: []int64{2, 1, 1},
Values: []byte("bb34"),
Expand Down Expand Up @@ -176,9 +182,10 @@ func TestQueryReponses(t *testing.T) {
queryResponses := []QueryResponse{
{
QueryResult: &Result{
Fields: fields1,
InsertID: 1,
RowsAffected: 2,
Fields: fields1,
InsertID: 1,
InsertIDChanged: true,
RowsAffected: 2,
Rows: [][]Value{{
TestValue(VarChar, "aa"),
TestValue(Int64, "1"),
Expand All @@ -188,9 +195,10 @@ func TestQueryReponses(t *testing.T) {
QueryError: nil,
}, {
QueryResult: &Result{
Fields: fields2,
InsertID: 3,
RowsAffected: 4,
Fields: fields2,
InsertID: 3,
InsertIDChanged: true,
RowsAffected: 4,
Rows: [][]Value{{
TestValue(VarChar, "bb"),
TestValue(Int64, "3"),
Expand All @@ -208,9 +216,10 @@ func TestQueryReponses(t *testing.T) {
{
Error: nil,
Result: &querypb.QueryResult{
Fields: fields1,
InsertId: 1,
RowsAffected: 2,
Fields: fields1,
InsertId: 1,
InsertIdChanged: true,
RowsAffected: 2,
Rows: []*querypb.Row{{
Lengths: []int64{2, 1, 1},
Values: []byte("aa12"),
Expand All @@ -219,9 +228,10 @@ func TestQueryReponses(t *testing.T) {
}, {
Error: nil,
Result: &querypb.QueryResult{
Fields: fields2,
InsertId: 3,
RowsAffected: 4,
Fields: fields2,
InsertId: 3,
InsertIdChanged: true,
RowsAffected: 4,
Rows: []*querypb.Row{{
Lengths: []int64{2, 1, 1},
Values: []byte("bb34"),
Expand Down
22 changes: 15 additions & 7 deletions go/sqltypes/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type Result struct {
Fields []*querypb.Field `json:"fields"`
RowsAffected uint64 `json:"rows_affected"`
InsertID uint64 `json:"insert_id"`
InsertIDChanged bool `json:"insert_id_changed"`
Rows []Row `json:"rows"`
SessionStateChanges string `json:"session_state_changes"`
StatusFlags uint16 `json:"status_flags"`
Expand Down Expand Up @@ -92,6 +93,7 @@ func (result *Result) Copy() *Result {
out := &Result{
RowsAffected: result.RowsAffected,
InsertID: result.InsertID,
InsertIDChanged: result.InsertIDChanged,
SessionStateChanges: result.SessionStateChanges,
StatusFlags: result.StatusFlags,
Info: result.Info,
Expand All @@ -116,6 +118,7 @@ func (result *Result) ShallowCopy() *Result {
return &Result{
Fields: result.Fields,
InsertID: result.InsertID,
InsertIDChanged: result.InsertIDChanged,
RowsAffected: result.RowsAffected,
Info: result.Info,
SessionStateChanges: result.SessionStateChanges,
Expand All @@ -129,6 +132,7 @@ func (result *Result) Metadata() *Result {
return &Result{
Fields: result.Fields,
InsertID: result.InsertID,
InsertIDChanged: result.InsertIDChanged,
RowsAffected: result.RowsAffected,
Info: result.Info,
SessionStateChanges: result.SessionStateChanges,
Expand All @@ -153,6 +157,7 @@ func (result *Result) Truncate(l int) *Result {

out := &Result{
InsertID: result.InsertID,
InsertIDChanged: result.InsertIDChanged,
RowsAffected: result.RowsAffected,
Info: result.Info,
SessionStateChanges: result.SessionStateChanges,
Expand Down Expand Up @@ -198,6 +203,7 @@ func (result *Result) Equal(other *Result) bool {
return FieldsEqual(result.Fields, other.Fields) &&
result.RowsAffected == other.RowsAffected &&
result.InsertID == other.InsertID &&
result.InsertIDChanged == other.InsertIDChanged &&
slices.EqualFunc(result.Rows, other.Rows, func(a, b Row) bool {
return RowEqual(a, b)
})
Expand Down Expand Up @@ -324,15 +330,13 @@ func (result *Result) StripMetadata(incl querypb.ExecuteOptions_IncludedFields)
// to another result.Note currently it doesn't handle cases like
// if two results have different fields.We will enhance this function.
func (result *Result) AppendResult(src *Result) {
if src.RowsAffected == 0 && len(src.Rows) == 0 && len(src.Fields) == 0 {
return
}
if result.Fields == nil {
result.Fields = src.Fields
}
result.RowsAffected += src.RowsAffected
if src.InsertID != 0 {
if src.InsertIDUpdated() {
result.InsertID = src.InsertID
result.InsertIDChanged = true
}
if len(result.Fields) == 0 {
result.Fields = src.Fields
}
result.Rows = append(result.Rows, src.Rows...)
}
Expand All @@ -351,3 +355,7 @@ func (result *Result) IsMoreResultsExists() bool {
func (result *Result) IsInTransaction() bool {
return result.StatusFlags&ServerStatusInTrans == ServerStatusInTrans
}

func (result *Result) InsertIDUpdated() bool {
return result.InsertIDChanged || result.InsertID > 0
}
Loading

0 comments on commit e750d22

Please sign in to comment.