Skip to content

Commit

Permalink
Replace mutex-guarded map with sync.Map (#63)
Browse files Browse the repository at this point in the history
* randomize listen port. makes the tests pass on macos.

* rip the mutex out of safemap.

* Benchmark the safemap.

* did  some tests, found a panic. fixed it.
  • Loading branch information
perbu authored Nov 26, 2023
1 parent 244f0bb commit c7be50d
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 21 deletions.
40 changes: 19 additions & 21 deletions safemap/safemap.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,43 @@ package safemap
import "sync"

type SafeMap[K comparable, V any] struct {
mu sync.RWMutex
data map[K]V
data sync.Map
}

func New[K comparable, V any]() *SafeMap[K, V] {
return &SafeMap[K, V]{
data: make(map[K]V),
data: sync.Map{},
}
}

func (s *SafeMap[K, V]) Set(k K, v V) {
s.mu.Lock()
defer s.mu.Unlock()
s.data[k] = v
s.data.Store(k, v)
}

func (s *SafeMap[K, V]) Get(k K) (V, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
val, ok := s.data[k]
return val, ok
val, ok := s.data.Load(k)
if !ok {
return *new(V), false // Return zero value of type V and false
}
return val.(V), ok
}

func (s *SafeMap[K, V]) Delete(k K) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.data, k)
s.data.Delete(k)
}

func (s *SafeMap[K, V]) Len() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.data)
count := 0
s.data.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
}

func (s *SafeMap[K, V]) ForEach(f func(K, V)) {
s.mu.RLock()
defer s.mu.RUnlock()
for key, val := range s.data {
f(key, val)
}
s.data.Range(func(key, value interface{}) bool {
f(key.(K), value.(V))
return true
})
}
121 changes: 121 additions & 0 deletions safemap/safemap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package safemap

import (
"math/rand"
"testing"
)

func TestNewAndSet(t *testing.T) {
sm := New[int, string]()
sm.Set(1, "one")
sm.Set(2, "two")

if sm.Len() != 2 {
t.Errorf("Expected length 2, got %d", sm.Len())
}
}

func TestGet(t *testing.T) {
sm := New[int, string]()
sm.Set(1, "one")

val, ok := sm.Get(1)
if !ok || val != "one" {
t.Errorf("Expected 'one', got %s", val)
}
_, ok = sm.Get(2)
if ok {
t.Errorf("Expected false, got true")
}
}

func TestDelete(t *testing.T) {
sm := New[int, string]()
sm.Set(1, "one")
sm.Delete(1)
_, ok := sm.Get(1)
if ok {
t.Errorf("Expected key 1 to be deleted")
}

sm.Delete(2) // just make sure this doesn't panic
}

func TestLen(t *testing.T) {
sm := New[int, string]()
sm.Set(1, "one")
sm.Set(2, "two")
sm.Set(3, "three")
sm.Delete(2)

if sm.Len() != 2 {
t.Errorf("Expected length 2, got %d", sm.Len())
}
}

func TestForEach(t *testing.T) {
sm := New[int, string]()
sm.Set(1, "one")
sm.Set(2, "two")

keys := make([]int, 0)
sm.ForEach(func(k int, v string) {
keys = append(keys, k)
})

if len(keys) != 2 {
t.Errorf("Expected 2 keys, got %d", len(keys))
}

// Check if keys 1 and 2 are present
if !contains(keys, 1) || !contains(keys, 2) {
t.Errorf("Expected keys 1 and 2, got %v", keys)
}
}

// Helper function to check if a slice contains a specific element.
func contains(slice []int, element int) bool {
for _, a := range slice {
if a == element {
return true
}
}
return false
}

// Benchmark Get
func BenchmarkGetConcurrent(b *testing.B) {
ds := New[uint64, uint64]()
// Pre-fill the data store with test data if needed
for i := 0; i < 100000; i++ {
ds.Set(uint64(i), uint64(i))
}
b.SetParallelism(100)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
r := rand.Uint64() % 100000
_, _ = ds.Get(r)
}
})
}

// fool the compiler
var silly uint64

// Benchmark Set
func BenchmarkSetConcurrent(b *testing.B) {
ds := New[uint64, uint64]()
b.SetParallelism(100)
b.ResetTimer()

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
r := rand.Uint64() % 100000
silly, _ = ds.Get(r)
ds.Set(r, r)
silly, _ = ds.Get(r)

}
})
}

0 comments on commit c7be50d

Please sign in to comment.