diff --git a/session_storage.go b/session_storage.go index bc2cbedf..0b09c184 100644 --- a/session_storage.go +++ b/session_storage.go @@ -60,46 +60,41 @@ func (c *smap) Range(f func(key string, value any) bool) { type ( ConcurrentMap[K comparable, V any] struct { - hasher maphash.Hasher[K] - sharding uint64 - buckets []*Map[K, V] + hasher maphash.Hasher[K] + num uint64 + shardings []*Map[K, V] } ) // NewConcurrentMap create a new concurrency-safe map // arg0 represents the number of shardings; arg1 represents the initialized capacity of a sharding. func NewConcurrentMap[K comparable, V any](size ...uint64) *ConcurrentMap[K, V] { - sharding, capacity := uint64(0), uint64(0) - if len(size) >= 1 { - sharding = size[0] - } - if len(size) >= 2 { - capacity = size[1] - } - sharding = internal.SelectValue(sharding <= 0, 16, sharding) - sharding = internal.ToBinaryNumber(sharding) + size = append(size, 0, 0) + num, capacity := size[0], size[1] + num = internal.ToBinaryNumber(internal.SelectValue(num <= 0, 16, num)) var cm = &ConcurrentMap[K, V]{ - hasher: maphash.NewHasher[K](), - sharding: sharding, - buckets: make([]*Map[K, V], sharding), + hasher: maphash.NewHasher[K](), + num: num, + shardings: make([]*Map[K, V], num), } - for i, _ := range cm.buckets { - cm.buckets[i] = &Map[K, V]{m: make(map[K]V, capacity)} + for i, _ := range cm.shardings { + cm.shardings[i] = NewMap[K, V](int(capacity)) } return cm } // GetSharding returns a map sharding for a key +// the operations inside the sharding is lockless, and need to be locked manually. func (c *ConcurrentMap[K, V]) GetSharding(key K) *Map[K, V] { var hashCode = c.hasher.Hash(key) - var index = hashCode & (c.sharding - 1) - return c.buckets[index] + var index = hashCode & (c.num - 1) + return c.shardings[index] } // Len returns the number of elements in the map func (c *ConcurrentMap[K, V]) Len() int { var length = 0 - for _, b := range c.buckets { + for _, b := range c.shardings { b.Lock() length += b.Len() b.Unlock() @@ -142,8 +137,8 @@ func (c *ConcurrentMap[K, V]) Range(f func(key K, value V) bool) { next = f(k, v) return next } - for i := uint64(0); i < c.sharding && next; i++ { - var b = c.buckets[i] + for i := uint64(0); i < c.num && next; i++ { + var b = c.shardings[i] b.Lock() b.Range(cb) b.Unlock() @@ -155,6 +150,16 @@ type Map[K comparable, V any] struct { m map[K]V } +func NewMap[K comparable, V any](size ...int) *Map[K, V] { + var capacity = 0 + if len(size) > 0 { + capacity = size[0] + } + c := new(Map[K, V]) + c.m = make(map[K]V, capacity) + return c +} + func (c *Map[K, V]) Len() int { return len(c.m) } func (c *Map[K, V]) Load(key K) (value V, ok bool) { diff --git a/session_storage_test.go b/session_storage_test.go index 883230e9..d4c8b5e5 100644 --- a/session_storage_test.go +++ b/session_storage_test.go @@ -100,7 +100,7 @@ func TestConcurrentMap(t *testing.T) { var as = assert.New(t) var m1 = make(map[string]any) var m2 = NewConcurrentMap[string, uint32]() - as.Equal(m2.sharding, uint64(16)) + as.Equal(m2.num, uint64(16)) var count = internal.AlphabetNumeric.Intn(1000) for i := 0; i < count; i++ { var key = string(internal.AlphabetNumeric.Generate(10)) @@ -128,7 +128,7 @@ func TestConcurrentMap(t *testing.T) { t.Run("", func(t *testing.T) { var sum = 0 var cm = NewConcurrentMap[string, int](8, 8) - for _, item := range cm.buckets { + for _, item := range cm.shardings { sum += len(item.m) } assert.Equal(t, sum, 0)