From 15ac8488d04702e9dca87db082740fafabbaf57e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B2=90?= Date: Mon, 8 Apr 2024 10:27:52 +0800 Subject: [PATCH] feat: pass context down to gorm, remove old ContextAdapter (#234) * fix: pass context down to gorm * fix: delete context_adapter_test.go * fix: go mod tidy * fix: update README & delete ContextAdapter --- README.md | 4 +- adapter.go | 48 +++++++++++--- context_adapter.go | 85 ------------------------ context_adapter_test.go | 141 ---------------------------------------- go.mod | 1 - go.sum | 7 -- 6 files changed, 41 insertions(+), 245 deletions(-) delete mode 100644 context_adapter.go delete mode 100644 context_adapter_test.go diff --git a/README.md b/README.md index 63bdfdf..6eae47c 100644 --- a/README.md +++ b/README.md @@ -228,11 +228,11 @@ func TestGetAllowedRecordsForUser(t *testing.T) { `gormadapter` supports adapter with context, the following is a timeout control implemented using context ```go -ca, _ := NewContextAdapter("mysql", "root:@tcp(127.0.0.1:3306)/", "casbin") +a, _ := gormadapter.NewAdapter("mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/") // Your driver and data source. // Limited time 300s ctx, cancel := context.WithTimeout(context.Background(), 300*time.Microsecond) defer cancel() -err := ca.AddPolicyCtx(ctx, "p", "p", []string{"alice", "data1", "read"}) +err := a.AddPolicyCtx(ctx, "p", "p", []string{"alice", "data1", "read"}) if err != nil { panic(err) } diff --git a/adapter.go b/adapter.go index 99d34ac..d101e7a 100755 --- a/adapter.go +++ b/adapter.go @@ -478,8 +478,13 @@ func loadPolicyLine(line CasbinRule, model model.Model) error { // LoadPolicy loads policy from database. func (a *Adapter) LoadPolicy(model model.Model) error { + return a.LoadPolicyCtx(context.Background(), model) +} + +// LoadPolicyCtx loads policy from database. +func (a *Adapter) LoadPolicyCtx(ctx context.Context, model model.Model) error { var lines []CasbinRule - if err := a.db.Order("ID").Find(&lines).Error; err != nil { + if err := a.db.WithContext(ctx).Order("ID").Find(&lines).Error; err != nil { return err } err := a.Preview(&lines, model) @@ -596,8 +601,13 @@ func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule { // SavePolicy saves policy to database. func (a *Adapter) SavePolicy(model model.Model) error { + return a.SavePolicyCtx(context.Background(), model) +} + +// SavePolicyCtx saves policy to database. +func (a *Adapter) SavePolicyCtx(ctx context.Context, model model.Model) error { var err error - tx := a.db.Clauses(dbresolver.Write).Begin() + tx := a.db.WithContext(ctx).Clauses(dbresolver.Write).Begin() err = a.truncateTable() @@ -646,15 +656,25 @@ func (a *Adapter) SavePolicy(model model.Model) error { // AddPolicy adds a policy rule to the storage. func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error { + return a.AddPolicyCtx(context.Background(), sec, ptype, rule) +} + +// AddPolicyCtx adds a policy rule to the storage. +func (a *Adapter) AddPolicyCtx(ctx context.Context, sec string, ptype string, rule []string) error { line := a.savePolicyLine(ptype, rule) - err := a.db.Create(&line).Error + err := a.db.WithContext(ctx).Create(&line).Error return err } // RemovePolicy removes a policy rule from the storage. func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error { + return a.RemovePolicyCtx(context.Background(), sec, ptype, rule) +} + +// RemovePolicyCtx removes a policy rule from the storage. +func (a *Adapter) RemovePolicyCtx(ctx context.Context, sec string, ptype string, rule []string) error { line := a.savePolicyLine(ptype, rule) - err := a.rawDelete(a.db, line) //can't use db.Delete as we're not using primary key https://gorm.io/docs/update.html + err := a.rawDelete(ctx, a.db, line) //can't use db.Delete as we're not using primary key https://gorm.io/docs/update.html return err } @@ -709,10 +729,15 @@ func (a *Adapter) Transaction(e casbin.IEnforcer, fc func(casbin.IEnforcer) erro // RemovePolicies removes multiple policy rules from the storage. func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) error { + return a.RemovePoliciesCtx(context.Background(), sec, ptype, rules) +} + +// RemovePoliciesCtx removes multiple policy rules from the storage. +func (a *Adapter) RemovePoliciesCtx(ctx context.Context, sec string, ptype string, rules [][]string) error { return a.db.Transaction(func(tx *gorm.DB) error { for _, rule := range rules { line := a.savePolicyLine(ptype, rule) - if err := a.rawDelete(tx, line); err != nil { //can't use db.Delete as we're not using primary key https://gorm.io/docs/update.html + if err := a.rawDelete(ctx, tx, line); err != nil { //can't use db.Delete as we're not using primary key https://gorm.io/docs/update.html } } return nil @@ -721,12 +746,17 @@ func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) err // RemoveFilteredPolicy removes policy rules that match the filter from the storage. func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { + return a.RemoveFilteredPolicyCtx(context.Background(), sec, ptype, fieldIndex, fieldValues...) +} + +// RemoveFilteredPolicyCtx removes policy rules that match the filter from the storage. +func (a *Adapter) RemoveFilteredPolicyCtx(ctx context.Context, sec string, ptype string, fieldIndex int, fieldValues ...string) error { line := a.getTableInstance() line.Ptype = ptype if fieldIndex == -1 { - return a.rawDelete(a.db, *line) + return a.rawDelete(ctx, a.db, *line) } err := checkQueryField(fieldValues) @@ -752,7 +782,7 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) { line.V5 = fieldValues[5-fieldIndex] } - err = a.rawDelete(a.db, *line) + err = a.rawDelete(ctx, a.db, *line) return err } @@ -766,7 +796,7 @@ func checkQueryField(fieldValues []string) error { return errors.New("the query field cannot all be empty string (\"\"), please check") } -func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error { +func (a *Adapter) rawDelete(ctx context.Context, db *gorm.DB, line CasbinRule) error { queryArgs := []interface{}{line.Ptype} queryStr := "ptype = ?" @@ -795,7 +825,7 @@ func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error { queryArgs = append(queryArgs, line.V5) } args := append([]interface{}{queryStr}, queryArgs...) - err := db.Delete(a.getTableInstance(), args...).Error + err := db.WithContext(ctx).Delete(a.getTableInstance(), args...).Error return err } diff --git a/context_adapter.go b/context_adapter.go deleted file mode 100644 index d6a066f..0000000 --- a/context_adapter.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2023 The casbin Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gormadapter - -import ( - "context" - - "github.com/casbin/casbin/v2/model" -) - -type ContextAdapter struct { - *Adapter -} - -func NewContextAdapter(driverName string, dataSourceName string, params ...interface{}) (*ContextAdapter, error) { - a, err := NewAdapter(driverName, dataSourceName, params...) - return &ContextAdapter{ - a, - }, err -} - -// executeWithContext is a helper function to execute a function with context and return the result or error. -func executeWithContext(ctx context.Context, fn func() error) error { - done := make(chan error) - go func() { - done <- fn() - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-done: - return err - } -} - -// LoadPolicyCtx loads all policy rules from the storage with context. -func (ca *ContextAdapter) LoadPolicyCtx(ctx context.Context, model model.Model) error { - return executeWithContext(ctx, func() error { - return ca.LoadPolicy(model) - }) -} - -// SavePolicyCtx saves all policy rules to the storage with context. -func (ca *ContextAdapter) SavePolicyCtx(ctx context.Context, model model.Model) error { - return executeWithContext(ctx, func() error { - return ca.SavePolicy(model) - }) -} - -// AddPolicyCtx adds a policy rule to the storage with context. -// This is part of the Auto-Save feature. -func (ca *ContextAdapter) AddPolicyCtx(ctx context.Context, sec string, ptype string, rule []string) error { - return executeWithContext(ctx, func() error { - return ca.AddPolicy(sec, ptype, rule) - }) -} - -// RemovePolicyCtx removes a policy rule from the storage with context. -// This is part of the Auto-Save feature. -func (ca *ContextAdapter) RemovePolicyCtx(ctx context.Context, sec string, ptype string, rule []string) error { - return executeWithContext(ctx, func() error { - return ca.RemovePolicy(sec, ptype, rule) - }) -} - -// RemoveFilteredPolicyCtx removes policy rules that match the filter from the storage with context. -// This is part of the Auto-Save feature. -func (ca *ContextAdapter) RemoveFilteredPolicyCtx(ctx context.Context, sec string, ptype string, fieldIndex int, fieldValues ...string) error { - return executeWithContext(ctx, func() error { - return ca.RemoveFilteredPolicy(sec, ptype, fieldIndex, fieldValues...) - }) -} diff --git a/context_adapter_test.go b/context_adapter_test.go deleted file mode 100644 index d360090..0000000 --- a/context_adapter_test.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright 2023 The casbin Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gormadapter - -import ( - "context" - "testing" - "time" - - "github.com/agiledragon/gomonkey/v2" - "github.com/casbin/casbin/v2" - "github.com/stretchr/testify/assert" -) - -func mockExecuteWithContextTimeOut(ctx context.Context, fn func() error) error { - done := make(chan error) - go func() { - time.Sleep(500 * time.Microsecond) - done <- fn() - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-done: - return err - } -} - -func clearDBPolicy() (*casbin.Enforcer, *ContextAdapter) { - ca, err := NewContextAdapter("mysql", "root:@tcp(127.0.0.1:3306)/", "casbin") - if err != nil { - panic(err) - } - e, err := casbin.NewEnforcer("examples/rbac_model.conf", ca) - if err != nil { - panic(err) - } - e.ClearPolicy() - _ = e.SavePolicy() - - return e, ca -} - -func TestContextAdapter_LoadPolicyCtx(t *testing.T) { - e, ca := clearDBPolicy() - - db, _ := openDBConnection("mysql", "root:@tcp(127.0.0.1:3306)/casbin") - policy := &CasbinRule{ - Ptype: "p", - V0: "alice", - V1: "data1", - V2: "read", - } - db.Create(policy) - - assert.NoError(t, ca.LoadPolicyCtx(context.Background(), e.GetModel())) - e, _ = casbin.NewEnforcer(e.GetModel(), ca) - testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}}) - - var p = gomonkey.ApplyFunc(executeWithContext, mockExecuteWithContextTimeOut) - defer p.Reset() - ctx, cancel := context.WithTimeout(context.Background(), 300*time.Microsecond) - defer cancel() - assert.EqualError(t, ca.LoadPolicyCtx(ctx, e.GetModel()), "context deadline exceeded") -} - -func TestContextAdapter_SavePolicyCtx(t *testing.T) { - e, ca := clearDBPolicy() - - e.EnableAutoSave(false) - _, _ = e.AddPolicy("alice", "data1", "read") - assert.NoError(t, ca.SavePolicyCtx(context.Background(), e.GetModel())) - _ = e.LoadPolicy() - testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}}) - - var p = gomonkey.ApplyFunc(executeWithContext, mockExecuteWithContextTimeOut) - defer p.Reset() - ctx, cancel := context.WithTimeout(context.Background(), 300*time.Microsecond) - defer cancel() - assert.EqualError(t, ca.SavePolicyCtx(ctx, e.GetModel()), "context deadline exceeded") -} - -func TestContextAdapter_AddPolicyCtx(t *testing.T) { - e, ca := clearDBPolicy() - - assert.NoError(t, ca.AddPolicyCtx(context.Background(), "p", "p", []string{"alice", "data1", "read"})) - _ = e.LoadPolicy() - testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}}) - - var p = gomonkey.ApplyFunc(executeWithContext, mockExecuteWithContextTimeOut) - defer p.Reset() - ctx, cancel := context.WithTimeout(context.Background(), 300*time.Microsecond) - defer cancel() - assert.EqualError(t, ca.AddPolicyCtx(ctx, "p", "p", []string{"alice", "data1", "read"}), "context deadline exceeded") -} - -func TestContextAdapter_RemovePolicyCtx(t *testing.T) { - e, ca := clearDBPolicy() - - _ = ca.AddPolicy("p", "p", []string{"alice", "data1", "read"}) - _ = ca.AddPolicy("p", "p", []string{"alice", "data2", "read"}) - assert.NoError(t, ca.RemovePolicyCtx(context.Background(), "p", "p", []string{"alice", "data1", "read"})) - _ = e.LoadPolicy() - testGetPolicy(t, e, [][]string{{"alice", "data2", "read"}}) - - var p = gomonkey.ApplyFunc(executeWithContext, mockExecuteWithContextTimeOut) - defer p.Reset() - ctx, cancel := context.WithTimeout(context.Background(), 300*time.Microsecond) - defer cancel() - assert.EqualError(t, ca.RemovePolicyCtx(ctx, "p", "p", []string{"alice", "data1", "read"}), "context deadline exceeded") -} - -func TestContextAdapter_RemoveFilteredPolicyCtx(t *testing.T) { - e, ca := clearDBPolicy() - - _ = ca.AddPolicy("p", "p", []string{"alice", "data1", "read"}) - _ = ca.AddPolicy("p", "p", []string{"alice", "data1", "write"}) - _ = ca.AddPolicy("p", "p", []string{"alice", "data2", "read"}) - assert.NoError(t, ca.RemoveFilteredPolicyCtx(context.Background(), "p", "p", 1, "data1")) - _ = e.LoadPolicy() - testGetPolicy(t, e, [][]string{{"alice", "data2", "read"}}) - - var p = gomonkey.ApplyFunc(executeWithContext, mockExecuteWithContextTimeOut) - defer p.Reset() - ctx, cancel := context.WithTimeout(context.Background(), 300*time.Microsecond) - defer cancel() - assert.EqualError(t, ca.RemoveFilteredPolicyCtx(ctx, "p", "p", 1, "data1"), "context deadline exceeded") -} diff --git a/go.mod b/go.mod index aafa9ce..718b7d9 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/casbin/gorm-adapter/v3 go 1.20 require ( - github.com/agiledragon/gomonkey/v2 v2.2.0 github.com/casbin/casbin/v2 v2.77.1 github.com/glebarez/sqlite v1.7.0 github.com/go-sql-driver/mysql v1.7.0 diff --git a/go.sum b/go.sum index d4da577..0c37585 100644 --- a/go.sum +++ b/go.sum @@ -18,8 +18,6 @@ github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0 h1:HCc0+LpPfpC github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1:1G1pk05UrOh0NlF1oeaaix1x8XzrfjIDK47TY0Zehcw= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= -github.com/agiledragon/gomonkey/v2 v2.2.0 h1:QJWqpdEhGV/JJy70sZ/LDnhbSlMrqHAWHcNOjz1kyuI= -github.com/agiledragon/gomonkey/v2 v2.2.0/go.mod h1:ap1AmDzcVOAz1YpeJ3TCzIgstoaWLA6jbbgxfB4w2iY= github.com/casbin/casbin/v2 v2.77.1 h1:+H46VamJCTlmCPcb0N99Zaj4tSorfuvBh3v5lyGopeU= github.com/casbin/casbin/v2 v2.77.1/go.mod h1:mzGx0hYW9/ksOSpw3wNjk3NRAroq5VMFYUQ6G43iGPk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -49,7 +47,6 @@ github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71 github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= @@ -71,7 +68,6 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= @@ -95,8 +91,6 @@ github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 h1:VstopitMQ github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -175,7 +169,6 @@ golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=