Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace mutex-guarded map with sync.Map #63

Merged
merged 5 commits into from
Nov 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions safemap/safemap.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,43 @@ package safemap
import "sync"

type SafeMap[K comparable, V any] struct {
mu sync.RWMutex
data map[K]V
data sync.Map
}

func New[K comparable, V any]() *SafeMap[K, V] {
return &SafeMap[K, V]{
data: make(map[K]V),
data: sync.Map{},
}
}

func (s *SafeMap[K, V]) Set(k K, v V) {
s.mu.Lock()
defer s.mu.Unlock()
s.data[k] = v
s.data.Store(k, v)
}

func (s *SafeMap[K, V]) Get(k K) (V, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
val, ok := s.data[k]
return val, ok
val, ok := s.data.Load(k)
if !ok {
return *new(V), false // Return zero value of type V and false
}
return val.(V), ok
}

func (s *SafeMap[K, V]) Delete(k K) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.data, k)
s.data.Delete(k)
}

func (s *SafeMap[K, V]) Len() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.data)
count := 0
s.data.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
}

func (s *SafeMap[K, V]) ForEach(f func(K, V)) {
s.mu.RLock()
defer s.mu.RUnlock()
for key, val := range s.data {
f(key, val)
}
s.data.Range(func(key, value interface{}) bool {
f(key.(K), value.(V))
return true
})
}
121 changes: 121 additions & 0 deletions safemap/safemap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package safemap

import (
"math/rand"
"testing"
)

func TestNewAndSet(t *testing.T) {
sm := New[int, string]()
sm.Set(1, "one")
sm.Set(2, "two")

if sm.Len() != 2 {
t.Errorf("Expected length 2, got %d", sm.Len())
}
}

func TestGet(t *testing.T) {
sm := New[int, string]()
sm.Set(1, "one")

val, ok := sm.Get(1)
if !ok || val != "one" {
t.Errorf("Expected 'one', got %s", val)
}
_, ok = sm.Get(2)
if ok {
t.Errorf("Expected false, got true")
}
}

func TestDelete(t *testing.T) {
sm := New[int, string]()
sm.Set(1, "one")
sm.Delete(1)
_, ok := sm.Get(1)
if ok {
t.Errorf("Expected key 1 to be deleted")
}

sm.Delete(2) // just make sure this doesn't panic
}

func TestLen(t *testing.T) {
sm := New[int, string]()
sm.Set(1, "one")
sm.Set(2, "two")
sm.Set(3, "three")
sm.Delete(2)

if sm.Len() != 2 {
t.Errorf("Expected length 2, got %d", sm.Len())
}
}

func TestForEach(t *testing.T) {
sm := New[int, string]()
sm.Set(1, "one")
sm.Set(2, "two")

keys := make([]int, 0)
sm.ForEach(func(k int, v string) {
keys = append(keys, k)
})

if len(keys) != 2 {
t.Errorf("Expected 2 keys, got %d", len(keys))
}

// Check if keys 1 and 2 are present
if !contains(keys, 1) || !contains(keys, 2) {
t.Errorf("Expected keys 1 and 2, got %v", keys)
}
}

// Helper function to check if a slice contains a specific element.
func contains(slice []int, element int) bool {
for _, a := range slice {
if a == element {
return true
}
}
return false
}

// Benchmark Get
func BenchmarkGetConcurrent(b *testing.B) {
ds := New[uint64, uint64]()
// Pre-fill the data store with test data if needed
for i := 0; i < 100000; i++ {
ds.Set(uint64(i), uint64(i))
}
b.SetParallelism(100)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
r := rand.Uint64() % 100000
_, _ = ds.Get(r)
}
})
}

// fool the compiler
var silly uint64

// Benchmark Set
func BenchmarkSetConcurrent(b *testing.B) {
ds := New[uint64, uint64]()
b.SetParallelism(100)
b.ResetTimer()

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
r := rand.Uint64() % 100000
silly, _ = ds.Get(r)
ds.Set(r, r)
silly, _ = ds.Get(r)

}
})
}